From 840d4c59ca16ed8383fa58ef5a3f0d1e7b29877a Mon Sep 17 00:00:00 2001 From: mtayfur Date: Thu, 9 Oct 2025 23:36:27 +0300 Subject: [PATCH] Refactor SkipDetector to use a callable embedding function instead of SentenceTransformer; update requirements to remove unnecessary dependencies. --- memory_system.py | 177 +++++++++++++++++++++++------------------------ requirements.txt | 4 +- 2 files changed, 89 insertions(+), 92 deletions(-) diff --git a/memory_system.py b/memory_system.py index 6c582a4..356c56c 100644 --- a/memory_system.py +++ b/memory_system.py @@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError -from sentence_transformers import SentenceTransformer from open_webui.utils.chat import generate_chat_completion from open_webui.models.users import Users @@ -23,11 +22,10 @@ from open_webui.routers.memories import Memories from fastapi import Request logging.getLogger("transformers").setLevel(logging.ERROR) -logging.getLogger("sentence_transformers").setLevel(logging.ERROR) logger = logging.getLogger("MemorySystem") -_SHARED_MODEL_CACHE = {} +_SHARED_SKIP_DETECTOR_CACHE = {} class Constants: """Centralized configuration constants for the memory system.""" @@ -65,7 +63,6 @@ class Constants: # Default Models DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite" - DEFAULT_EMBEDDING_MODEL = "Alibaba-NLP/gte-multilingual-base" class Prompts: """Container for all LLM prompts used in the memory system.""" @@ -462,58 +459,46 @@ class SkipDetector: SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations", } - def __init__(self, embedding_model: SentenceTransformer): - """Initialize the skip detector with an embedding model and compute reference embeddings.""" - self.embedding_model = embedding_model + def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + """Initialize the skip detector with an embedding function and compute reference embeddings.""" + self.embedding_function = embedding_function self._reference_embeddings = None self._initialize_reference_embeddings() def _initialize_reference_embeddings(self) -> None: """Compute and cache embeddings for category descriptions.""" try: - technical_embeddings = self.embedding_model.encode( - self.TECHNICAL_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + technical_embeddings = self.embedding_function( + self.TECHNICAL_CATEGORY_DESCRIPTIONS ) - instruction_embeddings = self.embedding_model.encode( - self.INSTRUCTION_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + instruction_embeddings = self.embedding_function( + self.INSTRUCTION_CATEGORY_DESCRIPTIONS ) - pure_math_embeddings = self.embedding_model.encode( - self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + pure_math_embeddings = self.embedding_function( + self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS ) - translation_embeddings = self.embedding_model.encode( - self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + translation_embeddings = self.embedding_function( + self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS ) - grammar_embeddings = self.embedding_model.encode( - self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + grammar_embeddings = self.embedding_function( + self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS ) - conversational_embeddings = self.embedding_model.encode( - self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS, - convert_to_tensor=True, - show_progress_bar=False + conversational_embeddings = self.embedding_function( + self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS ) self._reference_embeddings = { - 'technical': technical_embeddings, - 'instruction': instruction_embeddings, - 'pure_math': pure_math_embeddings, - 'translation': translation_embeddings, - 'grammar': grammar_embeddings, - 'conversational': conversational_embeddings, + 'technical': np.array(technical_embeddings), + 'instruction': np.array(instruction_embeddings), + 'pure_math': np.array(pure_math_embeddings), + 'translation': np.array(translation_embeddings), + 'grammar': np.array(grammar_embeddings), + 'conversational': np.array(conversational_embeddings), } total_skip_categories = ( @@ -569,7 +554,6 @@ class SkipDetector: parts = line[2:].split() if parts and parts[0].isalnum(): actual_command_lines += 1 - # Check for lines with embedded $ commands (e.g., "Run: $ command") elif '$ ' in line: dollar_index = line.find('$ ') if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'): @@ -583,7 +567,6 @@ class SkipDetector: elif line.startswith('> ') and len(line) > 2: pass - # Lowered threshold: even 1 command line with URL/pipe is technical if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']): return self.SkipReason.SKIP_TECHNICAL.value if actual_command_lines >= 3: @@ -602,23 +585,19 @@ class SkipDetector: if markup_chars >= 6: if markup_chars / msg_len > 0.10: return self.SkipReason.SKIP_TECHNICAL.value - # Special check for JSON-like structures (many nested braces) - # Even with low density, if we have lots of curly braces, it's likely JSON curly_count = message.count('{') + message.count('}') if curly_count >= 10: return self.SkipReason.SKIP_TECHNICAL.value # Pattern 7: Structured nested content with colons (key: value patterns) line_count = message.count('\n') - if line_count >= 8: # At least 8 lines + if line_count >= 8: lines = message.split('\n') non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - # Count lines with colon patterns (key: value or similar) colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#')) indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - # If most lines have colons and indentation, it's structured data if (colon_lines / len(non_empty_lines) > 0.4 and indented_lines / len(non_empty_lines) > 0.5): return self.SkipReason.SKIP_TECHNICAL.value @@ -631,7 +610,6 @@ class SkipDetector: markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>')) structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - # Require high markup presence or indented structure with technical keywords if markup_in_lines / len(non_empty_lines) > 0.3: return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: @@ -684,18 +662,12 @@ class SkipDetector: return None try: - from sentence_transformers import util + message_embedding = np.array(self.embedding_function([message.strip()])[0]) - message_embedding = self.embedding_model.encode( - message.strip(), - convert_to_tensor=True, - show_progress_bar=False - ) - - conversational_similarities = util.cos_sim( + conversational_similarities = np.dot( message_embedding, - self._reference_embeddings['conversational'] - )[0] + self._reference_embeddings['conversational'].T + ) max_conversational_similarity = float(conversational_similarities.max()) skip_categories = [ @@ -707,10 +679,10 @@ class SkipDetector: ] for cat_key, skip_reason, descriptions in skip_categories: - similarities = util.cos_sim( + similarities = np.dot( message_embedding, - self._reference_embeddings[cat_key] - )[0] + self._reference_embeddings[cat_key].T + ) max_similarity = float(similarities.max()) if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD: @@ -1069,7 +1041,6 @@ class Filter: """Configuration valves for the Memory System.""" model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - embedding_model: str = Field(default=Constants.DEFAULT_EMBEDDING_MODEL, description="Sentence transformer model for embeddings") max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval") @@ -1079,7 +1050,7 @@ class Filter: def __init__(self): """Initialize the Memory System filter with production validation.""" - global _SHARED_MODEL_CACHE + global _SHARED_SKIP_DETECTOR_CACHE self.valves = self.Valves() self._validate_system_configuration() @@ -1088,21 +1059,8 @@ class Filter: self._background_tasks: set = set() self._shutdown_event = asyncio.Event() - model_key = self.valves.embedding_model - - if model_key in _SHARED_MODEL_CACHE: - logger.info(f"♻️ Reusing cached embedding model: {model_key}") - self._model = _SHARED_MODEL_CACHE[model_key]["model"] - self._skip_detector = _SHARED_MODEL_CACHE[model_key]["skip_detector"] - else: - logger.info(f"🤖 Loading embedding model: {model_key} (cache has {len(_SHARED_MODEL_CACHE)} models)") - self._model = SentenceTransformer(self.valves.embedding_model, device="auto", trust_remote_code=True) - self._skip_detector = SkipDetector(self._model) - _SHARED_MODEL_CACHE[model_key] = { - "model": self._model, - "skip_detector": self._skip_detector - } - logger.info(f"✅ Embedding model and skip detector initialized and cached") + self._embedding_function = None + self._skip_detector = None self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) @@ -1118,6 +1076,35 @@ class Filter: self.__model__ = __model__ if __request__: self.__request__ = __request__ + + if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'): + self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION + logger.info(f"✅ Using OpenWebUI's embedding function") + + if self._skip_detector is None: + global _SHARED_SKIP_DETECTOR_CACHE + embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') + embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') + cache_key = f"{embedding_engine}:{embedding_model}" + + if cache_key in _SHARED_SKIP_DETECTOR_CACHE: + logger.info(f"♻️ Reusing cached skip detector: {cache_key}") + self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] + else: + logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") + embedding_fn = self._embedding_function + def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: + result = embedding_fn(texts, prefix=None, user=None) + if isinstance(result, list): + if isinstance(result[0], list): + return [np.array(emb, dtype=np.float16) for emb in result] + return np.array(result, dtype=np.float16) + return np.array(result, dtype=np.float16) + + self._skip_detector = SkipDetector(embedding_wrapper) + _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector + logger.info(f"✅ Skip detector initialized and cached") + def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" @@ -1169,24 +1156,20 @@ class Filter: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray: + def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: """Normalize embedding vector.""" - embedding = embedding.astype(np.float16) + if isinstance(embedding, list): + embedding = np.array(embedding, dtype=np.float16) + else: + embedding = embedding.astype(np.float16) norm = np.linalg.norm(embedding) return embedding / norm if norm > 0 else embedding - def _generate_embeddings_sync(self, model, texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: - """Synchronous embedding generation for single text or batch.""" - is_single = isinstance(texts, str) - input_texts = [texts] if is_single else texts - - embeddings = model.encode(input_texts, convert_to_numpy=True, show_progress_bar=False) - normalized = [self._normalize_embedding(emb) for emb in embeddings] - - return normalized[0] if is_single else normalized - async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: - """Unified embedding generation for single text or batch with optimized caching.""" + """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" + if self._embedding_function is None: + raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.") + is_single = isinstance(texts, str) text_list = [texts] if is_single else texts @@ -1219,8 +1202,24 @@ class Filter: uncached_hashes.append(text_hash) if uncached_texts: + user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None + loop = asyncio.get_event_loop() - new_embeddings = await loop.run_in_executor(None, self._generate_embeddings_sync, self._model, uncached_texts) + raw_embeddings = await loop.run_in_executor( + None, + self._embedding_function, + uncached_texts, + None, + user + ) + + if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: + if isinstance(raw_embeddings[0], list): + new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] + else: + new_embeddings = [self._normalize_embedding(raw_embeddings)] + else: + new_embeddings = [self._normalize_embedding(raw_embeddings)] for j, embedding in enumerate(new_embeddings): original_idx = uncached_indices[j] diff --git a/requirements.txt b/requirements.txt index 7abc7cc..e8f6881 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,5 @@ aiohttp>=3.12.15 pydantic>=2.11.7 -sentence-transformers>=5.1.1 -torch>=2.8.0 -transformers>=4.57.0 +numpy>=2.0.0 open-webui>=0.6.32 tiktoken>=0.11.0