mirror of
https://github.com/mtayfur/openwebui-memory-system.git
synced 2026-01-22 06:51:01 +01:00
Merge pull request #3 from GlisseManTV/dev
Way to use current model instead of dedicated model.
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -1,4 +1,4 @@
|
|||||||
__pycache__/
|
__pycache__/
|
||||||
|
.github/instructions/*
|
||||||
.venv/
|
.venv/
|
||||||
**AGENTS.md
|
|
||||||
tests/
|
tests/
|
||||||
362
memory_system.py
362
memory_system.py
@@ -1,6 +1,9 @@
|
|||||||
"""
|
"""
|
||||||
title: Memory System
|
title: Memory System
|
||||||
|
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
|
||||||
version: 1.0.0
|
version: 1.0.0
|
||||||
|
authors: https://github.com/mtayfur
|
||||||
|
license: Apache-2.0
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
@@ -217,7 +220,10 @@ class Models:
|
|||||||
|
|
||||||
if self.operation == Models.MemoryOperationType.CREATE:
|
if self.operation == Models.MemoryOperationType.CREATE:
|
||||||
return True
|
return True
|
||||||
elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]:
|
elif self.operation in [
|
||||||
|
Models.MemoryOperationType.UPDATE,
|
||||||
|
Models.MemoryOperationType.DELETE,
|
||||||
|
]:
|
||||||
return self.id in existing_memory_ids
|
return self.id in existing_memory_ids
|
||||||
return False
|
return False
|
||||||
|
|
||||||
@@ -229,7 +235,10 @@ class Models:
|
|||||||
class MemoryRerankingResponse(BaseModel):
|
class MemoryRerankingResponse(BaseModel):
|
||||||
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
|
"""Pydantic model for memory reranking LLM response - object containing array of memory IDs."""
|
||||||
|
|
||||||
ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query")
|
ids: List[str] = Field(
|
||||||
|
default_factory=list,
|
||||||
|
description="List of memory IDs selected as most relevant for the user query",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class UnifiedCacheManager:
|
class UnifiedCacheManager:
|
||||||
@@ -440,7 +449,10 @@ class SkipDetector:
|
|||||||
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
|
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
|
||||||
}
|
}
|
||||||
|
|
||||||
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
|
def __init__(
|
||||||
|
self,
|
||||||
|
embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]],
|
||||||
|
):
|
||||||
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
|
"""Initialize the skip detector with an embedding function and compute reference embeddings."""
|
||||||
self.embedding_function = embedding_function
|
self.embedding_function = embedding_function
|
||||||
self._reference_embeddings = None
|
self._reference_embeddings = None
|
||||||
@@ -583,7 +595,16 @@ class SkipDetector:
|
|||||||
if markup_in_lines / len(non_empty_lines) > 0.3:
|
if markup_in_lines / len(non_empty_lines) > 0.3:
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
elif structured_lines / len(non_empty_lines) > 0.6:
|
elif structured_lines / len(non_empty_lines) > 0.6:
|
||||||
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
|
technical_keywords = [
|
||||||
|
"function",
|
||||||
|
"class",
|
||||||
|
"import",
|
||||||
|
"return",
|
||||||
|
"const",
|
||||||
|
"var",
|
||||||
|
"let",
|
||||||
|
"def",
|
||||||
|
]
|
||||||
if any(keyword in message.lower() for keyword in technical_keywords):
|
if any(keyword in message.lower() for keyword in technical_keywords):
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
@@ -594,7 +615,18 @@ class SkipDetector:
|
|||||||
if non_empty_lines:
|
if non_empty_lines:
|
||||||
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
|
indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t"))
|
||||||
if indented_lines / len(non_empty_lines) > 0.5:
|
if indented_lines / len(non_empty_lines) > 0.5:
|
||||||
code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "]
|
code_indicators = [
|
||||||
|
"def ",
|
||||||
|
"class ",
|
||||||
|
"function ",
|
||||||
|
"return ",
|
||||||
|
"import ",
|
||||||
|
"const ",
|
||||||
|
"let ",
|
||||||
|
"var ",
|
||||||
|
"public ",
|
||||||
|
"private ",
|
||||||
|
]
|
||||||
if any(indicator in message.lower() for indicator in code_indicators):
|
if any(indicator in message.lower() for indicator in code_indicators):
|
||||||
return self.SkipReason.SKIP_TECHNICAL.value
|
return self.SkipReason.SKIP_TECHNICAL.value
|
||||||
|
|
||||||
@@ -637,11 +669,31 @@ class SkipDetector:
|
|||||||
max_conversational_similarity = float(conversational_similarities.max())
|
max_conversational_similarity = float(conversational_similarities.max())
|
||||||
|
|
||||||
skip_categories = [
|
skip_categories = [
|
||||||
("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS),
|
(
|
||||||
("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS),
|
"instruction",
|
||||||
("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS),
|
self.SkipReason.SKIP_INSTRUCTION,
|
||||||
("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS),
|
self.INSTRUCTION_CATEGORY_DESCRIPTIONS,
|
||||||
("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS),
|
),
|
||||||
|
(
|
||||||
|
"translation",
|
||||||
|
self.SkipReason.SKIP_TRANSLATION,
|
||||||
|
self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"grammar",
|
||||||
|
self.SkipReason.SKIP_GRAMMAR_PROOFREAD,
|
||||||
|
self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"technical",
|
||||||
|
self.SkipReason.SKIP_TECHNICAL,
|
||||||
|
self.TECHNICAL_CATEGORY_DESCRIPTIONS,
|
||||||
|
),
|
||||||
|
(
|
||||||
|
"pure_math",
|
||||||
|
self.SkipReason.SKIP_PURE_MATH,
|
||||||
|
self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
qualifying_categories = []
|
qualifying_categories = []
|
||||||
@@ -680,11 +732,23 @@ class LLMRerankingService:
|
|||||||
|
|
||||||
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
|
llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier)
|
||||||
if len(memories) > llm_trigger_threshold:
|
if len(memories) > llm_trigger_threshold:
|
||||||
return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold"
|
return (
|
||||||
|
True,
|
||||||
|
f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold",
|
||||||
|
)
|
||||||
|
|
||||||
return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}"
|
return (
|
||||||
|
False,
|
||||||
|
f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}",
|
||||||
|
)
|
||||||
|
|
||||||
async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]:
|
async def _llm_select_memories(
|
||||||
|
self,
|
||||||
|
user_message: str,
|
||||||
|
candidate_memories: List[Dict],
|
||||||
|
max_count: int,
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
|
) -> List[Dict]:
|
||||||
"""Use LLM to select most relevant memories."""
|
"""Use LLM to select most relevant memories."""
|
||||||
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
|
memory_lines = self.memory_system._format_memories_for_llm(candidate_memories)
|
||||||
memory_context = "\n".join(memory_lines)
|
memory_context = "\n".join(memory_lines)
|
||||||
@@ -697,7 +761,11 @@ CANDIDATE MEMORIES:
|
|||||||
{memory_context}"""
|
{memory_context}"""
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse)
|
response = await self.memory_system._query_llm(
|
||||||
|
Prompts.MEMORY_RERANKING,
|
||||||
|
user_prompt,
|
||||||
|
response_model=Models.MemoryRerankingResponse,
|
||||||
|
)
|
||||||
|
|
||||||
selected_memories = []
|
selected_memories = []
|
||||||
for memory in candidate_memories:
|
for memory in candidate_memories:
|
||||||
@@ -712,18 +780,31 @@ CANDIDATE MEMORIES:
|
|||||||
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
|
logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}")
|
||||||
return candidate_memories
|
return candidate_memories
|
||||||
|
|
||||||
async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]:
|
async def rerank_memories(
|
||||||
|
self,
|
||||||
|
user_message: str,
|
||||||
|
candidate_memories: List[Dict],
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
|
) -> Tuple[List[Dict], Dict[str, Any]]:
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
max_injection = self.memory_system.valves.max_memories_returned
|
max_injection = self.memory_system.valves.max_memories_returned
|
||||||
|
|
||||||
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
|
should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories)
|
||||||
|
|
||||||
analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)}
|
analysis_info = {
|
||||||
|
"llm_decision": should_use_llm,
|
||||||
|
"decision_reason": decision_reason,
|
||||||
|
"candidate_count": len(candidate_memories),
|
||||||
|
}
|
||||||
|
|
||||||
if should_use_llm:
|
if should_use_llm:
|
||||||
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER)
|
||||||
llm_candidates = candidate_memories[:extended_count]
|
llm_candidates = candidate_memories[:extended_count]
|
||||||
await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False)
|
await self.memory_system._emit_status(
|
||||||
|
emitter,
|
||||||
|
f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance",
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
logger.info(f"Using LLM reranking: {decision_reason}")
|
logger.info(f"Using LLM reranking: {decision_reason}")
|
||||||
|
|
||||||
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
|
selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter)
|
||||||
@@ -739,7 +820,11 @@ CANDIDATE MEMORIES:
|
|||||||
duration = time.time() - start_time
|
duration = time.time() - start_time
|
||||||
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
|
duration_text = f" in {duration:.2f}s" if duration >= 0.01 else ""
|
||||||
retrieval_method = "LLM" if should_use_llm else "Semantic"
|
retrieval_method = "LLM" if should_use_llm else "Semantic"
|
||||||
await self.memory_system._emit_status(emitter, f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", done=True)
|
await self.memory_system._emit_status(
|
||||||
|
emitter,
|
||||||
|
f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}",
|
||||||
|
done=True,
|
||||||
|
)
|
||||||
return selected_memories, analysis_info
|
return selected_memories, analysis_info
|
||||||
|
|
||||||
|
|
||||||
@@ -761,7 +846,10 @@ class LLMConsolidationService:
|
|||||||
return candidates, threshold_info
|
return candidates, threshold_info
|
||||||
|
|
||||||
async def collect_consolidation_candidates(
|
async def collect_consolidation_candidates(
|
||||||
self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None
|
self,
|
||||||
|
user_message: str,
|
||||||
|
user_id: str,
|
||||||
|
cached_similarities: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
|
"""Collect candidate memories for consolidation analysis using cached or computed similarities."""
|
||||||
if cached_similarities:
|
if cached_similarities:
|
||||||
@@ -805,7 +893,10 @@ class LLMConsolidationService:
|
|||||||
return candidates
|
return candidates
|
||||||
|
|
||||||
async def generate_consolidation_plan(
|
async def generate_consolidation_plan(
|
||||||
self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None
|
self,
|
||||||
|
user_message: str,
|
||||||
|
candidate_memories: List[Dict[str, Any]],
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
) -> List[Dict[str, Any]]:
|
) -> List[Dict[str, Any]]:
|
||||||
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
|
"""Generate consolidation plan using LLM with clear system/user prompt separation."""
|
||||||
if candidate_memories:
|
if candidate_memories:
|
||||||
@@ -820,7 +911,11 @@ class LLMConsolidationService:
|
|||||||
|
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse),
|
self.memory_system._query_llm(
|
||||||
|
Prompts.MEMORY_CONSOLIDATION,
|
||||||
|
user_prompt,
|
||||||
|
response_model=Models.ConsolidationResponse,
|
||||||
|
),
|
||||||
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -856,13 +951,21 @@ class LLMConsolidationService:
|
|||||||
|
|
||||||
return valid_operations
|
return valid_operations
|
||||||
|
|
||||||
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]:
|
async def execute_memory_operations(
|
||||||
|
self,
|
||||||
|
operations: List[Dict[str, Any]],
|
||||||
|
user_id: str,
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
|
) -> Tuple[int, int, int, int]:
|
||||||
"""Execute consolidation operations with simplified tracking."""
|
"""Execute consolidation operations with simplified tracking."""
|
||||||
if not operations or not user_id:
|
if not operations or not user_id:
|
||||||
return 0, 0, 0, 0
|
return 0, 0, 0, 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC)
|
user = await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(Users.get_user_by_id, user_id),
|
||||||
|
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
|
raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -929,7 +1032,10 @@ class LLMConsolidationService:
|
|||||||
if content_preview and content_preview != operation.id:
|
if content_preview and content_preview != operation.id:
|
||||||
content_preview = self.memory_system._truncate_content(content_preview)
|
content_preview = self.memory_system._truncate_content(content_preview)
|
||||||
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
|
await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False)
|
||||||
elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]:
|
elif result in [
|
||||||
|
Models.OperationResult.FAILED.value,
|
||||||
|
Models.OperationResult.UNSUPPORTED.value,
|
||||||
|
]:
|
||||||
failed_count += 1
|
failed_count += 1
|
||||||
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
|
await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -950,7 +1056,11 @@ class LLMConsolidationService:
|
|||||||
return created_count, updated_count, deleted_count, failed_count
|
return created_count, updated_count, deleted_count, failed_count
|
||||||
|
|
||||||
async def run_consolidation_pipeline(
|
async def run_consolidation_pipeline(
|
||||||
self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None
|
self,
|
||||||
|
user_message: str,
|
||||||
|
user_id: str,
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
|
cached_similarities: Optional[List[Dict[str, Any]]] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Complete consolidation pipeline with simplified flow."""
|
"""Complete consolidation pipeline with simplified flow."""
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
@@ -974,7 +1084,11 @@ class LLMConsolidationService:
|
|||||||
|
|
||||||
total_operations = created_count + updated_count + deleted_count
|
total_operations = created_count + updated_count + deleted_count
|
||||||
if total_operations > 0 or failed_count > 0:
|
if total_operations > 0 or failed_count > 0:
|
||||||
await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False)
|
await self.memory_system._emit_status(
|
||||||
|
emitter,
|
||||||
|
f"💾 Memory Consolidation Complete in {duration:.2f}s",
|
||||||
|
done=False,
|
||||||
|
)
|
||||||
|
|
||||||
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
|
operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count)
|
||||||
memory_word = "Memory" if total_operations == 1 else "Memories"
|
memory_word = "Memory" if total_operations == 1 else "Memories"
|
||||||
@@ -1001,22 +1115,41 @@ class Filter:
|
|||||||
class Valves(BaseModel):
|
class Valves(BaseModel):
|
||||||
"""Configuration valves for the Memory System."""
|
"""Configuration valves for the Memory System."""
|
||||||
|
|
||||||
model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations")
|
model: str = Field(
|
||||||
|
default=Constants.DEFAULT_LLM_MODEL,
|
||||||
max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations")
|
description="Model name for LLM operations",
|
||||||
max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context")
|
)
|
||||||
|
use_custom_model_for_memory: bool = Field(
|
||||||
|
default=False,
|
||||||
|
description="Use a custom model for memory operations instead of the current chat model",
|
||||||
|
)
|
||||||
|
custom_memory_model: str = Field(
|
||||||
|
default=Constants.DEFAULT_LLM_MODEL,
|
||||||
|
description="Custom model to use for memory operations when enabled",
|
||||||
|
)
|
||||||
|
max_memories_returned: int = Field(
|
||||||
|
default=Constants.MAX_MEMORIES_PER_RETRIEVAL,
|
||||||
|
description="Maximum number of memories to return in context",
|
||||||
|
)
|
||||||
|
max_message_chars: int = Field(
|
||||||
|
default=Constants.MAX_MESSAGE_CHARS,
|
||||||
|
description="Maximum user message length before skipping memory operations",
|
||||||
|
)
|
||||||
semantic_retrieval_threshold: float = Field(
|
semantic_retrieval_threshold: float = Field(
|
||||||
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval"
|
default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD,
|
||||||
|
description="Minimum similarity threshold for memory retrieval",
|
||||||
)
|
)
|
||||||
relaxed_semantic_threshold_multiplier: float = Field(
|
relaxed_semantic_threshold_multiplier: float = Field(
|
||||||
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
|
default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER,
|
||||||
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
|
description="Adjusts similarity threshold for memory consolidation (lower = more candidates)",
|
||||||
)
|
)
|
||||||
|
enable_llm_reranking: bool = Field(
|
||||||
enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection")
|
default=True,
|
||||||
|
description="Enable LLM-based memory reranking for improved contextual selection",
|
||||||
|
)
|
||||||
llm_reranking_trigger_multiplier: float = Field(
|
llm_reranking_trigger_multiplier: float = Field(
|
||||||
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)"
|
default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER,
|
||||||
|
description="Controls when LLM reranking activates (lower = more aggressive)",
|
||||||
)
|
)
|
||||||
|
|
||||||
def __init__(self):
|
def __init__(self):
|
||||||
@@ -1071,7 +1204,9 @@ class Filter:
|
|||||||
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}")
|
||||||
embedding_fn = self._embedding_function
|
embedding_fn = self._embedding_function
|
||||||
|
|
||||||
def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]:
|
def embedding_wrapper(
|
||||||
|
texts: Union[str, List[str]],
|
||||||
|
) -> Union[np.ndarray, List[np.ndarray]]:
|
||||||
result = embedding_fn(texts, prefix=None, user=None)
|
result = embedding_fn(texts, prefix=None, user=None)
|
||||||
if isinstance(result, list):
|
if isinstance(result, list):
|
||||||
if isinstance(result[0], list):
|
if isinstance(result[0], list):
|
||||||
@@ -1219,7 +1354,11 @@ class Filter:
|
|||||||
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]:
|
||||||
"""Extract user message and determine if memory operations should be skipped."""
|
"""Extract user message and determine if memory operations should be skipped."""
|
||||||
if not body or "messages" not in body or not isinstance(body["messages"], list):
|
if not body or "messages" not in body or not isinstance(body["messages"], list):
|
||||||
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
return (
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
|
||||||
|
)
|
||||||
|
|
||||||
messages = body["messages"]
|
messages = body["messages"]
|
||||||
user_message = None
|
user_message = None
|
||||||
@@ -1235,7 +1374,11 @@ class Filter:
|
|||||||
break
|
break
|
||||||
|
|
||||||
if not user_message or not user_message.strip():
|
if not user_message or not user_message.strip():
|
||||||
return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE]
|
return (
|
||||||
|
None,
|
||||||
|
True,
|
||||||
|
SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE],
|
||||||
|
)
|
||||||
|
|
||||||
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
|
should_skip, skip_reason = self._should_skip_memory_operations(user_message)
|
||||||
return user_message, should_skip, skip_reason
|
return user_message, should_skip, skip_reason
|
||||||
@@ -1245,7 +1388,10 @@ class Filter:
|
|||||||
if timeout is None:
|
if timeout is None:
|
||||||
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
||||||
try:
|
try:
|
||||||
return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout)
|
return await asyncio.wait_for(
|
||||||
|
asyncio.to_thread(Memories.get_memories_by_user_id, user_id),
|
||||||
|
timeout=timeout,
|
||||||
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s")
|
raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -1274,7 +1420,11 @@ class Filter:
|
|||||||
logger.info(f"Scores: [{scores_str}{suffix}]")
|
logger.info(f"Scores: [{scores_str}{suffix}]")
|
||||||
|
|
||||||
def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]:
|
def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]:
|
||||||
operations = [(created_count, "📝 Created"), (updated_count, "✏️ Updated"), (deleted_count, "🗑️ Deleted")]
|
operations = [
|
||||||
|
(created_count, "📝 Created"),
|
||||||
|
(updated_count, "✏️ Updated"),
|
||||||
|
(deleted_count, "🗑️ Deleted"),
|
||||||
|
]
|
||||||
return [f"{label} {count}" for count, label in operations if count > 0]
|
return [f"{label} {count}" for count, label in operations if count > 0]
|
||||||
|
|
||||||
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
|
def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str:
|
||||||
@@ -1366,10 +1516,19 @@ class Filter:
|
|||||||
|
|
||||||
self._log_retrieved_memories(final_memories, "semantic")
|
self._log_retrieved_memories(final_memories, "semantic")
|
||||||
|
|
||||||
return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info}
|
return {
|
||||||
|
"memories": final_memories,
|
||||||
|
"threshold": threshold,
|
||||||
|
"all_similarities": all_similarities,
|
||||||
|
"reranking_info": reranking_info,
|
||||||
|
}
|
||||||
|
|
||||||
async def _add_memory_context(
|
async def _add_memory_context(
|
||||||
self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None
|
self,
|
||||||
|
body: Dict[str, Any],
|
||||||
|
memories: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
emitter: Optional[Callable] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Add memory context to request body with simplified logic."""
|
"""Add memory context to request body with simplified logic."""
|
||||||
if not body or "messages" not in body or not body["messages"]:
|
if not body or "messages" not in body or not body["messages"]:
|
||||||
@@ -1397,7 +1556,10 @@ class Filter:
|
|||||||
|
|
||||||
memory_context = "\n\n".join(content_parts)
|
memory_context = "\n\n".join(content_parts)
|
||||||
|
|
||||||
system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None)
|
system_index = next(
|
||||||
|
(i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
|
||||||
if system_index is not None:
|
if system_index is not None:
|
||||||
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
|
body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}"
|
||||||
@@ -1410,7 +1572,11 @@ class Filter:
|
|||||||
|
|
||||||
def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]:
|
def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]:
|
||||||
"""Build memory dictionary with standardized timestamp conversion."""
|
"""Build memory dictionary with standardized timestamp conversion."""
|
||||||
memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity}
|
memory_dict = {
|
||||||
|
"id": str(memory.id),
|
||||||
|
"content": memory.content,
|
||||||
|
"relevance": similarity,
|
||||||
|
}
|
||||||
if hasattr(memory, "created_at") and memory.created_at:
|
if hasattr(memory, "created_at") and memory.created_at:
|
||||||
memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat()
|
memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat()
|
||||||
if hasattr(memory, "updated_at") and memory.updated_at:
|
if hasattr(memory, "updated_at") and memory.updated_at:
|
||||||
@@ -1462,42 +1628,58 @@ class Filter:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
"""Simplified inlet processing for memory retrieval and injection."""
|
"""Simplified inlet processing for memory retrieval and injection."""
|
||||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
|
||||||
|
model_to_use = body.get("model") if isinstance(body, dict) else None
|
||||||
|
if not model_to_use:
|
||||||
|
model_to_use = __model__ or getattr(__request__.state, "model", None)
|
||||||
|
if not model_to_use:
|
||||||
|
model_to_use = Constants.DEFAULT_LLM_MODEL
|
||||||
|
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
|
||||||
|
|
||||||
|
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
|
||||||
|
model_to_use = self.valves.custom_memory_model
|
||||||
|
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
|
||||||
|
|
||||||
|
self.valves.model = model_to_use
|
||||||
|
|
||||||
|
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
|
||||||
|
|
||||||
user_id = __user__.get("id") if body and __user__ else None
|
user_id = __user__.get("id") if body and __user__ else None
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return body
|
return body
|
||||||
|
|
||||||
user_message, should_skip, skip_reason = self._process_user_message(body)
|
user_message, should_skip, skip_reason = self._process_user_message(body)
|
||||||
|
|
||||||
if not user_message or should_skip:
|
if not user_message or should_skip:
|
||||||
if __event_emitter__ and skip_reason:
|
if __event_emitter__ and skip_reason:
|
||||||
await self._emit_status(__event_emitter__, skip_reason, done=True)
|
await self._emit_status(__event_emitter__, skip_reason, done=True)
|
||||||
await self._add_memory_context(body, [], user_id, __event_emitter__)
|
await self._add_memory_context(body, [], user_id, __event_emitter__)
|
||||||
return body
|
return body
|
||||||
|
|
||||||
try:
|
try:
|
||||||
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
|
memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id)
|
||||||
user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key)
|
user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key)
|
||||||
|
|
||||||
if user_memories is None:
|
if user_memories is None:
|
||||||
user_memories = await self._get_user_memories(user_id)
|
user_memories = await self._get_user_memories(user_id)
|
||||||
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
await self._cache_manager.put(
|
||||||
|
user_id,
|
||||||
|
self._cache_manager.MEMORY_CACHE,
|
||||||
|
memory_cache_key,
|
||||||
|
user_memories,
|
||||||
|
)
|
||||||
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
|
retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__)
|
||||||
memories = retrieval_result.get("memories", [])
|
memories = retrieval_result.get("memories", [])
|
||||||
threshold = retrieval_result.get("threshold")
|
threshold = retrieval_result.get("threshold")
|
||||||
all_similarities = retrieval_result.get("all_similarities", [])
|
all_similarities = retrieval_result.get("all_similarities", [])
|
||||||
|
|
||||||
if all_similarities:
|
if all_similarities:
|
||||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||||
await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities)
|
await self._cache_manager.put(
|
||||||
|
user_id,
|
||||||
|
self._cache_manager.RETRIEVAL_CACHE,
|
||||||
|
cache_key,
|
||||||
|
all_similarities,
|
||||||
|
)
|
||||||
await self._add_memory_context(body, memories, user_id, __event_emitter__)
|
await self._add_memory_context(body, memories, user_id, __event_emitter__)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}")
|
raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}")
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
|
||||||
async def outlet(
|
async def outlet(
|
||||||
@@ -1510,20 +1692,30 @@ class Filter:
|
|||||||
**kwargs,
|
**kwargs,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Simplified outlet processing for background memory consolidation."""
|
"""Simplified outlet processing for background memory consolidation."""
|
||||||
await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__)
|
|
||||||
|
model_to_use = body.get("model") if isinstance(body, dict) else None
|
||||||
|
if not model_to_use:
|
||||||
|
model_to_use = __model__ or getattr(__request__.state, "model", None)
|
||||||
|
if not model_to_use:
|
||||||
|
model_to_use = Constants.DEFAULT_LLM_MODEL
|
||||||
|
logger.warning(f"⚠️ No model found, use default model : {model_to_use}")
|
||||||
|
|
||||||
|
if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model:
|
||||||
|
model_to_use = self.valves.custom_memory_model
|
||||||
|
logger.info(f"🧠 Using the custom model for memory : {model_to_use}")
|
||||||
|
|
||||||
|
self.valves.model = model_to_use
|
||||||
|
|
||||||
|
await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__)
|
||||||
|
|
||||||
user_id = __user__.get("id") if body and __user__ else None
|
user_id = __user__.get("id") if body and __user__ else None
|
||||||
if not user_id:
|
if not user_id:
|
||||||
return body
|
return body
|
||||||
|
|
||||||
user_message, should_skip, skip_reason = self._process_user_message(body)
|
user_message, should_skip, skip_reason = self._process_user_message(body)
|
||||||
|
|
||||||
if not user_message or should_skip:
|
if not user_message or should_skip:
|
||||||
return body
|
return body
|
||||||
|
|
||||||
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message)
|
||||||
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
|
cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key)
|
||||||
|
|
||||||
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
|
task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities))
|
||||||
self._background_tasks.add(task)
|
self._background_tasks.add(task)
|
||||||
|
|
||||||
@@ -1537,7 +1729,6 @@ class Filter:
|
|||||||
logger.error(f"❌ Failed to cleanup background memory task: {str(e)}")
|
logger.error(f"❌ Failed to cleanup background memory task: {str(e)}")
|
||||||
|
|
||||||
task.add_done_callback(safe_cleanup)
|
task.add_done_callback(safe_cleanup)
|
||||||
|
|
||||||
return body
|
return body
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
@@ -1566,7 +1757,12 @@ class Filter:
|
|||||||
logger.info("📭 No memories found for user")
|
logger.info("📭 No memories found for user")
|
||||||
return
|
return
|
||||||
|
|
||||||
await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories)
|
await self._cache_manager.put(
|
||||||
|
user_id,
|
||||||
|
self._cache_manager.MEMORY_CACHE,
|
||||||
|
memory_cache_key,
|
||||||
|
user_memories,
|
||||||
|
)
|
||||||
|
|
||||||
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
|
memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS]
|
||||||
|
|
||||||
@@ -1588,7 +1784,8 @@ class Filter:
|
|||||||
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
||||||
|
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped),
|
||||||
|
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||||
)
|
)
|
||||||
return Models.MemoryOperationType.CREATE.value
|
return Models.MemoryOperationType.CREATE.value
|
||||||
|
|
||||||
@@ -1604,7 +1801,12 @@ class Filter:
|
|||||||
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value
|
||||||
|
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped),
|
asyncio.to_thread(
|
||||||
|
Memories.update_memory_by_id_and_user_id,
|
||||||
|
id_stripped,
|
||||||
|
user.id,
|
||||||
|
content_stripped,
|
||||||
|
),
|
||||||
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||||
)
|
)
|
||||||
return Models.MemoryOperationType.UPDATE.value
|
return Models.MemoryOperationType.UPDATE.value
|
||||||
@@ -1616,7 +1818,8 @@ class Filter:
|
|||||||
return Models.OperationResult.SKIPPED_EMPTY_ID.value
|
return Models.OperationResult.SKIPPED_EMPTY_ID.value
|
||||||
|
|
||||||
await asyncio.wait_for(
|
await asyncio.wait_for(
|
||||||
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC
|
asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id),
|
||||||
|
timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC,
|
||||||
)
|
)
|
||||||
return Models.MemoryOperationType.DELETE.value
|
return Models.MemoryOperationType.DELETE.value
|
||||||
else:
|
else:
|
||||||
@@ -1647,7 +1850,7 @@ class Filter:
|
|||||||
elif isinstance(value, dict):
|
elif isinstance(value, dict):
|
||||||
result[key] = self._remove_refs_from_schema(value, schema_defs)
|
result[key] = self._remove_refs_from_schema(value, schema_defs)
|
||||||
elif isinstance(value, list):
|
elif isinstance(value, list):
|
||||||
result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value]
|
result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value]
|
||||||
else:
|
else:
|
||||||
result[key] = value
|
result[key] = value
|
||||||
|
|
||||||
@@ -1656,7 +1859,12 @@ class Filter:
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]:
|
async def _query_llm(
|
||||||
|
self,
|
||||||
|
system_prompt: str,
|
||||||
|
user_prompt: str,
|
||||||
|
response_model: Optional[BaseModel] = None,
|
||||||
|
) -> Union[str, BaseModel]:
|
||||||
"""Query OpenWebUI's internal model system with Pydantic model parsing."""
|
"""Query OpenWebUI's internal model system with Pydantic model parsing."""
|
||||||
if not hasattr(self, "__request__") or not hasattr(self, "__user__"):
|
if not hasattr(self, "__request__") or not hasattr(self, "__user__"):
|
||||||
raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.")
|
raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.")
|
||||||
@@ -1667,7 +1875,10 @@ class Filter:
|
|||||||
|
|
||||||
form_data = {
|
form_data = {
|
||||||
"model": model_to_use,
|
"model": model_to_use,
|
||||||
"messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}],
|
"messages": [
|
||||||
|
{"role": "system", "content": system_prompt},
|
||||||
|
{"role": "user", "content": user_prompt},
|
||||||
|
],
|
||||||
"max_tokens": 4096,
|
"max_tokens": 4096,
|
||||||
"stream": False,
|
"stream": False,
|
||||||
}
|
}
|
||||||
@@ -1677,11 +1888,22 @@ class Filter:
|
|||||||
schema_defs = raw_schema.get("$defs", {})
|
schema_defs = raw_schema.get("$defs", {})
|
||||||
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
|
schema = self._remove_refs_from_schema(raw_schema, schema_defs)
|
||||||
schema["type"] = "object"
|
schema["type"] = "object"
|
||||||
form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}}
|
form_data["response_format"] = {
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": {
|
||||||
|
"name": response_model.__name__,
|
||||||
|
"strict": True,
|
||||||
|
"schema": schema,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await asyncio.wait_for(
|
response = await asyncio.wait_for(
|
||||||
generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])),
|
generate_chat_completion(
|
||||||
|
self.__request__,
|
||||||
|
form_data,
|
||||||
|
user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]),
|
||||||
|
),
|
||||||
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC,
|
||||||
)
|
)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
|
|||||||
@@ -4,4 +4,4 @@ pydantic
|
|||||||
numpy
|
numpy
|
||||||
tiktoken
|
tiktoken
|
||||||
black
|
black
|
||||||
isort
|
isort
|
||||||
|
|||||||
Reference in New Issue
Block a user