import pyshark from collections import defaultdict from typing import Optional, List, Type, Union import pandas as pd from tabulate import tabulate from .models import FlowKey from .stats import MultiStats, BaseStats, STATS_TYPES class PySharkAnalyzer: """Packet flow analyzer using PyShark for Wireshark dissector support.""" def __init__(self, stats_classes: Optional[List[Type[BaseStats]]] = None): if stats_classes is None: stats_classes = [STATS_TYPES['overview']] self.stats_classes = stats_classes self.flows = defaultdict(lambda: MultiStats(stats_classes)) self.packet_count = 0 def _get_flow_key(self, packet) -> Optional[FlowKey]: """Extract flow key from PyShark packet.""" try: # Check for IP layer if not hasattr(packet, 'ip'): return None src_ip = packet.ip.src dst_ip = packet.ip.dst protocol = packet.transport_layer if hasattr(packet, 'transport_layer') else 'IP' # Get ports based on protocol src_port = 0 dst_port = 0 if hasattr(packet, 'tcp'): src_port = int(packet.tcp.srcport) dst_port = int(packet.tcp.dstport) protocol = 'TCP' elif hasattr(packet, 'udp'): src_port = int(packet.udp.srcport) dst_port = int(packet.udp.dstport) protocol = 'UDP' # Check for extended protocol types extended_type = None if hasattr(packet, 'ptp'): extended_type = 'PTP' # Add more protocol detection here as needed return FlowKey(src_ip, src_port, dst_ip, dst_port, protocol, extended_type) except AttributeError: return None def _process_packet(self, packet): """Process a single packet.""" key = self._get_flow_key(packet) if key: # Get timestamp and size timestamp = float(packet.sniff_timestamp) if hasattr(packet, 'sniff_timestamp') else 0 size = int(packet.length) if hasattr(packet, 'length') else 0 self.flows[key].add(timestamp, size, packet) self.packet_count += 1 def analyze_pcap(self, file: str, display_filter: Optional[str] = None): """Analyze packets from a PCAP file.""" print(f"Analyzing: {file}") if display_filter: print(f"Filter: {display_filter}") try: # Use FileCapture for PCAP files capture = pyshark.FileCapture( file, display_filter=display_filter, use_json=True, # Use JSON output for better performance include_raw=False # Don't include raw packet data ) # Process packets for packet in capture: self._process_packet(packet) # Show progress every 1000 packets if self.packet_count % 1000 == 0: print(f" Processed {self.packet_count} packets...") capture.close() print(f"Found {len(self.flows)} flows from {self.packet_count} packets") except Exception as e: print(f"Error analyzing PCAP: {e}") def analyze_live(self, interface: str, count: int = 100, display_filter: Optional[str] = None, bpf_filter: Optional[str] = None): """Capture and analyze packets from a live interface.""" print(f"Capturing {count} packets on {interface}") if display_filter: print(f"Display filter: {display_filter}") if bpf_filter: print(f"BPF filter: {bpf_filter}") try: # Use LiveCapture for live capture capture = pyshark.LiveCapture( interface=interface, display_filter=display_filter, bpf_filter=bpf_filter, use_json=True, include_raw=False ) # Capture packets capture.sniff(packet_count=count) # Process captured packets for packet in capture: self._process_packet(packet) capture.close() print(f"Found {len(self.flows)} flows from {self.packet_count} packets") except Exception as e: print(f"Error during live capture: {e}") def summary(self) -> pd.DataFrame: """Generate summary DataFrame of all flows.""" rows = [] for key, multi_stats in self.flows.items(): row = { 'Src IP': key.src_ip, 'Src Port': key.src_port, 'Dst IP': key.dst_ip, 'Dst Port': key.dst_port, 'Proto': key.protocol } if key.extended_type: row['Type'] = key.extended_type row.update(multi_stats.get_combined_summary()) rows.append(row) # Sort by packet count descending df = pd.DataFrame(rows) if not df.empty and 'Pkts' in df.columns: df = df.sort_values('Pkts', ascending=False) return df def print_summary(self): """Print formatted summary of flows.""" df = self.summary() if df.empty: print("No flows detected") return print(f"\n{len(df)} flows:") print(tabulate(df, headers='keys', tablefmt='plain', showindex=False)) if 'Pkts' in df.columns and 'Bytes' in df.columns: print(f"\nTotals: {df['Pkts'].sum()} packets, {df['Bytes'].sum()} bytes") def get_protocol_summary(self) -> pd.DataFrame: """Get summary grouped by protocol.""" df = self.summary() if df.empty: return df # Group by protocol protocol_summary = df.groupby('Proto').agg({ 'Pkts': 'sum', 'Bytes': 'sum' }).reset_index() return protocol_summary def apply_wireshark_filter(self, display_filter: str): """ Apply a Wireshark display filter to the analysis. This demonstrates PyShark's ability to use Wireshark's filtering. """ filtered_flows = defaultdict(lambda: MultiStats(self.stats_classes)) # This would require re-processing with the filter # Shown here as an example of the capability print(f"Note: To apply Wireshark filters, re-analyze with display_filter parameter") return filtered_flows