diff --git a/memory_system.py b/memory_system.py index 82a555e..aebaa68 100644 --- a/memory_system.py +++ b/memory_system.py @@ -254,6 +254,7 @@ class UnifiedCacheManager: self.EMBEDDING_CACHE = "embedding" self.RETRIEVAL_CACHE = "retrieval" self.MEMORY_CACHE = "memory" + self.SKIP_STATE_CACHE = "skip" async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]: """Get value from cache with LRU updates.""" @@ -747,25 +748,20 @@ class LLMConsolidationService: if not existing_memories: return None - try: - content_embedding = await self.memory_system._generate_embeddings(content, user_id) + content_embedding = await self.memory_system._generate_embeddings(content, user_id) - for memory in existing_memories: - if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: - continue + for memory in existing_memories: + if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: + continue - memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) + memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) + similarity = float(np.dot(content_embedding, memory_embedding)) - similarity = float(np.dot(content_embedding, memory_embedding)) + if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: + logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") + return str(memory.id) - if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: - logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") - return str(memory.id) - - return None - except Exception as e: - logger.warning(f"⚠️ Semantic duplicate check failed: {str(e)}") - return None + return None def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: """Filter consolidation candidates by threshold and return candidates with threshold info.""" @@ -945,21 +941,13 @@ class LLMConsolidationService: async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" - if not operations or not user_id: + if not operations: return 0, 0, 0, 0 - try: - user = await asyncio.wait_for( - asyncio.to_thread(Users.get_user_by_id, user_id), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - except asyncio.TimeoutError: - raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") - except Exception as e: - raise RuntimeError(f"👤 User lookup failed: {str(e)}") - - if not user: - raise ValueError(f"👤 User not found for consolidation: {user_id}") + user = await asyncio.wait_for( + asyncio.to_thread(Users.get_user_by_id, user_id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) created_count = updated_count = deleted_count = failed_count = 0 @@ -982,28 +970,21 @@ class LLMConsolidationService: memory_contents_for_deletion = {} if operations_by_type["DELETE"]: - try: - user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} - except Exception as e: - logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}") + user_memories = await self.memory_system._get_user_memories(user_id) + memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: - try: - current_memories = await self.memory_system._get_user_memories(user_id) + current_memories = await self.memory_system._get_user_memories(user_id) - if operations_by_type["CREATE"]: - operations_by_type["CREATE"] = await self._deduplicate_operations( - operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" - ) + if operations_by_type["CREATE"]: + operations_by_type["CREATE"] = await self._deduplicate_operations( + operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" + ) - if operations_by_type["UPDATE"]: - operations_by_type["UPDATE"] = await self._deduplicate_operations( - operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] - ) - - except Exception as e: - logger.warning(f"⚠️ Semantic deduplication check failed, proceeding with original operations: {str(e)}") + if operations_by_type["UPDATE"]: + operations_by_type["UPDATE"] = await self._deduplicate_operations( + operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] + ) for operation_type, ops in operations_by_type.items(): if not ops: @@ -1014,38 +995,33 @@ class LLMConsolidationService: task = self.memory_system._execute_single_operation(operation, user) batch_tasks.append(task) - try: - results = await asyncio.gather(*batch_tasks, return_exceptions=True) - for idx, result in enumerate(results): - operation = ops[idx] + results = await asyncio.gather(*batch_tasks, return_exceptions=True) + for idx, result in enumerate(results): + operation = ops[idx] - if isinstance(result, Exception): - failed_count += 1 - await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) - elif result == Models.MemoryOperationType.CREATE.value: - created_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) - elif result == Models.MemoryOperationType.UPDATE.value: - updated_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) - elif result == Models.MemoryOperationType.DELETE.value: - deleted_count += 1 - content_preview = memory_contents_for_deletion.get(operation.id, operation.id) - if content_preview and content_preview != operation.id: - content_preview = self.memory_system._truncate_content(content_preview) - await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) - elif result in [ - Models.OperationResult.FAILED.value, - Models.OperationResult.UNSUPPORTED.value, - ]: - failed_count += 1 - await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) - except Exception as e: - failed_count += len(ops) - logger.error(f"❌ Batch {operation_type} operations failed during memory consolidation: {str(e)}") - await self.memory_system._emit_status(emitter, f"❌ Batch {operation_type} Failed", done=False) + if isinstance(result, Exception): + failed_count += 1 + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) + elif result == Models.MemoryOperationType.CREATE.value: + created_count += 1 + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) + elif result == Models.MemoryOperationType.UPDATE.value: + updated_count += 1 + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) + elif result == Models.MemoryOperationType.DELETE.value: + deleted_count += 1 + content_preview = memory_contents_for_deletion.get(operation.id, operation.id) + if content_preview and content_preview != operation.id: + content_preview = self.memory_system._truncate_content(content_preview) + await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) + elif result in [ + Models.OperationResult.FAILED.value, + Models.OperationResult.UNSUPPORTED.value, + ]: + failed_count += 1 + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) total_executed = created_count + updated_count + deleted_count logger.info( @@ -1296,17 +1272,9 @@ class Filter: async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" - if self._embedding_function is None: - raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.") - is_single = isinstance(texts, str) text_list = [texts] if is_single else texts - if not text_list: - if is_single: - raise ValueError("📏 Empty text provided for embedding generation") - return [] - result_embeddings = [] uncached_texts = [] uncached_indices = [] @@ -1360,9 +1328,6 @@ class Filter: return result_embeddings def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: - if self._skip_detector is None: - raise RuntimeError("🤖 Skip detector not initialized") - skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) @@ -1371,13 +1336,6 @@ class Filter: def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" - if not body or "messages" not in body or not isinstance(body["messages"], list): - return ( - None, - True, - SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], - ) - messages = body["messages"] user_message = None @@ -1421,10 +1379,6 @@ class Filter: return scores = [memory["relevance"] for memory in memories] - - if not scores: - return - top_score = max(scores) lowest_score = min(scores) median_score = statistics.median(scores) @@ -1485,13 +1439,9 @@ class Filter: return payload = {"type": "status", "data": {"description": description, "done": done}} - - try: - result = emitter(payload) - if asyncio.iscoroutine(result): - await result - except Exception: - pass + result = emitter(payload) + if asyncio.iscoroutine(result): + await result async def _retrieve_relevant_memories( self, @@ -1549,10 +1499,6 @@ class Filter: emitter: Optional[Callable] = None, ) -> None: """Add memory context to request body with simplified logic.""" - if not body or "messages" not in body or not body["messages"]: - logger.warning("⚠️ Invalid request body or no messages found") - return - content_parts = [f"Current Date/Time: {self.format_current_datetime()}"] memory_count = 0 @@ -1610,26 +1556,16 @@ class Filter: memory_contents = [memory.content for memory in user_memories] memory_embeddings = await self._generate_embeddings(memory_contents, user_id) - if len(memory_embeddings) != len(user_memories): - logger.error(f"🔢 Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories") - return [], self.valves.semantic_retrieval_threshold, [] - - similarity_scores = [] memory_data = [] - for memory_index, memory in enumerate(user_memories): memory_embedding = memory_embeddings[memory_index] if memory_embedding is None: continue similarity = float(np.dot(query_embedding, memory_embedding)) - similarity_scores.append(similarity) memory_dict = self._build_memory_dict(memory, similarity) memory_data.append(memory_dict) - if not similarity_scores: - return [], self.valves.semantic_retrieval_threshold, [] - memory_data.sort(key=lambda x: x["relevance"], reverse=True) threshold = self.valves.semantic_retrieval_threshold @@ -1657,10 +1593,19 @@ class Filter: return body user_message, should_skip, skip_reason = self._process_user_message(body) + if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True) await self._add_memory_context(body, [], user_id, __event_emitter__) + + skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "") + await self._cache_manager.put( + user_id, + self._cache_manager.SKIP_STATE_CACHE, + skip_cache_key, + True, + ) return body try: memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) @@ -1709,11 +1654,20 @@ class Filter: user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: + + user_message, _, _ = self._process_user_message(body) + if not user_message: return body - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) + + skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message) + should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key) + + if should_skip: + logger.info("⏭️ Skipping outlet consolidation: inlet already detected skip condition") + return body + + retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) + cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key) task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) self._background_tasks.add(task) @@ -1745,7 +1699,10 @@ class Filter: try: retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) - logger.info(f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}") + skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE) + logger.info( + f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding + {skip_state_cleared} skip state cache entries for user {user_id}" + ) user_memories = await self._get_user_memories(user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) @@ -1774,59 +1731,49 @@ class Filter: async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: """Execute a single memory operation.""" - try: - if operation.operation == Models.MemoryOperationType.CREATE: - content_stripped = operation.content.strip() - if not content_stripped: - logger.warning(f"⚠️ Skipping CREATE operation: empty content") - return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value + if operation.operation == Models.MemoryOperationType.CREATE: + content_stripped = operation.content.strip() + if not content_stripped: + return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value - await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.CREATE.value + await asyncio.wait_for( + asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.CREATE.value - elif operation.operation == Models.MemoryOperationType.UPDATE: - id_stripped = operation.id.strip() - if not id_stripped: - logger.warning(f"⚠️ Skipping UPDATE operation: empty ID") - return Models.OperationResult.SKIPPED_EMPTY_ID.value + elif operation.operation == Models.MemoryOperationType.UPDATE: + id_stripped = operation.id.strip() + if not id_stripped: + return Models.OperationResult.SKIPPED_EMPTY_ID.value - content_stripped = operation.content.strip() - if not content_stripped: - logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content") - return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value + content_stripped = operation.content.strip() + if not content_stripped: + return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value - await asyncio.wait_for( - asyncio.to_thread( - Memories.update_memory_by_id_and_user_id, - id_stripped, - user.id, - content_stripped, - ), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.UPDATE.value + await asyncio.wait_for( + asyncio.to_thread( + Memories.update_memory_by_id_and_user_id, + id_stripped, + user.id, + content_stripped, + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.UPDATE.value - elif operation.operation == Models.MemoryOperationType.DELETE: - id_stripped = operation.id.strip() - if not id_stripped: - logger.warning(f"⚠️ Skipping DELETE operation: empty ID") - return Models.OperationResult.SKIPPED_EMPTY_ID.value + elif operation.operation == Models.MemoryOperationType.DELETE: + id_stripped = operation.id.strip() + if not id_stripped: + return Models.OperationResult.SKIPPED_EMPTY_ID.value - await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.DELETE.value - else: - logger.error(f"❓ Unsupported operation: {operation}") - return Models.OperationResult.UNSUPPORTED.value + await asyncio.wait_for( + asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.DELETE.value - except Exception as e: - logger.error(f"💾 Database operation failed for {operation.operation.value}: {str(e)}") - return Models.OperationResult.FAILED.value + return Models.OperationResult.UNSUPPORTED.value def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Remove $ref references and ensure required fields for Azure OpenAI."""