refactor(memory_system): remove skip state cache and threshold from retrieval, unify skip logic for inlet and outlet

Skip state caching is eliminated to simplify state management, and skip logic is now handled directly in both inlet and outlet, improving clarity and reducing redundant cache usage; retrieval threshold is no longer returned in API responses, and skip reasons are separated for retrieval and consolidation to provide more precise status messaging.
This commit is contained in:
mtayfur
2025-11-29 12:31:06 +03:00
parent e97137bb4c
commit 879ea8a28d

View File

@@ -268,7 +268,6 @@ class UnifiedCacheManager:
self.EMBEDDING_CACHE = "embedding"
self.RETRIEVAL_CACHE = "retrieval"
self.MEMORY_CACHE = "memory"
self.SKIP_STATE_CACHE = "skip"
async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]:
"""Get value from cache with LRU updates."""
@@ -462,10 +461,17 @@ class SkipDetector:
class SkipReason(Enum):
SKIP_SIZE = "SKIP_SIZE"
SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL"
SKIP_ALL_NON_PERSONAL = "SKIP_ALL_NON_PERSONAL"
STATUS_MESSAGES = {
SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations",
SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations",
# Inlet (retrieval) status messages
INLET_STATUS_MESSAGES = {
SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, Skipping Memory Retrieval",
SkipReason.SKIP_NON_PERSONAL: "🚫 No Personal Content, Skipping Memory Retrieval",
}
# Outlet (consolidation) status messages
OUTLET_STATUS_MESSAGES = {
SkipReason.SKIP_ALL_NON_PERSONAL: "🚫 No Personal Content in Context, Skipping Memory Consolidation",
}
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]):
@@ -735,7 +741,7 @@ CANDIDATE MEMORIES:
llm_candidates = candidate_memories[:extended_count]
await self.memory_system._emit_status(
emitter,
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
f"🤖 Analyzing {len(llm_candidates)} Memories for Relevance",
done=False,
level=Constants.STATUS_LEVEL["Intermediate"],
)
@@ -745,9 +751,7 @@ CANDIDATE MEMORIES:
if not selected_memories:
logger.info("📭 No relevant memories after LLM analysis")
await self.memory_system._emit_status(
emitter, f"📭 No Relevant Memories After LLM Analysis", done=True, level=Constants.STATUS_LEVEL["Intermediate"]
)
await self.memory_system._emit_status(emitter, "📭 No Relevant Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"])
return selected_memories, analysis_info
else:
logger.info(f"⏩ Skipping LLM reranking: {decision_reason}")
@@ -886,7 +890,7 @@ class LLMConsolidationService:
)
except Exception as e:
logger.warning(f"🤖 LLM consolidation failed during memory processing: {str(e)}")
await self.memory_system._emit_status(emitter, f"⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"])
await self.memory_system._emit_status(emitter, "⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"])
return []
operations = response.ops
@@ -1130,7 +1134,7 @@ class LLMConsolidationService:
duration = time.time() - start_time
await self.memory_system._emit_status(
emitter,
f"Consolidation Complete: No Updates Needed",
"No Memory Updates Needed",
done=True,
level=Constants.STATUS_LEVEL["Detailed"],
)
@@ -1408,14 +1412,31 @@ class Filter:
skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
if skip_reason:
status_key = SkipDetector.SkipReason(skip_reason)
return True, SkipDetector.STATUS_MESSAGES[status_key]
return True, SkipDetector.INLET_STATUS_MESSAGES[status_key]
return False, ""
async def _should_skip_consolidation(self, conversation_context: List[str]) -> Tuple[bool, str]:
"""Check if consolidation should be skipped based on conversation context.
Returns (should_skip, reason). Skips only if ALL messages in context are skippable.
"""
logger.info(f"🔍 Evaluating {len(conversation_context)} messages for consolidation")
for idx, message in enumerate(conversation_context, 1):
skip_reason = await self._skip_detector.detect_skip_reason(message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
if not skip_reason: # Found at least one valuable message
logger.info(f"✅ Found personal content in message {idx}/{len(conversation_context)}, proceeding with consolidation")
return False, ""
# All messages were skippable
logger.info(f"🚫 All {len(conversation_context)} messages are non-personal, skipping consolidation")
return True, SkipDetector.OUTLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_ALL_NON_PERSONAL]
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped."""
user_message = self._get_last_user_message(body["messages"])
if not user_message:
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
return None, True, SkipDetector.INLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason
@@ -1523,7 +1544,6 @@ class Filter:
self._log_retrieved_memories(final_memories, "semantic")
return {
"memories": final_memories,
"threshold": self.valves.semantic_retrieval_threshold,
"all_similarities": cached_similarities,
"reranking_info": reranking_info,
}
@@ -1534,9 +1554,9 @@ class Filter:
if not user_memories:
logger.info("📭 No memories found for user")
await self._emit_status(emitter, "📭 No Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"])
return {"memories": [], "threshold": None}
return {"memories": []}
memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories)
memories, all_similarities = await self._compute_similarities(user_message, user_id, user_memories)
if memories:
final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter)
@@ -1550,7 +1570,6 @@ class Filter:
return {
"memories": final_memories,
"threshold": threshold,
"all_similarities": all_similarities,
"reranking_info": reranking_info,
}
@@ -1611,10 +1630,10 @@ class Filter:
memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat()
return memory_dict
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]:
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], List[Dict]]:
"""Compute similarity scores between user message and memories."""
if not user_memories:
return [], self.valves.semantic_retrieval_threshold, []
return [], []
query_embedding = await self._generate_embeddings(user_message, user_id)
memory_contents = [memory.content for memory in user_memories]
@@ -1632,9 +1651,8 @@ class Filter:
memory_data.sort(key=lambda x: x["relevance"], reverse=True)
threshold = self.valves.semantic_retrieval_threshold
filtered_memories = [m for m in memory_data if m["relevance"] >= threshold]
return filtered_memories, threshold, memory_data
filtered_memories = [m for m in memory_data if m["relevance"] >= self.valves.semantic_retrieval_threshold]
return filtered_memories, memory_data
async def inlet(
self,
@@ -1654,14 +1672,6 @@ class Filter:
user_message, should_skip, skip_reason = await self._process_user_message(body)
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
await self._cache_manager.put(
user_id,
self._cache_manager.SKIP_STATE_CACHE,
skip_cache_key,
should_skip,
)
if not user_message or should_skip:
if __event_emitter__ and skip_reason:
await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"])
@@ -1680,7 +1690,6 @@ class Filter:
)
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
memories = retrieval_result.get("memories", [])
threshold = retrieval_result.get("threshold")
all_similarities = retrieval_result.get("all_similarities", [])
if all_similarities:
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
@@ -1715,18 +1724,16 @@ class Filter:
if not user_message:
return body
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key)
conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages)
should_skip_consolidation, skip_reason = await self._should_skip_consolidation(conversation_context)
if should_skip:
logger.info("⏭️ Skipping outlet: inlet detected skip condition")
if should_skip_consolidation:
await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"])
return body
retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key)
conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages)
task = asyncio.create_task(
self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities, conversation_context)
)
@@ -1760,8 +1767,7 @@ class Filter:
try:
retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE)
embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE)
skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE)
logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding, {skip_state_cleared} skip entries")
logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding entries")
user_memories = await self._get_user_memories(user_id)
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)