""" Flow tracking and management """ from typing import Dict, Set, Tuple from ..models import FlowStats, FrameTypeStats from ..protocols import Chapter10Dissector, PTPDissector, IENADissector, StandardProtocolDissectors try: from scapy.all import Packet, IP, UDP, TCP except ImportError: print("Error: scapy library required. Install with: pip install scapy") import sys sys.exit(1) class FlowManager: """Manages network flows and frame type classification""" def __init__(self, statistics_engine=None): self.flows: Dict[Tuple[str, str], FlowStats] = {} self.statistics_engine = statistics_engine # Initialize dissectors self.specialized_dissectors = { 'chapter10': Chapter10Dissector(), 'ptp': PTPDissector(), 'iena': IENADissector() } self.standard_dissectors = StandardProtocolDissectors() def process_packet(self, packet: Packet, frame_num: int) -> None: """Process a single packet and update flow statistics""" if not packet.haslayer(IP): return ip_layer = packet[IP] src_ip = ip_layer.src dst_ip = ip_layer.dst timestamp = float(packet.time) packet_size = len(packet) # Determine basic protocol protocols = self._detect_basic_protocols(packet) # Create flow key flow_key = (src_ip, dst_ip) # Initialize flow stats if new if flow_key not in self.flows: self.flows[flow_key] = FlowStats( src_ip=src_ip, dst_ip=dst_ip, frame_count=0, timestamps=[], frame_numbers=[], inter_arrival_times=[], avg_inter_arrival=0.0, std_inter_arrival=0.0, outlier_frames=[], outlier_details=[], total_bytes=0, protocols=set(), detected_protocol_types=set(), frame_types={} ) # Update flow stats flow = self.flows[flow_key] flow.frame_count += 1 flow.timestamps.append(timestamp) flow.frame_numbers.append(frame_num) flow.total_bytes += packet_size flow.protocols.update(protocols) # Enhanced protocol detection dissection_results = self._dissect_packet(packet, frame_num) enhanced_protocols = self._extract_enhanced_protocols(dissection_results) flow.detected_protocol_types.update(enhanced_protocols) # Add fallback protocol detection fallback_protocols = self._detect_fallback_protocols(packet, dissection_results) flow.detected_protocol_types.update(fallback_protocols) # Classify and track frame types frame_type = self._classify_frame_type(packet, dissection_results) self._update_frame_type_stats(flow, frame_type, frame_num, timestamp, packet_size) # Calculate inter-arrival time if len(flow.timestamps) > 1: inter_arrival = timestamp - flow.timestamps[-2] flow.inter_arrival_times.append(inter_arrival) # Update real-time statistics if enabled if self.statistics_engine and self.statistics_engine.enable_realtime: self.statistics_engine.update_realtime_statistics(flow_key, flow) def _detect_basic_protocols(self, packet: Packet) -> Set[str]: """Detect basic transport protocols""" protocols = set() if packet.haslayer(UDP): protocols.add('UDP') if packet.haslayer(TCP): protocols.add('TCP') if not protocols: protocols.add('OTHER') return protocols def _dissect_packet(self, packet: Packet, frame_num: int) -> Dict: """Comprehensive packet dissection""" result = { 'frame_number': frame_num, 'timestamp': float(packet.time), 'size': len(packet), 'layers': {}, 'protocols': [] } # Apply standard dissectors standard_layers = self.standard_dissectors.dissect_all(packet) result['layers'].update(standard_layers) # Apply specialized protocol dissectors for name, dissector in self.specialized_dissectors.items(): try: if dissector.can_dissect(packet): dissection = dissector.dissect(packet) if dissection: result['layers'][name] = dissection.fields result['protocols'].append(dissection.protocol.name) if dissection.errors: result['layers'][name]['errors'] = dissection.errors if dissection.payload: result['layers'][name]['payload_size'] = len(dissection.payload) except Exception as e: result['layers'][name] = {'error': str(e)} return result def _extract_enhanced_protocols(self, dissection: Dict) -> Set[str]: """Extract enhanced protocol types from dissection""" protocols = set() if dissection.get('protocols'): protocols.update(dissection['protocols']) return protocols def _detect_fallback_protocols(self, packet: Packet, dissection: Dict) -> Set[str]: """Detect protocol types with fallback to generic descriptions""" protocol_types = set() if packet.haslayer(UDP): udp_layer = packet[UDP] sport, dport = udp_layer.sport, udp_layer.dport # Check for common protocols by port port_protocols = { (67, 68): 'DHCP', (53,): 'DNS', (123,): 'NTP', (161, 162): 'SNMP', (69,): 'TFTP', (319, 320): 'PTP', (50000, 50001): 'IENA' } for ports, protocol in port_protocols.items(): if sport in ports or dport in ports: protocol_types.add(protocol) break else: protocol_types.add('UDP') if packet.haslayer(TCP): tcp_layer = packet[TCP] sport, dport = tcp_layer.sport, tcp_layer.dport tcp_protocols = { (80,): 'HTTP', (443,): 'HTTPS', (22,): 'SSH', (23,): 'Telnet', (21,): 'FTP', (25,): 'SMTP', (110,): 'POP3', (143,): 'IMAP' } for ports, protocol in tcp_protocols.items(): if sport in ports or dport in ports: protocol_types.add(protocol) break else: protocol_types.add('TCP') # Check for IGMP and ICMP if packet.haslayer(IP): ip_layer = packet[IP] if ip_layer.proto == 2: # IGMP protocol number protocol_types.add('IGMP') elif ip_layer.proto == 1: # ICMP protocol number protocol_types.add('ICMP') # Check for multicast addresses if packet.haslayer(IP): ip_layer = packet[IP] dst_ip = ip_layer.dst if dst_ip.startswith('224.') or dst_ip.startswith('239.'): protocol_types.add('Multicast') return protocol_types def _classify_frame_type(self, packet: Packet, dissection: Dict) -> str: """Classify the frame type based on dissection results""" layers = dissection.get('layers', {}) # Check for Chapter 10 first if 'chapter10' in layers and not layers['chapter10'].get('error'): ch10_info = layers['chapter10'] # Check if it's a TMATS frame if self._is_tmats_frame(packet, ch10_info): return 'TMATS' else: return 'CH10-Data' # Check for other specialized protocols if 'ptp' in layers and not layers['ptp'].get('error'): ptp_info = layers['ptp'] msg_type = ptp_info.get('message_type_name', 'Unknown') return f'PTP-{msg_type}' if 'iena' in layers and not layers['iena'].get('error'): iena_info = layers['iena'] packet_type = iena_info.get('packet_type_name', 'Unknown') return f'IENA-{packet_type}' # Fallback to basic protocol classification if packet.haslayer(UDP): udp_layer = packet[UDP] sport, dport = udp_layer.sport, udp_layer.dport if sport == 53 or dport == 53: return 'DNS' elif sport in [67, 68] or dport in [67, 68]: return 'DHCP' elif sport == 123 or dport == 123: return 'NTP' else: return 'UDP' if packet.haslayer(TCP): tcp_layer = packet[TCP] sport, dport = tcp_layer.sport, tcp_layer.dport if sport == 80 or dport == 80: return 'HTTP' elif sport == 443 or dport == 443: return 'HTTPS' else: return 'TCP' # Check for other protocols if packet.haslayer(IP): ip_layer = packet[IP] if ip_layer.proto == 2: return 'IGMP' elif ip_layer.proto == 1: return 'ICMP' return 'OTHER' def _is_tmats_frame(self, packet: Packet, ch10_info: Dict) -> bool: """Check if a Chapter 10 frame contains TMATS data""" data_type = ch10_info.get('data_type', 0) # Data type 0x01 is typically TMATS if data_type == 0x01: return True # Also check for TMATS text patterns in the payload if packet.haslayer('Raw'): from scapy.all import Raw raw_data = bytes(packet[Raw]) # Look for TMATS-like patterns (ASCII text with TMATS keywords) try: # Check if we can find TMATS signature patterns text_sample = raw_data[50:200] # Sample middle section to avoid headers if b'\\' in text_sample and (b':' in text_sample or b';' in text_sample): # Look for TMATS-style key-value pairs if any(keyword in text_sample.upper() for keyword in [b'TMATS', b'R-', b'G-', b'P-', b'T-']): return True except: pass return False def _update_frame_type_stats(self, flow: FlowStats, frame_type: str, frame_num: int, timestamp: float, packet_size: int): """Update statistics for a specific frame type""" if frame_type not in flow.frame_types: flow.frame_types[frame_type] = FrameTypeStats(frame_type=frame_type) ft_stats = flow.frame_types[frame_type] ft_stats.count += 1 ft_stats.total_bytes += packet_size ft_stats.timestamps.append(timestamp) ft_stats.frame_numbers.append(frame_num) # Calculate inter-arrival time for this frame type if len(ft_stats.timestamps) > 1: inter_arrival = timestamp - ft_stats.timestamps[-2] ft_stats.inter_arrival_times.append(inter_arrival) def get_flows_summary(self) -> Dict: """Get summary of all flows""" unique_ips = set() for flow in self.flows.values(): unique_ips.add(flow.src_ip) unique_ips.add(flow.dst_ip) return { 'total_flows': len(self.flows), 'unique_ips': len(unique_ips), 'flows': self.flows }