Files
StreamLens/analyzer/analysis/flow_manager.py

328 lines
12 KiB
Python

"""
Flow tracking and management
"""
from typing import Dict, Set, Tuple
from ..models import FlowStats, FrameTypeStats
from ..protocols import Chapter10Dissector, PTPDissector, IENADissector, StandardProtocolDissectors
try:
from scapy.all import Packet, IP, UDP, TCP
except ImportError:
print("Error: scapy library required. Install with: pip install scapy")
import sys
sys.exit(1)
class FlowManager:
"""Manages network flows and frame type classification"""
def __init__(self, statistics_engine=None):
self.flows: Dict[Tuple[str, str], FlowStats] = {}
self.statistics_engine = statistics_engine
# Initialize dissectors
self.specialized_dissectors = {
'chapter10': Chapter10Dissector(),
'ptp': PTPDissector(),
'iena': IENADissector()
}
self.standard_dissectors = StandardProtocolDissectors()
def process_packet(self, packet: Packet, frame_num: int) -> None:
"""Process a single packet and update flow statistics"""
if not packet.haslayer(IP):
return
ip_layer = packet[IP]
src_ip = ip_layer.src
dst_ip = ip_layer.dst
timestamp = float(packet.time)
packet_size = len(packet)
# Determine basic protocol
protocols = self._detect_basic_protocols(packet)
# Create flow key
flow_key = (src_ip, dst_ip)
# Initialize flow stats if new
if flow_key not in self.flows:
self.flows[flow_key] = FlowStats(
src_ip=src_ip,
dst_ip=dst_ip,
frame_count=0,
timestamps=[],
frame_numbers=[],
inter_arrival_times=[],
avg_inter_arrival=0.0,
std_inter_arrival=0.0,
outlier_frames=[],
outlier_details=[],
total_bytes=0,
protocols=set(),
detected_protocol_types=set(),
frame_types={}
)
# Update flow stats
flow = self.flows[flow_key]
flow.frame_count += 1
flow.timestamps.append(timestamp)
flow.frame_numbers.append(frame_num)
flow.total_bytes += packet_size
flow.protocols.update(protocols)
# Enhanced protocol detection
dissection_results = self._dissect_packet(packet, frame_num)
enhanced_protocols = self._extract_enhanced_protocols(dissection_results)
flow.detected_protocol_types.update(enhanced_protocols)
# Add fallback protocol detection
fallback_protocols = self._detect_fallback_protocols(packet, dissection_results)
flow.detected_protocol_types.update(fallback_protocols)
# Classify and track frame types
frame_type = self._classify_frame_type(packet, dissection_results)
self._update_frame_type_stats(flow, frame_type, frame_num, timestamp, packet_size)
# Calculate inter-arrival time
if len(flow.timestamps) > 1:
inter_arrival = timestamp - flow.timestamps[-2]
flow.inter_arrival_times.append(inter_arrival)
# Update real-time statistics if enabled
if self.statistics_engine and self.statistics_engine.enable_realtime:
self.statistics_engine.update_realtime_statistics(flow_key, flow)
def _detect_basic_protocols(self, packet: Packet) -> Set[str]:
"""Detect basic transport protocols"""
protocols = set()
if packet.haslayer(UDP):
protocols.add('UDP')
if packet.haslayer(TCP):
protocols.add('TCP')
if not protocols:
protocols.add('OTHER')
return protocols
def _dissect_packet(self, packet: Packet, frame_num: int) -> Dict:
"""Comprehensive packet dissection"""
result = {
'frame_number': frame_num,
'timestamp': float(packet.time),
'size': len(packet),
'layers': {},
'protocols': []
}
# Apply standard dissectors
standard_layers = self.standard_dissectors.dissect_all(packet)
result['layers'].update(standard_layers)
# Apply specialized protocol dissectors
for name, dissector in self.specialized_dissectors.items():
try:
if dissector.can_dissect(packet):
dissection = dissector.dissect(packet)
if dissection:
result['layers'][name] = dissection.fields
result['protocols'].append(dissection.protocol.name)
if dissection.errors:
result['layers'][name]['errors'] = dissection.errors
if dissection.payload:
result['layers'][name]['payload_size'] = len(dissection.payload)
except Exception as e:
result['layers'][name] = {'error': str(e)}
return result
def _extract_enhanced_protocols(self, dissection: Dict) -> Set[str]:
"""Extract enhanced protocol types from dissection"""
protocols = set()
if dissection.get('protocols'):
protocols.update(dissection['protocols'])
return protocols
def _detect_fallback_protocols(self, packet: Packet, dissection: Dict) -> Set[str]:
"""Detect protocol types with fallback to generic descriptions"""
protocol_types = set()
if packet.haslayer(UDP):
udp_layer = packet[UDP]
sport, dport = udp_layer.sport, udp_layer.dport
# Check for common protocols by port
port_protocols = {
(67, 68): 'DHCP',
(53,): 'DNS',
(123,): 'NTP',
(161, 162): 'SNMP',
(69,): 'TFTP',
(319, 320): 'PTP',
(50000, 50001): 'IENA'
}
for ports, protocol in port_protocols.items():
if sport in ports or dport in ports:
protocol_types.add(protocol)
break
else:
protocol_types.add('UDP')
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
sport, dport = tcp_layer.sport, tcp_layer.dport
tcp_protocols = {
(80,): 'HTTP',
(443,): 'HTTPS',
(22,): 'SSH',
(23,): 'Telnet',
(21,): 'FTP',
(25,): 'SMTP',
(110,): 'POP3',
(143,): 'IMAP'
}
for ports, protocol in tcp_protocols.items():
if sport in ports or dport in ports:
protocol_types.add(protocol)
break
else:
protocol_types.add('TCP')
# Check for IGMP and ICMP
if packet.haslayer(IP):
ip_layer = packet[IP]
if ip_layer.proto == 2: # IGMP protocol number
protocol_types.add('IGMP')
elif ip_layer.proto == 1: # ICMP protocol number
protocol_types.add('ICMP')
# Check for multicast addresses
if packet.haslayer(IP):
ip_layer = packet[IP]
dst_ip = ip_layer.dst
if dst_ip.startswith('224.') or dst_ip.startswith('239.'):
protocol_types.add('Multicast')
return protocol_types
def _classify_frame_type(self, packet: Packet, dissection: Dict) -> str:
"""Classify the frame type based on dissection results"""
layers = dissection.get('layers', {})
# Check for Chapter 10 first
if 'chapter10' in layers and not layers['chapter10'].get('error'):
ch10_info = layers['chapter10']
# Check if it's a TMATS frame
if self._is_tmats_frame(packet, ch10_info):
return 'TMATS'
else:
return 'CH10-Data'
# Check for other specialized protocols
if 'ptp' in layers and not layers['ptp'].get('error'):
ptp_info = layers['ptp']
msg_type = ptp_info.get('message_type_name', 'Unknown')
return f'PTP-{msg_type}'
if 'iena' in layers and not layers['iena'].get('error'):
iena_info = layers['iena']
packet_type = iena_info.get('packet_type_name', 'Unknown')
return f'IENA-{packet_type}'
# Fallback to basic protocol classification
if packet.haslayer(UDP):
udp_layer = packet[UDP]
sport, dport = udp_layer.sport, udp_layer.dport
if sport == 53 or dport == 53:
return 'DNS'
elif sport in [67, 68] or dport in [67, 68]:
return 'DHCP'
elif sport == 123 or dport == 123:
return 'NTP'
else:
return 'UDP'
if packet.haslayer(TCP):
tcp_layer = packet[TCP]
sport, dport = tcp_layer.sport, tcp_layer.dport
if sport == 80 or dport == 80:
return 'HTTP'
elif sport == 443 or dport == 443:
return 'HTTPS'
else:
return 'TCP'
# Check for other protocols
if packet.haslayer(IP):
ip_layer = packet[IP]
if ip_layer.proto == 2:
return 'IGMP'
elif ip_layer.proto == 1:
return 'ICMP'
return 'OTHER'
def _is_tmats_frame(self, packet: Packet, ch10_info: Dict) -> bool:
"""Check if a Chapter 10 frame contains TMATS data"""
data_type = ch10_info.get('data_type', 0)
# Data type 0x01 is typically TMATS
if data_type == 0x01:
return True
# Also check for TMATS text patterns in the payload
if packet.haslayer('Raw'):
from scapy.all import Raw
raw_data = bytes(packet[Raw])
# Look for TMATS-like patterns (ASCII text with TMATS keywords)
try:
# Check if we can find TMATS signature patterns
text_sample = raw_data[50:200] # Sample middle section to avoid headers
if b'\\' in text_sample and (b':' in text_sample or b';' in text_sample):
# Look for TMATS-style key-value pairs
if any(keyword in text_sample.upper() for keyword in [b'TMATS', b'R-', b'G-', b'P-', b'T-']):
return True
except:
pass
return False
def _update_frame_type_stats(self, flow: FlowStats, frame_type: str,
frame_num: int, timestamp: float, packet_size: int):
"""Update statistics for a specific frame type"""
if frame_type not in flow.frame_types:
flow.frame_types[frame_type] = FrameTypeStats(frame_type=frame_type)
ft_stats = flow.frame_types[frame_type]
ft_stats.count += 1
ft_stats.total_bytes += packet_size
ft_stats.timestamps.append(timestamp)
ft_stats.frame_numbers.append(frame_num)
# Calculate inter-arrival time for this frame type
if len(ft_stats.timestamps) > 1:
inter_arrival = timestamp - ft_stats.timestamps[-2]
ft_stats.inter_arrival_times.append(inter_arrival)
def get_flows_summary(self) -> Dict:
"""Get summary of all flows"""
unique_ips = set()
for flow in self.flows.values():
unique_ips.add(flow.src_ip)
unique_ips.add(flow.dst_ip)
return {
'total_flows': len(self.flows),
'unique_ips': len(unique_ips),
'flows': self.flows
}