Refactor Filter class to use async for pipeline context setup; implement locking mechanism for shared skip detector cache to enhance concurrency safety.

This commit is contained in:
mtayfur
2025-10-12 23:24:58 +03:00
parent 849dd71a01
commit 2deba4fb2c

View File

@@ -25,6 +25,7 @@ from fastapi import Request
logger = logging.getLogger("MemorySystem") logger = logging.getLogger("MemorySystem")
_SHARED_SKIP_DETECTOR_CACHE = {} _SHARED_SKIP_DETECTOR_CACHE = {}
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
class Constants: class Constants:
"""Centralized configuration constants for the memory system.""" """Centralized configuration constants for the memory system."""
@@ -1062,7 +1063,7 @@ class Filter:
self._llm_reranking_service = LLMRerankingService(self) self._llm_reranking_service = LLMRerankingService(self)
self._llm_consolidation_service = LLMConsolidationService(self) self._llm_consolidation_service = LLMConsolidationService(self)
def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None: __model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods.""" """Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
if __event_emitter__: if __event_emitter__:
@@ -1079,28 +1080,29 @@ class Filter:
logger.info(f"✅ Using OpenWebUI's embedding function") logger.info(f"✅ Using OpenWebUI's embedding function")
if self._skip_detector is None: if self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
cache_key = f"{embedding_engine}:{embedding_model}" cache_key = f"{embedding_engine}:{embedding_model}"
if cache_key in _SHARED_SKIP_DETECTOR_CACHE: async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}") if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
else: self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") else:
embedding_fn = self._embedding_function logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: embedding_fn = self._embedding_function
result = embedding_fn(texts, prefix=None, user=None) def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
if isinstance(result, list): result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result[0], list): if isinstance(result, list):
return [np.array(emb, dtype=np.float16) for emb in result] if isinstance(result[0], list):
return [np.array(emb, dtype=np.float16) for emb in result]
return np.array(result, dtype=np.float16)
return np.array(result, dtype=np.float16) return np.array(result, dtype=np.float16)
return np.array(result, dtype=np.float16)
self._skip_detector = SkipDetector(embedding_wrapper)
self._skip_detector = SkipDetector(embedding_wrapper) _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector logger.info(f"✅ Skip detector initialized and cached")
logger.info(f"✅ Skip detector initialized and cached")
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
@@ -1492,7 +1494,7 @@ class Filter:
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Simplified inlet processing for memory retrieval and injection.""" """Simplified inlet processing for memory retrieval and injection."""
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id:
@@ -1540,7 +1542,7 @@ class Filter:
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""Simplified outlet processing for background memory consolidation.""" """Simplified outlet processing for background memory consolidation."""
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id: