Coverage for ai_integration/services/embedding_service.py: 99%
103 statements
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 19:26 +0800
« prev ^ index » next coverage.py v7.9.2, created at 2025-07-05 19:26 +0800
1"""
2Embedding generation service for family content
3Uses OpenAI text-embedding-3-small (cheaper than Anthropic)
4"""
5import hashlib
6import logging
7from typing import List, Optional, Dict, Any
8from django.utils import timezone
9from django.conf import settings
10from django.db import models
11from openai import OpenAI
12from ..models import EmbeddingCache
14logger = logging.getLogger(__name__)
17class EmbeddingService:
18 """Service for generating and managing content embeddings"""
20 def __init__(self):
21 self.client = OpenAI(api_key=getattr(settings, 'OPENAI_API_KEY', ''))
22 self.model = "text-embedding-3-small" # 1536 dimensions, $0.02/1M tokens
24 def generate_embedding(self, text: str) -> Optional[List[float]]:
25 """
26 Generate embedding for text content
28 Args:
29 text: Text content to embed
31 Returns:
32 List of embedding vectors or None if failed
33 """
34 if not text or not text.strip():
35 return None
37 try:
38 response = self.client.embeddings.create(
39 model=self.model,
40 input=text.strip()
41 )
43 embedding = response.data[0].embedding
44 logger.info(f"Generated embedding for text (length: {len(text)})")
45 return embedding
47 except Exception as e:
48 logger.error(f"Failed to generate embedding: {e}")
49 return None
51 def get_content_hash(self, text: str) -> str:
52 """Generate SHA256 hash of content for caching"""
53 return hashlib.sha256(text.encode('utf-8')).hexdigest()
55 def get_or_create_embedding(self, text: str, content_type: str, content_id: int) -> Optional[List[float]]:
56 """
57 Get embedding from cache or generate new one
59 Args:
60 text: Text content to embed
61 content_type: Type of content (story, event, heritage, health)
62 content_id: ID of the content object
64 Returns:
65 List of embedding vectors or None if failed
66 """
67 if not text or not text.strip():
68 return None
70 content_hash = self.get_content_hash(text)
72 # Try to get from cache first
73 try:
74 cached = EmbeddingCache.objects.get(content_hash=content_hash)
75 logger.info(f"Using cached embedding for {content_type}:{content_id}")
76 return cached.embedding
77 except EmbeddingCache.DoesNotExist:
78 pass
80 # Generate new embedding
81 embedding = self.generate_embedding(text)
82 if embedding:
83 # Cache the embedding
84 EmbeddingCache.objects.update_or_create(
85 content_hash=content_hash,
86 defaults={
87 'content_type': content_type,
88 'content_id': content_id,
89 'embedding': embedding,
90 }
91 )
92 logger.info(f"Cached new embedding for {content_type}:{content_id}")
94 return embedding
96 def update_model_embedding(self, instance, force_update: bool = False) -> bool:
97 """
98 Update embedding for a model instance
100 Args:
101 instance: Model instance with content_embedding field
102 force_update: Force regeneration even if embedding exists
104 Returns:
105 True if embedding was updated, False otherwise
106 """
107 if not hasattr(instance, 'content_embedding'):
108 logger.error(f"Model {type(instance).__name__} has no content_embedding field")
109 return False
111 # Determine content field and text
112 content_text = self._extract_content_text(instance)
113 if not content_text:
114 logger.warning(f"No content text found for {type(instance).__name__}:{instance.id}")
115 return False
117 # Check if update needed
118 if not force_update and instance.content_embedding and instance.embedding_updated:
119 content_hash = self.get_content_hash(content_text)
120 try:
121 cached = EmbeddingCache.objects.get(content_hash=content_hash)
122 if cached.embedding == instance.content_embedding: 122 ↛ 129line 122 didn't jump to line 129 because the condition on line 122 was always true
123 logger.info(f"Embedding up to date for {type(instance).__name__}:{instance.id}")
124 return False
125 except EmbeddingCache.DoesNotExist:
126 pass
128 # Generate/get embedding
129 content_type = type(instance).__name__.lower()
130 embedding = self.get_or_create_embedding(content_text, content_type, instance.id)
132 if embedding:
133 instance.content_embedding = embedding
134 instance.embedding_updated = timezone.now()
135 instance.save(update_fields=['content_embedding', 'embedding_updated'])
136 logger.info(f"Updated embedding for {content_type}:{instance.id}")
137 return True
139 return False
141 def _extract_content_text(self, instance) -> str:
142 """Extract text content from model instance for embedding"""
143 model_name = type(instance).__name__.lower()
145 if model_name == 'story':
146 return f"{instance.title}\n\n{instance.content}"
147 elif model_name == 'event':
148 return f"{instance.name}\n\n{instance.description}"
149 elif model_name == 'heritage':
150 return f"{instance.title}\n\n{instance.description}"
151 elif model_name == 'health':
152 return f"{instance.title}\n\n{instance.description}"
153 elif model_name == 'person':
154 return instance.bio if instance.bio else instance.name
155 else:
156 # Try common field names
157 for field in ['content', 'description', 'bio', 'title', 'name']:
158 if hasattr(instance, field):
159 value = getattr(instance, field)
160 if value: 160 ↛ 157line 160 didn't jump to line 157 because the condition on line 160 was always true
161 return str(value)
163 return ""
165 def bulk_update_embeddings(self, model_class, batch_size: int = 10) -> Dict[str, int]:
166 """
167 Bulk update embeddings for all instances of a model
169 Args:
170 model_class: Django model class
171 batch_size: Number of instances to process at once
173 Returns:
174 Dict with statistics: {'updated': int, 'skipped': int, 'failed': int}
175 """
176 stats = {'updated': 0, 'skipped': 0, 'failed': 0}
178 # Get instances that need embedding updates
179 instances = model_class.objects.filter(
180 models.Q(content_embedding__isnull=True) |
181 models.Q(embedding_updated__isnull=True)
182 )
184 total = instances.count()
185 logger.info(f"Bulk updating embeddings for {total} {model_class.__name__} instances")
187 for i in range(0, total, batch_size):
188 batch = instances[i:i + batch_size]
190 for instance in batch:
191 try:
192 if self.update_model_embedding(instance):
193 stats['updated'] += 1
194 else:
195 stats['skipped'] += 1
196 except Exception as e:
197 stats['failed'] += 1
198 logger.error(f"Failed to update embedding for {model_class.__name__}:{instance.id}: {e}")
200 logger.info(f"Bulk update complete: {stats}")
201 return stats
204# Global service instance
205embedding_service = EmbeddingService()