Refactor Filter class to use async for pipeline context setup; implement locking mechanism for shared skip detector cache to enhance concurrency safety.

This commit is contained in:
mtayfur
2025-10-12 23:24:58 +03:00
parent 849dd71a01
commit 2deba4fb2c

View File

@@ -25,6 +25,7 @@ from fastapi import Request
logger = logging.getLogger("MemorySystem")
_SHARED_SKIP_DETECTOR_CACHE = {}
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
class Constants:
"""Centralized configuration constants for the memory system."""
@@ -1062,7 +1063,7 @@ class Filter:
self._llm_reranking_service = LLMRerankingService(self)
self._llm_consolidation_service = LLMConsolidationService(self)
def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
if __event_emitter__:
@@ -1079,28 +1080,29 @@ class Filter:
logger.info(f"✅ Using OpenWebUI's embedding function")
if self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
cache_key = f"{embedding_engine}:{embedding_model}"
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
if isinstance(result[0], list):
return [np.array(emb, dtype=np.float16) for emb in result]
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
if isinstance(result[0], list):
return [np.array(emb, dtype=np.float16) for emb in result]
return np.array(result, dtype=np.float16)
return np.array(result, dtype=np.float16)
return np.array(result, dtype=np.float16)
self._skip_detector = SkipDetector(embedding_wrapper)
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
self._skip_detector = SkipDetector(embedding_wrapper)
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
@@ -1492,7 +1494,7 @@ class Filter:
**kwargs,
) -> Dict[str, Any]:
"""Simplified inlet processing for memory retrieval and injection."""
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
user_id = __user__.get("id") if body and __user__ else None
if not user_id:
@@ -1540,7 +1542,7 @@ class Filter:
**kwargs,
) -> dict:
"""Simplified outlet processing for background memory consolidation."""
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
user_id = __user__.get("id") if body and __user__ else None
if not user_id: