diff --git a/memory_system.py b/memory_system.py index 20b8ec1..6bceb7f 100644 --- a/memory_system.py +++ b/memory_system.py @@ -268,7 +268,6 @@ class UnifiedCacheManager: self.EMBEDDING_CACHE = "embedding" self.RETRIEVAL_CACHE = "retrieval" self.MEMORY_CACHE = "memory" - self.SKIP_STATE_CACHE = "skip" async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]: """Get value from cache with LRU updates.""" @@ -462,10 +461,17 @@ class SkipDetector: class SkipReason(Enum): SKIP_SIZE = "SKIP_SIZE" SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL" + SKIP_ALL_NON_PERSONAL = "SKIP_ALL_NON_PERSONAL" - STATUS_MESSAGES = { - SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations", - SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations", + # Inlet (retrieval) status messages + INLET_STATUS_MESSAGES = { + SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, Skipping Memory Retrieval", + SkipReason.SKIP_NON_PERSONAL: "🚫 No Personal Content, Skipping Memory Retrieval", + } + + # Outlet (consolidation) status messages + OUTLET_STATUS_MESSAGES = { + SkipReason.SKIP_ALL_NON_PERSONAL: "🚫 No Personal Content in Context, Skipping Memory Consolidation", } def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]): @@ -735,7 +741,7 @@ CANDIDATE MEMORIES: llm_candidates = candidate_memories[:extended_count] await self.memory_system._emit_status( emitter, - f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", + f"🤖 Analyzing {len(llm_candidates)} Memories for Relevance", done=False, level=Constants.STATUS_LEVEL["Intermediate"], ) @@ -745,9 +751,7 @@ CANDIDATE MEMORIES: if not selected_memories: logger.info("📭 No relevant memories after LLM analysis") - await self.memory_system._emit_status( - emitter, f"📭 No Relevant Memories After LLM Analysis", done=True, level=Constants.STATUS_LEVEL["Intermediate"] - ) + await self.memory_system._emit_status(emitter, "📭 No Relevant Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"]) return selected_memories, analysis_info else: logger.info(f"⏩ Skipping LLM reranking: {decision_reason}") @@ -886,7 +890,7 @@ class LLMConsolidationService: ) except Exception as e: logger.warning(f"🤖 LLM consolidation failed during memory processing: {str(e)}") - await self.memory_system._emit_status(emitter, f"⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"]) + await self.memory_system._emit_status(emitter, "⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"]) return [] operations = response.ops @@ -1130,7 +1134,7 @@ class LLMConsolidationService: duration = time.time() - start_time await self.memory_system._emit_status( emitter, - f"✅ Consolidation Complete: No Updates Needed", + "✅ No Memory Updates Needed", done=True, level=Constants.STATUS_LEVEL["Detailed"], ) @@ -1408,14 +1412,31 @@ class Filter: skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) - return True, SkipDetector.STATUS_MESSAGES[status_key] + return True, SkipDetector.INLET_STATUS_MESSAGES[status_key] return False, "" + async def _should_skip_consolidation(self, conversation_context: List[str]) -> Tuple[bool, str]: + """Check if consolidation should be skipped based on conversation context. + + Returns (should_skip, reason). Skips only if ALL messages in context are skippable. + """ + logger.info(f"🔍 Evaluating {len(conversation_context)} messages for consolidation") + + for idx, message in enumerate(conversation_context, 1): + skip_reason = await self._skip_detector.detect_skip_reason(message, Constants.MAX_MESSAGE_CHARS, memory_system=self) + if not skip_reason: # Found at least one valuable message + logger.info(f"✅ Found personal content in message {idx}/{len(conversation_context)}, proceeding with consolidation") + return False, "" + + # All messages were skippable + logger.info(f"🚫 All {len(conversation_context)} messages are non-personal, skipping consolidation") + return True, SkipDetector.OUTLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_ALL_NON_PERSONAL] + async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" user_message = self._get_last_user_message(body["messages"]) if not user_message: - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return None, True, SkipDetector.INLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] should_skip, skip_reason = await self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason @@ -1523,7 +1544,6 @@ class Filter: self._log_retrieved_memories(final_memories, "semantic") return { "memories": final_memories, - "threshold": self.valves.semantic_retrieval_threshold, "all_similarities": cached_similarities, "reranking_info": reranking_info, } @@ -1534,9 +1554,9 @@ class Filter: if not user_memories: logger.info("📭 No memories found for user") await self._emit_status(emitter, "📭 No Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"]) - return {"memories": [], "threshold": None} + return {"memories": []} - memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories) + memories, all_similarities = await self._compute_similarities(user_message, user_id, user_memories) if memories: final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) @@ -1550,7 +1570,6 @@ class Filter: return { "memories": final_memories, - "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info, } @@ -1611,10 +1630,10 @@ class Filter: memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat() return memory_dict - async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]: + async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], List[Dict]]: """Compute similarity scores between user message and memories.""" if not user_memories: - return [], self.valves.semantic_retrieval_threshold, [] + return [], [] query_embedding = await self._generate_embeddings(user_message, user_id) memory_contents = [memory.content for memory in user_memories] @@ -1632,9 +1651,8 @@ class Filter: memory_data.sort(key=lambda x: x["relevance"], reverse=True) - threshold = self.valves.semantic_retrieval_threshold - filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] - return filtered_memories, threshold, memory_data + filtered_memories = [m for m in memory_data if m["relevance"] >= self.valves.semantic_retrieval_threshold] + return filtered_memories, memory_data async def inlet( self, @@ -1654,14 +1672,6 @@ class Filter: user_message, should_skip, skip_reason = await self._process_user_message(body) - skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message) - await self._cache_manager.put( - user_id, - self._cache_manager.SKIP_STATE_CACHE, - skip_cache_key, - should_skip, - ) - if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"]) @@ -1680,7 +1690,6 @@ class Filter: ) retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) memories = retrieval_result.get("memories", []) - threshold = retrieval_result.get("threshold") all_similarities = retrieval_result.get("all_similarities", []) if all_similarities: cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) @@ -1715,18 +1724,16 @@ class Filter: if not user_message: return body - skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message) - should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key) + conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages) + should_skip_consolidation, skip_reason = await self._should_skip_consolidation(conversation_context) - if should_skip: - logger.info("⏭️ Skipping outlet: inlet detected skip condition") + if should_skip_consolidation: + await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"]) return body retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key) - conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages) - task = asyncio.create_task( self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities, conversation_context) ) @@ -1760,8 +1767,7 @@ class Filter: try: retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) - skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE) - logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding, {skip_state_cleared} skip entries") + logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding entries") user_memories = await self._get_user_memories(user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)