Go · 12262 bytes Raw Blame History
1 package llm
2
3 import (
4 "math"
5 "sort"
6 "sync"
7 )
8
9 // EnsembleSystem combines multiple ML techniques for optimal insult selection
10 type EnsembleSystem struct {
11 mu sync.RWMutex
12 tfidfEngine *TFIDFEngine
13 bm25Engine *BM25Engine // NEW: Industry-standard BM25 ranking
14 markovGen *MarkovGenerator
15 insultScorer *InsultScorer
16 database *InsultDatabase
17 history *InsultHistory
18
19 // Ensemble weights
20 semanticWeight float64
21 tagWeight float64
22 markovWeight float64
23 historicalWeight float64
24
25 // Quality thresholds
26 minSemanticScore float64
27 minTagScore float64
28 minEnsembleScore float64
29
30 // Configuration
31 useBM25 bool // Use BM25 instead of TF-IDF (recommended)
32 trained bool // Training state
33 }
34
35 // EnsembleScore represents a comprehensive scoring of an insult candidate
36 type EnsembleScore struct {
37 Insult string
38 SemanticScore float64 // TF-IDF cosine similarity
39 TagScore float64 // Tag-based matching
40 HistoricalScore float64 // Historical pattern matching
41 NoveltyScore float64 // Avoid repetition
42 PersonalityScore float64 // Personality fit
43 EnsembleScore float64 // Weighted combination
44 Confidence float64 // Confidence calibration
45 Source string // "semantic", "tag", "markov", "ensemble"
46 }
47
48 // NewEnsembleSystem creates a new ensemble learning system
49 func NewEnsembleSystem(db *InsultDatabase, scorer *InsultScorer, hist *InsultHistory) *EnsembleSystem {
50 return &EnsembleSystem{
51 tfidfEngine: NewTFIDFEngine(),
52 bm25Engine: NewBM25Engine(),
53 markovGen: NewMarkovGenerator(2), // Bigram model
54 insultScorer: scorer,
55 database: db,
56 history: hist,
57
58 // Default ensemble weights (can be tuned)
59 semanticWeight: 0.35,
60 tagWeight: 0.30,
61 markovWeight: 0.20,
62 historicalWeight: 0.15,
63
64 // Quality thresholds
65 minSemanticScore: 0.25,
66 minTagScore: 0.30,
67 minEnsembleScore: 0.40,
68
69 // Use BM25 by default (proven better than TF-IDF)
70 useBM25: true,
71 trained: false,
72 }
73 }
74
75 // Train trains all ML components on the insult database
76 func (es *EnsembleSystem) Train() {
77 es.mu.Lock()
78 if es.trained {
79 es.mu.Unlock()
80 return // Already trained
81 }
82 es.trained = true // Mark as training to prevent concurrent attempts
83 es.mu.Unlock()
84
85 // Collect all insult texts
86 insults := make([]string, 0, len(es.database.Insults))
87 for _, insult := range es.database.Insults {
88 insults = append(insults, insult.Text)
89 }
90
91 // Train TF-IDF engine
92 es.tfidfEngine.BuildCorpus(insults)
93
94 // Train BM25 engine (improved ranking algorithm)
95 es.bm25Engine.BuildCorpus(insults)
96
97 // Train Markov generator
98 es.markovGen.Train(insults)
99 }
100
101 // GenerateInsult generates the best possible insult using ensemble methods
102 func (es *EnsembleSystem) GenerateInsult(
103 ctx *SmartFallbackContext,
104 personality string,
105 ) string {
106 // Ensure training is done
107 es.mu.RLock()
108 trained := es.trained
109 es.mu.RUnlock()
110 if !trained {
111 es.Train()
112 }
113
114 // Get candidates from multiple sources
115 candidates := es.getAllCandidates(ctx, personality)
116
117 if len(candidates) == 0 {
118 // Last resort: generate using Markov
119 return es.markovGen.Blend(ctx)
120 }
121
122 // Sort by ensemble score
123 sort.Slice(candidates, func(i, j int) bool {
124 return candidates[i].EnsembleScore > candidates[j].EnsembleScore
125 })
126
127 // Get best candidate
128 best := candidates[0]
129
130 // If best score is still low, try Markov generation
131 if best.EnsembleScore < es.minEnsembleScore {
132 markovInsult := es.markovGen.Blend(ctx)
133 if markovInsult != "" && len(markovInsult) > 20 {
134 // Record and return Markov-generated insult
135 es.history.RecordInsult(markovInsult, ctx.FullCommand, 0.5)
136 return markovInsult
137 }
138 }
139
140 // Record selected insult
141 es.history.RecordInsult(best.Insult, ctx.FullCommand, best.EnsembleScore)
142
143 return best.Insult
144 }
145
146 // getAllCandidates gets scored candidates from all sources
147 func (es *EnsembleSystem) getAllCandidates(
148 ctx *SmartFallbackContext,
149 personality string,
150 ) []EnsembleScore {
151 candidates := make([]EnsembleScore, 0, len(es.database.Insults))
152
153 // Score all insults in database using ensemble
154 for _, insult := range es.database.Insults {
155 score := es.scoreInsult(insult, ctx, personality)
156
157 // Only include if above minimum thresholds
158 if score.EnsembleScore >= es.minEnsembleScore {
159 candidates = append(candidates, score)
160 }
161 }
162
163 return candidates
164 }
165
166 // scoreInsult scores a single insult using ensemble methods
167 func (es *EnsembleSystem) scoreInsult(
168 insult TaggedInsult,
169 ctx *SmartFallbackContext,
170 personality string,
171 ) EnsembleScore {
172 score := EnsembleScore{
173 Insult: insult.Text,
174 Source: "ensemble",
175 }
176
177 // 1. Semantic similarity score (TF-IDF)
178 score.SemanticScore = es.calculateSemanticScore(ctx, insult)
179
180 // 2. Tag-based score (existing system)
181 score.TagScore = es.calculateTagScore(ctx, insult)
182
183 // 3. Historical pattern score
184 score.HistoricalScore = es.calculateHistoricalScore(ctx, insult)
185
186 // 4. Novelty score (avoid repetition)
187 score.NoveltyScore = es.history.GetNoveltyScore(insult.Text)
188
189 // 5. Personality fit score
190 score.PersonalityScore = es.calculatePersonalityScore(insult, personality)
191
192 // Calculate weighted ensemble score
193 score.EnsembleScore = (score.SemanticScore * es.semanticWeight) +
194 (score.TagScore * es.tagWeight) +
195 (score.HistoricalScore * es.historicalWeight) +
196 (score.NoveltyScore * 0.10) +
197 (score.PersonalityScore * 0.05)
198
199 // Apply insult base weight
200 score.EnsembleScore *= insult.Weight
201
202 // Calculate confidence (how much methods agree)
203 score.Confidence = es.calculateConfidence(score)
204
205 // Boost score if high confidence
206 if score.Confidence > 0.8 {
207 score.EnsembleScore *= 1.1
208 }
209
210 return score
211 }
212
213 // calculateSemanticScore uses BM25 or TF-IDF for semantic similarity
214 func (es *EnsembleSystem) calculateSemanticScore(
215 ctx *SmartFallbackContext,
216 insult TaggedInsult,
217 ) float64 {
218 // Create a rich context description
219 contextText := es.buildContextText(ctx)
220
221 var score float64
222
223 if es.useBM25 {
224 // Use BM25 (industry standard, proven better)
225 // BM25 scores are typically in range 0-10, normalize to 0-1
226 rawScore := es.bm25Engine.Score(contextText, insult.Text)
227 score = math.Min(rawScore/10.0, 1.0)
228 } else {
229 // Use TF-IDF (for comparison)
230 similarity := es.tfidfEngine.CalculateSemanticScore(contextText, insult.Text)
231 score = sigmoid(similarity * 2.0)
232 }
233
234 return score
235 }
236
237 // buildContextText creates rich text representation of context
238 func (es *EnsembleSystem) buildContextText(ctx *SmartFallbackContext) string {
239 var parts []string
240
241 // Add command and type
242 parts = append(parts, ctx.FullCommand)
243 parts = append(parts, ctx.CommandType)
244 parts = append(parts, ctx.Command)
245
246 // Add error pattern
247 if ctx.ErrorPattern != "" {
248 parts = append(parts, ctx.ErrorPattern)
249 }
250
251 // Add project type
252 if ctx.ProjectType != "" {
253 parts = append(parts, ctx.ProjectType)
254 }
255
256 // Add git branch
257 if ctx.GitBranch != "" {
258 parts = append(parts, ctx.GitBranch)
259 }
260
261 // Add time context
262 if ctx.TimeOfDay >= 22 || ctx.TimeOfDay <= 4 {
263 parts = append(parts, "late night coding")
264 }
265
266 // Add CI context
267 if ctx.IsCI {
268 parts = append(parts, "continuous integration", "ci pipeline")
269 }
270
271 // Add repeated failure context
272 if ctx.IsRepeatedFailure {
273 parts = append(parts, "repeated failure", "again", "still failing")
274 }
275
276 return join(parts, " ")
277 }
278
279 // calculateTagScore uses the existing tag-based system
280 func (es *EnsembleSystem) calculateTagScore(
281 ctx *SmartFallbackContext,
282 insult TaggedInsult,
283 ) float64 {
284 // Parse intent
285 parser := NewIntentParser()
286 intent := parser.ParseIntent(ctx.FullCommand)
287
288 // Generate contextual tags
289 contextTags := ContextualTags(ctx, intent)
290
291 // Classify error
292 classifier := NewErrorClassifier()
293 errorCategories := classifier.ClassifyError(ctx.FullCommand, ctx.ExitCode, ctx.ErrorPattern)
294 errorTags := errorCategoriesToTags(errorCategories)
295
296 // Combine tags
297 allTags := append(contextTags, errorTags...)
298
299 // Count matches
300 matches := 0
301 for _, contextTag := range allTags {
302 for _, insultTag := range insult.Tags {
303 if contextTag == insultTag {
304 matches++
305 }
306 }
307 }
308
309 if len(allTags) == 0 {
310 return 0.5
311 }
312
313 // Calculate match ratio
314 score := float64(matches) / float64(len(allTags))
315
316 // Bonus for multiple matches
317 if matches > 2 {
318 score = math.Min(1.0, score*1.2)
319 }
320
321 return score
322 }
323
324 // calculateHistoricalScore uses historical patterns
325 func (es *EnsembleSystem) calculateHistoricalScore(
326 ctx *SmartFallbackContext,
327 insult TaggedInsult,
328 ) float64 {
329 // Check if similar commands have been failed before
330 // For now, use a simple heuristic based on command type
331
332 baseScore := 0.5
333
334 // Boost for matching command type
335 for _, tag := range insult.Tags {
336 if string(tag) == ctx.CommandType {
337 baseScore += 0.2
338 }
339 }
340
341 // Boost for matching error pattern
342 if ctx.ErrorPattern != "" {
343 for _, tag := range insult.Tags {
344 if string(tag) == ctx.ErrorPattern {
345 baseScore += 0.3
346 }
347 }
348 }
349
350 return math.Min(1.0, baseScore)
351 }
352
353 // calculatePersonalityScore ensures insult matches personality
354 func (es *EnsembleSystem) calculatePersonalityScore(
355 insult TaggedInsult,
356 personality string,
357 ) float64 {
358 switch personality {
359 case "mild":
360 if hasTag(insult.Tags, TagMild) {
361 return 1.0
362 }
363 if insult.Severity <= 4 {
364 return 0.8
365 }
366 return 0.3
367
368 case "sarcastic":
369 if hasTag(insult.Tags, TagSarcastic) {
370 return 1.0
371 }
372 if insult.Severity >= 4 && insult.Severity <= 7 {
373 return 0.8
374 }
375 return 0.5
376
377 case "savage":
378 if hasTag(insult.Tags, TagSavage) {
379 return 1.0
380 }
381 if insult.Severity >= 6 {
382 return 0.8
383 }
384 return 0.4
385
386 default:
387 return 0.7
388 }
389 }
390
391 // calculateConfidence measures how much different methods agree
392 func (es *EnsembleSystem) calculateConfidence(score EnsembleScore) float64 {
393 scores := []float64{
394 score.SemanticScore,
395 score.TagScore,
396 score.HistoricalScore,
397 score.NoveltyScore,
398 score.PersonalityScore,
399 }
400
401 // Calculate variance
402 mean := 0.0
403 for _, s := range scores {
404 mean += s
405 }
406 mean /= float64(len(scores))
407
408 variance := 0.0
409 for _, s := range scores {
410 variance += (s - mean) * (s - mean)
411 }
412 variance /= float64(len(scores))
413
414 // Low variance = high confidence (methods agree)
415 // Convert variance to confidence (0-1)
416 confidence := 1.0 - math.Min(variance*4.0, 1.0)
417
418 return confidence
419 }
420
421 // GenerateMarkovInsult generates a novel insult using Markov chains
422 func (es *EnsembleSystem) GenerateMarkovInsult(ctx *SmartFallbackContext) string {
423 if !es.trained {
424 es.Train()
425 }
426
427 return es.markovGen.Blend(ctx)
428 }
429
430 // AnalyzeScoring provides detailed scoring breakdown for debugging
431 func (es *EnsembleSystem) AnalyzeScoring(
432 ctx *SmartFallbackContext,
433 personality string,
434 topN int,
435 ) []EnsembleScore {
436 if !es.trained {
437 es.Train()
438 }
439
440 candidates := es.getAllCandidates(ctx, personality)
441
442 // Sort by ensemble score
443 sort.Slice(candidates, func(i, j int) bool {
444 return candidates[i].EnsembleScore > candidates[j].EnsembleScore
445 })
446
447 if len(candidates) > topN {
448 candidates = candidates[:topN]
449 }
450
451 return candidates
452 }
453
454 // UpdateWeights allows dynamic weight tuning based on feedback
455 func (es *EnsembleSystem) UpdateWeights(
456 semanticW, tagW, markovW, historicalW float64,
457 ) {
458 total := semanticW + tagW + markovW + historicalW
459
460 es.semanticWeight = semanticW / total
461 es.tagWeight = tagW / total
462 es.markovWeight = markovW / total
463 es.historicalWeight = historicalW / total
464 }
465
466 // GetStats returns ensemble system statistics
467 func (es *EnsembleSystem) GetStats() map[string]interface{} {
468 stats := make(map[string]interface{})
469
470 stats["trained"] = es.trained
471 stats["database_size"] = len(es.database.Insults)
472
473 if es.trained {
474 stats["tfidf_vocabulary"] = len(es.tfidfEngine.vocabulary)
475 stats["markov_stats"] = es.markovGen.GetStats()
476 }
477
478 stats["weights"] = map[string]float64{
479 "semantic": es.semanticWeight,
480 "tag": es.tagWeight,
481 "markov": es.markovWeight,
482 "historical": es.historicalWeight,
483 }
484
485 return stats
486 }
487
488 // Helper functions
489
490 func sigmoid(x float64) float64 {
491 return 1.0 / (1.0 + math.Exp(-x))
492 }
493
494 func join(parts []string, sep string) string {
495 result := ""
496 for i, part := range parts {
497 if i > 0 {
498 result += sep
499 }
500 result += part
501 }
502 return result
503 }
504