diff --git a/memory_system.py b/memory_system.py index 1ead04d..33cd035 100644 --- a/memory_system.py +++ b/memory_system.py @@ -4,6 +4,7 @@ description: A semantic memory management system for Open WebUI that consolidate version: 1.0.0 authors: https://github.com/mtayfur license: Apache-2.0 +required_open_webui_version: 0.6.37 """ import asyncio @@ -459,29 +460,27 @@ class SkipDetector: SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations", } - def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None - self._initialize_reference_embeddings() - def _initialize_reference_embeddings(self) -> None: + async def initialize(self) -> None: """Compute and cache embeddings for category descriptions.""" - try: - non_personal_embeddings = self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS) - personal_embeddings = self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS) + if self._reference_embeddings is not None: + return - self._reference_embeddings = { - "non_personal": np.array(non_personal_embeddings), - "personal": np.array(personal_embeddings), - } + non_personal_embeddings = await self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS) + personal_embeddings = await self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS) - logger.info( - f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories" - ) - except Exception as e: - logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") - self._reference_embeddings = None + self._reference_embeddings = { + "non_personal": np.array(non_personal_embeddings), + "personal": np.array(personal_embeddings), + } + + logger.info( + f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories" + ) def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]: """Validate message size constraints.""" @@ -615,7 +614,7 @@ class SkipDetector: return None - def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: + async def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) @@ -633,30 +632,25 @@ class SkipDetector: return self.SkipReason.SKIP_NON_PERSONAL.value if self._reference_embeddings is None: - logger.warning("SkipDetector reference embeddings not initialized, allowing message through") - return None + await self.initialize() - try: - message_embedding = np.array(self.embedding_function([message.strip()])[0]) + message_embedding_result = await self.embedding_function([message.strip()]) + message_embedding = np.array(message_embedding_result[0]) - personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T) - max_personal_similarity = float(personal_similarities.max()) + personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T) + max_personal_similarity = float(personal_similarities.max()) - non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T) - max_non_personal_similarity = float(non_personal_similarities.max()) + non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T) + max_non_personal_similarity = float(non_personal_similarities.max()) - margin = memory_system.valves.skip_category_margin - threshold = max_personal_similarity + margin - if (max_non_personal_similarity - max_personal_similarity) > margin: - logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})") - return self.SkipReason.SKIP_NON_PERSONAL.value + margin = memory_system.valves.skip_category_margin + threshold = max_personal_similarity + margin + if (max_non_personal_similarity - max_personal_similarity) > margin: + logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})") + return self.SkipReason.SKIP_NON_PERSONAL.value - logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})") - return None - - except Exception as e: - logger.error(f"Error in semantic skip detection: {e}") - return None + logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})") + return None class LLMRerankingService: @@ -1181,6 +1175,8 @@ class Filter: self._embedding_dimension = None self._skip_detector = None + self._initialization_lock = asyncio.Lock() + self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) @@ -1205,36 +1201,38 @@ class Filter: self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"✅ Using OpenWebUI's embedding function") - self._detect_embedding_dimension() + if self._embedding_function and self._embedding_dimension is None: + async with self._initialization_lock: + if self._embedding_dimension is None: + await self._detect_embedding_dimension() - if self._skip_detector is None: - global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK - embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") - embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") - cache_key = f"{embedding_engine}:{embedding_model}" + if self._embedding_function and self._skip_detector is None: + global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK + embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") + embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") + cache_key = f"{embedding_engine}:{embedding_model}" - async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: - if cache_key in _SHARED_SKIP_DETECTOR_CACHE: - logger.info(f"♻️ Reusing cached skip detector: {cache_key}") - self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] - else: - logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") - embedding_fn = self._embedding_function - normalize_fn = self._normalize_embedding + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: + if cache_key in _SHARED_SKIP_DETECTOR_CACHE: + logger.info(f"♻️ Reusing cached skip detector: {cache_key}") + self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] + else: + logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") + embedding_fn = self._embedding_function + normalize_fn = self._normalize_embedding - def embedding_wrapper( - texts: Union[str, List[str]], - ) -> Union[np.ndarray, List[np.ndarray]]: - result = embedding_fn(texts, prefix=None, user=None) - if isinstance(result, list): - if isinstance(result[0], list): - return [normalize_fn(emb) for emb in result] - return np.array([normalize_fn(result)]) - return normalize_fn(result) + async def embedding_wrapper( + texts: Union[str, List[str]], + ) -> Union[np.ndarray, List[np.ndarray]]: + result = await embedding_fn(texts, prefix=None, user=None) + if isinstance(result, list) and len(result) > 0 and isinstance(result[0], (list, np.ndarray)): + return [normalize_fn(emb) for emb in result] + return [normalize_fn(result if isinstance(result, (list, np.ndarray)) else [result])] - self._skip_detector = SkipDetector(embedding_wrapper) - _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector - logger.info(f"✅ Skip detector initialized and cached") + self._skip_detector = SkipDetector(embedding_wrapper) + await self._skip_detector.initialize() + _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector + logger.info(f"✅ Skip detector initialized and cached") def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" @@ -1283,16 +1281,16 @@ class Filter: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _detect_embedding_dimension(self) -> None: + async def _detect_embedding_dimension(self) -> None: """Detect embedding dimension by generating a test embedding.""" - try: - test_embedding = self._embedding_function(["dummy"], prefix=None, user=None) - if isinstance(test_embedding, list): - test_embedding = test_embedding[0] - self._embedding_dimension = np.squeeze(test_embedding).shape[0] - logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") - except Exception as e: - raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}") + test_embedding = await self._embedding_function("dummy", prefix=None, user=None) + + if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)): + test_embedding = test_embedding[0] + + emb_array = np.squeeze(np.array(test_embedding)) + self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1 + logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: """Normalize embedding vector and ensure 1D shape.""" @@ -1343,8 +1341,7 @@ class Filter: if uncached_texts: user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None - loop = asyncio.get_event_loop() - raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user) + raw_embeddings = await self._embedding_function(uncached_texts, prefix=None, user=user) if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], (list, np.ndarray)): @@ -1369,14 +1366,14 @@ class Filter: logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid") return result_embeddings - def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: - skip_reason = self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self) + async def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: + skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key] return False, "" - def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: + async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" messages = body["messages"] user_message = None @@ -1398,7 +1395,7 @@ class Filter: SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], ) - should_skip, skip_reason = self._should_skip_memory_operations(user_message) + should_skip, skip_reason = await self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: @@ -1645,7 +1642,7 @@ class Filter: if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) + user_message, should_skip, skip_reason = await self._process_user_message(body) skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "") await self._cache_manager.put(