A community based topic aggregation platform built on atproto
1package imageproxy
2
3import (
4 "errors"
5 "testing"
6)
7
8func TestValidateDID(t *testing.T) {
9 tests := []struct {
10 name string
11 did string
12 wantErr error
13 }{
14 // Valid DIDs - uses Indigo's syntax.ParseDID for consistency with codebase
15 {
16 name: "valid did:plc",
17 did: "did:plc:z72i7hdynmk6r22z27h6tvur",
18 wantErr: nil,
19 },
20 {
21 name: "valid did:web simple",
22 did: "did:web:example.com",
23 wantErr: nil,
24 },
25 {
26 name: "valid did:web with subdomain",
27 did: "did:web:bsky.social",
28 wantErr: nil,
29 },
30 {
31 name: "valid did:web with path",
32 did: "did:web:example.com:user:alice",
33 wantErr: nil,
34 },
35 // did:key is valid per Indigo library (used in other atproto contexts)
36 {
37 name: "valid did:key",
38 did: "did:key:z6MkhaXgBZDvotDkL5257faiztiGiC2QtKLGpbnnEGta2doK",
39 wantErr: nil,
40 },
41 // Invalid DIDs
42 {
43 name: "empty string",
44 did: "",
45 wantErr: ErrInvalidDID,
46 },
47 {
48 name: "missing did: prefix",
49 did: "plc:z72i7hdynmk6r22z27h6tvur",
50 wantErr: ErrInvalidDID,
51 },
52 {
53 name: "path traversal attempt in did",
54 did: "did:plc:../../../etc/passwd",
55 wantErr: ErrInvalidDID,
56 },
57 {
58 name: "null byte injection",
59 did: "did:plc:abc\x00def",
60 wantErr: ErrInvalidDID,
61 },
62 {
63 name: "forward slash injection",
64 did: "did:plc:abc/def",
65 wantErr: ErrInvalidDID,
66 },
67 {
68 name: "backslash injection",
69 did: "did:plc:abc\\def",
70 wantErr: ErrInvalidDID,
71 },
72 {
73 name: "just did prefix",
74 did: "did:",
75 wantErr: ErrInvalidDID,
76 },
77 {
78 name: "random gibberish",
79 did: "not-a-did-at-all",
80 wantErr: ErrInvalidDID,
81 },
82 }
83
84 for _, tt := range tests {
85 t.Run(tt.name, func(t *testing.T) {
86 err := ValidateDID(tt.did)
87 if !errors.Is(err, tt.wantErr) {
88 t.Errorf("ValidateDID(%q) = %v, want %v", tt.did, err, tt.wantErr)
89 }
90 })
91 }
92}
93
94func TestValidateCID(t *testing.T) {
95 tests := []struct {
96 name string
97 cid string
98 wantErr error
99 }{
100 // Valid CIDs
101 {
102 name: "valid CIDv1 base32 bafy",
103 cid: "bafyreihgdyzzpkkzq2izfnhcmm77ycuacvkuziwbnqxfxtqsz7tmxwhnshi",
104 wantErr: nil,
105 },
106 {
107 name: "valid CIDv1 base32 bafk",
108 cid: "bafkreihgdyzzpkkzq2izfnhcmm77ycuacvkuziwbnqxfxtqsz7tmxwhnshi",
109 wantErr: nil,
110 },
111 {
112 name: "valid CIDv0",
113 cid: "QmYwAPJzv5CZsnA625s3Xf2nemtYgPpHdWEz79ojWnPbdG",
114 wantErr: nil,
115 },
116 // Invalid CIDs
117 {
118 name: "empty string",
119 cid: "",
120 wantErr: ErrInvalidCID,
121 },
122 {
123 name: "too short",
124 cid: "bafyabc",
125 wantErr: ErrInvalidCID,
126 },
127 {
128 name: "path traversal attempt",
129 cid: "../../../etc/passwd",
130 wantErr: ErrInvalidCID,
131 },
132 {
133 name: "contains slash",
134 cid: "bafyrei/abc/def",
135 wantErr: ErrInvalidCID,
136 },
137 {
138 name: "contains backslash",
139 cid: "bafyrei\\abc",
140 wantErr: ErrInvalidCID,
141 },
142 {
143 name: "contains double dot",
144 cid: "bafyrei..abc",
145 wantErr: ErrInvalidCID,
146 },
147 {
148 name: "invalid base32 chars",
149 cid: "bafyreihgdyzzpkkzq2izfnhcmm77ycuacvkuziwbnqxfxtqsz7tmxwhnshi!@#",
150 wantErr: ErrInvalidCID,
151 },
152 {
153 name: "random string not matching any CID pattern",
154 cid: "this_is_not_a_valid_cid_at_all_12345",
155 wantErr: ErrInvalidCID,
156 },
157 {
158 name: "too long",
159 cid: "bafyrei" + string(make([]byte, 200)),
160 wantErr: ErrInvalidCID,
161 },
162 }
163
164 for _, tt := range tests {
165 t.Run(tt.name, func(t *testing.T) {
166 err := ValidateCID(tt.cid)
167 if !errors.Is(err, tt.wantErr) {
168 t.Errorf("ValidateCID(%q) = %v, want %v", tt.cid, err, tt.wantErr)
169 }
170 })
171 }
172}
173
174func TestSanitizePathComponent(t *testing.T) {
175 tests := []struct {
176 name string
177 input string
178 want string
179 }{
180 {
181 name: "clean string unchanged",
182 input: "abc123",
183 want: "abc123",
184 },
185 {
186 name: "forward slashes removed",
187 input: "path/to/file",
188 want: "path_to_file",
189 },
190 {
191 name: "backslashes removed",
192 input: "path\\to\\file",
193 want: "path_to_file",
194 },
195 {
196 name: "path traversal removed",
197 input: "../../../etc/passwd",
198 want: "___etc_passwd",
199 },
200 {
201 name: "colons replaced",
202 input: "did:plc:abc123",
203 want: "did_plc_abc123",
204 },
205 {
206 name: "null bytes removed",
207 input: "abc\x00def",
208 want: "abcdef",
209 },
210 {
211 name: "multiple dangerous chars",
212 input: "../path:to\\file\x00.txt",
213 want: "_path_to_file.txt",
214 },
215 }
216
217 for _, tt := range tests {
218 t.Run(tt.name, func(t *testing.T) {
219 got := SanitizePathComponent(tt.input)
220 if got != tt.want {
221 t.Errorf("SanitizePathComponent(%q) = %q, want %q", tt.input, got, tt.want)
222 }
223 })
224 }
225}
226
227func TestMakeDIDSafe_PathTraversal(t *testing.T) {
228 tests := []struct {
229 name string
230 did string
231 check func(result string) bool
232 }{
233 {
234 name: "normal did:plc is safe",
235 did: "did:plc:abc123",
236 check: func(r string) bool {
237 return r == "did_plc_abc123"
238 },
239 },
240 {
241 name: "path traversal sequences removed",
242 did: "did:plc:../../../etc/passwd",
243 check: func(r string) bool {
244 // Should not contain .. or /
245 return !contains(r, "..") && !contains(r, "/") && !contains(r, "\\")
246 },
247 },
248 {
249 name: "forward slashes removed",
250 did: "did:plc:abc/def",
251 check: func(r string) bool {
252 return !contains(r, "/")
253 },
254 },
255 {
256 name: "backslashes removed",
257 did: "did:plc:abc\\def",
258 check: func(r string) bool {
259 return !contains(r, "\\")
260 },
261 },
262 {
263 name: "null bytes removed",
264 did: "did:plc:abc\x00def",
265 check: func(r string) bool {
266 return !contains(r, "\x00")
267 },
268 },
269 }
270
271 for _, tt := range tests {
272 t.Run(tt.name, func(t *testing.T) {
273 result := makeDIDSafe(tt.did)
274 if !tt.check(result) {
275 t.Errorf("makeDIDSafe(%q) = %q, failed safety check", tt.did, result)
276 }
277 })
278 }
279}
280
281func TestMakeCIDSafe_PathTraversal(t *testing.T) {
282 tests := []struct {
283 name string
284 cid string
285 check func(result string) bool
286 }{
287 {
288 name: "normal CID unchanged",
289 cid: "bafyreiabc123",
290 check: func(r string) bool {
291 return r == "bafyreiabc123"
292 },
293 },
294 {
295 name: "path traversal removed",
296 cid: "../../../etc/passwd",
297 check: func(r string) bool {
298 return !contains(r, "..") && !contains(r, "/")
299 },
300 },
301 {
302 name: "forward slashes removed",
303 cid: "abc/def/ghi",
304 check: func(r string) bool {
305 return !contains(r, "/")
306 },
307 },
308 {
309 name: "backslashes removed",
310 cid: "abc\\def\\ghi",
311 check: func(r string) bool {
312 return !contains(r, "\\")
313 },
314 },
315 {
316 name: "null bytes removed",
317 cid: "abc\x00def",
318 check: func(r string) bool {
319 return !contains(r, "\x00")
320 },
321 },
322 }
323
324 for _, tt := range tests {
325 t.Run(tt.name, func(t *testing.T) {
326 result := makeCIDSafe(tt.cid)
327 if !tt.check(result) {
328 t.Errorf("makeCIDSafe(%q) = %q, failed safety check", tt.cid, result)
329 }
330 })
331 }
332}
333
334// helper function for checking string containment
335func contains(s, substr string) bool {
336 for i := 0; i <= len(s)-len(substr); i++ {
337 if s[i:i+len(substr)] == substr {
338 return true
339 }
340 }
341 return false
342}