Go · 10144 bytes Raw Blame History
1 package llm
2
3 import (
4 "math"
5 "sync"
6 )
7
8 // BM25Engine implements BM25 ranking algorithm (superior to basic TF-IDF)
9 // BM25 is the industry standard for text search and ranking
10 type BM25Engine struct {
11 mu sync.RWMutex
12 vocabulary map[string]int // word -> index
13 idf map[string]float64 // word -> inverse document frequency
14 docLengths []int // document lengths
15 avgDocLength float64 // average document length
16 documentCount int
17
18 // BM25 parameters (tunable)
19 k1 float64 // term frequency saturation parameter (typical: 1.2-2.0)
20 b float64 // document length normalization (typical: 0.75)
21
22 ngramRange [2]int // min and max n-gram size
23 }
24
25 // NewBM25Engine creates a new BM25 ranking engine
26 func NewBM25Engine() *BM25Engine {
27 return &BM25Engine{
28 vocabulary: make(map[string]int),
29 idf: make(map[string]float64),
30 docLengths: make([]int, 0),
31 documentCount: 0,
32
33 // Standard BM25 parameters (Okapi BM25)
34 k1: 1.5, // Typical range: 1.2-2.0
35 b: 0.75, // Typical range: 0.5-0.9
36
37 ngramRange: [2]int{1, 3}, // unigrams, bigrams, trigrams
38 }
39 }
40
41 // SetParameters allows tuning of BM25 parameters
42 func (engine *BM25Engine) SetParameters(k1, b float64) {
43 engine.k1 = k1
44 engine.b = b
45 }
46
47 // BuildCorpus builds the BM25 corpus from documents
48 func (engine *BM25Engine) BuildCorpus(documents []string) {
49 engine.mu.Lock()
50 defer engine.mu.Unlock()
51
52 // First pass: extract terms and calculate document frequencies
53 documentFreq := make(map[string]int)
54 engine.docLengths = make([]int, len(documents))
55 totalLength := 0
56
57 for docIdx, doc := range documents {
58 tokens := engine.extractNGrams(doc)
59 engine.docLengths[docIdx] = len(tokens)
60 totalLength += len(tokens)
61
62 // Track which terms appear in this document
63 seen := make(map[string]bool)
64 for _, token := range tokens {
65 if !seen[token] {
66 documentFreq[token]++
67 seen[token] = true
68 }
69
70 if _, exists := engine.vocabulary[token]; !exists {
71 engine.vocabulary[token] = len(engine.vocabulary)
72 }
73 }
74 }
75
76 engine.documentCount = len(documents)
77 engine.avgDocLength = float64(totalLength) / float64(engine.documentCount)
78
79 // Calculate IDF for each term using BM25 IDF formula
80 // IDF = log((N - df + 0.5) / (df + 0.5) + 1)
81 // This is the Robertson-Sparck Jones formula
82 for term, df := range documentFreq {
83 N := float64(engine.documentCount)
84 numerator := N - float64(df) + 0.5
85 denominator := float64(df) + 0.5
86 engine.idf[term] = math.Log((numerator / denominator) + 1.0)
87 }
88 }
89
90 // extractNGrams extracts n-grams from text (same as TF-IDF)
91 func (engine *BM25Engine) extractNGrams(text string) []string {
92 text = toLowerSimple(text)
93 words := engine.tokenize(text)
94
95 var ngrams []string
96
97 for n := engine.ngramRange[0]; n <= engine.ngramRange[1]; n++ {
98 if n > len(words) {
99 break
100 }
101
102 for i := 0; i <= len(words)-n; i++ {
103 ngram := joinWords(words[i:i+n], " ")
104 ngrams = append(ngrams, ngram)
105 }
106 }
107
108 return ngrams
109 }
110
111 // tokenize splits text into words
112 func (engine *BM25Engine) tokenize(text string) []string {
113 var words []string
114 var currentWord string
115
116 for _, r := range text {
117 if (r >= 'a' && r <= 'z') || (r >= '0' && r <= '9') || r == '-' || r == '_' {
118 currentWord += string(r)
119 } else {
120 if len(currentWord) > 1 { // Skip single characters
121 words = append(words, currentWord)
122 }
123 currentWord = ""
124 }
125 }
126
127 if len(currentWord) > 1 {
128 words = append(words, currentWord)
129 }
130
131 return words
132 }
133
134 // Score calculates BM25 score for a query against a document
135 func (engine *BM25Engine) Score(query string, document string) float64 {
136 engine.mu.RLock()
137 defer engine.mu.RUnlock()
138
139 queryTerms := engine.extractNGrams(query)
140 docTerms := engine.extractNGrams(document)
141
142 // Calculate term frequencies in document
143 termFreq := make(map[string]int)
144 for _, term := range docTerms {
145 termFreq[term]++
146 }
147
148 docLength := len(docTerms)
149
150 // Calculate BM25 score
151 score := 0.0
152
153 // Track query terms we've processed (unique terms only)
154 seenQuery := make(map[string]bool)
155
156 for _, queryTerm := range queryTerms {
157 if seenQuery[queryTerm] {
158 continue
159 }
160 seenQuery[queryTerm] = true
161
162 // Get IDF for this term
163 idf, exists := engine.idf[queryTerm]
164 if !exists {
165 // Term not in vocabulary - use a small IDF
166 idf = math.Log(float64(engine.documentCount) + 1.0)
167 }
168
169 // Get term frequency in document
170 tf := float64(termFreq[queryTerm])
171
172 // BM25 formula:
173 // score = IDF(qi) × (f(qi, D) × (k1 + 1)) / (f(qi, D) + k1 × (1 - b + b × |D| / avgdl))
174
175 numerator := tf * (engine.k1 + 1.0)
176 denominator := tf + engine.k1*(1.0-engine.b+engine.b*float64(docLength)/engine.avgDocLength)
177
178 termScore := idf * (numerator / denominator)
179 score += termScore
180 }
181
182 return score
183 }
184
185 // ScoreMultiple scores a query against multiple documents
186 func (engine *BM25Engine) ScoreMultiple(query string, documents []string) []BM25Score {
187 scores := make([]BM25Score, len(documents))
188
189 for i, doc := range documents {
190 scores[i] = BM25Score{
191 Index: i,
192 Document: doc,
193 Score: engine.Score(query, doc),
194 }
195 }
196
197 // Sort by score descending
198 sortBM25Scores(scores)
199
200 return scores
201 }
202
203 // FindTopK returns the top K documents for a query
204 func (engine *BM25Engine) FindTopK(query string, documents []string, k int) []BM25Score {
205 scores := engine.ScoreMultiple(query, documents)
206
207 if len(scores) > k {
208 scores = scores[:k]
209 }
210
211 return scores
212 }
213
214 // BM25Score represents a scored document
215 type BM25Score struct {
216 Index int
217 Document string
218 Score float64
219 }
220
221 // sortBM25Scores sorts scores in descending order (bubble sort for simplicity)
222 func sortBM25Scores(scores []BM25Score) {
223 n := len(scores)
224 for i := 0; i < n-1; i++ {
225 for j := 0; j < n-i-1; j++ {
226 if scores[j].Score < scores[j+1].Score {
227 scores[j], scores[j+1] = scores[j+1], scores[j]
228 }
229 }
230 }
231 }
232
233 // Explanation generates a human-readable explanation of the score
234 func (engine *BM25Engine) Explanation(query string, document string) string {
235 queryTerms := engine.extractNGrams(query)
236 docTerms := engine.extractNGrams(document)
237
238 termFreq := make(map[string]int)
239 for _, term := range docTerms {
240 termFreq[term]++
241 }
242
243 docLength := len(docTerms)
244
245 explanation := "BM25 Score Breakdown:\n"
246 explanation += "=====================\n\n"
247
248 totalScore := 0.0
249 seenQuery := make(map[string]bool)
250
251 for _, queryTerm := range queryTerms {
252 if seenQuery[queryTerm] {
253 continue
254 }
255 seenQuery[queryTerm] = true
256
257 if termFreq[queryTerm] == 0 {
258 continue // Term not in document
259 }
260
261 idf := engine.idf[queryTerm]
262 tf := float64(termFreq[queryTerm])
263
264 numerator := tf * (engine.k1 + 1.0)
265 denominator := tf + engine.k1*(1.0-engine.b+engine.b*float64(docLength)/engine.avgDocLength)
266 termScore := idf * (numerator / denominator)
267
268 totalScore += termScore
269
270 explanation += formatString("Term: '%s'\n", queryTerm)
271 explanation += formatString(" TF: %d (occurs %d times)\n", termFreq[queryTerm], termFreq[queryTerm])
272 explanation += formatString(" IDF: %.4f\n", idf)
273 explanation += formatString(" BM25 component: %.4f\n", termScore)
274 explanation += "\n"
275 }
276
277 explanation += formatString("Total BM25 Score: %.4f\n", totalScore)
278 explanation += formatString("Document length: %d (avg: %.1f)\n", docLength, engine.avgDocLength)
279
280 return explanation
281 }
282
283 // CompareWithTFIDF compares BM25 with basic TF-IDF for analysis
284 func (engine *BM25Engine) CompareWithTFIDF(query string, document string, tfidfScore float64) string {
285 bm25Score := engine.Score(query, document)
286
287 comparison := "BM25 vs TF-IDF Comparison:\n"
288 comparison += "===========================\n\n"
289 comparison += formatString("BM25 Score: %.4f\n", bm25Score)
290 comparison += formatString("TF-IDF Score: %.4f\n", tfidfScore)
291
292 diff := bm25Score - tfidfScore
293 percentDiff := (diff / tfidfScore) * 100
294
295 if diff > 0 {
296 comparison += formatString("Difference: +%.4f (+%.1f%%)\n", diff, percentDiff)
297 comparison += "✅ BM25 scores higher (better)\n"
298 } else {
299 comparison += formatString("Difference: %.4f (%.1f%%)\n", diff, percentDiff)
300 comparison += "⚠️ TF-IDF scores higher\n"
301 }
302
303 comparison += "\nWhy BM25 is generally better:\n"
304 comparison += "- Term frequency saturation (diminishing returns)\n"
305 comparison += "- Document length normalization (fairer comparison)\n"
306 comparison += "- More sophisticated IDF formula\n"
307 comparison += "- Industry standard for search engines\n"
308
309 return comparison
310 }
311
312 // Helper functions
313
314 func toLowerSimple(s string) string {
315 result := ""
316 for _, r := range s {
317 if r >= 'A' && r <= 'Z' {
318 result += string(r + 32)
319 } else {
320 result += string(r)
321 }
322 }
323 return result
324 }
325
326 func joinWords(words []string, sep string) string {
327 if len(words) == 0 {
328 return ""
329 }
330
331 result := words[0]
332 for i := 1; i < len(words); i++ {
333 result += sep + words[i]
334 }
335 return result
336 }
337
338 func formatString(format string, args ...interface{}) string {
339 // Simple sprintf equivalent for basic formatting
340 // This is a simplified version - in production use fmt.Sprintf
341 result := format
342 for _, arg := range args {
343 switch v := arg.(type) {
344 case string:
345 result = replaceFirst(result, "%s", v)
346 case int:
347 result = replaceFirst(result, "%d", intToString(v))
348 case float64:
349 // Simple float formatting
350 result = replaceFirst(result, "%.4f", floatToString(v, 4))
351 result = replaceFirst(result, "%.1f", floatToString(v, 1))
352 }
353 }
354 return result
355 }
356
357 func replaceFirst(s, old, new string) string {
358 idx := findSubstring(s, old)
359 if idx < 0 {
360 return s
361 }
362 return s[:idx] + new + s[idx+len(old):]
363 }
364
365 func intToString(n int) string {
366 if n == 0 {
367 return "0"
368 }
369
370 negative := n < 0
371 if negative {
372 n = -n
373 }
374
375 digits := ""
376 for n > 0 {
377 digits = string('0'+rune(n%10)) + digits
378 n /= 10
379 }
380
381 if negative {
382 digits = "-" + digits
383 }
384
385 return digits
386 }
387
388 func floatToString(f float64, precision int) string {
389 // Simple float to string conversion
390 intPart := int(f)
391 fracPart := f - float64(intPart)
392
393 result := intToString(intPart) + "."
394
395 for i := 0; i < precision; i++ {
396 fracPart *= 10
397 digit := int(fracPart) % 10
398 result += string('0' + rune(digit))
399 }
400
401 return result
402 }
403