From 55a8c70bace283856b279edd8cca147b276d168a Mon Sep 17 00:00:00 2001 From: GlissemanTV Date: Mon, 27 Oct 2025 21:06:35 +0100 Subject: [PATCH] add suppport of current chat model --- memory_system.py | 906 +++++++++++------------------------------------ 1 file changed, 205 insertions(+), 701 deletions(-) diff --git a/memory_system.py b/memory_system.py index 8497510..8cef035 100644 --- a/memory_system.py +++ b/memory_system.py @@ -1,9 +1,6 @@ """ title: Memory System -description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations. version: 1.0.0 -authors: https://github.com/mtayfur -license: Apache-2.0 """ import asyncio @@ -18,14 +15,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Tuple, Union import numpy as np -from pydantic import ( - BaseModel, - ConfigDict, - Field, - ValidationError as PydanticValidationError, -) - -from open_webui.utils.chat import generate_chat_completion +from fastapi import Request from open_webui.models.users import Users from open_webui.routers.memories import Memories from open_webui.utils.chat import generate_chat_completion @@ -38,11 +28,9 @@ _SHARED_SKIP_DETECTOR_CACHE = {} _SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock() - class Constants: """Centralized configuration constants for the memory system.""" - # Core System Limits MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query @@ -51,52 +39,38 @@ class Constants: DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations - # Cache System MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys - # Retrieval & Similarity SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval - RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = ( - 0.8 # Multiplier for relaxed similarity threshold in secondary operations - ) - EXTENDED_MAX_MEMORY_MULTIPLIER = ( - 1.6 # Multiplier for expanding memory candidates in advanced operations - ) - LLM_RERANKING_TRIGGER_MULTIPLIER = ( - 0.8 # Multiplier for LLM reranking trigger threshold - ) + RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations + EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations + LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold # Skip Detection - SKIP_CATEGORY_MARGIN = ( - 0.5 # Margin above conversational similarity for skip category classification - ) + SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check - # Content Display CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display - CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display - # Default Models DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite" - class Prompts: """Container for all LLM prompts used in the memory system.""" MEMORY_CONSOLIDATION = f"""You are the Memory System Consolidator, a specialist in creating precise user memories. ## OBJECTIVE -Your goal is to build precise memories of the user's personal narrative with factual, temporal statements. +Build precise memories of the user's personal narrative with factual, temporal statements. ## AVAILABLE OPERATIONS - CREATE: For new, personal facts. Must be semantically and temporally enhanced. @@ -106,11 +80,6 @@ Your goal is to build precise memories of the user's personal narrative with fac ## PROCESSING GUIDELINES - Personal Facts Only: Store only significant facts with lasting relevance to the user's life and identity. Exclude transient situations, questions, general knowledge, casual mentions, or momentary states. -- **Filter for Intent:** You MUST SKIP if the user's primary intent is instructional, technical, or analytical, even if the message contains personal details. This includes requests to: - - Rewrite, revise, translate, or proofread a block of text (e.g., "revise this review for me"). - - Answer a general knowledge, math, or technical question. - - Explain a concept, perform a calculation, or act as a persona. - **Only store facts when the user is *directly stating* them as part of a personal narrative, not when providing them as content for a task.** - Maintain Temporal Accuracy: - Capture Dates: Record temporal information when explicitly stated or clearly derivable. Convert relative references (last month, yesterday) to specific dates. - Preserve History: Transform superseded facts into past-tense statements with defined time boundaries. @@ -121,13 +90,14 @@ Your goal is to build precise memories of the user's personal narrative with fac - Retroactive Enrichment: If a name is provided for prior entity, UPDATE only if substantially valuable. - Ensure Memory Quality: - High Bar for Creation: Only CREATE memories for significant life facts, relationships, events, or core personal attributes. Skip trivial details or passing interests. + - Contextual Completeness: Combine related information into cohesive statements. Group connected facts (same topic, person, event, or timeframe) into single memories rather than fragmenting. Include supporting details while respecting boundaries. Only combine directly related facts. Avoid bare statements and never merge unrelated information. - Mandatory Semantic Enhancement: Enhance entities with descriptive categorical nouns for better retrieval. - Verify Nouns/Pronouns: Link pronouns (he, she, they) and nouns to specific entities. - First-Person Format: Write all memories in English from the user's perspective. ## DECISION FRAMEWORK -- Selectivity: Verify the user's *primary intent* is to state a direct, personally significant fact with lasting importance. If the intent is instructional, analytical, or a general question, SKIP. Never create duplicate memories. Skip momentary events or casual mentions. Be conservative with CREATE and UPDATE operations. -- Strategy: Strongly prioritize enriching existing memories over creating new ones. Analyze the message holistically to identify naturally connected facts (same person, event, or timeframe) and combine them into a unified, cohesive memory rather than fragmenting them. Each memory must be self-contained and **never** merge unrelated information. +- Selectivity: Verify the user is stating a direct, personally significant fact with lasting importance. If not, SKIP. Never create duplicate memories. Skip momentary events or casual mentions. Be conservative with CREATE and UPDATE operations. +- Strategy: Strongly prioritize enriching existing memories over creating new ones. Analyze the message holistically to identify naturally connected facts that should be captured together. When facts share connections (same person, event, situation, or causal relationship), combine them into a unified memory that preserves the complete picture. Each memory should be self-contained and meaningful. - Execution: For new significant facts, use CREATE. For simple attribute changes, use UPDATE only if it meaningfully improves the memory. For significant changes, use UPDATE to make the old memory historical, then CREATE the new one. For contradictions, use DELETE. ## EXAMPLES (Assumes Current Date: September 15, 2025) @@ -142,37 +112,37 @@ Explanation: Multiple facts about the same person (Sarah's active lifestyle, lov Message: "My daughter Emma just turned 12. We adopted a dog named Max for her 11th birthday. What should I give her for her 12th birthday?" Memories: [id:mem-002] My daughter Emma is 10 years old [noted at March 20 2024] [id:mem-101] I have a golden retriever [noted at September 20 2024] Return: {{"ops": [{{"operation": "UPDATE", "id": "mem-002", "content": "My daughter Emma turned 12 years old in September 2025"}}, {{"operation": "UPDATE", "id": "mem-101", "content": "I have a golden retriever named Max that was adopted in September 2024 as a birthday gift for my daughter Emma when she turned 11"}}]}} -Explanation: Dog memory enriched with related context (Emma, birthday gift, age 11) and temporal anchoring (September 2024). The instructional question ("What should I give her...?") is ignored as per the 'Filter for Intent' rule. +Explanation: Dog memory enriched with related context (Emma, birthday gift, age 11) and temporal anchoring (September 2024) - all semantically connected to the same event and relationship. ### Example 3 Message: "Can you recommend some good tapas restaurants in Barcelona? I moved here from Madrid last month." Memories: [id:mem-005] I live in Madrid Spain [noted at June 12 2025] Return: {{"ops": [{{"operation": "UPDATE", "id": "mem-005", "content": "I lived in Madrid Spain until August 2025"}}, {{"operation": "CREATE", "id": "", "content": "I moved to Barcelona Spain in August 2025"}}]}} -Explanation: Relocation is a significant life event. The request for recommendations is instructional and is ignored. +Explanation: Relocation is a significant life event with lasting impact. "Exploring the city" and "adjusting" are transient states and excluded. ### Example 4 Message: "My wife Sofia and I just got married in August. What are some good honeymoon destinations?" Memories: [id:mem-008] I am single [noted at January 5 2025] Return: {{"ops": [{{"operation": "DELETE", "id": "mem-008", "content": ""}}, {{"operation": "CREATE", "id": "", "content": "I married Sofia in August 2025 and she is now my wife"}}]}} -Explanation: Marriage is an enduring life event. The instructional question ("What are some good honeymoon destinations?") is ignored. +Explanation: Marriage is an enduring life event. Wife's name and marriage date are lasting facts combined naturally. "Planning honeymoon" is a transient activity and excluded. ### Example 5 Message: "¡Hola! Me mudé de Madrid a Barcelona el mes pasado y me casé con mi novia Sofía en agosto. ¿Me puedes recomendar un buen restaurante para celebrar?" Memories: [id:mem-005] I live in Madrid Spain [noted at June 12 2025] [id:mem-006] I am dating Sofia [noted at February 10 2025] [id:mem-008] I am single [noted at January 5 2025] Return: {{"ops": [{{"operation": "UPDATE", "id": "mem-005", "content": "I lived in Madrid Spain until August 2025"}}, {{"operation": "DELETE", "id": "mem-008", "content": ""}}, {{"operation": "UPDATE", "id": "mem-006", "content": "I moved to Barcelona Spain and married my girlfriend Sofia in August 2025, who is now my wife"}}]}} -Explanation: The user's move and marriage are significant, related life events. They are consolidated into a single memory. The request for a recommendation is ignored. +Explanation: The user's move and marriage are significant, related life events that occurred in the same month. They are consolidated into a single, cohesive memory that enriches the existing relationship context. ### Example 6 Message: "I'm feeling stressed about work this week and looking for some relaxation tips. I have a big presentation coming up on Friday." Memories: [] Return: {{"ops": []}} -Explanation: Transient state (stress) and a request for information (relaxation tips). The primary intent is instructional/analytical, and the facts (presentation) are not significant, lasting personal narrative. Nothing to store. +Explanation: Temporary stress, seeking tips, and upcoming presentation are all transient situations without lasting personal significance. Nothing to store. """ MEMORY_RERANKING = f"""You are the Memory Relevance Analyzer. ## OBJECTIVE -Your goal is to analyze the user's message and select the most relevant memories to personalize the AI's response. Prioritize direct connections and supporting context. +Select relevant memories to personalize the response, prioritizing direct connections and supporting context. ## RELEVANCE CATEGORIES - Direct: Memories explicitly about the query topic, people, or domain. @@ -181,8 +151,9 @@ Your goal is to analyze the user's message and select the most relevant memories ## SELECTION FRAMEWORK - Prioritize Current Info: Give current facts higher relevance than historical ones unless the query is about the past or historical context directly informs the current situation. -- Hierarchy: Prioritize topic matches first (Direct), then context that enhances the response (Contextual), and finally general background (Background). +- Hierarchy: Prioritize Direct → Contextual → Background. - Ordering: Order IDs by relevance, most relevant first. +- Standard: Prioritize topic matches, then context that enhances the response. - Maximum Limit: Return up to {Constants.MAX_MEMORIES_PER_RETRIEVAL} memory IDs. ## EXAMPLES (Assumes Current Date: September 15, 2025) @@ -235,15 +206,9 @@ class Models: class MemoryOperation(StrictModel): """Pydantic model for memory operations with validation.""" - operation: "Models.MemoryOperationType" = Field( - description="Type of memory operation to perform" - ) - content: str = Field( - description="Memory content (required for CREATE/UPDATE, empty for DELETE)" - ) - id: str = Field( - description="Memory ID (empty for CREATE, required for UPDATE/DELETE)" - ) + operation: "Models.MemoryOperationType" = Field(description="Type of memory operation to perform") + content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)") + id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)") def validate_operation(self, existing_memory_ids: Optional[set] = None) -> bool: """Validate the memory operation against existing memory IDs.""" @@ -262,9 +227,7 @@ class Models: class ConsolidationResponse(BaseModel): """Pydantic model for memory consolidation LLM response - object containing array of memory operations.""" - ops: List["Models.MemoryOperation"] = Field( - default_factory=list, description="List of memory operations to execute" - ) + ops: List["Models.MemoryOperation"] = Field(default_factory=list, description="List of memory operations to execute") class MemoryRerankingResponse(BaseModel): """Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" @@ -320,10 +283,7 @@ class UnifiedCacheManager: type_cache = user_cache[cache_type] - if ( - key not in type_cache - and len(type_cache) >= self.max_cache_size_per_type - ): + if key not in type_cache and len(type_cache) >= self.max_cache_size_per_type: evicted_key, _ = type_cache.popitem(last=False) if key in type_cache: @@ -334,9 +294,7 @@ class UnifiedCacheManager: self.caches.move_to_end(user_id) - async def clear_user_cache( - self, user_id: str, cache_type: Optional[str] = None - ) -> int: + async def clear_user_cache(self, user_id: str, cache_type: Optional[str] = None) -> int: """Clear specific cache type for user, or all caches for user if cache_type is None.""" async with self._lock: if user_id not in self.caches: @@ -345,9 +303,7 @@ class UnifiedCacheManager: user_cache = self.caches[user_id] if cache_type is None: - total_cleared = sum( - len(type_cache) for type_cache in user_cache.values() - ) + total_cleared = sum(len(type_cache) for type_cache in user_cache.values()) del self.caches[user_id] return total_cleared else: @@ -368,9 +324,9 @@ class UnifiedCacheManager: class SkipDetector: - """Binary content classifier: personal vs non-personal using semantic analysis.""" + """Semantic-based content classifier using zero-shot classification with category descriptions.""" - NON_PERSONAL_CATEGORY_DESCRIPTIONS = [ + TECHNICAL_CATEGORY_DESCRIPTIONS = [ "programming language syntax, data types like string or integer, algorithm logic, function, method, programming class, object-oriented paradigm, variable scope, control flow, import, module, package, library, framework, recursion, iteration", "software design patterns, creational: singleton, factory, builder; structural: adapter, decorator, facade, proxy; behavioral: observer, strategy, command, mediator, chain of responsibility; abstract interface, polymorphism, composition", "error handling, exception, stack trace, TypeError, NullPointerException, IndexError, segmentation fault, core dump, stack overflow, runtime vs compile-time error, assertion failed, syntax error, null pointer dereference, memory leak, bug", @@ -391,6 +347,9 @@ class SkipDetector: "regex pattern, regular expression matching, groups, capturing, backslash escapes, metacharacters, wildcards, quantifiers, character classes, lookaheads, lookbehinds, alternation, anchors, word boundary, multiline flag, global search", "software testing, unit test, assertion, mock, stub, fixture, test suite, test case, verification, automated QA, validation framework, JUnit, pytest, Jest. Integration, end-to-end (E2E), functional, regression, acceptance testing", "cloud computing platforms, infrastructure as a service (IaaS), PaaS, AWS, Azure, GCP, compute instance, region, availability zone, elasticity, distributed system, virtual machine, container, serverless, Lambda, edge computing, CDN", + ] + + INSTRUCTION_CATEGORY_DESCRIPTIONS = [ "format the output as structured data. Return the answer as JSON with specific keys and values, or as YAML. Organize information into a CSV file or a database-style table with columns and rows. Present as a list of objects or an array.", "style the text presentation. Use markdown formatting like bullet points, a numbered list, or a task list. Organize content into a grid or tabular layout with proper alignment. Create a hierarchical structure with nested elements for clarity.", "adjust the response length. Make the answer shorter, more concise, brief, or condensed. Summarize the key points. Trim down the text to reduce the overall word count or meet a specific character limit. Be less verbose and more direct.", @@ -401,6 +360,9 @@ class SkipDetector: "continue the generated response. Keep going with the explanation or list. Provide more information and finish your thought. Complete the rest of the content or story. Proceed with the next steps. Do not stop until you have concluded.", "act as a specific persona or role. Respond as if you were a pirate, a scientist, or a travel guide. Adopt the character's voice, style, and knowledge base in your answer. Maintain the persona throughout the entire response.", "compare and contrast two or more topics. Explain the similarities and differences between A and B. Provide a detailed analysis of what they have in common and how they diverge. Create a table to highlight the key distinctions.", + ] + + PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS = [ "perform a pure arithmetic calculation with explicit numbers. Solve, multiply, add, subtract, and divide. Compute a numeric expression following the order of operations (PEMDAS/BODMAS). What is 23 plus 456 minus 78 times 9 divided by 3?", "evaluate a mathematical expression containing numbers and operators, such as 2 plus 3 times 4 divided by 5. Solve this numerical problem and compute the final result. Simplify the arithmetic and show the final answer. Calculate 123 * 456.", "convert units between measurement systems with numeric values. Convert 100 kilometers to miles, 72 fahrenheit to celsius, or 5 feet 9 inches to centimeters. Change between metric and imperial for distance, weight, volume, or temperature.", @@ -411,6 +373,9 @@ class SkipDetector: "compute descriptive statistics for a dataset of numbers like 12, 15, 18, 20, 22. Calculate the mean, median, mode, average, and standard deviation. Find the variance, range, quartiles, and percentiles for a given sample distribution.", "calculate health and fitness metrics using a numeric formula. Compute the Body Mass Index (BMI) given a weight in pounds or kilograms and height in feet, inches, or meters. Find my basal metabolic rate (BMR) or target heart rate.", "calculate the time difference between two dates. How many days, hours, or minutes are between two points in time? Find the duration or elapsed time. Act as an age calculator for a birthday or find the time until a future anniversary.", + ] + + EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS = [ "translate the explicitly quoted text 'Hello, how are you?' to a foreign language like Spanish, French, or German. This is a translation instruction that includes the word 'translate' and the source text in quotes for direct conversion.", "how do you say a specific word or phrase in another language? For example, how do you say 'thank you', 'computer', or 'goodbye' in Japanese, Chinese, or Korean? This is a request for a direct translation of a common expression or term.", "convert a block of text or a paragraph from a source language to a target language. Translate the following content to Italian, Arabic, Portuguese, or Russian. This is a language conversion request for a larger piece of text provided.", @@ -421,6 +386,9 @@ class SkipDetector: "how do I say 'I am learning to code' in German? Convert this specific English phrase into its equivalent in another language. This is a request for a practical, conversational phrase translation for personal or professional use.", "translate this informal or slang expression to its colloquial equivalent in Spanish. How would you say 'What's up?' in Japanese in a casual context? This request focuses on capturing the correct tone and nuance of informal language.", "provide the formal and professional translation for 'Please find the attached document for your review' in French. Translate this business email phrase to German, ensuring the terminology and register are appropriate for a corporate context.", + ] + + GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS = [ "proofread the following text for errors. Here is my draft, please check it for typos and mistakes: 'Teh quick brown fox jumpped'. Review, revise, and correct any misspellings or grammatical issues you find in the provided passage.", "correct the grammar in this sentence: 'She don't like it'. Resolve grammatical issues like subject-verb agreement, incorrect verb tense, pronoun reference errors, or misplaced modifiers in the provided text. Address faulty sentence structure.", "check the spelling and punctuation in this passage. Please review the following text and correct any textual errors: 'its a beautiful day, isnt it'. Amend mistakes with commas, periods, apostrophes, quotation marks, colons, or capitalization.", @@ -433,7 +401,7 @@ class SkipDetector: "check my essay for conciseness and remove any redundancy. Help me edit this text to be more direct and to the point. Identify and eliminate wordiness, filler words, and repetitive phrases to strengthen the overall quality of the writing.", ] - PERSONAL_CATEGORY_DESCRIPTIONS = [ + CONVERSATIONAL_CATEGORY_DESCRIPTIONS = [ "discussing my family members, like my spouse, children, parents, or siblings. Mentioning relatives by name or role, such as my husband, wife, son, daughter, mother, or father. Sharing stories or asking questions about my family.", "expressing lasting personal feelings, core values, beliefs, or principles. My worldview, deeply held opinions, philosophy, or moral standards. Things I love, hate, or feel strongly about in life, such as my passion for animal welfare.", "describing my established personal hobbies, regular activities, or consistent interests. My passions and what I do in my leisure time, such as creative outlets like painting, sports like hiking, or other recreational pursuits I enjoy.", @@ -463,7 +431,11 @@ class SkipDetector: class SkipReason(Enum): SKIP_SIZE = "SKIP_SIZE" - SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL" + SKIP_TECHNICAL = "SKIP_TECHNICAL" + SKIP_INSTRUCTION = "SKIP_INSTRUCTION" + SKIP_PURE_MATH = "SKIP_PURE_MATH" + SKIP_TRANSLATION = "SKIP_TRANSLATION" + SKIP_GRAMMAR_PROOFREAD = "SKIP_GRAMMAR_PROOFREAD" STATUS_MESSAGES = { SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations", @@ -476,42 +448,27 @@ class SkipDetector: def __init__( self, - embedding_function: Callable[ - [Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]] - ], + embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]], ): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None self._initialize_reference_embeddings() - def _initialize_reference_embeddings(self) -> None: """Compute and cache embeddings for category descriptions.""" try: - technical_embeddings = self.embedding_function( - self.TECHNICAL_CATEGORY_DESCRIPTIONS - ) + technical_embeddings = self.embedding_function(self.TECHNICAL_CATEGORY_DESCRIPTIONS) - instruction_embeddings = self.embedding_function( - self.INSTRUCTION_CATEGORY_DESCRIPTIONS - ) + instruction_embeddings = self.embedding_function(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) - pure_math_embeddings = self.embedding_function( - self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS - ) + pure_math_embeddings = self.embedding_function(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) - translation_embeddings = self.embedding_function( - self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS - ) + translation_embeddings = self.embedding_function(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) - grammar_embeddings = self.embedding_function( - self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS - ) + grammar_embeddings = self.embedding_function(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) - conversational_embeddings = self.embedding_function( - self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS - ) + conversational_embeddings = self.embedding_function(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS) self._reference_embeddings = { "technical": np.array(technical_embeddings), @@ -520,30 +477,14 @@ class SkipDetector: "translation": np.array(translation_embeddings), "grammar": np.array(grammar_embeddings), "conversational": np.array(conversational_embeddings), - "technical": np.array(technical_embeddings), - "instruction": np.array(instruction_embeddings), - "pure_math": np.array(pure_math_embeddings), - "translation": np.array(translation_embeddings), - "grammar": np.array(grammar_embeddings), - "conversational": np.array(conversational_embeddings), } - total_skip_categories = ( len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) + len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) + len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) + len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) + len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) - len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) - + len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) - + len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) - + len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) - + len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS) - ) - - logger.info( - f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories" ) logger.info( @@ -553,17 +494,12 @@ class SkipDetector: logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}") self._reference_embeddings = None - def validate_message_size( - self, message: str, max_message_chars: int - ) -> Optional[str]: + def validate_message_size(self, message: str, max_message_chars: int) -> Optional[str]: """Validate message size constraints.""" if not message or not message.strip(): return SkipDetector.SkipReason.SKIP_SIZE.value trimmed = message.strip() - if ( - len(trimmed) < Constants.MIN_MESSAGE_CHARS - or len(trimmed) > max_message_chars - ): + if len(trimmed) < Constants.MIN_MESSAGE_CHARS or len(trimmed) > max_message_chars: return SkipDetector.SkipReason.SKIP_SIZE.value return None @@ -571,40 +507,29 @@ class SkipDetector: """Language-agnostic structural pattern detection with high confidence and low false positive rate.""" msg_len = len(message) - # Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references) url_pattern_count = message.count("http://") + message.count("https://") - url_pattern_count = message.count("http://") + message.count("https://") if url_pattern_count >= 5: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64) words = message.split() for word in words: cleaned = word.strip('.,;:!?()[]{}"\'"') - if ( - len(cleaned) > 80 - and cleaned.replace("-", "").replace("_", "").isalnum() - ): + if len(cleaned) > 80 and cleaned.replace("-", "").replace("_", "").isalnum(): return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***) separator_patterns = ["---", "===", "___", "***"] - separator_patterns = ["---", "===", "___", "***"] for pattern in separator_patterns: if message.count(pattern) >= 2: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 4: Command-line patterns with context-aware detection lines_stripped = [line.strip() for line in message.split("\n") if line.strip()] - lines_stripped = [line.strip() for line in message.split("\n") if line.strip()] if lines_stripped: actual_command_lines = 0 for line in lines_stripped: - if line.startswith("$ ") and len(line) > 2: if line.startswith("$ ") and len(line) > 2: parts = line[2:].split() if parts and parts[0].isalnum(): @@ -613,96 +538,59 @@ class SkipDetector: dollar_index = line.find("$ ") if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"): parts = line[dollar_index + 2 :].split() - if ( - parts - and len(parts[0]) > 0 - and ( - parts[0].isalnum() - or parts[0] - in ["curl", "wget", "git", "npm", "pip", "docker"] - ) - ): + if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ["curl", "wget", "git", "npm", "pip", "docker"]): actual_command_lines += 1 - elif line.startswith("# ") and len(line) > 2: elif line.startswith("# ") and len(line) > 2: rest = line[2:].strip() - if rest and not rest[0].isupper() and " " in rest: if rest and not rest[0].isupper() and " " in rest: actual_command_lines += 1 - elif line.startswith("> ") and len(line) > 2: elif line.startswith("> ") and len(line) > 2: pass - if actual_command_lines >= 1 and any( - c in message for c in ["http://", "https://", " | "] - ): + if actual_command_lines >= 1 and any(c in message for c in ["http://", "https://", " | "]): return self.SkipReason.SKIP_TECHNICAL.value if actual_command_lines >= 3: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs) if msg_len > 30: - slash_count = message.count("/") + message.count("\\") - dot_count = message.count(".") slash_count = message.count("/") + message.count("\\") dot_count = message.count(".") path_chars = slash_count + dot_count if path_chars > 10 and (path_chars / msg_len) > 0.15: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 6: Markup character density (structured data) markup_chars = sum(message.count(c) for c in "{}[]<>") - markup_chars = sum(message.count(c) for c in "{}[]<>") if markup_chars >= 6: if markup_chars / msg_len > 0.10: return self.SkipReason.SKIP_TECHNICAL.value curly_count = message.count("{") + message.count("}") - curly_count = message.count("{") + message.count("}") if curly_count >= 10: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 7: Structured nested content with colons (key: value patterns) line_count = message.count("\n") - line_count = message.count("\n") if line_count >= 8: - lines = message.split("\n") lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - colon_lines = sum( - 1 - for line in non_empty_lines - if ":" in line and not line.strip().startswith("#") - ) - indented_lines = sum( - 1 for line in non_empty_lines if line.startswith((" ", "\t")) - ) + colon_lines = sum(1 for line in non_empty_lines if ":" in line and not line.strip().startswith("#")) + indented_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t"))) - if ( - colon_lines / len(non_empty_lines) > 0.4 - and indented_lines / len(non_empty_lines) > 0.5 - ): + if colon_lines / len(non_empty_lines) > 0.4 and indented_lines / len(non_empty_lines) > 0.5: return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 8: Highly structured multi-line content (require markup chars for technical confidence) if line_count > 15: - lines = message.split("\n") lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - markup_in_lines = sum( - 1 for line in non_empty_lines if any(c in line for c in "{}[]<>") - ) - structured_lines = sum( - 1 for line in non_empty_lines if line.startswith((" ", "\t")) - ) + markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in "{}[]<>")) + structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t"))) if markup_in_lines / len(non_empty_lines) > 0.3: - return self.SkipReason.SKIP_NON_PERSONAL.value + return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: technical_keywords = [ "function", @@ -714,21 +602,15 @@ class SkipDetector: "let", "def", ] - if any( - keyword in message.lower() for keyword in technical_keywords - ): + if any(keyword in message.lower() for keyword in technical_keywords): return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists) if line_count >= 3: - lines = message.split("\n") lines = message.split("\n") non_empty_lines = [line for line in lines if line.strip()] if non_empty_lines: - indented_lines = sum( - 1 for line in non_empty_lines if line[0] in (" ", "\t") - ) + indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t")) if indented_lines / len(non_empty_lines) > 0.5: code_indicators = [ "def ", @@ -742,33 +624,25 @@ class SkipDetector: "public ", "private ", ] - if any( - indicator in message.lower() for indicator in code_indicators - ): + if any(indicator in message.lower() for indicator in code_indicators): return self.SkipReason.SKIP_TECHNICAL.value - # Pattern 10: Very high special character ratio (encoded data, technical output) if msg_len > 50: - special_chars = sum( - 1 for c in message if not c.isalnum() and not c.isspace() - ) + special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace()) special_ratio = special_chars / msg_len if special_ratio > 0.35: alphanumeric = sum(1 for c in message if c.isalnum()) if alphanumeric / msg_len < 0.50: return self.SkipReason.SKIP_TECHNICAL.value - return None - def detect_skip_reason( - self, message: str, max_message_chars: int, memory_system: "Filter" - ) -> Optional[str]: + def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) - 2. Binary semantic classification (personal vs non-personal) + 2. Semantic classification (for remaining cases) Returns: Skip reason string if content should be skipped, None otherwise """ @@ -776,29 +650,21 @@ class SkipDetector: if size_issue: return size_issue - fast_skip = self._fast_path_skip_detection(message) if fast_skip: logger.info(f"Fast-path skip: {fast_skip}") return fast_skip - if self._reference_embeddings is None: - logger.warning( - "SkipDetector reference embeddings not initialized, allowing message through" - ) + logger.warning("SkipDetector reference embeddings not initialized, allowing message through") return None - try: message_embedding = np.array(self.embedding_function([message.strip()])[0]) - conversational_similarities = np.dot( - message_embedding, self._reference_embeddings["conversational"].T - ) + conversational_similarities = np.dot(message_embedding, self._reference_embeddings["conversational"].T) max_conversational_similarity = float(conversational_similarities.max()) - skip_categories = [ ( "instruction", @@ -827,36 +693,25 @@ class SkipDetector: ), ] - qualifying_categories = [] - margin_threshold = ( - max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN - ) + margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN for cat_key, skip_reason, descriptions in skip_categories: - similarities = np.dot( - message_embedding, self._reference_embeddings[cat_key].T - ) + similarities = np.dot(message_embedding, self._reference_embeddings[cat_key].T) max_similarity = float(similarities.max()) - if max_similarity > margin_threshold: qualifying_categories.append((max_similarity, cat_key, skip_reason)) - if qualifying_categories: - highest_similarity, highest_cat_key, highest_skip_reason = max( - qualifying_categories, key=lambda x: x[0] - ) + highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0]) logger.info( f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})" ) return highest_skip_reason.value - return None - except Exception as e: logger.error(f"Error in semantic skip detection: {e}") return None @@ -872,10 +727,7 @@ class LLMRerankingService: if not self.memory_system.valves.enable_llm_reranking: return False, "LLM reranking disabled" - llm_trigger_threshold = int( - self.memory_system.valves.max_memories_returned - * self.memory_system.valves.llm_reranking_trigger_multiplier - ) + llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) if len(memories) > llm_trigger_threshold: return ( True, @@ -917,16 +769,12 @@ CANDIDATE MEMORIES: if memory["id"] in response.ids and len(selected_memories) < max_count: selected_memories.append(memory) - logger.info( - f"🧠 LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates" - ) + logger.info(f"🧠 LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates") return selected_memories except Exception as e: - logger.warning( - f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}" - ) + logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}") return candidate_memories async def rerank_memories( @@ -938,9 +786,7 @@ CANDIDATE MEMORIES: start_time = time.time() max_injection = self.memory_system.valves.max_memories_returned - should_use_llm, decision_reason = self._should_use_llm_reranking( - candidate_memories - ) + should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) analysis_info = { "llm_decision": should_use_llm, @@ -949,10 +795,7 @@ CANDIDATE MEMORIES: } if should_use_llm: - extended_count = int( - self.memory_system.valves.max_memories_returned - * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER - ) + extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) llm_candidates = candidate_memories[:extended_count] await self.memory_system._emit_status( emitter, @@ -961,21 +804,16 @@ CANDIDATE MEMORIES: ) logger.info(f"Using LLM reranking: {decision_reason}") - selected_memories = await self._llm_select_memories( - user_message, llm_candidates, max_injection, emitter - ) + selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) if not selected_memories: logger.info("📭 No relevant memories after LLM analysis") - await self.memory_system._emit_status( - emitter, f"📭 No Relevant Memories After LLM Analysis", done=True - ) + await self.memory_system._emit_status(emitter, f"📭 No Relevant Memories After LLM Analysis", done=True) return selected_memories, analysis_info else: logger.info(f"Skipping LLM reranking: {decision_reason}") selected_memories = candidate_memories[:max_injection] - duration = time.time() - start_time duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" retrieval_method = "LLM" if should_use_llm else "Semantic" @@ -993,26 +831,15 @@ class LLMConsolidationService: def __init__(self, memory_system): self.memory_system = memory_system - def _filter_consolidation_candidates( - self, similarities: List[Dict[str, Any]] - ) -> Tuple[List[Dict[str, Any]], str]: + def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: """Filter consolidation candidates by threshold and return candidates with threshold info.""" - consolidation_threshold = self.memory_system._get_retrieval_threshold( - is_consolidation=True - ) - candidates = [ - mem for mem in similarities if mem["relevance"] >= consolidation_threshold - ] + consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) + candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold] - max_consolidation_memories = int( - self.memory_system.valves.max_memories_returned - * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER - ) + max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) candidates = candidates[:max_consolidation_memories] - threshold_info = ( - f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" - ) + threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})" return candidates, threshold_info async def collect_consolidation_candidates( @@ -1023,13 +850,9 @@ class LLMConsolidationService: ) -> List[Dict[str, Any]]: """Collect candidate memories for consolidation analysis using cached or computed similarities.""" if cached_similarities: - candidates, threshold_info = self._filter_consolidation_candidates( - cached_similarities - ) + candidates, threshold_info = self._filter_consolidation_candidates(cached_similarities) - logger.info( - f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" - ) + logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") self.memory_system._log_retrieved_memories(candidates, "consolidation") return candidates @@ -1037,9 +860,7 @@ class LLMConsolidationService: try: user_memories = await self.memory_system._get_user_memories(user_id) except asyncio.TimeoutError: - raise TimeoutError( - f"⏱️ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" - ) + raise TimeoutError(f"⏱️ Memory retrieval timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") except Exception as e: logger.error(f"💾 Failed to retrieve user memories from database: {str(e)}") return [] @@ -1048,32 +869,21 @@ class LLMConsolidationService: logger.info("💭 No existing memories found for consolidation") return [] - logger.info( - f"🚀 Reusing cached user memories for consolidation: {len(user_memories)} memories" - ) + logger.info(f"🚀 Reusing cached user memories for consolidation: {len(user_memories)} memories") try: - all_similarities, _, _ = await self.memory_system._compute_similarities( - user_message, user_id, user_memories - ) + all_similarities, _, _ = await self.memory_system._compute_similarities(user_message, user_id, user_memories) except Exception as e: - logger.error( - f"🔍 Failed to compute memory similarities for retrieval: {str(e)}" - ) + logger.error(f"🔍 Failed to compute memory similarities for retrieval: {str(e)}") return [] if all_similarities: - candidates, threshold_info = self._filter_consolidation_candidates( - all_similarities - ) + candidates, threshold_info = self._filter_consolidation_candidates(all_similarities) else: candidates = [] threshold_info = "N/A" - threshold_info = "N/A" - logger.info( - f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})" - ) + logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})") self.memory_system._log_retrieved_memories(candidates, "consolidation") @@ -1087,9 +897,7 @@ class LLMConsolidationService: ) -> List[Dict[str, Any]]: """Generate consolidation plan using LLM with clear system/user prompt separation.""" if candidate_memories: - memory_lines = self.memory_system._format_memories_for_llm( - candidate_memories - ) + memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) memory_context = f"EXISTING MEMORIES FOR CONSOLIDATION:\n{chr(10).join(memory_lines)}\n\n" else: memory_context = "EXISTING MEMORIES FOR CONSOLIDATION:\n[]\n\nNote: No existing memories found - Focus on extracting new memories from the user message below.\n\n" @@ -1108,64 +916,33 @@ class LLMConsolidationService: timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except Exception as e: - logger.warning( - f"🤖 LLM consolidation failed during memory processing: {str(e)}" - ) - await self.memory_system._emit_status( - emitter, f"⚠️ Memory Consolidation Failed", done=True - ) + logger.warning(f"🤖 LLM consolidation failed during memory processing: {str(e)}") + await self.memory_system._emit_status(emitter, f"⚠️ Memory Consolidation Failed", done=True) return [] operations = response.ops existing_memory_ids = {memory["id"] for memory in candidate_memories} total_operations = len(operations) - delete_operations = [ - op for op in operations if op.operation == Models.MemoryOperationType.DELETE - ] - delete_ratio = ( - len(delete_operations) / total_operations if total_operations > 0 else 0 - ) + delete_operations = [op for op in operations if op.operation == Models.MemoryOperationType.DELETE] + delete_ratio = len(delete_operations) / total_operations if total_operations > 0 else 0 - if ( - delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO - and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK - ): + if delete_ratio > Constants.MAX_DELETE_OPERATIONS_RATIO and total_operations >= Constants.MIN_OPS_FOR_DELETE_RATIO_CHECK: logger.warning( f"⚠️ Consolidation safety: {len(delete_operations)}/{total_operations} operations are deletions ({delete_ratio*100:.1f}%) - rejecting plan" ) return [] - valid_operations = [ - op.model_dump() - for op in operations - if op.validate_operation(existing_memory_ids) - ] + valid_operations = [op.model_dump() for op in operations if op.validate_operation(existing_memory_ids)] if valid_operations: - create_count = sum( - 1 - for op in valid_operations - if op.get("operation") == Models.MemoryOperationType.CREATE.value - ) - update_count = sum( - 1 - for op in valid_operations - if op.get("operation") == Models.MemoryOperationType.UPDATE.value - ) - delete_count = sum( - 1 - for op in valid_operations - if op.get("operation") == Models.MemoryOperationType.DELETE.value - ) + create_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.CREATE.value) + update_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.UPDATE.value) + delete_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.DELETE.value) - operation_details = self.memory_system._build_operation_details( - create_count, update_count, delete_count - ) + operation_details = self.memory_system._build_operation_details(create_count, update_count, delete_count) - logger.info( - f"🎯 Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}" - ) + logger.info(f"🎯 Planned {len(valid_operations)} memory operations: {', '.join(operation_details)}") else: logger.info("🎯 No valid memory operations planned") @@ -1187,9 +964,7 @@ class LLMConsolidationService: timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) except asyncio.TimeoutError: - raise TimeoutError( - f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s" - ) + raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") except Exception as e: raise RuntimeError(f"👤 User lookup failed: {str(e)}") @@ -1205,31 +980,23 @@ class LLMConsolidationService: operations_by_type[operation.operation.value].append(operation) except Exception as e: failed_count += 1 - operation_type = operation_data.get( - "operation", Models.OperationResult.UNSUPPORTED.value - ) + operation_type = operation_data.get("operation", Models.OperationResult.UNSUPPORTED.value) content_preview = "" if "content" in operation_data: content = operation_data.get("content", "") content_preview = f" - Content: {self.memory_system._truncate_content(content, Constants.CONTENT_PREVIEW_LENGTH)}" elif "id" in operation_data: content_preview = f" - ID: {operation_data['id']}" - error_message = ( - f"Failed {operation_type} operation{content_preview}: {str(e)}" - ) + error_message = f"Failed {operation_type} operation{content_preview}: {str(e)}" logger.error(error_message) memory_contents_for_deletion = {} if operations_by_type["DELETE"]: try: user_memories = await self.memory_system._get_user_memories(user_id) - memory_contents_for_deletion = { - str(mem.id): mem.content for mem in user_memories - } + memory_contents_for_deletion = {str(mem.id): mem.content for mem in user_memories} except Exception as e: - logger.warning( - f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}" - ) + logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}") for operation_type, ops in operations_by_type.items(): if not ops: @@ -1245,56 +1012,33 @@ class LLMConsolidationService: for idx, result in enumerate(results): operation = ops[idx] - if isinstance(result, Exception): failed_count += 1 - await self.memory_system._emit_status( - emitter, f"❌ Failed {operation_type}", done=False - ) + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) elif result == Models.MemoryOperationType.CREATE.value: created_count += 1 - content_preview = self.memory_system._truncate_content( - operation.content - ) - await self.memory_system._emit_status( - emitter, f"📝 Created: {content_preview}", done=False - ) + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"📝 Created: {content_preview}", done=False) elif result == Models.MemoryOperationType.UPDATE.value: updated_count += 1 - content_preview = self.memory_system._truncate_content( - operation.content - ) - await self.memory_system._emit_status( - emitter, f"✏️ Updated: {content_preview}", done=False - ) + content_preview = self.memory_system._truncate_content(operation.content) + await self.memory_system._emit_status(emitter, f"✏️ Updated: {content_preview}", done=False) elif result == Models.MemoryOperationType.DELETE.value: deleted_count += 1 - content_preview = memory_contents_for_deletion.get( - operation.id, operation.id - ) + content_preview = memory_contents_for_deletion.get(operation.id, operation.id) if content_preview and content_preview != operation.id: - content_preview = self.memory_system._truncate_content( - content_preview - ) - await self.memory_system._emit_status( - emitter, f"🗑️ Deleted: {content_preview}", done=False - ) + content_preview = self.memory_system._truncate_content(content_preview) + await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) elif result in [ Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value, ]: failed_count += 1 - await self.memory_system._emit_status( - emitter, f"❌ Failed {operation_type}", done=False - ) + await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) except Exception as e: failed_count += len(ops) - logger.error( - f"❌ Batch {operation_type} operations failed during memory consolidation: {str(e)}" - ) - await self.memory_system._emit_status( - emitter, f"❌ Batch {operation_type} Failed", done=False - ) + logger.error(f"❌ Batch {operation_type} operations failed during memory consolidation: {str(e)}") + await self.memory_system._emit_status(emitter, f"❌ Batch {operation_type} Failed", done=False) total_executed = created_count + updated_count + deleted_count logger.info( @@ -1302,9 +1046,7 @@ class LLMConsolidationService: ) if total_executed > 0: - operation_details = self.memory_system._build_operation_details( - created_count, updated_count, deleted_count - ) + operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) logger.info(f"🔄 Memory Operations: {', '.join(operation_details)}") await self.memory_system._refresh_user_cache(user_id) @@ -1323,27 +1065,20 @@ class LLMConsolidationService: if self.memory_system._shutdown_event.is_set(): return - candidates = await self.collect_consolidation_candidates( - user_message, user_id, cached_similarities - ) + candidates = await self.collect_consolidation_candidates(user_message, user_id, cached_similarities) if self.memory_system._shutdown_event.is_set(): return - operations = await self.generate_consolidation_plan( - user_message, candidates, emitter - ) + operations = await self.generate_consolidation_plan(user_message, candidates, emitter) if self.memory_system._shutdown_event.is_set(): return if operations: - created_count, updated_count, deleted_count, failed_count = ( - await self.execute_memory_operations(operations, user_id, emitter) - ) + created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter) duration = time.time() - start_time logger.info(f"💾 Memory Consolidation Complete In {duration:.2f}s") - total_operations = created_count + updated_count + deleted_count if total_operations > 0 or failed_count > 0: await self.memory_system._emit_status( @@ -1352,25 +1087,18 @@ class LLMConsolidationService: done=False, ) - operation_details = self.memory_system._build_operation_details( - created_count, updated_count, deleted_count - ) + operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) memory_word = "Memory" if total_operations == 1 else "Memories" operations_summary = f"{', '.join(operation_details)} {memory_word}" - if failed_count > 0: operations_summary += f" (❌ {failed_count} Failed)" - await self.memory_system._emit_status( - emitter, operations_summary, done=True - ) + await self.memory_system._emit_status(emitter, operations_summary, done=True) except Exception as e: duration = time.time() - start_time - raise RuntimeError( - f"❌ Memory consolidation failed after {duration:.2f}s: {str(e)}" - ) + raise RuntimeError(f"❌ Memory consolidation failed after {duration:.2f}s: {str(e)}") class Filter: @@ -1425,30 +1153,19 @@ class Filter: """Initialize the Memory System filter with production validation.""" global _SHARED_SKIP_DETECTOR_CACHE - self.valves = self.Valves() self._validate_system_configuration() - self._cache_manager = UnifiedCacheManager( - Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES - ) + self._cache_manager = UnifiedCacheManager(Constants.MAX_CACHE_ENTRIES_PER_TYPE, Constants.MAX_CONCURRENT_USER_CACHES) self._background_tasks: set = set() self._shutdown_event = asyncio.Event() self._embedding_function = None - self._embedding_dimension = None self._skip_detector = None self._llm_reranking_service = LLMRerankingService(self) self._llm_consolidation_service = LLMConsolidationService(self) - async def _set_pipeline_context( - self, - __event_emitter__: Optional[Callable] = None, - __user__: Optional[Dict[str, Any]] = None, - __model__: Optional[str] = None, - __request__: Optional[Request] = None, - ) -> None: async def _set_pipeline_context( self, __event_emitter__: Optional[Callable] = None, @@ -1466,32 +1183,22 @@ class Filter: if __request__: self.__request__ = __request__ - if self._embedding_function is None and hasattr( - __request__.app.state, "EMBEDDING_FUNCTION" - ): + if self._embedding_function is None and hasattr(__request__.app.state, "EMBEDDING_FUNCTION"): self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION logger.info(f"✅ Using OpenWebUI's embedding function") - if self._skip_detector is None: global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK - embedding_engine = getattr( - __request__.app.state.config, "RAG_EMBEDDING_ENGINE", "" - ) - embedding_model = getattr( - __request__.app.state.config, "RAG_EMBEDDING_MODEL", "" - ) + embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "") + embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "") cache_key = f"{embedding_engine}:{embedding_model}" - async with _SHARED_SKIP_DETECTOR_CACHE_LOCK: if cache_key in _SHARED_SKIP_DETECTOR_CACHE: logger.info(f"♻️ Reusing cached skip detector: {cache_key}") self._skip_detector = _SHARED_SKIP_DETECTOR_CACHE[cache_key] else: - logger.info( - f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}" - ) + logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") embedding_fn = self._embedding_function def embedding_wrapper( @@ -1500,14 +1207,10 @@ class Filter: result = embedding_fn(texts, prefix=None, user=None) if isinstance(result, list): if isinstance(result[0], list): - return [ - np.array(emb, dtype=np.float16) - for emb in result - ] + return [np.array(emb, dtype=np.float16) for emb in result] return np.array(result, dtype=np.float16) return np.array(result, dtype=np.float16) - self._skip_detector = SkipDetector(embedding_wrapper) _SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector logger.info(f"✅ Skip detector initialized and cached") @@ -1521,10 +1224,7 @@ class Filter: def _get_retrieval_threshold(self, is_consolidation: bool = False) -> float: """Calculate retrieval threshold for semantic similarity filtering.""" if is_consolidation: - return ( - self.valves.semantic_retrieval_threshold - * self.valves.relaxed_semantic_threshold_multiplier - ) + return self.valves.semantic_retrieval_threshold * self.valves.relaxed_semantic_threshold_multiplier return self.valves.semantic_retrieval_threshold def _extract_text_from_content(self, content) -> str: @@ -1546,61 +1246,38 @@ class Filter: raise ValueError("🤖 Model not specified") if self.valves.max_memories_returned <= 0: - raise ValueError( - f"📊 Invalid max memories returned: {self.valves.max_memories_returned}" - ) + raise ValueError(f"📊 Invalid max memories returned: {self.valves.max_memories_returned}") if not (0.0 <= self.valves.semantic_retrieval_threshold <= 1.0): - raise ValueError( - f"🎯 Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)" - ) + raise ValueError(f"🎯 Invalid semantic retrieval threshold: {self.valves.semantic_retrieval_threshold} (must be 0.0-1.0)") logger.info("✅ Configuration validated") async def _get_embedding_cache(self, user_id: str, key: str) -> Optional[Any]: """Get embedding from cache.""" - return await self._cache_manager.get( - user_id, self._cache_manager.EMBEDDING_CACHE, key - ) + return await self._cache_manager.get(user_id, self._cache_manager.EMBEDDING_CACHE, key) async def _put_embedding_cache(self, user_id: str, key: str, value: Any) -> None: """Store embedding in cache.""" - await self._cache_manager.put( - user_id, self._cache_manager.EMBEDDING_CACHE, key, value - ) + await self._cache_manager.put(user_id, self._cache_manager.EMBEDDING_CACHE, key, value) def _compute_text_hash(self, text: str) -> str: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() - def _normalize_embedding( - self, embedding: Union[List[float], np.ndarray] - ) -> np.ndarray: + def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: """Normalize embedding vector.""" if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float16) else: embedding = embedding.astype(np.float16) - - embedding = np.squeeze(embedding) - - if embedding.ndim != 1: - raise ValueError(f"Embedding must be 1D after squeeze, got shape {embedding.shape}") - - if self._embedding_dimension and embedding.shape[0] != self._embedding_dimension: - raise ValueError(f"Embedding must have {self._embedding_dimension} dimensions, got {embedding.shape[0]}") - norm = np.linalg.norm(embedding) return embedding / norm if norm > 0 else embedding - async def _generate_embeddings( - self, texts: Union[str, List[str]], user_id: str - ) -> Union[np.ndarray, List[np.ndarray]]: + async def _generate_embeddings(self, texts: Union[str, List[str]], user_id: str) -> Union[np.ndarray, List[np.ndarray]]: """Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function.""" if self._embedding_function is None: - raise RuntimeError( - "🤖 Embedding function not initialized. Ensure pipeline context is set." - ) + raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.") is_single = isinstance(texts, str) text_list = [texts] if is_single else texts @@ -1634,22 +1311,14 @@ class Filter: uncached_hashes.append(text_hash) if uncached_texts: - user = ( - await asyncio.to_thread(Users.get_user_by_id, user_id) - if hasattr(self, "__user__") - else None - ) + user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None loop = asyncio.get_event_loop() - raw_embeddings = await loop.run_in_executor( - None, self._embedding_function, uncached_texts, None, user - ) + raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user) if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: if isinstance(raw_embeddings[0], list): - new_embeddings = [ - self._normalize_embedding(emb) for emb in raw_embeddings - ] + new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] else: new_embeddings = [self._normalize_embedding(raw_embeddings)] else: @@ -1662,11 +1331,7 @@ class Filter: result_embeddings[original_idx] = embedding if is_single: - logger.info( - "📥 User message embedding: cache hit" - if not uncached_texts - else "💾 User message embedding: generated and cached" - ) + logger.info("📥 User message embedding: cache hit" if not uncached_texts else "💾 User message embedding: generated and cached") return result_embeddings[0] else: valid_count = sum(1 for emb in result_embeddings if emb is not None) @@ -1677,17 +1342,13 @@ class Filter: if self._skip_detector is None: raise RuntimeError("🤖 Skip detector not initialized") - skip_reason = self._skip_detector.detect_skip_reason( - user_message, self.valves.max_message_chars, memory_system=self - ) + skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self) if skip_reason: status_key = SkipDetector.SkipReason(skip_reason) return True, SkipDetector.STATUS_MESSAGES[status_key] return False, "" - def _process_user_message( - self, body: Dict[str, Any] - ) -> Tuple[Optional[str], bool, str]: + def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" if not body or "messages" not in body or not isinstance(body["messages"], list): return ( @@ -1719,9 +1380,7 @@ class Filter: should_skip, skip_reason = self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason - async def _get_user_memories( - self, user_id: str, timeout: Optional[float] = None - ) -> List: + async def _get_user_memories(self, user_id: str, timeout: Optional[float] = None) -> List: """Get user memories with timeout handling.""" if timeout is None: timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC @@ -1735,9 +1394,7 @@ class Filter: except Exception as e: raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}") - def _log_retrieved_memories( - self, memories: List[Dict[str, Any]], context_type: str = "semantic" - ) -> None: + def _log_retrieved_memories(self, memories: List[Dict[str, Any]], context_type: str = "semantic") -> None: """Log retrieved memories with concise formatting showing key statistics and semantic values.""" if not memories: return @@ -1751,27 +1408,15 @@ class Filter: lowest_score = min(scores) median_score = statistics.median(scores) - context_label = ( - "📊 Consolidation candidate memories" - if context_type == "consolidation" - else "📊 Retrieved memories" - ) - max_scores_to_show = int( - self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER - ) - scores_str = ", ".join( - [f"{score:.3f}" for score in scores[:max_scores_to_show]] - ) + context_label = "📊 Consolidation candidate memories" if context_type == "consolidation" else "📊 Retrieved memories" + max_scores_to_show = int(self.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) + scores_str = ", ".join([f"{score:.3f}" for score in scores[:max_scores_to_show]]) suffix = "..." if len(scores) > max_scores_to_show else "" - logger.info( - f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}" - ) + logger.info(f"{context_label}: {len(memories)} memories | Top: {top_score:.3f} | Median: {median_score:.3f} | Lowest: {lowest_score:.3f}") logger.info(f"Scores: [{scores_str}{suffix}]") - def _build_operation_details( - self, created_count: int, updated_count: int, deleted_count: int - ) -> List[str]: + def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]: operations = [ (created_count, "📝 Created"), (updated_count, "✏️ Updated"), @@ -1779,14 +1424,10 @@ class Filter: ] return [f"{label} {count}" for count, label in operations if count > 0] - def _cache_key( - self, cache_type: str, user_id: str, content: Optional[str] = None - ) -> str: + def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: """Unified cache key generation for all cache types.""" if content: - content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[ - : Constants.CACHE_KEY_HASH_PREFIX_LENGTH - ] + content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH] return f"{cache_type}_{user_id}:{content_hash}" return f"{cache_type}_{user_id}" @@ -1806,9 +1447,7 @@ class Filter: if record_date: try: if isinstance(record_date, str): - parsed_date = datetime.fromisoformat( - record_date.replace("Z", "+00:00") - ) + parsed_date = datetime.fromisoformat(record_date.replace("Z", "+00:00")) else: parsed_date = record_date formatted_date = parsed_date.strftime("%b %d %Y") @@ -1819,9 +1458,7 @@ class Filter: memory_lines.append(line) return memory_lines - async def _emit_status( - self, emitter: Optional[Callable], description: str, done: bool = True - ) -> None: + async def _emit_status(self, emitter: Optional[Callable], description: str, done: bool = True) -> None: """Emit status messages for memory operations.""" if not emitter: return @@ -1845,19 +1482,9 @@ class Filter: ) -> Dict[str, Any]: """Retrieve memories for injection using similarity computation with optional LLM reranking.""" if cached_similarities is not None: - memories = [ - m - for m in cached_similarities - if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold - ] - logger.info( - f"🔍 Using cached similarities for {len(memories)} candidate memories" - ) - final_memories, reranking_info = ( - await self._llm_reranking_service.rerank_memories( - user_message, memories, emitter - ) - ) + memories = [m for m in cached_similarities if m.get("relevance", 0) >= self.valves.semantic_retrieval_threshold] + logger.info(f"🔍 Using cached similarities for {len(memories)} candidate memories") + final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) self._log_retrieved_memories(final_memories, "semantic") return { "memories": final_memories, @@ -1874,16 +1501,10 @@ class Filter: await self._emit_status(emitter, "📭 No Memories Found", done=True) return {"memories": [], "threshold": None} - memories, threshold, all_similarities = await self._compute_similarities( - user_message, user_id, user_memories - ) + memories, threshold, all_similarities = await self._compute_similarities(user_message, user_id, user_memories) if memories: - final_memories, reranking_info = ( - await self._llm_reranking_service.rerank_memories( - user_message, memories, emitter - ) - ) + final_memories, reranking_info = await self._llm_reranking_service.rerank_memories(user_message, memories, emitter) else: logger.info("📭 No relevant memories found above similarity threshold") await self._emit_status(emitter, "📭 No Relevant Memories Found", done=True) @@ -1919,15 +1540,12 @@ class Filter: memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request." formatted_memories = [] - for idx, memory in enumerate(memories, 1): formatted_memory = f"- {' '.join(memory['content'].split())}" formatted_memories.append(formatted_memory) content_preview = self._truncate_content(memory["content"]) - await self._emit_status( - emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False - ) + await self._emit_status(emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False) memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only." memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}" @@ -1936,22 +1554,15 @@ class Filter: memory_context = "\n\n".join(content_parts) system_index = next( - ( - i - for i, msg in enumerate(body["messages"]) - if msg.get("role") == "system" - ), + (i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None, ) if system_index is not None: - body["messages"][system_index][ - "content" - ] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" + body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" else: body["messages"].insert(0, {"role": "system", "content": memory_context}) - if memories and user_id: description = f"🧠 Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context" await self._emit_status(emitter, description, done=True) @@ -1964,13 +1575,9 @@ class Filter: "relevance": similarity, } if hasattr(memory, "created_at") and memory.created_at: - memory_dict["created_at"] = datetime.fromtimestamp( - memory.created_at, tz=timezone.utc - ).isoformat() + memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat() if hasattr(memory, "updated_at") and memory.updated_at: - memory_dict["updated_at"] = datetime.fromtimestamp( - memory.updated_at, tz=timezone.utc - ).isoformat() + memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat() return memory_dict async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]: @@ -1983,9 +1590,7 @@ class Filter: memory_embeddings = await self._generate_embeddings(memory_contents, user_id) if len(memory_embeddings) != len(user_memories): - logger.error( - f"🔢 Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories" - ) + logger.error(f"🔢 Embedding generation failed: generated {len(memory_embeddings)} embeddings but expected {len(user_memories)} for user memories") return [], self.valves.semantic_retrieval_threshold, [] similarity_scores = [] @@ -2008,7 +1613,6 @@ class Filter: threshold = self.valves.semantic_retrieval_threshold filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] - filtered_memories = [m for m in memory_data if m["relevance"] >= threshold] return filtered_memories, threshold, memory_data async def inlet( @@ -2035,9 +1639,7 @@ class Filter: self.valves.model = model_to_use - await self._set_pipeline_context( - __event_emitter__, __user__, model_to_use, __request__ - ) + await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: @@ -2050,12 +1652,8 @@ class Filter: await self._add_memory_context(body, [], user_id, __event_emitter__) return body try: - memory_cache_key = self._cache_key( - self._cache_manager.MEMORY_CACHE, user_id - ) - user_memories = await self._cache_manager.get( - user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key - ) + memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) + user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key) if user_memories is None: user_memories = await self._get_user_memories(user_id) await self._cache_manager.put( @@ -2064,16 +1662,12 @@ class Filter: memory_cache_key, user_memories, ) - retrieval_result = await self._retrieve_relevant_memories( - user_message, user_id, user_memories, __event_emitter__ - ) + retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) memories = retrieval_result.get("memories", []) threshold = retrieval_result.get("threshold") all_similarities = retrieval_result.get("all_similarities", []) if all_similarities: - cache_key = self._cache_key( - self._cache_manager.RETRIEVAL_CACHE, user_id, user_message - ) + cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) await self._cache_manager.put( user_id, self._cache_manager.RETRIEVAL_CACHE, @@ -2109,9 +1703,7 @@ class Filter: self.valves.model = model_to_use - await self._set_pipeline_context( - __event_emitter__, __user__, model_to_use, __request__ - ) + await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: @@ -2119,17 +1711,9 @@ class Filter: user_message, should_skip, skip_reason = self._process_user_message(body) if not user_message or should_skip: return body - cache_key = self._cache_key( - self._cache_manager.RETRIEVAL_CACHE, user_id, user_message - ) - cached_similarities = await self._cache_manager.get( - user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key - ) - task = asyncio.create_task( - self._llm_consolidation_service.run_consolidation_pipeline( - user_message, user_id, __event_emitter__, cached_similarities - ) - ) + cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) + cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) + task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) self._background_tasks.add(task) def safe_cleanup(t: asyncio.Task) -> None: @@ -2137,9 +1721,7 @@ class Filter: self._background_tasks.discard(t) if t.exception() and not t.cancelled(): exception = t.exception() - logger.error( - f"❌ Background memory consolidation task failed: {str(exception)}" - ) + logger.error(f"❌ Background memory consolidation task failed: {str(exception)}") except Exception as e: logger.error(f"❌ Failed to cleanup background memory task: {str(e)}") @@ -2160,25 +1742,15 @@ class Filter: """Refresh user cache - clear stale caches and update with fresh embeddings.""" start_time = time.time() try: - retrieval_cleared = await self._cache_manager.clear_user_cache( - user_id, self._cache_manager.RETRIEVAL_CACHE - ) - embedding_cleared = await self._cache_manager.clear_user_cache( - user_id, self._cache_manager.EMBEDDING_CACHE - ) - logger.info( - f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}" - ) + retrieval_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.RETRIEVAL_CACHE) + embedding_cleared = await self._cache_manager.clear_user_cache(user_id, self._cache_manager.EMBEDDING_CACHE) + logger.info(f"🔄 Cleared {retrieval_cleared} retrieval + {embedding_cleared} embedding cache entries for user {user_id}") user_memories = await self._get_user_memories(user_id) - memory_cache_key = self._cache_key( - self._cache_manager.MEMORY_CACHE, user_id - ) + memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) if not user_memories: - await self._cache_manager.put( - user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, [] - ) + await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, []) logger.info("📭 No memories found for user") return @@ -2189,28 +1761,17 @@ class Filter: user_memories, ) - memory_contents = [ - memory.content - for memory in user_memories - if memory.content - and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS - ] + memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS] if memory_contents: await self._generate_embeddings(memory_contents, user_id) duration = time.time() - start_time - logger.info( - f"🔄 Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s" - ) + logger.info(f"🔄 Cache updated with {len(memory_contents)} embeddings for user {user_id} in {duration:.2f}s") except Exception as e: - raise RuntimeError( - f"🧹 Failed to refresh cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}" - ) + raise RuntimeError(f"🧹 Failed to refresh cache for user {user_id} after {(time.time() - start_time):.2f}s: {str(e)}") - async def _execute_single_operation( - self, operation: Models.MemoryOperation, user: Any - ) -> str: + async def _execute_single_operation(self, operation: Models.MemoryOperation, user: Any) -> str: """Execute a single memory operation.""" try: if operation.operation == Models.MemoryOperationType.CREATE: @@ -2220,9 +1781,7 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread( - Memories.insert_new_memory, user.id, content_stripped - ), + asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.CREATE.value @@ -2233,12 +1792,9 @@ class Filter: logger.warning(f"⚠️ Skipping UPDATE operation: empty ID") return Models.OperationResult.SKIPPED_EMPTY_ID.value - content_stripped = operation.content.strip() if not content_stripped: - logger.warning( - f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content" - ) + logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content") return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( @@ -2259,9 +1815,7 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_ID.value await asyncio.wait_for( - asyncio.to_thread( - Memories.delete_memory_by_id_and_user_id, id_stripped, user.id - ), + asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.DELETE.value @@ -2270,56 +1824,33 @@ class Filter: return Models.OperationResult.UNSUPPORTED.value except Exception as e: - logger.error( - f"💾 Database operation failed for {operation.operation.value}: {str(e)}" - ) + logger.error(f"💾 Database operation failed for {operation.operation.value}: {str(e)}") return Models.OperationResult.FAILED.value - def _remove_refs_from_schema( - self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None - ) -> Dict[str, Any]: + def _remove_refs_from_schema(self, schema: Dict[str, Any], schema_defs: Optional[Dict[str, Any]] = None) -> Dict[str, Any]: """Remove $ref references and ensure required fields for Azure OpenAI.""" if not isinstance(schema, dict): return schema - if "$ref" in schema: - ref_path = schema["$ref"] - if ref_path.startswith("#/$defs/"): - def_name = ref_path.split("/")[-1] - if "$ref" in schema: ref_path = schema["$ref"] if ref_path.startswith("#/$defs/"): def_name = ref_path.split("/")[-1] if schema_defs and def_name in schema_defs: - return self._remove_refs_from_schema( - schema_defs[def_name].copy(), schema_defs - ) + return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs) return {"type": "object"} result = {} for key, value in schema.items(): - if key == "$defs": if key == "$defs": continue elif isinstance(value, dict): result[key] = self._remove_refs_from_schema(value, schema_defs) elif isinstance(value, list): - result[key] = [ - ( - self._remove_refs_from_schema(item, schema_defs) - if isinstance(item, dict) - else item - ) - for item in value - ] + result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value] else: result[key] = value - if result.get("type") == "object" and "properties" in result: - result["required"] = list(result["properties"].keys()) - - if result.get("type") == "object" and "properties" in result: result["required"] = list(result["properties"].keys()) @@ -2333,9 +1864,7 @@ class Filter: ) -> Union[str, BaseModel]: """Query OpenWebUI's internal model system with Pydantic model parsing.""" if not hasattr(self, "__request__") or not hasattr(self, "__user__"): - raise RuntimeError( - "🔧 Pipeline interface not properly initialized. __request__ and __user__ required." - ) + raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.") model_to_use = self.valves.model if self.valves.model else self.__model__ if not model_to_use: @@ -2354,7 +1883,6 @@ class Filter: if response_model: raw_schema = response_model.model_json_schema() schema_defs = raw_schema.get("$defs", {}) - schema_defs = raw_schema.get("$defs", {}) schema = self._remove_refs_from_schema(raw_schema, schema_defs) schema["type"] = "object" form_data["response_format"] = { @@ -2371,23 +1899,17 @@ class Filter: generate_chat_completion( self.__request__, form_data, - user=await asyncio.to_thread( - Users.get_user_by_id, self.__user__["id"] - ), + user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]), ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except asyncio.TimeoutError: - raise TimeoutError( - f"⏱️ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s" - ) + raise TimeoutError(f"⏱️ LLM query timed out after {Constants.LLM_CONSOLIDATION_TIMEOUT_SEC}s") except Exception as e: raise RuntimeError(f"🤖 LLM query failed: {str(e)}") try: - if hasattr(response, "body") and hasattr( - getattr(response, "body", None), "decode" - ): + if hasattr(response, "body") and hasattr(getattr(response, "body", None), "decode"): body = getattr(response, "body") response_data = json.loads(body.decode("utf-8")) else: @@ -2395,19 +1917,8 @@ class Filter: except (json.JSONDecodeError, AttributeError) as e: raise RuntimeError(f"🔍 Failed to decode response body: {str(e)}") - if ( - isinstance(response_data, dict) - and "choices" in response_data - and isinstance(response_data["choices"], list) - and len(response_data["choices"]) > 0 - ): + if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0: first_choice = response_data["choices"][0] - if ( - isinstance(first_choice, dict) - and "message" in first_choice - and isinstance(first_choice["message"], dict) - and "content" in first_choice["message"] - ): if ( isinstance(first_choice, dict) and "message" in first_choice @@ -2416,26 +1927,19 @@ class Filter: ): content = first_choice["message"]["content"] else: - raise ValueError( - "🤖 Invalid response structure: missing content in message" - ) + raise ValueError("🤖 Invalid response structure: missing content in message") else: raise ValueError(f"🤖 Unexpected LLM response format: {response_data}") if response_model: - try: try: parsed_data = json.loads(content) return response_model.model_validate(parsed_data) except json.JSONDecodeError as e: - raise ValueError( - f"🔍 Invalid JSON from LLM: {str(e)}\nContent: {content}" - ) + raise ValueError(f"🔍 Invalid JSON from LLM: {str(e)}\nContent: {content}") except PydanticValidationError as e: - raise ValueError( - f"🔍 LLM response validation failed: {str(e)}\nContent: {content}" - ) + raise ValueError(f"🔍 LLM response validation failed: {str(e)}\nContent: {content}") if not content or content.strip() == "": raise ValueError("🤖 Empty response from LLM") - return content \ No newline at end of file + return content