Go · 5347 bytes Raw Blame History
1 package llm
2
3 import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "fmt"
8 "net/http"
9 "net/url"
10 "time"
11 )
12
13 type OllamaClient struct {
14 BaseURL string
15 Model string
16 Mode string // "snappy" (fast) or "spicy" (quality)
17 client *http.Client
18 }
19
20 type GenerateRequest struct {
21 Model string `json:"model"`
22 Prompt string `json:"prompt"`
23 Stream bool `json:"stream"`
24 KeepAlive string `json:"keep_alive,omitempty"`
25 Options *GenerateOptions `json:"options,omitempty"`
26 }
27
28 // GenerateOptions controls Ollama generation behavior for speed optimization
29 type GenerateOptions struct {
30 NumPredict int `json:"num_predict,omitempty"` // Max tokens to generate (60 is plenty for insults)
31 NumCtx int `json:"num_ctx,omitempty"` // Context window size (512 is enough for small prompts)
32 Temperature float64 `json:"temperature,omitempty"` // Creativity (0.8 for variety)
33 }
34
35 type GenerateResponse struct {
36 Response string `json:"response"`
37 Done bool `json:"done"`
38 }
39
40 func NewOllamaClient(baseURL, model string) *OllamaClient {
41 if baseURL == "" {
42 baseURL = "http://127.0.0.1:11434" // Use IPv4 explicitly to avoid IPv6 issues
43 }
44 if model == "" {
45 model = "llama3.2:3b"
46 }
47
48 return &OllamaClient{
49 BaseURL: baseURL,
50 Model: model,
51 Mode: "snappy", // Default to fast mode
52 client: &http.Client{
53 Timeout: 60 * time.Second, // Maximum timeout; actual timeout controlled by context
54 },
55 }
56 }
57
58 // SetMode sets the generation mode ("snappy" for speed, "spicy" for quality)
59 func (c *OllamaClient) SetMode(mode string) {
60 if mode == "spicy" || mode == "snappy" {
61 c.Mode = mode
62 }
63 }
64
65 // getOptionsForMode returns optimized generation options based on mode
66 func (c *OllamaClient) getOptionsForMode() *GenerateOptions {
67 if c.Mode == "spicy" {
68 // Spicy mode: richer responses, more creative, willing to wait
69 return &GenerateOptions{
70 NumPredict: 80, // Longer responses
71 NumCtx: 1024, // Rich context window
72 Temperature: 0.85, // More creative
73 }
74 }
75 // Snappy mode (default): fast and punchy
76 return &GenerateOptions{
77 NumPredict: 40, // Short and punchy
78 NumCtx: 256, // Minimal context
79 Temperature: 0.6, // Faster convergence
80 }
81 }
82
83 func (c *OllamaClient) Generate(ctx context.Context, prompt string) (string, error) {
84 u, err := url.JoinPath(c.BaseURL, "/api/generate")
85 if err != nil {
86 return "", fmt.Errorf("invalid base URL: %w", err)
87 }
88
89 // Use mode-specific generation options
90 keepAlive := "5m"
91 if c.Mode == "spicy" {
92 keepAlive = "15m" // Keep model warm longer for quality mode
93 }
94
95 req := GenerateRequest{
96 Model: c.Model,
97 Prompt: prompt,
98 Stream: false,
99 KeepAlive: keepAlive,
100 Options: c.getOptionsForMode(),
101 }
102
103 reqBody, err := json.Marshal(req)
104 if err != nil {
105 return "", fmt.Errorf("failed to marshal request: %w", err)
106 }
107
108 httpReq, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(reqBody))
109 if err != nil {
110 return "", fmt.Errorf("failed to create request: %w", err)
111 }
112 httpReq.Header.Set("Content-Type", "application/json")
113
114 resp, err := c.client.Do(httpReq)
115 if err != nil {
116 return "", fmt.Errorf("failed to send request to %s: %w", c.BaseURL, err)
117 }
118 defer resp.Body.Close()
119
120 if resp.StatusCode != http.StatusOK {
121 // Read the actual error response body
122 bodyBytes := make([]byte, 512) // Read first 512 bytes for error message
123 n, _ := resp.Body.Read(bodyBytes)
124 bodyStr := string(bodyBytes[:n])
125 return "", fmt.Errorf("ollama API returned status %d: %s", resp.StatusCode, bodyStr)
126 }
127
128 var genResp GenerateResponse
129 if err := json.NewDecoder(resp.Body).Decode(&genResp); err != nil {
130 return "", fmt.Errorf("failed to decode response: %w", err)
131 }
132
133 return genResp.Response, nil
134 }
135
136 func (c *OllamaClient) IsAvailable() bool {
137 u, err := url.JoinPath(c.BaseURL, "/api/version")
138 if err != nil {
139 return false
140 }
141
142 ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
143 defer cancel()
144
145 req, err := http.NewRequestWithContext(ctx, "GET", u, nil)
146 if err != nil {
147 return false
148 }
149
150 resp, err := c.client.Do(req)
151 if err != nil {
152 return false
153 }
154 defer resp.Body.Close()
155
156 return resp.StatusCode == http.StatusOK
157 }
158
159 // WarmupModel preloads the model to avoid cold start delays
160 func (c *OllamaClient) WarmupModel() error {
161 u, err := url.JoinPath(c.BaseURL, "/api/generate")
162 if err != nil {
163 return fmt.Errorf("invalid base URL: %w", err)
164 }
165
166 req := GenerateRequest{
167 Model: c.Model,
168 Prompt: "Say OK", // Minimal prompt to load model
169 Stream: false,
170 KeepAlive: "10m",
171 Options: &GenerateOptions{
172 NumPredict: 5, // Minimal output for warmup
173 NumCtx: 256, // Minimal context
174 },
175 }
176
177 reqBody, err := json.Marshal(req)
178 if err != nil {
179 return fmt.Errorf("failed to marshal request: %w", err)
180 }
181
182 // Use a longer timeout for initial model loading
183 ctx, cancel := context.WithTimeout(context.Background(), 90*time.Second)
184 defer cancel()
185
186 httpReq, err := http.NewRequestWithContext(ctx, "POST", u, bytes.NewReader(reqBody))
187 if err != nil {
188 return fmt.Errorf("failed to create request: %w", err)
189 }
190 httpReq.Header.Set("Content-Type", "application/json")
191
192 resp, err := c.client.Do(httpReq)
193 if err != nil {
194 return fmt.Errorf("failed to warmup model: %w", err)
195 }
196 defer resp.Body.Close()
197
198 return nil
199 }