| 1 | package llm |
| 2 | |
| 3 | import ( |
| 4 | "math" |
| 5 | "sync" |
| 6 | ) |
| 7 | |
| 8 | // BM25Engine implements BM25 ranking algorithm (superior to basic TF-IDF) |
| 9 | // BM25 is the industry standard for text search and ranking |
| 10 | type BM25Engine struct { |
| 11 | mu sync.RWMutex |
| 12 | vocabulary map[string]int // word -> index |
| 13 | idf map[string]float64 // word -> inverse document frequency |
| 14 | docLengths []int // document lengths |
| 15 | avgDocLength float64 // average document length |
| 16 | documentCount int |
| 17 | |
| 18 | // BM25 parameters (tunable) |
| 19 | k1 float64 // term frequency saturation parameter (typical: 1.2-2.0) |
| 20 | b float64 // document length normalization (typical: 0.75) |
| 21 | |
| 22 | ngramRange [2]int // min and max n-gram size |
| 23 | } |
| 24 | |
| 25 | // NewBM25Engine creates a new BM25 ranking engine |
| 26 | func NewBM25Engine() *BM25Engine { |
| 27 | return &BM25Engine{ |
| 28 | vocabulary: make(map[string]int), |
| 29 | idf: make(map[string]float64), |
| 30 | docLengths: make([]int, 0), |
| 31 | documentCount: 0, |
| 32 | |
| 33 | // Standard BM25 parameters (Okapi BM25) |
| 34 | k1: 1.5, // Typical range: 1.2-2.0 |
| 35 | b: 0.75, // Typical range: 0.5-0.9 |
| 36 | |
| 37 | ngramRange: [2]int{1, 3}, // unigrams, bigrams, trigrams |
| 38 | } |
| 39 | } |
| 40 | |
| 41 | // SetParameters allows tuning of BM25 parameters |
| 42 | func (engine *BM25Engine) SetParameters(k1, b float64) { |
| 43 | engine.k1 = k1 |
| 44 | engine.b = b |
| 45 | } |
| 46 | |
| 47 | // BuildCorpus builds the BM25 corpus from documents |
| 48 | func (engine *BM25Engine) BuildCorpus(documents []string) { |
| 49 | engine.mu.Lock() |
| 50 | defer engine.mu.Unlock() |
| 51 | |
| 52 | // First pass: extract terms and calculate document frequencies |
| 53 | documentFreq := make(map[string]int) |
| 54 | engine.docLengths = make([]int, len(documents)) |
| 55 | totalLength := 0 |
| 56 | |
| 57 | for docIdx, doc := range documents { |
| 58 | tokens := engine.extractNGrams(doc) |
| 59 | engine.docLengths[docIdx] = len(tokens) |
| 60 | totalLength += len(tokens) |
| 61 | |
| 62 | // Track which terms appear in this document |
| 63 | seen := make(map[string]bool) |
| 64 | for _, token := range tokens { |
| 65 | if !seen[token] { |
| 66 | documentFreq[token]++ |
| 67 | seen[token] = true |
| 68 | } |
| 69 | |
| 70 | if _, exists := engine.vocabulary[token]; !exists { |
| 71 | engine.vocabulary[token] = len(engine.vocabulary) |
| 72 | } |
| 73 | } |
| 74 | } |
| 75 | |
| 76 | engine.documentCount = len(documents) |
| 77 | engine.avgDocLength = float64(totalLength) / float64(engine.documentCount) |
| 78 | |
| 79 | // Calculate IDF for each term using BM25 IDF formula |
| 80 | // IDF = log((N - df + 0.5) / (df + 0.5) + 1) |
| 81 | // This is the Robertson-Sparck Jones formula |
| 82 | for term, df := range documentFreq { |
| 83 | N := float64(engine.documentCount) |
| 84 | numerator := N - float64(df) + 0.5 |
| 85 | denominator := float64(df) + 0.5 |
| 86 | engine.idf[term] = math.Log((numerator / denominator) + 1.0) |
| 87 | } |
| 88 | } |
| 89 | |
| 90 | // extractNGrams extracts n-grams from text (same as TF-IDF) |
| 91 | func (engine *BM25Engine) extractNGrams(text string) []string { |
| 92 | text = toLowerSimple(text) |
| 93 | words := engine.tokenize(text) |
| 94 | |
| 95 | var ngrams []string |
| 96 | |
| 97 | for n := engine.ngramRange[0]; n <= engine.ngramRange[1]; n++ { |
| 98 | if n > len(words) { |
| 99 | break |
| 100 | } |
| 101 | |
| 102 | for i := 0; i <= len(words)-n; i++ { |
| 103 | ngram := joinWords(words[i:i+n], " ") |
| 104 | ngrams = append(ngrams, ngram) |
| 105 | } |
| 106 | } |
| 107 | |
| 108 | return ngrams |
| 109 | } |
| 110 | |
| 111 | // tokenize splits text into words |
| 112 | func (engine *BM25Engine) tokenize(text string) []string { |
| 113 | var words []string |
| 114 | var currentWord string |
| 115 | |
| 116 | for _, r := range text { |
| 117 | if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' { |
| 118 | currentWord += string(r) |
| 119 | } else { |
| 120 | if len(currentWord) > 1 { // Skip single characters |
| 121 | words = append(words, currentWord) |
| 122 | } |
| 123 | currentWord = "" |
| 124 | } |
| 125 | } |
| 126 | |
| 127 | if len(currentWord) > 1 { |
| 128 | words = append(words, currentWord) |
| 129 | } |
| 130 | |
| 131 | return words |
| 132 | } |
| 133 | |
| 134 | // Score calculates BM25 score for a query against a document |
| 135 | func (engine *BM25Engine) Score(query string, document string) float64 { |
| 136 | engine.mu.RLock() |
| 137 | defer engine.mu.RUnlock() |
| 138 | |
| 139 | queryTerms := engine.extractNGrams(query) |
| 140 | docTerms := engine.extractNGrams(document) |
| 141 | |
| 142 | // Calculate term frequencies in document |
| 143 | termFreq := make(map[string]int) |
| 144 | for _, term := range docTerms { |
| 145 | termFreq[term]++ |
| 146 | } |
| 147 | |
| 148 | docLength := len(docTerms) |
| 149 | |
| 150 | // Calculate BM25 score |
| 151 | score := 0.0 |
| 152 | |
| 153 | // Track query terms we've processed (unique terms only) |
| 154 | seenQuery := make(map[string]bool) |
| 155 | |
| 156 | for _, queryTerm := range queryTerms { |
| 157 | if seenQuery[queryTerm] { |
| 158 | continue |
| 159 | } |
| 160 | seenQuery[queryTerm] = true |
| 161 | |
| 162 | // Get IDF for this term |
| 163 | idf, exists := engine.idf[queryTerm] |
| 164 | if !exists { |
| 165 | // Term not in vocabulary - use a small IDF |
| 166 | idf = math.Log(float64(engine.documentCount) + 1.0) |
| 167 | } |
| 168 | |
| 169 | // Get term frequency in document |
| 170 | tf := float64(termFreq[queryTerm]) |
| 171 | |
| 172 | // BM25 formula: |
| 173 | // score = IDF(qi) × (f(qi, D) × (k1 + 1)) / (f(qi, D) + k1 × (1 - b + b × |D| / avgdl)) |
| 174 | |
| 175 | numerator := tf * (engine.k1 + 1.0) |
| 176 | denominator := tf + engine.k1*(1.0-engine.b+engine.b*float64(docLength)/engine.avgDocLength) |
| 177 | |
| 178 | termScore := idf * (numerator / denominator) |
| 179 | score += termScore |
| 180 | } |
| 181 | |
| 182 | return score |
| 183 | } |
| 184 | |
| 185 | // ScoreMultiple scores a query against multiple documents |
| 186 | func (engine *BM25Engine) ScoreMultiple(query string, documents []string) []BM25Score { |
| 187 | scores := make([]BM25Score, len(documents)) |
| 188 | |
| 189 | for i, doc := range documents { |
| 190 | scores[i] = BM25Score{ |
| 191 | Index: i, |
| 192 | Document: doc, |
| 193 | Score: engine.Score(query, doc), |
| 194 | } |
| 195 | } |
| 196 | |
| 197 | // Sort by score descending |
| 198 | sortBM25Scores(scores) |
| 199 | |
| 200 | return scores |
| 201 | } |
| 202 | |
| 203 | // FindTopK returns the top K documents for a query |
| 204 | func (engine *BM25Engine) FindTopK(query string, documents []string, k int) []BM25Score { |
| 205 | scores := engine.ScoreMultiple(query, documents) |
| 206 | |
| 207 | if len(scores) > k { |
| 208 | scores = scores[:k] |
| 209 | } |
| 210 | |
| 211 | return scores |
| 212 | } |
| 213 | |
| 214 | // BM25Score represents a scored document |
| 215 | type BM25Score struct { |
| 216 | Index int |
| 217 | Document string |
| 218 | Score float64 |
| 219 | } |
| 220 | |
| 221 | // sortBM25Scores sorts scores in descending order (bubble sort for simplicity) |
| 222 | func sortBM25Scores(scores []BM25Score) { |
| 223 | n := len(scores) |
| 224 | for i := 0; i < n-1; i++ { |
| 225 | for j := 0; j < n-i-1; j++ { |
| 226 | if scores[j].Score < scores[j+1].Score { |
| 227 | scores[j], scores[j+1] = scores[j+1], scores[j] |
| 228 | } |
| 229 | } |
| 230 | } |
| 231 | } |
| 232 | |
| 233 | // Explanation generates a human-readable explanation of the score |
| 234 | func (engine *BM25Engine) Explanation(query string, document string) string { |
| 235 | queryTerms := engine.extractNGrams(query) |
| 236 | docTerms := engine.extractNGrams(document) |
| 237 | |
| 238 | termFreq := make(map[string]int) |
| 239 | for _, term := range docTerms { |
| 240 | termFreq[term]++ |
| 241 | } |
| 242 | |
| 243 | docLength := len(docTerms) |
| 244 | |
| 245 | explanation := "BM25 Score Breakdown:\n" |
| 246 | explanation += "=====================\n\n" |
| 247 | |
| 248 | totalScore := 0.0 |
| 249 | seenQuery := make(map[string]bool) |
| 250 | |
| 251 | for _, queryTerm := range queryTerms { |
| 252 | if seenQuery[queryTerm] { |
| 253 | continue |
| 254 | } |
| 255 | seenQuery[queryTerm] = true |
| 256 | |
| 257 | if termFreq[queryTerm] == 0 { |
| 258 | continue // Term not in document |
| 259 | } |
| 260 | |
| 261 | idf := engine.idf[queryTerm] |
| 262 | tf := float64(termFreq[queryTerm]) |
| 263 | |
| 264 | numerator := tf * (engine.k1 + 1.0) |
| 265 | denominator := tf + engine.k1*(1.0-engine.b+engine.b*float64(docLength)/engine.avgDocLength) |
| 266 | termScore := idf * (numerator / denominator) |
| 267 | |
| 268 | totalScore += termScore |
| 269 | |
| 270 | explanation += formatString("Term: '%s'\n", queryTerm) |
| 271 | explanation += formatString(" TF: %d (occurs %d times)\n", termFreq[queryTerm], termFreq[queryTerm]) |
| 272 | explanation += formatString(" IDF: %.4f\n", idf) |
| 273 | explanation += formatString(" BM25 component: %.4f\n", termScore) |
| 274 | explanation += "\n" |
| 275 | } |
| 276 | |
| 277 | explanation += formatString("Total BM25 Score: %.4f\n", totalScore) |
| 278 | explanation += formatString("Document length: %d (avg: %.1f)\n", docLength, engine.avgDocLength) |
| 279 | |
| 280 | return explanation |
| 281 | } |
| 282 | |
| 283 | // CompareWithTFIDF compares BM25 with basic TF-IDF for analysis |
| 284 | func (engine *BM25Engine) CompareWithTFIDF(query string, document string, tfidfScore float64) string { |
| 285 | bm25Score := engine.Score(query, document) |
| 286 | |
| 287 | comparison := "BM25 vs TF-IDF Comparison:\n" |
| 288 | comparison += "===========================\n\n" |
| 289 | comparison += formatString("BM25 Score: %.4f\n", bm25Score) |
| 290 | comparison += formatString("TF-IDF Score: %.4f\n", tfidfScore) |
| 291 | |
| 292 | diff := bm25Score - tfidfScore |
| 293 | percentDiff := (diff / tfidfScore) * 100 |
| 294 | |
| 295 | if diff > 0 { |
| 296 | comparison += formatString("Difference: +%.4f (+%.1f%%)\n", diff, percentDiff) |
| 297 | comparison += "✅ BM25 scores higher (better)\n" |
| 298 | } else { |
| 299 | comparison += formatString("Difference: %.4f (%.1f%%)\n", diff, percentDiff) |
| 300 | comparison += "⚠️ TF-IDF scores higher\n" |
| 301 | } |
| 302 | |
| 303 | comparison += "\nWhy BM25 is generally better:\n" |
| 304 | comparison += "- Term frequency saturation (diminishing returns)\n" |
| 305 | comparison += "- Document length normalization (fairer comparison)\n" |
| 306 | comparison += "- More sophisticated IDF formula\n" |
| 307 | comparison += "- Industry standard for search engines\n" |
| 308 | |
| 309 | return comparison |
| 310 | } |
| 311 | |
| 312 | // Helper functions |
| 313 | |
| 314 | func toLowerSimple(s string) string { |
| 315 | result := "" |
| 316 | for _, r := range s { |
| 317 | if r >= 'A' && r <= 'Z' { |
| 318 | result += string(r + 32) |
| 319 | } else { |
| 320 | result += string(r) |
| 321 | } |
| 322 | } |
| 323 | return result |
| 324 | } |
| 325 | |
| 326 | func joinWords(words []string, sep string) string { |
| 327 | if len(words) == 0 { |
| 328 | return "" |
| 329 | } |
| 330 | |
| 331 | result := words[0] |
| 332 | for i := 1; i < len(words); i++ { |
| 333 | result += sep + words[i] |
| 334 | } |
| 335 | return result |
| 336 | } |
| 337 | |
| 338 | func formatString(format string, args ...interface{}) string { |
| 339 | // Simple sprintf equivalent for basic formatting |
| 340 | // This is a simplified version - in production use fmt.Sprintf |
| 341 | result := format |
| 342 | for _, arg := range args { |
| 343 | switch v := arg.(type) { |
| 344 | case string: |
| 345 | result = replaceFirst(result, "%s", v) |
| 346 | case int: |
| 347 | result = replaceFirst(result, "%d", intToString(v)) |
| 348 | case float64: |
| 349 | // Simple float formatting |
| 350 | result = replaceFirst(result, "%.4f", floatToString(v, 4)) |
| 351 | result = replaceFirst(result, "%.1f", floatToString(v, 1)) |
| 352 | } |
| 353 | } |
| 354 | return result |
| 355 | } |
| 356 | |
| 357 | func replaceFirst(s, old, new string) string { |
| 358 | idx := findSubstring(s, old) |
| 359 | if idx < 0 { |
| 360 | return s |
| 361 | } |
| 362 | return s[:idx] + new + s[idx+len(old):] |
| 363 | } |
| 364 | |
| 365 | func intToString(n int) string { |
| 366 | if n == 0 { |
| 367 | return "0" |
| 368 | } |
| 369 | |
| 370 | negative := n < 0 |
| 371 | if negative { |
| 372 | n = -n |
| 373 | } |
| 374 | |
| 375 | digits := "" |
| 376 | for n > 0 { |
| 377 | digits = string('0'+rune(n%10)) + digits |
| 378 | n /= 10 |
| 379 | } |
| 380 | |
| 381 | if negative { |
| 382 | digits = "-" + digits |
| 383 | } |
| 384 | |
| 385 | return digits |
| 386 | } |
| 387 | |
| 388 | func floatToString(f float64, precision int) string { |
| 389 | // Simple float to string conversion |
| 390 | intPart := int(f) |
| 391 | fracPart := f - float64(intPart) |
| 392 | |
| 393 | result := intToString(intPart) + "." |
| 394 | |
| 395 | for i := 0; i < precision; i++ { |
| 396 | fracPart *= 10 |
| 397 | digit := int(fracPart) % 10 |
| 398 | result += string('0' + rune(digit)) |
| 399 | } |
| 400 | |
| 401 | return result |
| 402 | } |
| 403 |