From 89399f57ccabaa108dcf371a057583109719cb9f Mon Sep 17 00:00:00 2001 From: GlissemanTV Date: Sun, 26 Oct 2025 16:01:13 +0100 Subject: [PATCH] add current model workflow with checkbox --- memory_system.py | 1440 ++++++++++++++++++++++++++++++++-------------- 1 file changed, 1009 insertions(+), 431 deletions(-) diff --git a/memory_system.py b/memory_system.py index 356c56c..d52cd2d 100644 --- a/memory_system.py +++ b/memory_system.py @@ -7,6 +7,7 @@ import asyncio import hashlib import json import logging +import statistics import time from collections import OrderedDict from datetime import datetime, timezone @@ -14,22 +15,27 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError as PydanticValidationError, +) from open_webui.utils.chat import generate_chat_completion from open_webui.models.users import Users from open_webui.routers.memories import Memories from fastapi import Request -logging.getLogger("transformers").setLevel(logging.ERROR) - -logger = logging.getLogger("MemorySystem") +logger = logging.getLogger(__name__) _SHARED_SKIP_DETECTOR_CACHE = {} +_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock() + class Constants: """Centralized configuration constants for the memory system.""" - + # Core System Limits MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query @@ -37,33 +43,40 @@ class Constants: MIN_MESSAGE_CHARS = 10 # Minimum message length for validation DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations - + # Cache System - MAX_CACHE_ENTRIES_PER_TYPE = 2500 # Maximum cache entries per cache type - MAX_CONCURRENT_USER_CACHES = 250 # Maximum concurrent user cache instances + MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type + MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys - + # Retrieval & Similarity - SEMANTIC_RETRIEVAL_THRESHOLD = 0.5 # Semantic similarity threshold for retrieval - RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.9 # Multiplier for relaxed similarity threshold in secondary operations - EXTENDED_MAX_MEMORY_MULTIPLIER = 1.5 # Multiplier for expanding memory candidates in advanced operations - LLM_RERANKING_TRIGGER_MULTIPLIER = 0.5 # Multiplier for LLM reranking trigger threshold - - # Skip Detection Thresholds - SKIP_DETECTION_SIMILARITY_THRESHOLD = 0.50 # Similarity threshold for skip category detection (tuned for zero-shot) - SKIP_DETECTION_MARGIN = 0.05 # Minimum margin required between skip and conversational similarity to skip - SKIP_DETECTION_CONFIDENT_MARGIN = 0.15 # Margin threshold for confident skips that trigger early exit - + SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval + RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = ( + 0.8 # Multiplier for relaxed similarity threshold in secondary operations + ) + EXTENDED_MAX_MEMORY_MULTIPLIER = ( + 1.6 # Multiplier for expanding memory candidates in advanced operations + ) + LLM_RERANKING_TRIGGER_MULTIPLIER = ( + 0.8 # Multiplier for LLM reranking trigger threshold + ) + + # Skip Detection + SKIP_CATEGORY_MARGIN = ( + 0.5 # Margin above conversational similarity for skip category classification + ) + # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check - + # Content Display - CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display - + CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display + # Default Models DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite" + class Prompts: """Container for all LLM prompts used in the memory system.""" @@ -90,7 +103,7 @@ Build precise memories of the user's personal narrative with factual, temporal s - Retroactive Enrichment: If a name is provided for prior entity, UPDATE only if substantially valuable. - Ensure Memory Quality: - High Bar for Creation: Only CREATE memories for significant life facts, relationships, events, or core personal attributes. Skip trivial details or passing interests. - - Contextual Completeness: Create memories that combine related information into cohesive statements. When multiple facts share connections (same topic, person, event, or timeframe), group them into a single memory rather than fragmenting. Include relevant supporting details that help understand the core fact while respecting boundaries. Only combine facts that are directly related and belong together naturally. Avoid bare statements lacking context and never merge unrelated information. + - Contextual Completeness: Combine related information into cohesive statements. Group connected facts (same topic, person, event, or timeframe) into single memories rather than fragmenting. Include supporting details while respecting boundaries. Only combine directly related facts. Avoid bare statements and never merge unrelated information. - Mandatory Semantic Enhancement: Enhance entities with descriptive categorical nouns for better retrieval. - Verify Nouns/Pronouns: Link pronouns (he, she, they) and nouns to specific entities. - First-Person Format: Write all memories in English from the user's perspective. @@ -183,6 +196,7 @@ Return: {{"ids": []}} Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained. """ + class Models: """Container for all Pydantic models used in the memory system.""" @@ -205,9 +219,15 @@ class Models: class MemoryOperation(StrictModel): """Pydantic model for memory operations with validation.""" - operation: 'Models.MemoryOperationType' = Field(description="Type of memory operation to perform") - content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)") - id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)") + operation: "Models.MemoryOperationType" = Field( + description="Type of memory operation to perform" + ) + content: str = Field( + description="Memory content (required for CREATE/UPDATE, empty for DELETE)" + ) + id: str = Field( + description="Memory ID (empty for CREATE, required for UPDATE/DELETE)" + ) def validate_operation(self, existing_memory_ids: Optional[set] = None) -> bool: """Validate the memory operation against existing memory IDs.""" @@ -216,19 +236,27 @@ class Models: if self.operation == Models.MemoryOperationType.CREATE: return True - elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]: + elif self.operation in [ + Models.MemoryOperationType.UPDATE, + Models.MemoryOperationType.DELETE, + ]: return self.id in existing_memory_ids return False class ConsolidationResponse(BaseModel): """Pydantic model for memory consolidation LLM response - object containing array of memory operations.""" - ops: List['Models.MemoryOperation'] = Field(default_factory=list, description="List of memory operations to execute") + ops: List["Models.MemoryOperation"] = Field( + default_factory=list, description="List of memory operations to execute" + ) class MemoryRerankingResponse(BaseModel): """Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" - ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query") + ids: List[str] = Field( + default_factory=list, + description="List of memory IDs selected as most relevant for the user query", + ) class UnifiedCacheManager: @@ -276,7 +304,10 @@ class UnifiedCacheManager: type_cache = user_cache[cache_type] - if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type: + if ( + key not in type_cache + and len(type_cache) >= self.max_cache_size_per_type + ): evicted_key, _ = type_cache.popitem(last=False) if key in type_cache: @@ -287,7 +318,9 @@ class UnifiedCacheManager: self.caches.move_to_end(user_id) - async def clear_user_cache(self, user_id: str, cache_type: Optional[str] = None) -> int: + async def clear_user_cache( + self, user_id: str, cache_type: Optional[str] = None + ) -> int: """Clear specific cache type for user, or all caches for user if cache_type is None.""" async with self._lock: if user_id not in self.caches: @@ -296,7 +329,9 @@ class UnifiedCacheManager: user_cache = self.caches[user_id] if cache_type is None: - total_cleared = sum(len(type_cache) for type_cache in user_cache.values()) + total_cleared = sum( + len(type_cache) for type_cache in user_cache.values() + ) del self.caches[user_id] return total_cleared else: @@ -315,26 +350,6 @@ class UnifiedCacheManager: async with self._lock: self.caches.clear() - async def get_cache_stats(self) -> Dict[str, Any]: - """Get cache statistics for monitoring.""" - async with self._lock: - total_users = len(self.caches) - total_items = 0 - cache_type_counts = {} - - for user_id, user_cache in self.caches.items(): - for cache_type, type_cache in user_cache.items(): - cache_type_counts[cache_type] = cache_type_counts.get(cache_type, 0) + len(type_cache) - total_items += len(type_cache) - - return { - "total_users": total_users, - "total_items": total_items, - "cache_type_counts": cache_type_counts, - "max_users": self.max_users, - "max_cache_size_per_type": self.max_cache_size_per_type, - } - class SkipDetector: """Semantic-based content classifier using zero-shot classification with category descriptions.""" @@ -459,245 +474,337 @@ class SkipDetector: SkipReason.SKIP_GRAMMAR_PROOFREAD: "๐Ÿ“ Grammar/Proofreading Request Detected, skipping memory operations", } - def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + def __init__( + self, + embedding_function: Callable[ + [Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]] + ], + ): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None self._initialize_reference_embeddings() - + def _initialize_reference_embeddings(self) -> None: """Compute and cache embeddings for category descriptions.""" try: technical_embeddings = self.embedding_function( self.TECHNICAL_CATEGORY_DESCRIPTIONS ) - + instruction_embeddings = self.embedding_function( self.INSTRUCTION_CATEGORY_DESCRIPTIONS ) - + pure_math_embeddings = self.embedding_function( self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS ) - + translation_embeddings = self.embedding_function( self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS ) - + grammar_embeddings = self.embedding_function( self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS ) - + conversational_embeddings = self.embedding_function( self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS ) - + self._reference_embeddings = { - 'technical': np.array(technical_embeddings), - 'instruction': np.array(instruction_embeddings), - 'pure_math': np.array(pure_math_embeddings), - 'translation': np.array(translation_embeddings), - 'grammar': np.array(grammar_embeddings), - 'conversational': np.array(conversational_embeddings), + "technical": np.array(technical_embeddings), + "instruction": np.array(instruction_embeddings), + "pure_math": np.array(pure_math_embeddings), + "translation": np.array(translation_embeddings), + "grammar": np.array(grammar_embeddings), + "conversational": np.array(conversational_embeddings), } - + total_skip_categories = ( - len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + - len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + - len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + - len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + - len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + + len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + + len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + + len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + + len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) + ) + + logger.info( + f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories" ) - - logger.info(f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories") except Exception as e: logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") self._reference_embeddings = None - def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]: + def validate_message_size( + self, message: str, max_message_chars: int + ) -> Optional[str]: """Validate message size constraints.""" if not message or not message.strip(): return SkipDetector.SkipReason.SKIP_SIZE.value trimmed = message.strip() - if len(trimmed) < Constants.MIN_MESSAGE_CHARS or len(trimmed) > max_message_chars: + if ( + len(trimmed) < Constants.MIN_MESSAGE_CHARS + or len(trimmed) > max_message_chars + ): return SkipDetector.SkipReason.SKIP_SIZE.value return None def _fast_path_skip_detection(self, message: str) -> Optional[str]: """Language-agnostic structural pattern detection with high confidence and low false positive rate.""" msg_len = len(message) - + # Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references) - url_pattern_count = message.count('http://') + message.count('https://') + url_pattern_count = message.count("http://") + message.count("https://") if url_pattern_count >= 5: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64) words = message.split() for word in words: cleaned = word.strip('.,;:!?()[]{}"\'"') - if len(cleaned) > 80 and cleaned.replace('-', '').replace('_', '').isalnum(): + if ( + len(cleaned) > 80 + and cleaned.replace("-", "").replace("_", "").isalnum() + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***) - separator_patterns = ['---', '===', '___', '***'] + separator_patterns = ["---", "===", "___", "***"] for pattern in separator_patterns: if message.count(pattern) >= 2: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 4: Command-line patterns with context-aware detection - lines_stripped = [line.strip() for line in message.split('\n') if line.strip()] + lines_stripped = [line.strip() for line in message.split("\n") if line.strip()] if lines_stripped: actual_command_lines = 0 for line in lines_stripped: - if line.startswith('$ ') and len(line) > 2: + if line.startswith("$ ") and len(line) > 2: parts = line[2:].split() if parts and parts[0].isalnum(): actual_command_lines += 1 - elif '$ ' in line: - dollar_index = line.find('$ ') - if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'): - parts = line[dollar_index+2:].split() - if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ['curl', 'wget', 'git', 'npm', 'pip', 'docker']): + elif "$ " in line: + dollar_index = line.find("$ ") + if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"): + parts = line[dollar_index + 2 :].split() + if ( + parts + and len(parts[0]) > 0 + and ( + parts[0].isalnum() + or parts[0] + in ["curl", "wget", "git", "npm", "pip", "docker"] + ) + ): actual_command_lines += 1 - elif line.startswith('# ') and len(line) > 2: + elif line.startswith("# ") and len(line) > 2: rest = line[2:].strip() - if rest and not rest[0].isupper() and ' ' in rest: + if rest and not rest[0].isupper() and " " in rest: actual_command_lines += 1 - elif line.startswith('> ') and len(line) > 2: + elif line.startswith("> ") and len(line) > 2: pass - - if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']): + + if actual_command_lines >= 1 and any( + c in message for c in ["http://", "https://", " | "] + ): return self.SkipReason.SKIP_TECHNICAL.value if actual_command_lines >= 3: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs) if msg_len > 30: - slash_count = message.count('/') + message.count('\\') - dot_count = message.count('.') + slash_count = message.count("/") + message.count("\\") + dot_count = message.count(".") path_chars = slash_count + dot_count if path_chars > 10 and (path_chars / msg_len) > 0.15: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 6: Markup character density (structured data) - markup_chars = sum(message.count(c) for c in '{}[]<>') + markup_chars = sum(message.count(c) for c in "{}[]<>") if markup_chars >= 6: if markup_chars / msg_len > 0.10: return self.SkipReason.SKIP_TECHNICAL.value - curly_count = message.count('{') + message.count('}') + curly_count = message.count("{") + message.count("}") if curly_count >= 10: return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 7: Structured nested content with colons (key: value patterns) - line_count = message.count('\n') + line_count = message.count("\n") if line_count >= 8: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#')) - indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - - if (colon_lines / len(non_empty_lines) > 0.4 and - indented_lines / len(non_empty_lines) > 0.5): + colon_lines = sum( + 1 + for line in non_empty_lines + if ":" in line and not line.strip().startswith("#") + ) + indented_lines = sum( + 1 for line in non_empty_lines if line.startswith((" ", "\t")) + ) + + if ( + colon_lines / len(non_empty_lines) > 0.4 + and indented_lines / len(non_empty_lines) > 0.5 + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 8: Highly structured multi-line content (require markup chars for technical confidence) if line_count > 15: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>')) - structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t'))) - + markup_in_lines = sum( + 1 for line in non_empty_lines if any(c in line for c in "{}[]<>") + ) + structured_lines = sum( + 1 for line in non_empty_lines if line.startswith((" ", "\t")) + ) + if markup_in_lines / len(non_empty_lines) > 0.3: return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: - technical_keywords = ['function', 'class', 'import', 'return', 'const', 'var', 'let', 'def'] - if any(keyword in message.lower() for keyword in technical_keywords): + technical_keywords = [ + "function", + "class", + "import", + "return", + "const", + "var", + "let", + "def", + ] + if any( + keyword in message.lower() for keyword in technical_keywords + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists) if line_count >= 3: - lines = message.split('\n') + lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - indented_lines = sum(1 for line in non_empty_lines if line[0] in (' ', '\t')) + indented_lines = sum( + 1 for line in non_empty_lines if line[0] in (" ", "\t") + ) if indented_lines / len(non_empty_lines) > 0.5: - code_indicators = ['def ', 'class ', 'function ', 'return ', 'import ', 'const ', 'let ', 'var ', 'public ', 'private '] - if any(indicator in message.lower() for indicator in code_indicators): + code_indicators = [ + "def ", + "class ", + "function ", + "return ", + "import ", + "const ", + "let ", + "var ", + "public ", + "private ", + ] + if any( + indicator in message.lower() for indicator in code_indicators + ): return self.SkipReason.SKIP_TECHNICAL.value - + # Pattern 10: Very high special character ratio (encoded data, technical output) if msg_len > 50: - special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace()) + special_chars = sum( + 1 for c in message if not c.isalnum() and not c.isspace() + ) special_ratio = special_chars / msg_len if special_ratio > 0.35: alphanumeric = sum(1 for c in message if c.isalnum()) if alphanumeric / msg_len < 0.50: return self.SkipReason.SKIP_TECHNICAL.value - + return None - def detect_skip_reason(self, message: str, max_message_chars: int = Constants.MAX_MESSAGE_CHARS) -> Optional[str]: + def detect_skip_reason( + self, message: str, max_message_chars: int, memory_system: "Filter" + ) -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) 2. Semantic classification (for remaining cases) - Returns: Skip reason string if content should be skipped, None otherwise """ size_issue = self.validate_message_size(message, max_message_chars) if size_issue: return size_issue - + fast_skip = self._fast_path_skip_detection(message) if fast_skip: logger.info(f"Fast-path skip: {fast_skip}") return fast_skip - + if self._reference_embeddings is None: - logger.warning("SkipDetector reference embeddings not initialized, allowing message through") + logger.warning( + "SkipDetector reference embeddings not initialized, allowing message through" + ) return None - + try: message_embedding = np.array(self.embedding_function([message.strip()])[0]) - + conversational_similarities = np.dot( - message_embedding, - self._reference_embeddings['conversational'].T + message_embedding, self._reference_embeddings["conversational"].T ) max_conversational_similarity = float(conversational_similarities.max()) - + skip_categories = [ - ('instruction', self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), - ('translation', self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), - ('grammar', self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), - ('technical', self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), - ('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), + ( + "instruction", + self.SkipReason.SKIP_INSTRUCTION, + self.INSTRUCTION_CATEGORY_DESCRIPTIONS, + ), + ( + "translation", + self.SkipReason.SKIP_TRANSLATION, + self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS, + ), + ( + "grammar", + self.SkipReason.SKIP_GRAMMAR_PROOFREAD, + self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS, + ), + ( + "technical", + self.SkipReason.SKIP_TECHNICAL, + self.TECHNICAL_CATEGORY_DESCRIPTIONS, + ), + ( + "pure_math", + self.SkipReason.SKIP_PURE_MATH, + self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS, + ), ] - + + qualifying_categories = [] + margin_threshold = ( + max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN + ) + for cat_key, skip_reason, descriptions in skip_categories: similarities = np.dot( - message_embedding, - self._reference_embeddings[cat_key].T + message_embedding, self._reference_embeddings[cat_key].T ) max_similarity = float(similarities.max()) - - if max_similarity > Constants.SKIP_DETECTION_SIMILARITY_THRESHOLD: - margin = max_similarity - max_conversational_similarity - - if margin > Constants.SKIP_DETECTION_CONFIDENT_MARGIN: - logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})") - return skip_reason.value - - if margin > Constants.SKIP_DETECTION_MARGIN: - logger.info(f"Skipping message - {skip_reason.value} ({cat_key}: {max_similarity:.3f}, conv: {max_conversational_similarity:.3f}, margin: {margin:.3f})") - return skip_reason.value - + + if max_similarity > margin_threshold: + qualifying_categories.append((max_similarity, cat_key, skip_reason)) + + if qualifying_categories: + highest_similarity, highest_cat_key, highest_skip_reason = max( + qualifying_categories, key=lambda x: x[0] + ) + logger.info( + f"๐Ÿšซ Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})" + ) + return highest_skip_reason.value + return None - + except Exception as e: logger.error(f"Error in semantic skip detection: {e}") return None @@ -713,13 +820,28 @@ class LLMRerankingService: if not self.memory_system.valves.enable_llm_reranking: return False, "LLM reranking disabled" - llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) + llm_trigger_threshold = int( + self.memory_system.valves.max_memories_returned + * self.memory_system.valves.llm_reranking_trigger_multiplier + ) if len(memories) > llm_trigger_threshold: - return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" + return ( + True, + f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold", + ) - return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}" + return ( + False, + f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}", + ) - async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]: + async def _llm_select_memories( + self, + user_message: str, + candidate_memories: List[Dict], + max_count: int, + emitter: Optional[Callable] = None, + ) -> List[Dict]: """Use LLM to select most relevant memories.""" memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) memory_context = "\n".join(memory_lines) @@ -732,56 +854,83 @@ CANDIDATE MEMORIES: {memory_context}""" try: - response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse) - - selected_ids = response.ids + response = await self.memory_system._query_llm( + Prompts.MEMORY_RERANKING, + user_prompt, + response_model=Models.MemoryRerankingResponse, + ) selected_memories = [] for memory in candidate_memories: - if memory["id"] in selected_ids and len(selected_memories) < max_count: + if memory["id"] in response.ids and len(selected_memories) < max_count: selected_memories.append(memory) - logger.info(f"๐Ÿง  LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates") - + logger.info( + f"๐Ÿง  LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates" + ) + return selected_memories except Exception as e: - logger.warning(f"๐Ÿค– LLM reranking failed during memory relevance analysis: {str(e)}") + logger.warning( + f"๐Ÿค– LLM reranking failed during memory relevance analysis: {str(e)}" + ) return candidate_memories async def rerank_memories( - self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None + self, + user_message: str, + candidate_memories: List[Dict], + emitter: Optional[Callable] = None, ) -> Tuple[List[Dict], Dict[str, Any]]: start_time = time.time() max_injection = self.memory_system.valves.max_memories_returned - should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) + should_use_llm, decision_reason = self._should_use_llm_reranking( + candidate_memories + ) - analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} + analysis_info = { + "llm_decision": should_use_llm, + "decision_reason": decision_reason, + "candidate_count": len(candidate_memories), + } if should_use_llm: - extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) + extended_count = int( + self.memory_system.valves.max_memories_returned + * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) llm_candidates = candidate_memories[:extended_count] await self.memory_system._emit_status( - emitter, f"๐Ÿค– LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False + emitter, + f"๐Ÿค– LLM Analyzing {len(llm_candidates)} Memories for Relevance", + done=False, ) logger.info(f"Using LLM reranking: {decision_reason}") - selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) - + selected_memories = await self._llm_select_memories( + user_message, llm_candidates, max_injection, emitter + ) + if not selected_memories: logger.info("๐Ÿ“ญ No relevant memories after LLM analysis") - await self.memory_system._emit_status(emitter, f"๐Ÿ“ญ No Relevant Memories After LLM Analysis", done=True) + await self.memory_system._emit_status( + emitter, f"๐Ÿ“ญ No Relevant Memories After LLM Analysis", done=True + ) return selected_memories, analysis_info else: logger.info(f"Skipping LLM reranking: {decision_reason}") selected_memories = candidate_memories[:max_injection] - + duration = time.time() - start_time duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" retrieval_method = "LLM" if should_use_llm else "Semantic" - await self.memory_system._emit_status(emitter, f"๐ŸŽฏ {retrieval_method} Memory Retrieval Complete{duration_text}", done=True) - logger.info(f"๐ŸŽฏ {retrieval_method} Memory Retrieval Complete{duration_text}") + await self.memory_system._emit_status( + emitter, + f"๐ŸŽฏ {retrieval_method} Memory Retrieval Complete{duration_text}", + done=True, + ) return selected_memories, analysis_info @@ -791,18 +940,43 @@ class LLMConsolidationService: def __init__(self, memory_system): self.memory_system = memory_system + def _filter_consolidation_candidates( + self, similarities: List[Dict[str, Any]] + ) -> Tuple[List[Dict[str, Any]], str]: + """Filter consolidation candidates by threshold and return candidates with threshold info.""" + consolidation_threshold = self.memory_system._get_retrieval_threshold( + is_consolidation=True + ) + candidates = [ + mem for mem in similarities if mem["relevance"] >= consolidation_threshold + ] + + max_consolidation_memories = int( + self.memory_system.valves.max_memories_returned + * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) + candidates = candidates[:max_consolidation_memories] + + threshold_info = ( + f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + ) + return candidates, threshold_info + async def collect_consolidation_candidates( - self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> List[Dict[str, Any]]: """Collect candidate memories for consolidation analysis using cached or computed similarities.""" if cached_similarities: - consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) - candidates = [mem for mem in cached_similarities if mem["relevance"] >= consolidation_threshold] - - max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - candidates = candidates[:max_consolidation_memories] + candidates, threshold_info = self._filter_consolidation_candidates( + cached_similarities + ) - logger.info(f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {consolidation_threshold:.3f}, max: {max_consolidation_memories})") + logger.info( + f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" + ) self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates @@ -810,7 +984,9 @@ class LLMConsolidationService: try: user_memories = await self.memory_system._get_user_memories(user_id) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" + ) except Exception as e: logger.error(f"๐Ÿ’พ Failed to retrieve user memories from database: {str(e)}") return [] @@ -818,37 +994,48 @@ class LLMConsolidationService: if not user_memories: logger.info("๐Ÿ’ญ No existing memories found for consolidation") return [] - else: - logger.info(f"๐Ÿš€ Reusing cached user memories for consolidation: {len(user_memories)} memories") + + logger.info( + f"๐Ÿš€ Reusing cached user memories for consolidation: {len(user_memories)} memories" + ) try: - all_similarities, _, _ = await self.memory_system._compute_similarities(user_message, user_id, user_memories) + all_similarities, _, _ = await self.memory_system._compute_similarities( + user_message, user_id, user_memories + ) except Exception as e: - logger.error(f"๐Ÿ” Failed to compute memory similarities for retrieval: {str(e)}") + logger.error( + f"๐Ÿ” Failed to compute memory similarities for retrieval: {str(e)}" + ) return [] if all_similarities: - consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) - candidates = [mem for mem in all_similarities if mem["relevance"] >= consolidation_threshold] - - max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - candidates = candidates[:max_consolidation_memories] - - threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" + candidates, threshold_info = self._filter_consolidation_candidates( + all_similarities + ) else: candidates = [] - threshold_info = 'N/A' + threshold_info = "N/A" - logger.info(f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") + logger.info( + f"๐ŸŽฏ Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" + ) self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates - async def generate_consolidation_plan(self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None) -> List[Dict[str, Any]]: + async def generate_consolidation_plan( + self, + user_message: str, + candidate_memories: List[Dict[str, Any]], + emitter: Optional[Callable] = None, + ) -> List[Dict[str, Any]]: """Generate consolidation plan using LLM with clear system/user prompt separation.""" if candidate_memories: - memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) + memory_lines = self.memory_system._format_memories_for_llm( + candidate_memories + ) memory_context = f"EXISTING MEMORIES FOR CONSOLIDATION:\n{chr(10).join(memory_lines)}\n\n" else: memory_context = "EXISTING MEMORIES FOR CONSOLIDATION:\n[]\n\nNote: No existing memories found - Focus on extracting new memories from the user message below.\n\n" @@ -859,51 +1046,96 @@ class LLMConsolidationService: try: response = await asyncio.wait_for( - self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse), + self.memory_system._query_llm( + Prompts.MEMORY_CONSOLIDATION, + user_prompt, + response_model=Models.ConsolidationResponse, + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except Exception as e: - logger.warning(f"๐Ÿค– LLM consolidation failed during memory processing: {str(e)}") - await self.memory_system._emit_status(emitter, f"โš ๏ธ Memory Consolidation Failed", done=True) + logger.warning( + f"๐Ÿค– LLM consolidation failed during memory processing: {str(e)}" + ) + await self.memory_system._emit_status( + emitter, f"โš ๏ธ Memory Consolidation Failed", done=True + ) return [] operations = response.ops existing_memory_ids = {memory["id"] for memory in candidate_memories} total_operations = len(operations) - delete_operations = [op for op in operations if op.operation == Models.MemoryOperationType.DELETE] - delete_ratio = len(delete_operations) / total_operations if total_operations > 0 else 0 + delete_operations = [ + op for op in operations if op.operation == Models.MemoryOperationType.DELETE + ] + delete_ratio = ( + len(delete_operations) / total_operations if total_operations > 0 else 0 + ) - if delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK: + if ( + delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO + and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK + ): logger.warning( f"โš ๏ธ Consolidation safety: {len(delete_operations)}/{total_operations} operations are deletions ({delete_ratio*100:.1f}%) - rejecting plan" ) return [] - valid_operations = [op.model_dump() for op in operations if op.validate_operation(existing_memory_ids)] + valid_operations = [ + op.model_dump() + for op in operations + if op.validate_operation(existing_memory_ids) + ] if valid_operations: - create_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.CREATE.value) - update_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.UPDATE.value) - delete_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.DELETE.value) + create_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.CREATE.value + ) + update_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.UPDATE.value + ) + delete_count = sum( + 1 + for op in valid_operations + if op.get("operation") == Models.MemoryOperationType.DELETE.value + ) - operation_details = self.memory_system._build_operation_details(create_count, update_count, delete_count) + operation_details = self.memory_system._build_operation_details( + create_count, update_count, delete_count + ) - logger.info(f"๐ŸŽฏ Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}") + logger.info( + f"๐ŸŽฏ Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}" + ) else: logger.info("๐ŸŽฏ No valid memory operations planned") return valid_operations - async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: + async def execute_memory_operations( + self, + operations: List[Dict[str, Any]], + user_id: str, + emitter: Optional[Callable] = None, + ) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" if not operations or not user_id: return 0, 0, 0, 0 try: - user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC) + user = await asyncio.wait_for( + asyncio.to_thread(Users.get_user_by_id, user_id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" + ) except Exception as e: raise RuntimeError(f"๐Ÿ‘ค User lookup failed: {str(e)}") @@ -919,23 +1151,31 @@ class LLMConsolidationService: operations_by_type[operation.operation.value].append(operation) except Exception as e: failed_count += 1 - operation_type = operation_data.get("operation", Models.OperationResult.UNSUPPORTED.value) + operation_type = operation_data.get( + "operation", Models.OperationResult.UNSUPPORTED.value + ) content_preview = "" if "content" in operation_data: content = operation_data.get("content", "") content_preview = f" - Content: {self.memory_system._truncate_content(content, Constants.CONTENT_PREVIEW_LENGTH)}" elif "id" in operation_data: content_preview = f" - ID: {operation_data['id']}" - error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}" + error_message = ( + f"Failed {operation_type} operation{content_preview}: {str(e)}" + ) logger.error(error_message) memory_contents_for_deletion = {} if operations_by_type["DELETE"]: try: user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} + memory_contents_for_deletion = { + str(mem.id): mem.content for mem in user_memories + } except Exception as e: - logger.warning(f"โš ๏ธ Failed to fetch memories for DELETE preview: {str(e)}") + logger.warning( + f"โš ๏ธ Failed to fetch memories for DELETE preview: {str(e)}" + ) for operation_type, ops in operations_by_type.items(): if not ops: @@ -950,31 +1190,56 @@ class LLMConsolidationService: results = await asyncio.gather(*batch_tasks, return_exceptions=True) for idx, result in enumerate(results): operation = ops[idx] - + if isinstance(result, Exception): failed_count += 1 - await self.memory_system._emit_status(emitter, f"โŒ Failed {operation_type}", done=False) + await self.memory_system._emit_status( + emitter, f"โŒ Failed {operation_type}", done=False + ) elif result == Models.MemoryOperationType.CREATE.value: created_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"๐Ÿ“ Created: {content_preview}", done=False) + content_preview = self.memory_system._truncate_content( + operation.content + ) + await self.memory_system._emit_status( + emitter, f"๐Ÿ“ Created: {content_preview}", done=False + ) elif result == Models.MemoryOperationType.UPDATE.value: updated_count += 1 - content_preview = self.memory_system._truncate_content(operation.content) - await self.memory_system._emit_status(emitter, f"โœ๏ธ Updated: {content_preview}", done=False) + content_preview = self.memory_system._truncate_content( + operation.content + ) + await self.memory_system._emit_status( + emitter, f"โœ๏ธ Updated: {content_preview}", done=False + ) elif result == Models.MemoryOperationType.DELETE.value: deleted_count += 1 - content_preview = memory_contents_for_deletion.get(operation.id, operation.id) + content_preview = memory_contents_for_deletion.get( + operation.id, operation.id + ) if content_preview and content_preview != operation.id: - content_preview = self.memory_system._truncate_content(content_preview) - await self.memory_system._emit_status(emitter, f"๐Ÿ—‘๏ธ Deleted: {content_preview}", done=False) - elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]: + content_preview = self.memory_system._truncate_content( + content_preview + ) + await self.memory_system._emit_status( + emitter, f"๐Ÿ—‘๏ธ Deleted: {content_preview}", done=False + ) + elif result in [ + Models.OperationResult.FAILED.value, + Models.OperationResult.UNSUPPORTED.value, + ]: failed_count += 1 - await self.memory_system._emit_status(emitter, f"โŒ Failed {operation_type}", done=False) + await self.memory_system._emit_status( + emitter, f"โŒ Failed {operation_type}", done=False + ) except Exception as e: failed_count += len(ops) - logger.error(f"โŒ Batch {operation_type} operations failed during memory consolidation: {str(e)}") - await self.memory_system._emit_status(emitter, f"โŒ Batch {operation_type} Failed", done=False) + logger.error( + f"โŒ Batch {operation_type} operations failed during memory consolidation: {str(e)}" + ) + await self.memory_system._emit_status( + emitter, f"โŒ Batch {operation_type} Failed", done=False + ) total_executed = created_count + updated_count + deleted_count logger.info( @@ -982,14 +1247,20 @@ class LLMConsolidationService: ) if total_executed > 0: - operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) + operation_details = self.memory_system._build_operation_details( + created_count, updated_count, deleted_count + ) logger.info(f"๐Ÿ”„ Memory Operations: {', '.join(operation_details)}") - await self.memory_system._manage_user_cache(user_id) + await self.memory_system._refresh_user_cache(user_id) return created_count, updated_count, deleted_count, failed_count async def run_consolidation_pipeline( - self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + emitter: Optional[Callable] = None, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> None: """Complete consolidation pipeline with simplified flow.""" start_time = time.time() @@ -997,36 +1268,52 @@ class LLMConsolidationService: if self.memory_system._shutdown_event.is_set(): return - candidates = await self.collect_consolidation_candidates(user_message, user_id, cached_similarities) + candidates = await self.collect_consolidation_candidates( + user_message, user_id, cached_similarities + ) if self.memory_system._shutdown_event.is_set(): return - operations = await self.generate_consolidation_plan(user_message, candidates, emitter) + operations = await self.generate_consolidation_plan( + user_message, candidates, emitter + ) if self.memory_system._shutdown_event.is_set(): return if operations: - created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter) - + created_count, updated_count, deleted_count, failed_count = ( + await self.execute_memory_operations(operations, user_id, emitter) + ) + duration = time.time() - start_time logger.info(f"๐Ÿ’พ Memory Consolidation Complete In {duration:.2f}s") - + total_operations = created_count + updated_count + deleted_count if total_operations > 0 or failed_count > 0: - await self.memory_system._emit_status(emitter, f"๐Ÿ’พ Memory Consolidation Complete in {duration:.2f}s", done=False) - - operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) + await self.memory_system._emit_status( + emitter, + f"๐Ÿ’พ Memory Consolidation Complete in {duration:.2f}s", + done=False, + ) + + operation_details = self.memory_system._build_operation_details( + created_count, updated_count, deleted_count + ) memory_word = "Memory" if total_operations == 1 else "Memories" operations_summary = f"{', '.join(operation_details)} {memory_word}" - + if failed_count > 0: operations_summary += f" (โŒ {failed_count} Failed)" - - await self.memory_system._emit_status(emitter, operations_summary, done=True) + + await self.memory_system._emit_status( + emitter, operations_summary, done=True + ) except Exception as e: duration = time.time() - start_time - raise RuntimeError(f"โŒ Memory consolidation failed after {duration:.2f}s: {str(e)}") + raise RuntimeError( + f"โŒ Memory consolidation failed after {duration:.2f}s: {str(e)}" + ) class Filter: @@ -1040,22 +1327,53 @@ class Filter: class Valves(BaseModel): """Configuration valves for the Memory System.""" - model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") - max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") - semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval") - relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)") - enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") - llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)") + model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Model name for LLM operations", + ) + use_custom_model_for_memory: bool = Field( + default=False, + description="Use a custom model for memory operations instead of the current chat model", + ) + custom_memory_model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Custom model to use for memory operations when enabled", + ) + max_memories_returned: int = Field( + default=Constants.MAX_MEMORIES_PER_RETRIEVAL, + description="Maximum number of memories to return in context", + ) + max_message_chars: int = Field( + default=Constants.MAX_MESSAGE_CHARS, + description="Maximum user message length before skipping memory operations", + ) + semantic_retrieval_threshold: float = Field( + default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, + description="Minimum similarity threshold for memory retrieval", + ) + relaxed_semantic_threshold_multiplier: float = Field( + default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, + description="Adjusts similarity threshold for memory consolidation (lower = more candidates)", + ) + enable_llm_reranking: bool = Field( + default=True, + description="Enable LLM-based memory reranking for improved contextual selection", + ) + llm_reranking_trigger_multiplier: float = Field( + default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, + description="Controls when LLM reranking activates (lower = more aggressive)", + ) def __init__(self): """Initialize the Memory System filter with production validation.""" global _SHARED_SKIP_DETECTOR_CACHE - + self.valves = self.Valves() self._validate_system_configuration() - self._cache_manager = UnifiedCacheManager(Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES) + self._cache_manager = UnifiedCacheManager( + Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES + ) self._background_tasks: set = set() self._shutdown_event = asyncio.Event() @@ -1065,8 +1383,13 @@ class Filter: self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) - def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None, - __model__: Optional[str] = None, __request__: Optional[Request] = None) -> None: + async def _set_pipeline_context( + self, + __event_emitter__: Optional[Callable] = None, + __user__: Optional[Dict[str, Any]] = None, + __model__: Optional[str] = None, + __request__: Optional[Request] = None, + ) -> None: """Set pipeline context parameters to avoid duplication in inlet/outlet methods.""" if __event_emitter__: self.__current_event_emitter__ = __event_emitter__ @@ -1076,35 +1399,49 @@ class Filter: self.__model__ = __model__ if __request__: self.__request__ = __request__ - - if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'): + + if self._embedding_function is None and hasattr( + __request__.app.state, "EMBEDDING_FUNCTION" + ): self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"โœ… Using OpenWebUI's embedding function") - - if self._skip_detector is None: - global _SHARED_SKIP_DETECTOR_CACHE - embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '') - embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '') - cache_key = f"{embedding_engine}:{embedding_model}" - - if cache_key in _SHARED_SKIP_DETECTOR_CACHE: - logger.info(f"โ™ป๏ธ Reusing cached skip detector: {cache_key}") - self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] - else: - logger.info(f"๐Ÿค– Initializing skip detector with OpenWebUI embeddings: {cache_key}") - embedding_fn = self._embedding_function - def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: - result = embedding_fn(texts, prefix=None, user=None) - if isinstance(result, list): - if isinstance(result[0], list): - return [np.array(emb, dtype=np.float16) for emb in result] - return np.array(result, dtype=np.float16) - return np.array(result, dtype=np.float16) - - self._skip_detector = SkipDetector(embedding_wrapper) - _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector - logger.info(f"โœ… Skip detector initialized and cached") + if self._skip_detector is None: + global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK + embedding_engine = getattr( + __request__.app.state.config, "RAG_EMBEDDING_ENGINE", "" + ) + embedding_model = getattr( + __request__.app.state.config, "RAG_EMBEDDING_MODEL", "" + ) + cache_key = f"{embedding_engine}:{embedding_model}" + + async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: + if cache_key in _SHARED_SKIP_DETECTOR_CACHE: + logger.info(f"โ™ป๏ธ Reusing cached skip detector: {cache_key}") + self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] + else: + logger.info( + f"๐Ÿค– Initializing skip detector with OpenWebUI embeddings: {cache_key}" + ) + embedding_fn = self._embedding_function + + def embedding_wrapper( + texts: Union[str, List[str]], + ) -> Union[np.ndarray, List[np.ndarray]]: + result = embedding_fn(texts, prefix=None, user=None) + if isinstance(result, list): + if isinstance(result[0], list): + return [ + np.array(emb, dtype=np.float16) + for emb in result + ] + return np.array(result, dtype=np.float16) + return np.array(result, dtype=np.float16) + + self._skip_detector = SkipDetector(embedding_wrapper) + _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector + logger.info(f"โœ… Skip detector initialized and cached") def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str: """Truncate content with ellipsis if needed.""" @@ -1115,7 +1452,10 @@ class Filter: def _get_retrieval_threshold(self, is_consolidation: bool = False) -> float: """Calculate retrieval threshold for semantic similarity filtering.""" if is_consolidation: - return self.valves.semantic_retrieval_threshold * self.valves.relaxed_semantic_threshold_multiplier + return ( + self.valves.semantic_retrieval_threshold + * self.valves.relaxed_semantic_threshold_multiplier + ) return self.valves.semantic_retrieval_threshold def _extract_text_from_content(self, content) -> str: @@ -1137,26 +1477,36 @@ class Filter: raise ValueError("๐Ÿค– Model not specified") if self.valves.max_memories_returned <= 0: - raise ValueError(f"๐Ÿ“Š Invalid max memories returned: {self.valves.max_memories_returned}") + raise ValueError( + f"๐Ÿ“Š Invalid max memories returned: {self.valves.max_memories_returned}" + ) if not (0.0 <= self.valves.semantic_retrieval_threshold <= 1.0): - raise ValueError(f"๐ŸŽฏ Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)") + raise ValueError( + f"๐ŸŽฏ Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)" + ) logger.info("โœ… Configuration validated") async def _get_embedding_cache(self, user_id: str, key: str) -> Optional[Any]: """Get embedding from cache.""" - return await self._cache_manager.get(user_id, self._cache_manager.EMBEDDING_CACHE, key) + return await self._cache_manager.get( + user_id, self._cache_manager.EMBEDDING_CACHE, key + ) async def _put_embedding_cache(self, user_id: str, key: str, value: Any) -> None: """Store embedding in cache.""" - await self._cache_manager.put(user_id, self._cache_manager.EMBEDDING_CACHE, key, value) + await self._cache_manager.put( + user_id, self._cache_manager.EMBEDDING_CACHE, key, value + ) def _compute_text_hash(self, text: str) -> str: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: + def _normalize_embedding( + self, embedding: Union[List[float], np.ndarray] + ) -> np.ndarray: """Normalize embedding vector.""" if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float16) @@ -1165,11 +1515,15 @@ class Filter: norm = np.linalg.norm(embedding) return embedding / norm if norm > 0 else embedding - async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: + async def _generate_embeddings( + self, texts: Union[str, List[str]], user_id: str + ) -> Union[np.ndarray, List[np.ndarray]]: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" if self._embedding_function is None: - raise RuntimeError("๐Ÿค– Embedding function not initialized. Ensure pipeline context is set.") - + raise RuntimeError( + "๐Ÿค– Embedding function not initialized. Ensure pipeline context is set." + ) + is_single = isinstance(texts, str) text_list = [texts] if is_single else texts @@ -1202,20 +1556,22 @@ class Filter: uncached_hashes.append(text_hash) if uncached_texts: - user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None - + user = ( + await asyncio.to_thread(Users.get_user_by_id, user_id) + if hasattr(self, "__user__") + else None + ) + loop = asyncio.get_event_loop() raw_embeddings = await loop.run_in_executor( - None, - self._embedding_function, - uncached_texts, - None, - user + None, self._embedding_function, uncached_texts, None, user ) - + if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], list): - new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] + new_embeddings = [ + self._normalize_embedding(emb) for emb in raw_embeddings + ] else: new_embeddings = [self._normalize_embedding(raw_embeddings)] else: @@ -1228,7 +1584,11 @@ class Filter: result_embeddings[original_idx] = embedding if is_single: - logger.info("๐Ÿ“ฅ User message embedding: cache hit" if not uncached_texts else "๐Ÿ’พ User message embedding: generated and cached") + logger.info( + "๐Ÿ“ฅ User message embedding: cache hit" + if not uncached_texts + else "๐Ÿ’พ User message embedding: generated and cached" + ) return result_embeddings[0] else: valid_count = sum(1 for emb in result_embeddings if emb is not None) @@ -1240,17 +1600,25 @@ class Filter: def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]: if self._skip_detector is None: raise RuntimeError("๐Ÿค– Skip detector not initialized") - - skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars) + + skip_reason = self._skip_detector.detect_skip_reason( + user_message, self.valves.max_message_chars, memory_system=self + ) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key] return False, "" - def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: + def _process_user_message( + self, body: Dict[str, Any] + ) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" if not body or "messages" not in body or not isinstance(body["messages"], list): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) messages = body["messages"] user_message = None @@ -1266,23 +1634,34 @@ class Filter: break if not user_message or not user_message.strip(): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) should_skip, skip_reason = self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason - async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: + async def _get_user_memories( + self, user_id: str, timeout: Optional[float] = None + ) -> List: """Get user memories with timeout handling.""" if timeout is None: timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC try: - return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout) + return await asyncio.wait_for( + asyncio.to_thread(Memories.get_memories_by_user_id, user_id), + timeout=timeout, + ) except asyncio.TimeoutError: raise TimeoutError(f"โฑ๏ธ Memory retrieval timed out after {timeout}s") except Exception as e: raise RuntimeError(f"๐Ÿ’พ Memory retrieval failed: {str(e)}") - def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: + def _log_retrieved_memories( + self, memories: List[Dict[str, Any]], context_type: str = "semantic" + ) -> None: """Log retrieved memories with concise formatting showing key statistics and semantic values.""" if not memories: return @@ -1294,32 +1673,44 @@ class Filter: top_score = max(scores) lowest_score = min(scores) - median_score = sorted(scores)[len(scores) // 2] + median_score = statistics.median(scores) - context_label = "๐Ÿ“Š Consolidation candidate memories" if context_type == "consolidation" else "๐Ÿ“Š Retrieved memories" - max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) - scores_str = ", ".join([f"{score:.3f}" for score in scores[:max_scores_to_show]]) + context_label = ( + "๐Ÿ“Š Consolidation candidate memories" + if context_type == "consolidation" + else "๐Ÿ“Š Retrieved memories" + ) + max_scores_to_show = int( + self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER + ) + scores_str = ", ".join( + [f"{score:.3f}" for score in scores[:max_scores_to_show]] + ) suffix = "..." if len(scores) > max_scores_to_show else "" - logger.info(f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}") + logger.info( + f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}" + ) logger.info(f"Scores: [{scores_str}{suffix}]") - def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]: - """Build operation details list with consistent formatting.""" - operation_details = [] + def _build_operation_details( + self, created_count: int, updated_count: int, deleted_count: int + ) -> List[str]: + operations = [ + (created_count, "๐Ÿ“ Created"), + (updated_count, "โœ๏ธ Updated"), + (deleted_count, "๐Ÿ—‘๏ธ Deleted"), + ] + return [f"{label} {count}" for count, label in operations if count > 0] - operations = [(created_count, "๐Ÿ“ Created"), (updated_count, "โœ๏ธ Updated"), (deleted_count, "๐Ÿ—‘๏ธ Deleted")] - - for count, label in operations: - if count > 0: - operation_details.append(f"{label} {count}") - - return operation_details - - def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: + def _cache_key( + self, cache_type: str, user_id: str, content: Optional[str] = None + ) -> str: """Unified cache key generation for all cache types.""" if content: - content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[ + : Constants.CACHE_KEY_HASH_PREFIX_LENGTH + ] return f"{cache_type}_{user_id}:{content_hash}" return f"{cache_type}_{user_id}" @@ -1339,7 +1730,9 @@ class Filter: if record_date: try: if isinstance(record_date, str): - parsed_date = datetime.fromisoformat(record_date.replace('Z', '+00:00')) + parsed_date = datetime.fromisoformat( + record_date.replace("Z", "+00:00") + ) else: parsed_date = record_date formatted_date = parsed_date.strftime("%b %d %Y") @@ -1350,7 +1743,9 @@ class Filter: memory_lines.append(line) return memory_lines - async def _emit_status(self, emitter: Optional[Callable], description: str, done: bool = True) -> None: + async def _emit_status( + self, emitter: Optional[Callable], description: str, done: bool = True + ) -> None: """Emit status messages for memory operations.""" if not emitter: return @@ -1374,9 +1769,19 @@ class Filter: ) -> Dict[str, Any]: """Retrieve memories for injection using similarity computation with optional LLM reranking.""" if cached_similarities is not None: - memories = [m for m in cached_similarities if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold] - logger.info(f"๐Ÿ” Using cached similarities for {len(memories)} candidate memories") - final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) + memories = [ + m + for m in cached_similarities + if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold + ] + logger.info( + f"๐Ÿ” Using cached similarities for {len(memories)} candidate memories" + ) + final_memories, reranking_info = ( + await self._llm_reranking_service.rerank_memories( + user_message, memories, emitter + ) + ) self._log_retrieved_memories(final_memories, "semantic") return { "memories": final_memories, @@ -1393,10 +1798,16 @@ class Filter: await self._emit_status(emitter, "๐Ÿ“ญ No Memories Found", done=True) return {"memories": [], "threshold": None} - memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories) + memories, threshold, all_similarities = await self._compute_similarities( + user_message, user_id, user_memories + ) if memories: - final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) + final_memories, reranking_info = ( + await self._llm_reranking_service.rerank_memories( + user_message, memories, emitter + ) + ) else: logger.info("๐Ÿ“ญ No relevant memories found above similarity threshold") await self._emit_status(emitter, "๐Ÿ“ญ No Relevant Memories Found", done=True) @@ -1405,10 +1816,19 @@ class Filter: self._log_retrieved_memories(final_memories, "semantic") - return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info} + return { + "memories": final_memories, + "threshold": threshold, + "all_similarities": all_similarities, + "reranking_info": reranking_info, + } async def _add_memory_context( - self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None + self, + body: Dict[str, Any], + memories: Optional[List[Dict[str, Any]]] = None, + user_id: Optional[str] = None, + emitter: Optional[Callable] = None, ) -> None: """Add memory context to request body with simplified logic.""" if not body or "messages" not in body or not body["messages"]: @@ -1422,38 +1842,57 @@ class Filter: memory_count = len(memories) memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request." formatted_memories = [] - + for idx, memory in enumerate(memories, 1): formatted_memory = f"- {' '.join(memory['content'].split())}" formatted_memories.append(formatted_memory) - - content_preview = self._truncate_content(memory['content']) - await self._emit_status(emitter, f"๐Ÿ’ญ {idx}/{memory_count}: {content_preview}", done=False) - + + content_preview = self._truncate_content(memory["content"]) + await self._emit_status( + emitter, f"๐Ÿ’ญ {idx}/{memory_count}: {content_preview}", done=False + ) + memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only." memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}" content_parts.append(memory_context_block) memory_context = "\n\n".join(content_parts) - system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None) + system_index = next( + ( + i + for i, msg in enumerate(body["messages"]) + if msg.get("role") == "system" + ), + None, + ) if system_index is not None: - body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" + body["messages"][system_index][ + "content" + ] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" else: body["messages"].insert(0, {"role": "system", "content": memory_context}) - + if memories and user_id: description = f"๐Ÿง  Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context" await self._emit_status(emitter, description, done=True) def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]: """Build memory dictionary with standardized timestamp conversion.""" - memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity} + memory_dict = { + "id": str(memory.id), + "content": memory.content, + "relevance": similarity, + } if hasattr(memory, "created_at") and memory.created_at: - memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat() + memory_dict["created_at"] = datetime.fromtimestamp( + memory.created_at, tz=timezone.utc + ).isoformat() if hasattr(memory, "updated_at") and memory.updated_at: - memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat() + memory_dict["updated_at"] = datetime.fromtimestamp( + memory.updated_at, tz=timezone.utc + ).isoformat() return memory_dict async def _compute_similarities( @@ -1468,7 +1907,9 @@ class Filter: memory_embeddings = await self._generate_embeddings(memory_contents, user_id) if len(memory_embeddings) != len(user_memories): - logger.error(f"๐Ÿ”ข Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories") + logger.error( + f"๐Ÿ”ข Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories" + ) return [], self.valves.semantic_retrieval_threshold, [] similarity_scores = [] @@ -1490,7 +1931,7 @@ class Filter: memory_data.sort(key=lambda x: x["relevance"], reverse=True) threshold = self.valves.semantic_retrieval_threshold - filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] + filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] return filtered_memories, threshold, memory_data async def inlet( @@ -1503,43 +1944,68 @@ class Filter: **kwargs, ) -> Dict[str, Any]: """Simplified inlet processing for memory retrieval and injection.""" - self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"โš ๏ธ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"๐Ÿง  Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context( + __event_emitter__, __user__, model_to_use, __request__ + ) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True) await self._add_memory_context(body, [], user_id, __event_emitter__) return body - try: - memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) - user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key) - + memory_cache_key = self._cache_key( + self._cache_manager.MEMORY_CACHE, user_id + ) + user_memories = await self._cache_manager.get( + user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key + ) if user_memories is None: user_memories = await self._get_user_memories(user_id) - if user_memories: - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) - - retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) + retrieval_result = await self._retrieve_relevant_memories( + user_message, user_id, user_memories, __event_emitter__ + ) memories = retrieval_result.get("memories", []) threshold = retrieval_result.get("threshold") all_similarities = retrieval_result.get("all_similarities", []) - if all_similarities: - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities) - + cache_key = self._cache_key( + self._cache_manager.RETRIEVAL_CACHE, user_id, user_message + ) + await self._cache_manager.put( + user_id, + self._cache_manager.RETRIEVAL_CACHE, + cache_key, + all_similarities, + ) await self._add_memory_context(body, memories, user_id, __event_emitter__) - except Exception as e: raise RuntimeError(f"๐Ÿ’พ Memory retrieval failed: {str(e)}") - return body async def outlet( @@ -1552,22 +2018,40 @@ class Filter: **kwargs, ) -> dict: """Simplified outlet processing for background memory consolidation.""" - self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"โš ๏ธ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"๐Ÿง  Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context( + __event_emitter__, __user__, model_to_use, __request__ + ) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: return body - - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) - + cache_key = self._cache_key( + self._cache_manager.RETRIEVAL_CACHE, user_id, user_message + ) + cached_similarities = await self._cache_manager.get( + user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key + ) task = asyncio.create_task( - self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities) + self._llm_consolidation_service.run_consolidation_pipeline( + user_message, user_id, __event_emitter__, cached_similarities + ) ) self._background_tasks.add(task) @@ -1576,12 +2060,13 @@ class Filter: self._background_tasks.discard(t) if t.exception() and not t.cancelled(): exception = t.exception() - logger.error(f"โŒ Background memory consolidation task failed: {str(exception)}") + logger.error( + f"โŒ Background memory consolidation task failed: {str(exception)}" + ) except Exception as e: logger.error(f"โŒ Failed to cleanup background memory task: {str(e)}") task.add_done_callback(safe_cleanup) - return body async def shutdown(self) -> None: @@ -1594,75 +2079,112 @@ class Filter: await self._cache_manager.clear_all_caches() - async def _manage_user_cache(self, user_id: str, clear_first: bool = False) -> None: - """Manage user cache - clear, invalidate, and refresh as needed.""" + async def _refresh_user_cache(self, user_id: str) -> None: + """Refresh user cache - clear stale caches and update with fresh embeddings.""" start_time = time.time() try: - if clear_first: - total_removed = await self._cache_manager.clear_user_cache(user_id) - logger.info(f"๐Ÿงน Cleared {total_removed} cache entries for user {user_id}") - else: - retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) - logger.info(f"๐Ÿ”„ Cleared {retrieval_cleared} retrieval cache entries for user {user_id}") + retrieval_cleared = await self._cache_manager.clear_user_cache( + user_id, self._cache_manager.RETRIEVAL_CACHE + ) + embedding_cleared = await self._cache_manager.clear_user_cache( + user_id, self._cache_manager.EMBEDDING_CACHE + ) + logger.info( + f"๐Ÿ”„ Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}" + ) user_memories = await self._get_user_memories(user_id) - memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) + memory_cache_key = self._cache_key( + self._cache_manager.MEMORY_CACHE, user_id + ) if not user_memories: - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, []) + await self._cache_manager.put( + user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, [] + ) logger.info("๐Ÿ“ญ No memories found for user") return - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) memory_contents = [ memory.content for memory in user_memories - if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS + if memory.content + and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS ] if memory_contents: await self._generate_embeddings(memory_contents, user_id) duration = time.time() - start_time - logger.info(f"๏ฟฝ Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s") + logger.info( + f"๐Ÿ”„ Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s" + ) except Exception as e: - raise RuntimeError(f"๐Ÿงน Failed to manage cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}") + raise RuntimeError( + f"๐Ÿงน Failed to refresh cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}" + ) - async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: + async def _execute_single_operation( + self, operation: Models.MemoryOperation, user: Any + ) -> str: """Execute a single memory operation.""" try: if operation.operation == Models.MemoryOperationType.CREATE: - if not operation.content.strip(): + content_stripped = operation.content.strip() + if not content_stripped: logger.warning(f"โš ๏ธ Skipping CREATE operation: empty content") return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, operation.content.strip()), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread( + Memories.insert_new_memory, user.id, content_stripped + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.CREATE.value elif operation.operation == Models.MemoryOperationType.UPDATE: - if not operation.id.strip(): + id_stripped = operation.id.strip() + if not id_stripped: logger.warning(f"โš ๏ธ Skipping UPDATE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value - if not operation.content.strip(): - logger.warning(f"โš ๏ธ Skipping UPDATE operation for {operation.id}: empty content") + + content_stripped = operation.content.strip() + if not content_stripped: + logger.warning( + f"โš ๏ธ Skipping UPDATE operation for {id_stripped}: empty content" + ) return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.update_memory_by_id_and_user_id, operation.id, user.id, operation.content.strip()), + asyncio.to_thread( + Memories.update_memory_by_id_and_user_id, + id_stripped, + user.id, + content_stripped, + ), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.UPDATE.value elif operation.operation == Models.MemoryOperationType.DELETE: - if not operation.id.strip(): + id_stripped = operation.id.strip() + if not id_stripped: logger.warning(f"โš ๏ธ Skipping DELETE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, operation.id, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread( + Memories.delete_memory_by_id_and_user_id, id_stripped, user.id + ), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.DELETE.value else: @@ -1670,42 +2192,62 @@ class Filter: return Models.OperationResult.UNSUPPORTED.value except Exception as e: - logger.error(f"๐Ÿ’พ Database operation failed for {operation.operation.value}: {str(e)}") + logger.error( + f"๐Ÿ’พ Database operation failed for {operation.operation.value}: {str(e)}" + ) return Models.OperationResult.FAILED.value - def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: + def _remove_refs_from_schema( + self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None + ) -> Dict[str, Any]: """Remove $ref references and ensure required fields for Azure OpenAI.""" if not isinstance(schema, dict): return schema - - if '$ref' in schema: - ref_path = schema['$ref'] - if ref_path.startswith('#/$defs/'): - def_name = ref_path.split('/')[-1] + + if "$ref" in schema: + ref_path = schema["$ref"] + if ref_path.startswith("#/$defs/"): + def_name = ref_path.split("/")[-1] if schema_defs and def_name in schema_defs: - return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs) - return {'type': 'object'} - + return self._remove_refs_from_schema( + schema_defs[def_name].copy(), schema_defs + ) + return {"type": "object"} + result = {} for key, value in schema.items(): - if key == '$defs': + if key == "$defs": continue elif isinstance(value, dict): result[key] = self._remove_refs_from_schema(value, schema_defs) elif isinstance(value, list): - result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value] + result[key] = [ + ( + self._remove_refs_from_schema(item, schema_defs) + if isinstance(item, dict) + else item + ) + for item in value + ] else: result[key] = value - - if result.get('type') == 'object' and 'properties' in result: - result['required'] = list(result['properties'].keys()) - + + if result.get("type") == "object" and "properties" in result: + result["required"] = list(result["properties"].keys()) + return result - async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]: + async def _query_llm( + self, + system_prompt: str, + user_prompt: str, + response_model: Optional[BaseModel] = None, + ) -> Union[str, BaseModel]: """Query OpenWebUI's internal model system with Pydantic model parsing.""" if not hasattr(self, "__request__") or not hasattr(self, "__user__"): - raise RuntimeError("๐Ÿ”ง Pipeline interface not properly initialized. __request__ and __user__ required.") + raise RuntimeError( + "๐Ÿ”ง Pipeline interface not properly initialized. __request__ and __user__ required." + ) model_to_use = self.valves.model if self.valves.model else self.__model__ if not model_to_use: @@ -1713,30 +2255,50 @@ class Filter: form_data = { "model": model_to_use, - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], "max_tokens": 4096, "stream": False, } if response_model: raw_schema = response_model.model_json_schema() - schema_defs = raw_schema.get('$defs', {}) + schema_defs = raw_schema.get("$defs", {}) schema = self._remove_refs_from_schema(raw_schema, schema_defs) - schema['type'] = 'object' - form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}} + schema["type"] = "object" + form_data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": response_model.__name__, + "strict": True, + "schema": schema, + }, + } try: response = await asyncio.wait_for( - generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])), + generate_chat_completion( + self.__request__, + form_data, + user=await asyncio.to_thread( + Users.get_user_by_id, self.__user__["id"] + ), + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except asyncio.TimeoutError: - raise TimeoutError(f"โฑ๏ธ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s") + raise TimeoutError( + f"โฑ๏ธ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s" + ) except Exception as e: raise RuntimeError(f"๐Ÿค– LLM query failed: {str(e)}") try: - if hasattr(response, "body") and hasattr(getattr(response, "body", None), "decode"): + if hasattr(response, "body") and hasattr( + getattr(response, "body", None), "decode" + ): body = getattr(response, "body") response_data = json.loads(body.decode("utf-8")) else: @@ -1744,23 +2306,39 @@ class Filter: except (json.JSONDecodeError, AttributeError) as e: raise RuntimeError(f"๐Ÿ” Failed to decode response body: {str(e)}") - if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0: + if ( + isinstance(response_data, dict) + and "choices" in response_data + and isinstance(response_data["choices"], list) + and len(response_data["choices"]) > 0 + ): first_choice = response_data["choices"][0] - if isinstance(first_choice, dict) and "message" in first_choice and isinstance(first_choice["message"], dict) and "content" in first_choice["message"]: + if ( + isinstance(first_choice, dict) + and "message" in first_choice + and isinstance(first_choice["message"], dict) + and "content" in first_choice["message"] + ): content = first_choice["message"]["content"] else: - raise ValueError("๐Ÿค– Invalid response structure: missing content in message") + raise ValueError( + "๐Ÿค– Invalid response structure: missing content in message" + ) else: raise ValueError(f"๐Ÿค– Unexpected LLM response format: {response_data}") if response_model: - try: + try: parsed_data = json.loads(content) return response_model.model_validate(parsed_data) except json.JSONDecodeError as e: - raise ValueError(f"๐Ÿ” Invalid JSON from LLM: {str(e)}\nContent: {content}") + raise ValueError( + f"๐Ÿ” Invalid JSON from LLM: {str(e)}\nContent: {content}" + ) except PydanticValidationError as e: - raise ValueError(f"๐Ÿ” LLM response validation failed: {str(e)}\nContent: {content}") + raise ValueError( + f"๐Ÿ” LLM response validation failed: {str(e)}\nContent: {content}" + ) if not content or content.strip() == "": raise ValueError("๐Ÿค– Empty response from LLM")