mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Refactor SkipDetector to streamline skip detection logic and improve clarity; update method signature for better integration with memory system.
This commit is contained in:
@@ -47,10 +47,8 @@ class Constants:
|
||||
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.5 # Multiplier for expanding memory candidates in advanced operations
|
||||
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.5 # Multiplier for LLM reranking trigger threshold
|
||||
|
||||
# Skip Detection Thresholds
|
||||
SKIP_DETECTION_SIMILARITY_THRESHOLD = 0.50 # Similarity threshold for skip category detection (tuned for zero-shot)
|
||||
SKIP_DETECTION_MARGIN = 0.05 # Minimum margin required between skip and conversational similarity to skip
|
||||
SKIP_DETECTION_CONFIDENT_MARGIN = 0.15 # Margin threshold for confident skips that trigger early exit
|
||||
# Skip Detection
|
||||
SKIP_CATEGORY_MARGIN = 0.1 # Margin above conversational similarity for skip category classification
|
||||
|
||||
# Safety & Operations
|
||||
MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety
|
||||
@@ -637,12 +635,11 @@ class SkipDetector:
|
||||
|
||||
return None
|
||||
|
||||
def detect_skip_reason(self, message: str, max_message_chars: int = Constants.MAX_MESSAGE_CHARS) -> Optional[str]:
|
||||
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]:
|
||||
"""
|
||||
Detect if a message should be skipped using two-stage detection:
|
||||
1. Fast-path structural patterns (~95% confidence)
|
||||
2. Semantic classification (for remaining cases)
|
||||
|
||||
Returns:
|
||||
Skip reason string if content should be skipped, None otherwise
|
||||
"""
|
||||
@@ -676,6 +673,9 @@ class SkipDetector:
|
||||
('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
|
||||
]
|
||||
|
||||
qualifying_categories = []
|
||||
margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN
|
||||
|
||||
for cat_key, skip_reason, descriptions in skip_categories:
|
||||
similarities = np.dot(
|
||||
message_embedding,
|
||||
@@ -683,16 +683,13 @@ class SkipDetector:
|
||||
)
|
||||
max_similarity = float(similarities.max())
|
||||
|
||||
if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD:
|
||||
margin = max_similarity - max_conversational_similarity
|
||||
if max_similarity > margin_threshold:
|
||||
qualifying_categories.append((max_similarity, cat_key, skip_reason))
|
||||
|
||||
if margin > Constants.SKIP_DETECTION_CONFIDENT_MARGIN:
|
||||
logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})")
|
||||
return skip_reason.value
|
||||
|
||||
if margin > Constants.SKIP_DETECTION_MARGIN:
|
||||
logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})")
|
||||
return skip_reason.value
|
||||
if qualifying_categories:
|
||||
highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0])
|
||||
logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})")
|
||||
return highest_skip_reason.value
|
||||
|
||||
return None
|
||||
|
||||
@@ -789,18 +786,25 @@ class LLMConsolidationService:
|
||||
def __init__(self, memory_system):
|
||||
self.memory_system = memory_system
|
||||
|
||||
def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]:
|
||||
"""Filter consolidation candidates by threshold and return candidates with threshold info."""
|
||||
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
|
||||
candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold]
|
||||
|
||||
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||
candidates = candidates[:max_consolidation_memories]
|
||||
|
||||
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
|
||||
return candidates, threshold_info
|
||||
|
||||
async def collect_consolidation_candidates(
|
||||
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
|
||||
if cached_similarities:
|
||||
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
|
||||
candidates = [mem for mem in cached_similarities if mem["relevance"] >= consolidation_threshold]
|
||||
candidates, threshold_info = self._filter_consolidation_candidates(cached_similarities)
|
||||
|
||||
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||
candidates = candidates[:max_consolidation_memories]
|
||||
|
||||
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {consolidation_threshold:.3f}, max: {max_consolidation_memories})")
|
||||
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})")
|
||||
|
||||
self.memory_system._log_retrieved_memories(candidates, "consolidation")
|
||||
return candidates
|
||||
@@ -826,13 +830,7 @@ class LLMConsolidationService:
|
||||
return []
|
||||
|
||||
if all_similarities:
|
||||
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
|
||||
candidates = [mem for mem in all_similarities if mem["relevance"] >= consolidation_threshold]
|
||||
|
||||
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||
candidates = candidates[:max_consolidation_memories]
|
||||
|
||||
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
|
||||
candidates, threshold_info = self._filter_consolidation_candidates(all_similarities)
|
||||
else:
|
||||
candidates = []
|
||||
threshold_info = 'N/A'
|
||||
@@ -1039,10 +1037,13 @@ class Filter:
|
||||
"""Configuration valves for the Memory System."""
|
||||
|
||||
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
||||
|
||||
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
||||
|
||||
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
|
||||
relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)")
|
||||
|
||||
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
|
||||
llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)")
|
||||
|
||||
@@ -1239,7 +1240,7 @@ class Filter:
|
||||
if self._skip_detector is None:
|
||||
raise RuntimeError("🤖 Skip detector not initialized")
|
||||
|
||||
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars)
|
||||
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self)
|
||||
if skip_reason:
|
||||
status_key = SkipDetector.SkipReason(skip_reason)
|
||||
return True, SkipDetector.STATUS_MESSAGES[status_key]
|
||||
|
||||
Reference in New Issue
Block a user