Merge pull request #3 from GlisseManTV/dev

Way to use current model instead of dedicated model.
This commit is contained in:
M. Tayfur
2025-10-27 23:39:06 +03:00
committed by GitHub
3 changed files with 294 additions and 72 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
__pycache__/ __pycache__/
.github/instructions/*
.venv/ .venv/
**AGENTS.md
tests/ tests/

View File

@@ -1,6 +1,9 @@
""" """
title: Memory System title: Memory System
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
version: 1.0.0 version: 1.0.0
authors: https://github.com/mtayfur
license: Apache-2.0
""" """
import asyncio import asyncio
@@ -217,7 +220,10 @@ class Models:
if self.operation == Models.MemoryOperationType.CREATE: if self.operation == Models.MemoryOperationType.CREATE:
return True return True
elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]: elif self.operation in [
Models.MemoryOperationType.UPDATE,
Models.MemoryOperationType.DELETE,
]:
return self.id in existing_memory_ids return self.id in existing_memory_ids
return False return False
@@ -229,7 +235,10 @@ class Models:
class MemoryRerankingResponse(BaseModel): class MemoryRerankingResponse(BaseModel):
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" """Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query") ids: List[str] = Field(
default_factory=list,
description="List of memory IDs selected as most relevant for the user query",
)
class UnifiedCacheManager: class UnifiedCacheManager:
@@ -440,7 +449,10 @@ class SkipDetector:
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations", SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
} }
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): def __init__(
self,
embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]],
):
"""Initialize the skip detector with an embedding function and compute reference embeddings.""" """Initialize the skip detector with an embedding function and compute reference embeddings."""
self.embedding_function = embedding_function self.embedding_function = embedding_function
self._reference_embeddings = None self._reference_embeddings = None
@@ -583,7 +595,16 @@ class SkipDetector:
if markup_in_lines / len(non_empty_lines) > 0.3: if markup_in_lines / len(non_empty_lines) > 0.3:
return self.SkipReason.SKIP_TECHNICAL.value return self.SkipReason.SKIP_TECHNICAL.value
elif structured_lines / len(non_empty_lines) > 0.6: elif structured_lines / len(non_empty_lines) > 0.6:
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"] technical_keywords = [
"function",
"class",
"import",
"return",
"const",
"var",
"let",
"def",
]
if any(keyword in message.lower() for keyword in technical_keywords): if any(keyword in message.lower() for keyword in technical_keywords):
return self.SkipReason.SKIP_TECHNICAL.value return self.SkipReason.SKIP_TECHNICAL.value
@@ -594,7 +615,18 @@ class SkipDetector:
if non_empty_lines: if non_empty_lines:
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t")) indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
if indented_lines / len(non_empty_lines) > 0.5: if indented_lines / len(non_empty_lines) > 0.5:
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "] code_indicators = [
"def ",
"class ",
"function ",
"return ",
"import ",
"const ",
"let ",
"var ",
"public ",
"private ",
]
if any(indicator in message.lower() for indicator in code_indicators): if any(indicator in message.lower() for indicator in code_indicators):
return self.SkipReason.SKIP_TECHNICAL.value return self.SkipReason.SKIP_TECHNICAL.value
@@ -637,11 +669,31 @@ class SkipDetector:
max_conversational_similarity = float(conversational_similarities.max()) max_conversational_similarity = float(conversational_similarities.max())
skip_categories = [ skip_categories = [
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), (
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), "instruction",
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), self.SkipReason.SKIP_INSTRUCTION,
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), ),
(
"translation",
self.SkipReason.SKIP_TRANSLATION,
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
),
(
"grammar",
self.SkipReason.SKIP_GRAMMAR_PROOFREAD,
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
),
(
"technical",
self.SkipReason.SKIP_TECHNICAL,
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
),
(
"pure_math",
self.SkipReason.SKIP_PURE_MATH,
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
),
] ]
qualifying_categories = [] qualifying_categories = []
@@ -680,11 +732,23 @@ class LLMRerankingService:
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
if len(memories) > llm_trigger_threshold: if len(memories) > llm_trigger_threshold:
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" return (
True,
f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold",
)
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}" return (
False,
f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}",
)
async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]: async def _llm_select_memories(
self,
user_message: str,
candidate_memories: List[Dict],
max_count: int,
emitter: Optional[Callable] = None,
) -> List[Dict]:
"""Use LLM to select most relevant memories.""" """Use LLM to select most relevant memories."""
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
memory_context = "\n".join(memory_lines) memory_context = "\n".join(memory_lines)
@@ -697,7 +761,11 @@ CANDIDATE MEMORIES:
{memory_context}""" {memory_context}"""
try: try:
response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse) response = await self.memory_system._query_llm(
Prompts.MEMORY_RERANKING,
user_prompt,
response_model=Models.MemoryRerankingResponse,
)
selected_memories = [] selected_memories = []
for memory in candidate_memories: for memory in candidate_memories:
@@ -712,18 +780,31 @@ CANDIDATE MEMORIES:
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}") logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
return candidate_memories return candidate_memories
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]: async def rerank_memories(
self,
user_message: str,
candidate_memories: List[Dict],
emitter: Optional[Callable] = None,
) -> Tuple[List[Dict], Dict[str, Any]]:
start_time = time.time() start_time = time.time()
max_injection = self.memory_system.valves.max_memories_returned max_injection = self.memory_system.valves.max_memories_returned
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} analysis_info = {
"llm_decision": should_use_llm,
"decision_reason": decision_reason,
"candidate_count": len(candidate_memories),
}
if should_use_llm: if should_use_llm:
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
llm_candidates = candidate_memories[:extended_count] llm_candidates = candidate_memories[:extended_count]
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False) await self.memory_system._emit_status(
emitter,
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
done=False,
)
logger.info(f"Using LLM reranking: {decision_reason}") logger.info(f"Using LLM reranking: {decision_reason}")
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
@@ -739,7 +820,11 @@ CANDIDATE MEMORIES:
duration = time.time() - start_time duration = time.time() - start_time
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
retrieval_method = "LLM" if should_use_llm else "Semantic" retrieval_method = "LLM" if should_use_llm else "Semantic"
await self.memory_system._emit_status(emitter, f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", done=True) await self.memory_system._emit_status(
emitter,
f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}",
done=True,
)
return selected_memories, analysis_info return selected_memories, analysis_info
@@ -761,7 +846,10 @@ class LLMConsolidationService:
return candidates, threshold_info return candidates, threshold_info
async def collect_consolidation_candidates( async def collect_consolidation_candidates(
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None self,
user_message: str,
user_id: str,
cached_similarities: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Collect candidate memories for consolidation analysis using cached or computed similarities.""" """Collect candidate memories for consolidation analysis using cached or computed similarities."""
if cached_similarities: if cached_similarities:
@@ -805,7 +893,10 @@ class LLMConsolidationService:
return candidates return candidates
async def generate_consolidation_plan( async def generate_consolidation_plan(
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None self,
user_message: str,
candidate_memories: List[Dict[str, Any]],
emitter: Optional[Callable] = None,
) -> List[Dict[str, Any]]: ) -> List[Dict[str, Any]]:
"""Generate consolidation plan using LLM with clear system/user prompt separation.""" """Generate consolidation plan using LLM with clear system/user prompt separation."""
if candidate_memories: if candidate_memories:
@@ -820,7 +911,11 @@ class LLMConsolidationService:
try: try:
response = await asyncio.wait_for( response = await asyncio.wait_for(
self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse), self.memory_system._query_llm(
Prompts.MEMORY_CONSOLIDATION,
user_prompt,
response_model=Models.ConsolidationResponse,
),
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
) )
except Exception as e: except Exception as e:
@@ -856,13 +951,21 @@ class LLMConsolidationService:
return valid_operations return valid_operations
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: async def execute_memory_operations(
self,
operations: List[Dict[str, Any]],
user_id: str,
emitter: Optional[Callable] = None,
) -> Tuple[int, int, int, int]:
"""Execute consolidation operations with simplified tracking.""" """Execute consolidation operations with simplified tracking."""
if not operations or not user_id: if not operations or not user_id:
return 0, 0, 0, 0 return 0, 0, 0, 0
try: try:
user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC) user = await asyncio.wait_for(
asyncio.to_thread(Users.get_user_by_id, user_id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
except Exception as e: except Exception as e:
@@ -929,7 +1032,10 @@ class LLMConsolidationService:
if content_preview and content_preview != operation.id: if content_preview and content_preview != operation.id:
content_preview = self.memory_system._truncate_content(content_preview) content_preview = self.memory_system._truncate_content(content_preview)
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]: elif result in [
Models.OperationResult.FAILED.value,
Models.OperationResult.UNSUPPORTED.value,
]:
failed_count += 1 failed_count += 1
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
except Exception as e: except Exception as e:
@@ -950,7 +1056,11 @@ class LLMConsolidationService:
return created_count, updated_count, deleted_count, failed_count return created_count, updated_count, deleted_count, failed_count
async def run_consolidation_pipeline( async def run_consolidation_pipeline(
self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None self,
user_message: str,
user_id: str,
emitter: Optional[Callable] = None,
cached_similarities: Optional[List[Dict[str, Any]]] = None,
) -> None: ) -> None:
"""Complete consolidation pipeline with simplified flow.""" """Complete consolidation pipeline with simplified flow."""
start_time = time.time() start_time = time.time()
@@ -974,7 +1084,11 @@ class LLMConsolidationService:
total_operations = created_count + updated_count + deleted_count total_operations = created_count + updated_count + deleted_count
if total_operations > 0 or failed_count > 0: if total_operations > 0 or failed_count > 0:
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False) await self.memory_system._emit_status(
emitter,
f"💾 Memory Consolidation Complete in {duration:.2f}s",
done=False,
)
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
memory_word = "Memory" if total_operations == 1 else "Memories" memory_word = "Memory" if total_operations == 1 else "Memories"
@@ -1001,22 +1115,41 @@ class Filter:
class Valves(BaseModel): class Valves(BaseModel):
"""Configuration valves for the Memory System.""" """Configuration valves for the Memory System."""
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") model: str = Field(
default=Constants.DEFAULT_LLM_MODEL,
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") description="Model name for LLM operations",
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") )
use_custom_model_for_memory: bool = Field(
default=False,
description="Use a custom model for memory operations instead of the current chat model",
)
custom_memory_model: str = Field(
default=Constants.DEFAULT_LLM_MODEL,
description="Custom model to use for memory operations when enabled",
)
max_memories_returned: int = Field(
default=Constants.MAX_MEMORIES_PER_RETRIEVAL,
description="Maximum number of memories to return in context",
)
max_message_chars: int = Field(
default=Constants.MAX_MESSAGE_CHARS,
description="Maximum user message length before skipping memory operations",
)
semantic_retrieval_threshold: float = Field( semantic_retrieval_threshold: float = Field(
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval" default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD,
description="Minimum similarity threshold for memory retrieval",
) )
relaxed_semantic_threshold_multiplier: float = Field( relaxed_semantic_threshold_multiplier: float = Field(
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)", description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
) )
enable_llm_reranking: bool = Field(
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") default=True,
description="Enable LLM-based memory reranking for improved contextual selection",
)
llm_reranking_trigger_multiplier: float = Field( llm_reranking_trigger_multiplier: float = Field(
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)" default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER,
description="Controls when LLM reranking activates (lower = more aggressive)",
) )
def __init__(self): def __init__(self):
@@ -1071,7 +1204,9 @@ class Filter:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: def embedding_wrapper(
texts: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None) result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list): if isinstance(result, list):
if isinstance(result[0], list): if isinstance(result[0], list):
@@ -1219,7 +1354,11 @@ class Filter:
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped.""" """Extract user message and determine if memory operations should be skipped."""
if not body or "messages" not in body or not isinstance(body["messages"], list): if not body or "messages" not in body or not isinstance(body["messages"], list):
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
messages = body["messages"] messages = body["messages"]
user_message = None user_message = None
@@ -1235,7 +1374,11 @@ class Filter:
break break
if not user_message or not user_message.strip(): if not user_message or not user_message.strip():
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
should_skip, skip_reason = self._should_skip_memory_operations(user_message) should_skip, skip_reason = self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason return user_message, should_skip, skip_reason
@@ -1245,7 +1388,10 @@ class Filter:
if timeout is None: if timeout is None:
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
try: try:
return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout) return await asyncio.wait_for(
asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
timeout=timeout,
)
except asyncio.TimeoutError: except asyncio.TimeoutError:
raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s") raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s")
except Exception as e: except Exception as e:
@@ -1274,7 +1420,11 @@ class Filter:
logger.info(f"Scores: [{scores_str}{suffix}]") logger.info(f"Scores: [{scores_str}{suffix}]")
def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]: def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]:
operations = [(created_count, "📝 Created"), (updated_count, "✏️ Updated"), (deleted_count, "🗑️ Deleted")] operations = [
(created_count, "📝 Created"),
(updated_count, "✏️ Updated"),
(deleted_count, "🗑️ Deleted"),
]
return [f"{label} {count}" for count, label in operations if count > 0] return [f"{label} {count}" for count, label in operations if count > 0]
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
@@ -1366,10 +1516,19 @@ class Filter:
self._log_retrieved_memories(final_memories, "semantic") self._log_retrieved_memories(final_memories, "semantic")
return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info} return {
"memories": final_memories,
"threshold": threshold,
"all_similarities": all_similarities,
"reranking_info": reranking_info,
}
async def _add_memory_context( async def _add_memory_context(
self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None self,
body: Dict[str, Any],
memories: Optional[List[Dict[str, Any]]] = None,
user_id: Optional[str] = None,
emitter: Optional[Callable] = None,
) -> None: ) -> None:
"""Add memory context to request body with simplified logic.""" """Add memory context to request body with simplified logic."""
if not body or "messages" not in body or not body["messages"]: if not body or "messages" not in body or not body["messages"]:
@@ -1397,7 +1556,10 @@ class Filter:
memory_context = "\n\n".join(content_parts) memory_context = "\n\n".join(content_parts)
system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None) system_index = next(
(i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"),
None,
)
if system_index is not None: if system_index is not None:
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
@@ -1410,7 +1572,11 @@ class Filter:
def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]: def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]:
"""Build memory dictionary with standardized timestamp conversion.""" """Build memory dictionary with standardized timestamp conversion."""
memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity} memory_dict = {
"id": str(memory.id),
"content": memory.content,
"relevance": similarity,
}
if hasattr(memory, "created_at") and memory.created_at: if hasattr(memory, "created_at") and memory.created_at:
memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat() memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat()
if hasattr(memory, "updated_at") and memory.updated_at: if hasattr(memory, "updated_at") and memory.updated_at:
@@ -1462,42 +1628,58 @@ class Filter:
**kwargs, **kwargs,
) -> Dict[str, Any]: ) -> Dict[str, Any]:
"""Simplified inlet processing for memory retrieval and injection.""" """Simplified inlet processing for memory retrieval and injection."""
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
model_to_use = body.get("model") if isinstance(body, dict) else None
if not model_to_use:
model_to_use = __model__ or getattr(__request__.state, "model", None)
if not model_to_use:
model_to_use = Constants.DEFAULT_LLM_MODEL
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
model_to_use = self.valves.custom_memory_model
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
self.valves.model = model_to_use
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id:
return body return body
user_message, should_skip, skip_reason = self._process_user_message(body) user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip: if not user_message or should_skip:
if __event_emitter__ and skip_reason: if __event_emitter__ and skip_reason:
await self._emit_status(__event_emitter__, skip_reason, done=True) await self._emit_status(__event_emitter__, skip_reason, done=True)
await self._add_memory_context(body, [], user_id, __event_emitter__) await self._add_memory_context(body, [], user_id, __event_emitter__)
return body return body
try: try:
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key) user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key)
if user_memories is None: if user_memories is None:
user_memories = await self._get_user_memories(user_id) user_memories = await self._get_user_memories(user_id)
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) await self._cache_manager.put(
user_id,
self._cache_manager.MEMORY_CACHE,
memory_cache_key,
user_memories,
)
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
memories = retrieval_result.get("memories", []) memories = retrieval_result.get("memories", [])
threshold = retrieval_result.get("threshold") threshold = retrieval_result.get("threshold")
all_similarities = retrieval_result.get("all_similarities", []) all_similarities = retrieval_result.get("all_similarities", [])
if all_similarities: if all_similarities:
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities) await self._cache_manager.put(
user_id,
self._cache_manager.RETRIEVAL_CACHE,
cache_key,
all_similarities,
)
await self._add_memory_context(body, memories, user_id, __event_emitter__) await self._add_memory_context(body, memories, user_id, __event_emitter__)
except Exception as e: except Exception as e:
raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}") raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}")
return body return body
async def outlet( async def outlet(
@@ -1510,20 +1692,30 @@ class Filter:
**kwargs, **kwargs,
) -> dict: ) -> dict:
"""Simplified outlet processing for background memory consolidation.""" """Simplified outlet processing for background memory consolidation."""
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
model_to_use = body.get("model") if isinstance(body, dict) else None
if not model_to_use:
model_to_use = __model__ or getattr(__request__.state, "model", None)
if not model_to_use:
model_to_use = Constants.DEFAULT_LLM_MODEL
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
model_to_use = self.valves.custom_memory_model
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
self.valves.model = model_to_use
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None user_id = __user__.get("id") if body and __user__ else None
if not user_id: if not user_id:
return body return body
user_message, should_skip, skip_reason = self._process_user_message(body) user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip: if not user_message or should_skip:
return body return body
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
self._background_tasks.add(task) self._background_tasks.add(task)
@@ -1537,7 +1729,6 @@ class Filter:
logger.error(f"❌ Failed to cleanup background memory task: {str(e)}") logger.error(f"❌ Failed to cleanup background memory task: {str(e)}")
task.add_done_callback(safe_cleanup) task.add_done_callback(safe_cleanup)
return body return body
async def shutdown(self) -> None: async def shutdown(self) -> None:
@@ -1566,7 +1757,12 @@ class Filter:
logger.info("📭 No memories found for user") logger.info("📭 No memories found for user")
return return
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) await self._cache_manager.put(
user_id,
self._cache_manager.MEMORY_CACHE,
memory_cache_key,
user_memories,
)
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS] memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
@@ -1588,7 +1784,8 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.CREATE.value return Models.MemoryOperationType.CREATE.value
@@ -1604,7 +1801,12 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped), asyncio.to_thread(
Memories.update_memory_by_id_and_user_id,
id_stripped,
user.id,
content_stripped,
),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.UPDATE.value return Models.MemoryOperationType.UPDATE.value
@@ -1616,7 +1818,8 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_ID.value return Models.OperationResult.SKIPPED_EMPTY_ID.value
await asyncio.wait_for( await asyncio.wait_for(
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
) )
return Models.MemoryOperationType.DELETE.value return Models.MemoryOperationType.DELETE.value
else: else:
@@ -1647,7 +1850,7 @@ class Filter:
elif isinstance(value, dict): elif isinstance(value, dict):
result[key] = self._remove_refs_from_schema(value, schema_defs) result[key] = self._remove_refs_from_schema(value, schema_defs)
elif isinstance(value, list): elif isinstance(value, list):
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value] result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value]
else: else:
result[key] = value result[key] = value
@@ -1656,7 +1859,12 @@ class Filter:
return result return result
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]: async def _query_llm(
self,
system_prompt: str,
user_prompt: str,
response_model: Optional[BaseModel] = None,
) -> Union[str, BaseModel]:
"""Query OpenWebUI's internal model system with Pydantic model parsing.""" """Query OpenWebUI's internal model system with Pydantic model parsing."""
if not hasattr(self, "__request__") or not hasattr(self, "__user__"): if not hasattr(self, "__request__") or not hasattr(self, "__user__"):
raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.") raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.")
@@ -1667,7 +1875,10 @@ class Filter:
form_data = { form_data = {
"model": model_to_use, "model": model_to_use,
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], "messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"max_tokens": 4096, "max_tokens": 4096,
"stream": False, "stream": False,
} }
@@ -1677,11 +1888,22 @@ class Filter:
schema_defs = raw_schema.get("$defs", {}) schema_defs = raw_schema.get("$defs", {})
schema = self._remove_refs_from_schema(raw_schema, schema_defs) schema = self._remove_refs_from_schema(raw_schema, schema_defs)
schema["type"] = "object" schema["type"] = "object"
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}} form_data["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"strict": True,
"schema": schema,
},
}
try: try:
response = await asyncio.wait_for( response = await asyncio.wait_for(
generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])), generate_chat_completion(
self.__request__,
form_data,
user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]),
),
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
) )
except asyncio.TimeoutError: except asyncio.TimeoutError:

View File

@@ -4,4 +4,4 @@ pydantic
numpy numpy
tiktoken tiktoken
black black
isort isort