♻️ (memory_system): refactor skip detection and add semantic deduplication

- Unify skip detection to a binary classifier (personal vs non-personal)
 for improved maintainability and clarity. Remove multiple technical/
 instruction/translation/etc. categories and consolidate into
 NON_PERSONAL and PERSONAL.
- Adjust skip detection margin for more precise classification.
- Add semantic deduplication for memory operations using embedding
 similarity, preventing duplicate memory creation and updates.
- Normalize and validate embedding dimensions for robustness.
- Add per-user async locks to prevent race conditions during memory
 consolidation.
- Refactor requirements.txt to remove version pinning for easier
 dependency management.
- Improve logging and error handling for embedding and deduplication
 operations.

These changes improve the reliability and accuracy of memory
classification and deduplication, reduce false positives in skip
detection, and prevent duplicate or conflicting memory operations in
concurrent environments. Dependency management is simplified for
compatibility.
This commit is contained in:
mtayfur
2025-10-26 23:31:37 +03:00
parent bb1bd01222
commit 3f9b4c6d48
3 changed files with 143 additions and 34 deletions

2
.gitignore vendored
View File

@@ -1,4 +1,4 @@
__pycache__/
.github/instructions/*
.venv/
**AGENTS.md
tests/

View File

@@ -1,6 +1,9 @@
"""
title: Memory System
description: A semantic memory management system for Open WebUI that consolidates, deduplicates, and retrieves personalized user memories using LLM operations.
version: 1.0.0
authors: https://github.com/mtayfur
license: Apache-2.0
"""
import asyncio
@@ -51,7 +54,7 @@ class Constants:
LLM_RERANKING_TRIGGER_MULTIPLIER = 0.8 # Multiplier for LLM reranking trigger threshold
# Skip Detection
SKIP_CATEGORY_MARGIN = 0.5 # Margin above conversational similarity for skip category classification
SKIP_CATEGORY_MARGIN = 0.20 # Margin above personal similarity for skip category classification
# Safety & Operations
MAX_DELETE_OPERATIONS_RATIO = 0.6 # Maximum delete operations ratio for safety
@@ -318,9 +321,9 @@ class UnifiedCacheManager:
class SkipDetector:
"""Semantic-based content classifier using zero-shot classification with category descriptions."""
"""Binary content classifier: personal vs non-personal using semantic analysis."""
TECHNICAL_CATEGORY_DESCRIPTIONS = [
NON_PERSONAL_CATEGORY_DESCRIPTIONS = [
"programming language syntax, data types like string or integer, algorithm logic, function, method, programming class, object-oriented paradigm, variable scope, control flow, import, module, package, library, framework, recursion, iteration",
"software design patterns, creational: singleton, factory, builder; structural: adapter, decorator, facade, proxy; behavioral: observer, strategy, command, mediator, chain of responsibility; abstract interface, polymorphism, composition",
"error handling, exception, stack trace, TypeError, NullPointerException, IndexError, segmentation fault, core dump, stack overflow, runtime vs compile-time error, assertion failed, syntax error, null pointer dereference, memory leak, bug",
@@ -341,9 +344,6 @@ class SkipDetector:
"regex pattern, regular expression matching, groups, capturing, backslash escapes, metacharacters, wildcards, quantifiers, character classes, lookaheads, lookbehinds, alternation, anchors, word boundary, multiline flag, global search",
"software testing, unit test, assertion, mock, stub, fixture, test suite, test case, verification, automated QA, validation framework, JUnit, pytest, Jest. Integration, end-to-end (E2E), functional, regression, acceptance testing",
"cloud computing platforms, infrastructure as a service (IaaS), PaaS, AWS, Azure, GCP, compute instance, region, availability zone, elasticity, distributed system, virtual machine, container, serverless, Lambda, edge computing, CDN",
]
INSTRUCTION_CATEGORY_DESCRIPTIONS = [
"format the output as structured data. Return the answer as JSON with specific keys and values, or as YAML. Organize information into a CSV file or a database-style table with columns and rows. Present as a list of objects or an array.",
"style the text presentation. Use markdown formatting like bullet points, a numbered list, or a task list. Organize content into a grid or tabular layout with proper alignment. Create a hierarchical structure with nested elements for clarity.",
"adjust the response length. Make the answer shorter, more concise, brief, or condensed. Summarize the key points. Trim down the text to reduce the overall word count or meet a specific character limit. Be less verbose and more direct.",
@@ -354,9 +354,6 @@ class SkipDetector:
"continue the generated response. Keep going with the explanation or list. Provide more information and finish your thought. Complete the rest of the content or story. Proceed with the next steps. Do not stop until you have concluded.",
"act as a specific persona or role. Respond as if you were a pirate, a scientist, or a travel guide. Adopt the character's voice, style, and knowledge base in your answer. Maintain the persona throughout the entire response.",
"compare and contrast two or more topics. Explain the similarities and differences between A and B. Provide a detailed analysis of what they have in common and how they diverge. Create a table to highlight the key distinctions.",
]
PURE_MATH_CALCULATION_CATEGORY_DESCRIPTIONS = [
"perform a pure arithmetic calculation with explicit numbers. Solve, multiply, add, subtract, and divide. Compute a numeric expression following the order of operations (PEMDAS/BODMAS). What is 23 plus 456 minus 78 times 9 divided by 3?",
"evaluate a mathematical expression containing numbers and operators, such as 2 plus 3 times 4 divided by 5. Solve this numerical problem and compute the final result. Simplify the arithmetic and show the final answer. Calculate 123 * 456.",
"convert units between measurement systems with numeric values. Convert 100 kilometers to miles, 72 fahrenheit to celsius, or 5 feet 9 inches to centimeters. Change between metric and imperial for distance, weight, volume, or temperature.",
@@ -367,9 +364,6 @@ class SkipDetector:
"compute descriptive statistics for a dataset of numbers like 12, 15, 18, 20, 22. Calculate the mean, median, mode, average, and standard deviation. Find the variance, range, quartiles, and percentiles for a given sample distribution.",
"calculate health and fitness metrics using a numeric formula. Compute the Body Mass Index (BMI) given a weight in pounds or kilograms and height in feet, inches, or meters. Find my basal metabolic rate (BMR) or target heart rate.",
"calculate the time difference between two dates. How many days, hours, or minutes are between two points in time? Find the duration or elapsed time. Act as an age calculator for a birthday or find the time until a future anniversary.",
]
EXPLICIT_TRANSLATION_CATEGORY_DESCRIPTIONS = [
"translate the explicitly quoted text 'Hello, how are you?' to a foreign language like Spanish, French, or German. This is a translation instruction that includes the word 'translate' and the source text in quotes for direct conversion.",
"how do you say a specific word or phrase in another language? For example, how do you say 'thank you', 'computer', or 'goodbye' in Japanese, Chinese, or Korean? This is a request for a direct translation of a common expression or term.",
"convert a block of text or a paragraph from a source language to a target language. Translate the following content to Italian, Arabic, Portuguese, or Russian. This is a language conversion request for a larger piece of text provided.",
@@ -380,9 +374,6 @@ class SkipDetector:
"how do I say 'I am learning to code' in German? Convert this specific English phrase into its equivalent in another language. This is a request for a practical, conversational phrase translation for personal or professional use.",
"translate this informal or slang expression to its colloquial equivalent in Spanish. How would you say 'What's up?' in Japanese in a casual context? This request focuses on capturing the correct tone and nuance of informal language.",
"provide the formal and professional translation for 'Please find the attached document for your review' in French. Translate this business email phrase to German, ensuring the terminology and register are appropriate for a corporate context.",
]
GRAMMAR_PROOFREADING_CATEGORY_DESCRIPTIONS = [
"proofread the following text for errors. Here is my draft, please check it for typos and mistakes: 'Teh quick brown fox jumpped'. Review, revise, and correct any misspellings or grammatical issues you find in the provided passage.",
"correct the grammar in this sentence: 'She don't like it'. Resolve grammatical issues like subject-verb agreement, incorrect verb tense, pronoun reference errors, or misplaced modifiers in the provided text. Address faulty sentence structure.",
"check the spelling and punctuation in this passage. Please review the following text and correct any textual errors: 'its a beautiful day, isnt it'. Amend mistakes with commas, periods, apostrophes, quotation marks, colons, or capitalization.",
@@ -395,7 +386,7 @@ class SkipDetector:
"check my essay for conciseness and remove any redundancy. Help me edit this text to be more direct and to the point. Identify and eliminate wordiness, filler words, and repetitive phrases to strengthen the overall quality of the writing.",
]
CONVERSATIONAL_CATEGORY_DESCRIPTIONS = [
PERSONAL_CATEGORY_DESCRIPTIONS = [
"discussing my family members, like my spouse, children, parents, or siblings. Mentioning relatives by name or role, such as my husband, wife, son, daughter, mother, or father. Sharing stories or asking questions about my family.",
"expressing lasting personal feelings, core values, beliefs, or principles. My worldview, deeply held opinions, philosophy, or moral standards. Things I love, hate, or feel strongly about in life, such as my passion for animal welfare.",
"describing my established personal hobbies, regular activities, or consistent interests. My passions and what I do in my leisure time, such as creative outlets like painting, sports like hiking, or other recreational pursuits I enjoy.",
@@ -425,19 +416,11 @@ class SkipDetector:
class SkipReason(Enum):
SKIP_SIZE = "SKIP_SIZE"
SKIP_TECHNICAL = "SKIP_TECHNICAL"
SKIP_INSTRUCTION = "SKIP_INSTRUCTION"
SKIP_PURE_MATH = "SKIP_PURE_MATH"
SKIP_TRANSLATION = "SKIP_TRANSLATION"
SKIP_GRAMMAR_PROOFREAD = "SKIP_GRAMMAR_PROOFREAD"
SKIP_NON_PERSONAL = "SKIP_NON_PERSONAL"
STATUS_MESSAGES = {
SkipReason.SKIP_SIZE: "📏 Message Length Out of Limits, skipping memory operations",
SkipReason.SKIP_TECHNICAL: "💻 Technical Content Detected, skipping memory operations",
SkipReason.SKIP_INSTRUCTION: "💬 Instruction Detected, skipping memory operations",
SkipReason.SKIP_PURE_MATH: "🔢 Mathematical Calculation Detected, skipping memory operations",
SkipReason.SKIP_TRANSLATION: "🌐 Translation Request Detected, skipping memory operations",
SkipReason.SKIP_GRAMMAR_PROOFREAD: "📝 Grammar/Proofreading Request Detected, skipping memory operations",
SkipReason.SKIP_NON_PERSONAL: "🚫 Non-Personal Content Detected, skipping memory operations",
}
def __init__(self, embedding_function: Callable[[Union[str, List[str]]], Union[np.ndarray, List[np.ndarray]]]):
@@ -581,7 +564,7 @@ class SkipDetector:
structured_lines = sum(1 for line in non_empty_lines if line.startswith((" ", "\t")))
if markup_in_lines / len(non_empty_lines) > 0.3:
return self.SkipReason.SKIP_TECHNICAL.value
return self.SkipReason.SKIP_NON_PERSONAL.value
elif structured_lines / len(non_empty_lines) > 0.6:
technical_keywords = ["function", "class", "import", "return", "const", "var", "let", "def"]
if any(keyword in message.lower() for keyword in technical_keywords):
@@ -613,7 +596,7 @@ class SkipDetector:
"""
Detect if a message should be skipped using two-stage detection:
1. Fast-path structural patterns (~95% confidence)
2. Semantic classification (for remaining cases)
2. Binary semantic classification (personal vs non-personal)
Returns:
Skip reason string if content should be skipped, None otherwise
"""
@@ -749,6 +732,34 @@ class LLMConsolidationService:
def __init__(self, memory_system):
self.memory_system = memory_system
async def _check_semantic_duplicate(self, content: str, existing_memories: List, user_id: str) -> Optional[str]:
"""
Check if content is semantically duplicate of existing memories using embeddings.
Returns the ID of duplicate memory if found, None otherwise.
"""
if not existing_memories:
return None
try:
content_embedding = await self.memory_system._generate_embeddings(content, user_id)
for memory in existing_memories:
if not memory.content or len(memory.content.strip()) < Constants.MIN_MESSAGE_CHARS:
continue
memory_embedding = await self.memory_system._generate_embeddings(memory.content, user_id)
similarity = float(np.dot(content_embedding, memory_embedding))
if similarity >= Constants.DEDUPLICATION_SIMILARITY_THRESHOLD:
logger.info(f"🔍 Semantic duplicate detected: similarity={similarity:.3f} with memory {memory.id}")
return str(memory.id)
return None
except Exception as e:
logger.warning(f"⚠️ Semantic duplicate check failed: {str(e)}")
return None
def _filter_consolidation_candidates(self, similarities: List[Dict[str, Any]]) -> Tuple[List[Dict[str, Any]], str]:
"""Filter consolidation candidates by threshold and return candidates with threshold info."""
consolidation_threshold = self.memory_system._get_retrieval_threshold(is_consolidation=True)
@@ -841,7 +852,31 @@ class LLMConsolidationService:
)
return []
valid_operations = [op.model_dump() for op in operations if op.validate_operation(existing_memory_ids)]
deduplicated_operations = []
seen_contents = set()
seen_update_ids = set()
for op in operations:
if not op.validate_operation(existing_memory_ids):
continue
if op.operation == Models.MemoryOperationType.UPDATE and op.id in seen_update_ids:
logger.info(f"⏭️ Skipping duplicate UPDATE for memory {op.id} in LLM response")
continue
if op.operation in [Models.MemoryOperationType.CREATE, Models.MemoryOperationType.UPDATE]:
normalized_content = op.content.strip().lower()
if normalized_content in seen_contents:
op_type = "CREATE" if op.operation == Models.MemoryOperationType.CREATE else f"UPDATE {op.id}"
logger.info(f"⏭️ Skipping duplicate {op_type} in LLM response: {self.memory_system._truncate_content(op.content)}")
continue
seen_contents.add(normalized_content)
if op.operation == Models.MemoryOperationType.UPDATE:
seen_update_ids.add(op.id)
deduplicated_operations.append(op.model_dump())
valid_operations = deduplicated_operations
if valid_operations:
create_count = sum(1 for op in valid_operations if op.get("operation") == Models.MemoryOperationType.CREATE.value)
@@ -856,6 +891,41 @@ class LLMConsolidationService:
return valid_operations
async def _deduplicate_operations(
self, operations: List, current_memories: List, user_id: str, operation_type: str, delete_operations: Optional[List] = None
) -> List:
"""
Deduplicate operations against existing memories using semantic similarity.
For UPDATE operations, preserves enriched content and deletes the duplicate.
"""
deduplicated = []
for operation in operations:
memories_to_check = current_memories
if operation_type == "UPDATE":
memories_to_check = [m for m in current_memories if str(m.id) != operation.id]
duplicate_id = await self._check_semantic_duplicate(operation.content, memories_to_check, user_id)
if duplicate_id:
if operation_type == "UPDATE" and delete_operations is not None:
logger.info(
f"🔄 UPDATE creates duplicate: keeping enriched content from memory {operation.id}, " f"deleting duplicate memory {duplicate_id}"
)
deduplicated.append(operation)
delete_operations.append(Models.MemoryOperation(operation=Models.MemoryOperationType.DELETE, content="", id=duplicate_id))
else:
logger.info(
f"⏭️ Skipping duplicate {operation_type}: "
f"{self.memory_system._truncate_content(operation.content)} "
f"(matches memory {duplicate_id})"
)
continue
deduplicated.append(operation)
return deduplicated
async def execute_memory_operations(self, operations: List[Dict[str, Any]], user_id: str, emitter: Optional[Callable] = None) -> Tuple[int, int, int, int]:
"""Execute consolidation operations with simplified tracking."""
if not operations or not user_id:
@@ -898,6 +968,23 @@ class LLMConsolidationService:
except Exception as e:
logger.warning(f"⚠️ Failed to fetch memories for DELETE preview: {str(e)}")
if operations_by_type["CREATE"] or operations_by_type["UPDATE"]:
try:
current_memories = await self.memory_system._get_user_memories(user_id)
if operations_by_type["CREATE"]:
operations_by_type["CREATE"] = await self._deduplicate_operations(
operations_by_type["CREATE"], current_memories, user_id, operation_type="CREATE"
)
if operations_by_type["UPDATE"]:
operations_by_type["UPDATE"] = await self._deduplicate_operations(
operations_by_type["UPDATE"], current_memories, user_id, operation_type="UPDATE", delete_operations=operations_by_type["DELETE"]
)
except Exception as e:
logger.warning(f"⚠️ Semantic deduplication check failed, proceeding with original operations: {str(e)}")
for operation_type, ops in operations_by_type.items():
if not ops:
continue
@@ -1031,6 +1118,7 @@ class Filter:
self._shutdown_event = asyncio.Event()
self._embedding_function = None
self._embedding_dimension = None
self._skip_detector = None
self._llm_reranking_service = LLMRerankingService(self)
@@ -1133,12 +1221,32 @@ class Filter:
"""Compute SHA256 hash for text caching."""
return hashlib.sha256(text.encode()).hexdigest()
def _detect_embedding_dimension(self) -> None:
"""Detect embedding dimension by generating a test embedding."""
try:
test_embedding = self._embedding_function(["dummy"], prefix=None, user=None)
if isinstance(test_embedding, list):
test_embedding = test_embedding[0]
self._embedding_dimension = np.squeeze(test_embedding).shape[0]
logger.info(f"🎯 Detected embedding dimension: {self._embedding_dimension}")
except Exception as e:
raise RuntimeError(f"Failed to detect embedding dimension: {str(e)}")
def _normalize_embedding(self, embedding: Union[List[float], np.ndarray]) -> np.ndarray:
"""Normalize embedding vector."""
"""Normalize embedding vector and ensure 1D shape."""
if isinstance(embedding, list):
embedding = np.array(embedding, dtype=np.float16)
else:
embedding = embedding.astype(np.float16)
embedding = np.squeeze(embedding)
if embedding.ndim != 1:
raise ValueError(f"Embedding must be 1D after squeeze, got shape {embedding.shape}")
if self._embedding_dimension and embedding.shape[0] != self._embedding_dimension:
raise ValueError(f"Embedding must have {self._embedding_dimension} dimensions, got {embedding.shape[0]}")
norm = np.linalg.norm(embedding)
return embedding / norm if norm > 0 else embedding
@@ -1185,7 +1293,7 @@ class Filter:
raw_embeddings = await loop.run_in_executor(None, self._embedding_function, uncached_texts, None, user)
if isinstance(raw_embeddings, list) and len(raw_embeddings) > 0:
if isinstance(raw_embeddings[0], list):
if isinstance(raw_embeddings[0], (list, np.ndarray)):
new_embeddings = [self._normalize_embedding(emb) for emb in raw_embeddings]
else:
new_embeddings = [self._normalize_embedding(raw_embeddings)]
@@ -1199,7 +1307,8 @@ class Filter:
result_embeddings[original_idx] = embedding
if is_single:
logger.info("📥 User message embedding: cache hit" if not uncached_texts else "💾 User message embedding: generated and cached")
if uncached_texts:
logger.info("💾 User message embedding: generated and cached")
return result_embeddings[0]
else:
valid_count = sum(1 for emb in result_embeddings if emb is not None)