mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Refactor Filter class to use async for pipeline context setup; implement locking mechanism for shared skip detector cache to enhance concurrency safety.
This commit is contained in:
@@ -25,6 +25,7 @@ from fastapi import Request
|
|||||||
logger = logging.getLogger("MemorySystem")
|
logger = logging.getLogger("MemorySystem")
|
||||||
|
|
||||||
_SHARED_SKIP_DETECTOR_CACHE = {}
|
_SHARED_SKIP_DETECTOR_CACHE = {}
|
||||||
|
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
class Constants:
|
class Constants:
|
||||||
"""Centralized configuration constants for the memory system."""
|
"""Centralized configuration constants for the memory system."""
|
||||||
@@ -1062,7 +1063,7 @@ class Filter:
|
|||||||
self._llm_reranking_service = LLMRerankingService(self)
|
self._llm_reranking_service = LLMRerankingService(self)
|
||||||
self._llm_consolidation_service = LLMConsolidationService(self)
|
self._llm_consolidation_service = LLMConsolidationService(self)
|
||||||
|
|
||||||
def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
|
async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
|
||||||
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
|
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
|
||||||
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
|
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
|
||||||
if __event_emitter__:
|
if __event_emitter__:
|
||||||
@@ -1079,28 +1080,29 @@ class Filter:
|
|||||||
logger.info(f"✅ Using OpenWebUI's embedding function")
|
logger.info(f"✅ Using OpenWebUI's embedding function")
|
||||||
|
|
||||||
if self._skip_detector is None:
|
if self._skip_detector is None:
|
||||||
global _SHARED_SKIP_DETECTOR_CACHE
|
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
||||||
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
|
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
|
||||||
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
|
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
|
||||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||||
|
|
||||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
||||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||||
else:
|
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
else:
|
||||||
embedding_fn = self._embedding_function
|
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
embedding_fn = self._embedding_function
|
||||||
result = embedding_fn(texts, prefix=None, user=None)
|
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||||
if isinstance(result, list):
|
result = embedding_fn(texts, prefix=None, user=None)
|
||||||
if isinstance(result[0], list):
|
if isinstance(result, list):
|
||||||
return [np.array(emb, dtype=np.float16) for emb in result]
|
if isinstance(result[0], list):
|
||||||
|
return [np.array(emb, dtype=np.float16) for emb in result]
|
||||||
|
return np.array(result, dtype=np.float16)
|
||||||
return np.array(result, dtype=np.float16)
|
return np.array(result, dtype=np.float16)
|
||||||
return np.array(result, dtype=np.float16)
|
|
||||||
|
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
logger.info(f"✅ Skip detector initialized and cached")
|
||||||
logger.info(f"✅ Skip detector initialized and cached")
|
|
||||||
|
|
||||||
|
|
||||||
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
||||||
@@ -1492,7 +1494,7 @@ class Filter:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Simplified inlet processing for memory retrieval and injection."""
|
"""Simplified inlet processing for memory retrieval and injection."""
|
||||||
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||||
|
|
||||||
user_id = __user__.get("id") if body and __user__ else None
|
user_id = __user__.get("id") if body and __user__ else None
|
||||||
if not user_id:
|
if not user_id:
|
||||||
@@ -1540,7 +1542,7 @@ class Filter:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Simplified outlet processing for background memory consolidation."""
|
"""Simplified outlet processing for background memory consolidation."""
|
||||||
self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||||
|
|
||||||
user_id = __user__.get("id") if body and __user__ else None
|
user_id = __user__.get("id") if body and __user__ else None
|
||||||
if not user_id:
|
if not user_id:
|
||||||
|
|||||||
Reference in New Issue
Block a user