Merge pull request #3 from GlisseManTV/dev

Way to use current model instead of dedicated model.
This commit is contained in:
M. Tayfur
2025-10-27 23:39:06 +03:00
committed by GitHub
3 changed files with 294 additions and 72 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
__pycache__/
.github/instructions/*
.venv/
**AGENTS.md
tests/

View File

@@ -1,6 +1,9 @@
"""
title: Memory System
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
version: 1.0.0
authors: https://github.com/mtayfur
license: Apache-2.0
"""
import asyncio
@@ -217,7 +220,10 @@ class Models:
if self.operation == Models.MemoryOperationType.CREATE:
return True
elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]:
elif self.operation in [
Models.MemoryOperationType.UPDATE,
Models.MemoryOperationType.DELETE,
]:
return self.id in existing_memory_ids
return False
@@ -229,7 +235,10 @@ class Models:
class MemoryRerankingResponse(BaseModel):
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query")
ids: List[str] = Field(
default_factory=list,
description="List of memory IDs selected as most relevant for the user query",
)
class UnifiedCacheManager:
@@ -440,7 +449,10 @@ class SkipDetector:
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
}
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
def __init__(
self,
embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]],
):
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
self.embedding_function = embedding_function
self._reference_embeddings = None
@@ -583,7 +595,16 @@ class SkipDetector:
if markup_in_lines / len(non_empty_lines) > 0.3:
return self.SkipReason.SKIP_TECHNICAL.value
elif structured_lines / len(non_empty_lines) > 0.6:
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
technical_keywords = [
"function",
"class",
"import",
"return",
"const",
"var",
"let",
"def",
]
if any(keyword in message.lower() for keyword in technical_keywords):
return self.SkipReason.SKIP_TECHNICAL.value
@@ -594,7 +615,18 @@ class SkipDetector:
if non_empty_lines:
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
if indented_lines / len(non_empty_lines) > 0.5:
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "]
code_indicators = [
"def ",
"class ",
"function ",
"return ",
"import ",
"const ",
"let ",
"var ",
"public ",
"private ",
]
if any(indicator in message.lower() for indicator in code_indicators):
return self.SkipReason.SKIP_TECHNICAL.value
@@ -637,11 +669,31 @@ class SkipDetector:
max_conversational_similarity = float(conversational_similarities.max())
skip_categories = [
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
(
"instruction",
self.SkipReason.SKIP_INSTRUCTION,
self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
),
(
"translation",
self.SkipReason.SKIP_TRANSLATION,
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
),
(
"grammar",
self.SkipReason.SKIP_GRAMMAR_PROOFREAD,
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
),
(
"technical",
self.SkipReason.SKIP_TECHNICAL,
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
),
(
"pure_math",
self.SkipReason.SKIP_PURE_MATH,
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
),
]
qualifying_categories = []
@@ -680,11 +732,23 @@ class LLMRerankingService:
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
if len(memories) > llm_trigger_threshold:
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold"
return (
True,
f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold",
)
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}"
return (
False,
f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}",
)
async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]:
async def _llm_select_memories(
self,
user_message: str,
candidate_memories: List[Dict],
max_count: int,
emitter: Optional[Callable] = None,
) -> List[Dict]:
"""Use LLM to select most relevant memories."""
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
memory_context = "\n".join(memory_lines)
@@ -697,7 +761,11 @@ CANDIDATE MEMORIES:
{memory_context}"""
try:
response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse)
response = await self.memory_system._query_llm(
Prompts.MEMORY_RERANKING,
user_prompt,
response_model=Models.MemoryRerankingResponse,
)
selected_memories = []
for memory in candidate_memories:
@@ -712,18 +780,31 @@ CANDIDATE MEMORIES:
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
return candidate_memories
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]:
async def rerank_memories(
self,
user_message: str,
candidate_memories: List[Dict],
emitter: Optional[Callable] = None,
) -> Tuple[List[Dict], Dict[str, Any]]:
start_time = time.time()
max_injection = self.memory_system.valves.max_memories_returned
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)}
analysis_info = {
"llm_decision": should_use_llm,
"decision_reason": decision_reason,
"candidate_count": len(candidate_memories),
}
if should_use_llm:
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
llm_candidates = candidate_memories[:extended_count]
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False)
await self.memory_system._emit_status(
emitter,
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
done=False,
)
logger.info(f"Using LLM reranking: {decision_reason}")
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
@@ -739,7 +820,11 @@ CANDIDATE MEMORIES:
duration = time.time() - start_time
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
retrieval_method = "LLM" if should_use_llm else "Semantic"
await self.memory_system._emit_status(emitter, f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", done=True)
await self.memory_system._emit_status(
emitter,
f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}",
done=True,
)
return selected_memories, analysis_info
@@ -761,7 +846,10 @@ class LLMConsolidationService:
return candidates, threshold_info
async def collect_consolidation_candidates(
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None
self,
user_message: str,
user_id: str,
cached_similarities: Optional[List[Dict[str, Any]]] = None,
) -> List[Dict[str, Any]]:
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
if cached_similarities:
@@ -805,7 +893,10 @@ class LLMConsolidationService:
return candidates
async def generate_consolidation_plan(
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None
self,
user_message: str,
candidate_memories: List[Dict[str, Any]],
emitter: Optional[Callable] = None,
) -> List[Dict[str, Any]]:
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
if candidate_memories:
@@ -820,7 +911,11 @@ class LLMConsolidationService:
try:
response = await asyncio.wait_for(
self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse),
self.memory_system._query_llm(
Prompts.MEMORY_CONSOLIDATION,
user_prompt,
response_model=Models.ConsolidationResponse,
),
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
)
except Exception as e:
@@ -856,13 +951,21 @@ class LLMConsolidationService:
return valid_operations
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]:
async def execute_memory_operations(
self,
operations: List[Dict[str, Any]],
user_id: str,
emitter: Optional[Callable] = None,
) -> Tuple[int, int, int, int]:
"""Execute consolidation operations with simplified tracking."""
if not operations or not user_id:
return 0, 0, 0, 0
try:
user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC)
user = await asyncio.wait_for(
asyncio.to_thread(Users.get_user_by_id, user_id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
except asyncio.TimeoutError:
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
except Exception as e:
@@ -929,7 +1032,10 @@ class LLMConsolidationService:
if content_preview and content_preview != operation.id:
content_preview = self.memory_system._truncate_content(content_preview)
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]:
elif result in [
Models.OperationResult.FAILED.value,
Models.OperationResult.UNSUPPORTED.value,
]:
failed_count += 1
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
except Exception as e:
@@ -950,7 +1056,11 @@ class LLMConsolidationService:
return created_count, updated_count, deleted_count, failed_count
async def run_consolidation_pipeline(
self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None
self,
user_message: str,
user_id: str,
emitter: Optional[Callable] = None,
cached_similarities: Optional[List[Dict[str, Any]]] = None,
) -> None:
"""Complete consolidation pipeline with simplified flow."""
start_time = time.time()
@@ -974,7 +1084,11 @@ class LLMConsolidationService:
total_operations = created_count + updated_count + deleted_count
if total_operations > 0 or failed_count > 0:
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False)
await self.memory_system._emit_status(
emitter,
f"💾 Memory Consolidation Complete in {duration:.2f}s",
done=False,
)
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
memory_word = "Memory" if total_operations == 1 else "Memories"
@@ -1001,22 +1115,41 @@ class Filter:
class Valves(BaseModel):
"""Configuration valves for the Memory System."""
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
model: str = Field(
default=Constants.DEFAULT_LLM_MODEL,
description="Model name for LLM operations",
)
use_custom_model_for_memory: bool = Field(
default=False,
description="Use a custom model for memory operations instead of the current chat model",
)
custom_memory_model: str = Field(
default=Constants.DEFAULT_LLM_MODEL,
description="Custom model to use for memory operations when enabled",
)
max_memories_returned: int = Field(
default=Constants.MAX_MEMORIES_PER_RETRIEVAL,
description="Maximum number of memories to return in context",
)
max_message_chars: int = Field(
default=Constants.MAX_MESSAGE_CHARS,
description="Maximum user message length before skipping memory operations",
)
semantic_retrieval_threshold: float = Field(
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval"
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD,
description="Minimum similarity threshold for memory retrieval",
)
relaxed_semantic_threshold_multiplier: float = Field(
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
)
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
enable_llm_reranking: bool = Field(
default=True,
description="Enable LLM-based memory reranking for improved contextual selection",
)
llm_reranking_trigger_multiplier: float = Field(
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)"
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER,
description="Controls when LLM reranking activates (lower = more aggressive)",
)
def __init__(self):
@@ -1071,7 +1204,9 @@ class Filter:
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
embedding_fn = self._embedding_function
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
def embedding_wrapper(
texts: Union[str, List[str]],
) -> Union[np.ndarray, List[np.ndarray]]:
result = embedding_fn(texts, prefix=None, user=None)
if isinstance(result, list):
if isinstance(result[0], list):
@@ -1219,7 +1354,11 @@ class Filter:
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
"""Extract user message and determine if memory operations should be skipped."""
if not body or "messages" not in body or not isinstance(body["messages"], list):
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
messages = body["messages"]
user_message = None
@@ -1235,7 +1374,11 @@ class Filter:
break
if not user_message or not user_message.strip():
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
return (
None,
True,
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
)
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
return user_message, should_skip, skip_reason
@@ -1245,7 +1388,10 @@ class Filter:
if timeout is None:
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
try:
return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout)
return await asyncio.wait_for(
asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
timeout=timeout,
)
except asyncio.TimeoutError:
raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s")
except Exception as e:
@@ -1274,7 +1420,11 @@ class Filter:
logger.info(f"Scores: [{scores_str}{suffix}]")
def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]:
operations = [(created_count, "📝 Created"), (updated_count, "✏️ Updated"), (deleted_count, "🗑️ Deleted")]
operations = [
(created_count, "📝 Created"),
(updated_count, "✏️ Updated"),
(deleted_count, "🗑️ Deleted"),
]
return [f"{label} {count}" for count, label in operations if count > 0]
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
@@ -1366,10 +1516,19 @@ class Filter:
self._log_retrieved_memories(final_memories, "semantic")
return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info}
return {
"memories": final_memories,
"threshold": threshold,
"all_similarities": all_similarities,
"reranking_info": reranking_info,
}
async def _add_memory_context(
self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None
self,
body: Dict[str, Any],
memories: Optional[List[Dict[str, Any]]] = None,
user_id: Optional[str] = None,
emitter: Optional[Callable] = None,
) -> None:
"""Add memory context to request body with simplified logic."""
if not body or "messages" not in body or not body["messages"]:
@@ -1397,7 +1556,10 @@ class Filter:
memory_context = "\n\n".join(content_parts)
system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None)
system_index = next(
(i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"),
None,
)
if system_index is not None:
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
@@ -1410,7 +1572,11 @@ class Filter:
def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]:
"""Build memory dictionary with standardized timestamp conversion."""
memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity}
memory_dict = {
"id": str(memory.id),
"content": memory.content,
"relevance": similarity,
}
if hasattr(memory, "created_at") and memory.created_at:
memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat()
if hasattr(memory, "updated_at") and memory.updated_at:
@@ -1462,42 +1628,58 @@ class Filter:
**kwargs,
) -> Dict[str, Any]:
"""Simplified inlet processing for memory retrieval and injection."""
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
model_to_use = body.get("model") if isinstance(body, dict) else None
if not model_to_use:
model_to_use = __model__ or getattr(__request__.state, "model", None)
if not model_to_use:
model_to_use = Constants.DEFAULT_LLM_MODEL
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
model_to_use = self.valves.custom_memory_model
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
self.valves.model = model_to_use
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None
if not user_id:
return body
user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip:
if __event_emitter__ and skip_reason:
await self._emit_status(__event_emitter__, skip_reason, done=True)
await self._add_memory_context(body, [], user_id, __event_emitter__)
return body
try:
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key)
if user_memories is None:
user_memories = await self._get_user_memories(user_id)
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
await self._cache_manager.put(
user_id,
self._cache_manager.MEMORY_CACHE,
memory_cache_key,
user_memories,
)
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
memories = retrieval_result.get("memories", [])
threshold = retrieval_result.get("threshold")
all_similarities = retrieval_result.get("all_similarities", [])
if all_similarities:
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities)
await self._cache_manager.put(
user_id,
self._cache_manager.RETRIEVAL_CACHE,
cache_key,
all_similarities,
)
await self._add_memory_context(body, memories, user_id, __event_emitter__)
except Exception as e:
raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}")
return body
async def outlet(
@@ -1510,20 +1692,30 @@ class Filter:
**kwargs,
) -> dict:
"""Simplified outlet processing for background memory consolidation."""
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
model_to_use = body.get("model") if isinstance(body, dict) else None
if not model_to_use:
model_to_use = __model__ or getattr(__request__.state, "model", None)
if not model_to_use:
model_to_use = Constants.DEFAULT_LLM_MODEL
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
model_to_use = self.valves.custom_memory_model
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
self.valves.model = model_to_use
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
user_id = __user__.get("id") if body and __user__ else None
if not user_id:
return body
user_message, should_skip, skip_reason = self._process_user_message(body)
if not user_message or should_skip:
return body
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
self._background_tasks.add(task)
@@ -1537,7 +1729,6 @@ class Filter:
logger.error(f"❌ Failed to cleanup background memory task: {str(e)}")
task.add_done_callback(safe_cleanup)
return body
async def shutdown(self) -> None:
@@ -1566,7 +1757,12 @@ class Filter:
logger.info("📭 No memories found for user")
return
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
await self._cache_manager.put(
user_id,
self._cache_manager.MEMORY_CACHE,
memory_cache_key,
user_memories,
)
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
@@ -1588,7 +1784,8 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for(
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.CREATE.value
@@ -1604,7 +1801,12 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
await asyncio.wait_for(
asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped),
asyncio.to_thread(
Memories.update_memory_by_id_and_user_id,
id_stripped,
user.id,
content_stripped,
),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.UPDATE.value
@@ -1616,7 +1818,8 @@ class Filter:
return Models.OperationResult.SKIPPED_EMPTY_ID.value
await asyncio.wait_for(
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
)
return Models.MemoryOperationType.DELETE.value
else:
@@ -1647,7 +1850,7 @@ class Filter:
elif isinstance(value, dict):
result[key] = self._remove_refs_from_schema(value, schema_defs)
elif isinstance(value, list):
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value]
result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value]
else:
result[key] = value
@@ -1656,7 +1859,12 @@ class Filter:
return result
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]:
async def _query_llm(
self,
system_prompt: str,
user_prompt: str,
response_model: Optional[BaseModel] = None,
) -> Union[str, BaseModel]:
"""Query OpenWebUI's internal model system with Pydantic model parsing."""
if not hasattr(self, "__request__") or not hasattr(self, "__user__"):
raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.")
@@ -1667,7 +1875,10 @@ class Filter:
form_data = {
"model": model_to_use,
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
"messages": [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt},
],
"max_tokens": 4096,
"stream": False,
}
@@ -1677,11 +1888,22 @@ class Filter:
schema_defs = raw_schema.get("$defs", {})
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
schema["type"] = "object"
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
form_data["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"strict": True,
"schema": schema,
},
}
try:
response = await asyncio.wait_for(
generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])),
generate_chat_completion(
self.__request__,
form_data,
user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]),
),
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
)
except asyncio.TimeoutError:

View File

@@ -4,4 +4,4 @@ pydantic
numpy
tiktoken
black
isort
isort