♻️ (memory_system.py): reformat code for consistency, readability, and maintainability

- Reorder and group imports for clarity and PEP8 compliance.
- Standardize string quoting and whitespace for consistency.
- Refactor long function signatures and dictionary constructions for better readability.
- Use double quotes for all string literals and dictionary keys.
- Improve formatting of multiline statements and function calls.
- Add or adjust line breaks to keep lines within recommended length.
- Reformat class and method docstrings for clarity.
- Use consistent indentation and spacing throughout the file.

These changes improve code readability, maintainability, and consistency, making it easier for future contributors to understand and modify the codebase. No functional logic is changed.
This commit is contained in:
mtayfur
2025-10-27 00:20:35 +03:00
parent 189c6d4226
commit bb1bd01222

View File

@@ -15,18 +15,19 @@ from enum import Enum
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
from pydantic import BaseModel, ConfigDict, Field, ValidationError as PydanticValidationError
from open_webui.utils.chat import generate_chat_completion
from fastapi import Request
from open_webui.models.users import Users
from open_webui.routers.memories import Memories
from fastapi import Request
from open_webui.utils.chat import generate_chat_completion
from pydantic import BaseModel, ConfigDict, Field
from pydantic import ValidationError as PydanticValidationError
logger = logging.getLogger(__name__)
_SHARED_SKIP_DETECTOR_CACHE = {}
_SHARED_SKIP_DETECTOR_CACHE_LOCK = asyncio.Lock()
class Constants:
"""Centralized configuration constants for the memory system."""
@@ -44,10 +45,10 @@ class Constants:
CACHE_KEY_HASH_PREFIX_LENGTH = 10 # Hash prefix length for cache keys
# Retrieval & Similarity
SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval
RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold
SEMANTIC_RETRIEVAL_THRESHOLD = 0.25 # Semantic similarity threshold for retrieval
RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER = 0.8 # Multiplier for relaxed similarity threshold in secondary operations
EXTENDED_MAX_MEMORY_MULTIPLIER = 1.6 # Multiplier for expanding memory candidates in advanced operations
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold
# Skip Detection
SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification
@@ -62,6 +63,7 @@ class Constants:
# Default Models
DEFAULT_LLM_MODEL = "google/gemini-2.5-flash-lite"
class Prompts:
"""Container for all LLM prompts used in the memory system."""
@@ -181,6 +183,7 @@ Return: {{"ids": []}}
Explanation: Query seeks general technical explanation without personal context. Job and family information don't affect how quantum computing concepts should be explained.
"""
class Models:
"""Container for all Pydantic models used in the memory system."""
@@ -203,7 +206,7 @@ class Models:
class MemoryOperation(StrictModel):
"""Pydantic model for memory operations with validation."""
operation: 'Models.MemoryOperationType' = Field(description="Type of memory operation to perform")
operation: "Models.MemoryOperationType" = Field(description="Type of memory operation to perform")
content: str = Field(description="Memory content (required for CREATE/UPDATE, empty for DELETE)")
id: str = Field(description="Memory ID (empty for CREATE, required for UPDATE/DELETE)")
@@ -221,7 +224,7 @@ class Models:
class ConsolidationResponse(BaseModel):
"""Pydantic model for memory consolidation LLM response - object containing array of memory operations."""
ops: List['Models.MemoryOperation'] = Field(default_factory=list, description="List of memory operations to execute")
ops: List["Models.MemoryOperation"] = Field(default_factory=list, description="List of memory operations to execute")
class MemoryRerankingResponse(BaseModel):
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
@@ -446,48 +449,38 @@ class SkipDetector:
def _initialize_reference_embeddings(self) -> None:
"""Compute and cache embeddings for category descriptions."""
try:
technical_embeddings = self.embedding_function(
self.TECHNICAL_CATEGORY_DESCRIPTIONS
)
technical_embeddings = self.embedding_function(self.TECHNICAL_CATEGORY_DESCRIPTIONS)
instruction_embeddings = self.embedding_function(
self.INSTRUCTION_CATEGORY_DESCRIPTIONS
)
instruction_embeddings = self.embedding_function(self.INSTRUCTION_CATEGORY_DESCRIPTIONS)
pure_math_embeddings = self.embedding_function(
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS
)
pure_math_embeddings = self.embedding_function(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS)
translation_embeddings = self.embedding_function(
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS
)
translation_embeddings = self.embedding_function(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS)
grammar_embeddings = self.embedding_function(
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS
)
grammar_embeddings = self.embedding_function(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
conversational_embeddings = self.embedding_function(
self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS
)
conversational_embeddings = self.embedding_function(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)
self._reference_embeddings = {
'technical': np.array(technical_embeddings),
'instruction': np.array(instruction_embeddings),
'pure_math': np.array(pure_math_embeddings),
'translation': np.array(translation_embeddings),
'grammar': np.array(grammar_embeddings),
'conversational': np.array(conversational_embeddings),
"technical": np.array(technical_embeddings),
"instruction": np.array(instruction_embeddings),
"pure_math": np.array(pure_math_embeddings),
"translation": np.array(translation_embeddings),
"grammar": np.array(grammar_embeddings),
"conversational": np.array(conversational_embeddings),
}
total_skip_categories = (
len(self.TECHNICAL_CATEGORY_DESCRIPTIONS) +
len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS) +
len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS) +
len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS) +
len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
len(self.TECHNICAL_CATEGORY_DESCRIPTIONS)
+ len(self.INSTRUCTION_CATEGORY_DESCRIPTIONS)
+ len(self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS)
+ len(self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS)
+ len(self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS)
)
logger.info(f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories")
logger.info(
f"SkipDetector initialized with {total_skip_categories} skip categories and {len(self.CONVERSATIONAL_CATEGORY_DESCRIPTIONS)} personal categories"
)
except Exception as e:
logger.error(f"Failed to initialize SkipDetector reference embeddings: {e}")
self._reference_embeddings = None
@@ -506,7 +499,7 @@ class SkipDetector:
msg_len = len(message)
# Pattern 1: Multiple URLs (5+ full URLs indicates link lists or technical references)
url_pattern_count = message.count('http://') + message.count('https://')
url_pattern_count = message.count("http://") + message.count("https://")
if url_pattern_count >= 5:
return self.SkipReason.SKIP_TECHNICAL.value
@@ -514,95 +507,94 @@ class SkipDetector:
words = message.split()
for word in words:
cleaned = word.strip('.,;:!?()[]{}"\'"')
if len(cleaned) > 80 and cleaned.replace('-', '').replace('_', '').isalnum():
if len(cleaned) > 80 and cleaned.replace("-", "").replace("_", "").isalnum():
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 3: Markdown/text separators (repeated ---, ===, ___, ***)
separator_patterns = ['---', '===', '___', '***']
separator_patterns = ["---", "===", "___", "***"]
for pattern in separator_patterns:
if message.count(pattern) >= 2:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 4: Command-line patterns with context-aware detection
lines_stripped = [line.strip() for line in message.split('\n') if line.strip()]
lines_stripped = [line.strip() for line in message.split("\n") if line.strip()]
if lines_stripped:
actual_command_lines = 0
for line in lines_stripped:
if line.startswith('$ ') and len(line) > 2:
if line.startswith("$ ") and len(line) > 2:
parts = line[2:].split()
if parts and parts[0].isalnum():
actual_command_lines += 1
elif '$ ' in line:
dollar_index = line.find('$ ')
if dollar_index > 0 and line[dollar_index-1] in (' ', ':', '\t'):
parts = line[dollar_index+2:].split()
if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ['curl', 'wget', 'git', 'npm', 'pip', 'docker']):
elif "$ " in line:
dollar_index = line.find("$ ")
if dollar_index > 0 and line[dollar_index - 1] in (" ", ":", "\t"):
parts = line[dollar_index + 2 :].split()
if parts and len(parts[0]) > 0 and (parts[0].isalnum() or parts[0] in ["curl", "wget", "git", "npm", "pip", "docker"]):
actual_command_lines += 1
elif line.startswith('# ') and len(line) > 2:
elif line.startswith("# ") and len(line) > 2:
rest = line[2:].strip()
if rest and not rest[0].isupper() and ' ' in rest:
if rest and not rest[0].isupper() and " " in rest:
actual_command_lines += 1
elif line.startswith('> ') and len(line) > 2:
elif line.startswith("> ") and len(line) > 2:
pass
if actual_command_lines >= 1 and any(c in message for c in ['http://', 'https://', ' | ']):
if actual_command_lines >= 1 and any(c in message for c in ["http://", "https://", " | "]):
return self.SkipReason.SKIP_TECHNICAL.value
if actual_command_lines >= 3:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 5: High path/URL density (dots and slashes suggesting file paths or URLs)
if msg_len > 30:
slash_count = message.count('/') + message.count('\\')
dot_count = message.count('.')
slash_count = message.count("/") + message.count("\\")
dot_count = message.count(".")
path_chars = slash_count + dot_count
if path_chars > 10 and (path_chars / msg_len) > 0.15:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 6: Markup character density (structured data)
markup_chars = sum(message.count(c) for c in '{}[]<>')
markup_chars = sum(message.count(c) for c in "{}[]<>")
if markup_chars >= 6:
if markup_chars / msg_len > 0.10:
return self.SkipReason.SKIP_TECHNICAL.value
curly_count = message.count('{') + message.count('}')
curly_count = message.count("{") + message.count("}")
if curly_count >= 10:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 7: Structured nested content with colons (key: value patterns)
line_count = message.count('\n')
line_count = message.count("\n")
if line_count >= 8:
lines = message.split('\n')
lines = message.split("\n")
non_empty_lines = [line for line in lines if line.strip()]
if non_empty_lines:
colon_lines = sum(1 for line in non_empty_lines if ':' in line and not line.strip().startswith('#'))
indented_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
colon_lines = sum(1 for line in non_empty_lines if ":" in line and not line.strip().startswith("#"))
indented_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t")))
if (colon_lines / len(non_empty_lines) > 0.4 and
indented_lines / len(non_empty_lines) > 0.5):
if colon_lines / len(non_empty_lines) > 0.4 and indented_lines / len(non_empty_lines) > 0.5:
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 8: Highly structured multi-line content (require markup chars for technical confidence)
if line_count > 15:
lines = message.split('\n')
lines = message.split("\n")
non_empty_lines = [line for line in lines if line.strip()]
if non_empty_lines:
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in '{}[]<>'))
structured_lines = sum(1 for line in non_empty_lines if line.startswith((' ', '\t')))
markup_in_lines = sum(1 for line in non_empty_lines if any(c in line for c in "{}[]<>"))
structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t")))
if markup_in_lines / len(non_empty_lines) > 0.3:
return self.SkipReason.SKIP_TECHNICAL.value
elif structured_lines / len(non_empty_lines) > 0.6:
technical_keywords = ['function', 'class', 'import', 'return', 'const', 'var', 'let', 'def']
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
if any(keyword in message.lower() for keyword in technical_keywords):
return self.SkipReason.SKIP_TECHNICAL.value
# Pattern 9: Code-like indentation pattern (require code indicators to avoid false positives from bullet lists)
if line_count >= 3:
lines = message.split('\n')
lines = message.split("\n")
non_empty_lines = [line for line in lines if line.strip()]
if non_empty_lines:
indented_lines = sum(1 for line in non_empty_lines if line[0] in (' ', '\t'))
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
if indented_lines / len(non_empty_lines) > 0.5:
code_indicators = ['def ', 'class ', 'function ', 'return ', 'import ', 'const ', 'let ', 'var ', 'public ', 'private ']
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "]
if any(indicator in message.lower() for indicator in code_indicators):
return self.SkipReason.SKIP_TECHNICAL.value
@@ -617,7 +609,7 @@ class SkipDetector:
return None
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: 'Filter') -> Optional[str]:
def detect_skip_reason(self, message: str, max_message_chars: int, memory_system: "Filter") -> Optional[str]:
"""
Detect if a message should be skipped using two-stage detection:
1. Fast-path structural patterns (~95% confidence)
@@ -641,28 +633,22 @@ class SkipDetector:
try:
message_embedding = np.array(self.embedding_function([message.strip()])[0])
conversational_similarities = np.dot(
message_embedding,
self._reference_embeddings['conversational'].T
)
conversational_similarities = np.dot(message_embedding, self._reference_embeddings["conversational"].T)
max_conversational_similarity = float(conversational_similarities.max())
skip_categories = [
('instruction', self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
('translation', self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
('grammar', self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
('technical', self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
('pure_math', self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
]
qualifying_categories = []
margin_threshold = max_conversational_similarity + Constants.SKIP_CATEGORY_MARGIN
for cat_key, skip_reason, descriptions in skip_categories:
similarities = np.dot(
message_embedding,
self._reference_embeddings[cat_key].T
)
similarities = np.dot(message_embedding, self._reference_embeddings[cat_key].T)
max_similarity = float(similarities.max())
if max_similarity > margin_threshold:
@@ -670,7 +656,9 @@ class SkipDetector:
if qualifying_categories:
highest_similarity, highest_cat_key, highest_skip_reason = max(qualifying_categories, key=lambda x: x[0])
logger.info(f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})")
logger.info(
f"🚫 Skipping message: {highest_skip_reason.value} (sim {highest_similarity:.3f} > conv {max_conversational_similarity:.3f} + {Constants.SKIP_CATEGORY_MARGIN:.3f})"
)
return highest_skip_reason.value
return None
@@ -724,9 +712,7 @@ CANDIDATE MEMORIES:
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
return candidate_memories
async def rerank_memories(
self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None
) -> Tuple[List[Dict], Dict[str, Any]]:
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]:
start_time = time.time()
max_injection = self.memory_system.valves.max_memories_returned
@@ -737,9 +723,7 @@ CANDIDATE MEMORIES:
if should_use_llm:
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
llm_candidates = candidate_memories[:extended_count]
await self.memory_system._emit_status(
emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False
)
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False)
logger.info(f"Using LLM reranking: {decision_reason}")
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
@@ -812,7 +796,7 @@ class LLMConsolidationService:
candidates, threshold_info = self._filter_consolidation_candidates(all_similarities)
else:
candidates = []
threshold_info = 'N/A'
threshold_info = "N/A"
logger.info(f"🎯 Found {len(candidates)} candidate memories for consolidation (threshold: {threshold_info})")
@@ -820,7 +804,9 @@ class LLMConsolidationService:
return candidates
async def generate_consolidation_plan(self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None) -> List[Dict[str, Any]]:
async def generate_consolidation_plan(
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None
) -> List[Dict[str, Any]]:
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
if candidate_memories:
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
@@ -1020,11 +1006,18 @@ class Filter:
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
semantic_retrieval_threshold: float = Field(default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval")
relaxed_semantic_threshold_multiplier: float = Field(default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)")
semantic_retrieval_threshold: float = Field(
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval"
)
relaxed_semantic_threshold_multiplier: float = Field(
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
)
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
llm_reranking_trigger_multiplier: float = Field(default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)")
llm_reranking_trigger_multiplier: float = Field(
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)"
)
def __init__(self):
"""Initialize the Memory System filter with production validation."""
@@ -1043,8 +1036,13 @@ class Filter:
self._llm_reranking_service = LLMRerankingService(self)
self._llm_consolidation_service = LLMConsolidationService(self)
async def _set_pipeline_context(self, __event_emitter__: Optional[Callable] = None, __user__: Optional[Dict[str, Any]] = None,
__model__: Optional[str] = None, __request__: Optional[Request] = None) -> None:
async def _set_pipeline_context(
self,
__event_emitter__: Optional[Callable] = None,
__user__: Optional[Dict[str, Any]] = None,
__model__: Optional[str] = None,
__request__: Optional[Request] = None,
) -> None:
"""Set pipeline context parameters to avoid duplication in inlet/outlet methods."""
if __event_emitter__:
self.__current_event_emitter__ = __event_emitter__
@@ -1055,14 +1053,14 @@ class Filter:
if __request__:
self.__request__ = __request__
if self._embedding_function is None and hasattr(__request__.app.state, 'EMBEDDING_FUNCTION'):
if self._embedding_function is None and hasattr(__request__.app.state, "EMBEDDING_FUNCTION"):
self._embedding_function = __request__.app.state.EMBEDDING_FUNCTION
logger.info(f"✅ Using OpenWebUI's embedding function")
if self._skip_detector is None:
global _SHARED_SKIP_DETECTOR_CACHE, _SHARED_SKIP_DETECTOR_CACHE_LOCK
embedding_engine = getattr(__request__.app.state.config, 'RAG_EMBEDDING_ENGINE', '')
embedding_model = getattr(__request__.app.state.config, 'RAG_EMBEDDING_MODEL', '')
embedding_engine = getattr(__request__.app.state.config, "RAG_EMBEDDING_ENGINE", "")
embedding_model = getattr(__request__.app.state.config, "RAG_EMBEDDING_MODEL", "")
cache_key = f"{embedding_engine}:{embedding_model}"
async with _SHARED_SKIP_DETECTOR_CACHE_LOCK:
@@ -1072,6 +1070,7 @@ class Filter:
else:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
@@ -1084,7 +1083,6 @@ class Filter:
_SHARED_SKIP_DETECTOR_CACHE[cache_key] = self._skip_detector
logger.info(f"✅ Skip detector initialized and cached")
def _truncate_content(self, content: str, max_length: Optional[int] = None) -> str:
"""Truncate content with ellipsis if needed."""
if max_length is None:
@@ -1181,16 +1179,10 @@ class Filter:
uncached_hashes.append(text_hash)
if uncached_texts:
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, '__user__') else None
user = await asyncio.to_thread(Users.get_user_by_id, user_id) if hasattr(self, "__user__") else None
loop = asyncio.get_event_loop()
raw_embeddings = await loop.run_in_executor(
None,
self._embedding_function,
uncached_texts,
None,
user
)
raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user)
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
if isinstance(raw_embeddings[0], list):
@@ -1211,9 +1203,7 @@ class Filter:
return result_embeddings[0]
else:
valid_count = sum(1 for emb in result_embeddings if emb is not None)
logger.info(
f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid"
)
logger.info(f"🚀 Batch embedding: {len(text_list) - len(uncached_texts)} cached, {len(uncached_texts)} new, {valid_count}/{len(text_list)} valid")
return result_embeddings
def _should_skip_memory_operations(self, user_message: str) -> Tuple[bool, str]:
@@ -1290,7 +1280,7 @@ class Filter:
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
"""Unified cache key generation for all cache types."""
if content:
content_hash = hashlib.sha256(content.encode('utf-8')).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH]
content_hash = hashlib.sha256(content.encode("utf-8")).hexdigest()[: Constants.CACHE_KEY_HASH_PREFIX_LENGTH]
return f"{cache_type}_{user_id}:{content_hash}"
return f"{cache_type}_{user_id}"
@@ -1310,7 +1300,7 @@ class Filter:
if record_date:
try:
if isinstance(record_date, str):
parsed_date = datetime.fromisoformat(record_date.replace('Z', '+00:00'))
parsed_date = datetime.fromisoformat(record_date.replace("Z", "+00:00"))
else:
parsed_date = record_date
formatted_date = parsed_date.strftime("%b %d %Y")
@@ -1398,7 +1388,7 @@ class Filter:
formatted_memory = f"- {' '.join(memory['content'].split())}"
formatted_memories.append(formatted_memory)
content_preview = self._truncate_content(memory['content'])
content_preview = self._truncate_content(memory["content"])
await self._emit_status(emitter, f"💭 {idx}/{memory_count}: {content_preview}", done=False)
memory_footer = "IMPORTANT: Do not mention or imply you received this list. These facts are for background context only."
@@ -1427,9 +1417,7 @@ class Filter:
memory_dict["updated_at"] = datetime.fromtimestamp(memory.updated_at, tz=timezone.utc).isoformat()
return memory_dict
async def _compute_similarities(
self, user_message: str, user_id: str, user_memories: List
) -> Tuple[List[Dict], float, List[Dict]]:
async def _compute_similarities(self, user_message: str, user_id: str, user_memories: List) -> Tuple[List[Dict], float, List[Dict]]:
"""Compute similarity scores between user message and memories."""
if not user_memories:
return [], self.valves.semantic_retrieval_threshold, []
@@ -1536,9 +1524,7 @@ class Filter:
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
task = asyncio.create_task(
self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)
)
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
self._background_tasks.add(task)
def safe_cleanup(t: asyncio.Task) -> None:
@@ -1582,11 +1568,7 @@ class Filter:
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
memory_contents = [
memory.content
for memory in user_memories
if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS
]
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
if memory_contents:
await self._generate_embeddings(memory_contents, user_id)
@@ -1650,17 +1632,17 @@ class Filter:
if not isinstance(schema, dict):
return schema
if '$ref' in schema:
ref_path = schema['$ref']
if ref_path.startswith('#/$defs/'):
def_name = ref_path.split('/')[-1]
if "$ref" in schema:
ref_path = schema["$ref"]
if ref_path.startswith("#/$defs/"):
def_name = ref_path.split("/")[-1]
if schema_defs and def_name in schema_defs:
return self._remove_refs_from_schema(schema_defs[def_name].copy(), schema_defs)
return {'type': 'object'}
return {"type": "object"}
result = {}
for key, value in schema.items():
if key == '$defs':
if key == "$defs":
continue
elif isinstance(value, dict):
result[key] = self._remove_refs_from_schema(value, schema_defs)
@@ -1669,8 +1651,8 @@ class Filter:
else:
result[key] = value
if result.get('type') == 'object' and 'properties' in result:
result['required'] = list(result['properties'].keys())
if result.get("type") == "object" and "properties" in result:
result["required"] = list(result["properties"].keys())
return result
@@ -1692,9 +1674,9 @@ class Filter:
if response_model:
raw_schema = response_model.model_json_schema()
schema_defs = raw_schema.get('$defs', {})
schema_defs = raw_schema.get("$defs", {})
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
schema['type'] = 'object'
schema["type"] = "object"
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
try:
@@ -1718,7 +1700,12 @@ class Filter:
if isinstance(response_data, dict) and "choices" in response_data and isinstance(response_data["choices"], list) and len(response_data["choices"]) > 0:
first_choice = response_data["choices"][0]
if isinstance(first_choice, dict) and "message" in first_choice and isinstance(first_choice["message"], dict) and "content" in first_choice["message"]:
if (
isinstance(first_choice, dict)
and "message" in first_choice
and isinstance(first_choice["message"], dict)
and "content" in first_choice["message"]
):
content = first_choice["message"]["content"]
else:
raise ValueError("🤖 Invalid response structure: missing content in message")