mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
refactor: make skip detection and embedding operations fully async for improved concurrency
Skip detection and embedding-related methods are now asynchronous, allowing non-blocking execution and better concurrency; embedding function wrappers and initialization routines are updated to support async/await, and shared skip detector caching is adapted accordingly. These changes are necessary to ensure compatibility with async embedding functions, prevent blocking the event loop, and improve scalability and responsiveness in high-concurrency environments.
This commit is contained in:
153
memory_system.py
153
memory_system.py
@@ -4,6 +4,7 @@ description: A semantic memory management system for Open WebUI that consolidate
|
||||
version: 1.0.0
|
||||
authors: https://github.com/mtayfur
|
||||
license: Apache-2.0
|
||||
required_open_webui_version: 0.6.37
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -459,29 +460,27 @@ class SkipDetector:
|
||||
SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations",
|
||||
}
|
||||
|
||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
|
||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]):
|
||||
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
|
||||
self.embedding_function = embedding_function
|
||||
self._reference_embeddings = None
|
||||
self._initialize_reference_embeddings()
|
||||
|
||||
def _initialize_reference_embeddings(self) -> None:
|
||||
async def initialize(self) -> None:
|
||||
"""Compute and cache embeddings for category descriptions."""
|
||||
try:
|
||||
non_personal_embeddings = self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)
|
||||
personal_embeddings = self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS)
|
||||
if self._reference_embeddings is not None:
|
||||
return
|
||||
|
||||
self._reference_embeddings = {
|
||||
"non_personal": np.array(non_personal_embeddings),
|
||||
"personal": np.array(personal_embeddings),
|
||||
}
|
||||
non_personal_embeddings = await self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)
|
||||
personal_embeddings = await self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS)
|
||||
|
||||
logger.info(
|
||||
f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}")
|
||||
self._reference_embeddings = None
|
||||
self._reference_embeddings = {
|
||||
"non_personal": np.array(non_personal_embeddings),
|
||||
"personal": np.array(personal_embeddings),
|
||||
}
|
||||
|
||||
logger.info(
|
||||
f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories"
|
||||
)
|
||||
|
||||
def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]:
|
||||
"""Validate message size constraints."""
|
||||
@@ -615,7 +614,7 @@ class SkipDetector:
|
||||
|
||||
return None
|
||||
|
||||
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
|
||||
async def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
|
||||
"""
|
||||
Detect if a message should be skipped using two-stage detection:
|
||||
1. Fast-path structural patterns (~95% confidence)
|
||||
@@ -633,30 +632,25 @@ class SkipDetector:
|
||||
return self.SkipReason.SKIP_NON_PERSONAL.value
|
||||
|
||||
if self._reference_embeddings is None:
|
||||
logger.warning("SkipDetector reference embeddings not initialized, allowing message through")
|
||||
return None
|
||||
await self.initialize()
|
||||
|
||||
try:
|
||||
message_embedding = np.array(self.embedding_function([message.strip()])[0])
|
||||
message_embedding_result = await self.embedding_function([message.strip()])
|
||||
message_embedding = np.array(message_embedding_result[0])
|
||||
|
||||
personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T)
|
||||
max_personal_similarity = float(personal_similarities.max())
|
||||
personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T)
|
||||
max_personal_similarity = float(personal_similarities.max())
|
||||
|
||||
non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T)
|
||||
max_non_personal_similarity = float(non_personal_similarities.max())
|
||||
non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T)
|
||||
max_non_personal_similarity = float(non_personal_similarities.max())
|
||||
|
||||
margin = memory_system.valves.skip_category_margin
|
||||
threshold = max_personal_similarity + margin
|
||||
if (max_non_personal_similarity - max_personal_similarity) > margin:
|
||||
logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})")
|
||||
return self.SkipReason.SKIP_NON_PERSONAL.value
|
||||
margin = memory_system.valves.skip_category_margin
|
||||
threshold = max_personal_similarity + margin
|
||||
if (max_non_personal_similarity - max_personal_similarity) > margin:
|
||||
logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})")
|
||||
return self.SkipReason.SKIP_NON_PERSONAL.value
|
||||
|
||||
logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in semantic skip detection: {e}")
|
||||
return None
|
||||
logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})")
|
||||
return None
|
||||
|
||||
|
||||
class LLMRerankingService:
|
||||
@@ -1181,6 +1175,8 @@ class Filter:
|
||||
self._embedding_dimension = None
|
||||
self._skip_detector = None
|
||||
|
||||
self._initialization_lock = asyncio.Lock()
|
||||
|
||||
self._llm_reranking_service = LLMRerankingService(self)
|
||||
self._llm_consolidation_service = LLMConsolidationService(self)
|
||||
|
||||
@@ -1205,36 +1201,38 @@ class Filter:
|
||||
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
|
||||
logger.info(f"✅ Using OpenWebUI's embedding function")
|
||||
|
||||
self._detect_embedding_dimension()
|
||||
if self._embedding_function and self._embedding_dimension is None:
|
||||
async with self._initialization_lock:
|
||||
if self._embedding_dimension is None:
|
||||
await self._detect_embedding_dimension()
|
||||
|
||||
if self._skip_detector is None:
|
||||
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
||||
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
|
||||
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
|
||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||
if self._embedding_function and self._skip_detector is None:
|
||||
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
||||
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
|
||||
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
|
||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||
|
||||
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||
else:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
normalize_fn = self._normalize_embedding
|
||||
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||
else:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
normalize_fn = self._normalize_embedding
|
||||
|
||||
def embedding_wrapper(
|
||||
texts: Union[str, List[str]],
|
||||
) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list):
|
||||
if isinstance(result[0], list):
|
||||
return [normalize_fn(emb) for emb in result]
|
||||
return np.array([normalize_fn(result)])
|
||||
return normalize_fn(result)
|
||||
async def embedding_wrapper(
|
||||
texts: Union[str, List[str]],
|
||||
) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = await embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], (list, np.ndarray)):
|
||||
return [normalize_fn(emb) for emb in result]
|
||||
return [normalize_fn(result if isinstance(result, (list, np.ndarray)) else [result])]
|
||||
|
||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||
logger.info(f"✅ Skip detector initialized and cached")
|
||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||
await self._skip_detector.initialize()
|
||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||
logger.info(f"✅ Skip detector initialized and cached")
|
||||
|
||||
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
||||
"""Truncate content with ellipsis if needed."""
|
||||
@@ -1283,16 +1281,16 @@ class Filter:
|
||||
"""Compute SHA256 hash for text caching."""
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
def _detect_embedding_dimension(self) -> None:
|
||||
async def _detect_embedding_dimension(self) -> None:
|
||||
"""Detect embedding dimension by generating a test embedding."""
|
||||
try:
|
||||
test_embedding = self._embedding_function(["dummy"], prefix=None, user=None)
|
||||
if isinstance(test_embedding, list):
|
||||
test_embedding = test_embedding[0]
|
||||
self._embedding_dimension = np.squeeze(test_embedding).shape[0]
|
||||
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}")
|
||||
test_embedding = await self._embedding_function("dummy", prefix=None, user=None)
|
||||
|
||||
if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)):
|
||||
test_embedding = test_embedding[0]
|
||||
|
||||
emb_array = np.squeeze(np.array(test_embedding))
|
||||
self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1
|
||||
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
|
||||
|
||||
def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray:
|
||||
"""Normalize embedding vector and ensure 1D shape."""
|
||||
@@ -1343,8 +1341,7 @@ class Filter:
|
||||
if uncached_texts:
|
||||
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user)
|
||||
raw_embeddings = await self._embedding_function(uncached_texts, prefix=None, user=user)
|
||||
|
||||
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
|
||||
if isinstance(raw_embeddings[0], (list, np.ndarray)):
|
||||
@@ -1369,14 +1366,14 @@ class Filter:
|
||||
logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid")
|
||||
return result_embeddings
|
||||
|
||||
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
|
||||
skip_reason = self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
|
||||
async def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
|
||||
skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
|
||||
if skip_reason:
|
||||
status_key = SkipDetector.SkipReason(skip_reason)
|
||||
return True, SkipDetector.STATUS_MESSAGES[status_key]
|
||||
return False, ""
|
||||
|
||||
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
||||
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
||||
"""Extract user message and determine if memory operations should be skipped."""
|
||||
messages = body["messages"]
|
||||
user_message = None
|
||||
@@ -1398,7 +1395,7 @@ class Filter:
|
||||
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
|
||||
)
|
||||
|
||||
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
|
||||
should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
|
||||
return user_message, should_skip, skip_reason
|
||||
|
||||
async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List:
|
||||
@@ -1645,7 +1642,7 @@ class Filter:
|
||||
if not user_id:
|
||||
return body
|
||||
|
||||
user_message, should_skip, skip_reason = self._process_user_message(body)
|
||||
user_message, should_skip, skip_reason = await self._process_user_message(body)
|
||||
|
||||
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "")
|
||||
await self._cache_manager.put(
|
||||
|
||||
Reference in New Issue
Block a user