mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
♻️ (memory_system.py): reformat code for consistency, readability, and maintainability
- Reorder and group imports for clarity and PEP8 compliance. - Standardize string quoting and whitespace for consistency. - Refactor long function signatures and dictionary constructions for better readability. - Use double quotes for all string literals and dictionary keys. - Improve formatting of multiline statements and function calls. - Add or adjust line breaks to keep lines within recommended length. - Reformat class and method docstrings for clarity. - Use consistent indentation and spacing throughout the file. These changes improve code readability, maintainability, and consistency, making it easier for future contributors to understand and modify the codebase. No functional logic is changed.
This commit is contained in:
423
memory_system.py
423
memory_system.py
@@ -15,21 +15,22 @@ from enum import Enum
|
|||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError
|
from fastapi import Request
|
||||||
|
|
||||||
from open_webui.utils.chat import generate_chat_completion
|
|
||||||
from open_webui.models.users import Users
|
from open_webui.models.users import Users
|
||||||
from open_webui.routers.memories import Memories
|
from open_webui.routers.memories import Memories
|
||||||
from fastapi import Request
|
from open_webui.utils.chat import generate_chat_completion
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
from pydantic import ValidationError as PydanticValidationError
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
_SHARED_SKIP_DETECTOR_CACHE = {}
|
_SHARED_SKIP_DETECTOR_CACHE = {}
|
||||||
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
|
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
|
||||||
|
|
||||||
|
|
||||||
class Constants:
|
class Constants:
|
||||||
"""Centralized configuration constants for the memory system."""
|
"""Centralized configuration constants for the memory system."""
|
||||||
|
|
||||||
# Core System Limits
|
# Core System Limits
|
||||||
MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content
|
MAX_MEMORY_CONTENT_CHARS = 500 # Character limit for LLM prompt memory content
|
||||||
MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query
|
MAX_MEMORIES_PER_RETRIEVAL = 10 # Maximum memories returned per query
|
||||||
@@ -37,31 +38,32 @@ class Constants:
|
|||||||
MIN_MESSAGE_CHARS = 10 # Minimum message length for validation
|
MIN_MESSAGE_CHARS = 10 # Minimum message length for validation
|
||||||
DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup
|
DATABASE_OPERATION_TIMEOUT_SEC = 10 # Timeout for DB operations like user lookup
|
||||||
LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations
|
LLM_CONSOLIDATION_TIMEOUT_SEC = 60.0 # Timeout for LLM consolidation operations
|
||||||
|
|
||||||
# Cache System
|
# Cache System
|
||||||
MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type
|
MAX_CACHE_ENTRIES_PER_TYPE = 500 # Maximum cache entries per cache type
|
||||||
MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances
|
MAX_CONCURRENT_USER_CACHES = 50 # Maximum concurrent user cache instances
|
||||||
CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys
|
CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys
|
||||||
|
|
||||||
# Retrieval & Similarity
|
# Retrieval & Similarity
|
||||||
SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval
|
SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval
|
||||||
RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations
|
RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations
|
||||||
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations
|
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations
|
||||||
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold
|
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold
|
||||||
|
|
||||||
# Skip Detection
|
# Skip Detection
|
||||||
SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification
|
SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification
|
||||||
|
|
||||||
# Safety & Operations
|
# Safety & Operations
|
||||||
MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety
|
MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety
|
||||||
MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check
|
MIN_OPS_FOR_DELETE_RATIO_CHECK = 6 # Minimum operations to apply ratio check
|
||||||
|
|
||||||
# Content Display
|
# Content Display
|
||||||
CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display
|
CONTENT_PREVIEW_LENGTH = 80 # Maximum length for content preview display
|
||||||
|
|
||||||
# Default Models
|
# Default Models
|
||||||
DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite"
|
DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite"
|
||||||
|
|
||||||
|
|
||||||
class Prompts:
|
class Prompts:
|
||||||
"""Container for all LLM prompts used in the memory system."""
|
"""Container for all LLM prompts used in the memory system."""
|
||||||
|
|
||||||
@@ -181,6 +183,7 @@ Return: {{"ids": []}}
|
|||||||
Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained.
|
Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
class Models:
|
class Models:
|
||||||
"""Container for all Pydantic models used in the memory system."""
|
"""Container for all Pydantic models used in the memory system."""
|
||||||
|
|
||||||
@@ -203,7 +206,7 @@ class Models:
|
|||||||
class MemoryOperation(StrictModel):
|
class MemoryOperation(StrictModel):
|
||||||
"""Pydantic model for memory operations with validation."""
|
"""Pydantic model for memory operations with validation."""
|
||||||
|
|
||||||
operation: 'Models.MemoryOperationType' = Field(description="Type of memory operation to perform")
|
operation: "Models.MemoryOperationType" = Field(description="Type of memory operation to perform")
|
||||||
content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)")
|
content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)")
|
||||||
id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)")
|
id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)")
|
||||||
|
|
||||||
@@ -221,7 +224,7 @@ class Models:
|
|||||||
class ConsolidationResponse(BaseModel):
|
class ConsolidationResponse(BaseModel):
|
||||||
"""Pydantic model for memory consolidation LLM response - object containing array of memory operations."""
|
"""Pydantic model for memory consolidation LLM response - object containing array of memory operations."""
|
||||||
|
|
||||||
ops: List['Models.MemoryOperation'] = Field(default_factory=list, description="List of memory operations to execute")
|
ops: List["Models.MemoryOperation"] = Field(default_factory=list, description="List of memory operations to execute")
|
||||||
|
|
||||||
class MemoryRerankingResponse(BaseModel):
|
class MemoryRerankingResponse(BaseModel):
|
||||||
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
|
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
|
||||||
@@ -442,52 +445,42 @@ class SkipDetector:
|
|||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self._reference_embeddings = None
|
self._reference_embeddings = None
|
||||||
self._initialize_reference_embeddings()
|
self._initialize_reference_embeddings()
|
||||||
|
|
||||||
def _initialize_reference_embeddings(self) -> None:
|
def _initialize_reference_embeddings(self) -> None:
|
||||||
"""Compute and cache embeddings for category descriptions."""
|
"""Compute and cache embeddings for category descriptions."""
|
||||||
try:
|
try:
|
||||||
technical_embeddings = self.embedding_function(
|
technical_embeddings = self.embedding_function(self.TECHNICAL_CATEGORY_DESCRIPTIONS)
|
||||||
self.TECHNICAL_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
instruction_embeddings = self.embedding_function(self.INSTRUCTION_CATEGORY_DESCRIPTIONS)
|
||||||
|
|
||||||
instruction_embeddings = self.embedding_function(
|
pure_math_embeddings = self.embedding_function(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS)
|
||||||
self.INSTRUCTION_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
translation_embeddings = self.embedding_function(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS)
|
||||||
|
|
||||||
pure_math_embeddings = self.embedding_function(
|
grammar_embeddings = self.embedding_function(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
|
||||||
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
conversational_embeddings = self.embedding_function(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)
|
||||||
|
|
||||||
translation_embeddings = self.embedding_function(
|
|
||||||
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
grammar_embeddings = self.embedding_function(
|
|
||||||
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
conversational_embeddings = self.embedding_function(
|
|
||||||
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS
|
|
||||||
)
|
|
||||||
|
|
||||||
self._reference_embeddings = {
|
self._reference_embeddings = {
|
||||||
'technical': np.array(technical_embeddings),
|
"technical": np.array(technical_embeddings),
|
||||||
'instruction': np.array(instruction_embeddings),
|
"instruction": np.array(instruction_embeddings),
|
||||||
'pure_math': np.array(pure_math_embeddings),
|
"pure_math": np.array(pure_math_embeddings),
|
||||||
'translation': np.array(translation_embeddings),
|
"translation": np.array(translation_embeddings),
|
||||||
'grammar': np.array(grammar_embeddings),
|
"grammar": np.array(grammar_embeddings),
|
||||||
'conversational': np.array(conversational_embeddings),
|
"conversational": np.array(conversational_embeddings),
|
||||||
}
|
}
|
||||||
|
|
||||||
total_skip_categories = (
|
total_skip_categories = (
|
||||||
len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) +
|
len(self.TECHNICAL_CATEGORY_DESCRIPTIONS)
|
||||||
len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) +
|
+ len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS)
|
||||||
len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) +
|
+ len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS)
|
||||||
len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) +
|
+ len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS)
|
||||||
len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
|
+ len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories"
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.info(f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}")
|
logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}")
|
||||||
self._reference_embeddings = None
|
self._reference_embeddings = None
|
||||||
@@ -504,108 +497,107 @@ class SkipDetector:
|
|||||||
def _fast_path_skip_detection(self, message: str) -> Optional[str]:
|
def _fast_path_skip_detection(self, message: str) -> Optional[str]:
|
||||||
"""Language-agnostic structural pattern detection with high confidence and low false positive rate."""
|
"""Language-agnostic structural pattern detection with high confidence and low false positive rate."""
|
||||||
msg_len = len(message)
|
msg_len = len(message)
|
||||||
|
|
||||||
# Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references)
|
# Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references)
|
||||||
url_pattern_count = message.count('http://') + message.count('https://')
|
url_pattern_count = message.count("http://") + message.count("https://")
|
||||||
if url_pattern_count >= 5:
|
if url_pattern_count >= 5:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64)
|
# Pattern 2: Long unbroken alphanumeric strings (tokens, hashes, base64)
|
||||||
words = message.split()
|
words = message.split()
|
||||||
for word in words:
|
for word in words:
|
||||||
cleaned = word.strip('.,;:!?()[]{}"\'"')
|
cleaned = word.strip('.,;:!?()[]{}"\'"')
|
||||||
if len(cleaned) > 80 and cleaned.replace('-', '').replace('_', '').isalnum():
|
if len(cleaned) > 80 and cleaned.replace("-", "").replace("_", "").isalnum():
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***)
|
# Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***)
|
||||||
separator_patterns = ['---', '===', '___', '***']
|
separator_patterns = ["---", "===", "___", "***"]
|
||||||
for pattern in separator_patterns:
|
for pattern in separator_patterns:
|
||||||
if message.count(pattern) >= 2:
|
if message.count(pattern) >= 2:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 4: Command-line patterns with context-aware detection
|
# Pattern 4: Command-line patterns with context-aware detection
|
||||||
lines_stripped = [line.strip() for line in message.split('\n') if line.strip()]
|
lines_stripped = [line.strip() for line in message.split("\n") if line.strip()]
|
||||||
if lines_stripped:
|
if lines_stripped:
|
||||||
actual_command_lines = 0
|
actual_command_lines = 0
|
||||||
for line in lines_stripped:
|
for line in lines_stripped:
|
||||||
if line.startswith('$ ') and len(line) > 2:
|
if line.startswith("$ ") and len(line) > 2:
|
||||||
parts = line[2:].split()
|
parts = line[2:].split()
|
||||||
if parts and parts[0].isalnum():
|
if parts and parts[0].isalnum():
|
||||||
actual_command_lines += 1
|
actual_command_lines += 1
|
||||||
elif '$ ' in line:
|
elif "$ " in line:
|
||||||
dollar_index = line.find('$ ')
|
dollar_index = line.find("$ ")
|
||||||
if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'):
|
if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"):
|
||||||
parts = line[dollar_index+2:].split()
|
parts = line[dollar_index + 2 :].split()
|
||||||
if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ['curl', 'wget', 'git', 'npm', 'pip', 'docker']):
|
if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ["curl", "wget", "git", "npm", "pip", "docker"]):
|
||||||
actual_command_lines += 1
|
actual_command_lines += 1
|
||||||
elif line.startswith('# ') and len(line) > 2:
|
elif line.startswith("# ") and len(line) > 2:
|
||||||
rest = line[2:].strip()
|
rest = line[2:].strip()
|
||||||
if rest and not rest[0].isupper() and ' ' in rest:
|
if rest and not rest[0].isupper() and " " in rest:
|
||||||
actual_command_lines += 1
|
actual_command_lines += 1
|
||||||
elif line.startswith('> ') and len(line) > 2:
|
elif line.startswith("> ") and len(line) > 2:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']):
|
if actual_command_lines >= 1 and any(c in message for c in ["http://", "https://", " | "]):
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
if actual_command_lines >= 3:
|
if actual_command_lines >= 3:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs)
|
# Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs)
|
||||||
if msg_len > 30:
|
if msg_len > 30:
|
||||||
slash_count = message.count('/') + message.count('\\')
|
slash_count = message.count("/") + message.count("\\")
|
||||||
dot_count = message.count('.')
|
dot_count = message.count(".")
|
||||||
path_chars = slash_count + dot_count
|
path_chars = slash_count + dot_count
|
||||||
if path_chars > 10 and (path_chars / msg_len) > 0.15:
|
if path_chars > 10 and (path_chars / msg_len) > 0.15:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 6: Markup character density (structured data)
|
# Pattern 6: Markup character density (structured data)
|
||||||
markup_chars = sum(message.count(c) for c in '{}[]<>')
|
markup_chars = sum(message.count(c) for c in "{}[]<>")
|
||||||
if markup_chars >= 6:
|
if markup_chars >= 6:
|
||||||
if markup_chars / msg_len > 0.10:
|
if markup_chars / msg_len > 0.10:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
curly_count = message.count('{') + message.count('}')
|
curly_count = message.count("{") + message.count("}")
|
||||||
if curly_count >= 10:
|
if curly_count >= 10:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 7: Structured nested content with colons (key: value patterns)
|
# Pattern 7: Structured nested content with colons (key: value patterns)
|
||||||
line_count = message.count('\n')
|
line_count = message.count("\n")
|
||||||
if line_count >= 8:
|
if line_count >= 8:
|
||||||
lines = message.split('\n')
|
lines = message.split("\n")
|
||||||
non_empty_lines = [line for line in lines if line.strip()]
|
non_empty_lines = [line for line in lines if line.strip()]
|
||||||
if non_empty_lines:
|
if non_empty_lines:
|
||||||
colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#'))
|
colon_lines = sum(1 for line in non_empty_lines if ":" in line and not line.strip().startswith("#"))
|
||||||
indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
|
indented_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t")))
|
||||||
|
|
||||||
if (colon_lines / len(non_empty_lines) > 0.4 and
|
if colon_lines / len(non_empty_lines) > 0.4 and indented_lines / len(non_empty_lines) > 0.5:
|
||||||
indented_lines / len(non_empty_lines) > 0.5):
|
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 8: Highly structured multi-line content (require markup chars for technical confidence)
|
# Pattern 8: Highly structured multi-line content (require markup chars for technical confidence)
|
||||||
if line_count > 15:
|
if line_count > 15:
|
||||||
lines = message.split('\n')
|
lines = message.split("\n")
|
||||||
non_empty_lines = [line for line in lines if line.strip()]
|
non_empty_lines = [line for line in lines if line.strip()]
|
||||||
if non_empty_lines:
|
if non_empty_lines:
|
||||||
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>'))
|
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in "{}[]<>"))
|
||||||
structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
|
structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t")))
|
||||||
|
|
||||||
if markup_in_lines / len(non_empty_lines) > 0.3:
|
if markup_in_lines / len(non_empty_lines) > 0.3:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
elif structured_lines / len(non_empty_lines) > 0.6:
|
elif structured_lines / len(non_empty_lines) > 0.6:
|
||||||
technical_keywords = ['function', 'class', 'import', 'return', 'const', 'var', 'let', 'def']
|
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
|
||||||
if any(keyword in message.lower() for keyword in technical_keywords):
|
if any(keyword in message.lower() for keyword in technical_keywords):
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists)
|
# Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists)
|
||||||
if line_count >= 3:
|
if line_count >= 3:
|
||||||
lines = message.split('\n')
|
lines = message.split("\n")
|
||||||
non_empty_lines = [line for line in lines if line.strip()]
|
non_empty_lines = [line for line in lines if line.strip()]
|
||||||
if non_empty_lines:
|
if non_empty_lines:
|
||||||
indented_lines = sum(1 for line in non_empty_lines if line[0] in (' ', '\t'))
|
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
|
||||||
if indented_lines / len(non_empty_lines) > 0.5:
|
if indented_lines / len(non_empty_lines) > 0.5:
|
||||||
code_indicators = ['def ', 'class ', 'function ', 'return ', 'import ', 'const ', 'let ', 'var ', 'public ', 'private ']
|
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "]
|
||||||
if any(indicator in message.lower() for indicator in code_indicators):
|
if any(indicator in message.lower() for indicator in code_indicators):
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
# Pattern 10: Very high special character ratio (encoded data, technical output)
|
# Pattern 10: Very high special character ratio (encoded data, technical output)
|
||||||
if msg_len > 50:
|
if msg_len > 50:
|
||||||
special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace())
|
special_chars = sum(1 for c in message if not c.isalnum() and not c.isspace())
|
||||||
@@ -614,10 +606,10 @@ class SkipDetector:
|
|||||||
alphanumeric = sum(1 for c in message if c.isalnum())
|
alphanumeric = sum(1 for c in message if c.isalnum())
|
||||||
if alphanumeric / msg_len < 0.50:
|
if alphanumeric / msg_len < 0.50:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]:
|
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Detect if a message should be skipped using two-stage detection:
|
Detect if a message should be skipped using two-stage detection:
|
||||||
1. Fast-path structural patterns (~95% confidence)
|
1. Fast-path structural patterns (~95% confidence)
|
||||||
@@ -628,53 +620,49 @@ class SkipDetector:
|
|||||||
size_issue = self.validate_message_size(message, max_message_chars)
|
size_issue = self.validate_message_size(message, max_message_chars)
|
||||||
if size_issue:
|
if size_issue:
|
||||||
return size_issue
|
return size_issue
|
||||||
|
|
||||||
fast_skip = self._fast_path_skip_detection(message)
|
fast_skip = self._fast_path_skip_detection(message)
|
||||||
if fast_skip:
|
if fast_skip:
|
||||||
logger.info(f"Fast-path skip: {fast_skip}")
|
logger.info(f"Fast-path skip: {fast_skip}")
|
||||||
return fast_skip
|
return fast_skip
|
||||||
|
|
||||||
if self._reference_embeddings is None:
|
if self._reference_embeddings is None:
|
||||||
logger.warning("SkipDetector reference embeddings not initialized, allowing message through")
|
logger.warning("SkipDetector reference embeddings not initialized, allowing message through")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
try:
|
try:
|
||||||
message_embedding = np.array(self.embedding_function([message.strip()])[0])
|
message_embedding = np.array(self.embedding_function([message.strip()])[0])
|
||||||
|
|
||||||
conversational_similarities = np.dot(
|
conversational_similarities = np.dot(message_embedding, self._reference_embeddings["conversational"].T)
|
||||||
message_embedding,
|
|
||||||
self._reference_embeddings['conversational'].T
|
|
||||||
)
|
|
||||||
max_conversational_similarity = float(conversational_similarities.max())
|
max_conversational_similarity = float(conversational_similarities.max())
|
||||||
|
|
||||||
skip_categories = [
|
skip_categories = [
|
||||||
('instruction', self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
|
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
|
||||||
('translation', self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
|
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
|
||||||
('grammar', self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
|
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
|
||||||
('technical', self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
|
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
|
||||||
('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
|
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
|
||||||
]
|
]
|
||||||
|
|
||||||
qualifying_categories = []
|
qualifying_categories = []
|
||||||
margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN
|
margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN
|
||||||
|
|
||||||
for cat_key, skip_reason, descriptions in skip_categories:
|
for cat_key, skip_reason, descriptions in skip_categories:
|
||||||
similarities = np.dot(
|
similarities = np.dot(message_embedding, self._reference_embeddings[cat_key].T)
|
||||||
message_embedding,
|
|
||||||
self._reference_embeddings[cat_key].T
|
|
||||||
)
|
|
||||||
max_similarity = float(similarities.max())
|
max_similarity = float(similarities.max())
|
||||||
|
|
||||||
if max_similarity > margin_threshold:
|
if max_similarity > margin_threshold:
|
||||||
qualifying_categories.append((max_similarity, cat_key, skip_reason))
|
qualifying_categories.append((max_similarity, cat_key, skip_reason))
|
||||||
|
|
||||||
if qualifying_categories:
|
if qualifying_categories:
|
||||||
highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0])
|
highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0])
|
||||||
logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})")
|
logger.info(
|
||||||
|
f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})"
|
||||||
|
)
|
||||||
return highest_skip_reason.value
|
return highest_skip_reason.value
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.error(f"Error in semantic skip detection: {e}")
|
logger.error(f"Error in semantic skip detection: {e}")
|
||||||
return None
|
return None
|
||||||
@@ -692,7 +680,7 @@ class LLMRerankingService:
|
|||||||
|
|
||||||
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
|
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
|
||||||
if len(memories) > llm_trigger_threshold:
|
if len(memories) > llm_trigger_threshold:
|
||||||
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold"
|
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold"
|
||||||
|
|
||||||
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}"
|
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}"
|
||||||
|
|
||||||
@@ -717,33 +705,29 @@ CANDIDATE MEMORIES:
|
|||||||
selected_memories.append(memory)
|
selected_memories.append(memory)
|
||||||
|
|
||||||
logger.info(f"🧠 LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates")
|
logger.info(f"🧠 LLM selected {len(selected_memories)} out of {len(candidate_memories)} candidates")
|
||||||
|
|
||||||
return selected_memories
|
return selected_memories
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
|
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
|
||||||
return candidate_memories
|
return candidate_memories
|
||||||
|
|
||||||
async def rerank_memories(
|
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||||
self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None
|
|
||||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
max_injection = self.memory_system.valves.max_memories_returned
|
max_injection = self.memory_system.valves.max_memories_returned
|
||||||
|
|
||||||
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
|
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
|
||||||
|
|
||||||
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)}
|
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)}
|
||||||
|
|
||||||
if should_use_llm:
|
if should_use_llm:
|
||||||
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||||
llm_candidates = candidate_memories[:extended_count]
|
llm_candidates = candidate_memories[:extended_count]
|
||||||
await self.memory_system._emit_status(
|
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False)
|
||||||
emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False
|
|
||||||
)
|
|
||||||
logger.info(f"Using LLM reranking: {decision_reason}")
|
logger.info(f"Using LLM reranking: {decision_reason}")
|
||||||
|
|
||||||
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
|
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
|
||||||
|
|
||||||
if not selected_memories:
|
if not selected_memories:
|
||||||
logger.info("📭 No relevant memories after LLM analysis")
|
logger.info("📭 No relevant memories after LLM analysis")
|
||||||
await self.memory_system._emit_status(emitter, f"📭 No Relevant Memories After LLM Analysis", done=True)
|
await self.memory_system._emit_status(emitter, f"📭 No Relevant Memories After LLM Analysis", done=True)
|
||||||
@@ -751,7 +735,7 @@ CANDIDATE MEMORIES:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"Skipping LLM reranking: {decision_reason}")
|
logger.info(f"Skipping LLM reranking: {decision_reason}")
|
||||||
selected_memories = candidate_memories[:max_injection]
|
selected_memories = candidate_memories[:max_injection]
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
|
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
|
||||||
retrieval_method = "LLM" if should_use_llm else "Semantic"
|
retrieval_method = "LLM" if should_use_llm else "Semantic"
|
||||||
@@ -769,10 +753,10 @@ class LLMConsolidationService:
|
|||||||
"""Filter consolidation candidates by threshold and return candidates with threshold info."""
|
"""Filter consolidation candidates by threshold and return candidates with threshold info."""
|
||||||
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
|
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
|
||||||
candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold]
|
candidates = [mem for mem in similarities if mem["relevance"] >= consolidation_threshold]
|
||||||
|
|
||||||
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
max_consolidation_memories = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||||
candidates = candidates[:max_consolidation_memories]
|
candidates = candidates[:max_consolidation_memories]
|
||||||
|
|
||||||
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
|
threshold_info = f"{consolidation_threshold:.3f} (max: {max_consolidation_memories})"
|
||||||
return candidates, threshold_info
|
return candidates, threshold_info
|
||||||
|
|
||||||
@@ -812,7 +796,7 @@ class LLMConsolidationService:
|
|||||||
candidates, threshold_info = self._filter_consolidation_candidates(all_similarities)
|
candidates, threshold_info = self._filter_consolidation_candidates(all_similarities)
|
||||||
else:
|
else:
|
||||||
candidates = []
|
candidates = []
|
||||||
threshold_info = 'N/A'
|
threshold_info = "N/A"
|
||||||
|
|
||||||
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})")
|
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})")
|
||||||
|
|
||||||
@@ -820,7 +804,9 @@ class LLMConsolidationService:
|
|||||||
|
|
||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
async def generate_consolidation_plan(self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None) -> List[Dict[str, Any]]:
|
async def generate_consolidation_plan(
|
||||||
|
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None
|
||||||
|
) -> List[Dict[str, Any]]:
|
||||||
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
|
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
|
||||||
if candidate_memories:
|
if candidate_memories:
|
||||||
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
|
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
|
||||||
@@ -925,7 +911,7 @@ class LLMConsolidationService:
|
|||||||
results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
results = await asyncio.gather(*batch_tasks, return_exceptions=True)
|
||||||
for idx, result in enumerate(results):
|
for idx, result in enumerate(results):
|
||||||
operation = ops[idx]
|
operation = ops[idx]
|
||||||
|
|
||||||
if isinstance(result, Exception):
|
if isinstance(result, Exception):
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
|
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
|
||||||
@@ -982,21 +968,21 @@ class LLMConsolidationService:
|
|||||||
|
|
||||||
if operations:
|
if operations:
|
||||||
created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter)
|
created_count, updated_count, deleted_count, failed_count = await self.execute_memory_operations(operations, user_id, emitter)
|
||||||
|
|
||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
logger.info(f"💾 Memory Consolidation Complete In {duration:.2f}s")
|
logger.info(f"💾 Memory Consolidation Complete In {duration:.2f}s")
|
||||||
|
|
||||||
total_operations = created_count + updated_count + deleted_count
|
total_operations = created_count + updated_count + deleted_count
|
||||||
if total_operations > 0 or failed_count > 0:
|
if total_operations > 0 or failed_count > 0:
|
||||||
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False)
|
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False)
|
||||||
|
|
||||||
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
|
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
|
||||||
memory_word = "Memory" if total_operations == 1 else "Memories"
|
memory_word = "Memory" if total_operations == 1 else "Memories"
|
||||||
operations_summary = f"{', '.join(operation_details)} {memory_word}"
|
operations_summary = f"{', '.join(operation_details)} {memory_word}"
|
||||||
|
|
||||||
if failed_count > 0:
|
if failed_count > 0:
|
||||||
operations_summary += f" (❌ {failed_count} Failed)"
|
operations_summary += f" (❌ {failed_count} Failed)"
|
||||||
|
|
||||||
await self.memory_system._emit_status(emitter, operations_summary, done=True)
|
await self.memory_system._emit_status(emitter, operations_summary, done=True)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1016,20 +1002,27 @@ class Filter:
|
|||||||
"""Configuration valves for the Memory System."""
|
"""Configuration valves for the Memory System."""
|
||||||
|
|
||||||
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
||||||
|
|
||||||
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
||||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
||||||
|
|
||||||
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
|
semantic_retrieval_threshold: float = Field(
|
||||||
relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)")
|
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval"
|
||||||
|
)
|
||||||
|
relaxed_semantic_threshold_multiplier: float = Field(
|
||||||
|
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
|
||||||
|
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
|
||||||
|
)
|
||||||
|
|
||||||
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
|
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
|
||||||
llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)")
|
llm_reranking_trigger_multiplier: float = Field(
|
||||||
|
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)"
|
||||||
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
"""Initialize the Memory System filter with production validation."""
|
"""Initialize the Memory System filter with production validation."""
|
||||||
global _SHARED_SKIP_DETECTOR_CACHE
|
global _SHARED_SKIP_DETECTOR_CACHE
|
||||||
|
|
||||||
self.valves = self.Valves()
|
self.valves = self.Valves()
|
||||||
self._validate_system_configuration()
|
self._validate_system_configuration()
|
||||||
|
|
||||||
@@ -1043,8 +1036,13 @@ class Filter:
|
|||||||
self._llm_reranking_service = LLMRerankingService(self)
|
self._llm_reranking_service = LLMRerankingService(self)
|
||||||
self._llm_consolidation_service = LLMConsolidationService(self)
|
self._llm_consolidation_service = LLMConsolidationService(self)
|
||||||
|
|
||||||
async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
|
async def _set_pipeline_context(
|
||||||
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
|
self,
|
||||||
|
__event_emitter__: Optional[Callable] = None,
|
||||||
|
__user__: Optional[Dict[str, Any]] = None,
|
||||||
|
__model__: Optional[str] = None,
|
||||||
|
__request__: Optional[Request] = None,
|
||||||
|
) -> None:
|
||||||
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
|
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
|
||||||
if __event_emitter__:
|
if __event_emitter__:
|
||||||
self.__current_event_emitter__ = __event_emitter__
|
self.__current_event_emitter__ = __event_emitter__
|
||||||
@@ -1054,17 +1052,17 @@ class Filter:
|
|||||||
self.__model__ = __model__
|
self.__model__ = __model__
|
||||||
if __request__:
|
if __request__:
|
||||||
self.__request__ = __request__
|
self.__request__ = __request__
|
||||||
|
|
||||||
if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'):
|
if self._embedding_function is None and hasattr(__request__.app.state, "EMBEDDING_FUNCTION"):
|
||||||
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
|
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
|
||||||
logger.info(f"✅ Using OpenWebUI's embedding function")
|
logger.info(f"✅ Using OpenWebUI's embedding function")
|
||||||
|
|
||||||
if self._skip_detector is None:
|
if self._skip_detector is None:
|
||||||
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
|
||||||
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
|
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
|
||||||
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
|
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
|
||||||
cache_key = f"{embedding_engine}:{embedding_model}"
|
cache_key = f"{embedding_engine}:{embedding_model}"
|
||||||
|
|
||||||
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
|
||||||
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
if cache_key in _SHARED_SKIP_DETECTOR_CACHE:
|
||||||
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
logger.info(f"♻️ Reusing cached skip detector: {cache_key}")
|
||||||
@@ -1072,6 +1070,7 @@ class Filter:
|
|||||||
else:
|
else:
|
||||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||||
embedding_fn = self._embedding_function
|
embedding_fn = self._embedding_function
|
||||||
|
|
||||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||||
result = embedding_fn(texts, prefix=None, user=None)
|
result = embedding_fn(texts, prefix=None, user=None)
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
@@ -1079,12 +1078,11 @@ class Filter:
|
|||||||
return [np.array(emb, dtype=np.float16) for emb in result]
|
return [np.array(emb, dtype=np.float16) for emb in result]
|
||||||
return np.array(result, dtype=np.float16)
|
return np.array(result, dtype=np.float16)
|
||||||
return np.array(result, dtype=np.float16)
|
return np.array(result, dtype=np.float16)
|
||||||
|
|
||||||
self._skip_detector = SkipDetector(embedding_wrapper)
|
self._skip_detector = SkipDetector(embedding_wrapper)
|
||||||
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
|
||||||
logger.info(f"✅ Skip detector initialized and cached")
|
logger.info(f"✅ Skip detector initialized and cached")
|
||||||
|
|
||||||
|
|
||||||
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
|
||||||
"""Truncate content with ellipsis if needed."""
|
"""Truncate content with ellipsis if needed."""
|
||||||
if max_length is None:
|
if max_length is None:
|
||||||
@@ -1148,7 +1146,7 @@ class Filter:
|
|||||||
"""Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function."""
|
"""Unified embedding generation for single text or batch with optimized caching using OpenWebUI's embedding function."""
|
||||||
if self._embedding_function is None:
|
if self._embedding_function is None:
|
||||||
raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.")
|
raise RuntimeError("🤖 Embedding function not initialized. Ensure pipeline context is set.")
|
||||||
|
|
||||||
is_single = isinstance(texts, str)
|
is_single = isinstance(texts, str)
|
||||||
text_list = [texts] if is_single else texts
|
text_list = [texts] if is_single else texts
|
||||||
|
|
||||||
@@ -1181,17 +1179,11 @@ class Filter:
|
|||||||
uncached_hashes.append(text_hash)
|
uncached_hashes.append(text_hash)
|
||||||
|
|
||||||
if uncached_texts:
|
if uncached_texts:
|
||||||
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None
|
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
raw_embeddings = await loop.run_in_executor(
|
raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user)
|
||||||
None,
|
|
||||||
self._embedding_function,
|
|
||||||
uncached_texts,
|
|
||||||
None,
|
|
||||||
user
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
|
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
|
||||||
if isinstance(raw_embeddings[0], list):
|
if isinstance(raw_embeddings[0], list):
|
||||||
new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings]
|
new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings]
|
||||||
@@ -1211,15 +1203,13 @@ class Filter:
|
|||||||
return result_embeddings[0]
|
return result_embeddings[0]
|
||||||
else:
|
else:
|
||||||
valid_count = sum(1 for emb in result_embeddings if emb is not None)
|
valid_count = sum(1 for emb in result_embeddings if emb is not None)
|
||||||
logger.info(
|
logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid")
|
||||||
f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid"
|
|
||||||
)
|
|
||||||
return result_embeddings
|
return result_embeddings
|
||||||
|
|
||||||
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
|
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
|
||||||
if self._skip_detector is None:
|
if self._skip_detector is None:
|
||||||
raise RuntimeError("🤖 Skip detector not initialized")
|
raise RuntimeError("🤖 Skip detector not initialized")
|
||||||
|
|
||||||
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self)
|
skip_reason = self._skip_detector.detect_skip_reason(user_message, self.valves.max_message_chars, memory_system=self)
|
||||||
if skip_reason:
|
if skip_reason:
|
||||||
status_key = SkipDetector.SkipReason(skip_reason)
|
status_key = SkipDetector.SkipReason(skip_reason)
|
||||||
@@ -1290,7 +1280,7 @@ class Filter:
|
|||||||
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
|
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
|
||||||
"""Unified cache key generation for all cache types."""
|
"""Unified cache key generation for all cache types."""
|
||||||
if content:
|
if content:
|
||||||
content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH]
|
content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH]
|
||||||
return f"{cache_type}_{user_id}:{content_hash}"
|
return f"{cache_type}_{user_id}:{content_hash}"
|
||||||
return f"{cache_type}_{user_id}"
|
return f"{cache_type}_{user_id}"
|
||||||
|
|
||||||
@@ -1310,7 +1300,7 @@ class Filter:
|
|||||||
if record_date:
|
if record_date:
|
||||||
try:
|
try:
|
||||||
if isinstance(record_date, str):
|
if isinstance(record_date, str):
|
||||||
parsed_date = datetime.fromisoformat(record_date.replace('Z', '+00:00'))
|
parsed_date = datetime.fromisoformat(record_date.replace("Z", "+00:00"))
|
||||||
else:
|
else:
|
||||||
parsed_date = record_date
|
parsed_date = record_date
|
||||||
formatted_date = parsed_date.strftime("%b %d %Y")
|
formatted_date = parsed_date.strftime("%b %d %Y")
|
||||||
@@ -1393,14 +1383,14 @@ class Filter:
|
|||||||
memory_count = len(memories)
|
memory_count = len(memories)
|
||||||
memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request."
|
memory_header = f"CONTEXT: The following {'fact' if memory_count == 1 else 'facts'} about the user are provided for background only. Not all facts may be relevant to the current request."
|
||||||
formatted_memories = []
|
formatted_memories = []
|
||||||
|
|
||||||
for idx, memory in enumerate(memories, 1):
|
for idx, memory in enumerate(memories, 1):
|
||||||
formatted_memory = f"- {' '.join(memory['content'].split())}"
|
formatted_memory = f"- {' '.join(memory['content'].split())}"
|
||||||
formatted_memories.append(formatted_memory)
|
formatted_memories.append(formatted_memory)
|
||||||
|
|
||||||
content_preview = self._truncate_content(memory['content'])
|
content_preview = self._truncate_content(memory["content"])
|
||||||
await self._emit_status(emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False)
|
await self._emit_status(emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False)
|
||||||
|
|
||||||
memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only."
|
memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only."
|
||||||
memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}"
|
memory_context_block = f"{memory_header}\n{chr(10).join(formatted_memories)}\n\n{memory_footer}"
|
||||||
content_parts.append(memory_context_block)
|
content_parts.append(memory_context_block)
|
||||||
@@ -1413,7 +1403,7 @@ class Filter:
|
|||||||
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
|
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
|
||||||
else:
|
else:
|
||||||
body["messages"].insert(0, {"role": "system", "content": memory_context})
|
body["messages"].insert(0, {"role": "system", "content": memory_context})
|
||||||
|
|
||||||
if memories and user_id:
|
if memories and user_id:
|
||||||
description = f"🧠 Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context"
|
description = f"🧠 Injected {memory_count} {'Memory' if memory_count == 1 else 'Memories'} to Context"
|
||||||
await self._emit_status(emitter, description, done=True)
|
await self._emit_status(emitter, description, done=True)
|
||||||
@@ -1427,9 +1417,7 @@ class Filter:
|
|||||||
memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat()
|
memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat()
|
||||||
return memory_dict
|
return memory_dict
|
||||||
|
|
||||||
async def _compute_similarities(
|
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]:
|
||||||
self, user_message: str, user_id: str, user_memories: List
|
|
||||||
) -> Tuple[List[Dict], float, List[Dict]]:
|
|
||||||
"""Compute similarity scores between user message and memories."""
|
"""Compute similarity scores between user message and memories."""
|
||||||
if not user_memories:
|
if not user_memories:
|
||||||
return [], self.valves.semantic_retrieval_threshold, []
|
return [], self.valves.semantic_retrieval_threshold, []
|
||||||
@@ -1461,7 +1449,7 @@ class Filter:
|
|||||||
memory_data.sort(key=lambda x: x["relevance"], reverse=True)
|
memory_data.sort(key=lambda x: x["relevance"], reverse=True)
|
||||||
|
|
||||||
threshold = self.valves.semantic_retrieval_threshold
|
threshold = self.valves.semantic_retrieval_threshold
|
||||||
filtered_memories = [m for m in memory_data if m["relevance"] >= threshold]
|
filtered_memories = [m for m in memory_data if m["relevance"] >= threshold]
|
||||||
return filtered_memories, threshold, memory_data
|
return filtered_memories, threshold, memory_data
|
||||||
|
|
||||||
async def inlet(
|
async def inlet(
|
||||||
@@ -1536,9 +1524,7 @@ class Filter:
|
|||||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||||
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
|
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
|
||||||
|
|
||||||
task = asyncio.create_task(
|
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
|
||||||
self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)
|
|
||||||
)
|
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
|
|
||||||
def safe_cleanup(t: asyncio.Task) -> None:
|
def safe_cleanup(t: asyncio.Task) -> None:
|
||||||
@@ -1582,11 +1568,7 @@ class Filter:
|
|||||||
|
|
||||||
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
||||||
|
|
||||||
memory_contents = [
|
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
|
||||||
memory.content
|
|
||||||
for memory in user_memories
|
|
||||||
if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS
|
|
||||||
]
|
|
||||||
|
|
||||||
if memory_contents:
|
if memory_contents:
|
||||||
await self._generate_embeddings(memory_contents, user_id)
|
await self._generate_embeddings(memory_contents, user_id)
|
||||||
@@ -1615,7 +1597,7 @@ class Filter:
|
|||||||
if not id_stripped:
|
if not id_stripped:
|
||||||
logger.warning(f"⚠️ Skipping UPDATE operation: empty ID")
|
logger.warning(f"⚠️ Skipping UPDATE operation: empty ID")
|
||||||
return Models.OperationResult.SKIPPED_EMPTY_ID.value
|
return Models.OperationResult.SKIPPED_EMPTY_ID.value
|
||||||
|
|
||||||
content_stripped = operation.content.strip()
|
content_stripped = operation.content.strip()
|
||||||
if not content_stripped:
|
if not content_stripped:
|
||||||
logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content")
|
logger.warning(f"⚠️ Skipping UPDATE operation for {id_stripped}: empty content")
|
||||||
@@ -1649,18 +1631,18 @@ class Filter:
|
|||||||
"""Remove $ref references and ensure required fields for Azure OpenAI."""
|
"""Remove $ref references and ensure required fields for Azure OpenAI."""
|
||||||
if not isinstance(schema, dict):
|
if not isinstance(schema, dict):
|
||||||
return schema
|
return schema
|
||||||
|
|
||||||
if '$ref' in schema:
|
if "$ref" in schema:
|
||||||
ref_path = schema['$ref']
|
ref_path = schema["$ref"]
|
||||||
if ref_path.startswith('#/$defs/'):
|
if ref_path.startswith("#/$defs/"):
|
||||||
def_name = ref_path.split('/')[-1]
|
def_name = ref_path.split("/")[-1]
|
||||||
if schema_defs and def_name in schema_defs:
|
if schema_defs and def_name in schema_defs:
|
||||||
return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs)
|
return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs)
|
||||||
return {'type': 'object'}
|
return {"type": "object"}
|
||||||
|
|
||||||
result = {}
|
result = {}
|
||||||
for key, value in schema.items():
|
for key, value in schema.items():
|
||||||
if key == '$defs':
|
if key == "$defs":
|
||||||
continue
|
continue
|
||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
result[key] = self._remove_refs_from_schema(value, schema_defs)
|
result[key] = self._remove_refs_from_schema(value, schema_defs)
|
||||||
@@ -1668,10 +1650,10 @@ class Filter:
|
|||||||
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value]
|
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value]
|
||||||
else:
|
else:
|
||||||
result[key] = value
|
result[key] = value
|
||||||
|
|
||||||
if result.get('type') == 'object' and 'properties' in result:
|
if result.get("type") == "object" and "properties" in result:
|
||||||
result['required'] = list(result['properties'].keys())
|
result["required"] = list(result["properties"].keys())
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]:
|
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]:
|
||||||
@@ -1685,16 +1667,16 @@ class Filter:
|
|||||||
|
|
||||||
form_data = {
|
form_data = {
|
||||||
"model": model_to_use,
|
"model": model_to_use,
|
||||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
raw_schema = response_model.model_json_schema()
|
raw_schema = response_model.model_json_schema()
|
||||||
schema_defs = raw_schema.get('$defs', {})
|
schema_defs = raw_schema.get("$defs", {})
|
||||||
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
|
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
|
||||||
schema['type'] = 'object'
|
schema["type"] = "object"
|
||||||
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
|
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
@@ -1718,7 +1700,12 @@ class Filter:
|
|||||||
|
|
||||||
if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0:
|
if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0:
|
||||||
first_choice = response_data["choices"][0]
|
first_choice = response_data["choices"][0]
|
||||||
if isinstance(first_choice, dict) and "message" in first_choice and isinstance(first_choice["message"], dict) and "content" in first_choice["message"]:
|
if (
|
||||||
|
isinstance(first_choice, dict)
|
||||||
|
and "message" in first_choice
|
||||||
|
and isinstance(first_choice["message"], dict)
|
||||||
|
and "content" in first_choice["message"]
|
||||||
|
):
|
||||||
content = first_choice["message"]["content"]
|
content = first_choice["message"]["content"]
|
||||||
else:
|
else:
|
||||||
raise ValueError("🤖 Invalid response structure: missing content in message")
|
raise ValueError("🤖 Invalid response structure: missing content in message")
|
||||||
@@ -1726,7 +1713,7 @@ class Filter:
|
|||||||
raise ValueError(f"🤖 Unexpected LLM response format: {response_data}")
|
raise ValueError(f"🤖 Unexpected LLM response format: {response_data}")
|
||||||
|
|
||||||
if response_model:
|
if response_model:
|
||||||
try:
|
try:
|
||||||
parsed_data = json.loads(content)
|
parsed_data = json.loads(content)
|
||||||
return response_model.model_validate(parsed_data)
|
return response_model.model_validate(parsed_data)
|
||||||
except json.JSONDecodeError as e:
|
except json.JSONDecodeError as e:
|
||||||
|
|||||||
Reference in New Issue
Block a user