+1
migrations/006_initbans.down.sql
+1
migrations/006_initbans.down.sql
···
1
+
DROP TABLE IF EXISTS bans;
+7
migrations/006_initbans.up.sql
+7
migrations/006_initbans.up.sql
+66
-2
server/internal/db/db.go
+66
-2
server/internal/db/db.go
···
9
9
"rvcx/internal/types"
10
10
"time"
11
11
12
+
"github.com/jackc/pgx/v5"
12
13
"github.com/jackc/pgx/v5/pgxpool"
13
14
)
14
15
···
321
322
uri := fmt.Sprintf("at://%s/org.xcvr.feed.channel/%s", did, rkey)
322
323
row := s.pool.QueryRow(ctx, `
323
324
SELECT
324
-
channels.uri,
325
-
channels.host,
325
+
channels.uri,
326
+
channels.host,
326
327
channels.title,
327
328
channels.topic,
328
329
channels.created_at,
···
351
352
_, err := s.pool.Exec(ctx, `DELETE FROM channels WHERE uri = $1`, uri)
352
353
return err
353
354
}
355
+
356
+
func (s *Store) GetBanned(did string, ctx context.Context) (*types.Ban, error) {
357
+
row := s.pool.QueryRow(ctx, `SELECT
358
+
id,
359
+
reason,
360
+
till,
361
+
banned_at
362
+
FROM bans WHERE did = $1`, did)
363
+
var ban types.Ban
364
+
err := row.Scan(&ban.Id, &ban.Reason, &ban.Till, &ban.BannedAt)
365
+
if err != nil {
366
+
return nil, err
367
+
}
368
+
ban.Did = did
369
+
return &ban, nil
370
+
}
371
+
372
+
func (s *Store) GetBanId(id int, ctx context.Context) (*types.Ban, error) {
373
+
row := s.pool.QueryRow(ctx, `SELECT
374
+
did,
375
+
reason,
376
+
till,
377
+
banned_at
378
+
FROM bans WHERE id = $1`, id)
379
+
var ban types.Ban
380
+
err := row.Scan(&ban.Id, &ban.Reason, &ban.Till, &ban.BannedAt)
381
+
if err != nil {
382
+
return nil, err
383
+
}
384
+
ban.Id = id
385
+
return &ban, nil
386
+
}
387
+
388
+
func (s *Store) AddBan(did string, reason *string, till *time.Time, ctx context.Context) error {
389
+
_, err := s.pool.Exec(ctx, `INSERT INTO bans (
390
+
did,
391
+
reason,
392
+
till
393
+
) VALUES (
394
+
$1, $2, $3
395
+
)
396
+
`, did, reason, till)
397
+
return err
398
+
}
399
+
400
+
func (s *Store) IsBanned(did string, ctx context.Context) (bool, error) {
401
+
ban, err := s.GetBanned(did, ctx)
402
+
if ban != nil {
403
+
defbanned := false
404
+
if ban.Till == nil {
405
+
defbanned = true
406
+
} else {
407
+
defbanned = time.Now().Before(*ban.Till)
408
+
}
409
+
if defbanned {
410
+
return true, nil
411
+
}
412
+
}
413
+
if err != nil && !errors.Is(err, pgx.ErrNoRows) {
414
+
return false, err
415
+
}
416
+
return false, nil
417
+
}
+5
server/internal/db/oauth.go
+5
server/internal/db/oauth.go
···
101
101
return nil
102
102
}
103
103
104
+
func (s Store) DeleteAllSessions(ctx context.Context, did string) error {
105
+
_, err := s.pool.Exec(ctx, `DELETE FROM sessions WHERE account_did = $1`)
106
+
return err
107
+
}
108
+
104
109
func (s Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
105
110
row := s.pool.QueryRow(ctx, `
106
111
SELECT
+2
server/internal/handler/handler.go
+2
server/internal/handler/handler.go
···
55
55
mux.HandleFunc(oauthJWKSPath(), h.WithCORS(h.serveJWKS))
56
56
mux.HandleFunc("POST /oauth/login", h.oauthLogin)
57
57
mux.HandleFunc("POST /oauth/logout", h.oauthMiddleware(h.oauthLogout))
58
+
mux.HandleFunc("POST /oauth/ban", h.postBan)
59
+
mux.HandleFunc("GET /oauth/ban", h.getBan)
58
60
mux.HandleFunc("GET /oauth/whoami", h.getSession)
59
61
mux.HandleFunc(oauthCallbackPath(), h.WithCORS(h.oauthCallback))
60
62
return h
+83
server/internal/handler/oauthHandlers.go
+83
server/internal/handler/oauthHandlers.go
···
6
6
"fmt"
7
7
"net/http"
8
8
"os"
9
+
"rvcx/internal/atputils"
9
10
"rvcx/internal/oauth"
11
+
"strconv"
10
12
"strings"
13
+
"time"
11
14
12
15
atoauth "github.com/bluesky-social/indigo/atproto/auth/oauth"
13
16
"github.com/bluesky-social/indigo/atproto/syntax"
···
57
60
h.serverError(w, errors.New("my god.... :"+err.Error()))
58
61
return
59
62
}
63
+
isban, err := h.db.IsBanned(sessData.AccountDID.String(), r.Context())
64
+
if err != nil {
65
+
h.serverError(w, errors.New("i'm not sure if user is banned, error, "+err.Error()))
66
+
return
67
+
}
68
+
if isban {
69
+
ban, _ := h.db.GetBanned(sessData.AccountDID.String(), r.Context())
70
+
http.Redirect(w, r, fmt.Sprintf("%s%d", os.Getenv("BAN_ENDPOINT"), ban.Id), http.StatusSeeOther)
71
+
return
72
+
}
73
+
60
74
err = h.rm.CreateInitialProfile(sessData, r.Context())
61
75
if err != nil {
62
76
h.serverError(w, err)
···
152
166
f(cs, w, r)
153
167
}
154
168
}
169
+
170
+
func (h *Handler) postBan(w http.ResponseWriter, r *http.Request) {
171
+
s, _ := h.sessionStore.Get(r, "oauthsession")
172
+
did, bok := s.Values["did"].(string)
173
+
if !bok {
174
+
h.badRequest(w, errors.New("not authorized"))
175
+
return
176
+
}
177
+
handle, err := h.db.ResolveDid(did, r.Context())
178
+
if err != nil {
179
+
h.serverError(w, errors.New("failed to resolve"+err.Error()))
180
+
return
181
+
}
182
+
if handle != os.Getenv("ADMIN_HANDLE") {
183
+
h.badRequest(w, errors.New("must be admin to ban"))
184
+
return
185
+
}
186
+
userhandle := r.Header.Get("user")
187
+
userdid, err := atputils.GetDidFromHandle(r.Context(), userhandle)
188
+
if err != nil {
189
+
h.badRequest(w, errors.New("failed to resolve user handle"))
190
+
return
191
+
}
192
+
daysstring := r.Header.Get("days")
193
+
daysint, err := strconv.Atoi(daysstring)
194
+
var till *time.Time
195
+
if err == nil {
196
+
tillt := time.Now().Add(time.Hour * 24 * time.Duration(daysint))
197
+
till = &tillt
198
+
}
199
+
var reason *string
200
+
reasonstr := r.Header.Get("reason")
201
+
if reasonstr != "" {
202
+
reason = &reasonstr
203
+
}
204
+
err = h.db.AddBan(userdid, reason, till, r.Context())
205
+
if err != nil {
206
+
h.serverError(w, errors.New("failed to ban, "+err.Error()))
207
+
return
208
+
}
209
+
ban, err := h.db.GetBanned(userdid, r.Context())
210
+
if err != nil {
211
+
h.serverError(w, errors.New("succeeded to ban and then failed again"+err.Error()))
212
+
return
213
+
}
214
+
err = h.db.DeleteAllSessions(r.Context(), ban.Did)
215
+
if err != nil {
216
+
h.serverError(w, errors.New("failed to kick user "+ban.Did+err.Error()))
217
+
return
218
+
}
219
+
http.Redirect(w, r, fmt.Sprintf("%s%d", os.Getenv("BAN_ENDPOINT"), ban.Id), http.StatusFound)
220
+
}
221
+
222
+
func (h *Handler) getBan(w http.ResponseWriter, r *http.Request) {
223
+
banid := r.Header.Get("id")
224
+
id, err := strconv.Atoi(banid)
225
+
if err != nil {
226
+
h.badRequest(w, err)
227
+
return
228
+
}
229
+
ban, err := h.db.GetBanId(id, r.Context())
230
+
if err != nil {
231
+
h.serverError(w, err)
232
+
return
233
+
}
234
+
encoder := json.NewEncoder(w)
235
+
w.Header().Add("Content-Type", "application/json")
236
+
encoder.Encode(ban)
237
+
}
+8
-1
server/internal/recordmanager/media.go
+8
-1
server/internal/recordmanager/media.go
···
21
21
}
22
22
23
23
func (rm *RecordManager) AddImageToCache(did string, cid string, ctx context.Context) (string, error) {
24
+
ib, err := rm.db.IsBanned(did, ctx)
25
+
if err != nil {
26
+
return "", err
27
+
}
28
+
if ib {
29
+
return "", errors.New("user banned")
30
+
}
24
31
uploadDir := "./uploads"
25
-
_, err := os.Stat(uploadDir)
32
+
_, err = os.Stat(uploadDir)
26
33
if os.IsNotExist(err) {
27
34
os.Mkdir(uploadDir, 0755)
28
35
}
+8
server/internal/types/oauth.go
+8
server/internal/types/oauth.go