refactor(memory_system): streamline memory deduplication, logging, and message extraction logic

Refactors deduplication to batch embedding generation for efficiency, consolidates user message extraction into a helper for reuse, replaces statistics.median with numpy for consistency, simplifies memory operation execution, and removes redundant logging and unused imports to improve maintainability and performance.
This commit is contained in:
mtayfur
2025-11-26 10:47:16 +03:00
parent cdedeee6ba
commit 0c87a815fc

View File

@@ -11,7 +11,6 @@ import asyncio
import hashlib import hashlib
import json import json
import logging import logging
import statistics
import time import time
from collections import OrderedDict from collections import OrderedDict
from datetime import datetime, timezone from datetime import datetime, timezone
@@ -298,7 +297,7 @@ class UnifiedCacheManager:
async with self._lock: async with self._lock:
if user_id not in self.caches: if user_id not in self.caches:
if len(self.caches) >= self.max_users: if len(self.caches) >= self.max_users:
evicted_user, _ = self.caches.popitem(last=False) self.caches.popitem(last=False)
self.caches[user_id] = {} self.caches[user_id] = {}
user_cache = self.caches[user_id] user_cache = self.caches[user_id]
@@ -309,7 +308,7 @@ class UnifiedCacheManager:
type_cache = user_cache[cache_type] type_cache = user_cache[cache_type]
if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type: if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type:
evicted_key, _ = type_cache.popitem(last=False) type_cache.popitem(last=False)
if key in type_cache: if key in type_cache:
type_cache[key] = value type_cache[key] = value
@@ -776,18 +775,20 @@ class LLMConsolidationService:
Check if content is semantically duplicate of existing memories using embeddings. Check if content is semantically duplicate of existing memories using embeddings.
Returns the ID of duplicate memory if found, None otherwise. Returns the ID of duplicate memory if found, None otherwise.
""" """
if not existing_memories: valid_memories = [m for m in existing_memories if m.content and len(m.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
if not valid_memories:
return None return None
content_embedding = await self.memory_system._generate_embeddings(content, user_id) memory_contents = [m.content for m in valid_memories]
all_texts = [content] + memory_contents
all_embeddings = await self.memory_system._generate_embeddings(all_texts, user_id)
for memory in existing_memories: content_embedding = all_embeddings[0]
if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: for i, memory in enumerate(valid_memories):
memory_embedding = all_embeddings[i + 1]
if memory_embedding is None:
continue continue
memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id)
similarity = float(np.dot(content_embedding, memory_embedding)) similarity = float(np.dot(content_embedding, memory_embedding))
if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD:
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
return str(memory.id) return str(memory.id)
@@ -999,23 +1000,16 @@ class LLMConsolidationService:
error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}" error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}"
logger.error(error_message) logger.error(error_message)
memory_contents_for_deletion = {} user_memories = await self.memory_system._get_user_memories(user_id)
if operations_by_type["DELETE"]: memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["DELETE"] else {}
user_memories = await self.memory_system._get_user_memories(user_id)
memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories}
if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: if operations_by_type["CREATE"]:
current_memories = await self.memory_system._get_user_memories(user_id) operations_by_type["CREATE"] = await self._deduplicate_operations(operations_by_type["CREATE"], user_memories, user_id, operation_type="CREATE")
if operations_by_type["CREATE"]: if operations_by_type["UPDATE"]:
operations_by_type["CREATE"] = await self._deduplicate_operations( operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" operations_by_type["UPDATE"], user_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
) )
if operations_by_type["UPDATE"]:
operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
)
for operation_type, ops in operations_by_type.items(): for operation_type, ops in operations_by_type.items():
if not ops: if not ops:
@@ -1194,6 +1188,8 @@ class Filter:
self.__user__ = __user__ self.__user__ = __user__
if __model__: if __model__:
self.__model__ = __model__ self.__model__ = __model__
if self.valves.memory_model:
logger.info(f"🤖 Using custom memory model: {__model__}")
if __request__: if __request__:
self.__request__ = __request__ self.__request__ = __request__
@@ -1255,6 +1251,17 @@ class Filter:
return content.get("text", "") return content.get("text", "")
return "" return ""
def _get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[str]:
"""Extract the last user message text from a list of messages."""
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
text = self._extract_text_from_content(content)
if text:
return text
return None
def _validate_system_configuration(self) -> None: def _validate_system_configuration(self) -> None:
"""Validate configuration and fail if invalid.""" """Validate configuration and fail if invalid."""
if self.valves.max_memories_returned <= 0: if self.valves.max_memories_returned <= 0:
@@ -1280,10 +1287,10 @@ class Filter:
async def _detect_embedding_dimension(self) -> None: async def _detect_embedding_dimension(self) -> None:
"""Detect embedding dimension by generating a test embedding.""" """Detect embedding dimension by generating a test embedding."""
test_embedding = await self._embedding_function("dummy", prefix=None, user=None) test_embedding = await self._embedding_function("dummy", prefix=None, user=None)
if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)): if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)):
test_embedding = test_embedding[0] test_embedding = test_embedding[0]
emb_array = np.squeeze(np.array(test_embedding)) emb_array = np.squeeze(np.array(test_embedding))
self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1 self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
@@ -1368,39 +1375,20 @@ class Filter:
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped.""" """Extract user message and determine if memory operations should be skipped."""
messages = body["messages"] user_message = self._get_last_user_message(body["messages"])
user_message = None if not user_message:
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
user_message = self._extract_text_from_content(content)
if user_message:
break
if not user_message or not user_message.strip():
return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
should_skip, skip_reason = await self._should_skip_memory_operations(user_message) should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason return user_message, should_skip, skip_reason
async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: async def _get_user_memories(self, user_id: str) -> List:
"""Get user memories with timeout handling.""" """Get user memories with timeout handling."""
if timeout is None:
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
memories = await asyncio.wait_for( memories = await asyncio.wait_for(
asyncio.to_thread(Memories.get_memories_by_user_id, user_id), asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
timeout=timeout, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return [m for m in (memories or []) if m.content] return [m for m in memories if m.content] if memories else []
def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None:
"""Log retrieved memories with concise formatting showing key statistics and semantic values.""" """Log retrieved memories with concise formatting showing key statistics and semantic values."""
@@ -1410,7 +1398,7 @@ class Filter:
scores = [memory["relevance"] for memory in memories] scores = [memory["relevance"] for memory in memories]
top_score = max(scores) top_score = max(scores)
lowest_score = min(scores) lowest_score = min(scores)
median_score = statistics.median(scores) median_score = float(np.median(scores))
context_label = "📊 Consolidation candidate memories" if context_type == "consolidation" else "📊 Retrieved memories" context_label = "📊 Consolidation candidate memories" if context_type == "consolidation" else "📊 Retrieved memories"
max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
@@ -1436,11 +1424,7 @@ class Filter:
return f"{cache_type}_{user_id}" return f"{cache_type}_{user_id}"
def format_current_datetime(self) -> str: def format_current_datetime(self) -> str:
try: return datetime.now(timezone.utc).strftime("%A %B %d %Y at %H:%M:%S UTC")
now = datetime.now(timezone.utc)
return now.strftime("%A %B %d %Y at %H:%M:%S UTC")
except Exception as e:
raise RuntimeError(f"📅 Failed to format datetime: {str(e)}")
def _format_memories_for_llm(self, memories: List[Dict[str, Any]]) -> List[str]: def _format_memories_for_llm(self, memories: List[Dict[str, Any]]) -> List[str]:
"""Format memories for LLM consumption with hybrid format and human-readable timestamps.""" """Format memories for LLM consumption with hybrid format and human-readable timestamps."""
@@ -1622,10 +1606,6 @@ class Filter:
"""Simplified inlet processing for memory retrieval and injection.""" """Simplified inlet processing for memory retrieval and injection."""
model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None) model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None)
if self.valves.memory_model:
logger.info(f"🤖 Using the custom model for memory : {model_to_use}")
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
@@ -1685,26 +1665,13 @@ class Filter:
"""Simplified outlet processing for background memory consolidation.""" """Simplified outlet processing for background memory consolidation."""
model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None) model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None)
if self.valves.memory_model:
logger.info(f"🤖 Using the custom model for memory : {model_to_use}")
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id:
return body return body
messages = body.get("messages", []) user_message = self._get_last_user_message(body.get("messages", []))
user_message = None
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
user_message = self._extract_text_from_content(content)
if user_message:
break
if not user_message: if not user_message:
return body return body
@@ -1780,44 +1747,34 @@ class Filter:
async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str:
"""Execute a single memory operation.""" """Execute a single memory operation."""
if operation.operation == Models.MemoryOperationType.CREATE: content = operation.content.strip()
content_stripped = operation.content.strip() memory_id = operation.id.strip()
if not content_stripped:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
if operation.operation == Models.MemoryOperationType.CREATE:
if not content:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), asyncio.to_thread(Memories.insert_new_memory, user.id, content),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.CREATE.value return Models.MemoryOperationType.CREATE.value
elif operation.operation == Models.MemoryOperationType.UPDATE: elif operation.operation == Models.MemoryOperationType.UPDATE:
id_stripped = operation.id.strip() if not memory_id:
if not id_stripped:
return Models.OperationResult.SKIPPED_EMPTY_ID.value return Models.OperationResult.SKIPPED_EMPTY_ID.value
if not content:
content_stripped = operation.content.strip()
if not content_stripped:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread( asyncio.to_thread(Memories.update_memory_by_id_and_user_id, memory_id, user.id, content),
Memories.update_memory_by_id_and_user_id,
id_stripped,
user.id,
content_stripped,
),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.UPDATE.value return Models.MemoryOperationType.UPDATE.value
elif operation.operation == Models.MemoryOperationType.DELETE: elif operation.operation == Models.MemoryOperationType.DELETE:
id_stripped = operation.id.strip() if not memory_id:
if not id_stripped:
return Models.OperationResult.SKIPPED_EMPTY_ID.value return Models.OperationResult.SKIPPED_EMPTY_ID.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, memory_id, user.id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.DELETE.value return Models.MemoryOperationType.DELETE.value