diff --git a/memory_system.py b/memory_system.py index 60d09ec..32adfe9 100644 --- a/memory_system.py +++ b/memory_system.py @@ -15,21 +15,22 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError - -from open_webui.utils.chat import generate_chat_completion +from fastapi import Request from open_webui.models.users import Users from open_webui.routers.memories import Memories -from fastapi import Request +from open_webui.utils.chat import generate_chat_completion +from pydantic import BaseModel, ConfigDict, Field +from pydantic import ValidationError as PydanticValidationError logger = logging.getLogger(__name__) _SHARED_SKIP_DETECTOR_CACHE = {} _SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock() + class Constants: """Centralized configuration constants for the memory system.""" - + # Core System Limits MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query @@ -37,31 +38,32 @@ class Constants: MIN_MESSAGE_CHARS = 10 # Minimum message length for validation DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations - + # Cache System MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys - + # Retrieval & Similarity - SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval - RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations - EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations - LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold - + SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval + RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations + EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations + LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold + # Skip Detection SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check - + # Content Display - CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display - + CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display + # Default Models DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite" + class Prompts: """Container for all LLM prompts used in the memory system.""" @@ -181,6 +183,7 @@ Return: {{"ids": []}} Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained. """ + class Models: """Container for all Pydantic models used in the memory system.""" @@ -203,7 +206,7 @@ class Models: class MemoryOperation(StrictModel): """Pydantic model for memory operations with validation.""" - operation: 'Models.MemoryOperationType' = Field(description="Type of memory operation to perform") + operation: "Models.MemoryOperationType" = Field(description="Type of memory operation to perform") content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)") id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)") @@ -221,7 +224,7 @@ class Models: class ConsolidationResponse(BaseModel): """Pydantic model for memory consolidation LLM response - object containing array of memory operations.""" - ops: List['Models.MemoryOperation'] = Field(default_factory=list, description="List of memory operations to execute") + ops: List["Models.MemoryOperation"] = Field(default_factory=list, description="List of memory operations to execute") class MemoryRerankingResponse(BaseModel): """Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" @@ -442,52 +445,42 @@ class SkipDetector: self.embedding_function = embedding_function self._reference_embeddings = None self._initialize_reference_embeddings() - + def _initialize_reference_embeddings(self) -> None: """Compute and cache embeddings for category descriptions.""" try: - technical_embeddings = self.embedding_function( - self.TECHNICAL_CATEGORY_DESCRIPTIONS - ) - - instruction_embeddings = self.embedding_function( - self.INSTRUCTION_CATEGORY_DESCRIPTIONS - ) - - pure_math_embeddings = self.embedding_function( - self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS - ) - - translation_embeddings = self.embedding_function( - self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS - ) - - grammar_embeddings = self.embedding_function( - self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS - ) - - conversational_embeddings = self.embedding_function( - self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS - ) - + technical_embeddings = self.embedding_function(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + + instruction_embeddings = self.embedding_function(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + + pure_math_embeddings = self.embedding_function(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + + translation_embeddings = self.embedding_function(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + + grammar_embeddings = self.embedding_function(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + + conversational_embeddings = self.embedding_function(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS) + self._reference_embeddings = { - 'technical': np.array(technical_embeddings), - 'instruction': np.array(instruction_embeddings), - 'pure_math': np.array(pure_math_embeddings), - 'translation': np.array(translation_embeddings), - 'grammar': np.array(grammar_embeddings), - 'conversational': np.array(conversational_embeddings), + "technical": np.array(technical_embeddings), + "instruction": np.array(instruction_embeddings), + "pure_math": np.array(pure_math_embeddings), + "translation": np.array(translation_embeddings), + "grammar": np.array(grammar_embeddings), + "conversational": np.array(conversational_embeddings), } - + total_skip_categories = ( - len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + - len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + - len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + - len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + - len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + + len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + + len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + + len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + + len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + ) + + logger.info( + f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories" ) - - logger.info(f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories") except Exception as e: logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") self._reference_embeddings = None @@ -504,108 +497,107 @@ class SkipDetector: def _fast_path_skip_detection(self, message: str) -> Optional[str]: """Language-agnostic structural pattern detection with high confidence and low false positive rate.""" msg_len = len(message) - + # Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references) - url_pattern_count = message.count('http://') + message.count('https://') + url_pattern_count = message.count("http://") + message.count("https://") if url_pattern_count >= 5: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64) words = message.split() for word in words: cleaned = word.strip('.,;:!?()[]{}"\'"') - if len(cleaned) > 80 and cleaned.replace('-', '').replace('_', '').isalnum(): + if len(cleaned) > 80 and cleaned.replace("-", "").replace("_", "").isalnum(): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***) - separator_patterns = ['---', '===', '___', '***'] + separator_patterns = ["---", "===", "___", "***"] for pattern in separator_patterns: if message.count(pattern) >= 2: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 4: Command-line patterns with context-aware detection - lines_stripped = [line.strip() for line in message.split('\n') if line.strip()] + lines_stripped = [line.strip() for line in message.split("\n") if line.strip()] if lines_stripped: actual_command_lines = 0 for line in lines_stripped: - if line.startswith('$ ') and len(line) > 2: + if line.startswith("$ ") and len(line) > 2: parts = line[2:].split() if parts and parts[0].isalnum(): actual_command_lines += 1 - elif '$ ' in line: - dollar_index = line.find('$ ') - if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'): - parts = line[dollar_index+2:].split() - if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ['curl', 'wget', 'git', 'npm', 'pip', 'docker']): + elif "$ " in line: + dollar_index = line.find("$ ") + if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"): + parts = line[dollar_index + 2 :].split() + if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ["curl", "wget", "git", "npm", "pip", "docker"]): actual_command_lines += 1 - elif line.startswith('# ') and len(line) > 2: + elif line.startswith("# ") and len(line) > 2: rest = line[2:].strip() - if rest and not rest[0].isupper() and ' ' in rest: + if rest and not rest[0].isupper() and " " in rest: actual_command_lines += 1 - elif line.startswith('> ') and len(line) > 2: + elif line.startswith("> ") and len(line) > 2: pass - - if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']): + + if actual_command_lines >= 1 and any(c in message for c in ["http://", "https://", " | "]): return self.SkipReason.SKIP_TECHNICAL.value if actual_command_lines >= 3: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs) if msg_len > 30: - slash_count = message.count('/') + message.count('\\') - dot_count = message.count('.') + slash_count = message.count("/") + message.count("\\") + dot_count = message.count(".") path_chars = slash_count + dot_count if path_chars > 10 and (path_chars / msg_len) > 0.15: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 6: Markup character density (structured data) - markup_chars = sum(message.count(c) for c in '{}[]<>') + markup_chars = sum(message.count(c) for c in "{}[]<>") if markup_chars >= 6: if markup_chars / msg_len > 0.10: return self.SkipReason.SKIP_TECHNICAL.value - curly_count = message.count('{') + message.count('}') + curly_count = message.count("{") + message.count("}") if curly_count >= 10: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 7: Structured nested content with colons (key: value patterns) - line_count = message.count('\n') + line_count = message.count("\n") if line_count >= 8: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#')) - indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - - if (colon_lines / len(non_empty_lines) > 0.4 and - indented_lines / len(non_empty_lines) > 0.5): + colon_lines = sum(1 for line in non_empty_lines if ":" in line and not line.strip().startswith("#")) + indented_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t"))) + + if colon_lines / len(non_empty_lines) > 0.4 and indented_lines / len(non_empty_lines) > 0.5: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 8: Highly structured multi-line content (require markup chars for technical confidence) if line_count > 15: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>')) - structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - + markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in "{}[]<>")) + structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t"))) + if markup_in_lines / len(non_empty_lines) > 0.3: return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: - technical_keywords = ['function', 'class', 'import', 'return', 'const', 'var', 'let', 'def'] + technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"] if any(keyword in message.lower() for keyword in technical_keywords): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists) if line_count >= 3: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - indented_lines = sum(1 for line in non_empty_lines if line[0] in (' ', '\t')) + indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t")) if indented_lines / len(non_empty_lines) > 0.5: - code_indicators = ['def ', 'class ', 'function ', 'return ', 'import ', 'const ', 'let ', 'var ', 'public ', 'private '] + code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "] if any(indicator in message.lower() for indicator in code_indicators): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 10: Very high special character ratio (encoded data, technical output) if msg_len > 50: special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace()) @@ -614,10 +606,10 @@ class SkipDetector: alphanumeric = sum(1 for c in message if c.isalnum()) if alphanumeric / msg_len < 0.50: return self.SkipReason.SKIP_TECHNICAL.value - + return None - def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]: + def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) @@ -628,53 +620,49 @@ class SkipDetector: size_issue = self.validate_message_size(message, max_message_chars) if size_issue: return size_issue - + fast_skip = self._fast_path_skip_detection(message) if fast_skip: logger.info(f"Fast-path skip: {fast_skip}") return fast_skip - + if self._reference_embeddings is None: logger.warning("SkipDetector reference embeddings not initialized, allowing message through") return None - + try: message_embedding = np.array(self.embedding_function([message.strip()])[0]) - - conversational_similarities = np.dot( - message_embedding, - self._reference_embeddings['conversational'].T - ) + + conversational_similarities = np.dot(message_embedding, self._reference_embeddings["conversational"].T) max_conversational_similarity = float(conversational_similarities.max()) - + skip_categories = [ - ('instruction', self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), - ('translation', self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), - ('grammar', self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), - ('technical', self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), - ('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), + ("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), + ("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), + ("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), + ("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), + ("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), ] - + qualifying_categories = [] margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN - + for cat_key, skip_reason, descriptions in skip_categories: - similarities = np.dot( - message_embedding, - self._reference_embeddings[cat_key].T - ) + similarities = np.dot(message_embedding, self._reference_embeddings[cat_key].T) max_similarity = float(similarities.max()) - + if max_similarity > margin_threshold: qualifying_categories.append((max_similarity, cat_key, skip_reason)) - + if qualifying_categories: highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0]) - logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})") + logger.info( + f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})" + ) return highest_skip_reason.value - + return None - + except Exception as e: logger.error(f"Error in semantic skip detection: {e}") return None @@ -692,7 +680,7 @@ class LLMRerankingService: llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) if len(memories) > llm_trigger_threshold: - return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" + return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}" @@ -717,33 +705,29 @@ CANDIDATE MEMORIES: selected_memories.append(memory) logger.info(f"🧠 LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates") - + return selected_memories except Exception as e: - logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}") + logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}") return candidate_memories - async def rerank_memories( - self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None - ) -> Tuple[List[Dict], Dict[str, Any]]: + async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]: start_time = time.time() max_injection = self.memory_system.valves.max_memories_returned should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) - analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} + analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} if should_use_llm: extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) llm_candidates = candidate_memories[:extended_count] - await self.memory_system._emit_status( - emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False - ) + await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False) logger.info(f"Using LLM reranking: {decision_reason}") selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) - + if not selected_memories: logger.info("📭 No relevant memories after LLM analysis") await self.memory_system._emit_status(emitter, f"📭 No Relevant Memories After LLM Analysis", done=True) @@ -751,7 +735,7 @@ CANDIDATE MEMORIES: else: logger.info(f"Skipping LLM reranking: {decision_reason}") selected_memories = candidate_memories[:max_injection] - + duration = time.time() - start_time duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" retrieval_method = "LLM" if should_use_llm else "Semantic" @@ -769,10 +753,10 @@ class LLMConsolidationService: """Filter consolidation candidates by threshold and return candidates with threshold info.""" consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold] - + max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) candidates = candidates[:max_consolidation_memories] - + threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" return candidates, threshold_info @@ -812,7 +796,7 @@ class LLMConsolidationService: candidates, threshold_info = self._filter_consolidation_candidates(all_similarities) else: candidates = [] - threshold_info = 'N/A' + threshold_info = "N/A" logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") @@ -820,7 +804,9 @@ class LLMConsolidationService: return candidates - async def generate_consolidation_plan(self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None) -> List[Dict[str, Any]]: + async def generate_consolidation_plan( + self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None + ) -> List[Dict[str, Any]]: """Generate consolidation plan using LLM with clear system/user prompt separation.""" if candidate_memories: memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) @@ -925,7 +911,7 @@ class LLMConsolidationService: results = await asyncio.gather(*batch_tasks, return_exceptions=True) for idx, result in enumerate(results): operation = ops[idx] - + if isinstance(result, Exception): failed_count += 1 await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) @@ -982,21 +968,21 @@ class LLMConsolidationService: if operations: created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter) - + duration = time.time() - start_time logger.info(f"💾 Memory Consolidation Complete In {duration:.2f}s") - + total_operations = created_count + updated_count + deleted_count if total_operations > 0 or failed_count > 0: await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False) - + operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) memory_word = "Memory" if total_operations == 1 else "Memories" operations_summary = f"{', '.join(operation_details)} {memory_word}" - + if failed_count > 0: operations_summary += f" (❌ {failed_count} Failed)" - + await self.memory_system._emit_status(emitter, operations_summary, done=True) except Exception as e: @@ -1016,20 +1002,27 @@ class Filter: """Configuration valves for the Memory System.""" model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - + max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") - - semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval") - relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)") - + + semantic_retrieval_threshold: float = Field( + default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval" + ) + relaxed_semantic_threshold_multiplier: float = Field( + default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, + description="Adjusts similarity threshold for memory consolidation (lower = more candidates)", + ) + enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") - llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)") + llm_reranking_trigger_multiplier: float = Field( + default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)" + ) def __init__(self): """Initialize the Memory System filter with production validation.""" global _SHARED_SKIP_DETECTOR_CACHE - + self.valves = self.Valves() self._validate_system_configuration() @@ -1043,8 +1036,13 @@ class Filter: self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) - async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, - __model__: Optional[str] = None, __request__: Optional[Request] = None) -> None: + async def _set_pipeline_context( + self, + __event_emitter__: Optional[Callable] = None, + __user__: Optional[Dict[str, Any]] = None, + __model__: Optional[str] = None, + __request__: Optional[Request] = None, + ) -> None: """Set pipeline context parameters to avoid duplication in inlet/outlet methods.""" if __event_emitter__: self.__current_event_emitter__ = __event_emitter__ @@ -1054,17 +1052,17 @@ class Filter: self.__model__ = __model__ if __request__: self.__request__ = __request__ - - if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'): + + if self._embedding_function is None and hasattr(__request__.app.state, "EMBEDDING_FUNCTION"): self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"✅ Using OpenWebUI's embedding function") - + if self._skip_detector is None: global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK - embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') - embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') + embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") + embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") cache_key = f"{embedding_engine}:{embedding_model}" - + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: if cache_key in _SHARED_SKIP_DETECTOR_CACHE: logger.info(f"♻️ Reusing cached skip detector: {cache_key}") @@ -1072,6 +1070,7 @@ class Filter: else: logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") embedding_fn = self._embedding_function + def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: result = embedding_fn(texts, prefix=None, user=None) if isinstance(result, list): @@ -1079,12 +1078,11 @@ class Filter: return [np.array(emb, dtype=np.float16) for emb in result] return np.array(result, dtype=np.float16) return np.array(result, dtype=np.float16) - + self._skip_detector = SkipDetector(embedding_wrapper) _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector logger.info(f"✅ Skip detector initialized and cached") - def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" if max_length is None: @@ -1148,7 +1146,7 @@ class Filter: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" if self._embedding_function is None: raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.") - + is_single = isinstance(texts, str) text_list = [texts] if is_single else texts @@ -1181,17 +1179,11 @@ class Filter: uncached_hashes.append(text_hash) if uncached_texts: - user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None - + user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None + loop = asyncio.get_event_loop() - raw_embeddings = await loop.run_in_executor( - None, - self._embedding_function, - uncached_texts, - None, - user - ) - + raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user) + if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], list): new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] @@ -1211,15 +1203,13 @@ class Filter: return result_embeddings[0] else: valid_count = sum(1 for emb in result_embeddings if emb is not None) - logger.info( - f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid" - ) + logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid") return result_embeddings def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: if self._skip_detector is None: raise RuntimeError("🤖 Skip detector not initialized") - + skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) @@ -1290,7 +1280,7 @@ class Filter: def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: """Unified cache key generation for all cache types.""" if content: - content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] return f"{cache_type}_{user_id}:{content_hash}" return f"{cache_type}_{user_id}" @@ -1310,7 +1300,7 @@ class Filter: if record_date: try: if isinstance(record_date, str): - parsed_date = datetime.fromisoformat(record_date.replace('Z', '+00:00')) + parsed_date = datetime.fromisoformat(record_date.replace("Z", "+00:00")) else: parsed_date = record_date formatted_date = parsed_date.strftime("%b %d %Y") @@ -1393,14 +1383,14 @@ class Filter: memory_count = len(memories) memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request." formatted_memories = [] - + for idx, memory in enumerate(memories, 1): formatted_memory = f"- {' '.join(memory['content'].split())}" formatted_memories.append(formatted_memory) - - content_preview = self._truncate_content(memory['content']) + + content_preview = self._truncate_content(memory["content"]) await self._emit_status(emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False) - + memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only." memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}" content_parts.append(memory_context_block) @@ -1413,7 +1403,7 @@ class Filter: body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" else: body["messages"].insert(0, {"role": "system", "content": memory_context}) - + if memories and user_id: description = f"🧠 Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context" await self._emit_status(emitter, description, done=True) @@ -1427,9 +1417,7 @@ class Filter: memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat() return memory_dict - async def _compute_similarities( - self, user_message: str, user_id: str, user_memories: List - ) -> Tuple[List[Dict], float, List[Dict]]: + async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]: """Compute similarity scores between user message and memories.""" if not user_memories: return [], self.valves.semantic_retrieval_threshold, [] @@ -1461,7 +1449,7 @@ class Filter: memory_data.sort(key=lambda x: x["relevance"], reverse=True) threshold = self.valves.semantic_retrieval_threshold - filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] + filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] return filtered_memories, threshold, memory_data async def inlet( @@ -1536,9 +1524,7 @@ class Filter: cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) - task = asyncio.create_task( - self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities) - ) + task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) self._background_tasks.add(task) def safe_cleanup(t: asyncio.Task) -> None: @@ -1582,11 +1568,7 @@ class Filter: await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) - memory_contents = [ - memory.content - for memory in user_memories - if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS - ] + memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS] if memory_contents: await self._generate_embeddings(memory_contents, user_id) @@ -1615,7 +1597,7 @@ class Filter: if not id_stripped: logger.warning(f"⚠️ Skipping UPDATE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value - + content_stripped = operation.content.strip() if not content_stripped: logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content") @@ -1649,18 +1631,18 @@ class Filter: """Remove $ref references and ensure required fields for Azure OpenAI.""" if not isinstance(schema, dict): return schema - - if '$ref' in schema: - ref_path = schema['$ref'] - if ref_path.startswith('#/$defs/'): - def_name = ref_path.split('/')[-1] + + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] if schema_defs and def_name in schema_defs: return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs) - return {'type': 'object'} - + return {"type": "object"} + result = {} for key, value in schema.items(): - if key == '$defs': + if key == "$defs": continue elif isinstance(value, dict): result[key] = self._remove_refs_from_schema(value, schema_defs) @@ -1668,10 +1650,10 @@ class Filter: result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value] else: result[key] = value - - if result.get('type') == 'object' and 'properties' in result: - result['required'] = list(result['properties'].keys()) - + + if result.get("type") == "object" and "properties" in result: + result["required"] = list(result["properties"].keys()) + return result async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]: @@ -1685,16 +1667,16 @@ class Filter: form_data = { "model": model_to_use, - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], "max_tokens": 4096, "stream": False, } if response_model: raw_schema = response_model.model_json_schema() - schema_defs = raw_schema.get('$defs', {}) + schema_defs = raw_schema.get("$defs", {}) schema = self._remove_refs_from_schema(raw_schema, schema_defs) - schema['type'] = 'object' + schema["type"] = "object" form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}} try: @@ -1718,7 +1700,12 @@ class Filter: if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0: first_choice = response_data["choices"][0] - if isinstance(first_choice, dict) and "message" in first_choice and isinstance(first_choice["message"], dict) and "content" in first_choice["message"]: + if ( + isinstance(first_choice, dict) + and "message" in first_choice + and isinstance(first_choice["message"], dict) + and "content" in first_choice["message"] + ): content = first_choice["message"]["content"] else: raise ValueError("🤖 Invalid response structure: missing content in message") @@ -1726,7 +1713,7 @@ class Filter: raise ValueError(f"🤖 Unexpected LLM response format: {response_data}") if response_model: - try: + try: parsed_data = json.loads(content) return response_model.model_validate(parsed_data) except json.JSONDecodeError as e: