Go · 11555 bytes Raw Blame History
1 package llm
2
3 import (
4 "math"
5 "strings"
6 )
7
8 // ContextualEmbedding represents a context as a vector in semantic space
9 // This is a simplified embedding system that works without external ML libraries
10 type ContextualEmbedding struct {
11 Vector []float64
12 ContextID string
13 Features map[string]float64 // Named features for interpretability
14 Magnitude float64 // Vector magnitude (cached)
15 }
16
17 // EmbeddingEngine creates and compares contextual embeddings
18 type EmbeddingEngine struct {
19 dimensions int
20 vocabulary map[string]int // Word to index mapping
21 idf map[string]float64 // IDF scores for weighting
22 }
23
24 // NewEmbeddingEngine creates a new embedding engine
25 func NewEmbeddingEngine() *EmbeddingEngine {
26 return &EmbeddingEngine{
27 dimensions: 32, // 32-dimensional embedding space
28 vocabulary: make(map[string]int),
29 idf: make(map[string]float64),
30 }
31 }
32
33 // CreateEmbedding creates a vector representation of a context
34 func (ee *EmbeddingEngine) CreateEmbedding(ctx *SmartFallbackContext) *ContextualEmbedding {
35 embedding := &ContextualEmbedding{
36 Vector: make([]float64, ee.dimensions),
37 ContextID: ctx.CommandType + "_" + ctx.ErrorPattern,
38 Features: make(map[string]float64),
39 }
40
41 // Feature extraction and encoding
42 // Each feature maps to specific dimensions in the vector
43
44 // Dimension 0-7: Command type encoding
45 ee.encodeCommandType(ctx, embedding, 0)
46
47 // Dimension 8-15: Error pattern encoding
48 ee.encodeErrorPattern(ctx, embedding, 8)
49
50 // Dimension 16-19: Project context
51 ee.encodeProjectContext(ctx, embedding, 16)
52
53 // Dimension 20-23: Temporal context
54 ee.encodeTemporalContext(ctx, embedding, 20)
55
56 // Dimension 24-27: Environmental context
57 ee.encodeEnvironmentalContext(ctx, embedding, 24)
58
59 // Dimension 28-31: Behavioral patterns
60 ee.encodeBehavioralContext(ctx, embedding, 28)
61
62 // Normalize vector
63 ee.normalizeVector(embedding)
64
65 return embedding
66 }
67
68 // FindSimilarContexts finds contexts similar to the given one
69 func (ee *EmbeddingEngine) FindSimilarContexts(
70 target *ContextualEmbedding,
71 candidates []*ContextualEmbedding,
72 topK int,
73 ) []SimilarContext {
74 similarities := make([]SimilarContext, 0, len(candidates))
75
76 for _, candidate := range candidates {
77 similarity := ee.CosineSimilarity(target, candidate)
78 similarities = append(similarities, SimilarContext{
79 Embedding: candidate,
80 Similarity: similarity,
81 })
82 }
83
84 // Sort by similarity (descending)
85 sortSimilarContexts(similarities)
86
87 if len(similarities) > topK {
88 similarities = similarities[:topK]
89 }
90
91 return similarities
92 }
93
94 // CosineSimilarity calculates cosine similarity between two embeddings
95 func (ee *EmbeddingEngine) CosineSimilarity(e1, e2 *ContextualEmbedding) float64 {
96 if len(e1.Vector) != len(e2.Vector) {
97 return 0.0
98 }
99
100 dotProduct := 0.0
101 for i := range e1.Vector {
102 dotProduct += e1.Vector[i] * e2.Vector[i]
103 }
104
105 // Use cached magnitudes
106 magnitude1 := e1.Magnitude
107 magnitude2 := e2.Magnitude
108
109 if magnitude1 == 0 || magnitude2 == 0 {
110 return 0.0
111 }
112
113 return dotProduct / (magnitude1 * magnitude2)
114 }
115
116 // EuclideanDistance calculates Euclidean distance between embeddings
117 func (ee *EmbeddingEngine) EuclideanDistance(e1, e2 *ContextualEmbedding) float64 {
118 if len(e1.Vector) != len(e2.Vector) {
119 return math.Inf(1)
120 }
121
122 sum := 0.0
123 for i := range e1.Vector {
124 diff := e1.Vector[i] - e2.Vector[i]
125 sum += diff * diff
126 }
127
128 return math.Sqrt(sum)
129 }
130
131 // Encoding functions for different context aspects
132
133 func (ee *EmbeddingEngine) encodeCommandType(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
134 // One-hot-like encoding for common command types
135 commandTypes := map[string]int{
136 "git": 0,
137 "docker": 1,
138 "nodejs": 2,
139 "python": 3,
140 "rust": 4,
141 "golang": 5,
142 "kubernetes": 6,
143 "database": 7,
144 }
145
146 if idx, exists := commandTypes[ctx.CommandType]; exists {
147 emb.Vector[offset+idx] = 1.0
148 emb.Features["command_type_"+ctx.CommandType] = 1.0
149 }
150 }
151
152 func (ee *EmbeddingEngine) encodeErrorPattern(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
153 // Encode error pattern
154 errorPatterns := map[string]int{
155 "permission_denied": 0,
156 "command_not_found": 1,
157 "network_error": 2,
158 "timeout": 3,
159 "syntax_error": 4,
160 "merge_conflict": 5,
161 "build_failure": 6,
162 "test_failure": 7,
163 }
164
165 if idx, exists := errorPatterns[ctx.ErrorPattern]; exists {
166 emb.Vector[offset+idx] = 1.0
167 emb.Features["error_"+ctx.ErrorPattern] = 1.0
168 }
169
170 // Encode exit code (normalized)
171 if ctx.ExitCode > 0 {
172 normalizedExitCode := math.Min(float64(ctx.ExitCode)/255.0, 1.0)
173 emb.Vector[offset+7] = normalizedExitCode
174 emb.Features["exit_code"] = float64(ctx.ExitCode)
175 }
176 }
177
178 func (ee *EmbeddingEngine) encodeProjectContext(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
179 // Encode project type
180 projectTypes := map[string]float64{
181 "node": 0.0,
182 "rust": 0.25,
183 "go": 0.5,
184 "python": 0.75,
185 "java": 1.0,
186 }
187
188 if val, exists := projectTypes[ctx.ProjectType]; exists {
189 emb.Vector[offset] = val
190 emb.Features["project_type"] = val
191 }
192
193 // Encode git branch awareness
194 if ctx.GitBranch != "" {
195 branchScore := 0.0
196 if ctx.GitBranch == "main" || ctx.GitBranch == "master" {
197 branchScore = 1.0 // High risk
198 } else if strings.Contains(ctx.GitBranch, "prod") {
199 branchScore = 0.9
200 } else if strings.Contains(ctx.GitBranch, "develop") {
201 branchScore = 0.5
202 }
203 emb.Vector[offset+1] = branchScore
204 emb.Features["branch_risk"] = branchScore
205 }
206
207 // Encode dependency complexity
208 if ctx.DependencyCount > 0 {
209 // Normalize dependency count (log scale)
210 normalized := math.Log(float64(ctx.DependencyCount)+1) / math.Log(100)
211 emb.Vector[offset+2] = math.Min(normalized, 1.0)
212 emb.Features["dependency_complexity"] = normalized
213 }
214
215 // Has build files
216 if ctx.HasDockerfile || ctx.HasMakefile {
217 emb.Vector[offset+3] = 1.0
218 emb.Features["has_build_system"] = 1.0
219 }
220 }
221
222 func (ee *EmbeddingEngine) encodeTemporalContext(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
223 // Encode time of day (cyclical encoding using sin/cos)
224 hour := float64(ctx.TimeOfDay)
225 hourRadian := (hour / 24.0) * 2.0 * math.Pi
226
227 emb.Vector[offset] = math.Sin(hourRadian)
228 emb.Vector[offset+1] = math.Cos(hourRadian)
229 emb.Features["time_sin"] = emb.Vector[offset]
230 emb.Features["time_cos"] = emb.Vector[offset+1]
231
232 // Late night coding indicator (high value between 22-4)
233 lateNight := 0.0
234 if hour >= 22 || hour <= 4 {
235 lateNight = 1.0
236 }
237 emb.Vector[offset+2] = lateNight
238 emb.Features["late_night"] = lateNight
239
240 // Repeated failure indicator
241 if ctx.IsRepeatedFailure {
242 emb.Vector[offset+3] = 1.0
243 emb.Features["repeated_failure"] = 1.0
244 }
245 }
246
247 func (ee *EmbeddingEngine) encodeEnvironmentalContext(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
248 // CI/CD environment
249 if ctx.IsCI {
250 emb.Vector[offset] = 1.0
251 emb.Features["is_ci"] = 1.0
252
253 // Encode CI provider
254 ciProviders := map[string]float64{
255 "github": 0.2,
256 "gitlab": 0.4,
257 "jenkins": 0.6,
258 "circle": 0.8,
259 }
260 if val, exists := ciProviders[ctx.CIProvider]; exists {
261 emb.Vector[offset+1] = val
262 emb.Features["ci_provider"] = val
263 }
264 }
265
266 // Shell type encoding
267 shells := map[string]float64{
268 "bash": 0.2,
269 "zsh": 0.4,
270 "fish": 0.6,
271 "sh": 0.8,
272 }
273 if val, exists := shells[ctx.Shell]; exists {
274 emb.Vector[offset+2] = val
275 emb.Features["shell"] = val
276 }
277
278 // Command complexity
279 complexityScore := 0.0
280 if ctx.HasPipes {
281 complexityScore += 0.3
282 }
283 if ctx.HasChaining {
284 complexityScore += 0.3
285 }
286 if ctx.CommandLength > 100 {
287 complexityScore += 0.4
288 }
289 emb.Vector[offset+3] = math.Min(complexityScore, 1.0)
290 emb.Features["complexity"] = complexityScore
291 }
292
293 func (ee *EmbeddingEngine) encodeBehavioralContext(ctx *SmartFallbackContext, emb *ContextualEmbedding, offset int) {
294 // Working directory patterns
295 wdPatterns := map[string]float64{
296 "tmp": 0.8,
297 "downloads": 0.6,
298 "desktop": 0.4,
299 }
300
301 wdLower := strings.ToLower(ctx.WorkingDir)
302 for pattern, score := range wdPatterns {
303 if strings.Contains(wdLower, pattern) {
304 emb.Vector[offset] = score
305 emb.Features["wd_pattern"] = score
306 break
307 }
308 }
309
310 // File extensions present
311 if len(ctx.FileExtensions) > 0 {
312 emb.Vector[offset+1] = 1.0
313 emb.Features["has_files"] = 1.0
314 }
315
316 // Numeric arguments (ports, chmod values, etc.)
317 if len(ctx.NumericArgs) > 0 {
318 // Check for risky values
319 riskyNums := map[int]float64{
320 777: 1.0, // chmod 777
321 666: 0.8, // chmod 666
322 9: 0.6, // kill -9
323 }
324 for _, num := range ctx.NumericArgs {
325 if score, exists := riskyNums[num]; exists {
326 emb.Vector[offset+2] = score
327 emb.Features["risky_numeric"] = score
328 break
329 }
330 }
331 }
332
333 // Git-specific context
334 if ctx.CommandType == "git" {
335 if strings.Contains(strings.ToLower(ctx.FullCommand), "force") {
336 emb.Vector[offset+3] = 1.0
337 emb.Features["force_flag"] = 1.0
338 }
339 }
340 }
341
342 func (ee *EmbeddingEngine) normalizeVector(emb *ContextualEmbedding) {
343 // Calculate magnitude
344 sumSquares := 0.0
345 for _, val := range emb.Vector {
346 sumSquares += val * val
347 }
348 emb.Magnitude = math.Sqrt(sumSquares)
349
350 // Normalize (make unit vector)
351 if emb.Magnitude > 0 {
352 for i := range emb.Vector {
353 emb.Vector[i] /= emb.Magnitude
354 }
355 // Update magnitude to 1.0 after normalization
356 emb.Magnitude = 1.0
357 }
358 }
359
360 // SimilarContext represents a similar context with similarity score
361 type SimilarContext struct {
362 Embedding *ContextualEmbedding
363 Similarity float64
364 }
365
366 func sortSimilarContexts(contexts []SimilarContext) {
367 // Bubble sort by similarity (descending)
368 n := len(contexts)
369 for i := 0; i < n-1; i++ {
370 for j := 0; j < n-i-1; j++ {
371 if contexts[j].Similarity < contexts[j+1].Similarity {
372 contexts[j], contexts[j+1] = contexts[j+1], contexts[j]
373 }
374 }
375 }
376 }
377
378 // GetFeatureImportance returns which features contributed most to the embedding
379 func (emb *ContextualEmbedding) GetFeatureImportance() []FeatureImportance {
380 features := make([]FeatureImportance, 0, len(emb.Features))
381
382 for name, value := range emb.Features {
383 features = append(features, FeatureImportance{
384 Name: name,
385 Value: value,
386 })
387 }
388
389 // Sort by absolute value (descending)
390 sortFeatureImportance(features)
391
392 return features
393 }
394
395 // FeatureImportance represents a feature and its value
396 type FeatureImportance struct {
397 Name string
398 Value float64
399 }
400
401 func sortFeatureImportance(features []FeatureImportance) {
402 n := len(features)
403 for i := 0; i < n-1; i++ {
404 for j := 0; j < n-i-1; j++ {
405 absI := features[i].Value
406 if absI < 0 {
407 absI = -absI
408 }
409 absJ := features[j].Value
410 if absJ < 0 {
411 absJ = -absJ
412 }
413
414 if absJ < absI {
415 features[i], features[j] = features[j], features[i]
416 }
417 }
418 }
419 }
420
421 // ExplainSimilarity explains why two contexts are similar
422 func (ee *EmbeddingEngine) ExplainSimilarity(e1, e2 *ContextualEmbedding) string {
423 explanation := "Similarity Breakdown:\n"
424
425 // Compare feature by feature
426 sharedFeatures := make([]string, 0)
427 for feature := range e1.Features {
428 if _, exists := e2.Features[feature]; exists {
429 sharedFeatures = append(sharedFeatures, feature)
430 }
431 }
432
433 if len(sharedFeatures) > 0 {
434 explanation += "Shared features: "
435 for i, feature := range sharedFeatures {
436 if i > 0 {
437 explanation += ", "
438 }
439 explanation += feature
440 }
441 explanation += "\n"
442 }
443
444 // Calculate similarity
445 similarity := ee.CosineSimilarity(e1, e2)
446 explanation += "Cosine similarity: "
447 explanation += formatFloat(similarity)
448 explanation += "\n"
449
450 return explanation
451 }
452