diff --git a/memory_system.py b/memory_system.py index 60d09ec..d52cd2d 100644 --- a/memory_system.py +++ b/memory_system.py @@ -15,7 +15,12 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError as PydanticValidationError, +) from open_webui.utils.chat import generate_chat_completion from open_webui.models.users import Users @@ -27,9 +32,10 @@ logger = logging.getLogger(__name__) _SHARED_SKIP_DETECTOR_CACHE = {} _SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock() + class Constants: """Centralized configuration constants for the memory system.""" - + # Core System Limits MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query @@ -37,31 +43,40 @@ class Constants: MIN_MESSAGE_CHARS = 10 # Minimum message length for validation DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations - + # Cache System MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys - + # Retrieval & Similarity - SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval - RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations - EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations - LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold - + SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval + RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = ( + 0.8 # Multiplier for relaxed similarity threshold in secondary operations + ) + EXTENDED_MAX_MEMORY_MULTIPLIER = ( + 1.6 # Multiplier for expanding memory candidates in advanced operations + ) + LLM_RERANKING_TRIGGER_MULTIPLIER = ( + 0.8 # Multiplier for LLM reranking trigger threshold + ) + # Skip Detection - SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification + SKIP_CATEGORY_MARGIN = ( + 0.5 # Margin above conversational similarity for skip category classification + ) # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check - + # Content Display - CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display - + CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display + # Default Models DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite" + class Prompts: """Container for all LLM prompts used in the memory system.""" @@ -181,6 +196,7 @@ Return: {{"ids": []}} Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained. """ + class Models: """Container for all Pydantic models used in the memory system.""" @@ -203,9 +219,15 @@ class Models: class MemoryOperation(StrictModel): """Pydantic model for memory operations with validation.""" - operation: 'Models.MemoryOperationType' = Field(description="Type of memory operation to perform") - content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)") - id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)") + operation: "Models.MemoryOperationType" = Field( + description="Type of memory operation to perform" + ) + content: str = Field( + description="Memory content (required for CREATE/UPDATE, empty for DELETE)" + ) + id: str = Field( + description="Memory ID (empty for CREATE, required for UPDATE/DELETE)" + ) def validate_operation(self, existing_memory_ids: Optional[set] = None) -> bool: """Validate the memory operation against existing memory IDs.""" @@ -214,19 +236,27 @@ class Models: if self.operation == Models.MemoryOperationType.CREATE: return True - elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]: + elif self.operation in [ + Models.MemoryOperationType.UPDATE, + Models.MemoryOperationType.DELETE, + ]: return self.id in existing_memory_ids return False class ConsolidationResponse(BaseModel): """Pydantic model for memory consolidation LLM response - object containing array of memory operations.""" - ops: List['Models.MemoryOperation'] = Field(default_factory=list, description="List of memory operations to execute") + ops: List["Models.MemoryOperation"] = Field( + default_factory=list, description="List of memory operations to execute" + ) class MemoryRerankingResponse(BaseModel): """Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" - ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query") + ids: List[str] = Field( + default_factory=list, + description="List of memory IDs selected as most relevant for the user query", + ) class UnifiedCacheManager: @@ -274,7 +304,10 @@ class UnifiedCacheManager: type_cache = user_cache[cache_type] - if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type: + if ( + key not in type_cache + and len(type_cache) >= self.max_cache_size_per_type + ): evicted_key, _ = type_cache.popitem(last=False) if key in type_cache: @@ -285,7 +318,9 @@ class UnifiedCacheManager: self.caches.move_to_end(user_id) - async def clear_user_cache(self, user_id: str, cache_type: Optional[str] = None) -> int: + async def clear_user_cache( + self, user_id: str, cache_type: Optional[str] = None + ) -> int: """Clear specific cache type for user, or all caches for user if cache_type is None.""" async with self._lock: if user_id not in self.caches: @@ -294,7 +329,9 @@ class UnifiedCacheManager: user_cache = self.caches[user_id] if cache_type is None: - total_cleared = sum(len(type_cache) for type_cache in user_cache.values()) + total_cleared = sum( + len(type_cache) for type_cache in user_cache.values() + ) del self.caches[user_id] return total_cleared else: @@ -437,187 +474,254 @@ class SkipDetector: SkipReason.SKIP_GRAMMAR_PROOFREAD: "๐Ÿ“ Grammar/Proofreading Request Detected, skipping memory operations", } - def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + def __init__( + self, + embedding_function: Callable[ + [Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]] + ], + ): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None self._initialize_reference_embeddings() - + def _initialize_reference_embeddings(self) -> None: """Compute and cache embeddings for category descriptions.""" try: technical_embeddings = self.embedding_function( self.TECHNICAL_CATEGORY_DESCRIPTIONS ) - + instruction_embeddings = self.embedding_function( self.INSTRUCTION_CATEGORY_DESCRIPTIONS ) - + pure_math_embeddings = self.embedding_function( self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS ) - + translation_embeddings = self.embedding_function( self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS ) - + grammar_embeddings = self.embedding_function( self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS ) - + conversational_embeddings = self.embedding_function( self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS ) - + self._reference_embeddings = { - 'technical': np.array(technical_embeddings), - 'instruction': np.array(instruction_embeddings), - 'pure_math': np.array(pure_math_embeddings), - 'translation': np.array(translation_embeddings), - 'grammar': np.array(grammar_embeddings), - 'conversational': np.array(conversational_embeddings), + "technical": np.array(technical_embeddings), + "instruction": np.array(instruction_embeddings), + "pure_math": np.array(pure_math_embeddings), + "translation": np.array(translation_embeddings), + "grammar": np.array(grammar_embeddings), + "conversational": np.array(conversational_embeddings), } - + total_skip_categories = ( - len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + - len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + - len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + - len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + - len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + + len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + + len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + + len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + + len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + ) + + logger.info( + f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories" ) - - logger.info(f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories") except Exception as e: logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") self._reference_embeddings = None - def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]: + def validate_message_size( + self, message: str, max_message_chars: int + ) -> Optional[str]: """Validate message size constraints.""" if not message or not message.strip(): return SkipDetector.SkipReason.SKIP_SIZE.value trimmed = message.strip() - if len(trimmed) < Constants.MIN_MESSAGE_CHARS or len(trimmed) > max_message_chars: + if ( + len(trimmed) < Constants.MIN_MESSAGE_CHARS + or len(trimmed) > max_message_chars + ): return SkipDetector.SkipReason.SKIP_SIZE.value return None def _fast_path_skip_detection(self, message: str) -> Optional[str]: """Language-agnostic structural pattern detection with high confidence and low false positive rate.""" msg_len = len(message) - + # Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references) - url_pattern_count = message.count('http://') + message.count('https://') + url_pattern_count = message.count("http://") + message.count("https://") if url_pattern_count >= 5: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64) words = message.split() for word in words: cleaned = word.strip('.,;:!?()[]{}"\'"') - if len(cleaned) > 80 and cleaned.replace('-', '').replace('_', '').isalnum(): + if ( + len(cleaned) > 80 + and cleaned.replace("-", "").replace("_", "").isalnum() + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***) - separator_patterns = ['---', '===', '___', '***'] + separator_patterns = ["---", "===", "___", "***"] for pattern in separator_patterns: if message.count(pattern) >= 2: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 4: Command-line patterns with context-aware detection - lines_stripped = [line.strip() for line in message.split('\n') if line.strip()] + lines_stripped = [line.strip() for line in message.split("\n") if line.strip()] if lines_stripped: actual_command_lines = 0 for line in lines_stripped: - if line.startswith('$ ') and len(line) > 2: + if line.startswith("$ ") and len(line) > 2: parts = line[2:].split() if parts and parts[0].isalnum(): actual_command_lines += 1 - elif '$ ' in line: - dollar_index = line.find('$ ') - if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'): - parts = line[dollar_index+2:].split() - if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ['curl', 'wget', 'git', 'npm', 'pip', 'docker']): + elif "$ " in line: + dollar_index = line.find("$ ") + if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"): + parts = line[dollar_index + 2 :].split() + if ( + parts + and len(parts[0]) > 0 + and ( + parts[0].isalnum() + or parts[0] + in ["curl", "wget", "git", "npm", "pip", "docker"] + ) + ): actual_command_lines += 1 - elif line.startswith('# ') and len(line) > 2: + elif line.startswith("# ") and len(line) > 2: rest = line[2:].strip() - if rest and not rest[0].isupper() and ' ' in rest: + if rest and not rest[0].isupper() and " " in rest: actual_command_lines += 1 - elif line.startswith('> ') and len(line) > 2: + elif line.startswith("> ") and len(line) > 2: pass - - if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']): + + if actual_command_lines >= 1 and any( + c in message for c in ["http://", "https://", " | "] + ): return self.SkipReason.SKIP_TECHNICAL.value if actual_command_lines >= 3: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs) if msg_len > 30: - slash_count = message.count('/') + message.count('\\') - dot_count = message.count('.') + slash_count = message.count("/") + message.count("\\") + dot_count = message.count(".") path_chars = slash_count + dot_count if path_chars > 10 and (path_chars / msg_len) > 0.15: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 6: Markup character density (structured data) - markup_chars = sum(message.count(c) for c in '{}[]<>') + markup_chars = sum(message.count(c) for c in "{}[]<>") if markup_chars >= 6: if markup_chars / msg_len > 0.10: return self.SkipReason.SKIP_TECHNICAL.value - curly_count = message.count('{') + message.count('}') + curly_count = message.count("{") + message.count("}") if curly_count >= 10: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 7: Structured nested content with colons (key: value patterns) - line_count = message.count('\n') + line_count = message.count("\n") if line_count >= 8: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#')) - indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - - if (colon_lines / len(non_empty_lines) > 0.4 and - indented_lines / len(non_empty_lines) > 0.5): + colon_lines = sum( + 1 + for line in non_empty_lines + if ":" in line and not line.strip().startswith("#") + ) + indented_lines = sum( + 1 for line in non_empty_lines if line.startswith((" ", "\t")) + ) + + if ( + colon_lines / len(non_empty_lines) > 0.4 + and indented_lines / len(non_empty_lines) > 0.5 + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 8: Highly structured multi-line content (require markup chars for technical confidence) if line_count > 15: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>')) - structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - + markup_in_lines = sum( + 1 for line in non_empty_lines if any(c in line for c in "{}[]<>") + ) + structured_lines = sum( + 1 for line in non_empty_lines if line.startswith((" ", "\t")) + ) + if markup_in_lines / len(non_empty_lines) > 0.3: return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: - technical_keywords = ['function', 'class', 'import', 'return', 'const', 'var', 'let', 'def'] - if any(keyword in message.lower() for keyword in technical_keywords): + technical_keywords = [ + "function", + "class", + "import", + "return", + "const", + "var", + "let", + "def", + ] + if any( + keyword in message.lower() for keyword in technical_keywords + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists) if line_count >= 3: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - indented_lines = sum(1 for line in non_empty_lines if line[0] in (' ', '\t')) + indented_lines = sum( + 1 for line in non_empty_lines if line[0] in (" ", "\t") + ) if indented_lines / len(non_empty_lines) > 0.5: - code_indicators = ['def ', 'class ', 'function ', 'return ', 'import ', 'const ', 'let ', 'var ', 'public ', 'private '] - if any(indicator in message.lower() for indicator in code_indicators): + code_indicators = [ + "def ", + "class ", + "function ", + "return ", + "import ", + "const ", + "let ", + "var ", + "public ", + "private ", + ] + if any( + indicator in message.lower() for indicator in code_indicators + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 10: Very high special character ratio (encoded data, technical output) if msg_len > 50: - special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace()) + special_chars = sum( + 1 for c in message if not c.isalnum() and not c.isspace() + ) special_ratio = special_chars / msg_len if special_ratio > 0.35: alphanumeric = sum(1 for c in message if c.isalnum()) if alphanumeric / msg_len < 0.50: return self.SkipReason.SKIP_TECHNICAL.value - + return None - def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]: + def detect_skip_reason( + self, message: str, max_message_chars: int, memory_system: "Filter" + ) -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) @@ -628,53 +732,79 @@ class SkipDetector: size_issue = self.validate_message_size(message, max_message_chars) if size_issue: return size_issue - + fast_skip = self._fast_path_skip_detection(message) if fast_skip: logger.info(f"Fast-path skip: {fast_skip}") return fast_skip - + if self._reference_embeddings is None: - logger.warning("SkipDetector reference embeddings not initialized, allowing message through") + logger.warning( + "SkipDetector reference embeddings not initialized, allowing message through" + ) return None - + try: message_embedding = np.array(self.embedding_function([message.strip()])[0]) - + conversational_similarities = np.dot( - message_embedding, - self._reference_embeddings['conversational'].T + message_embedding, self._reference_embeddings["conversational"].T ) max_conversational_similarity = float(conversational_similarities.max()) - + skip_categories = [ - ('instruction', self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), - ('translation', self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), - ('grammar', self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), - ('technical', self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), - ('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), + ( + "instruction", + self.SkipReason.SKIP_INSTRUCTION, + self.INSTRUCTION_CATEGORY_DESCRIPTIONS, + ), + ( + "translation", + self.SkipReason.SKIP_TRANSLATION, + self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS, + ), + ( + "grammar", + self.SkipReason.SKIP_GRAMMAR_PROOFREAD, + self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS, + ), + ( + "technical", + self.SkipReason.SKIP_TECHNICAL, + self.TECHNICAL_CATEGORY_DESCRIPTIONS, + ), + ( + "pure_math", + self.SkipReason.SKIP_PURE_MATH, + self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS, + ), ] - + qualifying_categories = [] - margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN - + margin_threshold = ( + max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN + ) + for cat_key, skip_reason, descriptions in skip_categories: similarities = np.dot( - message_embedding, - self._reference_embeddings[cat_key].T + message_embedding, self._reference_embeddings[cat_key].T ) max_similarity = float(similarities.max()) - + if max_similarity > margin_threshold: qualifying_categories.append((max_similarity, cat_key, skip_reason)) - + if qualifying_categories: - highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0]) - logger.info(f"๐Ÿšซ Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})") + highest_similarity, highest_cat_key, highest_skip_reason = max( + qualifying_categories, key=lambda x: x[0] + ) + logger.info( + f"๐Ÿšซ Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})" + ) return highest_skip_reason.value - + return None - + except Exception as e: logger.error(f"Error in semantic skip detection: {e}") return None @@ -690,13 +820,28 @@ class LLMRerankingService: if not self.memory_system.valves.enable_llm_reranking: return False, "LLM reranking disabled" - llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) + llm_trigger_threshold = int( + self.memory_system.valves.max_memories_returned + * self.memory_system.valves.llm_reranking_trigger_multiplier + ) if len(memories) > llm_trigger_threshold: - return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" + return ( + True, + f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold", + ) - return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}" + return ( + False, + f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}", + ) - async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]: + async def _llm_select_memories( + self, + user_message: str, + candidate_memories: List[Dict], + max_count: int, + emitter: Optional[Callable] = None, + ) -> List[Dict]: """Use LLM to select most relevant memories.""" memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) memory_context = "\n".join(memory_lines) @@ -709,53 +854,83 @@ CANDIDATE MEMORIES: {memory_context}""" try: - response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse) + response = await self.memory_system._query_llm( + Prompts.MEMORY_RERANKING, + user_prompt, + response_model=Models.MemoryRerankingResponse, + ) selected_memories = [] for memory in candidate_memories: if memory["id"] in response.ids and len(selected_memories) < max_count: selected_memories.append(memory) - logger.info(f"๐Ÿง  LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates") - + logger.info( + f"๐Ÿง  LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates" + ) + return selected_memories except Exception as e: - logger.warning(f"๐Ÿค– LLM reranking failed during memory relevance analysis: {str(e)}") + logger.warning( + f"๐Ÿค– LLM reranking failed during memory relevance analysis: {str(e)}" + ) return candidate_memories async def rerank_memories( - self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None + self, + user_message: str, + candidate_memories: List[Dict], + emitter: Optional[Callable] = None, ) -> Tuple[List[Dict], Dict[str, Any]]: start_time = time.time() max_injection = self.memory_system.valves.max_memories_returned - should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) + should_use_llm, decision_reason = self._should_use_llm_reranking( + candidate_memories + ) - analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} + analysis_info = { + "llm_decision": should_use_llm, + "decision_reason": decision_reason, + "candidate_count": len(candidate_memories), + } if should_use_llm: - extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) + extended_count = int( + self.memory_system.valves.max_memories_returned + * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) llm_candidates = candidate_memories[:extended_count] await self.memory_system._emit_status( - emitter, f"๐Ÿค– LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False + emitter, + f"๐Ÿค– LLM Analyzing {len(llm_candidates)} Memories for Relevance", + done=False, ) logger.info(f"Using LLM reranking: {decision_reason}") - selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) - + selected_memories = await self._llm_select_memories( + user_message, llm_candidates, max_injection, emitter + ) + if not selected_memories: logger.info("๐Ÿ“ญ No relevant memories after LLM analysis") - await self.memory_system._emit_status(emitter, f"๐Ÿ“ญ No Relevant Memories After LLM Analysis", done=True) + await self.memory_system._emit_status( + emitter, f"๐Ÿ“ญ No Relevant Memories After LLM Analysis", done=True + ) return selected_memories, analysis_info else: logger.info(f"Skipping LLM reranking: {decision_reason}") selected_memories = candidate_memories[:max_injection] - + duration = time.time() - start_time duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" retrieval_method = "LLM" if should_use_llm else "Semantic" - await self.memory_system._emit_status(emitter, f"๐ŸŽฏ {retrieval_method} Memory Retrieval Complete{duration_text}", done=True) + await self.memory_system._emit_status( + emitter, + f"๐ŸŽฏ {retrieval_method} Memory Retrieval Complete{duration_text}", + done=True, + ) return selected_memories, analysis_info @@ -765,25 +940,43 @@ class LLMConsolidationService: def __init__(self, memory_system): self.memory_system = memory_system - def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: + def _filter_consolidation_candidates( + self, similarities: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], str]: """Filter consolidation candidates by threshold and return candidates with threshold info.""" - consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) - candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold] - - max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) + consolidation_threshold = self.memory_system._get_retrieval_threshold( + is_consolidation=True + ) + candidates = [ + mem for mem in similarities if mem["relevance"] >= consolidation_threshold + ] + + max_consolidation_memories = int( + self.memory_system.valves.max_memories_returned + * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) candidates = candidates[:max_consolidation_memories] - - threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + + threshold_info = ( + f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + ) return candidates, threshold_info async def collect_consolidation_candidates( - self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> List[Dict[str, Any]]: """Collect candidate memories for consolidation analysis using cached or computed similarities.""" if cached_similarities: - candidates, threshold_info = self._filter_consolidation_candidates(cached_similarities) + candidates, threshold_info = self._filter_consolidation_candidates( + cached_similarities + ) - logger.info(f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") + logger.info( + f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" + ) self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates @@ -791,7 +984,9 @@ class LLMConsolidationService: try: user_memories = await self.memory_system._get_user_memories(user_id) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" + ) except Exception as e: logger.error(f"๐Ÿ’พ Failed to retrieve user memories from database: {str(e)}") return [] @@ -800,30 +995,47 @@ class LLMConsolidationService: logger.info("๐Ÿ’ญ No existing memories found for consolidation") return [] - logger.info(f"๐Ÿš€ Reusing cached user memories for consolidation: {len(user_memories)} memories") + logger.info( + f"๐Ÿš€ Reusing cached user memories for consolidation: {len(user_memories)} memories" + ) try: - all_similarities, _, _ = await self.memory_system._compute_similarities(user_message, user_id, user_memories) + all_similarities, _, _ = await self.memory_system._compute_similarities( + user_message, user_id, user_memories + ) except Exception as e: - logger.error(f"๐Ÿ” Failed to compute memory similarities for retrieval: {str(e)}") + logger.error( + f"๐Ÿ” Failed to compute memory similarities for retrieval: {str(e)}" + ) return [] if all_similarities: - candidates, threshold_info = self._filter_consolidation_candidates(all_similarities) + candidates, threshold_info = self._filter_consolidation_candidates( + all_similarities + ) else: candidates = [] - threshold_info = 'N/A' + threshold_info = "N/A" - logger.info(f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") + logger.info( + f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" + ) self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates - async def generate_consolidation_plan(self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None) -> List[Dict[str, Any]]: + async def generate_consolidation_plan( + self, + user_message: str, + candidate_memories: List[Dict[str, Any]], + emitter: Optional[Callable] = None, + ) -> List[Dict[str, Any]]: """Generate consolidation plan using LLM with clear system/user prompt separation.""" if candidate_memories: - memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) + memory_lines = self.memory_system._format_memories_for_llm( + candidate_memories + ) memory_context = f"EXISTING MEMORIES FOR CONSOLIDATION:\n{chr(10).join(memory_lines)}\n\n" else: memory_context = "EXISTING MEMORIES FOR CONSOLIDATION:\n[]\n\nNote: No existing memories found - Focus on extracting new memories from the user message below.\n\n" @@ -834,51 +1046,96 @@ class LLMConsolidationService: try: response = await asyncio.wait_for( - self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse), + self.memory_system._query_llm( + Prompts.MEMORY_CONSOLIDATION, + user_prompt, + response_model=Models.ConsolidationResponse, + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except Exception as e: - logger.warning(f"๐Ÿค– LLM consolidation failed during memory processing: {str(e)}") - await self.memory_system._emit_status(emitter, f"โš ๏ธ Memory Consolidation Failed", done=True) + logger.warning( + f"๐Ÿค– LLM consolidation failed during memory processing: {str(e)}" + ) + await self.memory_system._emit_status( + emitter, f"โš ๏ธ Memory Consolidation Failed", done=True + ) return [] operations = response.ops existing_memory_ids = {memory["id"] for memory in candidate_memories} total_operations = len(operations) - delete_operations = [op for op in operations if op.operation == Models.MemoryOperationType.DELETE] - delete_ratio = len(delete_operations) / total_operations if total_operations > 0 else 0 + delete_operations = [ + op for op in operations if op.operation == Models.MemoryOperationType.DELETE + ] + delete_ratio = ( + len(delete_operations) / total_operations if total_operations > 0 else 0 + ) - if delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK: + if ( + delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO + and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK + ): logger.warning( f"โš ๏ธ Consolidation safety: {len(delete_operations)}/{total_operations} operations are deletions ({delete_ratio*100:.1f}%) - rejecting plan" ) return [] - valid_operations = [op.model_dump() for op in operations if op.validate_operation(existing_memory_ids)] + valid_operations = [ + op.model_dump() + for op in operations + if op.validate_operation(existing_memory_ids) + ] if valid_operations: - create_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.CREATE.value) - update_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.UPDATE.value) - delete_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.DELETE.value) + create_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.CREATE.value + ) + update_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.UPDATE.value + ) + delete_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.DELETE.value + ) - operation_details = self.memory_system._build_operation_details(create_count, update_count, delete_count) + operation_details = self.memory_system._build_operation_details( + create_count, update_count, delete_count + ) - logger.info(f"๐ŸŽฏ Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}") + logger.info( + f"๐ŸŽฏ Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}" + ) else: logger.info("๐ŸŽฏ No valid memory operations planned") return valid_operations - async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: + async def execute_memory_operations( + self, + operations: List[Dict[str, Any]], + user_id: str, + emitter: Optional[Callable] = None, + ) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" if not operations or not user_id: return 0, 0, 0, 0 try: - user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC) + user = await asyncio.wait_for( + asyncio.to_thread(Users.get_user_by_id, user_id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" + ) except Exception as e: raise RuntimeError(f"๐Ÿ‘ค User lookup failed: {str(e)}") @@ -894,23 +1151,31 @@ class LLMConsolidationService: operations_by_type[operation.operation.value].append(operation) except Exception as e: failed_count += 1 - operation_type = operation_data.get("operation", Models.OperationResult.UNSUPPORTED.value) + operation_type = operation_data.get( + "operation", Models.OperationResult.UNSUPPORTED.value + ) content_preview = "" if "content" in operation_data: content = operation_data.get("content", "") content_preview = f" - Content: {self.memory_system._truncate_content(content, Constants.CONTENT_PREVIEW_LENGTH)}" elif "id" in operation_data: content_preview = f" - ID: {operation_data['id']}" - error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}" + error_message = ( + f"Failed {operation_type} operation{content_preview}: {str(e)}" + ) logger.error(error_message) memory_contents_for_deletion = {} if operations_by_type["DELETE"]: try: user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} + memory_contents_for_deletion = { + str(mem.id): mem.content for mem in user_memories + } except Exception as e: - logger.warning(f"โš ๏ธ Failed to fetch memories for DELETE preview: {str(e)}") + logger.warning( + f"โš ๏ธ Failed to fetch memories for DELETE preview: {str(e)}" + ) for operation_type, ops in operations_by_type.items(): if not ops: @@ -925,31 +1190,56 @@ class LLMConsolidationService: results = await asyncio.gather(*batch_tasks, return_exceptions=True) for idx, result in enumerate(results): operation = ops[idx] - + if isinstance(result, Exception): failed_count += 1 - await self.memory_system._emit_status(emitter, f"โŒ Failed {operation_type}", done=False) + await self.memory_system._emit_status( + emitter, f"โŒ Failed {operation_type}", done=False + ) elif result == Models.MemoryOperationType.CREATE.value: created_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"๐Ÿ“ Created: {content_preview}", done=False) + content_preview = self.memory_system._truncate_content( + operation.content + ) + await self.memory_system._emit_status( + emitter, f"๐Ÿ“ Created: {content_preview}", done=False + ) elif result == Models.MemoryOperationType.UPDATE.value: updated_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"โœ๏ธ Updated: {content_preview}", done=False) + content_preview = self.memory_system._truncate_content( + operation.content + ) + await self.memory_system._emit_status( + emitter, f"โœ๏ธ Updated: {content_preview}", done=False + ) elif result == Models.MemoryOperationType.DELETE.value: deleted_count += 1 - content_preview = memory_contents_for_deletion.get(operation.id, operation.id) + content_preview = memory_contents_for_deletion.get( + operation.id, operation.id + ) if content_preview and content_preview != operation.id: - content_preview = self.memory_system._truncate_content(content_preview) - await self.memory_system._emit_status(emitter, f"๐Ÿ—‘๏ธ Deleted: {content_preview}", done=False) - elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]: + content_preview = self.memory_system._truncate_content( + content_preview + ) + await self.memory_system._emit_status( + emitter, f"๐Ÿ—‘๏ธ Deleted: {content_preview}", done=False + ) + elif result in [ + Models.OperationResult.FAILED.value, + Models.OperationResult.UNSUPPORTED.value, + ]: failed_count += 1 - await self.memory_system._emit_status(emitter, f"โŒ Failed {operation_type}", done=False) + await self.memory_system._emit_status( + emitter, f"โŒ Failed {operation_type}", done=False + ) except Exception as e: failed_count += len(ops) - logger.error(f"โŒ Batch {operation_type} operations failed during memory consolidation: {str(e)}") - await self.memory_system._emit_status(emitter, f"โŒ Batch {operation_type} Failed", done=False) + logger.error( + f"โŒ Batch {operation_type} operations failed during memory consolidation: {str(e)}" + ) + await self.memory_system._emit_status( + emitter, f"โŒ Batch {operation_type} Failed", done=False + ) total_executed = created_count + updated_count + deleted_count logger.info( @@ -957,14 +1247,20 @@ class LLMConsolidationService: ) if total_executed > 0: - operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) + operation_details = self.memory_system._build_operation_details( + created_count, updated_count, deleted_count + ) logger.info(f"๐Ÿ”„ Memory Operations: {', '.join(operation_details)}") await self.memory_system._refresh_user_cache(user_id) return created_count, updated_count, deleted_count, failed_count async def run_consolidation_pipeline( - self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + emitter: Optional[Callable] = None, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> None: """Complete consolidation pipeline with simplified flow.""" start_time = time.time() @@ -972,36 +1268,52 @@ class LLMConsolidationService: if self.memory_system._shutdown_event.is_set(): return - candidates = await self.collect_consolidation_candidates(user_message, user_id, cached_similarities) + candidates = await self.collect_consolidation_candidates( + user_message, user_id, cached_similarities + ) if self.memory_system._shutdown_event.is_set(): return - operations = await self.generate_consolidation_plan(user_message, candidates, emitter) + operations = await self.generate_consolidation_plan( + user_message, candidates, emitter + ) if self.memory_system._shutdown_event.is_set(): return if operations: - created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter) - + created_count, updated_count, deleted_count, failed_count = ( + await self.execute_memory_operations(operations, user_id, emitter) + ) + duration = time.time() - start_time logger.info(f"๐Ÿ’พ Memory Consolidation Complete In {duration:.2f}s") - + total_operations = created_count + updated_count + deleted_count if total_operations > 0 or failed_count > 0: - await self.memory_system._emit_status(emitter, f"๐Ÿ’พ Memory Consolidation Complete in {duration:.2f}s", done=False) - - operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) + await self.memory_system._emit_status( + emitter, + f"๐Ÿ’พ Memory Consolidation Complete in {duration:.2f}s", + done=False, + ) + + operation_details = self.memory_system._build_operation_details( + created_count, updated_count, deleted_count + ) memory_word = "Memory" if total_operations == 1 else "Memories" operations_summary = f"{', '.join(operation_details)} {memory_word}" - + if failed_count > 0: operations_summary += f" (โŒ {failed_count} Failed)" - - await self.memory_system._emit_status(emitter, operations_summary, done=True) + + await self.memory_system._emit_status( + emitter, operations_summary, done=True + ) except Exception as e: duration = time.time() - start_time - raise RuntimeError(f"โŒ Memory consolidation failed after {duration:.2f}s: {str(e)}") + raise RuntimeError( + f"โŒ Memory consolidation failed after {duration:.2f}s: {str(e)}" + ) class Filter: @@ -1015,25 +1327,53 @@ class Filter: class Valves(BaseModel): """Configuration valves for the Memory System.""" - model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - - max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") - max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") - - semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval") - relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)") - - enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") - llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)") + model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Model name for LLM operations", + ) + use_custom_model_for_memory: bool = Field( + default=False, + description="Use a custom model for memory operations instead of the current chat model", + ) + custom_memory_model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Custom model to use for memory operations when enabled", + ) + max_memories_returned: int = Field( + default=Constants.MAX_MEMORIES_PER_RETRIEVAL, + description="Maximum number of memories to return in context", + ) + max_message_chars: int = Field( + default=Constants.MAX_MESSAGE_CHARS, + description="Maximum user message length before skipping memory operations", + ) + semantic_retrieval_threshold: float = Field( + default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, + description="Minimum similarity threshold for memory retrieval", + ) + relaxed_semantic_threshold_multiplier: float = Field( + default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, + description="Adjusts similarity threshold for memory consolidation (lower = more candidates)", + ) + enable_llm_reranking: bool = Field( + default=True, + description="Enable LLM-based memory reranking for improved contextual selection", + ) + llm_reranking_trigger_multiplier: float = Field( + default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, + description="Controls when LLM reranking activates (lower = more aggressive)", + ) def __init__(self): """Initialize the Memory System filter with production validation.""" global _SHARED_SKIP_DETECTOR_CACHE - + self.valves = self.Valves() self._validate_system_configuration() - self._cache_manager = UnifiedCacheManager(Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES) + self._cache_manager = UnifiedCacheManager( + Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES + ) self._background_tasks: set = set() self._shutdown_event = asyncio.Event() @@ -1043,8 +1383,13 @@ class Filter: self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) - async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, - __model__: Optional[str] = None, __request__: Optional[Request] = None) -> None: + async def _set_pipeline_context( + self, + __event_emitter__: Optional[Callable] = None, + __user__: Optional[Dict[str, Any]] = None, + __model__: Optional[str] = None, + __request__: Optional[Request] = None, + ) -> None: """Set pipeline context parameters to avoid duplication in inlet/outlet methods.""" if __event_emitter__: self.__current_event_emitter__ = __event_emitter__ @@ -1054,37 +1399,50 @@ class Filter: self.__model__ = __model__ if __request__: self.__request__ = __request__ - - if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'): + + if self._embedding_function is None and hasattr( + __request__.app.state, "EMBEDDING_FUNCTION" + ): self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"โœ… Using OpenWebUI's embedding function") - + if self._skip_detector is None: global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK - embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') - embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') + embedding_engine = getattr( + __request__.app.state.config, "RAG_EMBEDDING_ENGINE", "" + ) + embedding_model = getattr( + __request__.app.state.config, "RAG_EMBEDDING_MODEL", "" + ) cache_key = f"{embedding_engine}:{embedding_model}" - + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: if cache_key in _SHARED_SKIP_DETECTOR_CACHE: logger.info(f"โ™ป๏ธ Reusing cached skip detector: {cache_key}") self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] else: - logger.info(f"๐Ÿค– Initializing skip detector with OpenWebUI embeddings: {cache_key}") + logger.info( + f"๐Ÿค– Initializing skip detector with OpenWebUI embeddings: {cache_key}" + ) embedding_fn = self._embedding_function - def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: + + def embedding_wrapper( + texts: Union[str, List[str]], + ) -> Union[np.ndarray, List[np.ndarray]]: result = embedding_fn(texts, prefix=None, user=None) if isinstance(result, list): if isinstance(result[0], list): - return [np.array(emb, dtype=np.float16) for emb in result] + return [ + np.array(emb, dtype=np.float16) + for emb in result + ] return np.array(result, dtype=np.float16) return np.array(result, dtype=np.float16) - + self._skip_detector = SkipDetector(embedding_wrapper) _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector logger.info(f"โœ… Skip detector initialized and cached") - def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" if max_length is None: @@ -1094,7 +1452,10 @@ class Filter: def _get_retrieval_threshold(self, is_consolidation: bool = False) -> float: """Calculate retrieval threshold for semantic similarity filtering.""" if is_consolidation: - return self.valves.semantic_retrieval_threshold * self.valves.relaxed_semantic_threshold_multiplier + return ( + self.valves.semantic_retrieval_threshold + * self.valves.relaxed_semantic_threshold_multiplier + ) return self.valves.semantic_retrieval_threshold def _extract_text_from_content(self, content) -> str: @@ -1116,26 +1477,36 @@ class Filter: raise ValueError("๐Ÿค– Model not specified") if self.valves.max_memories_returned <= 0: - raise ValueError(f"๐Ÿ“Š Invalid max memories returned: {self.valves.max_memories_returned}") + raise ValueError( + f"๐Ÿ“Š Invalid max memories returned: {self.valves.max_memories_returned}" + ) if not (0.0 <= self.valves.semantic_retrieval_threshold <= 1.0): - raise ValueError(f"๐ŸŽฏ Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)") + raise ValueError( + f"๐ŸŽฏ Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)" + ) logger.info("โœ… Configuration validated") async def _get_embedding_cache(self, user_id: str, key: str) -> Optional[Any]: """Get embedding from cache.""" - return await self._cache_manager.get(user_id, self._cache_manager.EMBEDDING_CACHE, key) + return await self._cache_manager.get( + user_id, self._cache_manager.EMBEDDING_CACHE, key + ) async def _put_embedding_cache(self, user_id: str, key: str, value: Any) -> None: """Store embedding in cache.""" - await self._cache_manager.put(user_id, self._cache_manager.EMBEDDING_CACHE, key, value) + await self._cache_manager.put( + user_id, self._cache_manager.EMBEDDING_CACHE, key, value + ) def _compute_text_hash(self, text: str) -> str: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: + def _normalize_embedding( + self, embedding: Union[List[float], np.ndarray] + ) -> np.ndarray: """Normalize embedding vector.""" if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float16) @@ -1144,11 +1515,15 @@ class Filter: norm = np.linalg.norm(embedding) return embedding / norm if norm > 0 else embedding - async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: + async def _generate_embeddings( + self, texts: Union[str, List[str]], user_id: str + ) -> Union[np.ndarray, List[np.ndarray]]: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" if self._embedding_function is None: - raise RuntimeError("๐Ÿค– Embedding function not initialized. Ensure pipeline context is set.") - + raise RuntimeError( + "๐Ÿค– Embedding function not initialized. Ensure pipeline context is set." + ) + is_single = isinstance(texts, str) text_list = [texts] if is_single else texts @@ -1181,20 +1556,22 @@ class Filter: uncached_hashes.append(text_hash) if uncached_texts: - user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None - + user = ( + await asyncio.to_thread(Users.get_user_by_id, user_id) + if hasattr(self, "__user__") + else None + ) + loop = asyncio.get_event_loop() raw_embeddings = await loop.run_in_executor( - None, - self._embedding_function, - uncached_texts, - None, - user + None, self._embedding_function, uncached_texts, None, user ) - + if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], list): - new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] + new_embeddings = [ + self._normalize_embedding(emb) for emb in raw_embeddings + ] else: new_embeddings = [self._normalize_embedding(raw_embeddings)] else: @@ -1207,7 +1584,11 @@ class Filter: result_embeddings[original_idx] = embedding if is_single: - logger.info("๐Ÿ“ฅ User message embedding: cache hit" if not uncached_texts else "๐Ÿ’พ User message embedding: generated and cached") + logger.info( + "๐Ÿ“ฅ User message embedding: cache hit" + if not uncached_texts + else "๐Ÿ’พ User message embedding: generated and cached" + ) return result_embeddings[0] else: valid_count = sum(1 for emb in result_embeddings if emb is not None) @@ -1219,17 +1600,25 @@ class Filter: def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: if self._skip_detector is None: raise RuntimeError("๐Ÿค– Skip detector not initialized") - - skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) + + skip_reason = self._skip_detector.detect_skip_reason( + user_message, self.valves.max_message_chars, memory_system=self + ) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key] return False, "" - def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: + def _process_user_message( + self, body: Dict[str, Any] + ) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" if not body or "messages" not in body or not isinstance(body["messages"], list): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) messages = body["messages"] user_message = None @@ -1245,23 +1634,34 @@ class Filter: break if not user_message or not user_message.strip(): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) should_skip, skip_reason = self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason - async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: + async def _get_user_memories( + self, user_id: str, timeout: Optional[float] = None + ) -> List: """Get user memories with timeout handling.""" if timeout is None: timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC try: - return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout) + return await asyncio.wait_for( + asyncio.to_thread(Memories.get_memories_by_user_id, user_id), + timeout=timeout, + ) except asyncio.TimeoutError: raise TimeoutError(f"โฑ๏ธ Memory retrieval timed out after {timeout}s") except Exception as e: raise RuntimeError(f"๐Ÿ’พ Memory retrieval failed: {str(e)}") - def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: + def _log_retrieved_memories( + self, memories: List[Dict[str, Any]], context_type: str = "semantic" + ) -> None: """Log retrieved memories with concise formatting showing key statistics and semantic values.""" if not memories: return @@ -1275,22 +1675,42 @@ class Filter: lowest_score = min(scores) median_score = statistics.median(scores) - context_label = "๐Ÿ“Š Consolidation candidate memories" if context_type == "consolidation" else "๐Ÿ“Š Retrieved memories" - max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - scores_str = ", ".join([f"{score:.3f}" for score in scores[:max_scores_to_show]]) + context_label = ( + "๐Ÿ“Š Consolidation candidate memories" + if context_type == "consolidation" + else "๐Ÿ“Š Retrieved memories" + ) + max_scores_to_show = int( + self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) + scores_str = ", ".join( + [f"{score:.3f}" for score in scores[:max_scores_to_show]] + ) suffix = "..." if len(scores) > max_scores_to_show else "" - logger.info(f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}") + logger.info( + f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}" + ) logger.info(f"Scores: [{scores_str}{suffix}]") - def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]: - operations = [(created_count, "๐Ÿ“ Created"), (updated_count, "โœ๏ธ Updated"), (deleted_count, "๐Ÿ—‘๏ธ Deleted")] + def _build_operation_details( + self, created_count: int, updated_count: int, deleted_count: int + ) -> List[str]: + operations = [ + (created_count, "๐Ÿ“ Created"), + (updated_count, "โœ๏ธ Updated"), + (deleted_count, "๐Ÿ—‘๏ธ Deleted"), + ] return [f"{label} {count}" for count, label in operations if count > 0] - def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: + def _cache_key( + self, cache_type: str, user_id: str, content: Optional[str] = None + ) -> str: """Unified cache key generation for all cache types.""" if content: - content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[ + : Constants.CACHE_KEY_HASH_PREFIX_LENGTH + ] return f"{cache_type}_{user_id}:{content_hash}" return f"{cache_type}_{user_id}" @@ -1310,7 +1730,9 @@ class Filter: if record_date: try: if isinstance(record_date, str): - parsed_date = datetime.fromisoformat(record_date.replace('Z', '+00:00')) + parsed_date = datetime.fromisoformat( + record_date.replace("Z", "+00:00") + ) else: parsed_date = record_date formatted_date = parsed_date.strftime("%b %d %Y") @@ -1321,7 +1743,9 @@ class Filter: memory_lines.append(line) return memory_lines - async def _emit_status(self, emitter: Optional[Callable], description: str, done: bool = True) -> None: + async def _emit_status( + self, emitter: Optional[Callable], description: str, done: bool = True + ) -> None: """Emit status messages for memory operations.""" if not emitter: return @@ -1345,9 +1769,19 @@ class Filter: ) -> Dict[str, Any]: """Retrieve memories for injection using similarity computation with optional LLM reranking.""" if cached_similarities is not None: - memories = [m for m in cached_similarities if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold] - logger.info(f"๐Ÿ” Using cached similarities for {len(memories)} candidate memories") - final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) + memories = [ + m + for m in cached_similarities + if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold + ] + logger.info( + f"๐Ÿ” Using cached similarities for {len(memories)} candidate memories" + ) + final_memories, reranking_info = ( + await self._llm_reranking_service.rerank_memories( + user_message, memories, emitter + ) + ) self._log_retrieved_memories(final_memories, "semantic") return { "memories": final_memories, @@ -1364,10 +1798,16 @@ class Filter: await self._emit_status(emitter, "๐Ÿ“ญ No Memories Found", done=True) return {"memories": [], "threshold": None} - memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories) + memories, threshold, all_similarities = await self._compute_similarities( + user_message, user_id, user_memories + ) if memories: - final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) + final_memories, reranking_info = ( + await self._llm_reranking_service.rerank_memories( + user_message, memories, emitter + ) + ) else: logger.info("๐Ÿ“ญ No relevant memories found above similarity threshold") await self._emit_status(emitter, "๐Ÿ“ญ No Relevant Memories Found", done=True) @@ -1376,10 +1816,19 @@ class Filter: self._log_retrieved_memories(final_memories, "semantic") - return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info} + return { + "memories": final_memories, + "threshold": threshold, + "all_similarities": all_similarities, + "reranking_info": reranking_info, + } async def _add_memory_context( - self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None + self, + body: Dict[str, Any], + memories: Optional[List[Dict[str, Any]]] = None, + user_id: Optional[str] = None, + emitter: Optional[Callable] = None, ) -> None: """Add memory context to request body with simplified logic.""" if not body or "messages" not in body or not body["messages"]: @@ -1393,38 +1842,57 @@ class Filter: memory_count = len(memories) memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request." formatted_memories = [] - + for idx, memory in enumerate(memories, 1): formatted_memory = f"- {' '.join(memory['content'].split())}" formatted_memories.append(formatted_memory) - - content_preview = self._truncate_content(memory['content']) - await self._emit_status(emitter, f"๐Ÿ’ญ {idx}/{memory_count}: {content_preview}", done=False) - + + content_preview = self._truncate_content(memory["content"]) + await self._emit_status( + emitter, f"๐Ÿ’ญ {idx}/{memory_count}: {content_preview}", done=False + ) + memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only." memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}" content_parts.append(memory_context_block) memory_context = "\n\n".join(content_parts) - system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None) + system_index = next( + ( + i + for i, msg in enumerate(body["messages"]) + if msg.get("role") == "system" + ), + None, + ) if system_index is not None: - body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" + body["messages"][system_index][ + "content" + ] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" else: body["messages"].insert(0, {"role": "system", "content": memory_context}) - + if memories and user_id: description = f"๐Ÿง  Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context" await self._emit_status(emitter, description, done=True) def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]: """Build memory dictionary with standardized timestamp conversion.""" - memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity} + memory_dict = { + "id": str(memory.id), + "content": memory.content, + "relevance": similarity, + } if hasattr(memory, "created_at") and memory.created_at: - memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat() + memory_dict["created_at"] = datetime.fromtimestamp( + memory.created_at, tz=timezone.utc + ).isoformat() if hasattr(memory, "updated_at") and memory.updated_at: - memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat() + memory_dict["updated_at"] = datetime.fromtimestamp( + memory.updated_at, tz=timezone.utc + ).isoformat() return memory_dict async def _compute_similarities( @@ -1439,7 +1907,9 @@ class Filter: memory_embeddings = await self._generate_embeddings(memory_contents, user_id) if len(memory_embeddings) != len(user_memories): - logger.error(f"๐Ÿ”ข Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories") + logger.error( + f"๐Ÿ”ข Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories" + ) return [], self.valves.semantic_retrieval_threshold, [] similarity_scores = [] @@ -1461,7 +1931,7 @@ class Filter: memory_data.sort(key=lambda x: x["relevance"], reverse=True) threshold = self.valves.semantic_retrieval_threshold - filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] + filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] return filtered_memories, threshold, memory_data async def inlet( @@ -1474,42 +1944,68 @@ class Filter: **kwargs, ) -> Dict[str, Any]: """Simplified inlet processing for memory retrieval and injection.""" - await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"โš ๏ธ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"๐Ÿง  Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context( + __event_emitter__, __user__, model_to_use, __request__ + ) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True) await self._add_memory_context(body, [], user_id, __event_emitter__) return body - try: - memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) - user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key) - + memory_cache_key = self._cache_key( + self._cache_manager.MEMORY_CACHE, user_id + ) + user_memories = await self._cache_manager.get( + user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key + ) if user_memories is None: user_memories = await self._get_user_memories(user_id) - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) - - retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) + retrieval_result = await self._retrieve_relevant_memories( + user_message, user_id, user_memories, __event_emitter__ + ) memories = retrieval_result.get("memories", []) threshold = retrieval_result.get("threshold") all_similarities = retrieval_result.get("all_similarities", []) - if all_similarities: - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities) - + cache_key = self._cache_key( + self._cache_manager.RETRIEVAL_CACHE, user_id, user_message + ) + await self._cache_manager.put( + user_id, + self._cache_manager.RETRIEVAL_CACHE, + cache_key, + all_similarities, + ) await self._add_memory_context(body, memories, user_id, __event_emitter__) - except Exception as e: raise RuntimeError(f"๐Ÿ’พ Memory retrieval failed: {str(e)}") - return body async def outlet( @@ -1522,22 +2018,40 @@ class Filter: **kwargs, ) -> dict: """Simplified outlet processing for background memory consolidation.""" - await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"โš ๏ธ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"๐Ÿง  Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context( + __event_emitter__, __user__, model_to_use, __request__ + ) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: return body - - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) - + cache_key = self._cache_key( + self._cache_manager.RETRIEVAL_CACHE, user_id, user_message + ) + cached_similarities = await self._cache_manager.get( + user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key + ) task = asyncio.create_task( - self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities) + self._llm_consolidation_service.run_consolidation_pipeline( + user_message, user_id, __event_emitter__, cached_similarities + ) ) self._background_tasks.add(task) @@ -1546,12 +2060,13 @@ class Filter: self._background_tasks.discard(t) if t.exception() and not t.cancelled(): exception = t.exception() - logger.error(f"โŒ Background memory consolidation task failed: {str(exception)}") + logger.error( + f"โŒ Background memory consolidation task failed: {str(exception)}" + ) except Exception as e: logger.error(f"โŒ Failed to cleanup background memory task: {str(e)}") task.add_done_callback(safe_cleanup) - return body async def shutdown(self) -> None: @@ -1568,35 +2083,57 @@ class Filter: """Refresh user cache - clear stale caches and update with fresh embeddings.""" start_time = time.time() try: - retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) - embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) - logger.info(f"๐Ÿ”„ Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}") + retrieval_cleared = await self._cache_manager.clear_user_cache( + user_id, self._cache_manager.RETRIEVAL_CACHE + ) + embedding_cleared = await self._cache_manager.clear_user_cache( + user_id, self._cache_manager.EMBEDDING_CACHE + ) + logger.info( + f"๐Ÿ”„ Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}" + ) user_memories = await self._get_user_memories(user_id) - memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) + memory_cache_key = self._cache_key( + self._cache_manager.MEMORY_CACHE, user_id + ) if not user_memories: - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, []) + await self._cache_manager.put( + user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, [] + ) logger.info("๐Ÿ“ญ No memories found for user") return - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) memory_contents = [ memory.content for memory in user_memories - if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS + if memory.content + and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS ] if memory_contents: await self._generate_embeddings(memory_contents, user_id) duration = time.time() - start_time - logger.info(f"๐Ÿ”„ Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s") + logger.info( + f"๐Ÿ”„ Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s" + ) except Exception as e: - raise RuntimeError(f"๐Ÿงน Failed to refresh cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}") + raise RuntimeError( + f"๐Ÿงน Failed to refresh cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}" + ) - async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: + async def _execute_single_operation( + self, operation: Models.MemoryOperation, user: Any + ) -> str: """Execute a single memory operation.""" try: if operation.operation == Models.MemoryOperationType.CREATE: @@ -1606,7 +2143,10 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread( + Memories.insert_new_memory, user.id, content_stripped + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.CREATE.value @@ -1615,14 +2155,21 @@ class Filter: if not id_stripped: logger.warning(f"โš ๏ธ Skipping UPDATE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value - + content_stripped = operation.content.strip() if not content_stripped: - logger.warning(f"โš ๏ธ Skipping UPDATE operation for {id_stripped}: empty content") + logger.warning( + f"โš ๏ธ Skipping UPDATE operation for {id_stripped}: empty content" + ) return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped), + asyncio.to_thread( + Memories.update_memory_by_id_and_user_id, + id_stripped, + user.id, + content_stripped, + ), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.UPDATE.value @@ -1634,7 +2181,10 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_ID.value await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread( + Memories.delete_memory_by_id_and_user_id, id_stripped, user.id + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.DELETE.value else: @@ -1642,42 +2192,62 @@ class Filter: return Models.OperationResult.UNSUPPORTED.value except Exception as e: - logger.error(f"๐Ÿ’พ Database operation failed for {operation.operation.value}: {str(e)}") + logger.error( + f"๐Ÿ’พ Database operation failed for {operation.operation.value}: {str(e)}" + ) return Models.OperationResult.FAILED.value - def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _remove_refs_from_schema( + self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Remove $ref references and ensure required fields for Azure OpenAI.""" if not isinstance(schema, dict): return schema - - if '$ref' in schema: - ref_path = schema['$ref'] - if ref_path.startswith('#/$defs/'): - def_name = ref_path.split('/')[-1] + + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] if schema_defs and def_name in schema_defs: - return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs) - return {'type': 'object'} - + return self._remove_refs_from_schema( + schema_defs[def_name].copy(), schema_defs + ) + return {"type": "object"} + result = {} for key, value in schema.items(): - if key == '$defs': + if key == "$defs": continue elif isinstance(value, dict): result[key] = self._remove_refs_from_schema(value, schema_defs) elif isinstance(value, list): - result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value] + result[key] = [ + ( + self._remove_refs_from_schema(item, schema_defs) + if isinstance(item, dict) + else item + ) + for item in value + ] else: result[key] = value - - if result.get('type') == 'object' and 'properties' in result: - result['required'] = list(result['properties'].keys()) - + + if result.get("type") == "object" and "properties" in result: + result["required"] = list(result["properties"].keys()) + return result - async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]: + async def _query_llm( + self, + system_prompt: str, + user_prompt: str, + response_model: Optional[BaseModel] = None, + ) -> Union[str, BaseModel]: """Query OpenWebUI's internal model system with Pydantic model parsing.""" if not hasattr(self, "__request__") or not hasattr(self, "__user__"): - raise RuntimeError("๐Ÿ”ง Pipeline interface not properly initialized. __request__ and __user__ required.") + raise RuntimeError( + "๐Ÿ”ง Pipeline interface not properly initialized. __request__ and __user__ required." + ) model_to_use = self.valves.model if self.valves.model else self.__model__ if not model_to_use: @@ -1685,30 +2255,50 @@ class Filter: form_data = { "model": model_to_use, - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], "max_tokens": 4096, "stream": False, } if response_model: raw_schema = response_model.model_json_schema() - schema_defs = raw_schema.get('$defs', {}) + schema_defs = raw_schema.get("$defs", {}) schema = self._remove_refs_from_schema(raw_schema, schema_defs) - schema['type'] = 'object' - form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}} + schema["type"] = "object" + form_data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": response_model.__name__, + "strict": True, + "schema": schema, + }, + } try: response = await asyncio.wait_for( - generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])), + generate_chat_completion( + self.__request__, + form_data, + user=await asyncio.to_thread( + Users.get_user_by_id, self.__user__["id"] + ), + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s" + ) except Exception as e: raise RuntimeError(f"๐Ÿค– LLM query failed: {str(e)}") try: - if hasattr(response, "body") and hasattr(getattr(response, "body", None), "decode"): + if hasattr(response, "body") and hasattr( + getattr(response, "body", None), "decode" + ): body = getattr(response, "body") response_data = json.loads(body.decode("utf-8")) else: @@ -1716,23 +2306,39 @@ class Filter: except (json.JSONDecodeError, AttributeError) as e: raise RuntimeError(f"๐Ÿ” Failed to decode response body: {str(e)}") - if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0: + if ( + isinstance(response_data, dict) + and "choices" in response_data + and isinstance(response_data["choices"], list) + and len(response_data["choices"]) > 0 + ): first_choice = response_data["choices"][0] - if isinstance(first_choice, dict) and "message" in first_choice and isinstance(first_choice["message"], dict) and "content" in first_choice["message"]: + if ( + isinstance(first_choice, dict) + and "message" in first_choice + and isinstance(first_choice["message"], dict) + and "content" in first_choice["message"] + ): content = first_choice["message"]["content"] else: - raise ValueError("๐Ÿค– Invalid response structure: missing content in message") + raise ValueError( + "๐Ÿค– Invalid response structure: missing content in message" + ) else: raise ValueError(f"๐Ÿค– Unexpected LLM response format: {response_data}") if response_model: - try: + try: parsed_data = json.loads(content) return response_model.model_validate(parsed_data) except json.JSONDecodeError as e: - raise ValueError(f"๐Ÿ” Invalid JSON from LLM: {str(e)}\nContent: {content}") + raise ValueError( + f"๐Ÿ” Invalid JSON from LLM: {str(e)}\nContent: {content}" + ) except PydanticValidationError as e: - raise ValueError(f"๐Ÿ” LLM response validation failed: {str(e)}\nContent: {content}") + raise ValueError( + f"๐Ÿ” LLM response validation failed: {str(e)}\nContent: {content}" + ) if not content or content.strip() == "": raise ValueError("๐Ÿค– Empty response from LLM")