diff --git a/.gitignore b/.gitignore index f2d0f25..3ec1fec 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,4 @@ __pycache__/ +.github/instructions/* .venv/ -**AGENTS.md tests/ \ No newline at end of file diff --git a/memory_system.py b/memory_system.py index 32adfe9..394a634 100644 --- a/memory_system.py +++ b/memory_system.py @@ -1,6 +1,9 @@ """ title: Memory System +description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations. version: 1.0.0 +authors: https://github.com/mtayfur +license: Apache-2.0 """ import asyncio @@ -51,7 +54,7 @@ class Constants: LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold # Skip Detection - SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification + SKIP_CATEGORY_MARGIN = 0.20 # Margin above personal similarity for skip category classification # Safety & Operations MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety @@ -318,9 +321,9 @@ class UnifiedCacheManager: class SkipDetector: - """Semantic-based content classifier using zero-shot classification with category descriptions.""" + """Binary content classifier: personal vs non-personal using semantic analysis.""" - TECHNICAL_CATEGORY_DESCRIPTIONS = [ + NON_PERSONAL_CATEGORY_DESCRIPTIONS = [ "programming language syntax, data types like string or integer, algorithm logic, function, method, programming class, object-oriented paradigm, variable scope, control flow, import, module, package, library, framework, recursion, iteration", "software design patterns, creational: singleton, factory, builder; structural: adapter, decorator, facade, proxy; behavioral: observer, strategy, command, mediator, chain of responsibility; abstract interface, polymorphism, composition", "error handling, exception, stack trace, TypeError, NullPointerException, IndexError, segmentation fault, core dump, stack overflow, runtime vs compile-time error, assertion failed, syntax error, null pointer dereference, memory leak, bug", @@ -341,9 +344,6 @@ class SkipDetector: "regex pattern, regular expression matching, groups, capturing, backslash escapes, metacharacters, wildcards, quantifiers, character classes, lookaheads, lookbehinds, alternation, anchors, word boundary, multiline flag, global search", "software testing, unit test, assertion, mock, stub, fixture, test suite, test case, verification, automated QA, validation framework, JUnit, pytest, Jest. Integration, end-to-end (E2E), functional, regression, acceptance testing", "cloud computing platforms, infrastructure as a service (IaaS), PaaS, AWS, Azure, GCP, compute instance, region, availability zone, elasticity, distributed system, virtual machine, container, serverless, Lambda, edge computing, CDN", - ] - - INSTRUCTION_CATEGORY_DESCRIPTIONS = [ "format the output as structured data. Return the answer as JSON with specific keys and values, or as YAML. Organize information into a CSV file or a database-style table with columns and rows. Present as a list of objects or an array.", "style the text presentation. Use markdown formatting like bullet points, a numbered list, or a task list. Organize content into a grid or tabular layout with proper alignment. Create a hierarchical structure with nested elements for clarity.", "adjust the response length. Make the answer shorter, more concise, brief, or condensed. Summarize the key points. Trim down the text to reduce the overall word count or meet a specific character limit. Be less verbose and more direct.", @@ -354,9 +354,6 @@ class SkipDetector: "continue the generated response. Keep going with the explanation or list. Provide more information and finish your thought. Complete the rest of the content or story. Proceed with the next steps. Do not stop until you have concluded.", "act as a specific persona or role. Respond as if you were a pirate, a scientist, or a travel guide. Adopt the character's voice, style, and knowledge base in your answer. Maintain the persona throughout the entire response.", "compare and contrast two or more topics. Explain the similarities and differences between A and B. Provide a detailed analysis of what they have in common and how they diverge. Create a table to highlight the key distinctions.", - ] - - PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS = [ "perform a pure arithmetic calculation with explicit numbers. Solve, multiply, add, subtract, and divide. Compute a numeric expression following the order of operations (PEMDAS/BODMAS). What is 23 plus 456 minus 78 times 9 divided by 3?", "evaluate a mathematical expression containing numbers and operators, such as 2 plus 3 times 4 divided by 5. Solve this numerical problem and compute the final result. Simplify the arithmetic and show the final answer. Calculate 123 * 456.", "convert units between measurement systems with numeric values. Convert 100 kilometers to miles, 72 fahrenheit to celsius, or 5 feet 9 inches to centimeters. Change between metric and imperial for distance, weight, volume, or temperature.", @@ -367,9 +364,6 @@ class SkipDetector: "compute descriptive statistics for a dataset of numbers like 12, 15, 18, 20, 22. Calculate the mean, median, mode, average, and standard deviation. Find the variance, range, quartiles, and percentiles for a given sample distribution.", "calculate health and fitness metrics using a numeric formula. Compute the Body Mass Index (BMI) given a weight in pounds or kilograms and height in feet, inches, or meters. Find my basal metabolic rate (BMR) or target heart rate.", "calculate the time difference between two dates. How many days, hours, or minutes are between two points in time? Find the duration or elapsed time. Act as an age calculator for a birthday or find the time until a future anniversary.", - ] - - EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS = [ "translate the explicitly quoted text 'Hello, how are you?' to a foreign language like Spanish, French, or German. This is a translation instruction that includes the word 'translate' and the source text in quotes for direct conversion.", "how do you say a specific word or phrase in another language? For example, how do you say 'thank you', 'computer', or 'goodbye' in Japanese, Chinese, or Korean? This is a request for a direct translation of a common expression or term.", "convert a block of text or a paragraph from a source language to a target language. Translate the following content to Italian, Arabic, Portuguese, or Russian. This is a language conversion request for a larger piece of text provided.", @@ -380,9 +374,6 @@ class SkipDetector: "how do I say 'I am learning to code' in German? Convert this specific English phrase into its equivalent in another language. This is a request for a practical, conversational phrase translation for personal or professional use.", "translate this informal or slang expression to its colloquial equivalent in Spanish. How would you say 'What's up?' in Japanese in a casual context? This request focuses on capturing the correct tone and nuance of informal language.", "provide the formal and professional translation for 'Please find the attached document for your review' in French. Translate this business email phrase to German, ensuring the terminology and register are appropriate for a corporate context.", - ] - - GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS = [ "proofread the following text for errors. Here is my draft, please check it for typos and mistakes: 'Teh quick brown fox jumpped'. Review, revise, and correct any misspellings or grammatical issues you find in the provided passage.", "correct the grammar in this sentence: 'She don't like it'. Resolve grammatical issues like subject-verb agreement, incorrect verb tense, pronoun reference errors, or misplaced modifiers in the provided text. Address faulty sentence structure.", "check the spelling and punctuation in this passage. Please review the following text and correct any textual errors: 'its a beautiful day, isnt it'. Amend mistakes with commas, periods, apostrophes, quotation marks, colons, or capitalization.", @@ -395,7 +386,7 @@ class SkipDetector: "check my essay for conciseness and remove any redundancy. Help me edit this text to be more direct and to the point. Identify and eliminate wordiness, filler words, and repetitive phrases to strengthen the overall quality of the writing.", ] - CONVERSATIONAL_CATEGORY_DESCRIPTIONS = [ + PERSONAL_CATEGORY_DESCRIPTIONS = [ "discussing my family members, like my spouse, children, parents, or siblings. Mentioning relatives by name or role, such as my husband, wife, son, daughter, mother, or father. Sharing stories or asking questions about my family.", "expressing lasting personal feelings, core values, beliefs, or principles. My worldview, deeply held opinions, philosophy, or moral standards. Things I love, hate, or feel strongly about in life, such as my passion for animal welfare.", "describing my established personal hobbies, regular activities, or consistent interests. My passions and what I do in my leisure time, such as creative outlets like painting, sports like hiking, or other recreational pursuits I enjoy.", @@ -425,19 +416,11 @@ class SkipDetector: class SkipReason(Enum): SKIP_SIZE = "SKIP_SIZE" - SKIP_TECHNICAL = "SKIP_TECHNICAL" - SKIP_INSTRUCTION = "SKIP_INSTRUCTION" - SKIP_PURE_MATH = "SKIP_PURE_MATH" - SKIP_TRANSLATION = "SKIP_TRANSLATION" - SKIP_GRAMMAR_PROOFREAD = "SKIP_GRAMMAR_PROOFREAD" + SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL" STATUS_MESSAGES = { SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations", - SkipReason.SKIP_TECHNICAL: "💻 Technical Content Detected, skipping memory operations", - SkipReason.SKIP_INSTRUCTION: "💬 Instruction Detected, skipping memory operations", - SkipReason.SKIP_PURE_MATH: "🔢 Mathematical Calculation Detected, skipping memory operations", - SkipReason.SKIP_TRANSLATION: "🌐 Translation Request Detected, skipping memory operations", - SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations", + SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations", } def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]): @@ -581,7 +564,7 @@ class SkipDetector: structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t"))) if markup_in_lines / len(non_empty_lines) > 0.3: - return self.SkipReason.SKIP_TECHNICAL.value + return self.SkipReason.SKIP_NON_PERSONAL.value elif structured_lines / len(non_empty_lines) > 0.6: technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"] if any(keyword in message.lower() for keyword in technical_keywords): @@ -613,7 +596,7 @@ class SkipDetector: """ Detect if a message should be skipped using two-stage detection: 1. Fast-path structural patterns (~95% confidence) - 2. Semantic classification (for remaining cases) + 2. Binary semantic classification (personal vs non-personal) Returns: Skip reason string if content should be skipped, None otherwise """ @@ -749,6 +732,34 @@ class LLMConsolidationService: def __init__(self, memory_system): self.memory_system = memory_system + async def _check_semantic_duplicate(self, content: str, existing_memories: List, user_id: str) -> Optional[str]: + """ + Check if content is semantically duplicate of existing memories using embeddings. + Returns the ID of duplicate memory if found, None otherwise. + """ + if not existing_memories: + return None + + try: + content_embedding = await self.memory_system._generate_embeddings(content, user_id) + + for memory in existing_memories: + if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS: + continue + + memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id) + + similarity = float(np.dot(content_embedding, memory_embedding)) + + if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD: + logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}") + return str(memory.id) + + return None + except Exception as e: + logger.warning(f"⚠️ Semantic duplicate check failed: {str(e)}") + return None + def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]: """Filter consolidation candidates by threshold and return candidates with threshold info.""" consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True) @@ -841,7 +852,31 @@ class LLMConsolidationService: ) return [] - valid_operations = [op.model_dump() for op in operations if op.validate_operation(existing_memory_ids)] + deduplicated_operations = [] + seen_contents = set() + seen_update_ids = set() + + for op in operations: + if not op.validate_operation(existing_memory_ids): + continue + + if op.operation == Models.MemoryOperationType.UPDATE and op.id in seen_update_ids: + logger.info(f"⏭️ Skipping duplicate UPDATE for memory {op.id} in LLM response") + continue + + if op.operation in [Models.MemoryOperationType.CREATE, Models.MemoryOperationType.UPDATE]: + normalized_content = op.content.strip().lower() + if normalized_content in seen_contents: + op_type = "CREATE" if op.operation == Models.MemoryOperationType.CREATE else f"UPDATE {op.id}" + logger.info(f"⏭️ Skipping duplicate {op_type} in LLM response: {self.memory_system._truncate_content(op.content)}") + continue + seen_contents.add(normalized_content) + + if op.operation == Models.MemoryOperationType.UPDATE: + seen_update_ids.add(op.id) + deduplicated_operations.append(op.model_dump()) + + valid_operations = deduplicated_operations if valid_operations: create_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.CREATE.value) @@ -856,6 +891,41 @@ class LLMConsolidationService: return valid_operations + async def _deduplicate_operations( + self, operations: List, current_memories: List, user_id: str, operation_type: str, delete_operations: Optional[List] = None + ) -> List: + """ + Deduplicate operations against existing memories using semantic similarity. + For UPDATE operations, preserves enriched content and deletes the duplicate. + """ + deduplicated = [] + + for operation in operations: + memories_to_check = current_memories + if operation_type == "UPDATE": + memories_to_check = [m for m in current_memories if str(m.id) != operation.id] + + duplicate_id = await self._check_semantic_duplicate(operation.content, memories_to_check, user_id) + + if duplicate_id: + if operation_type == "UPDATE" and delete_operations is not None: + logger.info( + f"🔄 UPDATE creates duplicate: keeping enriched content from memory {operation.id}, " f"deleting duplicate memory {duplicate_id}" + ) + deduplicated.append(operation) + delete_operations.append(Models.MemoryOperation(operation=Models.MemoryOperationType.DELETE, content="", id=duplicate_id)) + else: + logger.info( + f"⏭️ Skipping duplicate {operation_type}: " + f"{self.memory_system._truncate_content(operation.content)} " + f"(matches memory {duplicate_id})" + ) + continue + + deduplicated.append(operation) + + return deduplicated + async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]: """Execute consolidation operations with simplified tracking.""" if not operations or not user_id: @@ -898,6 +968,23 @@ class LLMConsolidationService: except Exception as e: logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}") + if operations_by_type["CREATE"] or operations_by_type["UPDATE"]: + try: + current_memories = await self.memory_system._get_user_memories(user_id) + + if operations_by_type["CREATE"]: + operations_by_type["CREATE"] = await self._deduplicate_operations( + operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE" + ) + + if operations_by_type["UPDATE"]: + operations_by_type["UPDATE"] = await self._deduplicate_operations( + operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"] + ) + + except Exception as e: + logger.warning(f"⚠️ Semantic deduplication check failed, proceeding with original operations: {str(e)}") + for operation_type, ops in operations_by_type.items(): if not ops: continue @@ -1031,6 +1118,7 @@ class Filter: self._shutdown_event = asyncio.Event() self._embedding_function = None + self._embedding_dimension = None self._skip_detector = None self._llm_reranking_service = LLMRerankingService(self) @@ -1133,12 +1221,32 @@ class Filter: """Compute SHA256 hash for text caching.""" return hashlib.sha256(text.encode()).hexdigest() + def _detect_embedding_dimension(self) -> None: + """Detect embedding dimension by generating a test embedding.""" + try: + test_embedding = self._embedding_function(["dummy"], prefix=None, user=None) + if isinstance(test_embedding, list): + test_embedding = test_embedding[0] + self._embedding_dimension = np.squeeze(test_embedding).shape[0] + logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}") + except Exception as e: + raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}") + def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray: - """Normalize embedding vector.""" + """Normalize embedding vector and ensure 1D shape.""" if isinstance(embedding, list): embedding = np.array(embedding, dtype=np.float16) else: embedding = embedding.astype(np.float16) + + embedding = np.squeeze(embedding) + + if embedding.ndim != 1: + raise ValueError(f"Embedding must be 1D after squeeze, got shape {embedding.shape}") + + if self._embedding_dimension and embedding.shape[0] != self._embedding_dimension: + raise ValueError(f"Embedding must have {self._embedding_dimension} dimensions, got {embedding.shape[0]}") + norm = np.linalg.norm(embedding) return embedding / norm if norm > 0 else embedding @@ -1185,7 +1293,7 @@ class Filter: raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user) if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0: - if isinstance(raw_embeddings[0], list): + if isinstance(raw_embeddings[0], (list, np.ndarray)): new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings] else: new_embeddings = [self._normalize_embedding(raw_embeddings)] @@ -1199,7 +1307,8 @@ class Filter: result_embeddings[original_idx] = embedding if is_single: - logger.info("📥 User message embedding: cache hit" if not uncached_texts else "💾 User message embedding: generated and cached") + if uncached_texts: + logger.info("💾 User message embedding: generated and cached") return result_embeddings[0] else: valid_count = sum(1 for emb in result_embeddings if emb is not None) diff --git a/requirements.txt b/requirements.txt index 910811c..6b04c53 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,4 +4,4 @@ pydantic numpy tiktoken black -isort \ No newline at end of file +isort