From 2db2d3f2c873314949078dcf16c0a439482906b9 Mon Sep 17 00:00:00 2001 From: mtayfur Date: Sun, 12 Oct 2025 21:44:51 +0300 Subject: [PATCH] Refactor SkipDetector to streamline skip detection logic and improve clarity; update method signature for better integration with memory system. --- memory_system.py | 67 ++++++++++++++++++++++++------------------------ 1 file changed, 34 insertions(+), 33 deletions(-) diff --git a/memory_system.py b/memory_system.py index 284a5ec..ec88f2b 100644 --- a/memory_system.py +++ b/memory_system.py @@ -47,11 +47,9 @@ class Constants: EXTENDED_MAX_MEMORY_MULTIPLIER = 1.5 # Multiplier for expanding memory candidates in advanced operations LLM_RERANKING_TRIGGER_MULTIPLIER = 0.5 # Multiplier for LLM reranking trigger threshold - # Skip Detection Thresholds - SKIP_DETECTION_SIMILARITY_THRESHOLD = 0.50 # Similarity threshold for skip category detection (tuned for zero-shot) - SKIP_DETECTION_MARGIN = 0.05 # Minimum margin required between skip and conversational similarity to skip - SKIP_DETECTION_CONFIDENT_MARGIN = 0.15 # Margin threshold for confident skips that trigger early exit - + # Skip Detection + SKIP_CATEGORY_MARGIN = 0.1 # Margin above conversational similarity for skip category classification + # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check @@ -637,12 +635,11 @@ class SkipDetector: return None - def detect_skip_reason(self, message: str, max_message_chars: int = Constants.MAX_MESSAGE_CHARS) -> Optional[str]: + def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) 2. Semantic classification (for remaining cases) - Returns: Skip reason string if content should be skipped, None otherwise """ @@ -676,6 +673,9 @@ class SkipDetector: ('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), ] + qualifying_categories = [] + margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN + for cat_key, skip_reason, descriptions in skip_categories: similarities = np.dot( message_embedding, @@ -683,16 +683,13 @@ class SkipDetector: ) max_similarity = float(similarities.max()) - if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD: - margin = max_similarity - max_conversational_similarity - - if margin > Constants.SKIP_DETECTION_CONFIDENT_MARGIN: - logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})") - return skip_reason.value - - if margin > Constants.SKIP_DETECTION_MARGIN: - logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})") - return skip_reason.value + if max_similarity > margin_threshold: + qualifying_categories.append((max_similarity, cat_key, skip_reason)) + + if qualifying_categories: + highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0]) + logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})") + return highest_skip_reason.value return None @@ -789,18 +786,25 @@ class LLMConsolidationService: def __init__(self, memory_system): self.memory_system = memory_system + def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: + """Filter consolidation candidates by threshold and return candidates with threshold info.""" + consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) + candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold] + + max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) + candidates = candidates[:max_consolidation_memories] + + threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + return candidates, threshold_info + async def collect_consolidation_candidates( self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None ) -> List[Dict[str, Any]]: """Collect candidate memories for consolidation analysis using cached or computed similarities.""" if cached_similarities: - consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) - candidates = [mem for mem in cached_similarities if mem["relevance"] >= consolidation_threshold] - - max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - candidates = candidates[:max_consolidation_memories] + candidates, threshold_info = self._filter_consolidation_candidates(cached_similarities) - logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {consolidation_threshold:.3f}, max: {max_consolidation_memories})") + logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates @@ -826,13 +830,7 @@ class LLMConsolidationService: return [] if all_similarities: - consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) - candidates = [mem for mem in all_similarities if mem["relevance"] >= consolidation_threshold] - - max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - candidates = candidates[:max_consolidation_memories] - - threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + candidates, threshold_info = self._filter_consolidation_candidates(all_similarities) else: candidates = [] threshold_info = 'N/A' @@ -1039,10 +1037,13 @@ class Filter: """Configuration valves for the Memory System.""" model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") + max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") + max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") + semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval") - relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)") + relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)") + enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)") @@ -1239,7 +1240,7 @@ class Filter: if self._skip_detector is None: raise RuntimeError("🤖 Skip detector not initialized") - skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars) + skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key]