Go · 6012 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package event_test
4
5 import (
6 "testing"
7
8 "github.com/tenseleyFlow/shithub/internal/actions/event"
9 "github.com/tenseleyFlow/shithub/internal/actions/expr"
10 )
11
12 // TestPush_Shape pins the documented v1 push payload field set:
13 // ref, before, after, head_commit{message,id,author}. If you're here
14 // because you added a field and the test failed: update the doc
15 // (docs/internal/actions-schema.md), bump the v1→v2 marker if this
16 // is a rename or removal, and update the test in the same PR.
17 func TestPush_Shape(t *testing.T) {
18 t.Parallel()
19 p := event.Push("refs/heads/main", "abc", "def", event.HeadCommit{
20 Message: "fix: thing", ID: "def", Author: "alice",
21 })
22 got := keys(p)
23 wantTop := []string{"ref", "before", "after", "head_commit"}
24 for _, k := range wantTop {
25 if !contains(got, k) {
26 t.Errorf("missing key %q in push payload (have %v)", k, got)
27 }
28 }
29 hc, ok := p["head_commit"].(map[string]any)
30 if !ok {
31 t.Fatalf("head_commit not a map: %T", p["head_commit"])
32 }
33 for _, k := range []string{"message", "id", "author"} {
34 if _, ok := hc[k]; !ok {
35 t.Errorf("missing head_commit.%s", k)
36 }
37 }
38 }
39
40 // TestPullRequest_Shape pins the v1 pull_request schema.
41 func TestPullRequest_Shape(t *testing.T) {
42 t.Parallel()
43 p := event.PullRequest(
44 "opened", 42, "feat: foo",
45 event.PRRef{Ref: "feature", SHA: "aaaaaaaa"},
46 event.PRRef{Ref: "main", SHA: "bbbbbbbb"},
47 "alice",
48 )
49 for _, k := range []string{"action", "number", "pull_request"} {
50 if _, ok := p[k]; !ok {
51 t.Errorf("missing top-level %s", k)
52 }
53 }
54 pr, ok := p["pull_request"].(map[string]any)
55 if !ok {
56 t.Fatalf("pull_request not a map: %T", p["pull_request"])
57 }
58 for _, k := range []string{"title", "head", "base", "user"} {
59 if _, ok := pr[k]; !ok {
60 t.Errorf("missing pull_request.%s", k)
61 }
62 }
63 head := pr["head"].(map[string]any)
64 if head["ref"] != "feature" || head["sha"] != "aaaaaaaa" {
65 t.Errorf("head ref/sha wrong: %v", head)
66 }
67 user := pr["user"].(map[string]any)
68 if user["login"] != "alice" {
69 t.Errorf("user.login wrong: %v", user)
70 }
71 }
72
73 // TestSchedule_IsEmptyMap pins the empty-map invariant. Returning nil
74 // would force callers to nil-check before pgx-encoding; a non-nil
75 // empty map is the safer default.
76 func TestSchedule_IsEmptyMap(t *testing.T) {
77 t.Parallel()
78 p := event.Schedule()
79 if p == nil {
80 t.Fatal("Schedule() returned nil; expected non-nil empty map")
81 }
82 if len(p) != 0 {
83 t.Errorf("Schedule() should be empty, got %v", p)
84 }
85 }
86
87 // TestWorkflowDispatch_Inputs pins the inputs-wrapping shape.
88 // Authors template ${{ shithub.event.inputs.foo }}, so the inputs key
89 // must exist as a nested map even when no inputs are provided.
90 func TestWorkflowDispatch_Inputs(t *testing.T) {
91 t.Parallel()
92 p := event.WorkflowDispatch(map[string]string{"env": "prod", "tag": "v1.2"})
93 inputs, ok := p["inputs"].(map[string]any)
94 if !ok {
95 t.Fatalf("inputs not a map: %T", p["inputs"])
96 }
97 if inputs["env"] != "prod" || inputs["tag"] != "v1.2" {
98 t.Errorf("inputs wrong: %v", inputs)
99 }
100 }
101
102 // TestPush_FlowsThroughEvaluator ties this package to the actual
103 // expr.evalEventPath consumer. Workflow authors template ${{ ... }}
104 // against documented field paths; if the constructor lays out a key
105 // the evaluator can't reach, the contract is broken. Pin both ends.
106 func TestPush_FlowsThroughEvaluator(t *testing.T) {
107 t.Parallel()
108 p := event.Push("refs/heads/trunk", "abc", "def", event.HeadCommit{
109 Message: "fix: thing", ID: "def", Author: "alice",
110 })
111 ctx := &expr.Context{
112 Shithub: expr.ShithubContext{Event: p},
113 Untrusted: expr.DefaultUntrusted(),
114 }
115 cases := []struct {
116 path string
117 want string
118 }{
119 {`shithub.event.ref`, "refs/heads/trunk"},
120 {`shithub.event.head_commit.message`, "fix: thing"},
121 {`shithub.event.head_commit.id`, "def"},
122 {`shithub.event.head_commit.author`, "alice"},
123 {`github.event.head_commit.message`, "fix: thing"}, // alias path
124 }
125 for _, tc := range cases {
126 t.Run(tc.path, func(t *testing.T) {
127 t.Parallel()
128 toks, err := expr.Lex(tc.path)
129 if err != nil {
130 t.Fatalf("lex: %v", err)
131 }
132 ast, err := expr.Parse(toks)
133 if err != nil {
134 t.Fatalf("parse: %v", err)
135 }
136 v, err := expr.Eval(ast, ctx)
137 if err != nil {
138 t.Fatalf("eval: %v", err)
139 }
140 if v.S != tc.want {
141 t.Errorf("got %q, want %q", v.S, tc.want)
142 }
143 if !v.Tainted {
144 t.Errorf("event-derived value must be tainted")
145 }
146 })
147 }
148 }
149
150 // TestPullRequest_FlowsThroughEvaluator does the same end-to-end pin
151 // for the pull_request schema, which has the most authoring surface.
152 func TestPullRequest_FlowsThroughEvaluator(t *testing.T) {
153 t.Parallel()
154 p := event.PullRequest(
155 "opened", 7, "feat: add foo",
156 event.PRRef{Ref: "feature", SHA: "feedbeef"},
157 event.PRRef{Ref: "main", SHA: "deadbeef"},
158 "alice",
159 )
160 ctx := &expr.Context{
161 Shithub: expr.ShithubContext{Event: p},
162 Untrusted: expr.DefaultUntrusted(),
163 }
164 cases := []struct {
165 path string
166 want string
167 }{
168 {`shithub.event.pull_request.title`, "feat: add foo"},
169 {`shithub.event.pull_request.head.ref`, "feature"},
170 {`shithub.event.pull_request.base.sha`, "deadbeef"},
171 {`shithub.event.pull_request.user.login`, "alice"},
172 {`shithub.event.action`, "opened"},
173 }
174 for _, tc := range cases {
175 t.Run(tc.path, func(t *testing.T) {
176 t.Parallel()
177 toks, err := expr.Lex(tc.path)
178 if err != nil {
179 t.Fatalf("lex: %v", err)
180 }
181 ast, err := expr.Parse(toks)
182 if err != nil {
183 t.Fatalf("parse: %v", err)
184 }
185 v, err := expr.Eval(ast, ctx)
186 if err != nil {
187 t.Fatalf("eval: %v", err)
188 }
189 if v.S != tc.want {
190 t.Errorf("got %q, want %q", v.S, tc.want)
191 }
192 if !v.Tainted {
193 t.Errorf("event-derived value must be tainted")
194 }
195 })
196 }
197 }
198
199 func keys(m map[string]any) []string {
200 out := make([]string, 0, len(m))
201 for k := range m {
202 out = append(out, k)
203 }
204 return out
205 }
206
207 func contains(s []string, want string) bool {
208 for _, v := range s {
209 if v == want {
210 return true
211 }
212 }
213 return false
214 }
215