mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
refactor(memory_system): remove skip state cache and threshold from retrieval, unify skip logic for inlet and outlet
Skip state caching is eliminated to simplify state management, and skip logic is now handled directly in both inlet and outlet, improving clarity and reducing redundant cache usage; retrieval threshold is no longer returned in API responses, and skip reasons are separated for retrieval and consolidation to provide more precise status messaging.
This commit is contained in:
@@ -268,7 +268,6 @@ class UnifiedCacheManager:
|
||||
self.EMBEDDING_CACHE = "embedding"
|
||||
self.RETRIEVAL_CACHE = "retrieval"
|
||||
self.MEMORY_CACHE = "memory"
|
||||
self.SKIP_STATE_CACHE = "skip"
|
||||
|
||||
async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]:
|
||||
"""Get value from cache with LRU updates."""
|
||||
@@ -462,10 +461,17 @@ class SkipDetector:
|
||||
class SkipReason(Enum):
|
||||
SKIP_SIZE = "SKIP_SIZE"
|
||||
SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL"
|
||||
SKIP_ALL_NON_PERSONAL = "SKIP_ALL_NON_PERSONAL"
|
||||
|
||||
STATUS_MESSAGES = {
|
||||
SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations",
|
||||
SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations",
|
||||
# Inlet (retrieval) status messages
|
||||
INLET_STATUS_MESSAGES = {
|
||||
SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, Skipping Memory Retrieval",
|
||||
SkipReason.SKIP_NON_PERSONAL: "🚫 No Personal Content, Skipping Memory Retrieval",
|
||||
}
|
||||
|
||||
# Outlet (consolidation) status messages
|
||||
OUTLET_STATUS_MESSAGES = {
|
||||
SkipReason.SKIP_ALL_NON_PERSONAL: "🚫 No Personal Content in Context, Skipping Memory Consolidation",
|
||||
}
|
||||
|
||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]):
|
||||
@@ -735,7 +741,7 @@ CANDIDATE MEMORIES:
|
||||
llm_candidates = candidate_memories[:extended_count]
|
||||
await self.memory_system._emit_status(
|
||||
emitter,
|
||||
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
|
||||
f"🤖 Analyzing {len(llm_candidates)} Memories for Relevance",
|
||||
done=False,
|
||||
level=Constants.STATUS_LEVEL["Intermediate"],
|
||||
)
|
||||
@@ -745,9 +751,7 @@ CANDIDATE MEMORIES:
|
||||
|
||||
if not selected_memories:
|
||||
logger.info("📭 No relevant memories after LLM analysis")
|
||||
await self.memory_system._emit_status(
|
||||
emitter, f"📭 No Relevant Memories After LLM Analysis", done=True, level=Constants.STATUS_LEVEL["Intermediate"]
|
||||
)
|
||||
await self.memory_system._emit_status(emitter, "📭 No Relevant Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"])
|
||||
return selected_memories, analysis_info
|
||||
else:
|
||||
logger.info(f"⏩ Skipping LLM reranking: {decision_reason}")
|
||||
@@ -886,7 +890,7 @@ class LLMConsolidationService:
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"🤖 LLM consolidation failed during memory processing: {str(e)}")
|
||||
await self.memory_system._emit_status(emitter, f"⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"])
|
||||
await self.memory_system._emit_status(emitter, "⚠️ Memory Consolidation Failed", done=True, level=Constants.STATUS_LEVEL["Basic"])
|
||||
return []
|
||||
|
||||
operations = response.ops
|
||||
@@ -1130,7 +1134,7 @@ class LLMConsolidationService:
|
||||
duration = time.time() - start_time
|
||||
await self.memory_system._emit_status(
|
||||
emitter,
|
||||
f"✅ Consolidation Complete: No Updates Needed",
|
||||
"✅ No Memory Updates Needed",
|
||||
done=True,
|
||||
level=Constants.STATUS_LEVEL["Detailed"],
|
||||
)
|
||||
@@ -1408,14 +1412,31 @@ class Filter:
|
||||
skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
|
||||
if skip_reason:
|
||||
status_key = SkipDetector.SkipReason(skip_reason)
|
||||
return True, SkipDetector.STATUS_MESSAGES[status_key]
|
||||
return True, SkipDetector.INLET_STATUS_MESSAGES[status_key]
|
||||
return False, ""
|
||||
|
||||
async def _should_skip_consolidation(self, conversation_context: List[str]) -> Tuple[bool, str]:
|
||||
"""Check if consolidation should be skipped based on conversation context.
|
||||
|
||||
Returns (should_skip, reason). Skips only if ALL messages in context are skippable.
|
||||
"""
|
||||
logger.info(f"🔍 Evaluating {len(conversation_context)} messages for consolidation")
|
||||
|
||||
for idx, message in enumerate(conversation_context, 1):
|
||||
skip_reason = await self._skip_detector.detect_skip_reason(message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
|
||||
if not skip_reason: # Found at least one valuable message
|
||||
logger.info(f"✅ Found personal content in message {idx}/{len(conversation_context)}, proceeding with consolidation")
|
||||
return False, ""
|
||||
|
||||
# All messages were skippable
|
||||
logger.info(f"🚫 All {len(conversation_context)} messages are non-personal, skipping consolidation")
|
||||
return True, SkipDetector.OUTLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_ALL_NON_PERSONAL]
|
||||
|
||||
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
||||
"""Extract user message and determine if memory operations should be skipped."""
|
||||
user_message = self._get_last_user_message(body["messages"])
|
||||
if not user_message:
|
||||
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
||||
return None, True, SkipDetector.INLET_STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
||||
|
||||
should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
|
||||
return user_message, should_skip, skip_reason
|
||||
@@ -1523,7 +1544,6 @@ class Filter:
|
||||
self._log_retrieved_memories(final_memories, "semantic")
|
||||
return {
|
||||
"memories": final_memories,
|
||||
"threshold": self.valves.semantic_retrieval_threshold,
|
||||
"all_similarities": cached_similarities,
|
||||
"reranking_info": reranking_info,
|
||||
}
|
||||
@@ -1534,9 +1554,9 @@ class Filter:
|
||||
if not user_memories:
|
||||
logger.info("📭 No memories found for user")
|
||||
await self._emit_status(emitter, "📭 No Memories Found", done=True, level=Constants.STATUS_LEVEL["Intermediate"])
|
||||
return {"memories": [], "threshold": None}
|
||||
return {"memories": []}
|
||||
|
||||
memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories)
|
||||
memories, all_similarities = await self._compute_similarities(user_message, user_id, user_memories)
|
||||
|
||||
if memories:
|
||||
final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter)
|
||||
@@ -1550,7 +1570,6 @@ class Filter:
|
||||
|
||||
return {
|
||||
"memories": final_memories,
|
||||
"threshold": threshold,
|
||||
"all_similarities": all_similarities,
|
||||
"reranking_info": reranking_info,
|
||||
}
|
||||
@@ -1611,10 +1630,10 @@ class Filter:
|
||||
memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat()
|
||||
return memory_dict
|
||||
|
||||
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]:
|
||||
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], List[Dict]]:
|
||||
"""Compute similarity scores between user message and memories."""
|
||||
if not user_memories:
|
||||
return [], self.valves.semantic_retrieval_threshold, []
|
||||
return [], []
|
||||
|
||||
query_embedding = await self._generate_embeddings(user_message, user_id)
|
||||
memory_contents = [memory.content for memory in user_memories]
|
||||
@@ -1632,9 +1651,8 @@ class Filter:
|
||||
|
||||
memory_data.sort(key=lambda x: x["relevance"], reverse=True)
|
||||
|
||||
threshold = self.valves.semantic_retrieval_threshold
|
||||
filtered_memories = [m for m in memory_data if m["relevance"] >= threshold]
|
||||
return filtered_memories, threshold, memory_data
|
||||
filtered_memories = [m for m in memory_data if m["relevance"] >= self.valves.semantic_retrieval_threshold]
|
||||
return filtered_memories, memory_data
|
||||
|
||||
async def inlet(
|
||||
self,
|
||||
@@ -1654,14 +1672,6 @@ class Filter:
|
||||
|
||||
user_message, should_skip, skip_reason = await self._process_user_message(body)
|
||||
|
||||
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
|
||||
await self._cache_manager.put(
|
||||
user_id,
|
||||
self._cache_manager.SKIP_STATE_CACHE,
|
||||
skip_cache_key,
|
||||
should_skip,
|
||||
)
|
||||
|
||||
if not user_message or should_skip:
|
||||
if __event_emitter__ and skip_reason:
|
||||
await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"])
|
||||
@@ -1680,7 +1690,6 @@ class Filter:
|
||||
)
|
||||
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
|
||||
memories = retrieval_result.get("memories", [])
|
||||
threshold = retrieval_result.get("threshold")
|
||||
all_similarities = retrieval_result.get("all_similarities", [])
|
||||
if all_similarities:
|
||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||
@@ -1715,18 +1724,16 @@ class Filter:
|
||||
if not user_message:
|
||||
return body
|
||||
|
||||
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
|
||||
should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key)
|
||||
conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages)
|
||||
should_skip_consolidation, skip_reason = await self._should_skip_consolidation(conversation_context)
|
||||
|
||||
if should_skip:
|
||||
logger.info("⏭️ Skipping outlet: inlet detected skip condition")
|
||||
if should_skip_consolidation:
|
||||
await self._emit_status(__event_emitter__, skip_reason, done=True, level=Constants.STATUS_LEVEL["Intermediate"])
|
||||
return body
|
||||
|
||||
retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key)
|
||||
|
||||
conversation_context = self._get_recent_user_messages(body.get("messages", []), self.valves.max_consolidation_context_messages)
|
||||
|
||||
task = asyncio.create_task(
|
||||
self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities, conversation_context)
|
||||
)
|
||||
@@ -1760,8 +1767,7 @@ class Filter:
|
||||
try:
|
||||
retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE)
|
||||
embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE)
|
||||
skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE)
|
||||
logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding, {skip_state_cleared} skip entries")
|
||||
logger.info(f"🔄 Cleared cache: {retrieval_cleared} retrieval, {embedding_cleared} embedding entries")
|
||||
|
||||
user_memories = await self._get_user_memories(user_id)
|
||||
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
|
||||
|
||||
Reference in New Issue
Block a user