Refactor SkipDetector to streamline skip detection logic and improve clarity; update method signature for better integration with memory system.

This commit is contained in:
mtayfur
2025-10-12 21:44:51 +03:00
parent 1390505665
commit 2db2d3f2c8

View File

@@ -47,11 +47,9 @@ class Constants:
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.5 # Multiplier for expanding memory candidates in advanced operations
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.5 # Multiplier for LLM reranking trigger threshold
# Skip Detection Thresholds
SKIP_DETECTION_SIMILARITY_THRESHOLD = 0.50 # Similarity threshold for skip category detection (tuned for zero-shot)
SKIP_DETECTION_MARGIN = 0.05 # Minimum margin required between skip and conversational similarity to skip
SKIP_DETECTION_CONFIDENT_MARGIN = 0.15 # Margin threshold for confident skips that trigger early exit
# Skip Detection
SKIP_CATEGORY_MARGIN = 0.1 # Margin above conversational similarity for skip category classification
# Safety & Operations
MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety
MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check
@@ -637,12 +635,11 @@ class SkipDetector:
return None
def detect_skip_reason(self, message: str, max_message_chars: int = Constants.MAX_MESSAGE_CHARS) -> Optional[str]:
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]:
"""
Detect if a message should be skipped using two-stage detection:
1. Fast-path structural patterns (~95% confidence)
2. Semantic classification (for remaining cases)
Returns:
Skip reason string if content should be skipped, None otherwise
"""
@@ -676,6 +673,9 @@ class SkipDetector:
('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
]
qualifying_categories = []
margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN
for cat_key, skip_reason, descriptions in skip_categories:
similarities = np.dot(
message_embedding,
@@ -683,16 +683,13 @@ class SkipDetector:
)
max_similarity = float(similarities.max())
if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD:
margin = max_similarity - max_conversational_similarity
if margin > Constants.SKIP_DETECTION_CONFIDENT_MARGIN:
logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})")
return skip_reason.value
if margin > Constants.SKIP_DETECTION_MARGIN:
logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})")
return skip_reason.value
if max_similarity > margin_threshold:
qualifying_categories.append((max_similarity, cat_key, skip_reason))
if qualifying_categories:
highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0])
logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})")
return highest_skip_reason.value
return None
@@ -789,18 +786,25 @@ class LLMConsolidationService:
def __init__(self, memory_system):
self.memory_system = memory_system
def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]:
"""Filter consolidation candidates by threshold and return candidates with threshold info."""
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold]
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
candidates = candidates[:max_consolidation_memories]
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
return candidates, threshold_info
async def collect_consolidation_candidates(
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None
) -> List[Dict[str, Any]]:
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
if cached_similarities:
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
candidates = [mem for mem in cached_similarities if mem["relevance"] >= consolidation_threshold]
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
candidates = candidates[:max_consolidation_memories]
candidates, threshold_info = self._filter_consolidation_candidates(cached_similarities)
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {consolidation_threshold:.3f}, max: {max_consolidation_memories})")
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})")
self.memory_system._log_retrieved_memories(candidates, "consolidation")
return candidates
@@ -826,13 +830,7 @@ class LLMConsolidationService:
return []
if all_similarities:
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
candidates = [mem for mem in all_similarities if mem["relevance"] >= consolidation_threshold]
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
candidates = candidates[:max_consolidation_memories]
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
candidates, threshold_info = self._filter_consolidation_candidates(all_similarities)
else:
candidates = []
threshold_info = 'N/A'
@@ -1039,10 +1037,13 @@ class Filter:
"""Configuration valves for the Memory System."""
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)")
relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)")
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)")
@@ -1239,7 +1240,7 @@ class Filter:
if self._skip_detector is None:
raise RuntimeError("🤖 Skip detector not initialized")
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars)
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self)
if skip_reason:
status_key = SkipDetector.SkipReason(skip_reason)
return True, SkipDetector.STATUS_MESSAGES[status_key]