refactor(memory_system): remove excessive try/except and input validation, streamline async operations, and add skip state cache

Removes redundant try/except blocks and input validation in several
methods to simplify logic and improve readability. Moves error handling
to higher levels where appropriate. Adds a skip state cache to track
when memory operations should be skipped, improving efficiency by
avoiding repeated skip checks. Cleans up batch operation execution and
cache clearing to include the new skip state. These changes reduce
unnecessary code complexity and improve maintainability, while also
optimizing memory operation flow and cache management.
This commit is contained in:
mtayfur
2025-10-28 17:51:09 +03:00
parent 8ced9aace5
commit 8997f71f05

View File

@@ -254,6 +254,7 @@ class UnifiedCacheManager:
self.EMBEDDING_CACHE = "embedding" self.EMBEDDING_CACHE = "embedding"
self.RETRIEVAL_CACHE = "retrieval" self.RETRIEVAL_CACHE = "retrieval"
self.MEMORY_CACHE = "memory" self.MEMORY_CACHE = "memory"
self.SKIP_STATE_CACHE = "skip"
async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]: async def get(self, user_id: str, cache_type: str, key: str) -> Optional[Any]:
"""Get value from cache with LRU updates.""" """Get value from cache with LRU updates."""
@@ -747,25 +748,20 @@ class LLMConsolidationService:
if not existing_memories: if not existing_memories:
return None return None
try: content_embedding = await self.memory_system._generate_embeddings(content, user_id)
content_embedding = await self.memory_system._generate_embeddings(content, user_id)
for memory in existing_memories: for memory in existing_memories:
if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS:
continue continue
memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id)
similarity = float(np.dot(content_embedding, memory_embedding))
similarity = float(np.dot(content_embedding, memory_embedding)) if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD:
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
return str(memory.id)
if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: return None
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
return str(memory.id)
return None
except Exception as e:
logger.warning(f"⚠️ Semantic duplicate check failed: {str(e)}")
return None
def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]:
"""Filter consolidation candidates by threshold and return candidates with threshold info.""" """Filter consolidation candidates by threshold and return candidates with threshold info."""
@@ -945,21 +941,13 @@ class LLMConsolidationService:
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]:
"""Execute consolidation operations with simplified tracking.""" """Execute consolidation operations with simplified tracking."""
if not operations or not user_id: if not operations:
return 0, 0, 0, 0 return 0, 0, 0, 0
try: user = await asyncio.wait_for(
user = await asyncio.wait_for( asyncio.to_thread(Users.get_user_by_id, user_id),
asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, )
)
except asyncio.TimeoutError:
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
except Exception as e:
raise RuntimeError(f"👤 User lookup failed: {str(e)}")
if not user:
raise ValueError(f"👤 User not found for consolidation: {user_id}")
created_count = updated_count = deleted_count = failed_count = 0 created_count = updated_count = deleted_count = failed_count = 0
@@ -982,28 +970,21 @@ class LLMConsolidationService:
memory_contents_for_deletion = {} memory_contents_for_deletion = {}
if operations_by_type["DELETE"]: if operations_by_type["DELETE"]:
try: user_memories = await self.memory_system._get_user_memories(user_id)
user_memories = await self.memory_system._get_user_memories(user_id) memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories}
memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories}
except Exception as e:
logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}")
if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: if operations_by_type["CREATE"] or operations_by_type["UPDATE"]:
try: current_memories = await self.memory_system._get_user_memories(user_id)
current_memories = await self.memory_system._get_user_memories(user_id)
if operations_by_type["CREATE"]: if operations_by_type["CREATE"]:
operations_by_type["CREATE"] = await self._deduplicate_operations( operations_by_type["CREATE"] = await self._deduplicate_operations(
operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE"
) )
if operations_by_type["UPDATE"]: if operations_by_type["UPDATE"]:
operations_by_type["UPDATE"] = await self._deduplicate_operations( operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
) )
except Exception as e:
logger.warning(f"⚠️ Semantic deduplication check failed, proceeding with original operations: {str(e)}")
for operation_type, ops in operations_by_type.items(): for operation_type, ops in operations_by_type.items():
if not ops: if not ops:
@@ -1014,38 +995,33 @@ class LLMConsolidationService:
task = self.memory_system._execute_single_operation(operation, user) task = self.memory_system._execute_single_operation(operation, user)
batch_tasks.append(task) batch_tasks.append(task)
try: results = await asyncio.gather(*batch_tasks, return_exceptions=True)
results = await asyncio.gather(*batch_tasks, return_exceptions=True) for idx, result in enumerate(results):
for idx, result in enumerate(results): operation = ops[idx]
operation = ops[idx]
if isinstance(result, Exception): if isinstance(result, Exception):
failed_count += 1 failed_count += 1
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
elif result == Models.MemoryOperationType.CREATE.value: elif result == Models.MemoryOperationType.CREATE.value:
created_count += 1 created_count += 1
content_preview = self.memory_system._truncate_content(operation.content) content_preview = self.memory_system._truncate_content(operation.content)
await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False)
elif result == Models.MemoryOperationType.UPDATE.value: elif result == Models.MemoryOperationType.UPDATE.value:
updated_count += 1 updated_count += 1
content_preview = self.memory_system._truncate_content(operation.content) content_preview = self.memory_system._truncate_content(operation.content)
await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False)
elif result == Models.MemoryOperationType.DELETE.value: elif result == Models.MemoryOperationType.DELETE.value:
deleted_count += 1 deleted_count += 1
content_preview = memory_contents_for_deletion.get(operation.id, operation.id) content_preview = memory_contents_for_deletion.get(operation.id, operation.id)
if content_preview and content_preview != operation.id: if content_preview and content_preview != operation.id:
content_preview = self.memory_system._truncate_content(content_preview) content_preview = self.memory_system._truncate_content(content_preview)
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
elif result in [ elif result in [
Models.OperationResult.FAILED.value, Models.OperationResult.FAILED.value,
Models.OperationResult.UNSUPPORTED.value, Models.OperationResult.UNSUPPORTED.value,
]: ]:
failed_count += 1 failed_count += 1
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
except Exception as e:
failed_count += len(ops)
logger.error(f"❌ Batch {operation_type} operations failed during memory consolidation: {str(e)}")
await self.memory_system._emit_status(emitter, f"❌ Batch {operation_type} Failed", done=False)
total_executed = created_count + updated_count + deleted_count total_executed = created_count + updated_count + deleted_count
logger.info( logger.info(
@@ -1296,17 +1272,9 @@ class Filter:
async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]:
"""Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function."""
if self._embedding_function is None:
raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.")
is_single = isinstance(texts, str) is_single = isinstance(texts, str)
text_list = [texts] if is_single else texts text_list = [texts] if is_single else texts
if not text_list:
if is_single:
raise ValueError("📏 Empty text provided for embedding generation")
return []
result_embeddings = [] result_embeddings = []
uncached_texts = [] uncached_texts = []
uncached_indices = [] uncached_indices = []
@@ -1360,9 +1328,6 @@ class Filter:
return result_embeddings return result_embeddings
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
if self._skip_detector is None:
raise RuntimeError("🤖 Skip detector not initialized")
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self)
if skip_reason: if skip_reason:
status_key = SkipDetector.SkipReason(skip_reason) status_key = SkipDetector.SkipReason(skip_reason)
@@ -1371,13 +1336,6 @@ class Filter:
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped.""" """Extract user message and determine if memory operations should be skipped."""
if not body or "messages" not in body or not isinstance(body["messages"], list):
return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
messages = body["messages"] messages = body["messages"]
user_message = None user_message = None
@@ -1421,10 +1379,6 @@ class Filter:
return return
scores = [memory["relevance"] for memory in memories] scores = [memory["relevance"] for memory in memories]
if not scores:
return
top_score = max(scores) top_score = max(scores)
lowest_score = min(scores) lowest_score = min(scores)
median_score = statistics.median(scores) median_score = statistics.median(scores)
@@ -1485,13 +1439,9 @@ class Filter:
return return
payload = {"type": "status", "data": {"description": description, "done": done}} payload = {"type": "status", "data": {"description": description, "done": done}}
result = emitter(payload)
try: if asyncio.iscoroutine(result):
result = emitter(payload) await result
if asyncio.iscoroutine(result):
await result
except Exception:
pass
async def _retrieve_relevant_memories( async def _retrieve_relevant_memories(
self, self,
@@ -1549,10 +1499,6 @@ class Filter:
emitter: Optional[Callable] = None, emitter: Optional[Callable] = None,
) -> None: ) -> None:
"""Add memory context to request body with simplified logic.""" """Add memory context to request body with simplified logic."""
if not body or "messages" not in body or not body["messages"]:
logger.warning("⚠️ Invalid request body or no messages found")
return
content_parts = [f"Current Date/Time: {self.format_current_datetime()}"] content_parts = [f"Current Date/Time: {self.format_current_datetime()}"]
memory_count = 0 memory_count = 0
@@ -1610,26 +1556,16 @@ class Filter:
memory_contents = [memory.content for memory in user_memories] memory_contents = [memory.content for memory in user_memories]
memory_embeddings = await self._generate_embeddings(memory_contents, user_id) memory_embeddings = await self._generate_embeddings(memory_contents, user_id)
if len(memory_embeddings) != len(user_memories):
logger.error(f"🔢 Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories")
return [], self.valves.semantic_retrieval_threshold, []
similarity_scores = []
memory_data = [] memory_data = []
for memory_index, memory in enumerate(user_memories): for memory_index, memory in enumerate(user_memories):
memory_embedding = memory_embeddings[memory_index] memory_embedding = memory_embeddings[memory_index]
if memory_embedding is None: if memory_embedding is None:
continue continue
similarity = float(np.dot(query_embedding, memory_embedding)) similarity = float(np.dot(query_embedding, memory_embedding))
similarity_scores.append(similarity)
memory_dict = self._build_memory_dict(memory, similarity) memory_dict = self._build_memory_dict(memory, similarity)
memory_data.append(memory_dict) memory_data.append(memory_dict)
if not similarity_scores:
return [], self.valves.semantic_retrieval_threshold, []
memory_data.sort(key=lambda x: x["relevance"], reverse=True) memory_data.sort(key=lambda x: x["relevance"], reverse=True)
threshold = self.valves.semantic_retrieval_threshold threshold = self.valves.semantic_retrieval_threshold
@@ -1657,10 +1593,19 @@ class Filter:
return body return body
user_message, should_skip, skip_reason = self._process_user_message(body) user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip: if not user_message or should_skip:
if __event_emitter__ and skip_reason: if __event_emitter__ and skip_reason:
await self._emit_status(__event_emitter__, skip_reason, done=True) await self._emit_status(__event_emitter__, skip_reason, done=True)
await self._add_memory_context(body, [], user_id, __event_emitter__) await self._add_memory_context(body, [], user_id, __event_emitter__)
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "")
await self._cache_manager.put(
user_id,
self._cache_manager.SKIP_STATE_CACHE,
skip_cache_key,
True,
)
return body return body
try: try:
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
@@ -1709,11 +1654,20 @@ class Filter:
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id:
return body return body
user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip: user_message, _, _ = self._process_user_message(body)
if not user_message:
return body return body
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
should_skip = await self._cache_manager.get(user_id, self._cache_manager.SKIP_STATE_CACHE, skip_cache_key)
if should_skip:
logger.info("⏭️ Skipping outlet consolidation: inlet already detected skip condition")
return body
retrieval_cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, retrieval_cache_key)
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
self._background_tasks.add(task) self._background_tasks.add(task)
@@ -1745,7 +1699,10 @@ class Filter:
try: try:
retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE)
embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE)
logger.info(f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}") skip_state_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.SKIP_STATE_CACHE)
logger.info(
f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding + {skip_state_cleared} skip state cache entries for user {user_id}"
)
user_memories = await self._get_user_memories(user_id) user_memories = await self._get_user_memories(user_id)
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
@@ -1774,59 +1731,49 @@ class Filter:
async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str:
"""Execute a single memory operation.""" """Execute a single memory operation."""
try: if operation.operation == Models.MemoryOperationType.CREATE:
if operation.operation == Models.MemoryOperationType.CREATE: content_stripped = operation.content.strip()
content_stripped = operation.content.strip() if not content_stripped:
if not content_stripped: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
logger.warning(f"⚠️ Skipping CREATE operation: empty content")
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.CREATE.value return Models.MemoryOperationType.CREATE.value
elif operation.operation == Models.MemoryOperationType.UPDATE: elif operation.operation == Models.MemoryOperationType.UPDATE:
id_stripped = operation.id.strip() id_stripped = operation.id.strip()
if not id_stripped: if not id_stripped:
logger.warning(f"⚠️ Skipping UPDATE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value
return Models.OperationResult.SKIPPED_EMPTY_ID.value
content_stripped = operation.content.strip() content_stripped = operation.content.strip()
if not content_stripped: if not content_stripped:
logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content") return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread( asyncio.to_thread(
Memories.update_memory_by_id_and_user_id, Memories.update_memory_by_id_and_user_id,
id_stripped, id_stripped,
user.id, user.id,
content_stripped, content_stripped,
), ),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.UPDATE.value return Models.MemoryOperationType.UPDATE.value
elif operation.operation == Models.MemoryOperationType.DELETE: elif operation.operation == Models.MemoryOperationType.DELETE:
id_stripped = operation.id.strip() id_stripped = operation.id.strip()
if not id_stripped: if not id_stripped:
logger.warning(f"⚠️ Skipping DELETE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value
return Models.OperationResult.SKIPPED_EMPTY_ID.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.DELETE.value return Models.MemoryOperationType.DELETE.value
else:
logger.error(f"❓ Unsupported operation: {operation}")
return Models.OperationResult.UNSUPPORTED.value
except Exception as e: return Models.OperationResult.UNSUPPORTED.value
logger.error(f"💾 Database operation failed for {operation.operation.value}: {str(e)}")
return Models.OperationResult.FAILED.value
def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]:
"""Remove $ref references and ensure required fields for Azure OpenAI.""" """Remove $ref references and ensure required fields for Azure OpenAI."""