refactor(memory_system): streamline memory deduplication, logging, and message extraction logic

Refactors deduplication to batch embedding generation for efficiency, consolidates user message extraction into a helper for reuse, replaces statistics.median with numpy for consistency, simplifies memory operation execution, and removes redundant logging and unused imports to improve maintainability and performance.
This commit is contained in:
mtayfur
2025-11-26 10:47:16 +03:00
parent cdedeee6ba
commit 0c87a815fc

View File

@@ -11,7 +11,6 @@ import asyncio
import hashlib
import json
import logging
import statistics
import time
from collections import OrderedDict
from datetime import datetime, timezone
@@ -298,7 +297,7 @@ class UnifiedCacheManager:
async with self._lock:
if user_id not in self.caches:
if len(self.caches) >= self.max_users:
evicted_user, _ = self.caches.popitem(last=False)
self.caches.popitem(last=False)
self.caches[user_id] = {}
user_cache = self.caches[user_id]
@@ -309,7 +308,7 @@ class UnifiedCacheManager:
type_cache = user_cache[cache_type]
if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type:
evicted_key, _ = type_cache.popitem(last=False)
type_cache.popitem(last=False)
if key in type_cache:
type_cache[key] = value
@@ -776,18 +775,20 @@ class LLMConsolidationService:
Check if content is semantically duplicate of existing memories using embeddings.
Returns the ID of duplicate memory if found, None otherwise.
"""
if not existing_memories:
valid_memories = [m for m in existing_memories if m.content and len(m.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
if not valid_memories:
return None
content_embedding = await self.memory_system._generate_embeddings(content, user_id)
memory_contents = [m.content for m in valid_memories]
all_texts = [content] + memory_contents
all_embeddings = await self.memory_system._generate_embeddings(all_texts, user_id)
for memory in existing_memories:
if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS:
content_embedding = all_embeddings[0]
for i, memory in enumerate(valid_memories):
memory_embedding = all_embeddings[i + 1]
if memory_embedding is None:
continue
memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id)
similarity = float(np.dot(content_embedding, memory_embedding))
if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD:
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
return str(memory.id)
@@ -999,23 +1000,16 @@ class LLMConsolidationService:
error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}"
logger.error(error_message)
memory_contents_for_deletion = {}
if operations_by_type["DELETE"]:
user_memories = await self.memory_system._get_user_memories(user_id)
memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories}
user_memories = await self.memory_system._get_user_memories(user_id)
memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["DELETE"] else {}
if operations_by_type["CREATE"] or operations_by_type["UPDATE"]:
current_memories = await self.memory_system._get_user_memories(user_id)
if operations_by_type["CREATE"]:
operations_by_type["CREATE"] = await self._deduplicate_operations(operations_by_type["CREATE"], user_memories, user_id, operation_type="CREATE")
if operations_by_type["CREATE"]:
operations_by_type["CREATE"] = await self._deduplicate_operations(
operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE"
)
if operations_by_type["UPDATE"]:
operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
)
if operations_by_type["UPDATE"]:
operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["UPDATE"], user_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
)
for operation_type, ops in operations_by_type.items():
if not ops:
@@ -1194,6 +1188,8 @@ class Filter:
self.__user__ = __user__
if __model__:
self.__model__ = __model__
if self.valves.memory_model:
logger.info(f"🤖 Using custom memory model: {__model__}")
if __request__:
self.__request__ = __request__
@@ -1255,6 +1251,17 @@ class Filter:
return content.get("text", "")
return ""
def _get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[str]:
"""Extract the last user message text from a list of messages."""
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
text = self._extract_text_from_content(content)
if text:
return text
return None
def _validate_system_configuration(self) -> None:
"""Validate configuration and fail if invalid."""
if self.valves.max_memories_returned <= 0:
@@ -1280,10 +1287,10 @@ class Filter:
async def _detect_embedding_dimension(self) -> None:
"""Detect embedding dimension by generating a test embedding."""
test_embedding = await self._embedding_function("dummy", prefix=None, user=None)
if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)):
test_embedding = test_embedding[0]
emb_array = np.squeeze(np.array(test_embedding))
self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
@@ -1368,39 +1375,20 @@ class Filter:
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped."""
messages = body["messages"]
user_message = None
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
user_message = self._extract_text_from_content(content)
if user_message:
break
if not user_message or not user_message.strip():
return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
user_message = self._get_last_user_message(body["messages"])
if not user_message:
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason
async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List:
async def _get_user_memories(self, user_id: str) -> List:
"""Get user memories with timeout handling."""
if timeout is None:
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
memories = await asyncio.wait_for(
asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
timeout=timeout,
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return [m for m in (memories or []) if m.content]
return [m for m in memories if m.content] if memories else []
def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None:
"""Log retrieved memories with concise formatting showing key statistics and semantic values."""
@@ -1410,7 +1398,7 @@ class Filter:
scores = [memory["relevance"] for memory in memories]
top_score = max(scores)
lowest_score = min(scores)
median_score = statistics.median(scores)
median_score = float(np.median(scores))
context_label = "📊 Consolidation candidate memories" if context_type == "consolidation" else "📊 Retrieved memories"
max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
@@ -1436,11 +1424,7 @@ class Filter:
return f"{cache_type}_{user_id}"
def format_current_datetime(self) -> str:
try:
now = datetime.now(timezone.utc)
return now.strftime("%A %B %d %Y at %H:%M:%S UTC")
except Exception as e:
raise RuntimeError(f"📅 Failed to format datetime: {str(e)}")
return datetime.now(timezone.utc).strftime("%A %B %d %Y at %H:%M:%S UTC")
def _format_memories_for_llm(self, memories: List[Dict[str, Any]]) -> List[str]:
"""Format memories for LLM consumption with hybrid format and human-readable timestamps."""
@@ -1622,10 +1606,6 @@ class Filter:
"""Simplified inlet processing for memory retrieval and injection."""
model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None)
if self.valves.memory_model:
logger.info(f"🤖 Using the custom model for memory : {model_to_use}")
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None
@@ -1685,26 +1665,13 @@ class Filter:
"""Simplified outlet processing for background memory consolidation."""
model_to_use = self.valves.memory_model or (body.get("model") if isinstance(body, dict) else None)
if self.valves.memory_model:
logger.info(f"🤖 Using the custom model for memory : {model_to_use}")
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None
if not user_id:
return body
messages = body.get("messages", [])
user_message = None
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
user_message = self._extract_text_from_content(content)
if user_message:
break
user_message = self._get_last_user_message(body.get("messages", []))
if not user_message:
return body
@@ -1780,44 +1747,34 @@ class Filter:
async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str:
"""Execute a single memory operation."""
if operation.operation == Models.MemoryOperationType.CREATE:
content_stripped = operation.content.strip()
if not content_stripped:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
content = operation.content.strip()
memory_id = operation.id.strip()
if operation.operation == Models.MemoryOperationType.CREATE:
if not content:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for(
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
asyncio.to_thread(Memories.insert_new_memory, user.id, content),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.CREATE.value
elif operation.operation == Models.MemoryOperationType.UPDATE:
id_stripped = operation.id.strip()
if not id_stripped:
if not memory_id:
return Models.OperationResult.SKIPPED_EMPTY_ID.value
content_stripped = operation.content.strip()
if not content_stripped:
if not content:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for(
asyncio.to_thread(
Memories.update_memory_by_id_and_user_id,
id_stripped,
user.id,
content_stripped,
),
asyncio.to_thread(Memories.update_memory_by_id_and_user_id, memory_id, user.id, content),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.UPDATE.value
elif operation.operation == Models.MemoryOperationType.DELETE:
id_stripped = operation.id.strip()
if not id_stripped:
if not memory_id:
return Models.OperationResult.SKIPPED_EMPTY_ID.value
await asyncio.wait_for(
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, memory_id, user.id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.DELETE.value