187 lines
6.7 KiB
Python
187 lines
6.7 KiB
Python
import pyshark
|
|
from collections import defaultdict
|
|
from typing import Optional, List, Type, Union
|
|
import pandas as pd
|
|
from tabulate import tabulate
|
|
|
|
from .models import FlowKey
|
|
from .stats import MultiStats, BaseStats, STATS_TYPES
|
|
|
|
|
|
class PySharkAnalyzer:
|
|
"""Packet flow analyzer using PyShark for Wireshark dissector support."""
|
|
|
|
def __init__(self, stats_classes: Optional[List[Type[BaseStats]]] = None):
|
|
if stats_classes is None:
|
|
stats_classes = [STATS_TYPES['overview']]
|
|
self.stats_classes = stats_classes
|
|
self.flows = defaultdict(lambda: MultiStats(stats_classes))
|
|
self.packet_count = 0
|
|
|
|
def _get_flow_key(self, packet) -> Optional[FlowKey]:
|
|
"""Extract flow key from PyShark packet."""
|
|
try:
|
|
# Check for IP layer
|
|
if not hasattr(packet, 'ip'):
|
|
return None
|
|
|
|
src_ip = packet.ip.src
|
|
dst_ip = packet.ip.dst
|
|
protocol = packet.transport_layer if hasattr(packet, 'transport_layer') else 'IP'
|
|
|
|
# Get ports based on protocol
|
|
src_port = 0
|
|
dst_port = 0
|
|
|
|
if hasattr(packet, 'tcp'):
|
|
src_port = int(packet.tcp.srcport)
|
|
dst_port = int(packet.tcp.dstport)
|
|
protocol = 'TCP'
|
|
elif hasattr(packet, 'udp'):
|
|
src_port = int(packet.udp.srcport)
|
|
dst_port = int(packet.udp.dstport)
|
|
protocol = 'UDP'
|
|
|
|
# Check for extended protocol types
|
|
extended_type = None
|
|
if hasattr(packet, 'ptp'):
|
|
extended_type = 'PTP'
|
|
# Add more protocol detection here as needed
|
|
|
|
return FlowKey(src_ip, src_port, dst_ip, dst_port, protocol, extended_type)
|
|
|
|
except AttributeError:
|
|
return None
|
|
|
|
def _process_packet(self, packet):
|
|
"""Process a single packet."""
|
|
key = self._get_flow_key(packet)
|
|
if key:
|
|
# Get timestamp and size
|
|
timestamp = float(packet.sniff_timestamp) if hasattr(packet, 'sniff_timestamp') else 0
|
|
size = int(packet.length) if hasattr(packet, 'length') else 0
|
|
|
|
self.flows[key].add(timestamp, size, packet)
|
|
self.packet_count += 1
|
|
|
|
def analyze_pcap(self, file: str, display_filter: Optional[str] = None):
|
|
"""Analyze packets from a PCAP file."""
|
|
print(f"Analyzing: {file}")
|
|
if display_filter:
|
|
print(f"Filter: {display_filter}")
|
|
|
|
try:
|
|
# Use FileCapture for PCAP files
|
|
capture = pyshark.FileCapture(
|
|
file,
|
|
display_filter=display_filter,
|
|
use_json=True, # Use JSON output for better performance
|
|
include_raw=False # Don't include raw packet data
|
|
)
|
|
|
|
# Process packets
|
|
for packet in capture:
|
|
self._process_packet(packet)
|
|
# Show progress every 1000 packets
|
|
if self.packet_count % 1000 == 0:
|
|
print(f" Processed {self.packet_count} packets...")
|
|
|
|
capture.close()
|
|
print(f"Found {len(self.flows)} flows from {self.packet_count} packets")
|
|
|
|
except Exception as e:
|
|
print(f"Error analyzing PCAP: {e}")
|
|
|
|
def analyze_live(self, interface: str, count: int = 100,
|
|
display_filter: Optional[str] = None,
|
|
bpf_filter: Optional[str] = None):
|
|
"""Capture and analyze packets from a live interface."""
|
|
print(f"Capturing {count} packets on {interface}")
|
|
if display_filter:
|
|
print(f"Display filter: {display_filter}")
|
|
if bpf_filter:
|
|
print(f"BPF filter: {bpf_filter}")
|
|
|
|
try:
|
|
# Use LiveCapture for live capture
|
|
capture = pyshark.LiveCapture(
|
|
interface=interface,
|
|
display_filter=display_filter,
|
|
bpf_filter=bpf_filter,
|
|
use_json=True,
|
|
include_raw=False
|
|
)
|
|
|
|
# Capture packets
|
|
capture.sniff(packet_count=count)
|
|
|
|
# Process captured packets
|
|
for packet in capture:
|
|
self._process_packet(packet)
|
|
|
|
capture.close()
|
|
print(f"Found {len(self.flows)} flows from {self.packet_count} packets")
|
|
|
|
except Exception as e:
|
|
print(f"Error during live capture: {e}")
|
|
|
|
def summary(self) -> pd.DataFrame:
|
|
"""Generate summary DataFrame of all flows."""
|
|
rows = []
|
|
for key, multi_stats in self.flows.items():
|
|
row = {
|
|
'Src IP': key.src_ip,
|
|
'Src Port': key.src_port,
|
|
'Dst IP': key.dst_ip,
|
|
'Dst Port': key.dst_port,
|
|
'Proto': key.protocol
|
|
}
|
|
if key.extended_type:
|
|
row['Type'] = key.extended_type
|
|
row.update(multi_stats.get_combined_summary())
|
|
rows.append(row)
|
|
|
|
# Sort by packet count descending
|
|
df = pd.DataFrame(rows)
|
|
if not df.empty and 'Pkts' in df.columns:
|
|
df = df.sort_values('Pkts', ascending=False)
|
|
return df
|
|
|
|
def print_summary(self):
|
|
"""Print formatted summary of flows."""
|
|
df = self.summary()
|
|
if df.empty:
|
|
print("No flows detected")
|
|
return
|
|
|
|
print(f"\n{len(df)} flows:")
|
|
print(tabulate(df, headers='keys', tablefmt='plain', showindex=False))
|
|
|
|
if 'Pkts' in df.columns and 'Bytes' in df.columns:
|
|
print(f"\nTotals: {df['Pkts'].sum()} packets, {df['Bytes'].sum()} bytes")
|
|
|
|
def get_protocol_summary(self) -> pd.DataFrame:
|
|
"""Get summary grouped by protocol."""
|
|
df = self.summary()
|
|
if df.empty:
|
|
return df
|
|
|
|
# Group by protocol
|
|
protocol_summary = df.groupby('Proto').agg({
|
|
'Pkts': 'sum',
|
|
'Bytes': 'sum'
|
|
}).reset_index()
|
|
|
|
return protocol_summary
|
|
|
|
def apply_wireshark_filter(self, display_filter: str):
|
|
"""
|
|
Apply a Wireshark display filter to the analysis.
|
|
This demonstrates PyShark's ability to use Wireshark's filtering.
|
|
"""
|
|
filtered_flows = defaultdict(lambda: MultiStats(self.stats_classes))
|
|
|
|
# This would require re-processing with the filter
|
|
# Shown here as an example of the capability
|
|
print(f"Note: To apply Wireshark filters, re-analyze with display_filter parameter")
|
|
return filtered_flows |