From 8997f71f057f92d5a52cdc925c2a66ff374a819b Mon Sep 17 00:00:00 2001 From: mtayfur Date: Tue, 28 Oct 2025 17:51:09 +0300 Subject: [PATCH] refactor(memory_system): remove excessive try/except and input validation, streamline async operations, and add skip state cache Removes redundant try/except blocks and input validation in several methods to simplify logic and improve readability. Moves error handling to higher levels where appropriate. Adds a skip state cache to track when memory operations should be skipped, improving efficiency by avoiding repeated skip checks. Cleans up batch operation execution and cache clearing to include the new skip state. These changes reduce unnecessary code complexity and improve maintainability, while also optimizing memory operation flow and cache management. --- memory_system.py | 289 +++++++++++++++++++---------------------------- 1 file changed, 118 insertions(+), 171 deletions(-) diff --git a/memory_system.py b/memory_system.py index 82a555e..aebaa68 100644 --- a/memory_system.py +++ b/memory_system.py @@ -254,6 +254,7 @@ class UnifiedCacheManager: self.EMBEDDING_CACHE = "embedding" self.RETRIEVAL_CACHE = "retrieval" self.MEMORY_CACHE = "memory" + self.SKIP_STATE_CACHE = "skip" async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]: """Get value from cache with LRU updates.""" @@ -747,25 +748,20 @@ class LLMConsolidationService: if not existing_memories: return None - try: - content_embedding = await self.memory_system._generate_embeddings(content, user_id) + content_embedding = await self.memory_system._generate_embeddings(content, user_id) - for memory in existing_memories: - if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: - continue + for memory in existing_memories: + if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: + continue - memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) + memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) + similarity = float(np.dot(content_embedding, memory_embedding)) - similarity = float(np.dot(content_embedding, memory_embedding)) + if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: + logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") + return str(memory.id) - if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: - logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") - return str(memory.id) - - return None - except Exception as e: - logger.warning(f"⚠️ Semantic duplicate check failed: {str(e)}") - return None + return None def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: """Filter consolidation candidates by threshold and return candidates with threshold info.""" @@ -945,21 +941,13 @@ class LLMConsolidationService: async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" - if not operations or not user_id: + if not operations: return 0, 0, 0, 0 - try: - user = await asyncio.wait_for( - asyncio.to_thread(Users.get_user_by_id, user_id), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - except asyncio.TimeoutError: - raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") - except Exception as e: - raise RuntimeError(f"👤 User lookup failed: {str(e)}") - - if not user: - raise ValueError(f"👤 User not found for consolidation: {user_id}") + user = await asyncio.wait_for( + asyncio.to_thread(Users.get_user_by_id, user_id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) created_count = updated_count = deleted_count = failed_count = 0 @@ -982,28 +970,21 @@ class LLMConsolidationService: memory_contents_for_deletion = {} if operations_by_type["DELETE"]: - try: - user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} - except Exception as e: - logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}") + user_memories = await self.memory_system._get_user_memories(user_id) + memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: - try: - current_memories = await self.memory_system._get_user_memories(user_id) + current_memories = await self.memory_system._get_user_memories(user_id) - if operations_by_type["CREATE"]: - operations_by_type["CREATE"] = await self._deduplicate_operations( - operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" - ) + if operations_by_type["CREATE"]: + operations_by_type["CREATE"] = await self._deduplicate_operations( + operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" + ) - if operations_by_type["UPDATE"]: - operations_by_type["UPDATE"] = await self._deduplicate_operations( - operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] - ) - - except Exception as e: - logger.warning(f"⚠️ Semantic deduplication check failed, proceeding with original operations: {str(e)}") + if operations_by_type["UPDATE"]: + operations_by_type["UPDATE"] = await self._deduplicate_operations( + operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] + ) for operation_type, ops in operations_by_type.items(): if not ops: @@ -1014,38 +995,33 @@ class LLMConsolidationService: task = self.memory_system._execute_single_operation(operation, user) batch_tasks.append(task) - try: - results = await asyncio.gather(*batch_tasks, return_exceptions=True) - for idx, result in enumerate(results): - operation = ops[idx] + results = await asyncio.gather(*batch_tasks, return_exceptions=True) + for idx, result in enumerate(results): + operation = ops[idx] - if isinstance(result, Exception): - failed_count += 1 - await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) - elif result == Models.MemoryOperationType.CREATE.value: - created_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) - elif result == Models.MemoryOperationType.UPDATE.value: - updated_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) - elif result == Models.MemoryOperationType.DELETE.value: - deleted_count += 1 - content_preview = memory_contents_for_deletion.get(operation.id, operation.id) - if content_preview and content_preview != operation.id: - content_preview = self.memory_system._truncate_content(content_preview) - await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) - elif result in [ - Models.OperationResult.FAILED.value, - Models.OperationResult.UNSUPPORTED.value, - ]: - failed_count += 1 - await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) - except Exception as e: - failed_count += len(ops) - logger.error(f"❌ Batch {operation_type} operations failed during memory consolidation: {str(e)}") - await self.memory_system._emit_status(emitter, f"❌ Batch {operation_type} Failed", done=False) + if isinstance(result, Exception): + failed_count += 1 + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) + elif result == Models.MemoryOperationType.CREATE.value: + created_count += 1 + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) + elif result == Models.MemoryOperationType.UPDATE.value: + updated_count += 1 + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) + elif result == Models.MemoryOperationType.DELETE.value: + deleted_count += 1 + content_preview = memory_contents_for_deletion.get(operation.id, operation.id) + if content_preview and content_preview != operation.id: + content_preview = self.memory_system._truncate_content(content_preview) + await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) + elif result in [ + Models.OperationResult.FAILED.value, + Models.OperationResult.UNSUPPORTED.value, + ]: + failed_count += 1 + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) total_executed = created_count + updated_count + deleted_count logger.info( @@ -1296,17 +1272,9 @@ class Filter: async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" - if self._embedding_function is None: - raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.") - is_single = isinstance(texts, str) text_list = [texts] if is_single else texts - if not text_list: - if is_single: - raise ValueError("📏 Empty text provided for embedding generation") - return [] - result_embeddings = [] uncached_texts = [] uncached_indices = [] @@ -1360,9 +1328,6 @@ class Filter: return result_embeddings def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: - if self._skip_detector is None: - raise RuntimeError("🤖 Skip detector not initialized") - skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) @@ -1371,13 +1336,6 @@ class Filter: def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" - if not body or "messages" not in body or not isinstance(body["messages"], list): - return ( - None, - True, - SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], - ) - messages = body["messages"] user_message = None @@ -1421,10 +1379,6 @@ class Filter: return scores = [memory["relevance"] for memory in memories] - - if not scores: - return - top_score = max(scores) lowest_score = min(scores) median_score = statistics.median(scores) @@ -1485,13 +1439,9 @@ class Filter: return payload = {"type": "status", "data": {"description": description, "done": done}} - - try: - result = emitter(payload) - if asyncio.iscoroutine(result): - await result - except Exception: - pass + result = emitter(payload) + if asyncio.iscoroutine(result): + await result async def _retrieve_relevant_memories( self, @@ -1549,10 +1499,6 @@ class Filter: emitter: Optional[Callable] = None, ) -> None: """Add memory context to request body with simplified logic.""" - if not body or "messages" not in body or not body["messages"]: - logger.warning("⚠️ Invalid request body or no messages found") - return - content_parts = [f"Current Date/Time: {self.format_current_datetime()}"] memory_count = 0 @@ -1610,26 +1556,16 @@ class Filter: memory_contents = [memory.content for memory in user_memories] memory_embeddings = await self._generate_embeddings(memory_contents, user_id) - if len(memory_embeddings) != len(user_memories): - logger.error(f"🔢 Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories") - return [], self.valves.semantic_retrieval_threshold, [] - - similarity_scores = [] memory_data = [] - for memory_index, memory in enumerate(user_memories): memory_embedding = memory_embeddings[memory_index] if memory_embedding is None: continue similarity = float(np.dot(query_embedding, memory_embedding)) - similarity_scores.append(similarity) memory_dict = self._build_memory_dict(memory, similarity) memory_data.append(memory_dict) - if not similarity_scores: - return [], self.valves.semantic_retrieval_threshold, [] - memory_data.sort(key=lambda x: x["relevance"], reverse=True) threshold = self.valves.semantic_retrieval_threshold @@ -1657,10 +1593,19 @@ class Filter: return body user_message, should_skip, skip_reason = self._process_user_message(body) + if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True) await self._add_memory_context(body, [], user_id, __event_emitter__) + + skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "") + await self._cache_manager.put( + user_id, + self._cache_manager.SKIP_STATE_CACHE, + skip_cache_key, + True, + ) return body try: memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) @@ -1709,11 +1654,20 @@ class Filter: user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: + + user_message, _, _ = self._process_user_message(body) + if not user_message: return body - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) + + skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message) + should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key) + + if should_skip: + logger.info("⏭️ Skipping outlet consolidation: inlet already detected skip condition") + return body + + retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) + cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key) task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) self._background_tasks.add(task) @@ -1745,7 +1699,10 @@ class Filter: try: retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) - logger.info(f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}") + skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE) + logger.info( + f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding + {skip_state_cleared} skip state cache entries for user {user_id}" + ) user_memories = await self._get_user_memories(user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) @@ -1774,59 +1731,49 @@ class Filter: async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: """Execute a single memory operation.""" - try: - if operation.operation == Models.MemoryOperationType.CREATE: - content_stripped = operation.content.strip() - if not content_stripped: - logger.warning(f"⚠️ Skipping CREATE operation: empty content") - return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value + if operation.operation == Models.MemoryOperationType.CREATE: + content_stripped = operation.content.strip() + if not content_stripped: + return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value - await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.CREATE.value + await asyncio.wait_for( + asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.CREATE.value - elif operation.operation == Models.MemoryOperationType.UPDATE: - id_stripped = operation.id.strip() - if not id_stripped: - logger.warning(f"⚠️ Skipping UPDATE operation: empty ID") - return Models.OperationResult.SKIPPED_EMPTY_ID.value + elif operation.operation == Models.MemoryOperationType.UPDATE: + id_stripped = operation.id.strip() + if not id_stripped: + return Models.OperationResult.SKIPPED_EMPTY_ID.value - content_stripped = operation.content.strip() - if not content_stripped: - logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content") - return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value + content_stripped = operation.content.strip() + if not content_stripped: + return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value - await asyncio.wait_for( - asyncio.to_thread( - Memories.update_memory_by_id_and_user_id, - id_stripped, - user.id, - content_stripped, - ), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.UPDATE.value + await asyncio.wait_for( + asyncio.to_thread( + Memories.update_memory_by_id_and_user_id, + id_stripped, + user.id, + content_stripped, + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.UPDATE.value - elif operation.operation == Models.MemoryOperationType.DELETE: - id_stripped = operation.id.strip() - if not id_stripped: - logger.warning(f"⚠️ Skipping DELETE operation: empty ID") - return Models.OperationResult.SKIPPED_EMPTY_ID.value + elif operation.operation == Models.MemoryOperationType.DELETE: + id_stripped = operation.id.strip() + if not id_stripped: + return Models.OperationResult.SKIPPED_EMPTY_ID.value - await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), - timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, - ) - return Models.MemoryOperationType.DELETE.value - else: - logger.error(f"❓ Unsupported operation: {operation}") - return Models.OperationResult.UNSUPPORTED.value + await asyncio.wait_for( + asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) + return Models.MemoryOperationType.DELETE.value - except Exception as e: - logger.error(f"💾 Database operation failed for {operation.operation.value}: {str(e)}") - return Models.OperationResult.FAILED.value + return Models.OperationResult.UNSUPPORTED.value def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Remove $ref references and ensure required fields for Azure OpenAI."""