feat(memory_system): enhance memory consolidation with multi-message context and semantic deduplication optimizations

Add support for using recent conversation context to resolve pronouns and ambiguous references during memory consolidation, improve semantic deduplication by precomputing embeddings, and cache user memories for efficiency; these changes increase accuracy and performance in extracting and managing user memories, especially in multi-turn conversations.
This commit is contained in:
mtayfur
2025-11-29 00:56:17 +03:00
parent 5bf5f775f4
commit e97137bb4c

View File

@@ -40,6 +40,7 @@ class Constants:
MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query
MAX_MESSAGE_CHARS = 3000 # Maximum message length for validation
MIN_MESSAGE_CHARS = 10 # Minimum message length for validation
MAX_CONSOLIDATION_CONTEXT_MESSAGES = 3 # Number of recent messages to include for pronoun/context resolution
DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup
LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations
@@ -97,7 +98,7 @@ Your goal is to build precise memories of the user's personal narrative with fac
- Ensure Memory Quality:
- High Bar for Creation: Only CREATE memories for significant life facts, relationships, events, or core personal attributes. Skip trivial details or passing interests.
- Conciseness: Limit each memory to {Constants.MAX_MEMORY_SENTENCES} sentences and {Constants.MAX_MEMORY_CONTENT_CHARS} characters. If a new or existing memory exceeds these limits, decompose it into smaller, self-contained facts.
- Entity Resolution: Link pronouns (he, she, they) and generic nouns to their specific named entities before storing.
- Entity Resolution: Link pronouns (he, she, they) and generic nouns to their specific named entities before storing. When a CONVERSATION is provided, use it to resolve ambiguous references (e.g., "the project", "that place") and provide missing context.
- Capture Relationships: Store relationships with complete context. Never store incomplete relationships—always specify with whom.
- Retroactive Enrichment: UPDATE existing memories only when adding a name or correcting a significant fact.
- First-Person Format: Write all memories in English from the user's perspective.
@@ -139,6 +140,15 @@ Message: "Fix this Python function that calculates my age from my birthdate of M
Memories: []
Return: {{"ops": []}}
Explanation: The primary intent is technical (fix code). Personal data (birthdate) is provided as input for the task, not as a direct personal statement. The user is not stating "My birthday is March 15, 1990" but using it as a parameter. SKIP.
### Example 6 (Multi-Message Context)
CONVERSATION:
[1] "I had lunch with my colleague Maria yesterday at the new cafe downtown."
[2] "She's been working on the same AI project as me for 3 months now."
[3] "She mentioned she might be leaving the company next month."
Memories: []
Return: {{"ops": [{{"operation": "CREATE", "id": "", "content": "My colleague Maria and I have been working together on an AI project for 3 months as of September 2025"}}, {{"operation": "CREATE", "id": "", "content": "My colleague Maria mentioned in September 2025 that she might be leaving the company in October 2025"}}]}}
Explanation: The pronoun "She" in messages 2 and 3 refers to "Maria" from message 1. The conversation context enables proper entity resolution. Lunch location is transient and skipped. Work relationship and departure news are significant facts.
"""
MEMORY_RERANKING = f"""You are the Memory Relevance Analyzer.
@@ -610,13 +620,7 @@ class SkipDetector:
return None
async def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
"""
Detect if a message should be skipped using two-stage detection:
1. Fast-path structural patterns (~95% confidence)
2. Binary semantic classification (personal vs non-personal)
Returns:
Skip reason string if content should be skipped, None otherwise
"""
"""Detect if a message should be skipped using two-stage detection: fast-path structural patterns and binary semantic classification."""
size_issue = self.validate_message_size(message, max_message_chars)
if size_issue:
return size_issue
@@ -767,24 +771,21 @@ class LLMConsolidationService:
def __init__(self, memory_system):
self.memory_system = memory_system
async def _check_semantic_duplicate(self, content: str, existing_memories: List, user_id: str) -> Optional[str]:
"""
Check if content is semantically duplicate of existing memories using embeddings.
Returns the ID of duplicate memory if found, None otherwise.
"""
valid_memories = [m for m in existing_memories if m.content and len(m.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
if not valid_memories:
return None
memory_contents = [m.content for m in valid_memories]
all_texts = [content] + memory_contents
all_embeddings = await self.memory_system._generate_embeddings(all_texts, user_id)
content_embedding = all_embeddings[0]
for i, memory in enumerate(valid_memories):
memory_embedding = all_embeddings[i + 1]
async def _check_semantic_duplicate(
self,
content_embedding: np.ndarray,
memory_embeddings: List[np.ndarray],
memories: List,
exclude_id: Optional[str] = None,
) -> Optional[str]:
"""Check if content embedding is semantically duplicate of existing memories."""
for i, memory_embedding in enumerate(memory_embeddings):
if memory_embedding is None:
continue
memory = memories[i]
if exclude_id and str(memory.id) == exclude_id:
continue
similarity = float(np.dot(content_embedding, memory_embedding))
if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD:
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
@@ -855,6 +856,7 @@ class LLMConsolidationService:
user_message: str,
candidate_memories: List[Dict[str, Any]],
emitter: Optional[Callable] = None,
conversation_context: Optional[List[str]] = None,
) -> List[Dict[str, Any]]:
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
if candidate_memories:
@@ -863,9 +865,15 @@ class LLMConsolidationService:
else:
memory_context = "EXISTING MEMORIES FOR CONSOLIDATION:\n[]\n\nNote: No existing memories found - Focus on extracting new memories from the user message below.\n\n"
if conversation_context and len(conversation_context) > 1:
context_lines = [f'[{i+1}] "{msg}"' for i, msg in enumerate(conversation_context)]
message_section = f"CONVERSATION (recent messages for pronoun/context resolution):\n{chr(10).join(context_lines)}"
else:
message_section = f"USER MESSAGE: {user_message}"
user_prompt = f"""CURRENT DATE/TIME: {self.memory_system.format_current_datetime()}
{memory_context}USER MESSAGE: {user_message}"""
{memory_context}{message_section}"""
try:
response = await asyncio.wait_for(
@@ -936,18 +944,29 @@ class LLMConsolidationService:
async def _deduplicate_operations(
self, operations: List, current_memories: List, user_id: str, operation_type: str, delete_operations: Optional[List] = None
) -> List:
"""
Deduplicate operations against existing memories using semantic similarity.
For UPDATE operations, preserves enriched content and deletes the duplicate.
"""
"""Semantically deduplicate operations against existing memories. For UPDATEs, preserve enriched content and delete duplicates."""
if not operations:
return []
deduplicated = []
for operation in operations:
memories_to_check = current_memories
if operation_type == "UPDATE":
memories_to_check = [m for m in current_memories if str(m.id) != operation.id]
valid_memories = [m for m in current_memories if m.content and len(m.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
memory_embeddings = []
if valid_memories:
memory_contents = [m.content for m in valid_memories]
memory_embeddings = await self.memory_system._generate_embeddings(memory_contents, user_id)
duplicate_id = await self._check_semantic_duplicate(operation.content, memories_to_check, user_id)
op_contents = [op.content for op in operations]
op_embeddings = await self.memory_system._generate_embeddings(op_contents, user_id)
for i, operation in enumerate(operations):
op_embedding = op_embeddings[i]
if op_embedding is None or not valid_memories:
deduplicated.append(operation)
continue
exclude_id = operation.id if operation_type == "UPDATE" else None
duplicate_id = await self._check_semantic_duplicate(op_embedding, memory_embeddings, valid_memories, exclude_id)
if duplicate_id:
if operation_type == "UPDATE" and delete_operations is not None:
@@ -991,7 +1010,18 @@ class LLMConsolidationService:
error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}"
logger.error(error_message)
user_memories = await self.memory_system._get_user_memories(user_id)
memory_cache_key = self.memory_system._cache_key(self.memory_system._cache_manager.MEMORY_CACHE, user_id)
user_memories = await self.memory_system._cache_manager.get(user_id, self.memory_system._cache_manager.MEMORY_CACHE, memory_cache_key)
if user_memories is None:
user_memories = await self.memory_system._get_user_memories(user_id)
await self.memory_system._cache_manager.put(
user_id,
self.memory_system._cache_manager.MEMORY_CACHE,
memory_cache_key,
user_memories,
)
memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} if operations_by_type["DELETE"] else {}
if operations_by_type["CREATE"]:
@@ -1057,6 +1087,7 @@ class LLMConsolidationService:
user_id: str,
emitter: Optional[Callable] = None,
cached_similarities: Optional[List[Dict[str, Any]]] = None,
conversation_context: Optional[List[str]] = None,
) -> None:
"""Complete consolidation pipeline with simplified flow."""
start_time = time.time()
@@ -1068,7 +1099,7 @@ class LLMConsolidationService:
if self.memory_system._shutdown_event.is_set():
return
operations = await self.generate_consolidation_plan(user_message, candidates, emitter)
operations = await self.generate_consolidation_plan(user_message, candidates, emitter, conversation_context)
if self.memory_system._shutdown_event.is_set():
return
@@ -1144,6 +1175,10 @@ class Filter:
default="Intermediate",
description="Status message verbosity level: Basic (summary counts only), Intermediate (summaries and key details), Detailed (all details)",
)
max_consolidation_context_messages: int = Field(
default=Constants.MAX_CONSOLIDATION_CONTEXT_MESSAGES,
description="Maximum number of recent user messages to include in consolidation context",
)
def __init__(self):
"""Initialize the Memory System filter with production validation."""
@@ -1242,16 +1277,25 @@ class Filter:
return content.get("text", "")
return ""
def _get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[str]:
"""Extract the last user message text from a list of messages."""
def _get_recent_user_messages(self, messages: List[Dict[str, Any]], max_messages: int = 3) -> List[str]:
"""Extract recent user messages for conversation context in consolidation."""
user_messages = []
for message in reversed(messages):
if not isinstance(message, dict) or message.get("role") != "user":
continue
content = message.get("content", "")
text = self._extract_text_from_content(content)
if text:
return text
return None
user_messages.append(text)
if len(user_messages) >= max_messages:
break
user_messages.reverse()
return user_messages
def _get_last_user_message(self, messages: List[Dict[str, Any]]) -> Optional[str]:
"""Extract the last user message text from a list of messages."""
recent = self._get_recent_user_messages(messages, max_messages=1)
return recent[0] if recent else None
def _validate_system_configuration(self) -> None:
"""Validate configuration and fail if invalid."""
@@ -1261,6 +1305,9 @@ class Filter:
if not (0.0 <= self.valves.semantic_retrieval_threshold <= 1.0):
raise ValueError(f"🎯 Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)")
if self.valves.max_consolidation_context_messages <= 0:
raise ValueError(f"📊 Invalid max consolidation context messages: {self.valves.max_consolidation_context_messages}")
logger.info("✅ Configuration validated")
async def _get_embedding_cache(self, user_id: str, key: str) -> Optional[Any]:
@@ -1677,7 +1724,12 @@ class Filter:
retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key)
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages)
task = asyncio.create_task(
self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities, conversation_context)
)
self._background_tasks.add(task)
def safe_cleanup(t: asyncio.Task) -> None: