progress?
This commit is contained in:
@@ -1,8 +1,54 @@
|
||||
"""
|
||||
Data models for the Ethernet Traffic Analyzer
|
||||
StreamLens Data Models
|
||||
|
||||
This module provides the core data structures used throughout StreamLens for
|
||||
representing network flows, protocol information, and decoded packet data.
|
||||
|
||||
The models are organized into several categories:
|
||||
- Core models: FlowStats, FrameTypeStats
|
||||
- Protocol models: ProtocolInfo, DecodedField, ProtocolRegistry
|
||||
- Analysis models: EnhancedAnalysisData, TimingAnalysis
|
||||
- Result models: AnalysisResult, DissectionResult
|
||||
"""
|
||||
|
||||
# Core data models
|
||||
from .flow_stats import FlowStats, FrameTypeStats
|
||||
from .analysis_results import AnalysisResult
|
||||
from .analysis_results import AnalysisResult, DissectionResult
|
||||
|
||||
__all__ = ['FlowStats', 'FrameTypeStats', 'AnalysisResult']
|
||||
# Protocol models (new)
|
||||
from .protocols import (
|
||||
ProtocolInfo,
|
||||
DecodedField,
|
||||
ProtocolRegistry,
|
||||
StandardProtocol,
|
||||
EnhancedProtocol
|
||||
)
|
||||
|
||||
# Enhanced analysis models (refactored)
|
||||
from .enhanced_analysis import (
|
||||
EnhancedAnalysisData,
|
||||
TimingAnalysis,
|
||||
QualityMetrics,
|
||||
DecodedData
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
# Core models
|
||||
'FlowStats',
|
||||
'FrameTypeStats',
|
||||
'AnalysisResult',
|
||||
'DissectionResult',
|
||||
|
||||
# Protocol models
|
||||
'ProtocolInfo',
|
||||
'DecodedField',
|
||||
'ProtocolRegistry',
|
||||
'StandardProtocol',
|
||||
'EnhancedProtocol',
|
||||
|
||||
# Enhanced analysis
|
||||
'EnhancedAnalysisData',
|
||||
'TimingAnalysis',
|
||||
'QualityMetrics',
|
||||
'DecodedData'
|
||||
]
|
||||
289
analyzer/models/enhanced_analysis.py
Normal file
289
analyzer/models/enhanced_analysis.py
Normal file
@@ -0,0 +1,289 @@
|
||||
"""
|
||||
Enhanced Analysis Data Models
|
||||
|
||||
This module defines data structures for enhanced protocol analysis including
|
||||
timing analysis, quality metrics, and decoded data representation.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Set, Any, Optional
|
||||
from enum import Enum
|
||||
|
||||
|
||||
class TimingQuality(Enum):
|
||||
"""Timing quality classifications"""
|
||||
EXCELLENT = "excellent" # < 1ppm drift, stable
|
||||
GOOD = "good" # 1-10ppm drift, mostly stable
|
||||
MODERATE = "moderate" # 10-100ppm drift, variable
|
||||
POOR = "poor" # > 100ppm drift, unstable
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class TimingStability(Enum):
|
||||
"""Timing stability classifications"""
|
||||
STABLE = "stable" # Consistent timing behavior
|
||||
VARIABLE = "variable" # Some timing variations
|
||||
UNSTABLE = "unstable" # Highly variable timing
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
class DataType(Enum):
|
||||
"""Primary data types in enhanced protocols"""
|
||||
ANALOG = "analog"
|
||||
PCM = "pcm"
|
||||
DISCRETE = "discrete"
|
||||
TIME = "time"
|
||||
VIDEO = "video"
|
||||
TMATS = "tmats"
|
||||
UNKNOWN = "unknown"
|
||||
|
||||
|
||||
@dataclass
|
||||
class TimingAnalysis:
|
||||
"""Timing analysis results for enhanced protocols"""
|
||||
avg_clock_drift_ppm: float = 0.0
|
||||
max_clock_drift_ppm: float = 0.0
|
||||
min_clock_drift_ppm: float = 0.0
|
||||
drift_variance: float = 0.0
|
||||
|
||||
quality: TimingQuality = TimingQuality.UNKNOWN
|
||||
stability: TimingStability = TimingStability.UNKNOWN
|
||||
|
||||
# Timing accuracy metrics
|
||||
timing_accuracy_percent: float = 0.0
|
||||
sync_errors: int = 0
|
||||
timing_anomalies: int = 0
|
||||
anomaly_rate_percent: float = 0.0
|
||||
|
||||
# Internal timing capabilities
|
||||
has_internal_timing: bool = False
|
||||
rtc_sync_available: bool = False
|
||||
|
||||
def calculate_quality(self) -> TimingQuality:
|
||||
"""Calculate timing quality based on drift measurements"""
|
||||
max_drift = abs(max(self.max_clock_drift_ppm, self.min_clock_drift_ppm, key=abs))
|
||||
|
||||
if max_drift < 1.0:
|
||||
return TimingQuality.EXCELLENT
|
||||
elif max_drift < 10.0:
|
||||
return TimingQuality.GOOD
|
||||
elif max_drift < 100.0:
|
||||
return TimingQuality.MODERATE
|
||||
else:
|
||||
return TimingQuality.POOR
|
||||
|
||||
def calculate_stability(self) -> TimingStability:
|
||||
"""Calculate timing stability based on variance"""
|
||||
if self.drift_variance < 1.0:
|
||||
return TimingStability.STABLE
|
||||
elif self.drift_variance < 25.0:
|
||||
return TimingStability.VARIABLE
|
||||
else:
|
||||
return TimingStability.UNSTABLE
|
||||
|
||||
|
||||
@dataclass
|
||||
class QualityMetrics:
|
||||
"""Quality metrics for enhanced protocol data"""
|
||||
# Frame quality metrics
|
||||
avg_frame_quality_percent: float = 0.0
|
||||
frame_quality_samples: List[float] = field(default_factory=list)
|
||||
|
||||
# Signal quality metrics
|
||||
avg_signal_quality_percent: float = 0.0
|
||||
signal_quality_samples: List[float] = field(default_factory=list)
|
||||
|
||||
# Error counts
|
||||
sequence_gaps: int = 0
|
||||
format_errors: int = 0
|
||||
overflow_errors: int = 0
|
||||
checksum_errors: int = 0
|
||||
|
||||
# Confidence metrics
|
||||
avg_confidence_score: float = 0.0
|
||||
confidence_samples: List[float] = field(default_factory=list)
|
||||
low_confidence_frames: int = 0
|
||||
|
||||
# Data integrity
|
||||
corrupted_frames: int = 0
|
||||
missing_frames: int = 0
|
||||
duplicate_frames: int = 0
|
||||
|
||||
def calculate_overall_quality(self) -> float:
|
||||
"""Calculate overall quality score (0-100)"""
|
||||
if not self.frame_quality_samples and not self.signal_quality_samples:
|
||||
return 0.0
|
||||
|
||||
frame_score = self.avg_frame_quality_percent if self.frame_quality_samples else 100.0
|
||||
signal_score = self.avg_signal_quality_percent if self.signal_quality_samples else 100.0
|
||||
confidence_score = self.avg_confidence_score * 100 if self.confidence_samples else 100.0
|
||||
|
||||
# Weight the scores
|
||||
weighted_score = (frame_score * 0.4 + signal_score * 0.4 + confidence_score * 0.2)
|
||||
|
||||
# Apply error penalties
|
||||
total_frames = len(self.frame_quality_samples) or 1
|
||||
error_rate = (self.format_errors + self.overflow_errors + self.corrupted_frames) / total_frames
|
||||
penalty = min(error_rate * 50, 50) # Max 50% penalty
|
||||
|
||||
return max(0.0, weighted_score - penalty)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodedData:
|
||||
"""Container for decoded protocol data"""
|
||||
# Channel information
|
||||
channel_count: int = 0
|
||||
analog_channels: int = 0
|
||||
pcm_channels: int = 0
|
||||
discrete_channels: int = 0
|
||||
|
||||
# Data type classification
|
||||
primary_data_type: DataType = DataType.UNKNOWN
|
||||
secondary_data_types: Set[DataType] = field(default_factory=set)
|
||||
|
||||
# Decoded field information
|
||||
sample_decoded_fields: Dict[str, Any] = field(default_factory=dict)
|
||||
available_field_names: List[str] = field(default_factory=list)
|
||||
field_count: int = 0
|
||||
critical_fields: List[str] = field(default_factory=list)
|
||||
|
||||
# Frame type analysis
|
||||
frame_types: Set[str] = field(default_factory=set)
|
||||
frame_type_distribution: Dict[str, int] = field(default_factory=dict)
|
||||
|
||||
# Special frame counts
|
||||
tmats_frames: int = 0
|
||||
setup_frames: int = 0
|
||||
data_frames: int = 0
|
||||
|
||||
# Decoder metadata
|
||||
decoder_type: str = "Standard"
|
||||
decoder_version: Optional[str] = None
|
||||
decode_success_rate: float = 1.0
|
||||
|
||||
def add_frame_type(self, frame_type: str):
|
||||
"""Add a frame type to the analysis"""
|
||||
self.frame_types.add(frame_type)
|
||||
self.frame_type_distribution[frame_type] = self.frame_type_distribution.get(frame_type, 0) + 1
|
||||
|
||||
def get_dominant_frame_type(self) -> Optional[str]:
|
||||
"""Get the most common frame type"""
|
||||
if not self.frame_type_distribution:
|
||||
return None
|
||||
return max(self.frame_type_distribution.items(), key=lambda x: x[1])[0]
|
||||
|
||||
def update_data_type_classification(self):
|
||||
"""Update primary data type based on channel counts"""
|
||||
if self.analog_channels > 0 and self.analog_channels >= self.pcm_channels:
|
||||
self.primary_data_type = DataType.ANALOG
|
||||
elif self.pcm_channels > 0:
|
||||
self.primary_data_type = DataType.PCM
|
||||
elif self.discrete_channels > 0:
|
||||
self.primary_data_type = DataType.DISCRETE
|
||||
elif self.tmats_frames > 0:
|
||||
self.primary_data_type = DataType.TMATS
|
||||
|
||||
# Add secondary types
|
||||
if self.analog_channels > 0:
|
||||
self.secondary_data_types.add(DataType.ANALOG)
|
||||
if self.pcm_channels > 0:
|
||||
self.secondary_data_types.add(DataType.PCM)
|
||||
if self.discrete_channels > 0:
|
||||
self.secondary_data_types.add(DataType.DISCRETE)
|
||||
if self.tmats_frames > 0:
|
||||
self.secondary_data_types.add(DataType.TMATS)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnhancedAnalysisData:
|
||||
"""Complete enhanced analysis data combining all analysis types"""
|
||||
timing: TimingAnalysis = field(default_factory=TimingAnalysis)
|
||||
quality: QualityMetrics = field(default_factory=QualityMetrics)
|
||||
decoded: DecodedData = field(default_factory=DecodedData)
|
||||
|
||||
# Legacy compatibility fields (will be deprecated)
|
||||
avg_clock_drift_ppm: float = field(init=False)
|
||||
max_clock_drift_ppm: float = field(init=False)
|
||||
timing_quality: str = field(init=False)
|
||||
timing_stability: str = field(init=False)
|
||||
anomaly_rate: float = field(init=False)
|
||||
avg_confidence_score: float = field(init=False)
|
||||
avg_frame_quality: float = field(init=False)
|
||||
sequence_gaps: int = field(init=False)
|
||||
rtc_sync_errors: int = field(init=False)
|
||||
format_errors: int = field(init=False)
|
||||
overflow_errors: int = field(init=False)
|
||||
channel_count: int = field(init=False)
|
||||
analog_channels: int = field(init=False)
|
||||
pcm_channels: int = field(init=False)
|
||||
tmats_frames: int = field(init=False)
|
||||
has_internal_timing: bool = field(init=False)
|
||||
primary_data_type: str = field(init=False)
|
||||
decoder_type: str = field(init=False)
|
||||
sample_decoded_fields: Dict[str, Any] = field(init=False)
|
||||
available_field_names: List[str] = field(init=False)
|
||||
field_count: int = field(init=False)
|
||||
frame_types: Set[str] = field(init=False)
|
||||
timing_accuracy: float = field(init=False)
|
||||
signal_quality: float = field(init=False)
|
||||
|
||||
def __post_init__(self):
|
||||
"""Initialize legacy compatibility properties"""
|
||||
self._update_legacy_fields()
|
||||
|
||||
def _update_legacy_fields(self):
|
||||
"""Update legacy fields from new structured data"""
|
||||
# Timing fields
|
||||
self.avg_clock_drift_ppm = self.timing.avg_clock_drift_ppm
|
||||
self.max_clock_drift_ppm = self.timing.max_clock_drift_ppm
|
||||
self.timing_quality = self.timing.quality.value
|
||||
self.timing_stability = self.timing.stability.value
|
||||
self.anomaly_rate = self.timing.anomaly_rate_percent
|
||||
self.has_internal_timing = self.timing.has_internal_timing
|
||||
self.timing_accuracy = self.timing.timing_accuracy_percent
|
||||
|
||||
# Quality fields
|
||||
self.avg_confidence_score = self.quality.avg_confidence_score
|
||||
self.avg_frame_quality = self.quality.avg_frame_quality_percent
|
||||
self.sequence_gaps = self.quality.sequence_gaps
|
||||
self.rtc_sync_errors = self.timing.sync_errors
|
||||
self.format_errors = self.quality.format_errors
|
||||
self.overflow_errors = self.quality.overflow_errors
|
||||
self.signal_quality = self.quality.avg_signal_quality_percent
|
||||
|
||||
# Decoded data fields
|
||||
self.channel_count = self.decoded.channel_count
|
||||
self.analog_channels = self.decoded.analog_channels
|
||||
self.pcm_channels = self.decoded.pcm_channels
|
||||
self.tmats_frames = self.decoded.tmats_frames
|
||||
self.primary_data_type = self.decoded.primary_data_type.value
|
||||
self.decoder_type = self.decoded.decoder_type
|
||||
self.sample_decoded_fields = self.decoded.sample_decoded_fields
|
||||
self.available_field_names = self.decoded.available_field_names
|
||||
self.field_count = self.decoded.field_count
|
||||
self.frame_types = self.decoded.frame_types
|
||||
|
||||
def update_from_components(self):
|
||||
"""Update legacy fields when component objects change"""
|
||||
self._update_legacy_fields()
|
||||
|
||||
def get_overall_health_score(self) -> float:
|
||||
"""Calculate overall health score for the enhanced analysis"""
|
||||
quality_score = self.quality.calculate_overall_quality()
|
||||
|
||||
# Factor in timing quality
|
||||
timing_score = 100.0
|
||||
if self.timing.quality == TimingQuality.EXCELLENT:
|
||||
timing_score = 100.0
|
||||
elif self.timing.quality == TimingQuality.GOOD:
|
||||
timing_score = 80.0
|
||||
elif self.timing.quality == TimingQuality.MODERATE:
|
||||
timing_score = 60.0
|
||||
elif self.timing.quality == TimingQuality.POOR:
|
||||
timing_score = 30.0
|
||||
else:
|
||||
timing_score = 50.0 # Unknown
|
||||
|
||||
# Weight quality more heavily than timing
|
||||
return (quality_score * 0.7 + timing_score * 0.3)
|
||||
258
analyzer/models/protocols.py
Normal file
258
analyzer/models/protocols.py
Normal file
@@ -0,0 +1,258 @@
|
||||
"""
|
||||
Protocol Information Data Models
|
||||
|
||||
This module defines data structures for representing protocol information,
|
||||
decoded fields, and protocol registry management.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Dict, List, Set, Optional, Any, Union
|
||||
from enum import Enum, IntEnum
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
|
||||
class ProtocolType(IntEnum):
|
||||
"""Protocol type identifiers"""
|
||||
UNKNOWN = 0
|
||||
|
||||
# Standard protocols
|
||||
UDP = 10
|
||||
TCP = 11
|
||||
ICMP = 12
|
||||
IGMP = 13
|
||||
|
||||
# Enhanced protocols
|
||||
CHAPTER10 = 100
|
||||
CH10 = 100 # Alias for CHAPTER10
|
||||
PTP = 101
|
||||
IENA = 102
|
||||
NTP = 103
|
||||
|
||||
|
||||
class ProtocolCategory(Enum):
|
||||
"""Protocol categories for organization"""
|
||||
TRANSPORT = "transport" # UDP, TCP, ICMP
|
||||
NETWORK = "network" # IP, IGMP
|
||||
ENHANCED = "enhanced" # CH10, PTP, IENA
|
||||
TIMING = "timing" # PTP, NTP
|
||||
TELEMETRY = "telemetry" # CH10, IENA
|
||||
|
||||
|
||||
class FieldType(Enum):
|
||||
"""Types of decoded fields"""
|
||||
INTEGER = "integer"
|
||||
FLOAT = "float"
|
||||
STRING = "string"
|
||||
BOOLEAN = "boolean"
|
||||
TIMESTAMP = "timestamp"
|
||||
IP_ADDRESS = "ip_address"
|
||||
MAC_ADDRESS = "mac_address"
|
||||
BINARY = "binary"
|
||||
ENUM = "enum"
|
||||
|
||||
|
||||
@dataclass
|
||||
class DecodedField:
|
||||
"""Represents a single decoded field from a protocol"""
|
||||
name: str
|
||||
value: Any
|
||||
field_type: FieldType
|
||||
description: Optional[str] = None
|
||||
unit: Optional[str] = None # e.g., "ms", "bytes", "ppm"
|
||||
confidence: float = 1.0 # 0.0 to 1.0
|
||||
is_critical: bool = False # Critical field for protocol operation
|
||||
|
||||
def __str__(self) -> str:
|
||||
unit_str = f" {self.unit}" if self.unit else ""
|
||||
return f"{self.name}: {self.value}{unit_str}"
|
||||
|
||||
def format_value(self) -> str:
|
||||
"""Format the value for display"""
|
||||
if self.field_type == FieldType.TIMESTAMP:
|
||||
import datetime
|
||||
if isinstance(self.value, (int, float)):
|
||||
dt = datetime.datetime.fromtimestamp(self.value)
|
||||
return dt.strftime("%H:%M:%S.%f")[:-3]
|
||||
elif self.field_type == FieldType.FLOAT:
|
||||
return f"{self.value:.3f}"
|
||||
elif self.field_type == FieldType.IP_ADDRESS:
|
||||
return str(self.value)
|
||||
elif self.field_type == FieldType.BINARY:
|
||||
if isinstance(self.value, bytes):
|
||||
return self.value.hex()[:16] + "..." if len(self.value) > 8 else self.value.hex()
|
||||
|
||||
return str(self.value)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolInfo:
|
||||
"""Information about a detected protocol"""
|
||||
protocol_type: ProtocolType
|
||||
name: str
|
||||
category: ProtocolCategory
|
||||
version: Optional[str] = None
|
||||
confidence: float = 1.0 # Detection confidence 0.0 to 1.0
|
||||
|
||||
# Protocol-specific metadata
|
||||
port: Optional[int] = None
|
||||
subtype: Optional[str] = None # e.g., "CH10-Data", "PTP-Sync"
|
||||
vendor: Optional[str] = None
|
||||
|
||||
def __str__(self) -> str:
|
||||
version_str = f" v{self.version}" if self.version else ""
|
||||
subtype_str = f"-{self.subtype}" if self.subtype else ""
|
||||
return f"{self.name}{subtype_str}{version_str}"
|
||||
|
||||
@property
|
||||
def is_enhanced(self) -> bool:
|
||||
"""Check if this is an enhanced protocol requiring special handling"""
|
||||
return self.category in [ProtocolCategory.ENHANCED, ProtocolCategory.TIMING, ProtocolCategory.TELEMETRY]
|
||||
|
||||
|
||||
class StandardProtocol:
|
||||
"""Standard protocol definitions"""
|
||||
|
||||
UDP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.UDP,
|
||||
name="UDP",
|
||||
category=ProtocolCategory.TRANSPORT
|
||||
)
|
||||
|
||||
TCP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.TCP,
|
||||
name="TCP",
|
||||
category=ProtocolCategory.TRANSPORT
|
||||
)
|
||||
|
||||
ICMP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.ICMP,
|
||||
name="ICMP",
|
||||
category=ProtocolCategory.NETWORK
|
||||
)
|
||||
|
||||
IGMP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.IGMP,
|
||||
name="IGMP",
|
||||
category=ProtocolCategory.NETWORK
|
||||
)
|
||||
|
||||
|
||||
class EnhancedProtocol:
|
||||
"""Enhanced protocol definitions"""
|
||||
|
||||
CHAPTER10 = ProtocolInfo(
|
||||
protocol_type=ProtocolType.CHAPTER10,
|
||||
name="Chapter 10",
|
||||
category=ProtocolCategory.TELEMETRY
|
||||
)
|
||||
|
||||
PTP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.PTP,
|
||||
name="PTP",
|
||||
category=ProtocolCategory.TIMING
|
||||
)
|
||||
|
||||
IENA = ProtocolInfo(
|
||||
protocol_type=ProtocolType.IENA,
|
||||
name="IENA",
|
||||
category=ProtocolCategory.TELEMETRY
|
||||
)
|
||||
|
||||
NTP = ProtocolInfo(
|
||||
protocol_type=ProtocolType.NTP,
|
||||
name="NTP",
|
||||
category=ProtocolCategory.TIMING
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtocolDecodeResult:
|
||||
"""Result of protocol decoding"""
|
||||
protocol_info: ProtocolInfo
|
||||
fields: List[DecodedField] = field(default_factory=list)
|
||||
frame_type: Optional[str] = None # e.g., "CH10-Data", "PTP-Sync"
|
||||
payload_size: int = 0
|
||||
errors: List[str] = field(default_factory=list)
|
||||
warnings: List[str] = field(default_factory=list)
|
||||
|
||||
def get_field(self, name: str) -> Optional[DecodedField]:
|
||||
"""Get a specific field by name"""
|
||||
for field in self.fields:
|
||||
if field.name == name:
|
||||
return field
|
||||
return None
|
||||
|
||||
def get_critical_fields(self) -> List[DecodedField]:
|
||||
"""Get all critical fields"""
|
||||
return [f for f in self.fields if f.is_critical]
|
||||
|
||||
def has_errors(self) -> bool:
|
||||
"""Check if decode result has any errors"""
|
||||
return len(self.errors) > 0
|
||||
|
||||
|
||||
class ProtocolRegistry:
|
||||
"""Registry for managing protocol information and detection"""
|
||||
|
||||
def __init__(self):
|
||||
self._protocols: Dict[ProtocolType, ProtocolInfo] = {}
|
||||
self._register_standard_protocols()
|
||||
self._register_enhanced_protocols()
|
||||
|
||||
def _register_standard_protocols(self):
|
||||
"""Register standard protocols"""
|
||||
for attr_name in dir(StandardProtocol):
|
||||
if not attr_name.startswith('_'):
|
||||
protocol = getattr(StandardProtocol, attr_name)
|
||||
if isinstance(protocol, ProtocolInfo):
|
||||
self._protocols[protocol.protocol_type] = protocol
|
||||
|
||||
def _register_enhanced_protocols(self):
|
||||
"""Register enhanced protocols"""
|
||||
for attr_name in dir(EnhancedProtocol):
|
||||
if not attr_name.startswith('_'):
|
||||
protocol = getattr(EnhancedProtocol, attr_name)
|
||||
if isinstance(protocol, ProtocolInfo):
|
||||
self._protocols[protocol.protocol_type] = protocol
|
||||
|
||||
def get_protocol(self, protocol_type: ProtocolType) -> Optional[ProtocolInfo]:
|
||||
"""Get protocol info by type"""
|
||||
return self._protocols.get(protocol_type)
|
||||
|
||||
def get_protocol_by_name(self, name: str) -> Optional[ProtocolInfo]:
|
||||
"""Get protocol info by name"""
|
||||
for protocol in self._protocols.values():
|
||||
if protocol.name.lower() == name.lower():
|
||||
return protocol
|
||||
return None
|
||||
|
||||
def get_enhanced_protocols(self) -> List[ProtocolInfo]:
|
||||
"""Get all enhanced protocols"""
|
||||
return [p for p in self._protocols.values() if p.is_enhanced]
|
||||
|
||||
def get_protocols_by_category(self, category: ProtocolCategory) -> List[ProtocolInfo]:
|
||||
"""Get all protocols in a category"""
|
||||
return [p for p in self._protocols.values() if p.category == category]
|
||||
|
||||
def register_protocol(self, protocol_info: ProtocolInfo):
|
||||
"""Register a new protocol"""
|
||||
self._protocols[protocol_info.protocol_type] = protocol_info
|
||||
|
||||
def is_enhanced_protocol(self, protocol_type: ProtocolType) -> bool:
|
||||
"""Check if protocol type is enhanced"""
|
||||
protocol = self.get_protocol(protocol_type)
|
||||
return protocol.is_enhanced if protocol else False
|
||||
|
||||
|
||||
# Global protocol registry instance
|
||||
PROTOCOL_REGISTRY = ProtocolRegistry()
|
||||
|
||||
|
||||
def get_protocol_info(protocol_type: ProtocolType) -> Optional[ProtocolInfo]:
|
||||
"""Convenience function to get protocol info"""
|
||||
return PROTOCOL_REGISTRY.get_protocol(protocol_type)
|
||||
|
||||
|
||||
def is_enhanced_protocol(protocol_type: ProtocolType) -> bool:
|
||||
"""Convenience function to check if protocol is enhanced"""
|
||||
return PROTOCOL_REGISTRY.is_enhanced_protocol(protocol_type)
|
||||
Reference in New Issue
Block a user