Coverage for ai_integration/services/embedding_service.py: 99%

103 statements  

« prev     ^ index     » next       coverage.py v7.9.2, created at 2025-07-05 19:26 +0800

1""" 

2Embedding generation service for family content 

3Uses OpenAI text-embedding-3-small (cheaper than Anthropic) 

4""" 

5import hashlib 

6import logging 

7from typing import List, Optional, Dict, Any 

8from django.utils import timezone 

9from django.conf import settings 

10from django.db import models 

11from openai import OpenAI 

12from ..models import EmbeddingCache 

13 

14logger = logging.getLogger(__name__) 

15 

16 

17class EmbeddingService: 

18 """Service for generating and managing content embeddings""" 

19 

20 def __init__(self): 

21 self.client = OpenAI(api_key=getattr(settings, 'OPENAI_API_KEY', '')) 

22 self.model = "text-embedding-3-small" # 1536 dimensions, $0.02/1M tokens 

23 

24 def generate_embedding(self, text: str) -> Optional[List[float]]: 

25 """ 

26 Generate embedding for text content 

27  

28 Args: 

29 text: Text content to embed 

30  

31 Returns: 

32 List of embedding vectors or None if failed 

33 """ 

34 if not text or not text.strip(): 

35 return None 

36 

37 try: 

38 response = self.client.embeddings.create( 

39 model=self.model, 

40 input=text.strip() 

41 ) 

42 

43 embedding = response.data[0].embedding 

44 logger.info(f"Generated embedding for text (length: {len(text)})") 

45 return embedding 

46 

47 except Exception as e: 

48 logger.error(f"Failed to generate embedding: {e}") 

49 return None 

50 

51 def get_content_hash(self, text: str) -> str: 

52 """Generate SHA256 hash of content for caching""" 

53 return hashlib.sha256(text.encode('utf-8')).hexdigest() 

54 

55 def get_or_create_embedding(self, text: str, content_type: str, content_id: int) -> Optional[List[float]]: 

56 """ 

57 Get embedding from cache or generate new one 

58  

59 Args: 

60 text: Text content to embed 

61 content_type: Type of content (story, event, heritage, health) 

62 content_id: ID of the content object 

63  

64 Returns: 

65 List of embedding vectors or None if failed 

66 """ 

67 if not text or not text.strip(): 

68 return None 

69 

70 content_hash = self.get_content_hash(text) 

71 

72 # Try to get from cache first 

73 try: 

74 cached = EmbeddingCache.objects.get(content_hash=content_hash) 

75 logger.info(f"Using cached embedding for {content_type}:{content_id}") 

76 return cached.embedding 

77 except EmbeddingCache.DoesNotExist: 

78 pass 

79 

80 # Generate new embedding 

81 embedding = self.generate_embedding(text) 

82 if embedding: 

83 # Cache the embedding 

84 EmbeddingCache.objects.update_or_create( 

85 content_hash=content_hash, 

86 defaults={ 

87 'content_type': content_type, 

88 'content_id': content_id, 

89 'embedding': embedding, 

90 } 

91 ) 

92 logger.info(f"Cached new embedding for {content_type}:{content_id}") 

93 

94 return embedding 

95 

96 def update_model_embedding(self, instance, force_update: bool = False) -> bool: 

97 """ 

98 Update embedding for a model instance 

99  

100 Args: 

101 instance: Model instance with content_embedding field 

102 force_update: Force regeneration even if embedding exists 

103  

104 Returns: 

105 True if embedding was updated, False otherwise 

106 """ 

107 if not hasattr(instance, 'content_embedding'): 

108 logger.error(f"Model {type(instance).__name__} has no content_embedding field") 

109 return False 

110 

111 # Determine content field and text 

112 content_text = self._extract_content_text(instance) 

113 if not content_text: 

114 logger.warning(f"No content text found for {type(instance).__name__}:{instance.id}") 

115 return False 

116 

117 # Check if update needed 

118 if not force_update and instance.content_embedding and instance.embedding_updated: 

119 content_hash = self.get_content_hash(content_text) 

120 try: 

121 cached = EmbeddingCache.objects.get(content_hash=content_hash) 

122 if cached.embedding == instance.content_embedding: 122 ↛ 129line 122 didn't jump to line 129 because the condition on line 122 was always true

123 logger.info(f"Embedding up to date for {type(instance).__name__}:{instance.id}") 

124 return False 

125 except EmbeddingCache.DoesNotExist: 

126 pass 

127 

128 # Generate/get embedding 

129 content_type = type(instance).__name__.lower() 

130 embedding = self.get_or_create_embedding(content_text, content_type, instance.id) 

131 

132 if embedding: 

133 instance.content_embedding = embedding 

134 instance.embedding_updated = timezone.now() 

135 instance.save(update_fields=['content_embedding', 'embedding_updated']) 

136 logger.info(f"Updated embedding for {content_type}:{instance.id}") 

137 return True 

138 

139 return False 

140 

141 def _extract_content_text(self, instance) -> str: 

142 """Extract text content from model instance for embedding""" 

143 model_name = type(instance).__name__.lower() 

144 

145 if model_name == 'story': 

146 return f"{instance.title}\n\n{instance.content}" 

147 elif model_name == 'event': 

148 return f"{instance.name}\n\n{instance.description}" 

149 elif model_name == 'heritage': 

150 return f"{instance.title}\n\n{instance.description}" 

151 elif model_name == 'health': 

152 return f"{instance.title}\n\n{instance.description}" 

153 elif model_name == 'person': 

154 return instance.bio if instance.bio else instance.name 

155 else: 

156 # Try common field names 

157 for field in ['content', 'description', 'bio', 'title', 'name']: 

158 if hasattr(instance, field): 

159 value = getattr(instance, field) 

160 if value: 160 ↛ 157line 160 didn't jump to line 157 because the condition on line 160 was always true

161 return str(value) 

162 

163 return "" 

164 

165 def bulk_update_embeddings(self, model_class, batch_size: int = 10) -> Dict[str, int]: 

166 """ 

167 Bulk update embeddings for all instances of a model 

168  

169 Args: 

170 model_class: Django model class 

171 batch_size: Number of instances to process at once 

172  

173 Returns: 

174 Dict with statistics: {'updated': int, 'skipped': int, 'failed': int} 

175 """ 

176 stats = {'updated': 0, 'skipped': 0, 'failed': 0} 

177 

178 # Get instances that need embedding updates 

179 instances = model_class.objects.filter( 

180 models.Q(content_embedding__isnull=True) | 

181 models.Q(embedding_updated__isnull=True) 

182 ) 

183 

184 total = instances.count() 

185 logger.info(f"Bulk updating embeddings for {total} {model_class.__name__} instances") 

186 

187 for i in range(0, total, batch_size): 

188 batch = instances[i:i + batch_size] 

189 

190 for instance in batch: 

191 try: 

192 if self.update_model_embedding(instance): 

193 stats['updated'] += 1 

194 else: 

195 stats['skipped'] += 1 

196 except Exception as e: 

197 stats['failed'] += 1 

198 logger.error(f"Failed to update embedding for {model_class.__name__}:{instance.id}: {e}") 

199 

200 logger.info(f"Bulk update complete: {stats}") 

201 return stats 

202 

203 

204# Global service instance 

205embedding_service = EmbeddingService()