From 3b84f643927e023ef95a1108cc3a5b0c5a875e03 Mon Sep 17 00:00:00 2001 From: mtayfur Date: Mon, 24 Nov 2025 15:45:03 +0300 Subject: [PATCH] refactor: make skip detection and embedding operations fully async for improved concurrency Skip detection and embedding-related methods are now asynchronous, allowing non-blocking execution and better concurrency; embedding function wrappers and initialization routines are updated to support async/await, and shared skip detector caching is adapted accordingly. These changes are necessary to ensure compatibility with async embedding functions, prevent blocking the event loop, and improve scalability and responsiveness in high-concurrency environments. --- memory_system.py | 153 +++++++++++++++++++++++------------------------ 1 file changed, 75 insertions(+), 78 deletions(-) diff --git a/memory_system.py b/memory_system.py index 1ead04d..33cd035 100644 --- a/memory_system.py +++ b/memory_system.py @@ -4,6 +4,7 @@ description: A semantic memory management system for Open WebUI that consolidate version: 1.0.0 authors: https://github.com/mtayfur license: Apache-2.0 +required_open_webui_version: 0.6.37 """ import asyncio @@ -459,29 +460,27 @@ class SkipDetector: SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations", } - def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None - self._initialize_reference_embeddings() - def _initialize_reference_embeddings(self) -> None: + async def initialize(self) -> None: """Compute and cache embeddings for category descriptions.""" - try: - non_personal_embeddings = self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS) - personal_embeddings = self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS) + if self._reference_embeddings is not None: + return - self._reference_embeddings = { - "non_personal": np.array(non_personal_embeddings), - "personal": np.array(personal_embeddings), - } + non_personal_embeddings = await self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS) + personal_embeddings = await self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS) - logger.info( - f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories" - ) - except Exception as e: - logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") - self._reference_embeddings = None + self._reference_embeddings = { + "non_personal": np.array(non_personal_embeddings), + "personal": np.array(personal_embeddings), + } + + logger.info( + f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories" + ) def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]: """Validate message size constraints.""" @@ -615,7 +614,7 @@ class SkipDetector: return None - def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: + async def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) @@ -633,30 +632,25 @@ class SkipDetector: return self.SkipReason.SKIP_NON_PERSONAL.value if self._reference_embeddings is None: - logger.warning("SkipDetector reference embeddings not initialized, allowing message through") - return None + await self.initialize() - try: - message_embedding = np.array(self.embedding_function([message.strip()])[0]) + message_embedding_result = await self.embedding_function([message.strip()]) + message_embedding = np.array(message_embedding_result[0]) - personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T) - max_personal_similarity = float(personal_similarities.max()) + personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T) + max_personal_similarity = float(personal_similarities.max()) - non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T) - max_non_personal_similarity = float(non_personal_similarities.max()) + non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T) + max_non_personal_similarity = float(non_personal_similarities.max()) - margin = memory_system.valves.skip_category_margin - threshold = max_personal_similarity + margin - if (max_non_personal_similarity - max_personal_similarity) > margin: - logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})") - return self.SkipReason.SKIP_NON_PERSONAL.value + margin = memory_system.valves.skip_category_margin + threshold = max_personal_similarity + margin + if (max_non_personal_similarity - max_personal_similarity) > margin: + logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})") + return self.SkipReason.SKIP_NON_PERSONAL.value - logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})") - return None - - except Exception as e: - logger.error(f"Error in semantic skip detection: {e}") - return None + logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})") + return None class LLMRerankingService: @@ -1181,6 +1175,8 @@ class Filter: self._embedding_dimension = None self._skip_detector = None + self._initialization_lock = asyncio.Lock() + self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) @@ -1205,36 +1201,38 @@ class Filter: self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"✅ Using OpenWebUI's embedding function") - self._detect_embedding_dimension() + if self._embedding_function and self._embedding_dimension is None: + async with self._initialization_lock: + if self._embedding_dimension is None: + await self._detect_embedding_dimension() - if self._skip_detector is None: - global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK - embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") - embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") - cache_key = f"{embedding_engine}:{embedding_model}" + if self._embedding_function and self._skip_detector is None: + global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK + embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") + embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") + cache_key = f"{embedding_engine}:{embedding_model}" - async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: - if cache_key in _SHARED_SKIP_DETECTOR_CACHE: - logger.info(f"♻️ Reusing cached skip detector: {cache_key}") - self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] - else: - logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") - embedding_fn = self._embedding_function - normalize_fn = self._normalize_embedding + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: + if cache_key in _SHARED_SKIP_DETECTOR_CACHE: + logger.info(f"♻️ Reusing cached skip detector: {cache_key}") + self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] + else: + logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") + embedding_fn = self._embedding_function + normalize_fn = self._normalize_embedding - def embedding_wrapper( - texts: Union[str, List[str]], - ) -> Union[np.ndarray, List[np.ndarray]]: - result = embedding_fn(texts, prefix=None, user=None) - if isinstance(result, list): - if isinstance(result[0], list): - return [normalize_fn(emb) for emb in result] - return np.array([normalize_fn(result)]) - return normalize_fn(result) + async def embedding_wrapper( + texts: Union[str, List[str]], + ) -> Union[np.ndarray, List[np.ndarray]]: + result = await embedding_fn(texts, prefix=None, user=None) + if isinstance(result, list) and len(result) > 0 and isinstance(result[0], (list, np.ndarray)): + return [normalize_fn(emb) for emb in result] + return [normalize_fn(result if isinstance(result, (list, np.ndarray)) else [result])] - self._skip_detector = SkipDetector(embedding_wrapper) - _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector - logger.info(f"✅ Skip detector initialized and cached") + self._skip_detector = SkipDetector(embedding_wrapper) + await self._skip_detector.initialize() + _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector + logger.info(f"✅ Skip detector initialized and cached") def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" @@ -1283,16 +1281,16 @@ class Filter: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _detect_embedding_dimension(self) -> None: + async def _detect_embedding_dimension(self) -> None: """Detect embedding dimension by generating a test embedding.""" - try: - test_embedding = self._embedding_function(["dummy"], prefix=None, user=None) - if isinstance(test_embedding, list): - test_embedding = test_embedding[0] - self._embedding_dimension = np.squeeze(test_embedding).shape[0] - logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") - except Exception as e: - raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}") + test_embedding = await self._embedding_function("dummy", prefix=None, user=None) + + if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)): + test_embedding = test_embedding[0] + + emb_array = np.squeeze(np.array(test_embedding)) + self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1 + logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: """Normalize embedding vector and ensure 1D shape.""" @@ -1343,8 +1341,7 @@ class Filter: if uncached_texts: user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None - loop = asyncio.get_event_loop() - raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user) + raw_embeddings = await self._embedding_function(uncached_texts, prefix=None, user=user) if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], (list, np.ndarray)): @@ -1369,14 +1366,14 @@ class Filter: logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid") return result_embeddings - def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: - skip_reason = self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self) + async def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: + skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key] return False, "" - def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: + async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" messages = body["messages"] user_message = None @@ -1398,7 +1395,7 @@ class Filter: SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], ) - should_skip, skip_reason = self._should_skip_memory_operations(user_message) + should_skip, skip_reason = await self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: @@ -1645,7 +1642,7 @@ class Filter: if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) + user_message, should_skip, skip_reason = await self._process_user_message(body) skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "") await self._cache_manager.put(