fix(memory_system): improve content extraction, hash handling, and memory filtering

Updates content extraction logic for robustness, ensures all hash operations
handle non-string input safely, filters out empty memory content, and bumps
required Open WebUI version for compatibility. These changes address edge
cases in content processing, prevent potential errors, and ensure only valid
memories are processed and embedded.
This commit is contained in:
mtayfur
2025-11-26 01:04:48 +03:00
parent 4502e07fb3
commit cdedeee6ba

View File

@@ -1,10 +1,10 @@
""" """
title: Memory System title: Memory System
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations. description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
version: 1.2.0 version: 1.2.1
authors: https://github.com/mtayfur authors: https://github.com/mtayfur
license: Apache-2.0 license: Apache-2.0
required_open_webui_version: 0.6.37 required_open_webui_version: 0.6.40
""" """
import asyncio import asyncio
@@ -1249,14 +1249,10 @@ class Filter:
def _extract_text_from_content(self, content) -> str: def _extract_text_from_content(self, content) -> str:
if isinstance(content, str): if isinstance(content, str):
return content return content
elif isinstance(content, list): if isinstance(content, list):
text_parts = [] return " ".join(item.get("text", "") for item in content if isinstance(item, dict) and item.get("type") == "text")
for item in content: if isinstance(content, dict) and "text" in content:
if isinstance(item, dict) and item.get("type") == "text": return content.get("text", "")
text_parts.append(item.get("text", ""))
return " ".join(text_parts)
elif isinstance(content, dict) and "text" in content:
return content["text"]
return "" return ""
def _validate_system_configuration(self) -> None: def _validate_system_configuration(self) -> None:
@@ -1279,7 +1275,7 @@ class Filter:
def _compute_text_hash(self, text: str) -> str: def _compute_text_hash(self, text: str) -> str:
"""Compute SHA256 hash for text caching.""" """Compute SHA256 hash for text caching."""
return hashlib.sha256(text.encode()).hexdigest() return hashlib.sha256(str(text).encode()).hexdigest()
async def _detect_embedding_dimension(self) -> None: async def _detect_embedding_dimension(self) -> None:
"""Detect embedding dimension by generating a test embedding.""" """Detect embedding dimension by generating a test embedding."""
@@ -1400,10 +1396,11 @@ class Filter:
if timeout is None: if timeout is None:
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
return await asyncio.wait_for( memories = await asyncio.wait_for(
asyncio.to_thread(Memories.get_memories_by_user_id, user_id), asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
timeout=timeout, timeout=timeout,
) )
return [m for m in (memories or []) if m.content]
def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None:
"""Log retrieved memories with concise formatting showing key statistics and semantic values.""" """Log retrieved memories with concise formatting showing key statistics and semantic values."""
@@ -1434,7 +1431,7 @@ class Filter:
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
"""Unified cache key generation for all cache types.""" """Unified cache key generation for all cache types."""
if content: if content:
content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] content_hash = hashlib.sha256(str(content).encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH]
return f"{cache_type}_{user_id}:{content_hash}" return f"{cache_type}_{user_id}:{content_hash}"
return f"{cache_type}_{user_id}" return f"{cache_type}_{user_id}"
@@ -1637,7 +1634,7 @@ class Filter:
user_message, should_skip, skip_reason = await self._process_user_message(body) user_message, should_skip, skip_reason = await self._process_user_message(body)
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "") skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message)
await self._cache_manager.put( await self._cache_manager.put(
user_id, user_id,
self._cache_manager.SKIP_STATE_CACHE, self._cache_manager.SKIP_STATE_CACHE,
@@ -1771,7 +1768,7 @@ class Filter:
user_memories, user_memories,
) )
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS] memory_contents = [memory.content for memory in user_memories if len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
if memory_contents: if memory_contents:
await self._generate_embeddings(memory_contents, user_id) await self._generate_embeddings(memory_contents, user_id)