Refactor SkipDetector to use a callable embedding function instead of SentenceTransformer; update requirements to remove unnecessary dependencies.

This commit is contained in:
mtayfur
2025-10-09 23:36:27 +03:00
parent 5c0ca1f4ab
commit 840d4c59ca
2 changed files with 89 additions and 92 deletions

View File

@@ -15,7 +15,6 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError
from sentence_transformers import SentenceTransformer
from open_webui.utils.chat import generate_chat_completion
from open_webui.models.users import Users
@@ -23,11 +22,10 @@ from open_webui.routers.memories import Memories
from fastapi import Request
logging.getLogger("transformers").setLevel(logging.ERROR)
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
logger = logging.getLogger("MemorySystem")
_SHARED_MODEL_CACHE = {}
_SHARED_SKIP_DETECTOR_CACHE = {}
class Constants:
"""Centralized configuration constants for the memory system."""
@@ -65,7 +63,6 @@ class Constants:
# Default Models
DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite"
DEFAULT_EMBEDDING_MODEL = "Alibaba-NLP/gte-multilingual-base"
class Prompts:
"""Container for all LLM prompts used in the memory system."""
@@ -462,58 +459,46 @@ class SkipDetector:
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
}
def __init__(self, embedding_model: SentenceTransformer):
"""Initialize the skip detector with an embedding model and compute reference embeddings."""
self.embedding_model = embedding_model
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
self.embedding_function = embedding_function
self._reference_embeddings = None
self._initialize_reference_embeddings()
def _initialize_reference_embeddings(self) -> None:
"""Compute and cache embeddings for category descriptions."""
try:
technical_embeddings = self.embedding_model.encode(
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
technical_embeddings = self.embedding_function(
self.TECHNICAL_CATEGORY_DESCRIPTIONS
)
instruction_embeddings = self.embedding_model.encode(
self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
instruction_embeddings = self.embedding_function(
self.INSTRUCTION_CATEGORY_DESCRIPTIONS
)
pure_math_embeddings = self.embedding_model.encode(
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
pure_math_embeddings = self.embedding_function(
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS
)
translation_embeddings = self.embedding_model.encode(
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
translation_embeddings = self.embedding_function(
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS
)
grammar_embeddings = self.embedding_model.encode(
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
grammar_embeddings = self.embedding_function(
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS
)
conversational_embeddings = self.embedding_model.encode(
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS,
convert_to_tensor=True,
show_progress_bar=False
conversational_embeddings = self.embedding_function(
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS
)
self._reference_embeddings = {
'technical': technical_embeddings,
'instruction': instruction_embeddings,
'pure_math': pure_math_embeddings,
'translation': translation_embeddings,
'grammar': grammar_embeddings,
'conversational': conversational_embeddings,
'technical': np.array(technical_embeddings),
'instruction': np.array(instruction_embeddings),
'pure_math': np.array(pure_math_embeddings),
'translation': np.array(translation_embeddings),
'grammar': np.array(grammar_embeddings),
'conversational': np.array(conversational_embeddings),
}
total_skip_categories = (
@@ -569,7 +554,6 @@ class SkipDetector:
parts = line[2:].split()
if parts and parts[0].isalnum():
actual_command_lines += 1
# Check for lines with embedded $ commands (e.g., "Run: $ command")
elif '$ ' in line:
dollar_index = line.find('$ ')
if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'):
@@ -583,7 +567,6 @@ class SkipDetector:
elif line.startswith('> ') and len(line) > 2:
pass
# Lowered threshold: even 1 command line with URL/pipe is technical
if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']):
return self.SkipReason.SKIP_TECHNICAL.value
if actual_command_lines >= 3:
@@ -602,23 +585,19 @@ class SkipDetector:
if markup_chars >= 6:
if markup_chars / msg_len > 0.10:
return self.SkipReason.SKIP_TECHNICAL.value
# Special check for JSON-like structures (many nested braces)
# Even with low density, if we have lots of curly braces, it's likely JSON
curly_count = message.count('{') + message.count('}')
if curly_count >= 10:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 7: Structured nested content with colons (key: value patterns)
line_count = message.count('\n')
if line_count >= 8: # At least 8 lines
if line_count >= 8:
lines = message.split('\n')
non_empty_lines = [line for line in lines if line.strip()]
if non_empty_lines:
# Count lines with colon patterns (key: value or similar)
colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#'))
indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
# If most lines have colons and indentation, it's structured data
if (colon_lines / len(non_empty_lines) > 0.4 and
indented_lines / len(non_empty_lines) > 0.5):
return self.SkipReason.SKIP_TECHNICAL.value
@@ -631,7 +610,6 @@ class SkipDetector:
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>'))
structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
# Require high markup presence or indented structure with technical keywords
if markup_in_lines / len(non_empty_lines) > 0.3:
return self.SkipReason.SKIP_TECHNICAL.value
elif structured_lines / len(non_empty_lines) > 0.6:
@@ -684,18 +662,12 @@ class SkipDetector:
return None
try:
from sentence_transformers import util
message_embedding = np.array(self.embedding_function([message.strip()])[0])
message_embedding = self.embedding_model.encode(
message.strip(),
convert_to_tensor=True,
show_progress_bar=False
)
conversational_similarities = util.cos_sim(
conversational_similarities = np.dot(
message_embedding,
self._reference_embeddings['conversational']
)[0]
self._reference_embeddings['conversational'].T
)
max_conversational_similarity = float(conversational_similarities.max())
skip_categories = [
@@ -707,10 +679,10 @@ class SkipDetector:
]
for cat_key, skip_reason, descriptions in skip_categories:
similarities = util.cos_sim(
similarities = np.dot(
message_embedding,
self._reference_embeddings[cat_key]
)[0]
self._reference_embeddings[cat_key].T
)
max_similarity = float(similarities.max())
if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD:
@@ -1069,7 +1041,6 @@ class Filter:
"""Configuration valves for the Memory System."""
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
embedding_model: str = Field(default=Constants.DEFAULT_EMBEDDING_MODEL, description="Sentence transformer model for embeddings")
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
@@ -1079,7 +1050,7 @@ class Filter:
def __init__(self):
"""Initialize the Memory System filter with production validation."""
global _SHARED_MODEL_CACHE
global _SHARED_SKIP_DETECTOR_CACHE
self.valves = self.Valves()
self._validate_system_configuration()
@@ -1088,21 +1059,8 @@ class Filter:
self._background_tasks: set = set()
self._shutdown_event = asyncio.Event()
model_key = self.valves.embedding_model
if model_key in _SHARED_MODEL_CACHE:
logger.info(f"♻️ Reusing cached embedding model: {model_key}")
self._model = _SHARED_MODEL_CACHE[model_key]["model"]
self._skip_detector = _SHARED_MODEL_CACHE[model_key]["skip_detector"]
else:
logger.info(f"🤖 Loading embedding model: {model_key} (cache has {len(_SHARED_MODEL_CACHE)} models)")
self._model = SentenceTransformer(self.valves.embedding_model, device="auto", trust_remote_code=True)
self._skip_detector = SkipDetector(self._model)
_SHARED_MODEL_CACHE[model_key] = {
"model": self._model,
"skip_detector": self._skip_detector
}
logger.info(f"✅ Embedding model and skip detector initialized and cached")
self._embedding_function = None
self._skip_detector = None
self._llm_reranking_service = LLMRerankingService(self)
self._llm_consolidation_service = LLMConsolidationService(self)
@@ -1118,6 +1076,35 @@ class Filter:
self.__model__ = __model__
if __request__:
self.__request__ = __request__
if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'):
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
logger.info(f"✅ Using OpenWebUI's embedding function")
if self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
cache_key = f"{embedding_engine}:{embedding_model}"
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
if isinstance(result[0], list):
return [np.array(emb, dtype=np.float16) for emb in result]
return np.array(result, dtype=np.float16)
return np.array(result, dtype=np.float16)
self._skip_detector = SkipDetector(embedding_wrapper)
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
"""Truncate content with ellipsis if needed."""
@@ -1169,24 +1156,20 @@ class Filter:
"""Compute SHA256 hash for text caching."""
return hashlib.sha256(text.encode()).hexdigest()
def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray:
"""Normalize embedding vector."""
embedding = embedding.astype(np.float16)
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float16)
else:
embedding = embedding.astype(np.float16)
norm = np.linalg.norm(embedding)
return embedding / norm if norm > 0 else embedding
def _generate_embeddings_sync(self, model, texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
"""Synchronous embedding generation for single text or batch."""
is_single = isinstance(texts, str)
input_texts = [texts] if is_single else texts
embeddings = model.encode(input_texts, convert_to_numpy=True, show_progress_bar=False)
normalized = [self._normalize_embedding(emb) for emb in embeddings]
return normalized[0] if is_single else normalized
async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]:
"""Unified embedding generation for single text or batch with optimized caching."""
"""Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function."""
if self._embedding_function is None:
raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.")
is_single = isinstance(texts, str)
text_list = [texts] if is_single else texts
@@ -1219,8 +1202,24 @@ class Filter:
uncached_hashes.append(text_hash)
if uncached_texts:
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None
loop = asyncio.get_event_loop()
new_embeddings = await loop.run_in_executor(None, self._generate_embeddings_sync, self._model, uncached_texts)
raw_embeddings = await loop.run_in_executor(
None,
self._embedding_function,
uncached_texts,
None,
user
)
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
if isinstance(raw_embeddings[0], list):
new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings]
else:
new_embeddings = [self._normalize_embedding(raw_embeddings)]
else:
new_embeddings = [self._normalize_embedding(raw_embeddings)]
for j, embedding in enumerate(new_embeddings):
original_idx = uncached_indices[j]

View File

@@ -1,7 +1,5 @@
aiohttp>=3.12.15
pydantic>=2.11.7
sentence-transformers>=5.1.1
torch>=2.8.0
transformers>=4.57.0
numpy>=2.0.0
open-webui>=0.6.32
tiktoken>=0.11.0