diff --git a/memory_system.py b/memory_system.py index e04424d..0b32807 100644 --- a/memory_system.py +++ b/memory_system.py @@ -25,6 +25,7 @@ from fastapi import Request logger = logging.getLogger("MemorySystem") _SHARED_SKIP_DETECTOR_CACHE = {} +_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock() class Constants: """Centralized configuration constants for the memory system.""" @@ -1062,7 +1063,7 @@ class Filter: self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) - def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, + async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, __model__: Optional[str] = None, __request__: Optional[Request] = None) -> None: """Set pipeline context parameters to avoid duplication in inlet/outlet methods.""" if __event_emitter__: @@ -1079,28 +1080,29 @@ class Filter: logger.info(f"✅ Using OpenWebUI's embedding function") if self._skip_detector is None: - global _SHARED_SKIP_DETECTOR_CACHE + global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') cache_key = f"{embedding_engine}:{embedding_model}" - if cache_key in _SHARED_SKIP_DETECTOR_CACHE: - logger.info(f"♻️ Reusing cached skip detector: {cache_key}") - self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] - else: - logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") - embedding_fn = self._embedding_function - def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: - result = embedding_fn(texts, prefix=None, user=None) - if isinstance(result, list): - if isinstance(result[0], list): - return [np.array(emb, dtype=np.float16) for emb in result] + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: + if cache_key in _SHARED_SKIP_DETECTOR_CACHE: + logger.info(f"♻️ Reusing cached skip detector: {cache_key}") + self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] + else: + logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") + embedding_fn = self._embedding_function + def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: + result = embedding_fn(texts, prefix=None, user=None) + if isinstance(result, list): + if isinstance(result[0], list): + return [np.array(emb, dtype=np.float16) for emb in result] + return np.array(result, dtype=np.float16) return np.array(result, dtype=np.float16) - return np.array(result, dtype=np.float16) - - self._skip_detector = SkipDetector(embedding_wrapper) - _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector - logger.info(f"✅ Skip detector initialized and cached") + + self._skip_detector = SkipDetector(embedding_wrapper) + _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector + logger.info(f"✅ Skip detector initialized and cached") def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: @@ -1492,7 +1494,7 @@ class Filter: **kwargs, ) -> Dict[str, Any]: """Simplified inlet processing for memory retrieval and injection.""" - self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: @@ -1540,7 +1542,7 @@ class Filter: **kwargs, ) -> dict: """Simplified outlet processing for background memory consolidation.""" - self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: