refactor: make skip detection and embedding operations fully async for improved concurrency

Skip detection and embedding-related methods are now asynchronous, allowing
non-blocking execution and better concurrency; embedding function wrappers
and initialization routines are updated to support async/await, and
shared skip detector caching is adapted accordingly. These changes are
necessary to ensure compatibility with async embedding functions, prevent
blocking the event loop, and improve scalability and responsiveness in
high-concurrency environments.
This commit is contained in:
mtayfur
2025-11-24 15:45:03 +03:00
parent 960f8ce4a9
commit 3b84f64392

View File

@@ -4,6 +4,7 @@ description: A semantic memory management system for Open WebUI that consolidate
version: 1.0.0
authors: https://github.com/mtayfur
license: Apache-2.0
required_open_webui_version: 0.6.37
"""
import asyncio
@@ -459,29 +460,27 @@ class SkipDetector:
SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations",
}
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Any]):
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
self.embedding_function = embedding_function
self._reference_embeddings = None
self._initialize_reference_embeddings()
def _initialize_reference_embeddings(self) -> None:
async def initialize(self) -> None:
"""Compute and cache embeddings for category descriptions."""
try:
non_personal_embeddings = self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)
personal_embeddings = self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS)
if self._reference_embeddings is not None:
return
self._reference_embeddings = {
"non_personal": np.array(non_personal_embeddings),
"personal": np.array(personal_embeddings),
}
non_personal_embeddings = await self.embedding_function(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)
personal_embeddings = await self.embedding_function(self.PERSONAL_CATEGORY_DESCRIPTIONS)
logger.info(
f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories"
)
except Exception as e:
logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}")
self._reference_embeddings = None
self._reference_embeddings = {
"non_personal": np.array(non_personal_embeddings),
"personal": np.array(personal_embeddings),
}
logger.info(
f"SkipDetector initialized with {len(self.NON_PERSONAL_CATEGORY_DESCRIPTIONS)} non-personal categories and {len(self.PERSONAL_CATEGORY_DESCRIPTIONS)} personal categories"
)
def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]:
"""Validate message size constraints."""
@@ -615,7 +614,7 @@ class SkipDetector:
return None
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
async def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
"""
Detect if a message should be skipped using two-stage detection:
1. Fast-path structural patterns (~95% confidence)
@@ -633,30 +632,25 @@ class SkipDetector:
return self.SkipReason.SKIP_NON_PERSONAL.value
if self._reference_embeddings is None:
logger.warning("SkipDetector reference embeddings not initialized, allowing message through")
return None
await self.initialize()
try:
message_embedding = np.array(self.embedding_function([message.strip()])[0])
message_embedding_result = await self.embedding_function([message.strip()])
message_embedding = np.array(message_embedding_result[0])
personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T)
max_personal_similarity = float(personal_similarities.max())
personal_similarities = np.dot(message_embedding, self._reference_embeddings["personal"].T)
max_personal_similarity = float(personal_similarities.max())
non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T)
max_non_personal_similarity = float(non_personal_similarities.max())
non_personal_similarities = np.dot(message_embedding, self._reference_embeddings["non_personal"].T)
max_non_personal_similarity = float(non_personal_similarities.max())
margin = memory_system.valves.skip_category_margin
threshold = max_personal_similarity + margin
if (max_non_personal_similarity - max_personal_similarity) > margin:
logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})")
return self.SkipReason.SKIP_NON_PERSONAL.value
margin = memory_system.valves.skip_category_margin
threshold = max_personal_similarity + margin
if (max_non_personal_similarity - max_personal_similarity) > margin:
logger.info(f"🚫 Skipping: non-personal content (sim {max_non_personal_similarity:.3f} > {threshold:.3f})")
return self.SkipReason.SKIP_NON_PERSONAL.value
logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})")
return None
except Exception as e:
logger.error(f"Error in semantic skip detection: {e}")
return None
logger.info(f"✅ Allowing: personal content (sim {max_non_personal_similarity:.3f} <= {threshold:.3f})")
return None
class LLMRerankingService:
@@ -1181,6 +1175,8 @@ class Filter:
self._embedding_dimension = None
self._skip_detector = None
self._initialization_lock = asyncio.Lock()
self._llm_reranking_service = LLMRerankingService(self)
self._llm_consolidation_service = LLMConsolidationService(self)
@@ -1205,36 +1201,38 @@ class Filter:
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
logger.info(f"✅ Using OpenWebUI's embedding function")
self._detect_embedding_dimension()
if self._embedding_function and self._embedding_dimension is None:
async with self._initialization_lock:
if self._embedding_dimension is None:
await self._detect_embedding_dimension()
if self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
cache_key = f"{embedding_engine}:{embedding_model}"
if self._embedding_function and self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
cache_key = f"{embedding_engine}:{embedding_model}"
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
normalize_fn = self._normalize_embedding
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
normalize_fn = self._normalize_embedding
def embedding_wrapper(
texts: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
if isinstance(result[0], list):
return [normalize_fn(emb) for emb in result]
return np.array([normalize_fn(result)])
return normalize_fn(result)
async def embedding_wrapper(
texts: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
result = await embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list) and len(result) > 0 and isinstance(result[0], (list, np.ndarray)):
return [normalize_fn(emb) for emb in result]
return [normalize_fn(result if isinstance(result, (list, np.ndarray)) else [result])]
self._skip_detector = SkipDetector(embedding_wrapper)
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
self._skip_detector = SkipDetector(embedding_wrapper)
await self._skip_detector.initialize()
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
"""Truncate content with ellipsis if needed."""
@@ -1283,16 +1281,16 @@ class Filter:
"""Compute SHA256 hash for text caching."""
return hashlib.sha256(text.encode()).hexdigest()
def _detect_embedding_dimension(self) -> None:
async def _detect_embedding_dimension(self) -> None:
"""Detect embedding dimension by generating a test embedding."""
try:
test_embedding = self._embedding_function(["dummy"], prefix=None, user=None)
if isinstance(test_embedding, list):
test_embedding = test_embedding[0]
self._embedding_dimension = np.squeeze(test_embedding).shape[0]
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
except Exception as e:
raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}")
test_embedding = await self._embedding_function("dummy", prefix=None, user=None)
if isinstance(test_embedding, list) and len(test_embedding) > 0 and isinstance(test_embedding[0], (list, np.ndarray)):
test_embedding = test_embedding[0]
emb_array = np.squeeze(np.array(test_embedding))
self._embedding_dimension = emb_array.shape[0] if emb_array.ndim > 0 else 1
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray:
"""Normalize embedding vector and ensure 1D shape."""
@@ -1343,8 +1341,7 @@ class Filter:
if uncached_texts:
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None
loop = asyncio.get_event_loop()
raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user)
raw_embeddings = await self._embedding_function(uncached_texts, prefix=None, user=user)
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
if isinstance(raw_embeddings[0], (list, np.ndarray)):
@@ -1369,14 +1366,14 @@ class Filter:
logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid")
return result_embeddings
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
skip_reason = self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
async def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
skip_reason = await self._skip_detector.detect_skip_reason(user_message, Constants.MAX_MESSAGE_CHARS, memory_system=self)
if skip_reason:
status_key = SkipDetector.SkipReason(skip_reason)
return True, SkipDetector.STATUS_MESSAGES[status_key]
return False, ""
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
async def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped."""
messages = body["messages"]
user_message = None
@@ -1398,7 +1395,7 @@ class Filter:
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
should_skip, skip_reason = await self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason
async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List:
@@ -1645,7 +1642,7 @@ class Filter:
if not user_id:
return body
user_message, should_skip, skip_reason = self._process_user_message(body)
user_message, should_skip, skip_reason = await self._process_user_message(body)
skip_cache_key = self._cache_key(self._cache_manager.SKIP_STATE_CACHE, user_id, user_message or "")
await self._cache_manager.put(