mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Refactor Filter class to use async for pipeline context setup; implement locking mechanism for shared skip detector cache to enhance concurrency safety.
This commit is contained in:
@@ -25,6 +25,7 @@ from fastapi import Request
|
||||
logger = logging.getLogger("MemorySystem")
|
||||
|
||||
_SHARED_SKIP_DETECTOR_CACHE = {}
|
||||
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
|
||||
|
||||
class Constants:
|
||||
"""Centralized configuration constants for the memory system."""
|
||||
@@ -1062,7 +1063,7 @@ class Filter:
|
||||
self._llm_reranking_service = LLMRerankingService(self)
|
||||
self._llm_consolidation_service = LLMConsolidationService(self)
|
||||
|
||||
def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
|
||||
async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
|
||||
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
|
||||
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
|
||||
if __event_emitter__:
|
||||
@@ -1079,28 +1080,29 @@ class Filter:
|
||||
logger.info(f"✅ Using OpenWebUI's embedding function")
|
||||
|
||||
if self._skip_detector is None:
|
||||
global _SHARED_SKIP_DETECTOR_CACHE
|
||||
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
||||
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
|
||||
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
|
||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||
|
||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||
else:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list):
|
||||
if isinstance(result[0], list):
|
||||
return [np.array(emb, dtype=np.float16) for emb in result]
|
||||
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||
else:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list):
|
||||
if isinstance(result[0], list):
|
||||
return [np.array(emb, dtype=np.float16) for emb in result]
|
||||
return np.array(result, dtype=np.float16)
|
||||
return np.array(result, dtype=np.float16)
|
||||
return np.array(result, dtype=np.float16)
|
||||
|
||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||
logger.info(f"✅ Skip detector initialized and cached")
|
||||
|
||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||
logger.info(f"✅ Skip detector initialized and cached")
|
||||
|
||||
|
||||
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
||||
@@ -1492,7 +1494,7 @@ class Filter:
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Simplified inlet processing for memory retrieval and injection."""
|
||||
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
|
||||
user_id = __user__.get("id") if body and __user__ else None
|
||||
if not user_id:
|
||||
@@ -1540,7 +1542,7 @@ class Filter:
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Simplified outlet processing for background memory consolidation."""
|
||||
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
|
||||
user_id = __user__.get("id") if body and __user__ else None
|
||||
if not user_id:
|
||||
|
||||
Reference in New Issue
Block a user