Go · 5527 bytes Raw Blame History
1 // SPDX-License-Identifier: AGPL-3.0-or-later
2
3 package orgs_test
4
5 import (
6 "context"
7 "encoding/json"
8 "fmt"
9 "io"
10 "log/slog"
11 "net/http"
12 "strconv"
13 "strings"
14 "testing"
15
16 "github.com/tenseleyFlow/shithub/internal/auth/secretbox"
17 "github.com/tenseleyFlow/shithub/internal/orgs"
18 "github.com/tenseleyFlow/shithub/internal/worker"
19 )
20
21 func TestNormalizeGitHubOrg(t *testing.T) {
22 t.Parallel()
23 tests := []struct {
24 name string
25 raw string
26 want string
27 wantErr bool
28 }{
29 {name: "bare", raw: "tenseleyFlow", want: "tenseleyFlow"},
30 {name: "url", raw: "https://github.com/FortranGoingOnForty/", want: "FortranGoingOnForty"},
31 {name: "path rejected", raw: "github.com/owner/repo", wantErr: true},
32 {name: "trailing hyphen rejected", raw: "bad-", wantErr: true},
33 {name: "double hyphen rejected", raw: "bad--name", wantErr: true},
34 }
35 for _, tt := range tests {
36 t.Run(tt.name, func(t *testing.T) {
37 got, err := orgs.NormalizeGitHubOrg(tt.raw)
38 if tt.wantErr {
39 if err == nil {
40 t.Fatalf("NormalizeGitHubOrg(%q) succeeded", tt.raw)
41 }
42 return
43 }
44 if err != nil {
45 t.Fatalf("NormalizeGitHubOrg(%q): %v", tt.raw, err)
46 }
47 if got != tt.want {
48 t.Fatalf("NormalizeGitHubOrg(%q)=%q, want %q", tt.raw, got, tt.want)
49 }
50 })
51 }
52 }
53
54 func TestGitHubClientListOrgReposPaginatesAndUsesToken(t *testing.T) {
55 t.Parallel()
56 var seenPages []string
57 rt := roundTripFunc(func(r *http.Request) (*http.Response, error) {
58 if got, want := r.URL.Path, "/orgs/tenseleyFlow/repos"; got != want {
59 t.Fatalf("path=%q, want %q", got, want)
60 }
61 if got, want := r.Header.Get("Authorization"), "Bearer test-token"; got != want {
62 t.Fatalf("Authorization=%q, want %q", got, want)
63 }
64 if got, want := r.URL.Query().Get("type"), "all"; got != want {
65 t.Fatalf("type=%q, want %q", got, want)
66 }
67 seenPages = append(seenPages, r.URL.Query().Get("page"))
68 var body strings.Builder
69 if r.URL.Query().Get("page") == "1" {
70 writeGitHubRepoList(t, &body, 100)
71 } else {
72 _, _ = io.WriteString(&body, `[{"id":101,"name":"last","full_name":"tenseleyFlow/last","clone_url":"https://github.com/tenseleyFlow/last.git","description":"last repo","default_branch":"trunk"}]`)
73 }
74 return &http.Response{
75 StatusCode: http.StatusOK,
76 Status: "200 OK",
77 Header: http.Header{"Content-Type": []string{"application/json"}},
78 Body: io.NopCloser(strings.NewReader(body.String())),
79 Request: r,
80 }, nil
81 })
82
83 repos, err := (orgs.GitHubClient{
84 HTTPClient: &http.Client{Transport: rt},
85 BaseURL: "https://api.github.test",
86 UserAgent: "shithub-test",
87 }).ListOrgRepos(context.Background(), "tenseleyFlow", "test-token")
88 if err != nil {
89 t.Fatalf("ListOrgRepos: %v", err)
90 }
91 if len(repos) != 101 {
92 t.Fatalf("len(repos)=%d, want 101", len(repos))
93 }
94 if got := fmt.Sprint(seenPages); got != "[1 2]" {
95 t.Fatalf("pages=%s, want [1 2]", got)
96 }
97 if repos[100].FullName != "tenseleyFlow/last" || repos[100].Description != "last repo" || repos[100].DefaultBranch != "trunk" {
98 t.Fatalf("last repo decoded incorrectly: %+v", repos[100])
99 }
100 }
101
102 type roundTripFunc func(*http.Request) (*http.Response, error)
103
104 func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) {
105 return f(r)
106 }
107
108 func writeGitHubRepoList(t *testing.T, w io.Writer, count int) {
109 t.Helper()
110 payload := make([]map[string]any, 0, count)
111 for i := 0; i < count; i++ {
112 payload = append(payload, map[string]any{
113 "id": i + 1,
114 "name": fmt.Sprintf("repo-%03d", i),
115 "full_name": fmt.Sprintf("tenseleyFlow/repo-%03d", i),
116 "clone_url": fmt.Sprintf("https://github.com/tenseleyFlow/repo-%03d.git", i),
117 "description": nil,
118 "default_branch": "trunk",
119 "private": false,
120 "fork": false,
121 })
122 }
123 if err := json.NewEncoder(w).Encode(payload); err != nil {
124 t.Fatalf("encode GitHub response: %v", err)
125 }
126 }
127
128 func TestStartGitHubImportPersistsEncryptedTokenAndDiscoveryJob(t *testing.T) {
129 pool, deps, alice := setup(t)
130 row, err := orgs.Create(context.Background(), deps, orgs.CreateParams{
131 Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: alice,
132 })
133 if err != nil {
134 t.Fatalf("create org: %v", err)
135 }
136 key, err := secretbox.GenerateKey()
137 if err != nil {
138 t.Fatalf("GenerateKey: %v", err)
139 }
140 box, err := secretbox.FromBytes(key)
141 if err != nil {
142 t.Fatalf("FromBytes: %v", err)
143 }
144 imp, err := orgs.StartGitHubImport(context.Background(), orgs.ImportDeps{
145 Pool: pool,
146 Box: box,
147 Logger: slog.New(slog.NewTextHandler(io.Discard, nil)),
148 }, orgs.StartGitHubImportParams{
149 OrgID: row.ID, SourceOrg: "https://github.com/tenseleyFlow/",
150 RequestedByUserID: alice, Token: "ghp_secret",
151 })
152 if err != nil {
153 t.Fatalf("StartGitHubImport: %v", err)
154 }
155 if imp.SourceOrg != "tenseleyFlow" || !imp.IncludePrivate || !imp.TokenPresent {
156 t.Fatalf("unexpected import row: %+v", imp)
157 }
158 token, err := orgs.DecryptGitHubImportToken(imp, box)
159 if err != nil {
160 t.Fatalf("DecryptGitHubImportToken: %v", err)
161 }
162 if token != "ghp_secret" {
163 t.Fatalf("decrypted token=%q", token)
164 }
165 if string(imp.TokenCiphertext) == "ghp_secret" {
166 t.Fatal("token stored in plaintext")
167 }
168 var jobs int
169 if err := pool.QueryRow(context.Background(),
170 `SELECT count(*) FROM jobs WHERE kind = $1 AND payload->>'import_id' = $2`,
171 worker.KindOrgGitHubImportDiscover, strconv.FormatInt(imp.ID, 10),
172 ).Scan(&jobs); err != nil {
173 t.Fatalf("query jobs: %v", err)
174 }
175 if jobs != 1 {
176 t.Fatalf("discovery jobs=%d, want 1", jobs)
177 }
178 }
179