diff --git a/.gitignore b/.gitignore index f2d0f25..3ec1fec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ __pycache__/ +.github/instructions/* .venv/ -**AGENTS.md tests/ \ No newline at end of file diff --git a/memory_system.py b/memory_system.py index 32adfe9..9c47e17 100644 --- a/memory_system.py +++ b/memory_system.py @@ -1,6 +1,9 @@ """ title: Memory System +description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations. version: 1.0.0 +authors: https://github.com/mtayfur +license: Apache-2.0 """ import asyncio @@ -217,7 +220,10 @@ class Models: if self.operation == Models.MemoryOperationType.CREATE: return True - elif self.operation in [Models.MemoryOperationType.UPDATE, Models.MemoryOperationType.DELETE]: + elif self.operation in [ + Models.MemoryOperationType.UPDATE, + Models.MemoryOperationType.DELETE, + ]: return self.id in existing_memory_ids return False @@ -229,7 +235,10 @@ class Models: class MemoryRerankingResponse(BaseModel): """Pydantic model for memory reranking LLM response - object containing array of memory IDs.""" - ids: List[str] = Field(default_factory=list, description="List of memory IDs selected as most relevant for the user query") + ids: List[str] = Field( + default_factory=list, + description="List of memory IDs selected as most relevant for the user query", + ) class UnifiedCacheManager: @@ -440,7 +449,10 @@ class SkipDetector: SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations", } - def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): + def __init__( + self, + embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]], + ): """Initialize the skip detector with an embedding function and compute reference embeddings.""" self.embedding_function = embedding_function self._reference_embeddings = None @@ -583,7 +595,16 @@ class SkipDetector: if markup_in_lines / len(non_empty_lines) > 0.3: return self.SkipReason.SKIP_TECHNICAL.value elif structured_lines / len(non_empty_lines) > 0.6: - technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"] + technical_keywords = [ + "function", + "class", + "import", + "return", + "const", + "var", + "let", + "def", + ] if any(keyword in message.lower() for keyword in technical_keywords): return self.SkipReason.SKIP_TECHNICAL.value @@ -594,7 +615,18 @@ class SkipDetector: if non_empty_lines: indented_lines = sum(1 for line in non_empty_lines if line[0] in (" ", "\t")) if indented_lines / len(non_empty_lines) > 0.5: - code_indicators = ["def ", "class ", "function ", "return ", "import ", "const ", "let ", "var ", "public ", "private "] + code_indicators = [ + "def ", + "class ", + "function ", + "return ", + "import ", + "const ", + "let ", + "var ", + "public ", + "private ", + ] if any(indicator in message.lower() for indicator in code_indicators): return self.SkipReason.SKIP_TECHNICAL.value @@ -637,11 +669,31 @@ class SkipDetector: max_conversational_similarity = float(conversational_similarities.max()) skip_categories = [ - ("instruction", self.SkipReason.SKIP_INSTRUCTION, self.INSTRUCTION_CATEGORY_DESCRIPTIONS), - ("translation", self.SkipReason.SKIP_TRANSLATION, self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS), - ("grammar", self.SkipReason.SKIP_GRAMMAR_PROOFREAD, self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS), - ("technical", self.SkipReason.SKIP_TECHNICAL, self.TECHNICAL_CATEGORY_DESCRIPTIONS), - ("pure_math", self.SkipReason.SKIP_PURE_MATH, self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS), + ( + "instruction", + self.SkipReason.SKIP_INSTRUCTION, + self.INSTRUCTION_CATEGORY_DESCRIPTIONS, + ), + ( + "translation", + self.SkipReason.SKIP_TRANSLATION, + self.EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS, + ), + ( + "grammar", + self.SkipReason.SKIP_GRAMMAR_PROOFREAD, + self.GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS, + ), + ( + "technical", + self.SkipReason.SKIP_TECHNICAL, + self.TECHNICAL_CATEGORY_DESCRIPTIONS, + ), + ( + "pure_math", + self.SkipReason.SKIP_PURE_MATH, + self.PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS, + ), ] qualifying_categories = [] @@ -680,11 +732,23 @@ class LLMRerankingService: llm_trigger_threshold = int(self.memory_system.valves.max_memories_returned * self.memory_system.valves.llm_reranking_trigger_multiplier) if len(memories) > llm_trigger_threshold: - return True, f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold" + return ( + True, + f"{len(memories)} candidate memories exceed {llm_trigger_threshold} threshold", + ) - return False, f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}" + return ( + False, + f"{len(memories)} candidate memories within threshold of {llm_trigger_threshold}", + ) - async def _llm_select_memories(self, user_message: str, candidate_memories: List[Dict], max_count: int, emitter: Optional[Callable] = None) -> List[Dict]: + async def _llm_select_memories( + self, + user_message: str, + candidate_memories: List[Dict], + max_count: int, + emitter: Optional[Callable] = None, + ) -> List[Dict]: """Use LLM to select most relevant memories.""" memory_lines = self.memory_system._format_memories_for_llm(candidate_memories) memory_context = "\n".join(memory_lines) @@ -697,7 +761,11 @@ CANDIDATE MEMORIES: {memory_context}""" try: - response = await self.memory_system._query_llm(Prompts.MEMORY_RERANKING, user_prompt, response_model=Models.MemoryRerankingResponse) + response = await self.memory_system._query_llm( + Prompts.MEMORY_RERANKING, + user_prompt, + response_model=Models.MemoryRerankingResponse, + ) selected_memories = [] for memory in candidate_memories: @@ -712,18 +780,31 @@ CANDIDATE MEMORIES: logger.warning(f"🤖 LLM reranking failed during memory relevance analysis: {str(e)}") return candidate_memories - async def rerank_memories(self, user_message: str, candidate_memories: List[Dict], emitter: Optional[Callable] = None) -> Tuple[List[Dict], Dict[str, Any]]: + async def rerank_memories( + self, + user_message: str, + candidate_memories: List[Dict], + emitter: Optional[Callable] = None, + ) -> Tuple[List[Dict], Dict[str, Any]]: start_time = time.time() max_injection = self.memory_system.valves.max_memories_returned should_use_llm, decision_reason = self._should_use_llm_reranking(candidate_memories) - analysis_info = {"llm_decision": should_use_llm, "decision_reason": decision_reason, "candidate_count": len(candidate_memories)} + analysis_info = { + "llm_decision": should_use_llm, + "decision_reason": decision_reason, + "candidate_count": len(candidate_memories), + } if should_use_llm: extended_count = int(self.memory_system.valves.max_memories_returned * Constants.EXTENDED_MAX_MEMORY_MULTIPLIER) llm_candidates = candidate_memories[:extended_count] - await self.memory_system._emit_status(emitter, f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", done=False) + await self.memory_system._emit_status( + emitter, + f"🤖 LLM Analyzing {len(llm_candidates)} Memories for Relevance", + done=False, + ) logger.info(f"Using LLM reranking: {decision_reason}") selected_memories = await self._llm_select_memories(user_message, llm_candidates, max_injection, emitter) @@ -739,7 +820,11 @@ CANDIDATE MEMORIES: duration = time.time() - start_time duration_text = f" in {duration:.2f}s" if duration >= 0.01 else "" retrieval_method = "LLM" if should_use_llm else "Semantic" - await self.memory_system._emit_status(emitter, f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", done=True) + await self.memory_system._emit_status( + emitter, + f"🎯 {retrieval_method} Memory Retrieval Complete{duration_text}", + done=True, + ) return selected_memories, analysis_info @@ -761,7 +846,10 @@ class LLMConsolidationService: return candidates, threshold_info async def collect_consolidation_candidates( - self, user_message: str, user_id: str, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> List[Dict[str, Any]]: """Collect candidate memories for consolidation analysis using cached or computed similarities.""" if cached_similarities: @@ -805,7 +893,10 @@ class LLMConsolidationService: return candidates async def generate_consolidation_plan( - self, user_message: str, candidate_memories: List[Dict[str, Any]], emitter: Optional[Callable] = None + self, + user_message: str, + candidate_memories: List[Dict[str, Any]], + emitter: Optional[Callable] = None, ) -> List[Dict[str, Any]]: """Generate consolidation plan using LLM with clear system/user prompt separation.""" if candidate_memories: @@ -820,7 +911,11 @@ class LLMConsolidationService: try: response = await asyncio.wait_for( - self.memory_system._query_llm(Prompts.MEMORY_CONSOLIDATION, user_prompt, response_model=Models.ConsolidationResponse), + self.memory_system._query_llm( + Prompts.MEMORY_CONSOLIDATION, + user_prompt, + response_model=Models.ConsolidationResponse, + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except Exception as e: @@ -856,13 +951,21 @@ class LLMConsolidationService: return valid_operations - async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: + async def execute_memory_operations( + self, + operations: List[Dict[str, Any]], + user_id: str, + emitter: Optional[Callable] = None, + ) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" if not operations or not user_id: return 0, 0, 0, 0 try: - user = await asyncio.wait_for(asyncio.to_thread(Users.get_user_by_id, user_id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC) + user = await asyncio.wait_for( + asyncio.to_thread(Users.get_user_by_id, user_id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, + ) except asyncio.TimeoutError: raise TimeoutError(f"⏱️ User lookup timed out after {Constants.DATABASE_OPERATION_TIMEOUT_SEC}s") except Exception as e: @@ -929,7 +1032,10 @@ class LLMConsolidationService: if content_preview and content_preview != operation.id: content_preview = self.memory_system._truncate_content(content_preview) await self.memory_system._emit_status(emitter, f"🗑️ Deleted: {content_preview}", done=False) - elif result in [Models.OperationResult.FAILED.value, Models.OperationResult.UNSUPPORTED.value]: + elif result in [ + Models.OperationResult.FAILED.value, + Models.OperationResult.UNSUPPORTED.value, + ]: failed_count += 1 await self.memory_system._emit_status(emitter, f"❌ Failed {operation_type}", done=False) except Exception as e: @@ -950,7 +1056,11 @@ class LLMConsolidationService: return created_count, updated_count, deleted_count, failed_count async def run_consolidation_pipeline( - self, user_message: str, user_id: str, emitter: Optional[Callable] = None, cached_similarities: Optional[List[Dict[str, Any]]] = None + self, + user_message: str, + user_id: str, + emitter: Optional[Callable] = None, + cached_similarities: Optional[List[Dict[str, Any]]] = None, ) -> None: """Complete consolidation pipeline with simplified flow.""" start_time = time.time() @@ -974,7 +1084,11 @@ class LLMConsolidationService: total_operations = created_count + updated_count + deleted_count if total_operations > 0 or failed_count > 0: - await self.memory_system._emit_status(emitter, f"💾 Memory Consolidation Complete in {duration:.2f}s", done=False) + await self.memory_system._emit_status( + emitter, + f"💾 Memory Consolidation Complete in {duration:.2f}s", + done=False, + ) operation_details = self.memory_system._build_operation_details(created_count, updated_count, deleted_count) memory_word = "Memory" if total_operations == 1 else "Memories" @@ -1001,22 +1115,41 @@ class Filter: class Valves(BaseModel): """Configuration valves for the Memory System.""" - model: str = Field(default=Constants.DEFAULT_LLM_MODEL, description="Model name for LLM operations") - - max_message_chars: int = Field(default=Constants.MAX_MESSAGE_CHARS, description="Maximum user message length before skipping memory operations") - max_memories_returned: int = Field(default=Constants.MAX_MEMORIES_PER_RETRIEVAL, description="Maximum number of memories to return in context") - + model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Model name for LLM operations", + ) + use_custom_model_for_memory: bool = Field( + default=False, + description="Use a custom model for memory operations instead of the current chat model", + ) + custom_memory_model: str = Field( + default=Constants.DEFAULT_LLM_MODEL, + description="Custom model to use for memory operations when enabled", + ) + max_memories_returned: int = Field( + default=Constants.MAX_MEMORIES_PER_RETRIEVAL, + description="Maximum number of memories to return in context", + ) + max_message_chars: int = Field( + default=Constants.MAX_MESSAGE_CHARS, + description="Maximum user message length before skipping memory operations", + ) semantic_retrieval_threshold: float = Field( - default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, description="Minimum similarity threshold for memory retrieval" + default=Constants.SEMANTIC_RETRIEVAL_THRESHOLD, + description="Minimum similarity threshold for memory retrieval", ) relaxed_semantic_threshold_multiplier: float = Field( default=Constants.RELAXED_SEMANTIC_THRESHOLD_MULTIPLIER, description="Adjusts similarity threshold for memory consolidation (lower = more candidates)", ) - - enable_llm_reranking: bool = Field(default=True, description="Enable LLM-based memory reranking for improved contextual selection") + enable_llm_reranking: bool = Field( + default=True, + description="Enable LLM-based memory reranking for improved contextual selection", + ) llm_reranking_trigger_multiplier: float = Field( - default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, description="Controls when LLM reranking activates (lower = more aggressive)" + default=Constants.LLM_RERANKING_TRIGGER_MULTIPLIER, + description="Controls when LLM reranking activates (lower = more aggressive)", ) def __init__(self): @@ -1071,7 +1204,9 @@ class Filter: logger.info(f"🤖 Initializing skip detector with OpenWebUI embeddings: {cache_key}") embedding_fn = self._embedding_function - def embedding_wrapper(texts: Union[str, List[str]]) -> Union[np.ndarray, List[np.ndarray]]: + def embedding_wrapper( + texts: Union[str, List[str]], + ) -> Union[np.ndarray, List[np.ndarray]]: result = embedding_fn(texts, prefix=None, user=None) if isinstance(result, list): if isinstance(result[0], list): @@ -1219,7 +1354,11 @@ class Filter: def _process_user_message(self, body: Dict[str, Any]) -> Tuple[Optional[str], bool, str]: """Extract user message and determine if memory operations should be skipped.""" if not body or "messages" not in body or not isinstance(body["messages"], list): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) messages = body["messages"] user_message = None @@ -1235,7 +1374,11 @@ class Filter: break if not user_message or not user_message.strip(): - return None, True, SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE] + return ( + None, + True, + SkipDetector.STATUS_MESSAGES[SkipDetector.SkipReason.SKIP_SIZE], + ) should_skip, skip_reason = self._should_skip_memory_operations(user_message) return user_message, should_skip, skip_reason @@ -1245,7 +1388,10 @@ class Filter: if timeout is None: timeout = Constants.DATABASE_OPERATION_TIMEOUT_SEC try: - return await asyncio.wait_for(asyncio.to_thread(Memories.get_memories_by_user_id, user_id), timeout=timeout) + return await asyncio.wait_for( + asyncio.to_thread(Memories.get_memories_by_user_id, user_id), + timeout=timeout, + ) except asyncio.TimeoutError: raise TimeoutError(f"⏱️ Memory retrieval timed out after {timeout}s") except Exception as e: @@ -1274,7 +1420,11 @@ class Filter: logger.info(f"Scores: [{scores_str}{suffix}]") def _build_operation_details(self, created_count: int, updated_count: int, deleted_count: int) -> List[str]: - operations = [(created_count, "📝 Created"), (updated_count, "✏️ Updated"), (deleted_count, "🗑️ Deleted")] + operations = [ + (created_count, "📝 Created"), + (updated_count, "✏️ Updated"), + (deleted_count, "🗑️ Deleted"), + ] return [f"{label} {count}" for count, label in operations if count > 0] def _cache_key(self, cache_type: str, user_id: str, content: Optional[str] = None) -> str: @@ -1366,10 +1516,19 @@ class Filter: self._log_retrieved_memories(final_memories, "semantic") - return {"memories": final_memories, "threshold": threshold, "all_similarities": all_similarities, "reranking_info": reranking_info} + return { + "memories": final_memories, + "threshold": threshold, + "all_similarities": all_similarities, + "reranking_info": reranking_info, + } async def _add_memory_context( - self, body: Dict[str, Any], memories: Optional[List[Dict[str, Any]]] = None, user_id: Optional[str] = None, emitter: Optional[Callable] = None + self, + body: Dict[str, Any], + memories: Optional[List[Dict[str, Any]]] = None, + user_id: Optional[str] = None, + emitter: Optional[Callable] = None, ) -> None: """Add memory context to request body with simplified logic.""" if not body or "messages" not in body or not body["messages"]: @@ -1397,7 +1556,10 @@ class Filter: memory_context = "\n\n".join(content_parts) - system_index = next((i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), None) + system_index = next( + (i for i, msg in enumerate(body["messages"]) if msg.get("role") == "system"), + None, + ) if system_index is not None: body["messages"][system_index]["content"] = f"{body['messages'][system_index].get('content', '')}\n\n{memory_context}" @@ -1410,7 +1572,11 @@ class Filter: def _build_memory_dict(self, memory, similarity: float) -> Dict[str, Any]: """Build memory dictionary with standardized timestamp conversion.""" - memory_dict = {"id": str(memory.id), "content": memory.content, "relevance": similarity} + memory_dict = { + "id": str(memory.id), + "content": memory.content, + "relevance": similarity, + } if hasattr(memory, "created_at") and memory.created_at: memory_dict["created_at"] = datetime.fromtimestamp(memory.created_at, tz=timezone.utc).isoformat() if hasattr(memory, "updated_at") and memory.updated_at: @@ -1462,42 +1628,58 @@ class Filter: **kwargs, ) -> Dict[str, Any]: """Simplified inlet processing for memory retrieval and injection.""" - await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"⚠️ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"🧠 Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: if __event_emitter__ and skip_reason: await self._emit_status(__event_emitter__, skip_reason, done=True) await self._add_memory_context(body, [], user_id, __event_emitter__) return body - try: memory_cache_key = self._cache_key(self._cache_manager.MEMORY_CACHE, user_id) user_memories = await self._cache_manager.get(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key) - if user_memories is None: user_memories = await self._get_user_memories(user_id) - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) - + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) retrieval_result = await self._retrieve_relevant_memories(user_message, user_id, user_memories, __event_emitter__) memories = retrieval_result.get("memories", []) threshold = retrieval_result.get("threshold") all_similarities = retrieval_result.get("all_similarities", []) - if all_similarities: cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) - await self._cache_manager.put(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key, all_similarities) - + await self._cache_manager.put( + user_id, + self._cache_manager.RETRIEVAL_CACHE, + cache_key, + all_similarities, + ) await self._add_memory_context(body, memories, user_id, __event_emitter__) - except Exception as e: raise RuntimeError(f"💾 Memory retrieval failed: {str(e)}") - return body async def outlet( @@ -1510,20 +1692,30 @@ class Filter: **kwargs, ) -> dict: """Simplified outlet processing for background memory consolidation.""" - await self._set_pipeline_context(__event_emitter__, __user__, __model__, __request__) + + model_to_use = body.get("model") if isinstance(body, dict) else None + if not model_to_use: + model_to_use = __model__ or getattr(__request__.state, "model", None) + if not model_to_use: + model_to_use = Constants.DEFAULT_LLM_MODEL + logger.warning(f"⚠️ No model found, use default model : {model_to_use}") + + if self.valves.use_custom_model_for_memory and self.valves.custom_memory_model: + model_to_use = self.valves.custom_memory_model + logger.info(f"🧠 Using the custom model for memory : {model_to_use}") + + self.valves.model = model_to_use + + await self._set_pipeline_context(__event_emitter__, __user__, model_to_use, __request__) user_id = __user__.get("id") if body and __user__ else None if not user_id: return body - user_message, should_skip, skip_reason = self._process_user_message(body) - if not user_message or should_skip: return body - cache_key = self._cache_key(self._cache_manager.RETRIEVAL_CACHE, user_id, user_message) cached_similarities = await self._cache_manager.get(user_id, self._cache_manager.RETRIEVAL_CACHE, cache_key) - task = asyncio.create_task(self._llm_consolidation_service.run_consolidation_pipeline(user_message, user_id, __event_emitter__, cached_similarities)) self._background_tasks.add(task) @@ -1537,7 +1729,6 @@ class Filter: logger.error(f"❌ Failed to cleanup background memory task: {str(e)}") task.add_done_callback(safe_cleanup) - return body async def shutdown(self) -> None: @@ -1566,7 +1757,12 @@ class Filter: logger.info("📭 No memories found for user") return - await self._cache_manager.put(user_id, self._cache_manager.MEMORY_CACHE, memory_cache_key, user_memories) + await self._cache_manager.put( + user_id, + self._cache_manager.MEMORY_CACHE, + memory_cache_key, + user_memories, + ) memory_contents = [memory.content for memory in user_memories if memory.content and len(memory.content.strip()) >= Constants.MIN_MESSAGE_CHARS] @@ -1588,7 +1784,8 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread(Memories.insert_new_memory, user.id, content_stripped), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.CREATE.value @@ -1604,7 +1801,12 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_CONTENT.value await asyncio.wait_for( - asyncio.to_thread(Memories.update_memory_by_id_and_user_id, id_stripped, user.id, content_stripped), + asyncio.to_thread( + Memories.update_memory_by_id_and_user_id, + id_stripped, + user.id, + content_stripped, + ), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.UPDATE.value @@ -1616,7 +1818,8 @@ class Filter: return Models.OperationResult.SKIPPED_EMPTY_ID.value await asyncio.wait_for( - asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC + asyncio.to_thread(Memories.delete_memory_by_id_and_user_id, id_stripped, user.id), + timeout=Constants.DATABASE_OPERATION_TIMEOUT_SEC, ) return Models.MemoryOperationType.DELETE.value else: @@ -1647,7 +1850,7 @@ class Filter: elif isinstance(value, dict): result[key] = self._remove_refs_from_schema(value, schema_defs) elif isinstance(value, list): - result[key] = [self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item for item in value] + result[key] = [(self._remove_refs_from_schema(item, schema_defs) if isinstance(item, dict) else item) for item in value] else: result[key] = value @@ -1656,7 +1859,12 @@ class Filter: return result - async def _query_llm(self, system_prompt: str, user_prompt: str, response_model: Optional[BaseModel] = None) -> Union[str, BaseModel]: + async def _query_llm( + self, + system_prompt: str, + user_prompt: str, + response_model: Optional[BaseModel] = None, + ) -> Union[str, BaseModel]: """Query OpenWebUI's internal model system with Pydantic model parsing.""" if not hasattr(self, "__request__") or not hasattr(self, "__user__"): raise RuntimeError("🔧 Pipeline interface not properly initialized. __request__ and __user__ required.") @@ -1667,7 +1875,10 @@ class Filter: form_data = { "model": model_to_use, - "messages": [{"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt}], + "messages": [ + {"role": "system", "content": system_prompt}, + {"role": "user", "content": user_prompt}, + ], "max_tokens": 4096, "stream": False, } @@ -1677,11 +1888,22 @@ class Filter: schema_defs = raw_schema.get("$defs", {}) schema = self._remove_refs_from_schema(raw_schema, schema_defs) schema["type"] = "object" - form_data["response_format"] = {"type": "json_schema", "json_schema": {"name": response_model.__name__, "strict": True, "schema": schema}} + form_data["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": response_model.__name__, + "strict": True, + "schema": schema, + }, + } try: response = await asyncio.wait_for( - generate_chat_completion(self.__request__, form_data, user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"])), + generate_chat_completion( + self.__request__, + form_data, + user=await asyncio.to_thread(Users.get_user_by_id, self.__user__["id"]), + ), timeout=Constants.LLM_CONSOLIDATION_TIMEOUT_SEC, ) except asyncio.TimeoutError: diff --git a/requirements.txt b/requirements.txt index 910811c..6b04c53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ pydantic numpy tiktoken black -isort \ No newline at end of file +isort