| 1 | // SPDX-License-Identifier: AGPL-3.0-or-later |
| 2 | |
| 3 | package orgs_test |
| 4 | |
| 5 | import ( |
| 6 | "context" |
| 7 | "encoding/json" |
| 8 | "fmt" |
| 9 | "io" |
| 10 | "log/slog" |
| 11 | "net/http" |
| 12 | "strconv" |
| 13 | "strings" |
| 14 | "testing" |
| 15 | |
| 16 | "github.com/tenseleyFlow/shithub/internal/auth/secretbox" |
| 17 | "github.com/tenseleyFlow/shithub/internal/orgs" |
| 18 | "github.com/tenseleyFlow/shithub/internal/worker" |
| 19 | ) |
| 20 | |
| 21 | func TestNormalizeGitHubOrg(t *testing.T) { |
| 22 | t.Parallel() |
| 23 | tests := []struct { |
| 24 | name string |
| 25 | raw string |
| 26 | want string |
| 27 | wantErr bool |
| 28 | }{ |
| 29 | {name: "bare", raw: "tenseleyFlow", want: "tenseleyFlow"}, |
| 30 | {name: "url", raw: "https://github.com/FortranGoingOnForty/", want: "FortranGoingOnForty"}, |
| 31 | {name: "path rejected", raw: "github.com/owner/repo", wantErr: true}, |
| 32 | {name: "trailing hyphen rejected", raw: "bad-", wantErr: true}, |
| 33 | {name: "double hyphen rejected", raw: "bad--name", wantErr: true}, |
| 34 | } |
| 35 | for _, tt := range tests { |
| 36 | t.Run(tt.name, func(t *testing.T) { |
| 37 | got, err := orgs.NormalizeGitHubOrg(tt.raw) |
| 38 | if tt.wantErr { |
| 39 | if err == nil { |
| 40 | t.Fatalf("NormalizeGitHubOrg(%q) succeeded", tt.raw) |
| 41 | } |
| 42 | return |
| 43 | } |
| 44 | if err != nil { |
| 45 | t.Fatalf("NormalizeGitHubOrg(%q): %v", tt.raw, err) |
| 46 | } |
| 47 | if got != tt.want { |
| 48 | t.Fatalf("NormalizeGitHubOrg(%q)=%q, want %q", tt.raw, got, tt.want) |
| 49 | } |
| 50 | }) |
| 51 | } |
| 52 | } |
| 53 | |
| 54 | func TestGitHubClientListOrgReposPaginatesAndUsesToken(t *testing.T) { |
| 55 | t.Parallel() |
| 56 | var seenPages []string |
| 57 | rt := roundTripFunc(func(r *http.Request) (*http.Response, error) { |
| 58 | if got, want := r.URL.Path, "/orgs/tenseleyFlow/repos"; got != want { |
| 59 | t.Fatalf("path=%q, want %q", got, want) |
| 60 | } |
| 61 | if got, want := r.Header.Get("Authorization"), "Bearer test-token"; got != want { |
| 62 | t.Fatalf("Authorization=%q, want %q", got, want) |
| 63 | } |
| 64 | if got, want := r.URL.Query().Get("type"), "all"; got != want { |
| 65 | t.Fatalf("type=%q, want %q", got, want) |
| 66 | } |
| 67 | seenPages = append(seenPages, r.URL.Query().Get("page")) |
| 68 | var body strings.Builder |
| 69 | if r.URL.Query().Get("page") == "1" { |
| 70 | writeGitHubRepoList(t, &body, 100) |
| 71 | } else { |
| 72 | _, _ = io.WriteString(&body, `[{"id":101,"name":"last","full_name":"tenseleyFlow/last","clone_url":"https://github.com/tenseleyFlow/last.git","description":"last repo","default_branch":"trunk"}]`) |
| 73 | } |
| 74 | return &http.Response{ |
| 75 | StatusCode: http.StatusOK, |
| 76 | Status: "200 OK", |
| 77 | Header: http.Header{"Content-Type": []string{"application/json"}}, |
| 78 | Body: io.NopCloser(strings.NewReader(body.String())), |
| 79 | Request: r, |
| 80 | }, nil |
| 81 | }) |
| 82 | |
| 83 | repos, err := (orgs.GitHubClient{ |
| 84 | HTTPClient: &http.Client{Transport: rt}, |
| 85 | BaseURL: "https://api.github.test", |
| 86 | UserAgent: "shithub-test", |
| 87 | }).ListOrgRepos(context.Background(), "tenseleyFlow", "test-token") |
| 88 | if err != nil { |
| 89 | t.Fatalf("ListOrgRepos: %v", err) |
| 90 | } |
| 91 | if len(repos) != 101 { |
| 92 | t.Fatalf("len(repos)=%d, want 101", len(repos)) |
| 93 | } |
| 94 | if got := fmt.Sprint(seenPages); got != "[1 2]" { |
| 95 | t.Fatalf("pages=%s, want [1 2]", got) |
| 96 | } |
| 97 | if repos[100].FullName != "tenseleyFlow/last" || repos[100].Description != "last repo" || repos[100].DefaultBranch != "trunk" { |
| 98 | t.Fatalf("last repo decoded incorrectly: %+v", repos[100]) |
| 99 | } |
| 100 | } |
| 101 | |
| 102 | type roundTripFunc func(*http.Request) (*http.Response, error) |
| 103 | |
| 104 | func (f roundTripFunc) RoundTrip(r *http.Request) (*http.Response, error) { |
| 105 | return f(r) |
| 106 | } |
| 107 | |
| 108 | func writeGitHubRepoList(t *testing.T, w io.Writer, count int) { |
| 109 | t.Helper() |
| 110 | payload := make([]map[string]any, 0, count) |
| 111 | for i := 0; i < count; i++ { |
| 112 | payload = append(payload, map[string]any{ |
| 113 | "id": i + 1, |
| 114 | "name": fmt.Sprintf("repo-%03d", i), |
| 115 | "full_name": fmt.Sprintf("tenseleyFlow/repo-%03d", i), |
| 116 | "clone_url": fmt.Sprintf("https://github.com/tenseleyFlow/repo-%03d.git", i), |
| 117 | "description": nil, |
| 118 | "default_branch": "trunk", |
| 119 | "private": false, |
| 120 | "fork": false, |
| 121 | }) |
| 122 | } |
| 123 | if err := json.NewEncoder(w).Encode(payload); err != nil { |
| 124 | t.Fatalf("encode GitHub response: %v", err) |
| 125 | } |
| 126 | } |
| 127 | |
| 128 | func TestStartGitHubImportPersistsEncryptedTokenAndDiscoveryJob(t *testing.T) { |
| 129 | pool, deps, alice := setup(t) |
| 130 | row, err := orgs.Create(context.Background(), deps, orgs.CreateParams{ |
| 131 | Slug: "acme", DisplayName: "Acme Inc", CreatedByUserID: alice, |
| 132 | }) |
| 133 | if err != nil { |
| 134 | t.Fatalf("create org: %v", err) |
| 135 | } |
| 136 | key, err := secretbox.GenerateKey() |
| 137 | if err != nil { |
| 138 | t.Fatalf("GenerateKey: %v", err) |
| 139 | } |
| 140 | box, err := secretbox.FromBytes(key) |
| 141 | if err != nil { |
| 142 | t.Fatalf("FromBytes: %v", err) |
| 143 | } |
| 144 | imp, err := orgs.StartGitHubImport(context.Background(), orgs.ImportDeps{ |
| 145 | Pool: pool, |
| 146 | Box: box, |
| 147 | Logger: slog.New(slog.NewTextHandler(io.Discard, nil)), |
| 148 | }, orgs.StartGitHubImportParams{ |
| 149 | OrgID: row.ID, SourceOrg: "https://github.com/tenseleyFlow/", |
| 150 | RequestedByUserID: alice, Token: "ghp_secret", |
| 151 | }) |
| 152 | if err != nil { |
| 153 | t.Fatalf("StartGitHubImport: %v", err) |
| 154 | } |
| 155 | if imp.SourceOrg != "tenseleyFlow" || !imp.IncludePrivate || !imp.TokenPresent { |
| 156 | t.Fatalf("unexpected import row: %+v", imp) |
| 157 | } |
| 158 | token, err := orgs.DecryptGitHubImportToken(imp, box) |
| 159 | if err != nil { |
| 160 | t.Fatalf("DecryptGitHubImportToken: %v", err) |
| 161 | } |
| 162 | if token != "ghp_secret" { |
| 163 | t.Fatalf("decrypted token=%q", token) |
| 164 | } |
| 165 | if string(imp.TokenCiphertext) == "ghp_secret" { |
| 166 | t.Fatal("token stored in plaintext") |
| 167 | } |
| 168 | var jobs int |
| 169 | if err := pool.QueryRow(context.Background(), |
| 170 | `SELECT count(*) FROM jobs WHERE kind = $1 AND payload->>'import_id' = $2`, |
| 171 | worker.KindOrgGitHubImportDiscover, strconv.FormatInt(imp.ID, 10), |
| 172 | ).Scan(&jobs); err != nil { |
| 173 | t.Fatalf("query jobs: %v", err) |
| 174 | } |
| 175 | if jobs != 1 { |
| 176 | t.Fatalf("discovery jobs=%d, want 1", jobs) |
| 177 | } |
| 178 | } |
| 179 |