mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Merge pull request #3 from GlisseManTV/dev
Way to use current model instead of dedicated model.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
||||
__pycache__/
|
||||
.github/instructions/*
|
||||
.venv/
|
||||
**AGENTS.md
|
||||
tests/
|
||||
362
memory_system.py
362
memory_system.py
@@ -1,6 +1,9 @@
|
||||
"""
|
||||
title: Memory System
|
||||
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
|
||||
version: 1.0.0
|
||||
authors: https://github.com/mtayfur
|
||||
license: Apache-2.0
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
@@ -217,7 +220,10 @@ class Models:
|
||||
|
||||
if self.operation == Models.MemoryOperationType.CREATE:
|
||||
return True
|
||||
elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]:
|
||||
elif self.operation in [
|
||||
Models.MemoryOperationType.UPDATE,
|
||||
Models.MemoryOperationType.DELETE,
|
||||
]:
|
||||
return self.id in existing_memory_ids
|
||||
return False
|
||||
|
||||
@@ -229,7 +235,10 @@ class Models:
|
||||
class MemoryRerankingResponse(BaseModel):
|
||||
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
|
||||
|
||||
ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query")
|
||||
ids: List[str] = Field(
|
||||
default_factory=list,
|
||||
description="List of memory IDs selected as most relevant for the user query",
|
||||
)
|
||||
|
||||
|
||||
class UnifiedCacheManager:
|
||||
@@ -440,7 +449,10 @@ class SkipDetector:
|
||||
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
|
||||
}
|
||||
|
||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
|
||||
def __init__(
|
||||
self,
|
||||
embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]],
|
||||
):
|
||||
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
|
||||
self.embedding_function = embedding_function
|
||||
self._reference_embeddings = None
|
||||
@@ -583,7 +595,16 @@ class SkipDetector:
|
||||
if markup_in_lines / len(non_empty_lines) > 0.3:
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
elif structured_lines / len(non_empty_lines) > 0.6:
|
||||
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
|
||||
technical_keywords = [
|
||||
"function",
|
||||
"class",
|
||||
"import",
|
||||
"return",
|
||||
"const",
|
||||
"var",
|
||||
"let",
|
||||
"def",
|
||||
]
|
||||
if any(keyword in message.lower() for keyword in technical_keywords):
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
|
||||
@@ -594,7 +615,18 @@ class SkipDetector:
|
||||
if non_empty_lines:
|
||||
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
|
||||
if indented_lines / len(non_empty_lines) > 0.5:
|
||||
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "]
|
||||
code_indicators = [
|
||||
"def ",
|
||||
"class ",
|
||||
"function ",
|
||||
"return ",
|
||||
"import ",
|
||||
"const ",
|
||||
"let ",
|
||||
"var ",
|
||||
"public ",
|
||||
"private ",
|
||||
]
|
||||
if any(indicator in message.lower() for indicator in code_indicators):
|
||||
return self.SkipReason.SKIP_TECHNICAL.value
|
||||
|
||||
@@ -637,11 +669,31 @@ class SkipDetector:
|
||||
max_conversational_similarity = float(conversational_similarities.max())
|
||||
|
||||
skip_categories = [
|
||||
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
|
||||
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
|
||||
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
|
||||
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
|
||||
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
|
||||
(
|
||||
"instruction",
|
||||
self.SkipReason.SKIP_INSTRUCTION,
|
||||
self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
|
||||
),
|
||||
(
|
||||
"translation",
|
||||
self.SkipReason.SKIP_TRANSLATION,
|
||||
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
|
||||
),
|
||||
(
|
||||
"grammar",
|
||||
self.SkipReason.SKIP_GRAMMAR_PROOFREAD,
|
||||
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
|
||||
),
|
||||
(
|
||||
"technical",
|
||||
self.SkipReason.SKIP_TECHNICAL,
|
||||
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
|
||||
),
|
||||
(
|
||||
"pure_math",
|
||||
self.SkipReason.SKIP_PURE_MATH,
|
||||
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
|
||||
),
|
||||
]
|
||||
|
||||
qualifying_categories = []
|
||||
@@ -680,11 +732,23 @@ class LLMRerankingService:
|
||||
|
||||
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
|
||||
if len(memories) > llm_trigger_threshold:
|
||||
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold"
|
||||
return (
|
||||
True,
|
||||
f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold",
|
||||
)
|
||||
|
||||
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}"
|
||||
return (
|
||||
False,
|
||||
f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}",
|
||||
)
|
||||
|
||||
async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]:
|
||||
async def _llm_select_memories(
|
||||
self,
|
||||
user_message: str,
|
||||
candidate_memories: List[Dict],
|
||||
max_count: int,
|
||||
emitter: Optional[Callable] = None,
|
||||
) -> List[Dict]:
|
||||
"""Use LLM to select most relevant memories."""
|
||||
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
|
||||
memory_context = "\n".join(memory_lines)
|
||||
@@ -697,7 +761,11 @@ CANDIDATE MEMORIES:
|
||||
{memory_context}"""
|
||||
|
||||
try:
|
||||
response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse)
|
||||
response = await self.memory_system._query_llm(
|
||||
Prompts.MEMORY_RERANKING,
|
||||
user_prompt,
|
||||
response_model=Models.MemoryRerankingResponse,
|
||||
)
|
||||
|
||||
selected_memories = []
|
||||
for memory in candidate_memories:
|
||||
@@ -712,18 +780,31 @@ CANDIDATE MEMORIES:
|
||||
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
|
||||
return candidate_memories
|
||||
|
||||
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
async def rerank_memories(
|
||||
self,
|
||||
user_message: str,
|
||||
candidate_memories: List[Dict],
|
||||
emitter: Optional[Callable] = None,
|
||||
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||
start_time = time.time()
|
||||
max_injection = self.memory_system.valves.max_memories_returned
|
||||
|
||||
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
|
||||
|
||||
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)}
|
||||
analysis_info = {
|
||||
"llm_decision": should_use_llm,
|
||||
"decision_reason": decision_reason,
|
||||
"candidate_count": len(candidate_memories),
|
||||
}
|
||||
|
||||
if should_use_llm:
|
||||
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||
llm_candidates = candidate_memories[:extended_count]
|
||||
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False)
|
||||
await self.memory_system._emit_status(
|
||||
emitter,
|
||||
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
|
||||
done=False,
|
||||
)
|
||||
logger.info(f"Using LLM reranking: {decision_reason}")
|
||||
|
||||
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
|
||||
@@ -739,7 +820,11 @@ CANDIDATE MEMORIES:
|
||||
duration = time.time() - start_time
|
||||
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
|
||||
retrieval_method = "LLM" if should_use_llm else "Semantic"
|
||||
await self.memory_system._emit_status(emitter, f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", done=True)
|
||||
await self.memory_system._emit_status(
|
||||
emitter,
|
||||
f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}",
|
||||
done=True,
|
||||
)
|
||||
return selected_memories, analysis_info
|
||||
|
||||
|
||||
@@ -761,7 +846,10 @@ class LLMConsolidationService:
|
||||
return candidates, threshold_info
|
||||
|
||||
async def collect_consolidation_candidates(
|
||||
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
user_message: str,
|
||||
user_id: str,
|
||||
cached_similarities: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
|
||||
if cached_similarities:
|
||||
@@ -805,7 +893,10 @@ class LLMConsolidationService:
|
||||
return candidates
|
||||
|
||||
async def generate_consolidation_plan(
|
||||
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None
|
||||
self,
|
||||
user_message: str,
|
||||
candidate_memories: List[Dict[str, Any]],
|
||||
emitter: Optional[Callable] = None,
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
|
||||
if candidate_memories:
|
||||
@@ -820,7 +911,11 @@ class LLMConsolidationService:
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse),
|
||||
self.memory_system._query_llm(
|
||||
Prompts.MEMORY_CONSOLIDATION,
|
||||
user_prompt,
|
||||
response_model=Models.ConsolidationResponse,
|
||||
),
|
||||
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -856,13 +951,21 @@ class LLMConsolidationService:
|
||||
|
||||
return valid_operations
|
||||
|
||||
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]:
|
||||
async def execute_memory_operations(
|
||||
self,
|
||||
operations: List[Dict[str, Any]],
|
||||
user_id: str,
|
||||
emitter: Optional[Callable] = None,
|
||||
) -> Tuple[int, int, int, int]:
|
||||
"""Execute consolidation operations with simplified tracking."""
|
||||
if not operations or not user_id:
|
||||
return 0, 0, 0, 0
|
||||
|
||||
try:
|
||||
user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC)
|
||||
user = await asyncio.wait_for(
|
||||
asyncio.to_thread(Users.get_user_by_id, user_id),
|
||||
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
|
||||
except Exception as e:
|
||||
@@ -929,7 +1032,10 @@ class LLMConsolidationService:
|
||||
if content_preview and content_preview != operation.id:
|
||||
content_preview = self.memory_system._truncate_content(content_preview)
|
||||
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
|
||||
elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]:
|
||||
elif result in [
|
||||
Models.OperationResult.FAILED.value,
|
||||
Models.OperationResult.UNSUPPORTED.value,
|
||||
]:
|
||||
failed_count += 1
|
||||
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
|
||||
except Exception as e:
|
||||
@@ -950,7 +1056,11 @@ class LLMConsolidationService:
|
||||
return created_count, updated_count, deleted_count, failed_count
|
||||
|
||||
async def run_consolidation_pipeline(
|
||||
self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None
|
||||
self,
|
||||
user_message: str,
|
||||
user_id: str,
|
||||
emitter: Optional[Callable] = None,
|
||||
cached_similarities: Optional[List[Dict[str, Any]]] = None,
|
||||
) -> None:
|
||||
"""Complete consolidation pipeline with simplified flow."""
|
||||
start_time = time.time()
|
||||
@@ -974,7 +1084,11 @@ class LLMConsolidationService:
|
||||
|
||||
total_operations = created_count + updated_count + deleted_count
|
||||
if total_operations > 0 or failed_count > 0:
|
||||
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False)
|
||||
await self.memory_system._emit_status(
|
||||
emitter,
|
||||
f"💾 Memory Consolidation Complete in {duration:.2f}s",
|
||||
done=False,
|
||||
)
|
||||
|
||||
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
|
||||
memory_word = "Memory" if total_operations == 1 else "Memories"
|
||||
@@ -1001,22 +1115,41 @@ class Filter:
|
||||
class Valves(BaseModel):
|
||||
"""Configuration valves for the Memory System."""
|
||||
|
||||
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
||||
|
||||
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
||||
|
||||
model: str = Field(
|
||||
default=Constants.DEFAULT_LLM_MODEL,
|
||||
description="Model name for LLM operations",
|
||||
)
|
||||
use_custom_model_for_memory: bool = Field(
|
||||
default=False,
|
||||
description="Use a custom model for memory operations instead of the current chat model",
|
||||
)
|
||||
custom_memory_model: str = Field(
|
||||
default=Constants.DEFAULT_LLM_MODEL,
|
||||
description="Custom model to use for memory operations when enabled",
|
||||
)
|
||||
max_memories_returned: int = Field(
|
||||
default=Constants.MAX_MEMORIES_PER_RETRIEVAL,
|
||||
description="Maximum number of memories to return in context",
|
||||
)
|
||||
max_message_chars: int = Field(
|
||||
default=Constants.MAX_MESSAGE_CHARS,
|
||||
description="Maximum user message length before skipping memory operations",
|
||||
)
|
||||
semantic_retrieval_threshold: float = Field(
|
||||
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval"
|
||||
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD,
|
||||
description="Minimum similarity threshold for memory retrieval",
|
||||
)
|
||||
relaxed_semantic_threshold_multiplier: float = Field(
|
||||
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
|
||||
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
|
||||
)
|
||||
|
||||
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
|
||||
enable_llm_reranking: bool = Field(
|
||||
default=True,
|
||||
description="Enable LLM-based memory reranking for improved contextual selection",
|
||||
)
|
||||
llm_reranking_trigger_multiplier: float = Field(
|
||||
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)"
|
||||
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER,
|
||||
description="Controls when LLM reranking activates (lower = more aggressive)",
|
||||
)
|
||||
|
||||
def __init__(self):
|
||||
@@ -1071,7 +1204,9 @@ class Filter:
|
||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||
embedding_fn = self._embedding_function
|
||||
|
||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
def embedding_wrapper(
|
||||
texts: Union[str, List[str]],
|
||||
) -> Union[np.ndarray, List[np.ndarray]]:
|
||||
result = embedding_fn(texts, prefix=None, user=None)
|
||||
if isinstance(result, list):
|
||||
if isinstance(result[0], list):
|
||||
@@ -1219,7 +1354,11 @@ class Filter:
|
||||
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
||||
"""Extract user message and determine if memory operations should be skipped."""
|
||||
if not body or "messages" not in body or not isinstance(body["messages"], list):
|
||||
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
||||
return (
|
||||
None,
|
||||
True,
|
||||
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
|
||||
)
|
||||
|
||||
messages = body["messages"]
|
||||
user_message = None
|
||||
@@ -1235,7 +1374,11 @@ class Filter:
|
||||
break
|
||||
|
||||
if not user_message or not user_message.strip():
|
||||
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
||||
return (
|
||||
None,
|
||||
True,
|
||||
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
|
||||
)
|
||||
|
||||
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
|
||||
return user_message, should_skip, skip_reason
|
||||
@@ -1245,7 +1388,10 @@ class Filter:
|
||||
if timeout is None:
|
||||
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
||||
try:
|
||||
return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout)
|
||||
return await asyncio.wait_for(
|
||||
asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s")
|
||||
except Exception as e:
|
||||
@@ -1274,7 +1420,11 @@ class Filter:
|
||||
logger.info(f"Scores: [{scores_str}{suffix}]")
|
||||
|
||||
def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]:
|
||||
operations = [(created_count, "📝 Created"), (updated_count, "✏️ Updated"), (deleted_count, "🗑️ Deleted")]
|
||||
operations = [
|
||||
(created_count, "📝 Created"),
|
||||
(updated_count, "✏️ Updated"),
|
||||
(deleted_count, "🗑️ Deleted"),
|
||||
]
|
||||
return [f"{label} {count}" for count, label in operations if count > 0]
|
||||
|
||||
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
|
||||
@@ -1366,10 +1516,19 @@ class Filter:
|
||||
|
||||
self._log_retrieved_memories(final_memories, "semantic")
|
||||
|
||||
return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info}
|
||||
return {
|
||||
"memories": final_memories,
|
||||
"threshold": threshold,
|
||||
"all_similarities": all_similarities,
|
||||
"reranking_info": reranking_info,
|
||||
}
|
||||
|
||||
async def _add_memory_context(
|
||||
self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None
|
||||
self,
|
||||
body: Dict[str, Any],
|
||||
memories: Optional[List[Dict[str, Any]]] = None,
|
||||
user_id: Optional[str] = None,
|
||||
emitter: Optional[Callable] = None,
|
||||
) -> None:
|
||||
"""Add memory context to request body with simplified logic."""
|
||||
if not body or "messages" not in body or not body["messages"]:
|
||||
@@ -1397,7 +1556,10 @@ class Filter:
|
||||
|
||||
memory_context = "\n\n".join(content_parts)
|
||||
|
||||
system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None)
|
||||
system_index = next(
|
||||
(i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"),
|
||||
None,
|
||||
)
|
||||
|
||||
if system_index is not None:
|
||||
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
|
||||
@@ -1410,7 +1572,11 @@ class Filter:
|
||||
|
||||
def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]:
|
||||
"""Build memory dictionary with standardized timestamp conversion."""
|
||||
memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity}
|
||||
memory_dict = {
|
||||
"id": str(memory.id),
|
||||
"content": memory.content,
|
||||
"relevance": similarity,
|
||||
}
|
||||
if hasattr(memory, "created_at") and memory.created_at:
|
||||
memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat()
|
||||
if hasattr(memory, "updated_at") and memory.updated_at:
|
||||
@@ -1462,42 +1628,58 @@ class Filter:
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Simplified inlet processing for memory retrieval and injection."""
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
|
||||
model_to_use = body.get("model") if isinstance(body, dict) else None
|
||||
if not model_to_use:
|
||||
model_to_use = __model__ or getattr(__request__.state, "model", None)
|
||||
if not model_to_use:
|
||||
model_to_use = Constants.DEFAULT_LLM_MODEL
|
||||
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
|
||||
|
||||
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
|
||||
model_to_use = self.valves.custom_memory_model
|
||||
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
|
||||
|
||||
self.valves.model = model_to_use
|
||||
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
|
||||
|
||||
user_id = __user__.get("id") if body and __user__ else None
|
||||
if not user_id:
|
||||
return body
|
||||
|
||||
user_message, should_skip, skip_reason = self._process_user_message(body)
|
||||
|
||||
if not user_message or should_skip:
|
||||
if __event_emitter__ and skip_reason:
|
||||
await self._emit_status(__event_emitter__, skip_reason, done=True)
|
||||
await self._add_memory_context(body, [], user_id, __event_emitter__)
|
||||
return body
|
||||
|
||||
try:
|
||||
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
|
||||
user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key)
|
||||
|
||||
if user_memories is None:
|
||||
user_memories = await self._get_user_memories(user_id)
|
||||
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
||||
|
||||
await self._cache_manager.put(
|
||||
user_id,
|
||||
self._cache_manager.MEMORY_CACHE,
|
||||
memory_cache_key,
|
||||
user_memories,
|
||||
)
|
||||
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
|
||||
memories = retrieval_result.get("memories", [])
|
||||
threshold = retrieval_result.get("threshold")
|
||||
all_similarities = retrieval_result.get("all_similarities", [])
|
||||
|
||||
if all_similarities:
|
||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||
await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities)
|
||||
|
||||
await self._cache_manager.put(
|
||||
user_id,
|
||||
self._cache_manager.RETRIEVAL_CACHE,
|
||||
cache_key,
|
||||
all_similarities,
|
||||
)
|
||||
await self._add_memory_context(body, memories, user_id, __event_emitter__)
|
||||
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}")
|
||||
|
||||
return body
|
||||
|
||||
async def outlet(
|
||||
@@ -1510,20 +1692,30 @@ class Filter:
|
||||
**kwargs,
|
||||
) -> dict:
|
||||
"""Simplified outlet processing for background memory consolidation."""
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
||||
|
||||
model_to_use = body.get("model") if isinstance(body, dict) else None
|
||||
if not model_to_use:
|
||||
model_to_use = __model__ or getattr(__request__.state, "model", None)
|
||||
if not model_to_use:
|
||||
model_to_use = Constants.DEFAULT_LLM_MODEL
|
||||
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
|
||||
|
||||
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
|
||||
model_to_use = self.valves.custom_memory_model
|
||||
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
|
||||
|
||||
self.valves.model = model_to_use
|
||||
|
||||
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
|
||||
|
||||
user_id = __user__.get("id") if body and __user__ else None
|
||||
if not user_id:
|
||||
return body
|
||||
|
||||
user_message, should_skip, skip_reason = self._process_user_message(body)
|
||||
|
||||
if not user_message or should_skip:
|
||||
return body
|
||||
|
||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
|
||||
|
||||
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
|
||||
self._background_tasks.add(task)
|
||||
|
||||
@@ -1537,7 +1729,6 @@ class Filter:
|
||||
logger.error(f"❌ Failed to cleanup background memory task: {str(e)}")
|
||||
|
||||
task.add_done_callback(safe_cleanup)
|
||||
|
||||
return body
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
@@ -1566,7 +1757,12 @@ class Filter:
|
||||
logger.info("📭 No memories found for user")
|
||||
return
|
||||
|
||||
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
||||
await self._cache_manager.put(
|
||||
user_id,
|
||||
self._cache_manager.MEMORY_CACHE,
|
||||
memory_cache_key,
|
||||
user_memories,
|
||||
)
|
||||
|
||||
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
|
||||
|
||||
@@ -1588,7 +1784,8 @@ class Filter:
|
||||
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
||||
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
|
||||
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||
)
|
||||
return Models.MemoryOperationType.CREATE.value
|
||||
|
||||
@@ -1604,7 +1801,12 @@ class Filter:
|
||||
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped),
|
||||
asyncio.to_thread(
|
||||
Memories.update_memory_by_id_and_user_id,
|
||||
id_stripped,
|
||||
user.id,
|
||||
content_stripped,
|
||||
),
|
||||
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||
)
|
||||
return Models.MemoryOperationType.UPDATE.value
|
||||
@@ -1616,7 +1818,8 @@ class Filter:
|
||||
return Models.OperationResult.SKIPPED_EMPTY_ID.value
|
||||
|
||||
await asyncio.wait_for(
|
||||
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
||||
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
|
||||
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||
)
|
||||
return Models.MemoryOperationType.DELETE.value
|
||||
else:
|
||||
@@ -1647,7 +1850,7 @@ class Filter:
|
||||
elif isinstance(value, dict):
|
||||
result[key] = self._remove_refs_from_schema(value, schema_defs)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value]
|
||||
result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
|
||||
@@ -1656,7 +1859,12 @@ class Filter:
|
||||
|
||||
return result
|
||||
|
||||
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]:
|
||||
async def _query_llm(
|
||||
self,
|
||||
system_prompt: str,
|
||||
user_prompt: str,
|
||||
response_model: Optional[BaseModel] = None,
|
||||
) -> Union[str, BaseModel]:
|
||||
"""Query OpenWebUI's internal model system with Pydantic model parsing."""
|
||||
if not hasattr(self, "__request__") or not hasattr(self, "__user__"):
|
||||
raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.")
|
||||
@@ -1667,7 +1875,10 @@ class Filter:
|
||||
|
||||
form_data = {
|
||||
"model": model_to_use,
|
||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
||||
"messages": [
|
||||
{"role": "system", "content": system_prompt},
|
||||
{"role": "user", "content": user_prompt},
|
||||
],
|
||||
"max_tokens": 4096,
|
||||
"stream": False,
|
||||
}
|
||||
@@ -1677,11 +1888,22 @@ class Filter:
|
||||
schema_defs = raw_schema.get("$defs", {})
|
||||
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
|
||||
schema["type"] = "object"
|
||||
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
|
||||
form_data["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"strict": True,
|
||||
"schema": schema,
|
||||
},
|
||||
}
|
||||
|
||||
try:
|
||||
response = await asyncio.wait_for(
|
||||
generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])),
|
||||
generate_chat_completion(
|
||||
self.__request__,
|
||||
form_data,
|
||||
user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]),
|
||||
),
|
||||
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
||||
)
|
||||
except asyncio.TimeoutError:
|
||||
|
||||
@@ -4,4 +4,4 @@ pydantic
|
||||
numpy
|
||||
tiktoken
|
||||
black
|
||||
isort
|
||||
isort
|
||||
|
||||
Reference in New Issue
Block a user