Go · 8701 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package repo
4
5 import (
6 "context"
7 "encoding/base64"
8 "encoding/json"
9 "errors"
10 "fmt"
11 "net/http"
12 "strconv"
13 "time"
14
15 "github.com/jackc/pgx/v5"
16
17 "github.com/tenseleyFlow/shithub/internal/actions/logstream"
18 actionsdb "github.com/tenseleyFlow/shithub/internal/actions/sqlc"
19 "github.com/tenseleyFlow/shithub/internal/auth/policy"
20 "github.com/tenseleyFlow/shithub/internal/ratelimit"
21 "github.com/tenseleyFlow/shithub/internal/web/middleware"
22 )
23
24 const (
25 actionsLogStreamBatchSize = int32(100)
26 actionsLogStreamHeartbeatEvery = 20 * time.Second
27 actionsLogStreamReleaseTimeout = 3 * time.Second
28 )
29
30 var actionsLogStreamLimit = ratelimit.Policy{
31 Scope: "actions:logtail",
32 Max: 5,
33 Window: 2 * time.Minute,
34 }
35
36 type actionsLogStreamChunk struct {
37 Seq int32 `json:"seq"`
38 ChunkB64 string `json:"chunk_b64"`
39 }
40
41 func (h *Handlers) repoActionStepLogStream(w http.ResponseWriter, r *http.Request) {
42 row, owner, ok := h.loadRepoAndAuthorize(w, r, policy.ActionRepoRead)
43 if !ok {
44 return
45 }
46 runIndex, ok := parsePositiveInt64Param(r, "runIndex")
47 if !ok {
48 h.d.Render.HTTPError(w, r, http.StatusNotFound, "")
49 return
50 }
51 jobIndex, ok := parseNonNegativeInt32Param(r, "jobIndex")
52 if !ok {
53 h.d.Render.HTTPError(w, r, http.StatusNotFound, "")
54 return
55 }
56 stepIndex, ok := parseNonNegativeInt32Param(r, "stepIndex")
57 if !ok {
58 h.d.Render.HTTPError(w, r, http.StatusNotFound, "")
59 return
60 }
61 lastSeq, ok := parseLogStreamAfter(r)
62 if !ok {
63 h.d.Render.HTTPError(w, r, http.StatusBadRequest, "invalid Last-Event-ID")
64 return
65 }
66
67 run, err := h.loadActionsRunDetail(r.Context(), row.ID, owner.Username, row.Name, runIndex)
68 if err != nil {
69 if errors.Is(err, pgx.ErrNoRows) {
70 h.d.Render.HTTPError(w, r, http.StatusNotFound, "")
71 } else {
72 h.d.Logger.WarnContext(r.Context(), "repo actions: get run for step log stream", "repo_id", row.ID, "run_index", runIndex, "error", err)
73 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
74 }
75 return
76 }
77 _, step, ok := findActionStep(run, jobIndex, stepIndex)
78 if !ok {
79 h.d.Render.HTTPError(w, r, http.StatusNotFound, "")
80 return
81 }
82
83 flusher, ok := w.(http.Flusher)
84 if !ok {
85 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "streaming is not supported")
86 return
87 }
88
89 lease, decision, leaseErr := h.acquireLogStreamLease(r.Context(), w, r)
90 if leaseErr != nil && h.d.Logger != nil {
91 h.d.Logger.WarnContext(r.Context(), "repo actions: log stream rate-limit failed", "step_id", step.ID, "error", leaseErr)
92 }
93 if !decision.Allowed {
94 w.Header().Set("Retry-After", strconv.Itoa(int(decision.RetryAfter/time.Second)))
95 h.d.Render.HTTPError(w, r, http.StatusTooManyRequests, "too many live log streams")
96 return
97 }
98 if lease != nil {
99 defer func() {
100 ctx, cancel := context.WithTimeout(context.Background(), actionsLogStreamReleaseTimeout)
101 defer cancel()
102 if err := lease.Release(ctx); err != nil && h.d.Logger != nil {
103 h.d.Logger.WarnContext(r.Context(), "repo actions: release log stream lease", "step_id", step.ID, "error", err)
104 }
105 }()
106 }
107
108 conn, err := h.d.Pool.Acquire(r.Context())
109 if err != nil {
110 h.d.Logger.WarnContext(r.Context(), "repo actions: acquire log stream conn", "step_id", step.ID, "error", err)
111 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
112 return
113 }
114 defer conn.Release()
115 if _, err := conn.Exec(r.Context(), logstream.ListenSQL(step.ID)); err != nil {
116 h.d.Logger.WarnContext(r.Context(), "repo actions: listen log stream", "step_id", step.ID, "error", err)
117 h.d.Render.HTTPError(w, r, http.StatusInternalServerError, "")
118 return
119 }
120
121 w.Header().Set("Content-Type", "text/event-stream; charset=utf-8")
122 w.Header().Set("Cache-Control", "no-cache, no-transform")
123 w.Header().Set("Connection", "keep-alive")
124 w.Header().Set("X-Accel-Buffering", "no")
125 w.WriteHeader(http.StatusOK)
126
127 nextSeq, err := h.flushStepLogChunks(r.Context(), w, conn, step.ID, lastSeq)
128 if err != nil {
129 h.d.Logger.WarnContext(r.Context(), "repo actions: write initial log chunks", "step_id", step.ID, "error", err)
130 return
131 }
132 done, err := h.stepLogStreamDone(r.Context(), conn, step.ID)
133 if err != nil {
134 h.d.Logger.WarnContext(r.Context(), "repo actions: check step terminal", "step_id", step.ID, "error", err)
135 return
136 }
137 if done {
138 _ = writeSSEEvent(w, "done", -1, []byte(`{}`))
139 flusher.Flush()
140 return
141 }
142 flusher.Flush()
143
144 for {
145 waitCtx, cancel := context.WithTimeout(r.Context(), actionsLogStreamHeartbeatEvery)
146 notification, err := conn.Conn().WaitForNotification(waitCtx)
147 cancel()
148 if r.Context().Err() != nil {
149 return
150 }
151 if errors.Is(err, context.DeadlineExceeded) {
152 nextSeq, err = h.flushStepLogChunks(r.Context(), w, conn, step.ID, nextSeq)
153 if err != nil {
154 h.d.Logger.WarnContext(r.Context(), "repo actions: write heartbeat log chunks", "step_id", step.ID, "error", err)
155 return
156 }
157 done, err := h.stepLogStreamDone(r.Context(), conn, step.ID)
158 if err != nil {
159 h.d.Logger.WarnContext(r.Context(), "repo actions: heartbeat terminal check", "step_id", step.ID, "error", err)
160 return
161 }
162 if done {
163 _ = writeSSEEvent(w, "done", -1, []byte(`{}`))
164 flusher.Flush()
165 return
166 }
167 if _, err := fmt.Fprint(w, ": keep-alive\n\n"); err != nil {
168 return
169 }
170 flusher.Flush()
171 continue
172 }
173 if err != nil {
174 h.d.Logger.WarnContext(r.Context(), "repo actions: wait log notification", "step_id", step.ID, "error", err)
175 return
176 }
177 if notification.Channel != logstream.Channel(step.ID) {
178 continue
179 }
180 _, done, ok := logstream.ParsePayload(notification.Payload)
181 if !ok {
182 h.d.Logger.WarnContext(r.Context(), "repo actions: invalid log notification", "step_id", step.ID, "payload", notification.Payload)
183 continue
184 }
185 nextSeq, err = h.flushStepLogChunks(r.Context(), w, conn, step.ID, nextSeq)
186 if err != nil {
187 h.d.Logger.WarnContext(r.Context(), "repo actions: write log chunks", "step_id", step.ID, "error", err)
188 return
189 }
190 if done {
191 _ = writeSSEEvent(w, "done", -1, []byte(`{}`))
192 flusher.Flush()
193 return
194 }
195 flusher.Flush()
196 }
197 }
198
199 func (h *Handlers) acquireLogStreamLease(ctx context.Context, w http.ResponseWriter, r *http.Request) (*ratelimit.Lease, ratelimit.Decision, error) {
200 if h.d.RateLimiter == nil {
201 return nil, ratelimit.Decision{Allowed: true}, nil
202 }
203 key := logStreamRateLimitKey(r)
204 if key == "" {
205 return nil, ratelimit.Decision{Allowed: true}, nil
206 }
207 lease, decision, err := h.d.RateLimiter.AcquireLease(ctx, actionsLogStreamLimit, key)
208 ratelimit.StampHeaders(w, decision)
209 return lease, decision, err
210 }
211
212 func logStreamRateLimitKey(r *http.Request) string {
213 viewer := middleware.CurrentUserFromContext(r.Context())
214 if !viewer.IsAnonymous() {
215 return "u:" + strconv.FormatInt(viewer.ID, 10)
216 }
217 if ip, ok := ratelimit.ClientIP(r, true); ok {
218 return "ip:" + ip.String()
219 }
220 return ""
221 }
222
223 func parseLogStreamAfter(r *http.Request) (int32, bool) {
224 raw := r.Header.Get("Last-Event-ID")
225 if raw == "" {
226 raw = r.URL.Query().Get("after")
227 }
228 if raw == "" {
229 return -1, true
230 }
231 n, err := strconv.ParseInt(raw, 10, 32)
232 if err != nil || n < -1 {
233 return 0, false
234 }
235 return int32(n), true
236 }
237
238 func (h *Handlers) flushStepLogChunks(ctx context.Context, w http.ResponseWriter, db actionsdb.DBTX, stepID int64, afterSeq int32) (int32, error) {
239 q := actionsdb.New()
240 nextSeq := afterSeq
241 for {
242 chunks, err := q.ListStepLogChunks(ctx, db, actionsdb.ListStepLogChunksParams{
243 StepID: stepID,
244 Seq: nextSeq,
245 Limit: actionsLogStreamBatchSize,
246 })
247 if err != nil {
248 return nextSeq, err
249 }
250 for _, chunk := range chunks {
251 payload, err := json.Marshal(actionsLogStreamChunk{
252 Seq: chunk.Seq,
253 ChunkB64: base64.StdEncoding.EncodeToString(chunk.Chunk),
254 })
255 if err != nil {
256 return nextSeq, err
257 }
258 if err := writeSSEEvent(w, "chunk", chunk.Seq, payload); err != nil {
259 return nextSeq, err
260 }
261 nextSeq = chunk.Seq
262 }
263 if int32(len(chunks)) < actionsLogStreamBatchSize {
264 return nextSeq, nil
265 }
266 }
267 }
268
269 func (h *Handlers) stepLogStreamDone(ctx context.Context, db actionsdb.DBTX, stepID int64) (bool, error) {
270 step, err := actionsdb.New().GetWorkflowStepByID(ctx, db, stepID)
271 if err != nil {
272 return false, err
273 }
274 return workflowStepTerminal(step.Status), nil
275 }
276
277 func writeSSEEvent(w http.ResponseWriter, event string, id int32, payload []byte) error {
278 if id >= 0 {
279 if _, err := fmt.Fprintf(w, "id: %d\n", id); err != nil {
280 return err
281 }
282 }
283 if event != "" {
284 if _, err := fmt.Fprintf(w, "event: %s\n", event); err != nil {
285 return err
286 }
287 }
288 if _, err := fmt.Fprintf(w, "data: %s\n\n", payload); err != nil {
289 return err
290 }
291 return nil
292 }
293