From 0c87a815fce33521638ca73904ba67fdb4ac29ce Mon Sep 17 00:00:00 2001 From: mtayfur Date: Wed, 26 Nov 2025 10:47:16 +0300 Subject: [PATCH] refactor(memory_system): streamline memory deduplication, logging, and message extraction logic Refactors deduplication to batch embedding generation for efficiency, consolidates user message extraction into a helper for reuse, replaces statistics.median with numpy for consistency, simplifies memory operation execution, and removes redundant logging and unused imports to improve maintainability and performance. --- memory_system.py | 151 +++++++++++++++++------------------------------ 1 file changed, 54 insertions(+), 97 deletions(-) diff --git a/memory_system.py b/memory_system.py index 01b9684..1de4caa 100644 --- a/memory_system.py +++ b/memory_system.py @@ -11,7 +11,6 @@ import asyncio import hashlib import json import logging -import statistics import time from collections import OrderedDict from datetime import datetime, timezone @@ -298,7 +297,7 @@ class UnifiedCacheManager: async with self._lock: if user_id not in self.caches: if len(self.caches) >= self.max_users: - evicted_user, _ = self.caches.popitem(last=False) + self.caches.popitem(last=False) self.caches[user_id] = {} user_cache = self.caches[user_id] @@ -309,7 +308,7 @@ class UnifiedCacheManager: type_cache = user_cache[cache_type] if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type: - evicted_key, _ = type_cache.popitem(last=False) + type_cache.popitem(last=False) if key in type_cache: type_cache[key] = value @@ -776,18 +775,20 @@ class LLMConsolidationService: Check if content is semantically duplicate of existing memories using embeddings. Returns the ID of duplicate memory if found, None otherwise. """ - if not existing_memories: + valid_memories = [m for m in existing_memories if m.content and len(m.content.strip()) >= Constants.MIN_MESSAGE_CHARS] + if not valid_memories: return None - content_embedding = await self.memory_system._generate_embeddings(content, user_id) + memory_contents = [m.content for m in valid_memories] + all_texts = [content] + memory_contents + all_embeddings = await self.memory_system._generate_embeddings(all_texts, user_id) - for memory in existing_memories: - if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: + content_embedding = all_embeddings[0] + for i, memory in enumerate(valid_memories): + memory_embedding = all_embeddings[i + 1] + if memory_embedding is None: continue - - memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) similarity = float(np.dot(content_embedding, memory_embedding)) - if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") return str(memory.id) @@ -999,23 +1000,16 @@ class LLMConsolidationService: error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}" logger.error(error_message) - memory_contents_for_deletion = {} - if operations_by_type["DELETE"]: - user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} + user_memories = await self.memory_system._get_user_memories(user_id) + memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["DELETE"] else {} - if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: - current_memories = await self.memory_system._get_user_memories(user_id) + if operations_by_type["CREATE"]: + operations_by_type["CREATE"] = await self._deduplicate_operations(operations_by_type["CREATE"], user_memories, user_id, operation_type="CREATE") - if operations_by_type["CREATE"]: - operations_by_type["CREATE"] = await self._deduplicate_operations( - operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" - ) - - if operations_by_type["UPDATE"]: - operations_by_type["UPDATE"] = await self._deduplicate_operations( - operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] - ) + if operations_by_type["UPDATE"]: + operations_by_type["UPDATE"] = await self._deduplicate_operations( + operations_by_type["UPDATE"], user_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] + ) for operation_type, ops in operations_by_type.items(): if not ops: @@ -1194,6 +1188,8 @@ class Filter: self.__user__ = __user__ if __model__: self.__model__ = __model__ + if self.valves.memory_model: + logger.info(f"🤖 Using custom memory model: {__model__}") if __request__: self.__request__ = __request__ @@ -1255,6 +1251,17 @@ class Filter: return content.get("text", "") return "" + def _get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[str]: + """Extract the last user message text from a list of messages.""" + for message in reversed(messages): + if not isinstance(message, dict) or message.get("role") != "user": + continue + content = message.get("content", "") + text = self._extract_text_from_content(content) + if text: + return text + return None + def _validate_system_configuration(self) -> None: """Validate configuration and fail if invalid.""" if self.valves.max_memories_returned <= 0: @@ -1280,10 +1287,10 @@ class Filter: async def _detect_embedding_dimension(self) -> None: """Detect embedding dimension by generating a test embedding.""" test_embedding = await self._embedding_function("dummy", prefix=None, user=None) - + if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)): test_embedding = test_embedding[0] - + emb_array = np.squeeze(np.array(test_embedding)) self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1 logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") @@ -1368,39 +1375,20 @@ class Filter: async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" - messages = body["messages"] - user_message = None - - for message in reversed(messages): - if not isinstance(message, dict) or message.get("role") != "user": - continue - - content = message.get("content", "") - user_message = self._extract_text_from_content(content) - - if user_message: - break - - if not user_message or not user_message.strip(): - return ( - None, - True, - SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], - ) + user_message = self._get_last_user_message(body["messages"]) + if not user_message: + return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] should_skip, skip_reason = await self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason - async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: + async def _get_user_memories(self, user_id: str) -> List: """Get user memories with timeout handling.""" - if timeout is None: - timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC - memories = await asyncio.wait_for( asyncio.to_thread(Memories.get_memories_by_user_id, user_id), - timeout=timeout, + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) - return [m for m in (memories or []) if m.content] + return [m for m in memories if m.content] if memories else [] def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: """Log retrieved memories with concise formatting showing key statistics and semantic values.""" @@ -1410,7 +1398,7 @@ class Filter: scores = [memory["relevance"] for memory in memories] top_score = max(scores) lowest_score = min(scores) - median_score = statistics.median(scores) + median_score = float(np.median(scores)) context_label = "📊 Consolidation candidate memories" if context_type == "consolidation" else "📊 Retrieved memories" max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) @@ -1436,11 +1424,7 @@ class Filter: return f"{cache_type}_{user_id}" def format_current_datetime(self) -> str: - try: - now = datetime.now(timezone.utc) - return now.strftime("%A %B %d %Y at %H:%M:%S UTC") - except Exception as e: - raise RuntimeError(f"📅 Failed to format datetime: {str(e)}") + return datetime.now(timezone.utc).strftime("%A %B %d %Y at %H:%M:%S UTC") def _format_memories_for_llm(self, memories: List[Dict[str, Any]]) -> List[str]: """Format memories for LLM consumption with hybrid format and human-readable timestamps.""" @@ -1622,10 +1606,6 @@ class Filter: """Simplified inlet processing for memory retrieval and injection.""" model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None) - - if self.valves.memory_model: - logger.info(f"🤖 Using the custom model for memory : {model_to_use}") - await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None @@ -1685,26 +1665,13 @@ class Filter: """Simplified outlet processing for background memory consolidation.""" model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None) - - if self.valves.memory_model: - logger.info(f"🤖 Using the custom model for memory : {model_to_use}") - await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - messages = body.get("messages", []) - user_message = None - for message in reversed(messages): - if not isinstance(message, dict) or message.get("role") != "user": - continue - content = message.get("content", "") - user_message = self._extract_text_from_content(content) - if user_message: - break - + user_message = self._get_last_user_message(body.get("messages", [])) if not user_message: return body @@ -1780,44 +1747,34 @@ class Filter: async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: """Execute a single memory operation.""" - if operation.operation == Models.MemoryOperationType.CREATE: - content_stripped = operation.content.strip() - if not content_stripped: - return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value + content = operation.content.strip() + memory_id = operation.id.strip() + if operation.operation == Models.MemoryOperationType.CREATE: + if not content: + return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), + asyncio.to_thread(Memories.insert_new_memory, user.id, content), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.CREATE.value elif operation.operation == Models.MemoryOperationType.UPDATE: - id_stripped = operation.id.strip() - if not id_stripped: + if not memory_id: return Models.OperationResult.SKIPPED_EMPTY_ID.value - - content_stripped = operation.content.strip() - if not content_stripped: + if not content: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value - await asyncio.wait_for( - asyncio.to_thread( - Memories.update_memory_by_id_and_user_id, - id_stripped, - user.id, - content_stripped, - ), + asyncio.to_thread(Memories.update_memory_by_id_and_user_id, memory_id, user.id, content), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.UPDATE.value elif operation.operation == Models.MemoryOperationType.DELETE: - id_stripped = operation.id.strip() - if not id_stripped: + if not memory_id: return Models.OperationResult.SKIPPED_EMPTY_ID.value - await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), + asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, memory_id, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.DELETE.value