mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Refactor SkipDetector to use a callable embedding function instead of SentenceTransformer; update requirements to remove unnecessary dependencies.
This commit is contained in:
179
memory_system.py
179
memory_system.py
@@ -15,19 +15,15 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError
|
||||
from sentence_transformers import SentenceTransformer
|
||||
|
||||
from open_webui.utils.chat import generate_chat_completion
|
||||
from open_webui.models.users import Users
|
||||
from open_webui.routers.memories import Memories
|
||||
from fastapi import Request
|
||||
|
||||
logging.getLogger("transformers").setLevel(logging.ERROR)
|
||||
logging.getLogger("sentence_transformers").setLevel(logging.ERROR)
|
||||
|
||||
logger = logging.getLogger("MemorySystem")
|
||||
|
||||
_SHARED_MODEL_CACHE = {}
|
||||
_SHARED_SKIP_DETECTOR_CACHE = {}
|
||||
|
||||
class Constants:
|
||||
"""Centralized configuration constants for the memory system."""
|
||||
@@ -65,7 +61,6 @@ class Constants:
|
||||
|
||||
# Default Models
|
||||
DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite"
|
||||
DEFAULT_EMBEDDING_MODEL = "Alibaba-NLP/gte-multilingual-base"
|
||||
|
||||
class Prompts:
|
||||
"""Container for all LLM prompts used in the memory system."""
|
||||
@@ -462,58 +457,46 @@ class SkipDetector:
|
||||
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
|
||||
}
|
||||
|
||||
def __init__(self, embedding_model: SentenceTransformer):
|
||||
"""Initialize the skip detector with an embedding model and compute reference embeddings."""
|
||||
self.embedding_model = embedding_model
|
||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
|
||||
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
|
||||
self.embedding_function = embedding_function
|
||||
self._reference_embeddings = None
|
||||
self._initialize_reference_embeddings()
|
||||
|
||||
def _initialize_reference_embeddings(self) -> None:
|
||||
"""Compute and cache embeddings for category descriptions."""
|
||||
try:
|
||||
technical_embeddings = self.embedding_model.encode(
|
||||
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
technical_embeddings = self.embedding_function(
|
||||
self.TECHNICAL_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
instruction_embeddings = self.embedding_model.encode(
|
||||
self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
instruction_embeddings = self.embedding_function(
|
||||
self.INSTRUCTION_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
pure_math_embeddings = self.embedding_model.encode(
|
||||
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
pure_math_embeddings = self.embedding_function(
|
||||
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
translation_embeddings = self.embedding_model.encode(
|
||||
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
translation_embeddings = self.embedding_function(
|
||||
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
grammar_embeddings = self.embedding_model.encode(
|
||||
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
grammar_embeddings = self.embedding_function(
|
||||
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
conversational_embeddings = self.embedding_model.encode(
|
||||
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS,
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
conversational_embeddings = self.embedding_function(
|
||||
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS
|
||||
)
|
||||
|
||||
self._reference_embeddings = {
|
||||
'technical': technical_embeddings,
|
||||
'instruction': instruction_embeddings,
|
||||
'pure_math': pure_math_embeddings,
|
||||
'translation': translation_embeddings,
|
||||
'grammar': grammar_embeddings,
|
||||
'conversational': conversational_embeddings,
|
||||
'technical': np.array(technical_embeddings),
|
||||
'instruction': np.array(instruction_embeddings),
|
||||
'pure_math': np.array(pure_math_embeddings),
|
||||
'translation': np.array(translation_embeddings),
|
||||
'grammar': np.array(grammar_embeddings),
|
||||
'conversational': np.array(conversational_embeddings),
|
||||
}
|
||||
|
||||
total_skip_categories = (
|
||||
@@ -569,7 +552,6 @@ class SkipDetector:
|
||||
parts = line[2:].split()
|
||||
if parts and parts[0].isalnum():
|
||||
actual_command_lines += 1
|
||||
# Check for lines with embedded $ commands (e.g., "Run: $ command")
|
||||
elif '$ ' in line:
|
||||
dollar_index = line.find('$ ')
|
||||
if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'):
|
||||
@@ -583,7 +565,6 @@ class SkipDetector:
|
||||
elif line.startswith('> ') and len(line) > 2:
|
||||
pass
|
||||
|
||||
# Lowered threshold: even 1 command line with URL/pipe is technical
|
||||
if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']):
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
if actual_command_lines >= 3:
|
||||
@@ -602,23 +583,19 @@ class SkipDetector:
|
||||
if markup_chars >= 6:
|
||||
if markup_chars / msg_len > 0.10:
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
# Special check for JSON-like structures (many nested braces)
|
||||
# Even with low density, if we have lots of curly braces, it's likely JSON
|
||||
curly_count = message.count('{') + message.count('}')
|
||||
if curly_count >= 10:
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
|
||||
# Pattern 7: Structured nested content with colons (key: value patterns)
|
||||
line_count = message.count('\n')
|
||||
if line_count >= 8: # At least 8 lines
|
||||
if line_count >= 8:
|
||||
lines = message.split('\n')
|
||||
non_empty_lines = [line for line in lines if line.strip()]
|
||||
if non_empty_lines:
|
||||
# Count lines with colon patterns (key: value or similar)
|
||||
colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#'))
|
||||
indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
|
||||
|
||||
# If most lines have colons and indentation, it's structured data
|
||||
if (colon_lines / len(non_empty_lines) > 0.4 and
|
||||
indented_lines / len(non_empty_lines) > 0.5):
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
@@ -631,7 +608,6 @@ class SkipDetector:
|
||||
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>'))
|
||||
structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
|
||||
|
||||
# Require high markup presence or indented structure with technical keywords
|
||||
if markup_in_lines / len(non_empty_lines) > 0.3:
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
elif structured_lines / len(non_empty_lines) > 0.6:
|
||||
@@ -684,18 +660,12 @@ class SkipDetector:
|
||||
return None
|
||||
|
||||
try:
|
||||
from sentence_transformers import util
|
||||
message_embedding = np.array(self.embedding_function([message.strip()])[0])
|
||||
|
||||
message_embedding = self.embedding_model.encode(
|
||||
message.strip(),
|
||||
convert_to_tensor=True,
|
||||
show_progress_bar=False
|
||||
)
|
||||
|
||||
conversational_similarities = util.cos_sim(
|
||||
conversational_similarities = np.dot(
|
||||
message_embedding,
|
||||
self._reference_embeddings['conversational']
|
||||
)[0]
|
||||
self._reference_embeddings['conversational'].T
|
||||
)
|
||||
max_conversational_similarity = float(conversational_similarities.max())
|
||||
|
||||
skip_categories = [
|
||||
@@ -707,10 +677,10 @@ class SkipDetector:
|
||||
]
|
||||
|
||||
for cat_key, skip_reason, descriptions in skip_categories:
|
||||
similarities = util.cos_sim(
|
||||
similarities = np.dot(
|
||||
message_embedding,
|
||||
self._reference_embeddings[cat_key]
|
||||
)[0]
|
||||
self._reference_embeddings[cat_key].T
|
||||
)
|
||||
max_similarity = float(similarities.max())
|
||||
|
||||
if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD:
|
||||
@@ -1069,7 +1039,6 @@ class Filter:
|
||||
"""Configuration valves for the Memory System."""
|
||||
|
||||
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
||||
embedding_model: str = Field(default=Constants.DEFAULT_EMBEDDING_MODEL, description="Sentence transformer model for embeddings")
|
||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
||||
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
||||
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
|
||||
@@ -1079,7 +1048,7 @@ class Filter:
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the Memory System filter with production validation."""
|
||||
global _SHARED_MODEL_CACHE
|
||||
global _SHARED_SKIP_DETECTOR_CACHE
|
||||
|
||||
self.valves = self.Valves()
|
||||
self._validate_system_configuration()
|
||||
@@ -1088,21 +1057,8 @@ class Filter:
|
||||
self._background_tasks: set = set()
|
||||
self._shutdown_event = asyncio.Event()
|
||||
|
||||
model_key = self.valves.embedding_model
|
||||
|
||||
if model_key in _SHARED_MODEL_CACHE:
|
||||
logger.info(f"♻️ Reusing cached embedding model: {model_key}")
|
||||
self._model = _SHARED_MODEL_CACHE[model_key]["model"]
|
||||
self._skip_detector = _SHARED_MODEL_CACHE[model_key]["skip_detector"]
|
||||
else:
|
||||
logger.info(f"🤖 Loading embedding model: {model_key} (cache has {len(_SHARED_MODEL_CACHE)} models)")
|
||||
self._model = SentenceTransformer(self.valves.embedding_model, device="auto", trust_remote_code=True)
|
||||
self._skip_detector = SkipDetector(self._model)
|
||||
_SHARED_MODEL_CACHE[model_key] = {
|
||||
"model": self._model,
|
||||
"skip_detector": self._skip_detector
|
||||
}
|
||||
logger.info(f"✅ Embedding model and skip detector initialized and cached")
|
||||
self._embedding_function = None
|
||||
self._skip_detector = None
|
||||
|
||||
self._llm_reranking_service = LLMRerankingService(self)
|
||||
self._llm_consolidation_service = LLMConsolidationService(self)
|
||||
@@ -1118,6 +1074,35 @@ class Filter:
|
||||
self.__model__ = __model__
|
||||
if __request__:
|
||||
self.__request__ = __request__
|
||||
|
||||
if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'):
|
||||
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
|
||||
logger.info(f"✅ Using OpenWebUI's embedding function")
|
||||
|
||||
if self._skip_detector is None:
|
||||
global _SHARED_SKIP_DETECTOR_CACHE
|
||||
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
|
||||
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
|
||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||
|
||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||
self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key]
|
||||
else:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list):
|
||||
if isinstance(result[0], list):
|
||||
return [np.array(emb, dtype=np.float16) for emb in result]
|
||||
return np.array(result, dtype=np.float16)
|
||||
return np.array(result, dtype=np.float16)
|
||||
|
||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||
logger.info(f"✅ Skip detector initialized and cached")
|
||||
|
||||
|
||||
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
||||
"""Truncate content with ellipsis if needed."""
|
||||
@@ -1169,24 +1154,20 @@ class Filter:
|
||||
"""Compute SHA256 hash for text caching."""
|
||||
return hashlib.sha256(text.encode()).hexdigest()
|
||||
|
||||
def _normalize_embedding(self, embedding: np.ndarray) -> np.ndarray:
|
||||
def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray:
|
||||
"""Normalize embedding vector."""
|
||||
embedding = embedding.astype(np.float16)
|
||||
if isinstance(embedding, list):
|
||||
embedding = np.array(embedding, dtype=np.float16)
|
||||
else:
|
||||
embedding = embedding.astype(np.float16)
|
||||
norm = np.linalg.norm(embedding)
|
||||
return embedding / norm if norm > 0 else embedding
|
||||
|
||||
def _generate_embeddings_sync(self, model, texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
"""Synchronous embedding generation for single text or batch."""
|
||||
is_single = isinstance(texts, str)
|
||||
input_texts = [texts] if is_single else texts
|
||||
|
||||
embeddings = model.encode(input_texts, convert_to_numpy=True, show_progress_bar=False)
|
||||
normalized = [self._normalize_embedding(emb) for emb in embeddings]
|
||||
|
||||
return normalized[0] if is_single else normalized
|
||||
|
||||
async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
"""Unified embedding generation for single text or batch with optimized caching."""
|
||||
"""Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function."""
|
||||
if self._embedding_function is None:
|
||||
raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.")
|
||||
|
||||
is_single = isinstance(texts, str)
|
||||
text_list = [texts] if is_single else texts
|
||||
|
||||
@@ -1219,8 +1200,24 @@ class Filter:
|
||||
uncached_hashes.append(text_hash)
|
||||
|
||||
if uncached_texts:
|
||||
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
new_embeddings = await loop.run_in_executor(None, self._generate_embeddings_sync, self._model, uncached_texts)
|
||||
raw_embeddings = await loop.run_in_executor(
|
||||
None,
|
||||
self._embedding_function,
|
||||
uncached_texts,
|
||||
None,
|
||||
user
|
||||
)
|
||||
|
||||
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
|
||||
if isinstance(raw_embeddings[0], list):
|
||||
new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings]
|
||||
else:
|
||||
new_embeddings = [self._normalize_embedding(raw_embeddings)]
|
||||
else:
|
||||
new_embeddings = [self._normalize_embedding(raw_embeddings)]
|
||||
|
||||
for j, embedding in enumerate(new_embeddings):
|
||||
original_idx = uncached_indices[j]
|
||||
|
||||
Reference in New Issue
Block a user