fork of indigo with slightly nicer lexgen
0
fork

Configure Feed

Select the types of activity you want to include in your feed.

medsky radical trim of bigsky (relay) sync 1.1 induction firehose add getRepo 302 redirect note subscribeRepos message deprecation time-seq rename everything from 'bigsky' to 'relay' don't bounce auth key off db handleFedEvent clean-er deprecated {handle,migrate,tombstone} BGS.createExternalUser() is now syncPDSAccount() metric events_warn_counter{pds,warn} a bunch of rename 'user' to 'account' tiny ram blockstore even more minimal than ipld/ipfs do deprecate migrate/handle/tombstone (but don't delete yet so we have the option of measuring stale data received)

authored by

Brian Olson and committed by
Brian Olson
6aa979fa d9a74f69

+7237 -107
+52
.github/workflows/container-relay-aws.yaml
··· 1 + name: container-relay-aws 2 + on: [push] 3 + env: 4 + REGISTRY: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_REGISTRY }} 5 + USERNAME: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_USERNAME }} 6 + PASSWORD: ${{ secrets.AWS_ECR_REGISTRY_USEAST2_PACKAGES_PASSWORD }} 7 + # github.repository as <account>/<repo> 8 + IMAGE_NAME: relay 9 + 10 + jobs: 11 + container-relay-aws: 12 + if: github.repository == 'bluesky-social/indigo' 13 + runs-on: ubuntu-latest 14 + permissions: 15 + contents: read 16 + packages: write 17 + id-token: write 18 + 19 + steps: 20 + - name: Checkout repository 21 + uses: actions/checkout@v3 22 + 23 + - name: Setup Docker buildx 24 + uses: docker/setup-buildx-action@v1 25 + 26 + - name: Log into registry ${{ env.REGISTRY }} 27 + uses: docker/login-action@v2 28 + with: 29 + registry: ${{ env.REGISTRY }} 30 + username: ${{ env.USERNAME }} 31 + password: ${{ env.PASSWORD }} 32 + 33 + - name: Extract Docker metadata 34 + id: meta 35 + uses: docker/metadata-action@v4 36 + with: 37 + images: | 38 + ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 39 + tags: | 40 + type=sha,enable=true,priority=100,prefix=,suffix=,format=long 41 + 42 + - name: Build and push Docker image 43 + id: build-and-push 44 + uses: docker/build-push-action@v4 45 + with: 46 + context: . 47 + file: ./cmd/relay/Dockerfile 48 + push: ${{ github.event_name != 'pull_request' }} 49 + tags: ${{ steps.meta.outputs.tags }} 50 + labels: ${{ steps.meta.outputs.labels }} 51 + cache-from: type=gha 52 + cache-to: type=gha,mode=max
+1
atproto/identity/cache_directory.go
··· 11 11 "github.com/hashicorp/golang-lru/v2/expirable" 12 12 ) 13 13 14 + // CacheDirectory is an implementation of identity.Directory with local cache of Handle and DID 14 15 type CacheDirectory struct { 15 16 Inner Directory 16 17 ErrTTL time.Duration
+49 -3
atproto/repo/car.go
··· 3 3 import ( 4 4 "bytes" 5 5 "context" 6 + "errors" 6 7 "fmt" 7 8 "io" 8 9 9 10 "github.com/bluesky-social/indigo/atproto/repo/mst" 10 11 "github.com/bluesky-social/indigo/atproto/syntax" 11 12 12 - "github.com/ipfs/go-datastore" 13 - blockstore "github.com/ipfs/go-ipfs-blockstore" 13 + blocks "github.com/ipfs/go-block-format" 14 14 "github.com/ipld/go-car" 15 15 ) 16 16 17 17 func LoadFromCAR(ctx context.Context, r io.Reader) (*Commit, *Repo, error) { 18 18 19 - bs := blockstore.NewBlockstore(datastore.NewMapDatastore()) 19 + //bs := blockstore.NewBlockstore(datastore.NewMapDatastore()) 20 + bs := NewTinyBlockstore() 20 21 21 22 cr, err := car.NewCarReader(r) 22 23 if err != nil { ··· 71 72 } 72 73 return &commit, &repo, nil 73 74 } 75 + 76 + var ErrNoRoot = errors.New("CAR file missing root CID") 77 + var ErrNoCommit = errors.New("no commit") 78 + 79 + // LoadCARCommit is like LoadFromCAR() but filters to only return the commit object. 80 + // useful for subscribeRepos/firehose `#sync` message 81 + func LoadCARCommit(ctx context.Context, r io.Reader) (*Commit, error) { 82 + cr, err := car.NewCarReader(r) 83 + if err != nil { 84 + return nil, err 85 + } 86 + if cr.Header.Version != 1 { 87 + return nil, fmt.Errorf("unsupported CAR file version: %d", cr.Header.Version) 88 + } 89 + if len(cr.Header.Roots) < 1 { 90 + return nil, ErrNoRoot 91 + } 92 + commitCID := cr.Header.Roots[0] 93 + var commitBlock blocks.Block 94 + for { 95 + blk, err := cr.Next() 96 + if err != nil { 97 + if err == io.EOF { 98 + break 99 + } 100 + return nil, err 101 + } 102 + 103 + if blk.Cid().Equals(commitCID) { 104 + commitBlock = blk 105 + break 106 + } 107 + } 108 + if commitBlock == nil { 109 + return nil, ErrNoCommit 110 + } 111 + var commit Commit 112 + if err := commit.UnmarshalCBOR(bytes.NewReader(commitBlock.RawData())); err != nil { 113 + return nil, fmt.Errorf("parsing commit block from CAR file: %w", err) 114 + } 115 + if err := commit.VerifyStructure(); err != nil { 116 + return nil, fmt.Errorf("parsing commit block from CAR file: %w", err) 117 + } 118 + return &commit, nil 119 + }
+1 -1
atproto/repo/mst/encoding.go
··· 199 199 return c, nil 200 200 } 201 201 202 - func loadNodeFromStore(ctx context.Context, bs blockstore.Blockstore, ref cid.Cid) (*Node, error) { 202 + func loadNodeFromStore(ctx context.Context, bs MSTBlockSource, ref cid.Cid) (*Node, error) { 203 203 block, err := bs.Get(ctx, ref) 204 204 if err != nil { 205 205 return nil, err
+7 -1
atproto/repo/mst/tree.go
··· 5 5 "errors" 6 6 "fmt" 7 7 8 + blocks "github.com/ipfs/go-block-format" 8 9 "github.com/ipfs/go-cid" 9 10 blockstore "github.com/ipfs/go-ipfs-blockstore" 10 11 ) ··· 148 149 } 149 150 } 150 151 151 - func LoadTreeFromStore(ctx context.Context, bs blockstore.Blockstore, root cid.Cid) (*Tree, error) { 152 + func LoadTreeFromStore(ctx context.Context, bs MSTBlockSource, root cid.Cid) (*Tree, error) { 152 153 n, err := loadNodeFromStore(ctx, bs, root) 153 154 if err != nil { 154 155 return nil, err ··· 157 158 return &Tree{ 158 159 Root: n, 159 160 }, nil 161 + } 162 + 163 + // subset of Blockstore that we actually need 164 + type MSTBlockSource interface { 165 + Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) 160 166 } 161 167 162 168 // Walks the tree, encodes any "dirty" nodes as CBOR data, and writes that data as blocks to the provided blockstore. Returns root CID.
+16 -12
atproto/repo/repo.go
··· 7 7 "github.com/bluesky-social/indigo/atproto/repo/mst" 8 8 "github.com/bluesky-social/indigo/atproto/syntax" 9 9 10 + blocks "github.com/ipfs/go-block-format" 10 11 "github.com/ipfs/go-cid" 11 - "github.com/ipfs/go-datastore" 12 - blockstore "github.com/ipfs/go-ipfs-blockstore" 13 12 ) 14 13 15 14 // Version of the repo data format implemented in this package ··· 20 19 DID syntax.DID 21 20 Clock *syntax.TIDClock 22 21 23 - RecordStore blockstore.Blockstore 22 + RecordStore RepoBlockSource // formerly blockstore.Blockstore 24 23 MST mst.Tree 25 24 } 26 25 26 + // subset of Blockstore that we actually need 27 + type RepoBlockSource interface { 28 + Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) 29 + } 30 + 27 31 var ErrNotFound = errors.New("record not found in repository") 28 32 29 - func NewEmptyRepo(did syntax.DID) Repo { 30 - clk := syntax.NewTIDClock(0) 31 - return Repo{ 32 - DID: did, 33 - Clock: &clk, 34 - RecordStore: blockstore.NewBlockstore(datastore.NewMapDatastore()), 35 - MST: mst.NewEmptyTree(), 36 - } 37 - } 33 + //func NewEmptyRepo(did syntax.DID) Repo { 34 + // clk := syntax.NewTIDClock(0) 35 + // return Repo{ 36 + // DID: did, 37 + // Clock: &clk, 38 + // RecordStore: blockstore.NewBlockstore(datastore.NewMapDatastore()), 39 + // MST: mst.NewEmptyTree(), 40 + // } 41 + //} 38 42 39 43 func (repo *Repo) GetRecordCID(ctx context.Context, collection syntax.NSID, rkey syntax.RecordKey) (*cid.Cid, error) { 40 44 path := collection.String() + "/" + rkey.String()
+33
atproto/repo/tiny_blockstore.go
··· 1 + package repo 2 + 3 + import ( 4 + "context" 5 + 6 + blocks "github.com/ipfs/go-block-format" 7 + "github.com/ipfs/go-cid" 8 + ipld "github.com/ipfs/go-ipld-format" 9 + ) 10 + 11 + type TinyBlockstore struct { 12 + blocks map[string]blocks.Block 13 + } 14 + 15 + func NewTinyBlockstore() *TinyBlockstore { 16 + return &TinyBlockstore{blocks: make(map[string]blocks.Block, 20)} 17 + } 18 + 19 + func (tb *TinyBlockstore) Put(_ context.Context, block blocks.Block) error { 20 + ncid := block.Cid() 21 + key := ncid.KeyString() 22 + tb.blocks[key] = block 23 + return nil 24 + } 25 + 26 + func (tb *TinyBlockstore) Get(_ context.Context, ncid cid.Cid) (blocks.Block, error) { 27 + key := ncid.KeyString() 28 + block, found := tb.blocks[key] 29 + if found { 30 + return block, nil 31 + } 32 + return nil, &ipld.ErrNotFound{Cid: ncid} 33 + }
+1 -1
cmd/bigsky/main.go
··· 306 306 signals := make(chan os.Signal, 1) 307 307 signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) 308 308 309 - _, err := cliutil.SetupSlog(cliutil.LogOptions{}) 309 + _, _, err := cliutil.SetupSlog(cliutil.LogOptions{}) 310 310 if err != nil { 311 311 return err 312 312 }
+1 -1
cmd/gosky/main.go
··· 81 81 }, 82 82 } 83 83 84 - _, err := cliutil.SetupSlog(cliutil.LogOptions{}) 84 + _, _, err := cliutil.SetupSlog(cliutil.LogOptions{}) 85 85 if err != nil { 86 86 fmt.Fprintf(os.Stderr, "logging setup error: %s\n", err.Error()) 87 87 os.Exit(1)
+49
cmd/relay/Dockerfile
··· 1 + # Run this dockerfile from the top level of the indigo git repository like: 2 + # 3 + # podman build -f ./cmd/relay/Dockerfile -t relay . 4 + 5 + ### Compile stage 6 + FROM golang:1.23-alpine3.20 AS build-env 7 + RUN apk add --no-cache build-base make git 8 + 9 + ADD . /dockerbuild 10 + WORKDIR /dockerbuild 11 + 12 + # timezone data for alpine builds 13 + ENV GOEXPERIMENT=loopvar 14 + RUN GIT_VERSION=$(git describe --tags --long --always) && \ 15 + go build -tags timetzdata -o /relay ./cmd/relay 16 + 17 + ### Build Frontend stage 18 + FROM node:18-alpine as web-builder 19 + 20 + WORKDIR /app 21 + 22 + COPY ts/bgs-dash /app/ 23 + 24 + RUN yarn install --frozen-lockfile 25 + 26 + RUN yarn build 27 + 28 + ### Run stage 29 + FROM alpine:3.20 30 + 31 + RUN apk add --no-cache --update dumb-init ca-certificates runit 32 + ENTRYPOINT ["dumb-init", "--"] 33 + 34 + WORKDIR / 35 + RUN mkdir -p data/relay 36 + COPY --from=build-env /relay / 37 + COPY --from=web-builder /app/dist/ public/ 38 + 39 + # small things to make golang binaries work well under alpine 40 + ENV GODEBUG=netdns=go 41 + ENV TZ=Etc/UTC 42 + 43 + EXPOSE 2470 44 + 45 + CMD ["/relay"] 46 + 47 + LABEL org.opencontainers.image.source=https://github.com/bluesky-social/indigo 48 + LABEL org.opencontainers.image.description="atproto Relay" 49 + LABEL org.opencontainers.image.licenses=MIT
+328
cmd/relay/README.md
··· 1 + 2 + atproto Relay Service 3 + =============================== 4 + 5 + *NOTE: "Relays" used to be called "Big Graph Servers", or "BGS", or "bigsky". Many variables and packages still reference "bgs"* 6 + 7 + This is the implementation of an atproto Relay which is running in the production network, written and operated by Bluesky. 8 + 9 + In atproto, a Relay subscribes to multiple PDS hosts and outputs a combined "firehose" event stream. Downstream services can subscribe to this single firehose a get all relevant events for the entire network, or a specific sub-graph of the network. The Relay maintains a mirror of repo data from all accounts on the upstream PDS instances, and verifies repo data structure integrity and identity signatures. It is agnostic to applications, and does not validate data against atproto Lexicon schemas. 10 + 11 + This Relay implementation is designed to subscribe to the entire global network. The current state of the codebase is informally expected to scale to around 50 million accounts in the network, and thousands of repo events per second (peak). 12 + 13 + Features and design decisions: 14 + 15 + - runs on a single server 16 + - crawling and account state: stored in SQL database 17 + - SQL driver: gorm, with PostgreSQL in production and sqlite for testing 18 + - highly concurrent: not particularly CPU intensive 19 + - single golang binary for easy deployment 20 + - observability: logging, prometheus metrics, OTEL traces 21 + - admin web interface: configure limits, add upstream PDS instances, etc 22 + 23 + This software is not as packaged, documented, and supported for self-hosting as our PDS distribution or Ozone service. But it is relatively simple and inexpensive to get running. 24 + 25 + A note and reminder about Relays in general are that they are more of a convenience in the protocol than a hard requirement. The "firehose" API is the exact same on the PDS and on a Relay. Any service which subscribes to the Relay could instead connect to one or more PDS instances directly. 26 + 27 + 28 + ## Development Tips 29 + 30 + The README and Makefile at the top level of this git repo have some generic helpers for testing, linting, formatting code, etc. 31 + 32 + To re-build and run the Relay locally: 33 + 34 + make run-dev-relay 35 + 36 + You can re-build and run the command directly to get a list of configuration flags and env vars; env vars will be loaded from `.env` if that file exists: 37 + 38 + RELAY_ADMIN_KEY=localdev go run ./cmd/relay/ --help 39 + 40 + By default, the daemon will use sqlite for databases (in the directory `./data/bigsky/`), CAR data will be stored as individual shard files in `./data/bigsky/carstore/`), and the HTTP API will be bound to localhost port 2470. 41 + 42 + When the daemon isn't running, sqlite database files can be inspected with: 43 + 44 + sqlite3 data/bigsky/bgs.sqlite 45 + [...] 46 + sqlite> .schema 47 + 48 + Wipe all local data: 49 + 50 + # careful! double-check this destructive command 51 + rm -rf ./data/bigsky/* 52 + 53 + There is a basic web dashboard, though it will not be included unless built and copied to a local directory `./public/`. Run `make build-relay-ui`, and then when running the daemon the dashboard will be available at: <http://localhost:2470/dash/>. Paste in the admin key, eg `localdev`. 54 + 55 + The local admin routes can also be accessed by passing the admin key as a bearer token, for example: 56 + 57 + http get :2470/admin/pds/list Authorization:"Bearer localdev" 58 + 59 + Request crawl of an individual PDS instance like: 60 + 61 + http post :2470/admin/pds/requestCrawl Authorization:"Bearer localdev" hostname=pds.example.com 62 + 63 + 64 + ## Docker Containers 65 + 66 + One way to deploy is running a docker image. You can pull and/or run a specific version of bigsky, referenced by git commit, from the Bluesky Github container registry. For example: 67 + 68 + docker pull ghcr.io/bluesky-social/indigo:relay-fd66f93ce1412a3678a1dd3e6d53320b725978a6 69 + docker run ghcr.io/bluesky-social/indigo:relay-fd66f93ce1412a3678a1dd3e6d53320b725978a6 70 + 71 + There is a Dockerfile in this directory, which can be used to build customized/patched versions of the Relay as a container, republish them, run locally, deploy to servers, deploy to an orchestrated cluster, etc. See docs and guides for docker and cluster management systems for details. 72 + 73 + 74 + ## Database Setup 75 + 76 + PostgreSQL and Sqlite are both supported. When using Sqlite, separate files are used for Relay metadata and CarStore metadata. With PostgreSQL a single database server, user, and logical database can all be reused: table names will not conflict. 77 + 78 + Database configuration is passed via the `DATABASE_URL` and `CARSTORE_DATABASE_URL` environment variables, or the corresponding CLI args. 79 + 80 + For PostgreSQL, the user and database must already be configured. Some example SQL commands are: 81 + 82 + CREATE DATABASE bgs; 83 + CREATE DATABASE carstore; 84 + 85 + CREATE USER ${username} WITH PASSWORD '${password}'; 86 + GRANT ALL PRIVILEGES ON DATABASE bgs TO ${username}; 87 + GRANT ALL PRIVILEGES ON DATABASE carstore TO ${username}; 88 + 89 + This service currently uses `gorm` to automatically run database migrations as the regular user. There is no concept of running a separate set of migrations under more privileged database user. 90 + 91 + 92 + ## Deployment 93 + 94 + *NOTE: this is not a complete guide to operating a Relay. There are decisions to be made and communicated about policies, bandwidth use, PDS crawling and rate-limits, financial sustainability, etc, which are not covered here. This is just a quick overview of how to technically get a relay up and running.* 95 + 96 + In a real-world system, you will probably want to use PostgreSQL. 97 + 98 + Some notable configuration env vars to set: 99 + 100 + - `ENVIRONMENT`: eg, `production` 101 + - `DATABASE_URL`: see section below 102 + - `DATA_DIR`: misc data will go in a subdirectory 103 + - `GOLOG_LOG_LEVEL`: log verbosity 104 + - `RESOLVE_ADDRESS`: DNS server to use 105 + - `FORCE_DNS_UDP`: recommend "true" 106 + 107 + There is a health check endpoint at `/xrpc/_health`. Prometheus metrics are exposed by default on port 2471, path `/metrics`. The service logs fairly verbosely to stderr; use `GOLOG_LOG_LEVEL` to control log volume. 108 + 109 + As a rough guideline for the compute resources needed to run a full-network Relay, in June 2024 an example Relay for over 5 million repositories used: 110 + 111 + - roughly 1 TByte of disk for PostgreSQL 112 + - roughly 1 TByte of disk for event playback buffer 113 + - roughly 5k disk I/O operations per second (all combined) 114 + - roughly 100% of one CPU core (quite low CPU utilization) 115 + - roughly 5GB of RAM for `relay`, and as much RAM as available for PostgreSQL and page cache 116 + - on the order of 1 megabit inbound bandwidth (crawling PDS instances) and 1 megabit outbound per connected client. 1 mbit continuous is approximately 350 GByte/month 117 + 118 + Be sure to double-check bandwidth usage and pricing if running a public relay! Bandwidth prices can vary widely between providers, and popular cloud services (AWS, Google Cloud, Azure) are very expensive compared to alternatives like OVH or Hetzner. 119 + 120 + 121 + ## Bootstrapping the Network 122 + 123 + To bootstrap the entire network, you'll want to start with a list of large PDS instances to backfill from. You could pull from a public dashboard of instances (like [mackuba's](https://blue.mackuba.eu/directory/pdses)), or scrape the full DID PLC directory, parse out all PDS service declarations, and sort by count. 124 + 125 + Once you have a set of PDS hosts, you can put the bare hostnames (not URLs: no `https://` prefix, port, or path suffix) in a `hosts.txt` file, and then use the `crawl_pds.sh` script to backfill and configure limits for all of them: 126 + 127 + export RELAY_HOST=your.pds.hostname.tld 128 + export RELAY_ADMIN_KEY=your-secret-key 129 + 130 + # both request crawl, and set generous crawl limits for each 131 + cat hosts.txt | parallel -j1 ./crawl_pds.sh {} 132 + 133 + Just consuming from the firehose for a few hours will only backfill accounts with activity during that period. This is fine to get the backfill process started, but eventually you'll want to do full "resync" of all the repositories on the PDS host to the most recent repo rev version. To enqueue that for all the PDS instances: 134 + 135 + # start sync/backfill of all accounts 136 + cat hosts.txt | parallel -j1 ./sync_pds.sh {} 137 + 138 + Lastly, can monitor progress of any ongoing re-syncs: 139 + 140 + # check sync progress for all hosts 141 + cat hosts.txt | parallel -j1 ./sync_pds.sh {} 142 + 143 + 144 + ## Admin API 145 + 146 + The relay has a number of admin HTTP API endpoints. Given a relay setup listening on port 2470 and with a reasonably secure admin secret: 147 + 148 + ``` 149 + RELAY_ADMIN_PASSWORD=$(openssl rand --hex 16) 150 + relay --api-listen :2470 --admin-key ${RELAY_ADMIN_PASSWORD} ... 151 + ``` 152 + 153 + One can, for example, begin compaction of all repos 154 + 155 + ``` 156 + curl -H 'Authorization: Bearer '${RELAY_ADMIN_PASSWORD} -H 'Content-Type: application/x-www-form-urlencoded' --data '' http://127.0.0.1:2470/admin/repo/compactAll 157 + ``` 158 + 159 + ### /admin/subs/getUpstreamConns 160 + 161 + Return list of PDS host names in json array of strings: ["host", ...] 162 + 163 + ### /admin/subs/perDayLimit 164 + 165 + Return `{"limit": int}` for the number of new PDS subscriptions that the relay may start in a rolling 24 hour window. 166 + 167 + ### /admin/subs/setPerDayLimit 168 + 169 + POST with `?limit={int}` to set the number of new PDS subscriptions that the relay may start in a rolling 24 hour window. 170 + 171 + ### /admin/subs/setEnabled 172 + 173 + POST with param `?enabled=true` or `?enabled=false` to enable or disable PDS-requested new-PDS crawling. 174 + 175 + ### /admin/subs/getEnabled 176 + 177 + Return `{"enabled": bool}` if non-admin new PDS crawl requests are enabled 178 + 179 + ### /admin/subs/killUpstream 180 + 181 + POST with `?host={pds host name}` to disconnect from their firehose. 182 + 183 + Optionally add `&block=true` to prevent connecting to them in the future. 184 + 185 + ### /admin/subs/listDomainBans 186 + 187 + Return `{"banned_domains": ["host name", ...]}` 188 + 189 + ### /admin/subs/banDomain 190 + 191 + POST `{"Domain": "host name"}` to ban a domain 192 + 193 + ### /admin/subs/unbanDomain 194 + 195 + POST `{"Domain": "host name"}` to un-ban a domain 196 + 197 + ### /admin/repo/takeDown 198 + 199 + POST `{"did": "did:..."}` to take-down a bad repo; deletes all local data for the repo 200 + 201 + ### /admin/repo/reverseTakedown 202 + 203 + POST `?did={did:...}` to reverse a repo take-down 204 + 205 + ### /admin/repo/compact 206 + 207 + POST `?did={did:...}` to compact a repo. Optionally `&fast=true`. HTTP blocks until the compaction finishes. 208 + 209 + ### /admin/repo/compactAll 210 + 211 + POST to begin compaction of all repos. Optional query params: 212 + 213 + * `fast=true` 214 + * `limit={int}` maximum number of repos to compact (biggest first) (default 50) 215 + * `threhsold={int}` minimum number of shard files a repo must have on disk to merit compaction (default 20) 216 + 217 + ### /admin/repo/reset 218 + 219 + POST `?did={did:...}` deletes all local data for the repo 220 + 221 + ### /admin/repo/verify 222 + 223 + POST `?did={did:...}` checks that all repo data is accessible. HTTP blocks until done. 224 + 225 + ### /admin/pds/requestCrawl 226 + 227 + POST `{"hostname":"pds host"}` to start crawling a PDS 228 + 229 + ### /admin/pds/list 230 + 231 + GET returns JSON list of records 232 + ```json 233 + [{ 234 + "Host": string, 235 + "Did": string, 236 + "SSL": bool, 237 + "Cursor": int, 238 + "Registered": bool, 239 + "Blocked": bool, 240 + "RateLimit": float, 241 + "CrawlRateLimit": float, 242 + "RepoCount": int, 243 + "RepoLimit": int, 244 + "HourlyEventLimit": int, 245 + "DailyEventLimit": int, 246 + 247 + "HasActiveConnection": bool, 248 + "EventsSeenSinceStartup": int, 249 + "PerSecondEventRate": {"Max": float, "Window": float seconds}, 250 + "PerHourEventRate": {"Max": float, "Window": float seconds}, 251 + "PerDayEventRate": {"Max": float, "Window": float seconds}, 252 + "CrawlRate": {"Max": float, "Window": float seconds}, 253 + "UserCount": int, 254 + }, ...] 255 + ``` 256 + 257 + ### /admin/pds/resync 258 + 259 + POST `?host={host}` to start a resync of a PDS 260 + 261 + GET `?host={host}` to get status of a PDS resync, return 262 + 263 + ```json 264 + {"resync": { 265 + "pds": { 266 + "Host": string, 267 + "Did": string, 268 + "SSL": bool, 269 + "Cursor": int, 270 + "Registered": bool, 271 + "Blocked": bool, 272 + "RateLimit": float, 273 + "CrawlRateLimit": float, 274 + "RepoCount": int, 275 + "RepoLimit": int, 276 + "HourlyEventLimit": int, 277 + "DailyEventLimit": int, 278 + }, 279 + "numRepoPages": int, 280 + "numRepos": int, 281 + "numReposChecked": int, 282 + "numReposToResync": int, 283 + "status": string, 284 + "statusChangedAt": time, 285 + }} 286 + ``` 287 + 288 + ### /admin/pds/changeLimits 289 + 290 + POST to set the limits for a PDS. body: 291 + 292 + ```json 293 + { 294 + "host": string, 295 + "per_second": int, 296 + "per_hour": int, 297 + "per_day": int, 298 + "crawl_rate": int, 299 + "repo_limit": int, 300 + } 301 + ``` 302 + 303 + ### /admin/pds/block 304 + 305 + POST `?host={host}` to block a PDS 306 + 307 + ### /admin/pds/unblock 308 + 309 + POST `?host={host}` to un-block a PDS 310 + 311 + 312 + ### /admin/pds/addTrustedDomain 313 + 314 + POST `?domain={}` to make a domain trusted 315 + 316 + ### /admin/consumers/list 317 + 318 + GET returns list json of clients currently reading from the relay firehose 319 + 320 + ```json 321 + [{ 322 + "id": int, 323 + "remote_addr": string, 324 + "user_agent": string, 325 + "events_consumed": int, 326 + "connected_at": time, 327 + }, ...] 328 + ```
+539
cmd/relay/bgs/admin.go
··· 1 + package bgs 2 + 3 + import ( 4 + "errors" 5 + "fmt" 6 + "net/http" 7 + "net/url" 8 + "slices" 9 + "strconv" 10 + "strings" 11 + "time" 12 + 13 + "github.com/bluesky-social/indigo/cmd/relay/models" 14 + "github.com/labstack/echo/v4" 15 + dto "github.com/prometheus/client_model/go" 16 + "gorm.io/gorm" 17 + ) 18 + 19 + func (bgs *BGS) handleAdminSetSubsEnabled(e echo.Context) error { 20 + enabled, err := strconv.ParseBool(e.QueryParam("enabled")) 21 + if err != nil { 22 + return &echo.HTTPError{ 23 + Code: 400, 24 + Message: err.Error(), 25 + } 26 + } 27 + 28 + return bgs.slurper.SetNewSubsDisabled(!enabled) 29 + } 30 + 31 + func (bgs *BGS) handleAdminGetSubsEnabled(e echo.Context) error { 32 + return e.JSON(200, map[string]bool{ 33 + "enabled": !bgs.slurper.GetNewSubsDisabledState(), 34 + }) 35 + } 36 + 37 + func (bgs *BGS) handleAdminGetNewPDSPerDayRateLimit(e echo.Context) error { 38 + limit := bgs.slurper.GetNewPDSPerDayLimit() 39 + return e.JSON(200, map[string]int64{ 40 + "limit": limit, 41 + }) 42 + } 43 + 44 + func (bgs *BGS) handleAdminSetNewPDSPerDayRateLimit(e echo.Context) error { 45 + limit, err := strconv.ParseInt(e.QueryParam("limit"), 10, 64) 46 + if err != nil { 47 + return &echo.HTTPError{ 48 + Code: 400, 49 + Message: fmt.Errorf("failed to parse limit: %w", err).Error(), 50 + } 51 + } 52 + 53 + err = bgs.slurper.SetNewPDSPerDayLimit(limit) 54 + if err != nil { 55 + return &echo.HTTPError{ 56 + Code: 500, 57 + Message: fmt.Errorf("failed to set new PDS per day rate limit: %w", err).Error(), 58 + } 59 + } 60 + 61 + return nil 62 + } 63 + 64 + func (bgs *BGS) handleAdminTakeDownRepo(e echo.Context) error { 65 + ctx := e.Request().Context() 66 + 67 + var body map[string]string 68 + if err := e.Bind(&body); err != nil { 69 + return err 70 + } 71 + did, ok := body["did"] 72 + if !ok { 73 + return &echo.HTTPError{ 74 + Code: 400, 75 + Message: "must specify did parameter in body", 76 + } 77 + } 78 + 79 + err := bgs.TakeDownRepo(ctx, did) 80 + if err != nil { 81 + if errors.Is(err, gorm.ErrRecordNotFound) { 82 + return &echo.HTTPError{ 83 + Code: http.StatusNotFound, 84 + Message: "repo not found", 85 + } 86 + } 87 + return &echo.HTTPError{ 88 + Code: http.StatusInternalServerError, 89 + Message: err.Error(), 90 + } 91 + } 92 + return nil 93 + } 94 + 95 + func (bgs *BGS) handleAdminReverseTakedown(e echo.Context) error { 96 + did := e.QueryParam("did") 97 + ctx := e.Request().Context() 98 + err := bgs.ReverseTakedown(ctx, did) 99 + 100 + if err != nil { 101 + if errors.Is(err, gorm.ErrRecordNotFound) { 102 + return &echo.HTTPError{ 103 + Code: http.StatusNotFound, 104 + Message: "repo not found", 105 + } 106 + } 107 + return &echo.HTTPError{ 108 + Code: http.StatusInternalServerError, 109 + Message: err.Error(), 110 + } 111 + } 112 + 113 + return nil 114 + } 115 + 116 + type ListTakedownsResponse struct { 117 + Dids []string `json:"dids"` 118 + Cursor int64 `json:"cursor,omitempty"` 119 + } 120 + 121 + func (bgs *BGS) handleAdminListRepoTakeDowns(e echo.Context) error { 122 + ctx := e.Request().Context() 123 + haveMinId := false 124 + minId := int64(-1) 125 + qmin := e.QueryParam("cursor") 126 + if qmin != "" { 127 + tmin, err := strconv.ParseInt(qmin, 10, 64) 128 + if err != nil { 129 + return &echo.HTTPError{Code: 400, Message: "bad cursor"} 130 + } 131 + minId = tmin 132 + haveMinId = true 133 + } 134 + limit := 1000 135 + wat := bgs.db.Model(Account{}).WithContext(ctx).Select("id", "did").Where("taken_down = TRUE") 136 + if haveMinId { 137 + wat = wat.Where("id > ?", minId) 138 + } 139 + //var users []Account 140 + rows, err := wat.Order("id").Limit(limit).Rows() 141 + if err != nil { 142 + return echo.NewHTTPError(http.StatusInternalServerError, "oops").WithInternal(err) 143 + } 144 + var out ListTakedownsResponse 145 + for rows.Next() { 146 + var id int64 147 + var did string 148 + err := rows.Scan(&id, &did) 149 + if err != nil { 150 + return echo.NewHTTPError(http.StatusInternalServerError, "oops").WithInternal(err) 151 + } 152 + out.Dids = append(out.Dids, did) 153 + out.Cursor = id 154 + } 155 + if len(out.Dids) < limit { 156 + out.Cursor = 0 157 + } 158 + return e.JSON(200, out) 159 + } 160 + 161 + func (bgs *BGS) handleAdminGetUpstreamConns(e echo.Context) error { 162 + return e.JSON(200, bgs.slurper.GetActiveList()) 163 + } 164 + 165 + type rateLimit struct { 166 + Max float64 `json:"Max"` 167 + WindowSeconds float64 `json:"Window"` 168 + } 169 + 170 + type enrichedPDS struct { 171 + models.PDS 172 + HasActiveConnection bool `json:"HasActiveConnection"` 173 + EventsSeenSinceStartup uint64 `json:"EventsSeenSinceStartup"` 174 + PerSecondEventRate rateLimit `json:"PerSecondEventRate"` 175 + PerHourEventRate rateLimit `json:"PerHourEventRate"` 176 + PerDayEventRate rateLimit `json:"PerDayEventRate"` 177 + UserCount int64 `json:"UserCount"` 178 + } 179 + 180 + type UserCount struct { 181 + PDSID uint `gorm:"column:pds"` 182 + UserCount int64 `gorm:"column:user_count"` 183 + } 184 + 185 + func (bgs *BGS) handleListPDSs(e echo.Context) error { 186 + var pds []models.PDS 187 + if err := bgs.db.Find(&pds).Error; err != nil { 188 + return err 189 + } 190 + 191 + enrichedPDSs := make([]enrichedPDS, len(pds)) 192 + 193 + activePDSHosts := bgs.slurper.GetActiveList() 194 + 195 + for i, p := range pds { 196 + enrichedPDSs[i].PDS = p 197 + enrichedPDSs[i].HasActiveConnection = false 198 + for _, host := range activePDSHosts { 199 + if strings.ToLower(host) == strings.ToLower(p.Host) { 200 + enrichedPDSs[i].HasActiveConnection = true 201 + break 202 + } 203 + } 204 + var m = &dto.Metric{} 205 + if err := eventsReceivedCounter.WithLabelValues(p.Host).Write(m); err != nil { 206 + enrichedPDSs[i].EventsSeenSinceStartup = 0 207 + continue 208 + } 209 + enrichedPDSs[i].EventsSeenSinceStartup = uint64(m.Counter.GetValue()) 210 + 211 + enrichedPDSs[i].PerSecondEventRate = rateLimit{ 212 + Max: p.RateLimit, 213 + WindowSeconds: 1, 214 + } 215 + 216 + enrichedPDSs[i].PerHourEventRate = rateLimit{ 217 + Max: float64(p.HourlyEventLimit), 218 + WindowSeconds: 3600, 219 + } 220 + 221 + enrichedPDSs[i].PerDayEventRate = rateLimit{ 222 + Max: float64(p.DailyEventLimit), 223 + WindowSeconds: 86400, 224 + } 225 + } 226 + 227 + return e.JSON(200, enrichedPDSs) 228 + } 229 + 230 + type consumer struct { 231 + ID uint64 `json:"id"` 232 + RemoteAddr string `json:"remote_addr"` 233 + UserAgent string `json:"user_agent"` 234 + EventsConsumed uint64 `json:"events_consumed"` 235 + ConnectedAt time.Time `json:"connected_at"` 236 + } 237 + 238 + func (bgs *BGS) handleAdminListConsumers(e echo.Context) error { 239 + bgs.consumersLk.RLock() 240 + defer bgs.consumersLk.RUnlock() 241 + 242 + consumers := make([]consumer, 0, len(bgs.consumers)) 243 + for id, c := range bgs.consumers { 244 + var m = &dto.Metric{} 245 + if err := c.EventsSent.Write(m); err != nil { 246 + continue 247 + } 248 + consumers = append(consumers, consumer{ 249 + ID: id, 250 + RemoteAddr: c.RemoteAddr, 251 + UserAgent: c.UserAgent, 252 + EventsConsumed: uint64(m.Counter.GetValue()), 253 + ConnectedAt: c.ConnectedAt, 254 + }) 255 + } 256 + 257 + return e.JSON(200, consumers) 258 + } 259 + 260 + func (bgs *BGS) handleAdminKillUpstreamConn(e echo.Context) error { 261 + host := strings.TrimSpace(e.QueryParam("host")) 262 + if host == "" { 263 + return &echo.HTTPError{ 264 + Code: 400, 265 + Message: "must pass a valid host", 266 + } 267 + } 268 + 269 + block := strings.ToLower(e.QueryParam("block")) == "true" 270 + 271 + if err := bgs.slurper.KillUpstreamConnection(host, block); err != nil { 272 + if errors.Is(err, ErrNoActiveConnection) { 273 + return &echo.HTTPError{ 274 + Code: 400, 275 + Message: "no active connection to given host", 276 + } 277 + } 278 + return err 279 + } 280 + 281 + return e.JSON(200, map[string]any{ 282 + "success": "true", 283 + }) 284 + } 285 + 286 + func (bgs *BGS) handleBlockPDS(e echo.Context) error { 287 + host := strings.TrimSpace(e.QueryParam("host")) 288 + if host == "" { 289 + return &echo.HTTPError{ 290 + Code: 400, 291 + Message: "must pass a valid host", 292 + } 293 + } 294 + 295 + // Set the block flag to true in the DB 296 + if err := bgs.db.Model(&models.PDS{}).Where("host = ?", host).Update("blocked", true).Error; err != nil { 297 + return err 298 + } 299 + 300 + // don't care if this errors, but we should try to disconnect something we just blocked 301 + _ = bgs.slurper.KillUpstreamConnection(host, false) 302 + 303 + return e.JSON(200, map[string]any{ 304 + "success": "true", 305 + }) 306 + } 307 + 308 + func (bgs *BGS) handleUnblockPDS(e echo.Context) error { 309 + host := strings.TrimSpace(e.QueryParam("host")) 310 + if host == "" { 311 + return &echo.HTTPError{ 312 + Code: 400, 313 + Message: "must pass a valid host", 314 + } 315 + } 316 + 317 + // Set the block flag to false in the DB 318 + if err := bgs.db.Model(&models.PDS{}).Where("host = ?", host).Update("blocked", false).Error; err != nil { 319 + return err 320 + } 321 + 322 + return e.JSON(200, map[string]any{ 323 + "success": "true", 324 + }) 325 + } 326 + 327 + type bannedDomains struct { 328 + BannedDomains []string `json:"banned_domains"` 329 + } 330 + 331 + func (bgs *BGS) handleAdminListDomainBans(c echo.Context) error { 332 + var all []DomainBan 333 + if err := bgs.db.Find(&all).Error; err != nil { 334 + return err 335 + } 336 + 337 + resp := bannedDomains{ 338 + BannedDomains: []string{}, 339 + } 340 + for _, b := range all { 341 + resp.BannedDomains = append(resp.BannedDomains, b.Domain) 342 + } 343 + 344 + return c.JSON(200, resp) 345 + } 346 + 347 + type banDomainBody struct { 348 + Domain string 349 + } 350 + 351 + func (bgs *BGS) handleAdminBanDomain(c echo.Context) error { 352 + var body banDomainBody 353 + if err := c.Bind(&body); err != nil { 354 + return err 355 + } 356 + 357 + // Check if the domain is already banned 358 + var existing DomainBan 359 + if err := bgs.db.Where("domain = ?", body.Domain).First(&existing).Error; err == nil { 360 + return &echo.HTTPError{ 361 + Code: 400, 362 + Message: "domain is already banned", 363 + } 364 + } 365 + 366 + if err := bgs.db.Create(&DomainBan{ 367 + Domain: body.Domain, 368 + }).Error; err != nil { 369 + return err 370 + } 371 + 372 + return c.JSON(200, map[string]any{ 373 + "success": "true", 374 + }) 375 + } 376 + 377 + func (bgs *BGS) handleAdminUnbanDomain(c echo.Context) error { 378 + var body banDomainBody 379 + if err := c.Bind(&body); err != nil { 380 + return err 381 + } 382 + 383 + if err := bgs.db.Where("domain = ?", body.Domain).Delete(&DomainBan{}).Error; err != nil { 384 + return err 385 + } 386 + 387 + return c.JSON(200, map[string]any{ 388 + "success": "true", 389 + }) 390 + } 391 + 392 + type PDSRates struct { 393 + // core event rate, counts firehose events 394 + PerSecond int64 `json:"per_second,omitempty"` 395 + PerHour int64 `json:"per_hour,omitempty"` 396 + PerDay int64 `json:"per_day,omitempty"` 397 + 398 + RepoLimit int64 `json:"repo_limit,omitempty"` 399 + } 400 + 401 + func (pr *PDSRates) FromSlurper(s *Slurper) { 402 + if pr.PerSecond == 0 { 403 + pr.PerHour = s.DefaultPerSecondLimit 404 + } 405 + if pr.PerHour == 0 { 406 + pr.PerHour = s.DefaultPerHourLimit 407 + } 408 + if pr.PerDay == 0 { 409 + pr.PerDay = s.DefaultPerDayLimit 410 + } 411 + if pr.RepoLimit == 0 { 412 + pr.RepoLimit = s.DefaultRepoLimit 413 + } 414 + } 415 + 416 + type RateLimitChangeRequest struct { 417 + Host string `json:"host"` 418 + PDSRates 419 + } 420 + 421 + func (bgs *BGS) handleAdminChangePDSRateLimits(e echo.Context) error { 422 + var body RateLimitChangeRequest 423 + if err := e.Bind(&body); err != nil { 424 + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid body: %s", err)) 425 + } 426 + 427 + // Get the PDS from the DB 428 + var pds models.PDS 429 + if err := bgs.db.Where("host = ?", body.Host).First(&pds).Error; err != nil { 430 + return err 431 + } 432 + 433 + // Update the rate limits in the DB 434 + pds.RateLimit = float64(body.PerSecond) 435 + pds.HourlyEventLimit = body.PerHour 436 + pds.DailyEventLimit = body.PerDay 437 + pds.RepoLimit = body.RepoLimit 438 + 439 + if err := bgs.db.Save(&pds).Error; err != nil { 440 + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Errorf("failed to save rate limit changes: %w", err)) 441 + } 442 + 443 + // Update the rate limit in the limiter 444 + limits := bgs.slurper.GetOrCreateLimiters(pds.ID, body.PerSecond, body.PerHour, body.PerDay) 445 + limits.PerSecond.SetLimit(body.PerSecond) 446 + limits.PerHour.SetLimit(body.PerHour) 447 + limits.PerDay.SetLimit(body.PerDay) 448 + 449 + return e.JSON(200, map[string]any{ 450 + "success": "true", 451 + }) 452 + } 453 + 454 + func (bgs *BGS) handleAdminAddTrustedDomain(e echo.Context) error { 455 + domain := e.QueryParam("domain") 456 + if domain == "" { 457 + return fmt.Errorf("must specify domain in query parameter") 458 + } 459 + 460 + // Check if the domain is already trusted 461 + trustedDomains := bgs.slurper.GetTrustedDomains() 462 + if slices.Contains(trustedDomains, domain) { 463 + return &echo.HTTPError{ 464 + Code: 400, 465 + Message: "domain is already trusted", 466 + } 467 + } 468 + 469 + if err := bgs.slurper.AddTrustedDomain(domain); err != nil { 470 + return err 471 + } 472 + 473 + return e.JSON(200, map[string]any{ 474 + "success": true, 475 + }) 476 + } 477 + 478 + type AdminRequestCrawlRequest struct { 479 + Hostname string `json:"hostname"` 480 + 481 + // optional: 482 + PDSRates 483 + } 484 + 485 + func (bgs *BGS) handleAdminRequestCrawl(e echo.Context) error { 486 + ctx := e.Request().Context() 487 + 488 + var body AdminRequestCrawlRequest 489 + if err := e.Bind(&body); err != nil { 490 + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("invalid body: %s", err)) 491 + } 492 + 493 + host := body.Hostname 494 + if host == "" { 495 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname") 496 + } 497 + 498 + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { 499 + if bgs.ssl { 500 + host = "https://" + host 501 + } else { 502 + host = "http://" + host 503 + } 504 + } 505 + 506 + u, err := url.Parse(host) 507 + if err != nil { 508 + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse hostname") 509 + } 510 + 511 + if u.Scheme == "http" && bgs.ssl { 512 + return echo.NewHTTPError(http.StatusBadRequest, "this server requires https") 513 + } 514 + 515 + if u.Scheme == "https" && !bgs.ssl { 516 + return echo.NewHTTPError(http.StatusBadRequest, "this server does not support https") 517 + } 518 + 519 + if u.Path != "" { 520 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without path") 521 + } 522 + 523 + if u.Query().Encode() != "" { 524 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without query") 525 + } 526 + 527 + host = u.Host // potentially hostname:port 528 + 529 + banned, err := bgs.domainIsBanned(ctx, host) 530 + if banned { 531 + return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned") 532 + } 533 + 534 + // Skip checking if the server is online for now 535 + rateOverrides := body.PDSRates 536 + rateOverrides.FromSlurper(bgs.slurper) 537 + 538 + return bgs.slurper.SubscribeToPds(ctx, host, true, true, &rateOverrides) // Override Trusted Domain Check 539 + }
+1267
cmd/relay/bgs/bgs.go
··· 1 + package bgs 2 + 3 + import ( 4 + "context" 5 + "errors" 6 + "fmt" 7 + "github.com/bluesky-social/indigo/atproto/identity" 8 + "github.com/bluesky-social/indigo/atproto/syntax" 9 + "github.com/ipfs/go-cid" 10 + "io" 11 + "log/slog" 12 + "net" 13 + "net/http" 14 + _ "net/http/pprof" 15 + "net/url" 16 + "strconv" 17 + "strings" 18 + "sync" 19 + "time" 20 + 21 + comatproto "github.com/bluesky-social/indigo/api/atproto" 22 + "github.com/bluesky-social/indigo/cmd/relay/events" 23 + "github.com/bluesky-social/indigo/cmd/relay/models" 24 + lexutil "github.com/bluesky-social/indigo/lex/util" 25 + "github.com/bluesky-social/indigo/xrpc" 26 + 27 + "github.com/gorilla/websocket" 28 + lru "github.com/hashicorp/golang-lru/v2" 29 + "github.com/labstack/echo/v4" 30 + "github.com/labstack/echo/v4/middleware" 31 + promclient "github.com/prometheus/client_golang/prometheus" 32 + "github.com/prometheus/client_golang/prometheus/promhttp" 33 + dto "github.com/prometheus/client_model/go" 34 + "go.opentelemetry.io/otel" 35 + "go.opentelemetry.io/otel/attribute" 36 + "gorm.io/gorm" 37 + ) 38 + 39 + var tracer = otel.Tracer("bgs") 40 + 41 + // serverListenerBootTimeout is how long to wait for the requested server socket 42 + // to become available for use. This is an arbitrary timeout that should be safe 43 + // on any platform, but there's no great way to weave this timeout without 44 + // adding another parameter to the (at time of writing) long signature of 45 + // NewServer. 46 + const serverListenerBootTimeout = 5 * time.Second 47 + 48 + type BGS struct { 49 + db *gorm.DB 50 + slurper *Slurper 51 + events *events.EventManager 52 + didd identity.Directory 53 + 54 + // TODO: work on doing away with this flag in favor of more pluggable 55 + // pieces that abstract the need for explicit ssl checks 56 + ssl bool 57 + 58 + // extUserLk serializes a section of syncPDSAccount() 59 + // TODO: at some point we will want to lock specific DIDs, this lock as is 60 + // is overly broad, but i dont expect it to be a bottleneck for now 61 + extUserLk sync.Mutex 62 + 63 + validator *Validator 64 + 65 + // Management of Socket Consumers 66 + consumersLk sync.RWMutex 67 + nextConsumerID uint64 68 + consumers map[uint64]*SocketConsumer 69 + 70 + // Account cache 71 + userCache *lru.Cache[string, *Account] 72 + 73 + // nextCrawlers gets forwarded POST /xrpc/com.atproto.sync.requestCrawl 74 + nextCrawlers []*url.URL 75 + httpClient http.Client 76 + 77 + log *slog.Logger 78 + inductionTraceLog *slog.Logger 79 + 80 + config BGSConfig 81 + } 82 + 83 + type SocketConsumer struct { 84 + UserAgent string 85 + RemoteAddr string 86 + ConnectedAt time.Time 87 + EventsSent promclient.Counter 88 + } 89 + 90 + type BGSConfig struct { 91 + SSL bool 92 + DefaultRepoLimit int64 93 + ConcurrencyPerPDS int64 94 + MaxQueuePerPDS int64 95 + 96 + // NextCrawlers gets forwarded POST /xrpc/com.atproto.sync.requestCrawl 97 + NextCrawlers []*url.URL 98 + 99 + ApplyPDSClientSettings func(c *xrpc.Client) 100 + InductionTraceLog *slog.Logger 101 + 102 + // AdminToken checked against "Authorization: Bearer {}" header 103 + AdminToken string 104 + } 105 + 106 + func DefaultBGSConfig() *BGSConfig { 107 + return &BGSConfig{ 108 + SSL: true, 109 + DefaultRepoLimit: 100, 110 + ConcurrencyPerPDS: 100, 111 + MaxQueuePerPDS: 1_000, 112 + } 113 + } 114 + 115 + func NewBGS(db *gorm.DB, validator *Validator, evtman *events.EventManager, didd identity.Directory, config *BGSConfig) (*BGS, error) { 116 + 117 + if config == nil { 118 + config = DefaultBGSConfig() 119 + } 120 + if err := db.AutoMigrate(DomainBan{}); err != nil { 121 + panic(err) 122 + } 123 + if err := db.AutoMigrate(models.PDS{}); err != nil { 124 + panic(err) 125 + } 126 + if err := db.AutoMigrate(Account{}); err != nil { 127 + panic(err) 128 + } 129 + if err := db.AutoMigrate(AccountPreviousState{}); err != nil { 130 + panic(err) 131 + } 132 + 133 + uc, _ := lru.New[string, *Account](1_000_000) 134 + 135 + bgs := &BGS{ 136 + db: db, 137 + 138 + validator: validator, 139 + events: evtman, 140 + didd: didd, 141 + ssl: config.SSL, 142 + 143 + consumersLk: sync.RWMutex{}, 144 + consumers: make(map[uint64]*SocketConsumer), 145 + 146 + userCache: uc, 147 + 148 + log: slog.Default().With("system", "bgs"), 149 + 150 + config: *config, 151 + 152 + inductionTraceLog: config.InductionTraceLog, 153 + } 154 + 155 + slOpts := DefaultSlurperOptions() 156 + slOpts.SSL = config.SSL 157 + slOpts.DefaultRepoLimit = config.DefaultRepoLimit 158 + slOpts.ConcurrencyPerPDS = config.ConcurrencyPerPDS 159 + slOpts.MaxQueuePerPDS = config.MaxQueuePerPDS 160 + slOpts.Logger = bgs.log 161 + s, err := NewSlurper(db, bgs.handleFedEvent, slOpts) 162 + if err != nil { 163 + return nil, err 164 + } 165 + 166 + bgs.slurper = s 167 + 168 + if err := bgs.slurper.RestartAll(); err != nil { 169 + return nil, err 170 + } 171 + 172 + bgs.nextCrawlers = config.NextCrawlers 173 + bgs.httpClient.Timeout = time.Second * 5 174 + 175 + return bgs, nil 176 + } 177 + 178 + func (bgs *BGS) StartMetrics(listen string) error { 179 + http.Handle("/metrics", promhttp.Handler()) 180 + return http.ListenAndServe(listen, nil) 181 + } 182 + 183 + func (bgs *BGS) Start(addr string, logWriter io.Writer) error { 184 + var lc net.ListenConfig 185 + ctx, cancel := context.WithTimeout(context.Background(), serverListenerBootTimeout) 186 + defer cancel() 187 + 188 + li, err := lc.Listen(ctx, "tcp", addr) 189 + if err != nil { 190 + return err 191 + } 192 + return bgs.StartWithListener(li, logWriter) 193 + } 194 + 195 + func (bgs *BGS) StartWithListener(listen net.Listener, logWriter io.Writer) error { 196 + e := echo.New() 197 + e.Logger.SetOutput(logWriter) 198 + e.HideBanner = true 199 + 200 + e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 201 + AllowOrigins: []string{"*"}, 202 + AllowHeaders: []string{echo.HeaderOrigin, echo.HeaderContentType, echo.HeaderAccept, echo.HeaderAuthorization}, 203 + })) 204 + 205 + if !bgs.ssl { 206 + e.Use(middleware.LoggerWithConfig(middleware.LoggerConfig{ 207 + Format: "method=${method}, uri=${uri}, status=${status} latency=${latency_human}\n", 208 + })) 209 + } else { 210 + e.Use(middleware.LoggerWithConfig(middleware.DefaultLoggerConfig)) 211 + } 212 + 213 + // React uses a virtual router, so we need to serve the index.html for all 214 + // routes that aren't otherwise handled or in the /assets directory. 215 + e.File("/dash", "public/index.html") 216 + e.File("/dash/*", "public/index.html") 217 + e.Static("/assets", "public/assets") 218 + 219 + e.Use(MetricsMiddleware) 220 + 221 + e.HTTPErrorHandler = func(err error, ctx echo.Context) { 222 + switch err := err.(type) { 223 + case *echo.HTTPError: 224 + if err2 := ctx.JSON(err.Code, map[string]any{ 225 + "error": err.Message, 226 + }); err2 != nil { 227 + bgs.log.Error("Failed to write http error", "err", err2) 228 + } 229 + default: 230 + sendHeader := true 231 + if ctx.Path() == "/xrpc/com.atproto.sync.subscribeRepos" { 232 + sendHeader = false 233 + } 234 + 235 + bgs.log.Warn("HANDLER ERROR: (%s) %s", ctx.Path(), err) 236 + 237 + if strings.HasPrefix(ctx.Path(), "/admin/") { 238 + ctx.JSON(500, map[string]any{ 239 + "error": err.Error(), 240 + }) 241 + return 242 + } 243 + 244 + if sendHeader { 245 + ctx.Response().WriteHeader(500) 246 + } 247 + } 248 + } 249 + 250 + // TODO: this API is temporary until we formalize what we want here 251 + 252 + e.GET("/xrpc/com.atproto.sync.subscribeRepos", bgs.EventsHandler) 253 + e.POST("/xrpc/com.atproto.sync.requestCrawl", bgs.HandleComAtprotoSyncRequestCrawl) 254 + e.GET("/xrpc/com.atproto.sync.listRepos", bgs.HandleComAtprotoSyncListRepos) 255 + e.GET("/xrpc/com.atproto.sync.getRepo", bgs.HandleComAtprotoSyncGetRepo) // just returns 3xx redirect to source PDS 256 + e.GET("/xrpc/com.atproto.sync.getLatestCommit", bgs.HandleComAtprotoSyncGetLatestCommit) 257 + e.GET("/xrpc/_health", bgs.HandleHealthCheck) 258 + e.GET("/_health", bgs.HandleHealthCheck) 259 + e.GET("/", bgs.HandleHomeMessage) 260 + 261 + admin := e.Group("/admin", bgs.checkAdminAuth) 262 + 263 + // Slurper-related Admin API 264 + admin.GET("/subs/getUpstreamConns", bgs.handleAdminGetUpstreamConns) 265 + admin.GET("/subs/getEnabled", bgs.handleAdminGetSubsEnabled) 266 + admin.GET("/subs/perDayLimit", bgs.handleAdminGetNewPDSPerDayRateLimit) 267 + admin.POST("/subs/setEnabled", bgs.handleAdminSetSubsEnabled) 268 + admin.POST("/subs/killUpstream", bgs.handleAdminKillUpstreamConn) 269 + admin.POST("/subs/setPerDayLimit", bgs.handleAdminSetNewPDSPerDayRateLimit) 270 + 271 + // Domain-related Admin API 272 + admin.GET("/subs/listDomainBans", bgs.handleAdminListDomainBans) 273 + admin.POST("/subs/banDomain", bgs.handleAdminBanDomain) 274 + admin.POST("/subs/unbanDomain", bgs.handleAdminUnbanDomain) 275 + 276 + // Repo-related Admin API 277 + admin.POST("/repo/takeDown", bgs.handleAdminTakeDownRepo) 278 + admin.POST("/repo/reverseTakedown", bgs.handleAdminReverseTakedown) 279 + admin.GET("/repo/takedowns", bgs.handleAdminListRepoTakeDowns) 280 + 281 + // PDS-related Admin API 282 + admin.POST("/pds/requestCrawl", bgs.handleAdminRequestCrawl) 283 + admin.GET("/pds/list", bgs.handleListPDSs) 284 + admin.POST("/pds/changeLimits", bgs.handleAdminChangePDSRateLimits) 285 + admin.POST("/pds/block", bgs.handleBlockPDS) 286 + admin.POST("/pds/unblock", bgs.handleUnblockPDS) 287 + admin.POST("/pds/addTrustedDomain", bgs.handleAdminAddTrustedDomain) 288 + 289 + // Consumer-related Admin API 290 + admin.GET("/consumers/list", bgs.handleAdminListConsumers) 291 + 292 + // In order to support booting on random ports in tests, we need to tell the 293 + // Echo instance it's already got a port, and then use its StartServer 294 + // method to re-use that listener. 295 + e.Listener = listen 296 + srv := &http.Server{} 297 + return e.StartServer(srv) 298 + } 299 + 300 + func (bgs *BGS) Shutdown() []error { 301 + errs := bgs.slurper.Shutdown() 302 + 303 + if err := bgs.events.Shutdown(context.TODO()); err != nil { 304 + errs = append(errs, err) 305 + } 306 + 307 + return errs 308 + } 309 + 310 + type HealthStatus struct { 311 + Status string `json:"status"` 312 + Message string `json:"msg,omitempty"` 313 + } 314 + 315 + func (bgs *BGS) HandleHealthCheck(c echo.Context) error { 316 + if err := bgs.db.Exec("SELECT 1").Error; err != nil { 317 + bgs.log.Error("healthcheck can't connect to database", "err", err) 318 + return c.JSON(500, HealthStatus{Status: "error", Message: "can't connect to database"}) 319 + } else { 320 + return c.JSON(200, HealthStatus{Status: "ok"}) 321 + } 322 + } 323 + 324 + var homeMessage string = ` 325 + .########..########.##..........###....##....## 326 + .##.....##.##.......##.........##.##....##..##. 327 + .##.....##.##.......##........##...##....####.. 328 + .########..######...##.......##.....##....##... 329 + .##...##...##.......##.......#########....##... 330 + .##....##..##.......##.......##.....##....##... 331 + .##.....##.########.########.##.....##....##... 332 + 333 + This is an atproto [https://atproto.com] relay instance, running the 'bigsky' codebase [https://github.com/bluesky-social/indigo] 334 + 335 + The firehose WebSocket path is at: /xrpc/com.atproto.sync.subscribeRepos 336 + ` 337 + 338 + func (bgs *BGS) HandleHomeMessage(c echo.Context) error { 339 + return c.String(http.StatusOK, homeMessage) 340 + } 341 + 342 + const authorizationBearerPrefix = "Bearer " 343 + 344 + func (bgs *BGS) checkAdminAuth(next echo.HandlerFunc) echo.HandlerFunc { 345 + return func(e echo.Context) error { 346 + authheader := e.Request().Header.Get("Authorization") 347 + if !strings.HasPrefix(authheader, authorizationBearerPrefix) { 348 + return echo.ErrForbidden 349 + } 350 + 351 + token := authheader[len(authorizationBearerPrefix):] 352 + 353 + if bgs.config.AdminToken != token { 354 + return echo.ErrForbidden 355 + } 356 + 357 + return next(e) 358 + } 359 + } 360 + 361 + type Account struct { 362 + ID models.Uid `gorm:"primarykey"` 363 + CreatedAt time.Time 364 + UpdatedAt time.Time 365 + DeletedAt gorm.DeletedAt `gorm:"index"` 366 + Did string `gorm:"uniqueIndex"` 367 + PDS uint // foreign key on models.PDS.ID 368 + 369 + // TakenDown is set to true if the user in question has been taken down by an admin action at this relay. 370 + // A user in this state will have all future events related to it dropped 371 + // and no data about this user will be served. 372 + TakenDown bool 373 + 374 + // UpstreamStatus is the state of the user as reported by the upstream PDS through #account messages. 375 + // Additionally, the non-standard string "active" is set to represent an upstream #account message with the active bool true. 376 + UpstreamStatus string `gorm:"index"` 377 + 378 + lk sync.Mutex 379 + } 380 + 381 + func (account *Account) GetDid() string { 382 + return account.Did 383 + } 384 + 385 + func (account *Account) GetUid() models.Uid { 386 + return account.ID 387 + } 388 + 389 + func (account *Account) SetTakenDown(v bool) { 390 + account.lk.Lock() 391 + defer account.lk.Unlock() 392 + account.TakenDown = v 393 + } 394 + 395 + func (account *Account) GetTakenDown() bool { 396 + account.lk.Lock() 397 + defer account.lk.Unlock() 398 + return account.TakenDown 399 + } 400 + 401 + func (account *Account) SetPDS(pdsId uint) { 402 + account.lk.Lock() 403 + defer account.lk.Unlock() 404 + account.PDS = pdsId 405 + } 406 + 407 + func (account *Account) GetPDS() uint { 408 + account.lk.Lock() 409 + defer account.lk.Unlock() 410 + return account.PDS 411 + } 412 + 413 + func (account *Account) SetUpstreamStatus(v string) { 414 + account.lk.Lock() 415 + defer account.lk.Unlock() 416 + account.UpstreamStatus = v 417 + } 418 + 419 + func (account *Account) GetUpstreamStatus() string { 420 + account.lk.Lock() 421 + defer account.lk.Unlock() 422 + return account.UpstreamStatus 423 + } 424 + 425 + type AccountPreviousState struct { 426 + Uid models.Uid `gorm:"column:uid;primaryKey"` 427 + Cid models.DbCID `gorm:"column:cid"` 428 + Rev string `gorm:"column:rev"` 429 + Seq int64 `gorm:"column:seq"` 430 + } 431 + 432 + func (ups *AccountPreviousState) GetCid() cid.Cid { 433 + return ups.Cid.CID 434 + } 435 + func (ups *AccountPreviousState) GetRev() syntax.TID { 436 + xt, _ := syntax.ParseTID(ups.Rev) 437 + return xt 438 + } 439 + 440 + type addTargetBody struct { 441 + Host string `json:"host"` 442 + } 443 + 444 + func (bgs *BGS) registerConsumer(c *SocketConsumer) uint64 { 445 + bgs.consumersLk.Lock() 446 + defer bgs.consumersLk.Unlock() 447 + 448 + id := bgs.nextConsumerID 449 + bgs.nextConsumerID++ 450 + 451 + bgs.consumers[id] = c 452 + 453 + return id 454 + } 455 + 456 + func (bgs *BGS) cleanupConsumer(id uint64) { 457 + bgs.consumersLk.Lock() 458 + defer bgs.consumersLk.Unlock() 459 + 460 + c := bgs.consumers[id] 461 + 462 + var m = &dto.Metric{} 463 + if err := c.EventsSent.Write(m); err != nil { 464 + bgs.log.Error("failed to get sent counter", "err", err) 465 + } 466 + 467 + bgs.log.Info("consumer disconnected", 468 + "consumer_id", id, 469 + "remote_addr", c.RemoteAddr, 470 + "user_agent", c.UserAgent, 471 + "events_sent", m.Counter.GetValue()) 472 + 473 + delete(bgs.consumers, id) 474 + } 475 + 476 + // GET+websocket /xrpc/com.atproto.sync.subscribeRepos 477 + func (bgs *BGS) EventsHandler(c echo.Context) error { 478 + var since *int64 479 + if sinceVal := c.QueryParam("cursor"); sinceVal != "" { 480 + sval, err := strconv.ParseInt(sinceVal, 10, 64) 481 + if err != nil { 482 + return err 483 + } 484 + since = &sval 485 + } 486 + 487 + ctx, cancel := context.WithCancel(c.Request().Context()) 488 + defer cancel() 489 + 490 + conn, err := websocket.Upgrade(c.Response(), c.Request(), c.Response().Header(), 10<<10, 10<<10) 491 + if err != nil { 492 + return fmt.Errorf("upgrading websocket: %w", err) 493 + } 494 + 495 + defer conn.Close() 496 + 497 + lastWriteLk := sync.Mutex{} 498 + lastWrite := time.Now() 499 + 500 + // Start a goroutine to ping the client every 30 seconds to check if it's 501 + // still alive. If the client doesn't respond to a ping within 5 seconds, 502 + // we'll close the connection and teardown the consumer. 503 + go func() { 504 + ticker := time.NewTicker(30 * time.Second) 505 + defer ticker.Stop() 506 + 507 + for { 508 + select { 509 + case <-ticker.C: 510 + lastWriteLk.Lock() 511 + lw := lastWrite 512 + lastWriteLk.Unlock() 513 + 514 + if time.Since(lw) < 30*time.Second { 515 + continue 516 + } 517 + 518 + if err := conn.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(5*time.Second)); err != nil { 519 + bgs.log.Warn("failed to ping client", "err", err) 520 + cancel() 521 + return 522 + } 523 + case <-ctx.Done(): 524 + return 525 + } 526 + } 527 + }() 528 + 529 + conn.SetPingHandler(func(message string) error { 530 + err := conn.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) 531 + if err == websocket.ErrCloseSent { 532 + return nil 533 + } else if e, ok := err.(net.Error); ok && e.Temporary() { 534 + return nil 535 + } 536 + return err 537 + }) 538 + 539 + // Start a goroutine to read messages from the client and discard them. 540 + go func() { 541 + for { 542 + _, _, err := conn.ReadMessage() 543 + if err != nil { 544 + bgs.log.Warn("failed to read message from client", "err", err) 545 + cancel() 546 + return 547 + } 548 + } 549 + }() 550 + 551 + ident := c.RealIP() + "-" + c.Request().UserAgent() 552 + 553 + evts, cleanup, err := bgs.events.Subscribe(ctx, ident, func(evt *events.XRPCStreamEvent) bool { return true }, since) 554 + if err != nil { 555 + return err 556 + } 557 + defer cleanup() 558 + 559 + // Keep track of the consumer for metrics and admin endpoints 560 + consumer := SocketConsumer{ 561 + RemoteAddr: c.RealIP(), 562 + UserAgent: c.Request().UserAgent(), 563 + ConnectedAt: time.Now(), 564 + } 565 + sentCounter := eventsSentCounter.WithLabelValues(consumer.RemoteAddr, consumer.UserAgent) 566 + consumer.EventsSent = sentCounter 567 + 568 + consumerID := bgs.registerConsumer(&consumer) 569 + defer bgs.cleanupConsumer(consumerID) 570 + 571 + logger := bgs.log.With( 572 + "consumer_id", consumerID, 573 + "remote_addr", consumer.RemoteAddr, 574 + "user_agent", consumer.UserAgent, 575 + ) 576 + 577 + logger.Info("new consumer", "cursor", since) 578 + 579 + for { 580 + select { 581 + case evt, ok := <-evts: 582 + if !ok { 583 + logger.Error("event stream closed unexpectedly") 584 + return nil 585 + } 586 + 587 + wc, err := conn.NextWriter(websocket.BinaryMessage) 588 + if err != nil { 589 + logger.Error("failed to get next writer", "err", err) 590 + return err 591 + } 592 + 593 + if evt.Preserialized != nil { 594 + _, err = wc.Write(evt.Preserialized) 595 + } else { 596 + err = evt.Serialize(wc) 597 + } 598 + if err != nil { 599 + return fmt.Errorf("failed to write event: %w", err) 600 + } 601 + 602 + if err := wc.Close(); err != nil { 603 + logger.Warn("failed to flush-close our event write", "err", err) 604 + return nil 605 + } 606 + 607 + lastWriteLk.Lock() 608 + lastWrite = time.Now() 609 + lastWriteLk.Unlock() 610 + sentCounter.Inc() 611 + case <-ctx.Done(): 612 + return nil 613 + } 614 + } 615 + } 616 + 617 + // domainIsBanned checks if the given host is banned, starting with the host 618 + // itself, then checking every parent domain up to the tld 619 + func (s *BGS) domainIsBanned(ctx context.Context, host string) (bool, error) { 620 + // ignore ports when checking for ban status 621 + hostport := strings.Split(host, ":") 622 + 623 + segments := strings.Split(hostport[0], ".") 624 + 625 + // TODO: use normalize method once that merges 626 + var cleaned []string 627 + for _, s := range segments { 628 + if s == "" { 629 + continue 630 + } 631 + s = strings.ToLower(s) 632 + 633 + cleaned = append(cleaned, s) 634 + } 635 + segments = cleaned 636 + 637 + for i := 0; i < len(segments)-1; i++ { 638 + dchk := strings.Join(segments[i:], ".") 639 + found, err := s.findDomainBan(ctx, dchk) 640 + if err != nil { 641 + return false, err 642 + } 643 + 644 + if found { 645 + return true, nil 646 + } 647 + } 648 + return false, nil 649 + } 650 + 651 + func (s *BGS) findDomainBan(ctx context.Context, host string) (bool, error) { 652 + var db DomainBan 653 + if err := s.db.Find(&db, "domain = ?", host).Error; err != nil { 654 + return false, err 655 + } 656 + 657 + if db.ID == 0 { 658 + return false, nil 659 + } 660 + 661 + return true, nil 662 + } 663 + 664 + var ErrNotFound = errors.New("not found") 665 + 666 + func (bgs *BGS) DidToUid(ctx context.Context, did string) (models.Uid, error) { 667 + xu, err := bgs.lookupUserByDid(ctx, did) 668 + if err != nil { 669 + return 0, err 670 + } 671 + if xu == nil { 672 + return 0, ErrNotFound 673 + } 674 + return xu.ID, nil 675 + } 676 + 677 + func (bgs *BGS) lookupUserByDid(ctx context.Context, did string) (*Account, error) { 678 + ctx, span := tracer.Start(ctx, "lookupUserByDid") 679 + defer span.End() 680 + 681 + cu, ok := bgs.userCache.Get(did) 682 + if ok { 683 + return cu, nil 684 + } 685 + 686 + var u Account 687 + if err := bgs.db.Find(&u, "did = ?", did).Error; err != nil { 688 + return nil, err 689 + } 690 + 691 + if u.ID == 0 { 692 + return nil, gorm.ErrRecordNotFound 693 + } 694 + 695 + bgs.userCache.Add(did, &u) 696 + 697 + return &u, nil 698 + } 699 + 700 + func (bgs *BGS) lookupUserByUID(ctx context.Context, uid models.Uid) (*Account, error) { 701 + ctx, span := tracer.Start(ctx, "lookupUserByUID") 702 + defer span.End() 703 + 704 + var u Account 705 + if err := bgs.db.Find(&u, "id = ?", uid).Error; err != nil { 706 + return nil, err 707 + } 708 + 709 + if u.ID == 0 { 710 + return nil, gorm.ErrRecordNotFound 711 + } 712 + 713 + return &u, nil 714 + } 715 + 716 + func stringLink(lnk *lexutil.LexLink) string { 717 + if lnk == nil { 718 + return "<nil>" 719 + } 720 + 721 + return lnk.String() 722 + } 723 + 724 + // handleFedEvent() is the callback passed to Slurper called from Slurper.handleConnection() 725 + func (bgs *BGS) handleFedEvent(ctx context.Context, host *models.PDS, env *events.XRPCStreamEvent) error { 726 + ctx, span := tracer.Start(ctx, "handleFedEvent") 727 + defer span.End() 728 + 729 + start := time.Now() 730 + defer func() { 731 + eventsHandleDuration.WithLabelValues(host.Host).Observe(time.Since(start).Seconds()) 732 + }() 733 + 734 + eventsReceivedCounter.WithLabelValues(host.Host).Add(1) 735 + 736 + switch { 737 + case env.RepoCommit != nil: 738 + repoCommitsReceivedCounter.WithLabelValues(host.Host).Add(1) 739 + return bgs.handleCommit(ctx, host, env.RepoCommit) 740 + case env.RepoSync != nil: 741 + repoSyncReceivedCounter.WithLabelValues(host.Host).Add(1) 742 + return bgs.handleSync(ctx, host, env.RepoSync) 743 + case env.RepoHandle != nil: 744 + eventsWarningsCounter.WithLabelValues(host.Host, "handle").Add(1) 745 + // TODO: rate limit warnings per PDS before we (temporarily?) block them 746 + return nil 747 + case env.RepoIdentity != nil: 748 + bgs.log.Info("bgs got identity event", "did", env.RepoIdentity.Did) 749 + // Flush any cached DID documents for this user 750 + bgs.purgeDidCache(ctx, env.RepoIdentity.Did) 751 + 752 + // Refetch the DID doc and update our cached keys and handle etc. 753 + account, err := bgs.syncPDSAccount(ctx, env.RepoIdentity.Did, host, nil) 754 + if err != nil { 755 + return err 756 + } 757 + 758 + // Broadcast the identity event to all consumers 759 + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ 760 + RepoIdentity: &comatproto.SyncSubscribeRepos_Identity{ 761 + Did: env.RepoIdentity.Did, 762 + Seq: env.RepoIdentity.Seq, 763 + Time: env.RepoIdentity.Time, 764 + Handle: env.RepoIdentity.Handle, 765 + }, 766 + PrivUid: account.ID, 767 + }) 768 + if err != nil { 769 + bgs.log.Error("failed to broadcast Identity event", "error", err, "did", env.RepoIdentity.Did) 770 + return fmt.Errorf("failed to broadcast Identity event: %w", err) 771 + } 772 + 773 + return nil 774 + case env.RepoAccount != nil: 775 + span.SetAttributes( 776 + attribute.String("did", env.RepoAccount.Did), 777 + attribute.Int64("seq", env.RepoAccount.Seq), 778 + attribute.Bool("active", env.RepoAccount.Active), 779 + ) 780 + 781 + if env.RepoAccount.Status != nil { 782 + span.SetAttributes(attribute.String("repo_status", *env.RepoAccount.Status)) 783 + } 784 + bgs.log.Info("bgs got account event", "did", env.RepoAccount.Did) 785 + 786 + if !env.RepoAccount.Active && env.RepoAccount.Status == nil { 787 + accountVerifyWarnings.WithLabelValues(host.Host, "nostat").Inc() 788 + return nil 789 + } 790 + 791 + // Flush any cached DID documents for this user 792 + bgs.purgeDidCache(ctx, env.RepoAccount.Did) 793 + 794 + // Refetch the DID doc to make sure the PDS is still authoritative 795 + account, err := bgs.syncPDSAccount(ctx, env.RepoAccount.Did, host, nil) 796 + if err != nil { 797 + span.RecordError(err) 798 + return err 799 + } 800 + 801 + // Check if the PDS is still authoritative 802 + // if not we don't want to be propagating this account event 803 + if account.GetPDS() != host.ID { 804 + bgs.log.Error("account event from non-authoritative pds", 805 + "seq", env.RepoAccount.Seq, 806 + "did", env.RepoAccount.Did, 807 + "event_from", host.Host, 808 + "did_doc_declared_pds", account.GetPDS(), 809 + "account_evt", env.RepoAccount, 810 + ) 811 + return fmt.Errorf("event from non-authoritative pds") 812 + } 813 + 814 + // Process the account status change 815 + repoStatus := events.AccountStatusActive 816 + if !env.RepoAccount.Active && env.RepoAccount.Status != nil { 817 + repoStatus = *env.RepoAccount.Status 818 + } 819 + 820 + account.SetUpstreamStatus(repoStatus) 821 + err = bgs.db.Save(account).Error 822 + if err != nil { 823 + span.RecordError(err) 824 + return fmt.Errorf("failed to update account status: %w", err) 825 + } 826 + 827 + shouldBeActive := env.RepoAccount.Active 828 + status := env.RepoAccount.Status 829 + 830 + // override with local status 831 + if account.GetTakenDown() { 832 + shouldBeActive = false 833 + status = &events.AccountStatusTakendown 834 + } 835 + 836 + // Broadcast the account event to all consumers 837 + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ 838 + RepoAccount: &comatproto.SyncSubscribeRepos_Account{ 839 + Active: shouldBeActive, 840 + Did: env.RepoAccount.Did, 841 + Seq: env.RepoAccount.Seq, 842 + Status: status, 843 + Time: env.RepoAccount.Time, 844 + }, 845 + PrivUid: account.ID, 846 + }) 847 + if err != nil { 848 + bgs.log.Error("failed to broadcast Account event", "error", err, "did", env.RepoAccount.Did) 849 + return fmt.Errorf("failed to broadcast Account event: %w", err) 850 + } 851 + 852 + return nil 853 + case env.RepoMigrate != nil: 854 + eventsWarningsCounter.WithLabelValues(host.Host, "migrate").Add(1) 855 + // TODO: rate limit warnings per PDS before we (temporarily?) block them 856 + return nil 857 + case env.RepoTombstone != nil: 858 + eventsWarningsCounter.WithLabelValues(host.Host, "tombstone").Add(1) 859 + // TODO: rate limit warnings per PDS before we (temporarily?) block them 860 + return nil 861 + default: 862 + return fmt.Errorf("invalid fed event") 863 + } 864 + } 865 + 866 + func (bgs *BGS) newUser(ctx context.Context, host *models.PDS, did string) (*Account, error) { 867 + newUsersDiscovered.Inc() 868 + start := time.Now() 869 + account, err := bgs.syncPDSAccount(ctx, did, host, nil) 870 + newUserDiscoveryDuration.Observe(time.Since(start).Seconds()) 871 + if err != nil { 872 + repoCommitsResultCounter.WithLabelValues(host.Host, "uerr").Inc() 873 + return nil, fmt.Errorf("fed event create external user: %w", err) 874 + } 875 + return account, nil 876 + } 877 + 878 + var ErrCommitNoUser = errors.New("commit no user") 879 + 880 + func (bgs *BGS) handleCommit(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Commit) error { 881 + bgs.log.Debug("bgs got repo append event", "seq", evt.Seq, "pdsHost", host.Host, "repo", evt.Repo) 882 + 883 + account, err := bgs.lookupUserByDid(ctx, evt.Repo) 884 + if err != nil { 885 + if !errors.Is(err, gorm.ErrRecordNotFound) { 886 + repoCommitsResultCounter.WithLabelValues(host.Host, "nou").Inc() 887 + return fmt.Errorf("looking up event user: %w", err) 888 + } 889 + 890 + account, err = bgs.newUser(ctx, host, evt.Repo) 891 + if err != nil { 892 + repoCommitsResultCounter.WithLabelValues(host.Host, "nuerr").Inc() 893 + return err 894 + } 895 + } 896 + if account == nil { 897 + repoCommitsResultCounter.WithLabelValues(host.Host, "nou2").Inc() 898 + return ErrCommitNoUser 899 + } 900 + 901 + ustatus := account.GetUpstreamStatus() 902 + 903 + if account.GetTakenDown() || ustatus == events.AccountStatusTakendown { 904 + bgs.log.Debug("dropping commit event from taken down user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 905 + repoCommitsResultCounter.WithLabelValues(host.Host, "tdu").Inc() 906 + return nil 907 + } 908 + 909 + if ustatus == events.AccountStatusSuspended { 910 + bgs.log.Debug("dropping commit event from suspended user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 911 + repoCommitsResultCounter.WithLabelValues(host.Host, "susu").Inc() 912 + return nil 913 + } 914 + 915 + if ustatus == events.AccountStatusDeactivated { 916 + bgs.log.Debug("dropping commit event from deactivated user", "did", evt.Repo, "seq", evt.Seq, "pdsHost", host.Host) 917 + repoCommitsResultCounter.WithLabelValues(host.Host, "du").Inc() 918 + return nil 919 + } 920 + 921 + if evt.Rebase { 922 + repoCommitsResultCounter.WithLabelValues(host.Host, "rebase").Inc() 923 + return fmt.Errorf("rebase was true in event seq:%d,host:%s", evt.Seq, host.Host) 924 + } 925 + 926 + accountPDSId := account.GetPDS() 927 + if host.ID != accountPDSId && accountPDSId != 0 { 928 + bgs.log.Warn("received event for repo from different pds than expected", "repo", evt.Repo, "expPds", accountPDSId, "gotPds", host.Host) 929 + // Flush any cached DID documents for this user 930 + bgs.purgeDidCache(ctx, evt.Repo) 931 + 932 + account, err = bgs.syncPDSAccount(ctx, evt.Repo, host, account) 933 + if err != nil { 934 + repoCommitsResultCounter.WithLabelValues(host.Host, "uerr2").Inc() 935 + return err 936 + } 937 + 938 + if account.GetPDS() != host.ID { 939 + repoCommitsResultCounter.WithLabelValues(host.Host, "noauth").Inc() 940 + return fmt.Errorf("event from non-authoritative pds") 941 + } 942 + } 943 + 944 + var prevState AccountPreviousState 945 + err = bgs.db.First(&prevState, account.ID).Error 946 + prevP := &prevState 947 + if errors.Is(err, gorm.ErrRecordNotFound) { 948 + prevP = nil 949 + } else if err != nil { 950 + bgs.log.Error("failed to get previous root", "err", err) 951 + prevP = nil 952 + } 953 + dbPrevRootStr := "" 954 + dbPrevSeqStr := "" 955 + if prevP != nil { 956 + if prevState.Seq >= evt.Seq && ((prevState.Seq - evt.Seq) < 2000) { 957 + // ignore catchup overlap of 200 on some subscribeRepos restarts 958 + repoCommitsResultCounter.WithLabelValues(host.Host, "dup").Inc() 959 + return nil 960 + } 961 + dbPrevRootStr = prevState.Cid.CID.String() 962 + dbPrevSeqStr = strconv.FormatInt(prevState.Seq, 10) 963 + } 964 + evtPrevDataStr := "" 965 + if evt.PrevData != nil { 966 + evtPrevDataStr = ((*cid.Cid)(evt.PrevData)).String() 967 + } 968 + newRootCid, err := bgs.validator.HandleCommit(ctx, host, account, evt, prevP) 969 + if err != nil { 970 + bgs.inductionTraceLog.Error("commit bad", "seq", evt.Seq, "pseq", dbPrevSeqStr, "pdsHost", host.Host, "repo", evt.Repo, "prev", evtPrevDataStr, "dbprev", dbPrevRootStr, "err", err) 971 + bgs.log.Warn("failed handling event", "err", err, "pdsHost", host.Host, "seq", evt.Seq, "repo", account.Did, "commit", evt.Commit.String()) 972 + repoCommitsResultCounter.WithLabelValues(host.Host, "err").Inc() 973 + return fmt.Errorf("handle user event failed: %w", err) 974 + } else { 975 + // store now verified new repo state 976 + err = bgs.upsertPrevState(account.ID, newRootCid, evt.Rev, evt.Seq) 977 + if err != nil { 978 + return fmt.Errorf("failed to set previous root uid=%d: %w", account.ID, err) 979 + } 980 + } 981 + 982 + repoCommitsResultCounter.WithLabelValues(host.Host, "ok").Inc() 983 + 984 + // Broadcast the identity event to all consumers 985 + commitCopy := *evt 986 + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ 987 + RepoCommit: &commitCopy, 988 + PrivUid: account.GetUid(), 989 + }) 990 + if err != nil { 991 + bgs.log.Error("failed to broadcast commit event", "error", err, "did", evt.Repo) 992 + return fmt.Errorf("failed to broadcast commit event: %w", err) 993 + } 994 + 995 + return nil 996 + } 997 + 998 + // handleSync processes #sync messages 999 + func (bgs *BGS) handleSync(ctx context.Context, host *models.PDS, evt *comatproto.SyncSubscribeRepos_Sync) error { 1000 + account, err := bgs.lookupUserByDid(ctx, evt.Did) 1001 + if err != nil { 1002 + if !errors.Is(err, gorm.ErrRecordNotFound) { 1003 + repoCommitsResultCounter.WithLabelValues(host.Host, "nou").Inc() 1004 + return fmt.Errorf("looking up event user: %w", err) 1005 + } 1006 + 1007 + account, err = bgs.newUser(ctx, host, evt.Did) 1008 + } 1009 + if err != nil { 1010 + return fmt.Errorf("could not get user for did %#v: %w", evt.Did, err) 1011 + } 1012 + 1013 + newRootCid, err := bgs.validator.HandleSync(ctx, host, evt) 1014 + if err != nil { 1015 + return err 1016 + } 1017 + err = bgs.upsertPrevState(account.ID, newRootCid, evt.Rev, evt.Seq) 1018 + if err != nil { 1019 + return fmt.Errorf("could not sync set previous state uid=%d: %w", account.ID, err) 1020 + } 1021 + 1022 + // Broadcast the sync event to all consumers 1023 + evtCopy := *evt 1024 + err = bgs.events.AddEvent(ctx, &events.XRPCStreamEvent{ 1025 + RepoSync: &evtCopy, 1026 + }) 1027 + if err != nil { 1028 + bgs.log.Error("failed to broadcast sync event", "error", err, "did", evt.Did) 1029 + return fmt.Errorf("failed to broadcast sync event: %w", err) 1030 + } 1031 + 1032 + return nil 1033 + } 1034 + 1035 + func (bgs *BGS) upsertPrevState(accountID models.Uid, newRootCid *cid.Cid, rev string, seq int64) error { 1036 + cidBytes := newRootCid.Bytes() 1037 + return bgs.db.Exec( 1038 + "INSERT INTO account_previous_states (uid, cid, rev, seq) VALUES (?, ?, ?, ?) ON CONFLICT (uid) DO UPDATE SET cid = EXCLUDED.cid, rev = EXCLUDED.rev, seq = EXCLUDED.seq", 1039 + accountID, cidBytes, rev, seq, 1040 + ).Error 1041 + } 1042 + 1043 + func (bgs *BGS) purgeDidCache(ctx context.Context, did string) { 1044 + ati, err := syntax.ParseAtIdentifier(did) 1045 + if err != nil { 1046 + return 1047 + } 1048 + _ = bgs.didd.Purge(ctx, *ati) 1049 + } 1050 + 1051 + // syncPDSAccount ensures that a DID has an account record in the database attached to a PDS record in the database 1052 + // Some fields may be updated if needed. 1053 + // did is the user 1054 + // host is the PDS we received this from, not necessarily the canonical PDS in the DID document 1055 + // cachedAccount is (optionally) the account that we have already looked up from cache or database 1056 + func (bgs *BGS) syncPDSAccount(ctx context.Context, did string, host *models.PDS, cachedAccount *Account) (*Account, error) { 1057 + ctx, span := tracer.Start(ctx, "syncPDSAccount") 1058 + defer span.End() 1059 + 1060 + externalUserCreationAttempts.Inc() 1061 + 1062 + bgs.log.Debug("create external user", "did", did) 1063 + 1064 + // lookup identity so that we know a DID's canonical source PDS 1065 + pdid, err := syntax.ParseDID(did) 1066 + if err != nil { 1067 + return nil, fmt.Errorf("bad did %#v, %w", did, err) 1068 + } 1069 + ident, err := bgs.didd.LookupDID(ctx, pdid) 1070 + if err != nil { 1071 + return nil, fmt.Errorf("no ident for did %s, %w", did, err) 1072 + } 1073 + if len(ident.Services) == 0 { 1074 + return nil, fmt.Errorf("no services for did %s", did) 1075 + } 1076 + pdsService, ok := ident.Services["atproto_pds"] 1077 + if !ok { 1078 + return nil, fmt.Errorf("no atproto_pds service for did %s", did) 1079 + } 1080 + durl, err := url.Parse(pdsService.URL) 1081 + if err != nil { 1082 + return nil, fmt.Errorf("pds bad url %#v, %w", pdsService.URL, err) 1083 + } 1084 + 1085 + // is the canonical PDS banned? 1086 + ban, err := bgs.domainIsBanned(ctx, durl.Host) 1087 + if err != nil { 1088 + return nil, fmt.Errorf("failed to check pds ban status: %w", err) 1089 + } 1090 + if ban { 1091 + return nil, fmt.Errorf("cannot create user on pds with banned domain") 1092 + } 1093 + 1094 + if strings.HasPrefix(durl.Host, "localhost:") { 1095 + durl.Scheme = "http" 1096 + } 1097 + 1098 + var canonicalHost *models.PDS 1099 + if host.Host == durl.Host { 1100 + // we got the message from the canonical PDS, convenient! 1101 + canonicalHost = host 1102 + } else { 1103 + // we got the message from an intermediate relay 1104 + // check our db for info on canonical PDS 1105 + var peering models.PDS 1106 + if err := bgs.db.Find(&peering, "host = ?", durl.Host).Error; err != nil { 1107 + bgs.log.Error("failed to find pds", "host", durl.Host) 1108 + return nil, err 1109 + } 1110 + canonicalHost = &peering 1111 + } 1112 + 1113 + if canonicalHost.Blocked { 1114 + return nil, fmt.Errorf("refusing to create user with blocked PDS") 1115 + } 1116 + 1117 + if canonicalHost.ID == 0 { 1118 + // we got an event from a non-canonical PDS (an intermediate relay) 1119 + // a non-canonical PDS we haven't seen before; ping it to make sure it's real 1120 + // TODO: what do we actually want to track about the source we immediately got this message from vs the canonical PDS? 1121 + bgs.log.Warn("pds discovered in new user flow", "pds", durl.String(), "did", did) 1122 + 1123 + // Do a trivial API request against the PDS to verify that it exists 1124 + pclient := &xrpc.Client{Host: durl.String()} 1125 + bgs.config.ApplyPDSClientSettings(pclient) 1126 + cfg, err := comatproto.ServerDescribeServer(ctx, pclient) 1127 + if err != nil { 1128 + // TODO: failing this shouldn't halt our indexing 1129 + return nil, fmt.Errorf("failed to check unrecognized pds: %w", err) 1130 + } 1131 + 1132 + // since handles can be anything, checking against this list doesn't matter... 1133 + _ = cfg 1134 + 1135 + // could check other things, a valid response is good enough for now 1136 + canonicalHost.Host = durl.Host 1137 + canonicalHost.SSL = (durl.Scheme == "https") 1138 + canonicalHost.RateLimit = float64(bgs.slurper.DefaultPerSecondLimit) 1139 + canonicalHost.HourlyEventLimit = bgs.slurper.DefaultPerHourLimit 1140 + canonicalHost.DailyEventLimit = bgs.slurper.DefaultPerDayLimit 1141 + canonicalHost.RepoLimit = bgs.slurper.DefaultRepoLimit 1142 + 1143 + if bgs.ssl && !canonicalHost.SSL { 1144 + return nil, fmt.Errorf("did references non-ssl PDS, this is disallowed in prod: %q %q", did, pdsService.URL) 1145 + } 1146 + 1147 + if err := bgs.db.Create(&canonicalHost).Error; err != nil { 1148 + return nil, err 1149 + } 1150 + } 1151 + 1152 + if canonicalHost.ID == 0 { 1153 + panic("somehow failed to create a pds entry?") 1154 + } 1155 + 1156 + if canonicalHost.RepoCount >= canonicalHost.RepoLimit { 1157 + // TODO: soft-limit / hard-limit ? create account in 'throttled' state, unless there are _really_ too many accounts 1158 + return nil, fmt.Errorf("refusing to create user on PDS at max repo limit for pds %q", canonicalHost.Host) 1159 + } 1160 + 1161 + // this lock just governs the lower half of this function 1162 + bgs.extUserLk.Lock() 1163 + defer bgs.extUserLk.Unlock() 1164 + 1165 + if cachedAccount == nil { 1166 + cachedAccount, err = bgs.lookupUserByDid(ctx, did) 1167 + } 1168 + if errors.Is(err, ErrNotFound) || errors.Is(err, gorm.ErrRecordNotFound) { 1169 + err = nil 1170 + } 1171 + if err != nil { 1172 + return nil, err 1173 + } 1174 + if cachedAccount != nil { 1175 + caPDS := cachedAccount.GetPDS() 1176 + if caPDS != canonicalHost.ID { 1177 + // Account is now on a different PDS, update 1178 + err = bgs.db.Transaction(func(tx *gorm.DB) error { 1179 + if caPDS != 0 { 1180 + // decrement prior PDS's account count 1181 + tx.Model(&models.PDS{}).Where("id = ?", caPDS).Update("repo_count", gorm.Expr("repo_count - 1")) 1182 + } 1183 + // update user's PDS ID 1184 + res := tx.Model(Account{}).Where("id = ?", cachedAccount.ID).Update("pds", canonicalHost.ID) 1185 + if res.Error != nil { 1186 + return fmt.Errorf("failed to update users pds: %w", res.Error) 1187 + } 1188 + // increment new PDS's account count 1189 + res = tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", canonicalHost.ID).Update("repo_count", gorm.Expr("repo_count + 1")) 1190 + return nil 1191 + }) 1192 + 1193 + cachedAccount.SetPDS(canonicalHost.ID) 1194 + } 1195 + return cachedAccount, nil 1196 + } 1197 + 1198 + newAccount := Account{ 1199 + Did: did, 1200 + PDS: canonicalHost.ID, 1201 + } 1202 + 1203 + err = bgs.db.Transaction(func(tx *gorm.DB) error { 1204 + res := tx.Model(&models.PDS{}).Where("id = ? AND repo_count < repo_limit", canonicalHost.ID).Update("repo_count", gorm.Expr("repo_count + 1")) 1205 + if res.Error != nil { 1206 + return fmt.Errorf("failed to increment repo count for pds %q: %w", canonicalHost.Host, res.Error) 1207 + } 1208 + if terr := bgs.db.Create(&newAccount).Error; terr != nil { 1209 + bgs.log.Error("failed to create user", "did", newAccount.Did, "err", terr) 1210 + return fmt.Errorf("failed to create other pds user: %w", terr) 1211 + } 1212 + return nil 1213 + }) 1214 + if err != nil { 1215 + bgs.log.Error("user create and pds inc err", "err", err) 1216 + return nil, err 1217 + } 1218 + 1219 + bgs.userCache.Add(did, &newAccount) 1220 + 1221 + return &newAccount, nil 1222 + } 1223 + 1224 + func (bgs *BGS) TakeDownRepo(ctx context.Context, did string) error { 1225 + u, err := bgs.lookupUserByDid(ctx, did) 1226 + if err != nil { 1227 + return err 1228 + } 1229 + 1230 + if err := bgs.db.Model(Account{}).Where("id = ?", u.ID).Update("taken_down", true).Error; err != nil { 1231 + return err 1232 + } 1233 + u.SetTakenDown(true) 1234 + 1235 + if err := bgs.events.TakeDownRepo(ctx, u.ID); err != nil { 1236 + return err 1237 + } 1238 + 1239 + return nil 1240 + } 1241 + 1242 + func (bgs *BGS) ReverseTakedown(ctx context.Context, did string) error { 1243 + u, err := bgs.lookupUserByDid(ctx, did) 1244 + if err != nil { 1245 + return err 1246 + } 1247 + 1248 + if err := bgs.db.Model(Account{}).Where("id = ?", u.ID).Update("taken_down", false).Error; err != nil { 1249 + return err 1250 + } 1251 + u.SetTakenDown(false) 1252 + 1253 + return nil 1254 + } 1255 + 1256 + func (bgs *BGS) GetRepoRoot(ctx context.Context, user models.Uid) (cid.Cid, error) { 1257 + var prevState AccountPreviousState 1258 + err := bgs.db.First(&prevState, user).Error 1259 + if err == nil { 1260 + return prevState.Cid.CID, nil 1261 + } else if errors.Is(err, gorm.ErrRecordNotFound) { 1262 + return cid.Cid{}, ErrUserStatusUnavailable 1263 + } else { 1264 + bgs.log.Error("user db err", "err", err) 1265 + return cid.Cid{}, fmt.Errorf("user prev db err, %w", err) 1266 + } 1267 + }
+774
cmd/relay/bgs/fedmgr.go
··· 1 + package bgs 2 + 3 + import ( 4 + "context" 5 + "errors" 6 + "fmt" 7 + "log/slog" 8 + "math/rand" 9 + "strings" 10 + "sync" 11 + "time" 12 + 13 + "github.com/RussellLuo/slidingwindow" 14 + comatproto "github.com/bluesky-social/indigo/api/atproto" 15 + "github.com/bluesky-social/indigo/cmd/relay/events" 16 + "github.com/bluesky-social/indigo/cmd/relay/events/schedulers/parallel" 17 + "github.com/bluesky-social/indigo/cmd/relay/models" 18 + 19 + "github.com/gorilla/websocket" 20 + pq "github.com/lib/pq" 21 + "gorm.io/gorm" 22 + ) 23 + 24 + type IndexCallback func(context.Context, *models.PDS, *events.XRPCStreamEvent) error 25 + 26 + type Slurper struct { 27 + cb IndexCallback 28 + 29 + db *gorm.DB 30 + 31 + lk sync.Mutex 32 + active map[string]*activeSub 33 + 34 + LimitMux sync.RWMutex 35 + Limiters map[uint]*Limiters 36 + DefaultPerSecondLimit int64 37 + DefaultPerHourLimit int64 38 + DefaultPerDayLimit int64 39 + 40 + DefaultRepoLimit int64 41 + ConcurrencyPerPDS int64 42 + MaxQueuePerPDS int64 43 + 44 + NewPDSPerDayLimiter *slidingwindow.Limiter 45 + 46 + newSubsDisabled bool 47 + trustedDomains []string 48 + 49 + shutdownChan chan bool 50 + shutdownResult chan []error 51 + 52 + ssl bool 53 + 54 + log *slog.Logger 55 + } 56 + 57 + type Limiters struct { 58 + PerSecond *slidingwindow.Limiter 59 + PerHour *slidingwindow.Limiter 60 + PerDay *slidingwindow.Limiter 61 + } 62 + 63 + type SlurperOptions struct { 64 + SSL bool 65 + DefaultPerSecondLimit int64 66 + DefaultPerHourLimit int64 67 + DefaultPerDayLimit int64 68 + DefaultRepoLimit int64 69 + ConcurrencyPerPDS int64 70 + MaxQueuePerPDS int64 71 + 72 + Logger *slog.Logger 73 + } 74 + 75 + func DefaultSlurperOptions() *SlurperOptions { 76 + return &SlurperOptions{ 77 + SSL: false, 78 + DefaultPerSecondLimit: 50, 79 + DefaultPerHourLimit: 2500, 80 + DefaultPerDayLimit: 20_000, 81 + DefaultRepoLimit: 100, 82 + ConcurrencyPerPDS: 100, 83 + MaxQueuePerPDS: 1_000, 84 + 85 + Logger: slog.Default(), 86 + } 87 + } 88 + 89 + type activeSub struct { 90 + pds *models.PDS 91 + lk sync.RWMutex 92 + ctx context.Context 93 + cancel func() 94 + } 95 + 96 + func (sub *activeSub) updateCursor(curs int64) { 97 + sub.lk.Lock() 98 + defer sub.lk.Unlock() 99 + sub.pds.Cursor = curs 100 + } 101 + 102 + func NewSlurper(db *gorm.DB, cb IndexCallback, opts *SlurperOptions) (*Slurper, error) { 103 + if opts == nil { 104 + opts = DefaultSlurperOptions() 105 + } 106 + err := db.AutoMigrate(&SlurpConfig{}) 107 + if err != nil { 108 + return nil, err 109 + } 110 + s := &Slurper{ 111 + cb: cb, 112 + db: db, 113 + active: make(map[string]*activeSub), 114 + Limiters: make(map[uint]*Limiters), 115 + DefaultPerSecondLimit: opts.DefaultPerSecondLimit, 116 + DefaultPerHourLimit: opts.DefaultPerHourLimit, 117 + DefaultPerDayLimit: opts.DefaultPerDayLimit, 118 + DefaultRepoLimit: opts.DefaultRepoLimit, 119 + ConcurrencyPerPDS: opts.ConcurrencyPerPDS, 120 + MaxQueuePerPDS: opts.MaxQueuePerPDS, 121 + ssl: opts.SSL, 122 + shutdownChan: make(chan bool), 123 + shutdownResult: make(chan []error), 124 + log: opts.Logger, 125 + } 126 + if err := s.loadConfig(); err != nil { 127 + return nil, err 128 + } 129 + 130 + // Start a goroutine to flush cursors to the DB every 30s 131 + go func() { 132 + for { 133 + select { 134 + case <-s.shutdownChan: 135 + s.log.Info("flushing PDS cursors on shutdown") 136 + ctx := context.Background() 137 + var errs []error 138 + if errs = s.flushCursors(ctx); len(errs) > 0 { 139 + for _, err := range errs { 140 + s.log.Error("failed to flush cursors on shutdown", "err", err) 141 + } 142 + } 143 + s.log.Info("done flushing PDS cursors on shutdown") 144 + s.shutdownResult <- errs 145 + return 146 + case <-time.After(time.Second * 10): 147 + s.log.Debug("flushing PDS cursors") 148 + ctx := context.Background() 149 + if errs := s.flushCursors(ctx); len(errs) > 0 { 150 + for _, err := range errs { 151 + s.log.Error("failed to flush cursors", "err", err) 152 + } 153 + } 154 + s.log.Debug("done flushing PDS cursors") 155 + } 156 + } 157 + }() 158 + 159 + return s, nil 160 + } 161 + 162 + func windowFunc() (slidingwindow.Window, slidingwindow.StopFunc) { 163 + return slidingwindow.NewLocalWindow() 164 + } 165 + 166 + func (s *Slurper) GetLimiters(pdsID uint) *Limiters { 167 + s.LimitMux.RLock() 168 + defer s.LimitMux.RUnlock() 169 + return s.Limiters[pdsID] 170 + } 171 + 172 + func (s *Slurper) GetOrCreateLimiters(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) *Limiters { 173 + s.LimitMux.RLock() 174 + defer s.LimitMux.RUnlock() 175 + lim, ok := s.Limiters[pdsID] 176 + if !ok { 177 + perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) 178 + perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) 179 + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) 180 + lim = &Limiters{ 181 + PerSecond: perSec, 182 + PerHour: perHour, 183 + PerDay: perDay, 184 + } 185 + s.Limiters[pdsID] = lim 186 + } 187 + 188 + return lim 189 + } 190 + 191 + func (s *Slurper) SetLimits(pdsID uint, perSecLimit int64, perHourLimit int64, perDayLimit int64) { 192 + s.LimitMux.Lock() 193 + defer s.LimitMux.Unlock() 194 + lim, ok := s.Limiters[pdsID] 195 + if !ok { 196 + perSec, _ := slidingwindow.NewLimiter(time.Second, perSecLimit, windowFunc) 197 + perHour, _ := slidingwindow.NewLimiter(time.Hour, perHourLimit, windowFunc) 198 + perDay, _ := slidingwindow.NewLimiter(time.Hour*24, perDayLimit, windowFunc) 199 + lim = &Limiters{ 200 + PerSecond: perSec, 201 + PerHour: perHour, 202 + PerDay: perDay, 203 + } 204 + s.Limiters[pdsID] = lim 205 + } 206 + 207 + lim.PerSecond.SetLimit(perSecLimit) 208 + lim.PerHour.SetLimit(perHourLimit) 209 + lim.PerDay.SetLimit(perDayLimit) 210 + } 211 + 212 + // Shutdown shuts down the slurper 213 + func (s *Slurper) Shutdown() []error { 214 + s.shutdownChan <- true 215 + s.log.Info("waiting for slurper shutdown") 216 + errs := <-s.shutdownResult 217 + if len(errs) > 0 { 218 + for _, err := range errs { 219 + s.log.Error("shutdown error", "err", err) 220 + } 221 + } 222 + s.log.Info("slurper shutdown complete") 223 + return errs 224 + } 225 + 226 + func (s *Slurper) loadConfig() error { 227 + var sc SlurpConfig 228 + if err := s.db.Find(&sc).Error; err != nil { 229 + return err 230 + } 231 + 232 + if sc.ID == 0 { 233 + if err := s.db.Create(&SlurpConfig{}).Error; err != nil { 234 + return err 235 + } 236 + } 237 + 238 + s.newSubsDisabled = sc.NewSubsDisabled 239 + s.trustedDomains = sc.TrustedDomains 240 + 241 + s.NewPDSPerDayLimiter, _ = slidingwindow.NewLimiter(time.Hour*24, sc.NewPDSPerDayLimit, windowFunc) 242 + 243 + return nil 244 + } 245 + 246 + type SlurpConfig struct { 247 + gorm.Model 248 + 249 + NewSubsDisabled bool 250 + TrustedDomains pq.StringArray `gorm:"type:text[]"` 251 + NewPDSPerDayLimit int64 252 + } 253 + 254 + func (s *Slurper) SetNewSubsDisabled(dis bool) error { 255 + s.lk.Lock() 256 + defer s.lk.Unlock() 257 + 258 + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_subs_disabled", dis).Error; err != nil { 259 + return err 260 + } 261 + 262 + s.newSubsDisabled = dis 263 + return nil 264 + } 265 + 266 + func (s *Slurper) GetNewSubsDisabledState() bool { 267 + s.lk.Lock() 268 + defer s.lk.Unlock() 269 + return s.newSubsDisabled 270 + } 271 + 272 + func (s *Slurper) SetNewPDSPerDayLimit(limit int64) error { 273 + s.lk.Lock() 274 + defer s.lk.Unlock() 275 + 276 + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("new_pds_per_day_limit", limit).Error; err != nil { 277 + return err 278 + } 279 + 280 + s.NewPDSPerDayLimiter.SetLimit(limit) 281 + return nil 282 + } 283 + 284 + func (s *Slurper) GetNewPDSPerDayLimit() int64 { 285 + s.lk.Lock() 286 + defer s.lk.Unlock() 287 + return s.NewPDSPerDayLimiter.Limit() 288 + } 289 + 290 + func (s *Slurper) AddTrustedDomain(domain string) error { 291 + s.lk.Lock() 292 + defer s.lk.Unlock() 293 + 294 + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_append(trusted_domains, ?)", domain)).Error; err != nil { 295 + return err 296 + } 297 + 298 + s.trustedDomains = append(s.trustedDomains, domain) 299 + return nil 300 + } 301 + 302 + func (s *Slurper) RemoveTrustedDomain(domain string) error { 303 + s.lk.Lock() 304 + defer s.lk.Unlock() 305 + 306 + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", gorm.Expr("array_remove(trusted_domains, ?)", domain)).Error; err != nil { 307 + if errors.Is(err, gorm.ErrRecordNotFound) { 308 + return nil 309 + } 310 + return err 311 + } 312 + 313 + for i, d := range s.trustedDomains { 314 + if d == domain { 315 + s.trustedDomains = append(s.trustedDomains[:i], s.trustedDomains[i+1:]...) 316 + break 317 + } 318 + } 319 + 320 + return nil 321 + } 322 + 323 + func (s *Slurper) SetTrustedDomains(domains []string) error { 324 + s.lk.Lock() 325 + defer s.lk.Unlock() 326 + 327 + if err := s.db.Model(SlurpConfig{}).Where("id = 1").Update("trusted_domains", domains).Error; err != nil { 328 + return err 329 + } 330 + 331 + s.trustedDomains = domains 332 + return nil 333 + } 334 + 335 + func (s *Slurper) GetTrustedDomains() []string { 336 + s.lk.Lock() 337 + defer s.lk.Unlock() 338 + return s.trustedDomains 339 + } 340 + 341 + var ErrNewSubsDisabled = fmt.Errorf("new subscriptions temporarily disabled") 342 + 343 + // Checks whether a host is allowed to be subscribed to 344 + // must be called with the slurper lock held 345 + func (s *Slurper) canSlurpHost(host string) bool { 346 + // Check if we're over the limit for new PDSs today 347 + if !s.NewPDSPerDayLimiter.Allow() { 348 + return false 349 + } 350 + 351 + // Check if the host is a trusted domain 352 + for _, d := range s.trustedDomains { 353 + // If the domain starts with a *., it's a wildcard 354 + if strings.HasPrefix(d, "*.") { 355 + // Cut off the * so we have .domain.com 356 + if strings.HasSuffix(host, strings.TrimPrefix(d, "*")) { 357 + return true 358 + } 359 + } else { 360 + if host == d { 361 + return true 362 + } 363 + } 364 + } 365 + 366 + return !s.newSubsDisabled 367 + } 368 + 369 + func (s *Slurper) SubscribeToPds(ctx context.Context, host string, reg bool, adminOverride bool, rateOverrides *PDSRates) error { 370 + // TODO: for performance, lock on the hostname instead of global 371 + s.lk.Lock() 372 + defer s.lk.Unlock() 373 + 374 + _, ok := s.active[host] 375 + if ok { 376 + return nil 377 + } 378 + 379 + var peering models.PDS 380 + if err := s.db.Find(&peering, "host = ?", host).Error; err != nil { 381 + return err 382 + } 383 + 384 + if peering.Blocked { 385 + return fmt.Errorf("cannot subscribe to blocked pds") 386 + } 387 + 388 + newHost := false 389 + 390 + if peering.ID == 0 { 391 + if !adminOverride && !s.canSlurpHost(host) { 392 + return ErrNewSubsDisabled 393 + } 394 + // New PDS! 395 + npds := models.PDS{ 396 + Host: host, 397 + SSL: s.ssl, 398 + Registered: reg, 399 + RateLimit: float64(s.DefaultPerSecondLimit), 400 + HourlyEventLimit: s.DefaultPerHourLimit, 401 + DailyEventLimit: s.DefaultPerDayLimit, 402 + RepoLimit: s.DefaultRepoLimit, 403 + } 404 + if rateOverrides != nil { 405 + npds.RateLimit = float64(rateOverrides.PerSecond) 406 + npds.HourlyEventLimit = rateOverrides.PerHour 407 + npds.DailyEventLimit = rateOverrides.PerDay 408 + npds.RepoLimit = rateOverrides.RepoLimit 409 + } 410 + if err := s.db.Create(&npds).Error; err != nil { 411 + return err 412 + } 413 + 414 + newHost = true 415 + peering = npds 416 + } 417 + 418 + if !peering.Registered && reg { 419 + peering.Registered = true 420 + if err := s.db.Model(models.PDS{}).Where("id = ?", peering.ID).Update("registered", true).Error; err != nil { 421 + return err 422 + } 423 + } 424 + 425 + ctx, cancel := context.WithCancel(context.Background()) 426 + sub := activeSub{ 427 + pds: &peering, 428 + ctx: ctx, 429 + cancel: cancel, 430 + } 431 + s.active[host] = &sub 432 + 433 + s.GetOrCreateLimiters(peering.ID, int64(peering.RateLimit), peering.HourlyEventLimit, peering.DailyEventLimit) 434 + 435 + go s.subscribeWithRedialer(ctx, &peering, &sub, newHost) 436 + 437 + return nil 438 + } 439 + 440 + func (s *Slurper) RestartAll() error { 441 + s.lk.Lock() 442 + defer s.lk.Unlock() 443 + 444 + var all []models.PDS 445 + if err := s.db.Find(&all, "registered = true AND blocked = false").Error; err != nil { 446 + return err 447 + } 448 + 449 + for _, pds := range all { 450 + pds := pds 451 + 452 + ctx, cancel := context.WithCancel(context.Background()) 453 + sub := activeSub{ 454 + pds: &pds, 455 + ctx: ctx, 456 + cancel: cancel, 457 + } 458 + s.active[pds.Host] = &sub 459 + 460 + // Check if we've already got a limiter for this PDS 461 + s.GetOrCreateLimiters(pds.ID, int64(pds.RateLimit), pds.HourlyEventLimit, pds.DailyEventLimit) 462 + go s.subscribeWithRedialer(ctx, &pds, &sub, false) 463 + } 464 + 465 + return nil 466 + } 467 + 468 + func (s *Slurper) subscribeWithRedialer(ctx context.Context, host *models.PDS, sub *activeSub, newHost bool) { 469 + defer func() { 470 + s.lk.Lock() 471 + defer s.lk.Unlock() 472 + 473 + delete(s.active, host.Host) 474 + }() 475 + 476 + d := websocket.Dialer{ 477 + HandshakeTimeout: time.Second * 5, 478 + } 479 + 480 + protocol := "ws" 481 + if s.ssl { 482 + protocol = "wss" 483 + } 484 + 485 + // Special case `.host.bsky.network` PDSs to rewind cursor by 200 events to smooth over unclean shutdowns 486 + if strings.HasSuffix(host.Host, ".host.bsky.network") && host.Cursor > 200 { 487 + host.Cursor -= 200 488 + } 489 + 490 + cursor := host.Cursor 491 + 492 + connectedInbound.Inc() 493 + defer connectedInbound.Dec() 494 + // TODO:? maybe keep a gauge of 'in retry backoff' sources? 495 + 496 + var backoff int 497 + for { 498 + select { 499 + case <-ctx.Done(): 500 + return 501 + default: 502 + } 503 + 504 + var url string 505 + if newHost { 506 + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos", protocol, host.Host) 507 + } else { 508 + url = fmt.Sprintf("%s://%s/xrpc/com.atproto.sync.subscribeRepos?cursor=%d", protocol, host.Host, cursor) 509 + } 510 + con, res, err := d.DialContext(ctx, url, nil) 511 + if err != nil { 512 + s.log.Warn("dialing failed", "pdsHost", host.Host, "err", err, "backoff", backoff) 513 + time.Sleep(sleepForBackoff(backoff)) 514 + backoff++ 515 + 516 + if backoff > 15 { 517 + s.log.Warn("pds does not appear to be online, disabling for now", "pdsHost", host.Host) 518 + if err := s.db.Model(&models.PDS{}).Where("id = ?", host.ID).Update("registered", false).Error; err != nil { 519 + s.log.Error("failed to unregister failing pds", "err", err) 520 + } 521 + 522 + return 523 + } 524 + 525 + continue 526 + } 527 + 528 + s.log.Info("event subscription response", "code", res.StatusCode, "url", url) 529 + 530 + curCursor := cursor 531 + if err := s.handleConnection(ctx, host, con, &cursor, sub); err != nil { 532 + if errors.Is(err, ErrTimeoutShutdown) { 533 + s.log.Info("shutting down pds subscription after timeout", "host", host.Host, "time", EventsTimeout) 534 + return 535 + } 536 + s.log.Warn("connection to failed", "host", host.Host, "err", err) 537 + // TODO: measure the last N connection error times and if they're coming too fast reconnect slower or don't reconnect and wait for requestCrawl 538 + } 539 + 540 + if cursor > curCursor { 541 + backoff = 0 542 + } 543 + } 544 + } 545 + 546 + func sleepForBackoff(b int) time.Duration { 547 + if b == 0 { 548 + return 0 549 + } 550 + 551 + if b < 10 { 552 + return (time.Duration(b) * 2) + (time.Millisecond * time.Duration(rand.Intn(1000))) 553 + } 554 + 555 + return time.Second * 30 556 + } 557 + 558 + var ErrTimeoutShutdown = fmt.Errorf("timed out waiting for new events") 559 + 560 + var EventsTimeout = time.Minute 561 + 562 + func (s *Slurper) handleConnection(ctx context.Context, host *models.PDS, con *websocket.Conn, lastCursor *int64, sub *activeSub) error { 563 + ctx, cancel := context.WithCancel(ctx) 564 + defer cancel() 565 + 566 + rsc := &events.RepoStreamCallbacks{ 567 + RepoCommit: func(evt *comatproto.SyncSubscribeRepos_Commit) error { 568 + s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Repo, "seq", evt.Seq) 569 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 570 + RepoCommit: evt, 571 + }); err != nil { 572 + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 573 + } 574 + *lastCursor = evt.Seq 575 + 576 + sub.updateCursor(*lastCursor) 577 + 578 + return nil 579 + }, 580 + RepoSync: func(evt *comatproto.SyncSubscribeRepos_Sync) error { 581 + s.log.Debug("got remote repo event", "pdsHost", host.Host, "repo", evt.Did, "seq", evt.Seq) 582 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 583 + RepoSync: evt, 584 + }); err != nil { 585 + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 586 + } 587 + *lastCursor = evt.Seq 588 + 589 + sub.updateCursor(*lastCursor) 590 + 591 + return nil 592 + }, 593 + RepoHandle: func(evt *comatproto.SyncSubscribeRepos_Handle) error { 594 + s.log.Debug("got remote handle update event", "pdsHost", host.Host, "did", evt.Did, "handle", evt.Handle) 595 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 596 + RepoHandle: evt, 597 + }); err != nil { 598 + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 599 + } 600 + *lastCursor = evt.Seq 601 + 602 + sub.updateCursor(*lastCursor) 603 + 604 + return nil 605 + }, 606 + RepoMigrate: func(evt *comatproto.SyncSubscribeRepos_Migrate) error { 607 + s.log.Debug("got remote repo migrate event", "pdsHost", host.Host, "did", evt.Did, "migrateTo", evt.MigrateTo) 608 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 609 + RepoMigrate: evt, 610 + }); err != nil { 611 + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 612 + } 613 + *lastCursor = evt.Seq 614 + 615 + sub.updateCursor(*lastCursor) 616 + 617 + return nil 618 + }, 619 + RepoTombstone: func(evt *comatproto.SyncSubscribeRepos_Tombstone) error { 620 + s.log.Debug("got remote repo tombstone event", "pdsHost", host.Host, "did", evt.Did) 621 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 622 + RepoTombstone: evt, 623 + }); err != nil { 624 + s.log.Error("failed handling event", "host", host.Host, "seq", evt.Seq, "err", err) 625 + } 626 + *lastCursor = evt.Seq 627 + 628 + sub.updateCursor(*lastCursor) 629 + 630 + return nil 631 + }, 632 + RepoInfo: func(info *comatproto.SyncSubscribeRepos_Info) error { 633 + s.log.Debug("info event", "name", info.Name, "message", info.Message, "pdsHost", host.Host) 634 + return nil 635 + }, 636 + RepoIdentity: func(ident *comatproto.SyncSubscribeRepos_Identity) error { 637 + s.log.Debug("identity event", "did", ident.Did) 638 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 639 + RepoIdentity: ident, 640 + }); err != nil { 641 + s.log.Error("failed handling event", "host", host.Host, "seq", ident.Seq, "err", err) 642 + } 643 + *lastCursor = ident.Seq 644 + 645 + sub.updateCursor(*lastCursor) 646 + 647 + return nil 648 + }, 649 + RepoAccount: func(acct *comatproto.SyncSubscribeRepos_Account) error { 650 + s.log.Debug("account event", "did", acct.Did, "status", acct.Status) 651 + if err := s.cb(context.TODO(), host, &events.XRPCStreamEvent{ 652 + RepoAccount: acct, 653 + }); err != nil { 654 + s.log.Error("failed handling event", "host", host.Host, "seq", acct.Seq, "err", err) 655 + } 656 + *lastCursor = acct.Seq 657 + 658 + sub.updateCursor(*lastCursor) 659 + 660 + return nil 661 + }, 662 + // TODO: all the other event types (handle change, migration, etc) 663 + Error: func(errf *events.ErrorFrame) error { 664 + switch errf.Error { 665 + case "FutureCursor": 666 + // if we get a FutureCursor frame, reset our sequence number for this host 667 + if err := s.db.Table("pds").Where("id = ?", host.ID).Update("cursor", 0).Error; err != nil { 668 + return err 669 + } 670 + 671 + *lastCursor = 0 672 + return fmt.Errorf("got FutureCursor frame, reset cursor tracking for host") 673 + default: 674 + return fmt.Errorf("error frame: %s: %s", errf.Error, errf.Message) 675 + } 676 + }, 677 + } 678 + 679 + lims := s.GetOrCreateLimiters(host.ID, int64(host.RateLimit), host.HourlyEventLimit, host.DailyEventLimit) 680 + 681 + limiters := []*slidingwindow.Limiter{ 682 + lims.PerSecond, 683 + lims.PerHour, 684 + lims.PerDay, 685 + } 686 + 687 + instrumentedRSC := events.NewInstrumentedRepoStreamCallbacks(limiters, rsc.EventHandler) 688 + 689 + pool := parallel.NewScheduler( 690 + 100, 691 + 1_000, 692 + con.RemoteAddr().String(), 693 + instrumentedRSC.EventHandler, 694 + ) 695 + return events.HandleRepoStream(ctx, con, pool, nil) 696 + } 697 + 698 + type cursorSnapshot struct { 699 + id uint 700 + cursor int64 701 + } 702 + 703 + // flushCursors updates the PDS cursors in the DB for all active subscriptions 704 + func (s *Slurper) flushCursors(ctx context.Context) []error { 705 + start := time.Now() 706 + //ctx, span := otel.Tracer("feedmgr").Start(ctx, "flushCursors") 707 + //defer span.End() 708 + 709 + var cursors []cursorSnapshot 710 + 711 + s.lk.Lock() 712 + // Iterate over active subs and copy the current cursor 713 + for _, sub := range s.active { 714 + sub.lk.RLock() 715 + cursors = append(cursors, cursorSnapshot{ 716 + id: sub.pds.ID, 717 + cursor: sub.pds.Cursor, 718 + }) 719 + sub.lk.RUnlock() 720 + } 721 + s.lk.Unlock() 722 + 723 + errs := []error{} 724 + okcount := 0 725 + 726 + tx := s.db.WithContext(ctx).Begin() 727 + for _, cursor := range cursors { 728 + if err := tx.WithContext(ctx).Model(models.PDS{}).Where("id = ?", cursor.id).UpdateColumn("cursor", cursor.cursor).Error; err != nil { 729 + errs = append(errs, err) 730 + } else { 731 + okcount++ 732 + } 733 + } 734 + if err := tx.WithContext(ctx).Commit().Error; err != nil { 735 + errs = append(errs, err) 736 + } 737 + dt := time.Since(start) 738 + s.log.Info("flushCursors", "dt", dt, "ok", okcount, "errs", len(errs)) 739 + 740 + return errs 741 + } 742 + 743 + func (s *Slurper) GetActiveList() []string { 744 + s.lk.Lock() 745 + defer s.lk.Unlock() 746 + var out []string 747 + for k := range s.active { 748 + out = append(out, k) 749 + } 750 + 751 + return out 752 + } 753 + 754 + var ErrNoActiveConnection = fmt.Errorf("no active connection to host") 755 + 756 + func (s *Slurper) KillUpstreamConnection(host string, block bool) error { 757 + s.lk.Lock() 758 + defer s.lk.Unlock() 759 + 760 + ac, ok := s.active[host] 761 + if !ok { 762 + return fmt.Errorf("killing connection %q: %w", host, ErrNoActiveConnection) 763 + } 764 + ac.cancel() 765 + // cleanup in the run thread subscribeWithRedialer() will delete(s.active, host) 766 + 767 + if block { 768 + if err := s.db.Model(models.PDS{}).Where("id = ?", ac.pds.ID).UpdateColumn("blocked", true).Error; err != nil { 769 + return fmt.Errorf("failed to set host as blocked: %w", err) 770 + } 771 + } 772 + 773 + return nil 774 + }
+199
cmd/relay/bgs/handlers.go
··· 1 + package bgs 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "encoding/json" 7 + "errors" 8 + "fmt" 9 + "net/http" 10 + "net/url" 11 + "strings" 12 + 13 + atproto "github.com/bluesky-social/indigo/api/atproto" 14 + comatprototypes "github.com/bluesky-social/indigo/api/atproto" 15 + "github.com/bluesky-social/indigo/cmd/relay/events" 16 + "gorm.io/gorm" 17 + 18 + "github.com/bluesky-social/indigo/xrpc" 19 + "github.com/labstack/echo/v4" 20 + ) 21 + 22 + func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context, body *comatprototypes.SyncRequestCrawl_Input) error { 23 + host := body.Hostname 24 + if host == "" { 25 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname") 26 + } 27 + 28 + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { 29 + if s.ssl { 30 + host = "https://" + host 31 + } else { 32 + host = "http://" + host 33 + } 34 + } 35 + 36 + u, err := url.Parse(host) 37 + if err != nil { 38 + return echo.NewHTTPError(http.StatusBadRequest, "failed to parse hostname") 39 + } 40 + 41 + if u.Scheme == "http" && s.ssl { 42 + return echo.NewHTTPError(http.StatusBadRequest, "this server requires https") 43 + } 44 + 45 + if u.Scheme == "https" && !s.ssl { 46 + return echo.NewHTTPError(http.StatusBadRequest, "this server does not support https") 47 + } 48 + 49 + if u.Path != "" { 50 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without path") 51 + } 52 + 53 + if u.Query().Encode() != "" { 54 + return echo.NewHTTPError(http.StatusBadRequest, "must pass hostname without query") 55 + } 56 + 57 + host = u.Host // potentially hostname:port 58 + 59 + banned, err := s.domainIsBanned(ctx, host) 60 + if banned { 61 + return echo.NewHTTPError(http.StatusUnauthorized, "domain is banned") 62 + } 63 + 64 + s.log.Warn("TODO: better host validation for crawl requests") 65 + 66 + clientHost := fmt.Sprintf("%s://%s", u.Scheme, host) 67 + 68 + c := &xrpc.Client{ 69 + Host: clientHost, 70 + Client: http.DefaultClient, // not using the client that auto-retries 71 + } 72 + 73 + desc, err := atproto.ServerDescribeServer(ctx, c) 74 + if err != nil { 75 + errMsg := fmt.Sprintf("requested host (%s) failed to respond to describe request", clientHost) 76 + return echo.NewHTTPError(http.StatusBadRequest, errMsg) 77 + } 78 + 79 + // Maybe we could do something with this response later 80 + _ = desc 81 + 82 + if len(s.nextCrawlers) != 0 { 83 + blob, err := json.Marshal(body) 84 + if err != nil { 85 + s.log.Warn("could not forward requestCrawl, json err", "err", err) 86 + } else { 87 + go func(bodyBlob []byte) { 88 + for _, rpu := range s.nextCrawlers { 89 + pu := rpu.JoinPath("/xrpc/com.atproto.sync.requestCrawl") 90 + response, err := s.httpClient.Post(pu.String(), "application/json", bytes.NewReader(bodyBlob)) 91 + if response != nil && response.Body != nil { 92 + response.Body.Close() 93 + } 94 + if err != nil || response == nil { 95 + s.log.Warn("requestCrawl forward failed", "host", rpu, "err", err) 96 + } else if response.StatusCode != http.StatusOK { 97 + s.log.Warn("requestCrawl forward failed", "host", rpu, "status", response.Status) 98 + } else { 99 + s.log.Info("requestCrawl forward successful", "host", rpu) 100 + } 101 + } 102 + }(blob) 103 + } 104 + } 105 + 106 + return s.slurper.SubscribeToPds(ctx, host, true, false, nil) 107 + } 108 + 109 + func (s *BGS) handleComAtprotoSyncListRepos(ctx context.Context, cursor int64, limit int) (*comatprototypes.SyncListRepos_Output, error) { 110 + // Load the accounts 111 + accounts := []*Account{} 112 + if err := s.db.Model(&Account{}).Where("id > ? AND NOT taken_down AND (upstream_status IS NULL OR upstream_status = 'active')", cursor).Order("id").Limit(limit).Find(&accounts).Error; err != nil { 113 + if err == gorm.ErrRecordNotFound { 114 + return &comatprototypes.SyncListRepos_Output{}, nil 115 + } 116 + s.log.Error("failed to query accounts", "err", err) 117 + return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to query accounts") 118 + } 119 + 120 + if len(accounts) == 0 { 121 + // resp.Repos is an explicit empty array, not just 'nil' 122 + return &comatprototypes.SyncListRepos_Output{ 123 + Repos: []*comatprototypes.SyncListRepos_Repo{}, 124 + }, nil 125 + } 126 + 127 + resp := &comatprototypes.SyncListRepos_Output{ 128 + Repos: make([]*comatprototypes.SyncListRepos_Repo, len(accounts)), 129 + } 130 + 131 + // Fetch the repo roots for each user 132 + for i := range accounts { 133 + user := accounts[i] 134 + 135 + root, err := s.GetRepoRoot(ctx, user.ID) 136 + if err != nil { 137 + s.log.Error("failed to get repo root", "err", err, "did", user.Did) 138 + return nil, echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to get repo root for (%s): %v", user.Did, err.Error())) 139 + } 140 + 141 + resp.Repos[i] = &comatprototypes.SyncListRepos_Repo{ 142 + Did: user.Did, 143 + Head: root.String(), 144 + } 145 + } 146 + 147 + // If this is not the last page, set the cursor 148 + if len(accounts) >= limit && len(accounts) > 1 { 149 + nextCursor := fmt.Sprintf("%d", accounts[len(accounts)-1].ID) 150 + resp.Cursor = &nextCursor 151 + } 152 + 153 + return resp, nil 154 + } 155 + 156 + var ErrUserStatusUnavailable = errors.New("user status unavailable") 157 + 158 + func (s *BGS) handleComAtprotoSyncGetLatestCommit(ctx context.Context, did string) (*comatprototypes.SyncGetLatestCommit_Output, error) { 159 + u, err := s.lookupUserByDid(ctx, did) 160 + if err != nil { 161 + if errors.Is(err, gorm.ErrRecordNotFound) { 162 + return nil, echo.NewHTTPError(http.StatusNotFound, "user not found") 163 + } 164 + return nil, echo.NewHTTPError(http.StatusInternalServerError, "failed to lookup user") 165 + } 166 + 167 + if u.GetTakenDown() { 168 + return nil, fmt.Errorf("account was taken down by the Relay") 169 + } 170 + 171 + ustatus := u.GetUpstreamStatus() 172 + if ustatus == events.AccountStatusTakendown { 173 + return nil, fmt.Errorf("account was taken down by its PDS") 174 + } 175 + 176 + if ustatus == events.AccountStatusDeactivated { 177 + return nil, fmt.Errorf("account is temporarily deactivated") 178 + } 179 + 180 + if ustatus == events.AccountStatusSuspended { 181 + return nil, fmt.Errorf("account is suspended by its PDS") 182 + } 183 + 184 + var prevState AccountPreviousState 185 + err = s.db.First(&prevState, u.ID).Error 186 + if err == nil { 187 + // okay! 188 + } else if errors.Is(err, gorm.ErrRecordNotFound) { 189 + return nil, ErrUserStatusUnavailable 190 + } else { 191 + s.log.Error("user db err", "err", err) 192 + return nil, fmt.Errorf("user prev db err, %w", err) 193 + } 194 + 195 + return &comatprototypes.SyncGetLatestCommit_Output{ 196 + Cid: prevState.Cid.CID.String(), 197 + Rev: prevState.Rev, 198 + }, nil 199 + }
+188
cmd/relay/bgs/metrics.go
··· 1 + package bgs 2 + 3 + import ( 4 + "errors" 5 + "net/http" 6 + "strconv" 7 + "time" 8 + 9 + "github.com/labstack/echo/v4" 10 + "github.com/prometheus/client_golang/prometheus" 11 + "github.com/prometheus/client_golang/prometheus/promauto" 12 + ) 13 + 14 + var eventsReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 15 + Name: "events_received_counter", 16 + Help: "The total number of events received", 17 + }, []string{"pds"}) 18 + 19 + var eventsWarningsCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 20 + Name: "events_warn_counter", 21 + Help: "Events received with warnings", 22 + }, []string{"pds", "warn"}) 23 + 24 + var eventsHandleDuration = promauto.NewHistogramVec(prometheus.HistogramOpts{ 25 + Name: "events_handle_duration", 26 + Help: "A histogram of handleFedEvent latencies", 27 + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), 28 + }, []string{"pds"}) 29 + 30 + var repoCommitsReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 31 + Name: "repo_commits_received_counter", 32 + Help: "The total number of commit events received", 33 + }, []string{"pds"}) 34 + var repoSyncReceivedCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 35 + Name: "repo_sync_received_counter", 36 + Help: "The total number of sync events received", 37 + }, []string{"pds"}) 38 + 39 + var repoCommitsResultCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 40 + Name: "repo_commits_result_counter", 41 + Help: "The results of commit events received", 42 + }, []string{"pds", "status"}) 43 + 44 + var eventsSentCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 45 + Name: "events_sent_counter", 46 + Help: "The total number of events sent to consumers", 47 + }, []string{"remote_addr", "user_agent"}) 48 + 49 + var externalUserCreationAttempts = promauto.NewCounter(prometheus.CounterOpts{ 50 + Name: "bgs_external_user_creation_attempts", 51 + Help: "The total number of external users created", 52 + }) 53 + 54 + var connectedInbound = promauto.NewGauge(prometheus.GaugeOpts{ 55 + Name: "bgs_connected_inbound", 56 + Help: "Number of inbound firehoses we are consuming", 57 + }) 58 + 59 + var newUsersDiscovered = promauto.NewCounter(prometheus.CounterOpts{ 60 + Name: "bgs_new_users_discovered", 61 + Help: "The total number of new users discovered directly from the firehose (not from refs)", 62 + }) 63 + 64 + var reqSz = promauto.NewHistogramVec(prometheus.HistogramOpts{ 65 + Name: "http_request_size_bytes", 66 + Help: "A histogram of request sizes for requests.", 67 + Buckets: prometheus.ExponentialBuckets(100, 10, 8), 68 + }, []string{"code", "method", "path"}) 69 + 70 + var reqDur = promauto.NewHistogramVec(prometheus.HistogramOpts{ 71 + Name: "http_request_duration_seconds", 72 + Help: "A histogram of latencies for requests.", 73 + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), 74 + }, []string{"code", "method", "path"}) 75 + 76 + var reqCnt = promauto.NewCounterVec(prometheus.CounterOpts{ 77 + Name: "http_requests_total", 78 + Help: "A counter for requests to the wrapped handler.", 79 + }, []string{"code", "method", "path"}) 80 + 81 + var resSz = promauto.NewHistogramVec(prometheus.HistogramOpts{ 82 + Name: "http_response_size_bytes", 83 + Help: "A histogram of response sizes for requests.", 84 + Buckets: prometheus.ExponentialBuckets(100, 10, 8), 85 + }, []string{"code", "method", "path"}) 86 + 87 + var newUserDiscoveryDuration = promauto.NewHistogram(prometheus.HistogramOpts{ 88 + Name: "relay_new_user_discovery_duration", 89 + Help: "A histogram of new user discovery latencies", 90 + Buckets: prometheus.ExponentialBuckets(0.001, 2, 15), 91 + }) 92 + 93 + var commitVerifyStarts = promauto.NewCounter(prometheus.CounterOpts{ 94 + Name: "validator_commit_verify_starts", 95 + }) 96 + 97 + var commitVerifyWarnings = promauto.NewCounterVec(prometheus.CounterOpts{ 98 + Name: "validator_commit_verify_warnings", 99 + }, []string{"host", "warn"}) 100 + 101 + // verify error and short code for why 102 + var commitVerifyErrors = promauto.NewCounterVec(prometheus.CounterOpts{ 103 + Name: "validator_commit_verify_errors", 104 + }, []string{"host", "err"}) 105 + 106 + // ok and *fully verified* 107 + var commitVerifyOk = promauto.NewCounterVec(prometheus.CounterOpts{ 108 + Name: "validator_commit_verify_ok", 109 + }, []string{"host"}) 110 + 111 + // it's ok, but... {old protocol, no previous root cid, ...} 112 + var commitVerifyOkish = promauto.NewCounterVec(prometheus.CounterOpts{ 113 + Name: "validator_commit_verify_okish", 114 + }, []string{"host", "but"}) 115 + 116 + // verify error and short code for why 117 + var syncVerifyErrors = promauto.NewCounterVec(prometheus.CounterOpts{ 118 + Name: "validator_sync_verify_errors", 119 + }, []string{"host", "err"}) 120 + 121 + var accountVerifyWarnings = promauto.NewCounterVec(prometheus.CounterOpts{ 122 + Name: "validator_account_verify_warnings", 123 + Help: "things that have been a little bit wrong with account messages", 124 + }, []string{"host", "warn"}) 125 + 126 + // MetricsMiddleware defines handler function for metrics middleware 127 + func MetricsMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 128 + return func(c echo.Context) error { 129 + path := c.Path() 130 + if path == "/metrics" || path == "/_health" { 131 + return next(c) 132 + } 133 + 134 + start := time.Now() 135 + requestSize := computeApproximateRequestSize(c.Request()) 136 + 137 + err := next(c) 138 + 139 + status := c.Response().Status 140 + if err != nil { 141 + var httpError *echo.HTTPError 142 + if errors.As(err, &httpError) { 143 + status = httpError.Code 144 + } 145 + if status == 0 || status == http.StatusOK { 146 + status = http.StatusInternalServerError 147 + } 148 + } 149 + 150 + elapsed := float64(time.Since(start)) / float64(time.Second) 151 + 152 + statusStr := strconv.Itoa(status) 153 + method := c.Request().Method 154 + 155 + responseSize := float64(c.Response().Size) 156 + 157 + reqDur.WithLabelValues(statusStr, method, path).Observe(elapsed) 158 + reqCnt.WithLabelValues(statusStr, method, path).Inc() 159 + reqSz.WithLabelValues(statusStr, method, path).Observe(float64(requestSize)) 160 + resSz.WithLabelValues(statusStr, method, path).Observe(responseSize) 161 + 162 + return err 163 + } 164 + } 165 + 166 + func computeApproximateRequestSize(r *http.Request) int { 167 + s := 0 168 + if r.URL != nil { 169 + s = len(r.URL.Path) 170 + } 171 + 172 + s += len(r.Method) 173 + s += len(r.Proto) 174 + for name, values := range r.Header { 175 + s += len(name) 176 + for _, value := range values { 177 + s += len(value) 178 + } 179 + } 180 + s += len(r.Host) 181 + 182 + // N.B. r.Form and r.MultipartForm are assumed to be included in r.URL. 183 + 184 + if r.ContentLength != -1 { 185 + s += int(r.ContentLength) 186 + } 187 + return s 188 + }
+8
cmd/relay/bgs/models.go
··· 1 + package bgs 2 + 3 + import "gorm.io/gorm" 4 + 5 + type DomainBan struct { 6 + gorm.Model 7 + Domain string `gorm:"unique"` 8 + }
+142
cmd/relay/bgs/stubs.go
··· 1 + package bgs 2 + 3 + import ( 4 + "errors" 5 + "fmt" 6 + "gorm.io/gorm" 7 + "net/http" 8 + "strconv" 9 + 10 + comatprototypes "github.com/bluesky-social/indigo/api/atproto" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + "github.com/labstack/echo/v4" 13 + "go.opentelemetry.io/otel" 14 + ) 15 + 16 + type XRPCError struct { 17 + Message string `json:"message"` 18 + } 19 + 20 + func (s *BGS) RegisterHandlersAppBsky(e *echo.Echo) error { 21 + return nil 22 + } 23 + 24 + func (s *BGS) RegisterHandlersComAtproto(e *echo.Echo) error { 25 + e.GET("/xrpc/com.atproto.sync.getLatestCommit", s.HandleComAtprotoSyncGetLatestCommit) 26 + e.GET("/xrpc/com.atproto.sync.listRepos", s.HandleComAtprotoSyncListRepos) 27 + e.POST("/xrpc/com.atproto.sync.requestCrawl", s.HandleComAtprotoSyncRequestCrawl) 28 + return nil 29 + } 30 + 31 + func (s *BGS) HandleComAtprotoSyncGetLatestCommit(c echo.Context) error { 32 + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncGetLatestCommit") 33 + defer span.End() 34 + did := c.QueryParam("did") 35 + 36 + _, err := syntax.ParseDID(did) 37 + if err != nil { 38 + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid did: %s", did)}) 39 + } 40 + 41 + var out *comatprototypes.SyncGetLatestCommit_Output 42 + var handleErr error 43 + // func (s *BGS) handleComAtprotoSyncGetLatestCommit(ctx context.Context,did string) (*comatprototypes.SyncGetLatestCommit_Output, error) 44 + out, handleErr = s.handleComAtprotoSyncGetLatestCommit(ctx, did) 45 + if handleErr != nil { 46 + return handleErr 47 + } 48 + return c.JSON(200, out) 49 + } 50 + 51 + func (s *BGS) HandleComAtprotoSyncListRepos(c echo.Context) error { 52 + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncListRepos") 53 + defer span.End() 54 + 55 + cursorQuery := c.QueryParam("cursor") 56 + limitQuery := c.QueryParam("limit") 57 + 58 + var err error 59 + 60 + limit := 500 61 + if limitQuery != "" { 62 + limit, err = strconv.Atoi(limitQuery) 63 + if err != nil || limit < 1 || limit > 1000 { 64 + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid limit: %s", limitQuery)}) 65 + } 66 + } 67 + 68 + cursor := int64(0) 69 + if cursorQuery != "" { 70 + cursor, err = strconv.ParseInt(cursorQuery, 10, 64) 71 + if err != nil || cursor < 0 { 72 + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid cursor: %s", cursorQuery)}) 73 + } 74 + } 75 + 76 + out, handleErr := s.handleComAtprotoSyncListRepos(ctx, cursor, limit) 77 + if handleErr != nil { 78 + return handleErr 79 + } 80 + return c.JSON(200, out) 81 + } 82 + 83 + // HandleComAtprotoSyncGetRepo handles /xrpc/com.atproto.sync.getRepo 84 + // returns 3xx to same URL at source PDS 85 + func (s *BGS) HandleComAtprotoSyncGetRepo(c echo.Context) error { 86 + // no request object, only params 87 + params := c.QueryParams() 88 + var did string 89 + hasDid := false 90 + for paramName, pvl := range params { 91 + switch paramName { 92 + case "did": 93 + if len(pvl) == 1 { 94 + did = pvl[0] 95 + hasDid = true 96 + } else if len(pvl) > 1 { 97 + return c.JSON(http.StatusBadRequest, XRPCError{Message: "only allow one did param"}) 98 + } 99 + case "since": 100 + // ok 101 + default: 102 + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid param: %s", paramName)}) 103 + } 104 + } 105 + if !hasDid { 106 + return c.JSON(http.StatusBadRequest, XRPCError{Message: "need did param"}) 107 + } 108 + 109 + var pdsHostname string 110 + err := s.db.Raw("SELECT pds.host FROM users JOIN pds ON users.pds = pds.id WHERE users.did = ?", did).Scan(&pdsHostname).Error 111 + if err != nil { 112 + if errors.Is(err, gorm.ErrRecordNotFound) { 113 + return c.JSON(http.StatusNotFound, XRPCError{Message: "NULL"}) 114 + } 115 + s.log.Error("user.pds.host lookup", "err", err) 116 + return c.JSON(http.StatusInternalServerError, XRPCError{Message: "sorry"}) 117 + } 118 + 119 + nextUrl := *(c.Request().URL) 120 + nextUrl.Host = pdsHostname 121 + if nextUrl.Scheme == "" { 122 + nextUrl.Scheme = "https" 123 + } 124 + return c.Redirect(http.StatusFound, nextUrl.String()) 125 + } 126 + 127 + func (s *BGS) HandleComAtprotoSyncRequestCrawl(c echo.Context) error { 128 + ctx, span := otel.Tracer("server").Start(c.Request().Context(), "HandleComAtprotoSyncRequestCrawl") 129 + defer span.End() 130 + 131 + var body comatprototypes.SyncRequestCrawl_Input 132 + if err := c.Bind(&body); err != nil { 133 + return c.JSON(http.StatusBadRequest, XRPCError{Message: fmt.Sprintf("invalid body: %s", err)}) 134 + } 135 + var handleErr error 136 + // func (s *BGS) handleComAtprotoSyncRequestCrawl(ctx context.Context,body *comatprototypes.SyncRequestCrawl_Input) error 137 + handleErr = s.handleComAtprotoSyncRequestCrawl(ctx, &body) 138 + if handleErr != nil { 139 + return handleErr 140 + } 141 + return nil 142 + }
+431
cmd/relay/bgs/validator.go
··· 1 + package bgs 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "fmt" 7 + "log/slog" 8 + "sync" 9 + "sync/atomic" 10 + "time" 11 + 12 + atproto "github.com/bluesky-social/indigo/api/atproto" 13 + "github.com/bluesky-social/indigo/atproto/identity" 14 + atrepo "github.com/bluesky-social/indigo/atproto/repo" 15 + "github.com/bluesky-social/indigo/atproto/syntax" 16 + "github.com/bluesky-social/indigo/cmd/relay/models" 17 + "github.com/ipfs/go-cid" 18 + "go.opentelemetry.io/otel" 19 + ) 20 + 21 + const defaultMaxRevFuture = time.Hour 22 + 23 + func NewValidator(directory identity.Directory, inductionTraceLog *slog.Logger) *Validator { 24 + maxRevFuture := defaultMaxRevFuture // TODO: configurable 25 + ErrRevTooFarFuture := fmt.Errorf("new rev is > %s in the future", maxRevFuture) 26 + 27 + return &Validator{ 28 + userLocks: make(map[models.Uid]*userLock), 29 + log: slog.Default().With("system", "validator"), 30 + inductionTraceLog: inductionTraceLog, 31 + directory: directory, 32 + 33 + maxRevFuture: maxRevFuture, 34 + ErrRevTooFarFuture: ErrRevTooFarFuture, 35 + AllowSignatureNotFound: true, // TODO: configurable 36 + } 37 + } 38 + 39 + // Validator contains the context and code necessary to validate #commit and #sync messages 40 + type Validator struct { 41 + lklk sync.Mutex 42 + userLocks map[models.Uid]*userLock 43 + 44 + log *slog.Logger 45 + inductionTraceLog *slog.Logger 46 + 47 + directory identity.Directory 48 + 49 + // maxRevFuture is added to time.Now() for a limit of clock skew we'll accept a `rev` in the future for 50 + maxRevFuture time.Duration 51 + 52 + // ErrRevTooFarFuture is the error we return 53 + // held here because we fmt.Errorf() once with our configured maxRevFuture into the message 54 + ErrRevTooFarFuture error 55 + 56 + // AllowSignatureNotFound enables counting messages without findable public key to pass through with a warning counter 57 + // TODO: refine this for what kind of 'not found' we accept. 58 + AllowSignatureNotFound bool 59 + } 60 + 61 + type NextCommitHandler interface { 62 + HandleCommit(ctx context.Context, host *models.PDS, uid models.Uid, did string, commit *atproto.SyncSubscribeRepos_Commit) error 63 + } 64 + 65 + type userLock struct { 66 + lk sync.Mutex 67 + waiters atomic.Int32 68 + } 69 + 70 + // lockUser re-serializes access per-user after events may have been fanned out to many worker threads by events/schedulers/parallel 71 + func (val *Validator) lockUser(ctx context.Context, user models.Uid) func() { 72 + ctx, span := otel.Tracer("validator").Start(ctx, "userLock") 73 + defer span.End() 74 + 75 + val.lklk.Lock() 76 + 77 + ulk, ok := val.userLocks[user] 78 + if !ok { 79 + ulk = &userLock{} 80 + val.userLocks[user] = ulk 81 + } 82 + 83 + ulk.waiters.Add(1) 84 + 85 + val.lklk.Unlock() 86 + 87 + ulk.lk.Lock() 88 + 89 + return func() { 90 + val.lklk.Lock() 91 + defer val.lklk.Unlock() 92 + 93 + ulk.lk.Unlock() 94 + 95 + nv := ulk.waiters.Add(-1) 96 + 97 + if nv == 0 { 98 + delete(val.userLocks, user) 99 + } 100 + } 101 + } 102 + 103 + func (val *Validator) HandleCommit(ctx context.Context, host *models.PDS, account *Account, commit *atproto.SyncSubscribeRepos_Commit, prevRoot *AccountPreviousState) (newRoot *cid.Cid, err error) { 104 + uid := account.GetUid() 105 + unlock := val.lockUser(ctx, uid) 106 + defer unlock() 107 + repoFragment, err := val.VerifyCommitMessage(ctx, host, commit, prevRoot) 108 + if err != nil { 109 + return nil, err 110 + } 111 + newRootCid, err := repoFragment.MST.RootCID() 112 + if err != nil { 113 + return nil, err 114 + } 115 + return newRootCid, nil 116 + } 117 + 118 + type revOutOfOrderError struct { 119 + dt time.Duration 120 + } 121 + 122 + func (roooe *revOutOfOrderError) Error() string { 123 + return fmt.Sprintf("new rev is before previous rev by %s", roooe.dt.String()) 124 + } 125 + 126 + var ErrNewRevBeforePrevRev = &revOutOfOrderError{} 127 + 128 + func (val *Validator) VerifyCommitMessage(ctx context.Context, host *models.PDS, msg *atproto.SyncSubscribeRepos_Commit, prevRoot *AccountPreviousState) (*atrepo.Repo, error) { 129 + hostname := host.Host 130 + hasWarning := false 131 + commitVerifyStarts.Inc() 132 + logger := slog.Default().With("did", msg.Repo, "rev", msg.Rev, "seq", msg.Seq, "time", msg.Time) 133 + 134 + did, err := syntax.ParseDID(msg.Repo) 135 + if err != nil { 136 + commitVerifyErrors.WithLabelValues(hostname, "did").Inc() 137 + return nil, err 138 + } 139 + rev, err := syntax.ParseTID(msg.Rev) 140 + if err != nil { 141 + commitVerifyErrors.WithLabelValues(hostname, "tid").Inc() 142 + return nil, err 143 + } 144 + if prevRoot != nil { 145 + prevRev := prevRoot.GetRev() 146 + curTime := rev.Time() 147 + prevTime := prevRev.Time() 148 + if curTime.Before(prevTime) { 149 + commitVerifyErrors.WithLabelValues(hostname, "revb").Inc() 150 + dt := prevTime.Sub(curTime) 151 + return nil, &revOutOfOrderError{dt} 152 + } 153 + } 154 + if rev.Time().After(time.Now().Add(val.maxRevFuture)) { 155 + commitVerifyErrors.WithLabelValues(hostname, "revf").Inc() 156 + return nil, val.ErrRevTooFarFuture 157 + } 158 + _, err = syntax.ParseDatetime(msg.Time) 159 + if err != nil { 160 + commitVerifyErrors.WithLabelValues(hostname, "time").Inc() 161 + return nil, err 162 + } 163 + 164 + if msg.TooBig { 165 + //logger.Warn("event with tooBig flag set") 166 + commitVerifyWarnings.WithLabelValues(hostname, "big").Inc() 167 + val.inductionTraceLog.Warn("commit tooBig", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 168 + hasWarning = true 169 + } 170 + if msg.Rebase { 171 + //logger.Warn("event with rebase flag set") 172 + commitVerifyWarnings.WithLabelValues(hostname, "reb").Inc() 173 + val.inductionTraceLog.Warn("commit rebase", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 174 + hasWarning = true 175 + } 176 + 177 + commit, repoFragment, err := atrepo.LoadFromCAR(ctx, bytes.NewReader([]byte(msg.Blocks))) 178 + if err != nil { 179 + commitVerifyErrors.WithLabelValues(hostname, "car").Inc() 180 + return nil, err 181 + } 182 + 183 + if commit.Rev != rev.String() { 184 + commitVerifyErrors.WithLabelValues(hostname, "rev").Inc() 185 + return nil, fmt.Errorf("rev did not match commit") 186 + } 187 + if commit.DID != did.String() { 188 + commitVerifyErrors.WithLabelValues(hostname, "did2").Inc() 189 + return nil, fmt.Errorf("rev did not match commit") 190 + } 191 + 192 + err = val.VerifyCommitSignature(ctx, commit, hostname, &hasWarning) 193 + if err != nil { 194 + // signature errors are metrics counted inside VerifyCommitSignature() 195 + return nil, err 196 + } 197 + 198 + // load out all the records 199 + for _, op := range msg.Ops { 200 + if (op.Action == "create" || op.Action == "update") && op.Cid != nil { 201 + c := (*cid.Cid)(op.Cid) 202 + nsid, rkey, err := syntax.ParseRepoPath(op.Path) 203 + if err != nil { 204 + commitVerifyErrors.WithLabelValues(hostname, "opp").Inc() 205 + return nil, fmt.Errorf("invalid repo path in ops list: %w", err) 206 + } 207 + val, err := repoFragment.GetRecordCID(ctx, nsid, rkey) 208 + if err != nil { 209 + commitVerifyErrors.WithLabelValues(hostname, "rcid").Inc() 210 + return nil, err 211 + } 212 + if *c != *val { 213 + commitVerifyErrors.WithLabelValues(hostname, "opc").Inc() 214 + return nil, fmt.Errorf("record op doesn't match MST tree value") 215 + } 216 + _, _, err = repoFragment.GetRecordBytes(ctx, nsid, rkey) 217 + if err != nil { 218 + commitVerifyErrors.WithLabelValues(hostname, "rec").Inc() 219 + return nil, err 220 + } 221 + } 222 + } 223 + 224 + // TODO: once firehose format is fully shipped, remove this 225 + for _, o := range msg.Ops { 226 + switch o.Action { 227 + case "delete": 228 + if o.Prev == nil { 229 + logger.Debug("can't invert legacy op", "action", o.Action) 230 + val.inductionTraceLog.Warn("commit delete op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 231 + commitVerifyOkish.WithLabelValues(hostname, "del").Inc() 232 + return repoFragment, nil 233 + } 234 + case "update": 235 + if o.Prev == nil { 236 + logger.Debug("can't invert legacy op", "action", o.Action) 237 + val.inductionTraceLog.Warn("commit update op", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 238 + commitVerifyOkish.WithLabelValues(hostname, "up").Inc() 239 + return repoFragment, nil 240 + } 241 + } 242 + } 243 + 244 + if msg.PrevData != nil { 245 + c := (*cid.Cid)(msg.PrevData) 246 + if prevRoot != nil { 247 + if *c != prevRoot.GetCid() { 248 + commitVerifyWarnings.WithLabelValues(hostname, "pr").Inc() 249 + val.inductionTraceLog.Warn("commit prevData mismatch", "seq", msg.Seq, "pdsHost", host.Host, "repo", msg.Repo) 250 + hasWarning = true 251 + } 252 + } else { 253 + // see counter below for okish "new" 254 + } 255 + 256 + // check internal consistency that claimed previous root matches the rest of this message 257 + ops, err := ParseCommitOps(msg.Ops) 258 + if err != nil { 259 + commitVerifyErrors.WithLabelValues(hostname, "pop").Inc() 260 + return nil, err 261 + } 262 + ops, err = atrepo.NormalizeOps(ops) 263 + if err != nil { 264 + commitVerifyErrors.WithLabelValues(hostname, "nop").Inc() 265 + return nil, err 266 + } 267 + 268 + invTree := repoFragment.MST.Copy() 269 + for _, op := range ops { 270 + if err := atrepo.InvertOp(&invTree, &op); err != nil { 271 + commitVerifyErrors.WithLabelValues(hostname, "inv").Inc() 272 + return nil, err 273 + } 274 + } 275 + computed, err := invTree.RootCID() 276 + if err != nil { 277 + commitVerifyErrors.WithLabelValues(hostname, "it").Inc() 278 + return nil, err 279 + } 280 + if *computed != *c { 281 + // this is self-inconsistent malformed data 282 + commitVerifyErrors.WithLabelValues(hostname, "pd").Inc() 283 + return nil, fmt.Errorf("inverted tree root didn't match prevData") 284 + } 285 + //logger.Debug("prevData matched", "prevData", c.String(), "computed", computed.String()) 286 + 287 + if prevRoot == nil { 288 + commitVerifyOkish.WithLabelValues(hostname, "new").Inc() 289 + } else if hasWarning { 290 + commitVerifyOkish.WithLabelValues(hostname, "warn").Inc() 291 + } else { 292 + // TODO: would it be better to make everything "okish"? 293 + // commitVerifyOkish.WithLabelValues(hostname, "ok").Inc() 294 + commitVerifyOk.WithLabelValues(hostname).Inc() 295 + } 296 + } else { 297 + // this source is still on old protocol without new prevData field 298 + commitVerifyOkish.WithLabelValues(hostname, "old").Inc() 299 + } 300 + 301 + return repoFragment, nil 302 + } 303 + 304 + // HandleSync checks signed commit from a #sync message 305 + func (val *Validator) HandleSync(ctx context.Context, host *models.PDS, msg *atproto.SyncSubscribeRepos_Sync) (newRoot *cid.Cid, err error) { 306 + hostname := host.Host 307 + hasWarning := false 308 + 309 + did, err := syntax.ParseDID(msg.Did) 310 + if err != nil { 311 + syncVerifyErrors.WithLabelValues(hostname, "did").Inc() 312 + return nil, err 313 + } 314 + rev, err := syntax.ParseTID(msg.Rev) 315 + if err != nil { 316 + syncVerifyErrors.WithLabelValues(hostname, "tid").Inc() 317 + return nil, err 318 + } 319 + if rev.Time().After(time.Now().Add(val.maxRevFuture)) { 320 + syncVerifyErrors.WithLabelValues(hostname, "revf").Inc() 321 + return nil, val.ErrRevTooFarFuture 322 + } 323 + _, err = syntax.ParseDatetime(msg.Time) 324 + if err != nil { 325 + syncVerifyErrors.WithLabelValues(hostname, "time").Inc() 326 + return nil, err 327 + } 328 + 329 + commit, err := atrepo.LoadCARCommit(ctx, bytes.NewReader([]byte(msg.Blocks))) 330 + if err != nil { 331 + commitVerifyErrors.WithLabelValues(hostname, "car").Inc() 332 + return nil, err 333 + } 334 + 335 + if commit.Rev != rev.String() { 336 + commitVerifyErrors.WithLabelValues(hostname, "rev").Inc() 337 + return nil, fmt.Errorf("rev did not match commit") 338 + } 339 + if commit.DID != did.String() { 340 + commitVerifyErrors.WithLabelValues(hostname, "did2").Inc() 341 + return nil, fmt.Errorf("rev did not match commit") 342 + } 343 + 344 + err = val.VerifyCommitSignature(ctx, commit, hostname, &hasWarning) 345 + if err != nil { 346 + // signature errors are metrics counted inside VerifyCommitSignature() 347 + return nil, err 348 + } 349 + 350 + return &commit.Data, nil 351 + } 352 + 353 + // TODO: lift back to indigo/atproto/repo util code? 354 + func ParseCommitOps(ops []*atproto.SyncSubscribeRepos_RepoOp) ([]atrepo.Operation, error) { 355 + out := []atrepo.Operation{} 356 + for _, rop := range ops { 357 + switch rop.Action { 358 + case "create": 359 + if rop.Cid == nil || rop.Prev != nil { 360 + return nil, fmt.Errorf("invalid repoOp: create") 361 + } 362 + op := atrepo.Operation{ 363 + Path: rop.Path, 364 + Prev: nil, 365 + Value: (*cid.Cid)(rop.Cid), 366 + } 367 + out = append(out, op) 368 + case "delete": 369 + if rop.Cid != nil || rop.Prev == nil { 370 + return nil, fmt.Errorf("invalid repoOp: delete") 371 + } 372 + op := atrepo.Operation{ 373 + Path: rop.Path, 374 + Prev: (*cid.Cid)(rop.Prev), 375 + Value: nil, 376 + } 377 + out = append(out, op) 378 + case "update": 379 + if rop.Cid == nil || rop.Prev == nil { 380 + return nil, fmt.Errorf("invalid repoOp: update") 381 + } 382 + op := atrepo.Operation{ 383 + Path: rop.Path, 384 + Prev: (*cid.Cid)(rop.Prev), 385 + Value: (*cid.Cid)(rop.Cid), 386 + } 387 + out = append(out, op) 388 + default: 389 + return nil, fmt.Errorf("invalid repoOp action: %s", rop.Action) 390 + } 391 + } 392 + return out, nil 393 + } 394 + 395 + // VerifyCommitSignature get's repo's registered public key from Identity Directory, verifies Commit 396 + // hostname is just for metrics in case of error 397 + func (val *Validator) VerifyCommitSignature(ctx context.Context, commit *atrepo.Commit, hostname string, hasWarning *bool) error { 398 + if val.directory == nil { 399 + return nil 400 + } 401 + xdid, err := syntax.ParseDID(commit.DID) 402 + if err != nil { 403 + commitVerifyErrors.WithLabelValues(hostname, "sig1").Inc() 404 + return fmt.Errorf("bad car DID, %w", err) 405 + } 406 + ident, err := val.directory.LookupDID(ctx, xdid) 407 + if err != nil { 408 + if val.AllowSignatureNotFound { 409 + // allow not-found conditions to pass without signature check 410 + commitVerifyWarnings.WithLabelValues(hostname, "nok").Inc() 411 + if hasWarning != nil { 412 + *hasWarning = true 413 + } 414 + return nil 415 + } 416 + commitVerifyErrors.WithLabelValues(hostname, "sig2").Inc() 417 + return fmt.Errorf("DID lookup failed, %w", err) 418 + } 419 + pk, err := ident.GetPublicKey("atproto") 420 + if err != nil { 421 + commitVerifyErrors.WithLabelValues(hostname, "sig3").Inc() 422 + return fmt.Errorf("no atproto pubkey, %w", err) 423 + } 424 + err = commit.VerifySignature(pk) 425 + if err != nil { 426 + // TODO: if the DID document was stale, force re-fetch from source and re-try if pubkey has changed 427 + commitVerifyErrors.WithLabelValues(hostname, "sig4").Inc() 428 + return fmt.Errorf("invalid signature, %w", err) 429 + } 430 + return nil 431 + }
+303
cmd/relay/events/cbor_gen.go
··· 1 + // Code generated by github.com/whyrusleeping/cbor-gen. DO NOT EDIT. 2 + 3 + package events 4 + 5 + import ( 6 + "fmt" 7 + "io" 8 + "math" 9 + "sort" 10 + 11 + cid "github.com/ipfs/go-cid" 12 + cbg "github.com/whyrusleeping/cbor-gen" 13 + xerrors "golang.org/x/xerrors" 14 + ) 15 + 16 + var _ = xerrors.Errorf 17 + var _ = cid.Undef 18 + var _ = math.E 19 + var _ = sort.Sort 20 + 21 + func (t *EventHeader) MarshalCBOR(w io.Writer) error { 22 + if t == nil { 23 + _, err := w.Write(cbg.CborNull) 24 + return err 25 + } 26 + 27 + cw := cbg.NewCborWriter(w) 28 + 29 + if _, err := cw.Write([]byte{162}); err != nil { 30 + return err 31 + } 32 + 33 + // t.MsgType (string) (string) 34 + if len("t") > 1000000 { 35 + return xerrors.Errorf("Value in field \"t\" was too long") 36 + } 37 + 38 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("t"))); err != nil { 39 + return err 40 + } 41 + if _, err := cw.WriteString(string("t")); err != nil { 42 + return err 43 + } 44 + 45 + if len(t.MsgType) > 1000000 { 46 + return xerrors.Errorf("Value in field t.MsgType was too long") 47 + } 48 + 49 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.MsgType))); err != nil { 50 + return err 51 + } 52 + if _, err := cw.WriteString(string(t.MsgType)); err != nil { 53 + return err 54 + } 55 + 56 + // t.Op (int64) (int64) 57 + if len("op") > 1000000 { 58 + return xerrors.Errorf("Value in field \"op\" was too long") 59 + } 60 + 61 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("op"))); err != nil { 62 + return err 63 + } 64 + if _, err := cw.WriteString(string("op")); err != nil { 65 + return err 66 + } 67 + 68 + if t.Op >= 0 { 69 + if err := cw.WriteMajorTypeHeader(cbg.MajUnsignedInt, uint64(t.Op)); err != nil { 70 + return err 71 + } 72 + } else { 73 + if err := cw.WriteMajorTypeHeader(cbg.MajNegativeInt, uint64(-t.Op-1)); err != nil { 74 + return err 75 + } 76 + } 77 + 78 + return nil 79 + } 80 + 81 + func (t *EventHeader) UnmarshalCBOR(r io.Reader) (err error) { 82 + *t = EventHeader{} 83 + 84 + cr := cbg.NewCborReader(r) 85 + 86 + maj, extra, err := cr.ReadHeader() 87 + if err != nil { 88 + return err 89 + } 90 + defer func() { 91 + if err == io.EOF { 92 + err = io.ErrUnexpectedEOF 93 + } 94 + }() 95 + 96 + if maj != cbg.MajMap { 97 + return fmt.Errorf("cbor input should be of type map") 98 + } 99 + 100 + if extra > cbg.MaxLength { 101 + return fmt.Errorf("EventHeader: map struct too large (%d)", extra) 102 + } 103 + 104 + n := extra 105 + 106 + nameBuf := make([]byte, 2) 107 + for i := uint64(0); i < n; i++ { 108 + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) 109 + if err != nil { 110 + return err 111 + } 112 + 113 + if !ok { 114 + // Field doesn't exist on this type, so ignore it 115 + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { 116 + return err 117 + } 118 + continue 119 + } 120 + 121 + switch string(nameBuf[:nameLen]) { 122 + // t.MsgType (string) (string) 123 + case "t": 124 + 125 + { 126 + sval, err := cbg.ReadStringWithMax(cr, 1000000) 127 + if err != nil { 128 + return err 129 + } 130 + 131 + t.MsgType = string(sval) 132 + } 133 + // t.Op (int64) (int64) 134 + case "op": 135 + { 136 + maj, extra, err := cr.ReadHeader() 137 + if err != nil { 138 + return err 139 + } 140 + var extraI int64 141 + switch maj { 142 + case cbg.MajUnsignedInt: 143 + extraI = int64(extra) 144 + if extraI < 0 { 145 + return fmt.Errorf("int64 positive overflow") 146 + } 147 + case cbg.MajNegativeInt: 148 + extraI = int64(extra) 149 + if extraI < 0 { 150 + return fmt.Errorf("int64 negative overflow") 151 + } 152 + extraI = -1 - extraI 153 + default: 154 + return fmt.Errorf("wrong type for int64 field: %d", maj) 155 + } 156 + 157 + t.Op = int64(extraI) 158 + } 159 + 160 + default: 161 + // Field doesn't exist on this type, so ignore it 162 + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { 163 + return err 164 + } 165 + } 166 + } 167 + 168 + return nil 169 + } 170 + func (t *ErrorFrame) MarshalCBOR(w io.Writer) error { 171 + if t == nil { 172 + _, err := w.Write(cbg.CborNull) 173 + return err 174 + } 175 + 176 + cw := cbg.NewCborWriter(w) 177 + 178 + if _, err := cw.Write([]byte{162}); err != nil { 179 + return err 180 + } 181 + 182 + // t.Error (string) (string) 183 + if len("error") > 1000000 { 184 + return xerrors.Errorf("Value in field \"error\" was too long") 185 + } 186 + 187 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("error"))); err != nil { 188 + return err 189 + } 190 + if _, err := cw.WriteString(string("error")); err != nil { 191 + return err 192 + } 193 + 194 + if len(t.Error) > 1000000 { 195 + return xerrors.Errorf("Value in field t.Error was too long") 196 + } 197 + 198 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Error))); err != nil { 199 + return err 200 + } 201 + if _, err := cw.WriteString(string(t.Error)); err != nil { 202 + return err 203 + } 204 + 205 + // t.Message (string) (string) 206 + if len("message") > 1000000 { 207 + return xerrors.Errorf("Value in field \"message\" was too long") 208 + } 209 + 210 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len("message"))); err != nil { 211 + return err 212 + } 213 + if _, err := cw.WriteString(string("message")); err != nil { 214 + return err 215 + } 216 + 217 + if len(t.Message) > 1000000 { 218 + return xerrors.Errorf("Value in field t.Message was too long") 219 + } 220 + 221 + if err := cw.WriteMajorTypeHeader(cbg.MajTextString, uint64(len(t.Message))); err != nil { 222 + return err 223 + } 224 + if _, err := cw.WriteString(string(t.Message)); err != nil { 225 + return err 226 + } 227 + return nil 228 + } 229 + 230 + func (t *ErrorFrame) UnmarshalCBOR(r io.Reader) (err error) { 231 + *t = ErrorFrame{} 232 + 233 + cr := cbg.NewCborReader(r) 234 + 235 + maj, extra, err := cr.ReadHeader() 236 + if err != nil { 237 + return err 238 + } 239 + defer func() { 240 + if err == io.EOF { 241 + err = io.ErrUnexpectedEOF 242 + } 243 + }() 244 + 245 + if maj != cbg.MajMap { 246 + return fmt.Errorf("cbor input should be of type map") 247 + } 248 + 249 + if extra > cbg.MaxLength { 250 + return fmt.Errorf("ErrorFrame: map struct too large (%d)", extra) 251 + } 252 + 253 + n := extra 254 + 255 + nameBuf := make([]byte, 7) 256 + for i := uint64(0); i < n; i++ { 257 + nameLen, ok, err := cbg.ReadFullStringIntoBuf(cr, nameBuf, 1000000) 258 + if err != nil { 259 + return err 260 + } 261 + 262 + if !ok { 263 + // Field doesn't exist on this type, so ignore it 264 + if err := cbg.ScanForLinks(cr, func(cid.Cid) {}); err != nil { 265 + return err 266 + } 267 + continue 268 + } 269 + 270 + switch string(nameBuf[:nameLen]) { 271 + // t.Error (string) (string) 272 + case "error": 273 + 274 + { 275 + sval, err := cbg.ReadStringWithMax(cr, 1000000) 276 + if err != nil { 277 + return err 278 + } 279 + 280 + t.Error = string(sval) 281 + } 282 + // t.Message (string) (string) 283 + case "message": 284 + 285 + { 286 + sval, err := cbg.ReadStringWithMax(cr, 1000000) 287 + if err != nil { 288 + return err 289 + } 290 + 291 + t.Message = string(sval) 292 + } 293 + 294 + default: 295 + // Field doesn't exist on this type, so ignore it 296 + if err := cbg.ScanForLinks(r, func(cid.Cid) {}); err != nil { 297 + return err 298 + } 299 + } 300 + } 301 + 302 + return nil 303 + }
+375
cmd/relay/events/consumer.go
··· 1 + package events 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "io" 7 + "log/slog" 8 + "net" 9 + "time" 10 + 11 + "github.com/RussellLuo/slidingwindow" 12 + comatproto "github.com/bluesky-social/indigo/api/atproto" 13 + "github.com/prometheus/client_golang/prometheus" 14 + 15 + "github.com/gorilla/websocket" 16 + ) 17 + 18 + type RepoStreamCallbacks struct { 19 + RepoCommit func(evt *comatproto.SyncSubscribeRepos_Commit) error 20 + RepoSync func(evt *comatproto.SyncSubscribeRepos_Sync) error 21 + RepoHandle func(evt *comatproto.SyncSubscribeRepos_Handle) error 22 + RepoIdentity func(evt *comatproto.SyncSubscribeRepos_Identity) error 23 + RepoAccount func(evt *comatproto.SyncSubscribeRepos_Account) error 24 + RepoInfo func(evt *comatproto.SyncSubscribeRepos_Info) error 25 + RepoMigrate func(evt *comatproto.SyncSubscribeRepos_Migrate) error 26 + RepoTombstone func(evt *comatproto.SyncSubscribeRepos_Tombstone) error 27 + LabelLabels func(evt *comatproto.LabelSubscribeLabels_Labels) error 28 + LabelInfo func(evt *comatproto.LabelSubscribeLabels_Info) error 29 + Error func(evt *ErrorFrame) error 30 + } 31 + 32 + func (rsc *RepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { 33 + switch { 34 + case xev.RepoCommit != nil && rsc.RepoCommit != nil: 35 + return rsc.RepoCommit(xev.RepoCommit) 36 + case xev.RepoSync != nil && rsc.RepoSync != nil: 37 + return rsc.RepoSync(xev.RepoSync) 38 + case xev.RepoHandle != nil && rsc.RepoHandle != nil: 39 + return rsc.RepoHandle(xev.RepoHandle) 40 + case xev.RepoInfo != nil && rsc.RepoInfo != nil: 41 + return rsc.RepoInfo(xev.RepoInfo) 42 + case xev.RepoMigrate != nil && rsc.RepoMigrate != nil: 43 + return rsc.RepoMigrate(xev.RepoMigrate) 44 + case xev.RepoIdentity != nil && rsc.RepoIdentity != nil: 45 + return rsc.RepoIdentity(xev.RepoIdentity) 46 + case xev.RepoAccount != nil && rsc.RepoAccount != nil: 47 + return rsc.RepoAccount(xev.RepoAccount) 48 + case xev.RepoTombstone != nil && rsc.RepoTombstone != nil: 49 + return rsc.RepoTombstone(xev.RepoTombstone) 50 + case xev.LabelLabels != nil && rsc.LabelLabels != nil: 51 + return rsc.LabelLabels(xev.LabelLabels) 52 + case xev.LabelInfo != nil && rsc.LabelInfo != nil: 53 + return rsc.LabelInfo(xev.LabelInfo) 54 + case xev.Error != nil && rsc.Error != nil: 55 + return rsc.Error(xev.Error) 56 + default: 57 + return nil 58 + } 59 + } 60 + 61 + type InstrumentedRepoStreamCallbacks struct { 62 + limiters []*slidingwindow.Limiter 63 + Next func(ctx context.Context, xev *XRPCStreamEvent) error 64 + } 65 + 66 + func NewInstrumentedRepoStreamCallbacks(limiters []*slidingwindow.Limiter, next func(ctx context.Context, xev *XRPCStreamEvent) error) *InstrumentedRepoStreamCallbacks { 67 + return &InstrumentedRepoStreamCallbacks{ 68 + limiters: limiters, 69 + Next: next, 70 + } 71 + } 72 + 73 + func waitForLimiter(ctx context.Context, lim *slidingwindow.Limiter) error { 74 + if lim.Allow() { 75 + return nil 76 + } 77 + 78 + // wait until the limiter is ready (check every 100ms) 79 + t := time.NewTicker(100 * time.Millisecond) 80 + defer t.Stop() 81 + 82 + for !lim.Allow() { 83 + select { 84 + case <-ctx.Done(): 85 + return ctx.Err() 86 + case <-t.C: 87 + } 88 + } 89 + 90 + return nil 91 + } 92 + 93 + func (rsc *InstrumentedRepoStreamCallbacks) EventHandler(ctx context.Context, xev *XRPCStreamEvent) error { 94 + // Wait on all limiters before calling the next handler 95 + for _, lim := range rsc.limiters { 96 + if err := waitForLimiter(ctx, lim); err != nil { 97 + return err 98 + } 99 + } 100 + return rsc.Next(ctx, xev) 101 + } 102 + 103 + type instrumentedReader struct { 104 + r io.Reader 105 + addr string 106 + bytesCounter prometheus.Counter 107 + } 108 + 109 + func (sr *instrumentedReader) Read(p []byte) (int, error) { 110 + n, err := sr.r.Read(p) 111 + sr.bytesCounter.Add(float64(n)) 112 + return n, err 113 + } 114 + 115 + // HandleRepoStream 116 + // con is source of events 117 + // sched gets AddWork for each event 118 + // log may be nil for default logger 119 + func HandleRepoStream(ctx context.Context, con *websocket.Conn, sched Scheduler, log *slog.Logger) error { 120 + if log == nil { 121 + log = slog.Default().With("system", "events") 122 + } 123 + ctx, cancel := context.WithCancel(ctx) 124 + defer cancel() 125 + defer sched.Shutdown() 126 + 127 + remoteAddr := con.RemoteAddr().String() 128 + 129 + go func() { 130 + t := time.NewTicker(time.Second * 30) 131 + defer t.Stop() 132 + failcount := 0 133 + 134 + for { 135 + 136 + select { 137 + case <-t.C: 138 + if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil { 139 + log.Warn("failed to ping", "err", err) 140 + failcount++ 141 + if failcount >= 4 { 142 + log.Error("too many ping fails", "count", failcount) 143 + con.Close() 144 + return 145 + } 146 + } else { 147 + failcount = 0 // ok ping 148 + } 149 + case <-ctx.Done(): 150 + con.Close() 151 + return 152 + } 153 + } 154 + }() 155 + 156 + con.SetPingHandler(func(message string) error { 157 + err := con.WriteControl(websocket.PongMessage, []byte(message), time.Now().Add(time.Second*60)) 158 + if err == websocket.ErrCloseSent { 159 + return nil 160 + } else if e, ok := err.(net.Error); ok && e.Temporary() { 161 + return nil 162 + } 163 + return err 164 + }) 165 + 166 + con.SetPongHandler(func(_ string) error { 167 + if err := con.SetReadDeadline(time.Now().Add(time.Minute)); err != nil { 168 + log.Error("failed to set read deadline", "err", err) 169 + } 170 + 171 + return nil 172 + }) 173 + 174 + lastSeq := int64(-1) 175 + for { 176 + select { 177 + case <-ctx.Done(): 178 + return ctx.Err() 179 + default: 180 + } 181 + 182 + mt, rawReader, err := con.NextReader() 183 + if err != nil { 184 + return fmt.Errorf("con err at read: %w", err) 185 + } 186 + 187 + switch mt { 188 + default: 189 + return fmt.Errorf("expected binary message from subscription endpoint") 190 + case websocket.BinaryMessage: 191 + // ok 192 + } 193 + 194 + r := &instrumentedReader{ 195 + r: rawReader, 196 + addr: remoteAddr, 197 + bytesCounter: bytesFromStreamCounter.WithLabelValues(remoteAddr), 198 + } 199 + 200 + var header EventHeader 201 + if err := header.UnmarshalCBOR(r); err != nil { 202 + return fmt.Errorf("reading header: %w", err) 203 + } 204 + 205 + eventsFromStreamCounter.WithLabelValues(remoteAddr).Inc() 206 + 207 + switch header.Op { 208 + case EvtKindMessage: 209 + switch header.MsgType { 210 + case "#commit": 211 + var evt comatproto.SyncSubscribeRepos_Commit 212 + if err := evt.UnmarshalCBOR(r); err != nil { 213 + return fmt.Errorf("reading repoCommit event: %w", err) 214 + } 215 + 216 + if evt.Seq < lastSeq { 217 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 218 + } 219 + 220 + lastSeq = evt.Seq 221 + 222 + if err := sched.AddWork(ctx, evt.Repo, &XRPCStreamEvent{ 223 + RepoCommit: &evt, 224 + }); err != nil { 225 + return err 226 + } 227 + case "#sync": 228 + var evt comatproto.SyncSubscribeRepos_Sync 229 + if err := evt.UnmarshalCBOR(r); err != nil { 230 + return fmt.Errorf("reading repoSync event: %w", err) 231 + } 232 + 233 + if evt.Seq < lastSeq { 234 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 235 + } 236 + 237 + lastSeq = evt.Seq 238 + 239 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 240 + RepoSync: &evt, 241 + }); err != nil { 242 + return err 243 + } 244 + case "#handle": 245 + // TODO: DEPRECATED message; warning/counter; drop message 246 + var evt comatproto.SyncSubscribeRepos_Handle 247 + if err := evt.UnmarshalCBOR(r); err != nil { 248 + return err 249 + } 250 + 251 + if evt.Seq < lastSeq { 252 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 253 + } 254 + lastSeq = evt.Seq 255 + 256 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 257 + RepoHandle: &evt, 258 + }); err != nil { 259 + return err 260 + } 261 + case "#identity": 262 + var evt comatproto.SyncSubscribeRepos_Identity 263 + if err := evt.UnmarshalCBOR(r); err != nil { 264 + return err 265 + } 266 + 267 + if evt.Seq < lastSeq { 268 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 269 + } 270 + lastSeq = evt.Seq 271 + 272 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 273 + RepoIdentity: &evt, 274 + }); err != nil { 275 + return err 276 + } 277 + case "#account": 278 + var evt comatproto.SyncSubscribeRepos_Account 279 + if err := evt.UnmarshalCBOR(r); err != nil { 280 + return err 281 + } 282 + 283 + if evt.Seq < lastSeq { 284 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 285 + } 286 + lastSeq = evt.Seq 287 + 288 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 289 + RepoAccount: &evt, 290 + }); err != nil { 291 + return err 292 + } 293 + case "#info": 294 + // TODO: this might also be a LabelInfo (as opposed to RepoInfo) 295 + var evt comatproto.SyncSubscribeRepos_Info 296 + if err := evt.UnmarshalCBOR(r); err != nil { 297 + return err 298 + } 299 + 300 + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 301 + RepoInfo: &evt, 302 + }); err != nil { 303 + return err 304 + } 305 + case "#migrate": 306 + // TODO: DEPRECATED message; warning/counter; drop message 307 + var evt comatproto.SyncSubscribeRepos_Migrate 308 + if err := evt.UnmarshalCBOR(r); err != nil { 309 + return err 310 + } 311 + 312 + if evt.Seq < lastSeq { 313 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 314 + } 315 + lastSeq = evt.Seq 316 + 317 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 318 + RepoMigrate: &evt, 319 + }); err != nil { 320 + return err 321 + } 322 + case "#tombstone": 323 + // TODO: DEPRECATED message; warning/counter; drop message 324 + var evt comatproto.SyncSubscribeRepos_Tombstone 325 + if err := evt.UnmarshalCBOR(r); err != nil { 326 + return err 327 + } 328 + 329 + if evt.Seq < lastSeq { 330 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 331 + } 332 + lastSeq = evt.Seq 333 + 334 + if err := sched.AddWork(ctx, evt.Did, &XRPCStreamEvent{ 335 + RepoTombstone: &evt, 336 + }); err != nil { 337 + return err 338 + } 339 + case "#labels": 340 + var evt comatproto.LabelSubscribeLabels_Labels 341 + if err := evt.UnmarshalCBOR(r); err != nil { 342 + return fmt.Errorf("reading Labels event: %w", err) 343 + } 344 + 345 + if evt.Seq < lastSeq { 346 + log.Error("Got events out of order from stream", "seq", evt.Seq, "prev", lastSeq) 347 + } 348 + 349 + lastSeq = evt.Seq 350 + 351 + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 352 + LabelLabels: &evt, 353 + }); err != nil { 354 + return err 355 + } 356 + } 357 + 358 + case EvtKindErrorFrame: 359 + var errframe ErrorFrame 360 + if err := errframe.UnmarshalCBOR(r); err != nil { 361 + return err 362 + } 363 + 364 + if err := sched.AddWork(ctx, "", &XRPCStreamEvent{ 365 + Error: &errframe, 366 + }); err != nil { 367 + return err 368 + } 369 + 370 + default: 371 + return fmt.Errorf("unrecognized event stream type: %d", header.Op) 372 + } 373 + 374 + } 375 + }
+1007
cmd/relay/events/diskpersist/diskpersist.go
··· 1 + package diskpersist 2 + 3 + import ( 4 + "bufio" 5 + "bytes" 6 + "context" 7 + "encoding/binary" 8 + "errors" 9 + "fmt" 10 + "github.com/bluesky-social/indigo/cmd/relay/events" 11 + "io" 12 + "log/slog" 13 + "os" 14 + "path/filepath" 15 + "sync" 16 + "time" 17 + 18 + "github.com/bluesky-social/indigo/api/atproto" 19 + "github.com/bluesky-social/indigo/cmd/relay/models" 20 + arc "github.com/hashicorp/golang-lru/arc/v2" 21 + "github.com/prometheus/client_golang/prometheus" 22 + "github.com/prometheus/client_golang/prometheus/promauto" 23 + cbg "github.com/whyrusleeping/cbor-gen" 24 + "gorm.io/gorm" 25 + ) 26 + 27 + type DiskPersistence struct { 28 + primaryDir string 29 + archiveDir string 30 + eventsPerFile int64 31 + writeBufferSize int 32 + retention time.Duration 33 + 34 + meta *gorm.DB 35 + 36 + broadcast func(*events.XRPCStreamEvent) 37 + 38 + logfi *os.File 39 + 40 + eventCounter int64 41 + curSeq int64 42 + timeSequence bool 43 + 44 + uids UidSource 45 + uidCache *arc.ARCCache[models.Uid, string] // TODO: unused 46 + didCache *arc.ARCCache[string, models.Uid] 47 + 48 + writers *sync.Pool 49 + buffers *sync.Pool 50 + scratch []byte 51 + 52 + outbuf *bytes.Buffer 53 + evtbuf []persistJob 54 + 55 + shutdown chan struct{} 56 + 57 + log *slog.Logger 58 + 59 + lk sync.Mutex 60 + } 61 + 62 + type persistJob struct { 63 + Bytes []byte 64 + Evt *events.XRPCStreamEvent 65 + Buffer *bytes.Buffer // so we can put it back in the pool when we're done 66 + } 67 + 68 + type jobResult struct { 69 + Err error 70 + Seq int64 71 + } 72 + 73 + const ( 74 + EvtFlagTakedown = 1 << iota 75 + EvtFlagRebased 76 + ) 77 + 78 + var _ (events.EventPersistence) = (*DiskPersistence)(nil) 79 + 80 + type DiskPersistOptions struct { 81 + UIDCacheSize int 82 + DIDCacheSize int 83 + EventsPerFile int64 84 + WriteBufferSize int 85 + Retention time.Duration 86 + 87 + Logger *slog.Logger 88 + 89 + TimeSequence bool 90 + } 91 + 92 + func DefaultDiskPersistOptions() *DiskPersistOptions { 93 + return &DiskPersistOptions{ 94 + EventsPerFile: 10_000, 95 + UIDCacheSize: 1_000_000, 96 + DIDCacheSize: 1_000_000, 97 + WriteBufferSize: 50, 98 + Retention: time.Hour * 24 * 3, // 3 days 99 + } 100 + } 101 + 102 + type UidSource interface { 103 + DidToUid(ctx context.Context, did string) (models.Uid, error) 104 + } 105 + 106 + func NewDiskPersistence(primaryDir, archiveDir string, db *gorm.DB, opts *DiskPersistOptions) (*DiskPersistence, error) { 107 + if opts == nil { 108 + opts = DefaultDiskPersistOptions() 109 + } 110 + 111 + uidCache, err := arc.NewARC[models.Uid, string](opts.UIDCacheSize) 112 + if err != nil { 113 + return nil, fmt.Errorf("failed to create uid cache: %w", err) 114 + } 115 + 116 + didCache, err := arc.NewARC[string, models.Uid](opts.DIDCacheSize) 117 + if err != nil { 118 + return nil, fmt.Errorf("failed to create did cache: %w", err) 119 + } 120 + 121 + db.AutoMigrate(&LogFileRef{}) 122 + 123 + bufpool := &sync.Pool{ 124 + New: func() any { 125 + return new(bytes.Buffer) 126 + }, 127 + } 128 + 129 + wrpool := &sync.Pool{ 130 + New: func() any { 131 + return cbg.NewCborWriter(nil) 132 + }, 133 + } 134 + 135 + dp := &DiskPersistence{ 136 + meta: db, 137 + primaryDir: primaryDir, 138 + archiveDir: archiveDir, 139 + buffers: bufpool, 140 + retention: opts.Retention, 141 + writers: wrpool, 142 + uidCache: uidCache, 143 + didCache: didCache, 144 + eventsPerFile: opts.EventsPerFile, 145 + scratch: make([]byte, headerSize), 146 + outbuf: new(bytes.Buffer), 147 + writeBufferSize: opts.WriteBufferSize, 148 + shutdown: make(chan struct{}), 149 + timeSequence: opts.TimeSequence, 150 + log: opts.Logger, 151 + } 152 + if dp.log == nil { 153 + dp.log = slog.Default().With("system", "diskpersist") 154 + } 155 + 156 + if err := dp.resumeLog(); err != nil { 157 + return nil, err 158 + } 159 + 160 + go dp.flushRoutine() 161 + 162 + go dp.garbageCollectRoutine() 163 + 164 + return dp, nil 165 + } 166 + 167 + type LogFileRef struct { 168 + gorm.Model 169 + Path string 170 + Archived bool 171 + SeqStart int64 172 + } 173 + 174 + func (dp *DiskPersistence) SetUidSource(uids UidSource) { 175 + dp.uids = uids 176 + } 177 + 178 + func (dp *DiskPersistence) resumeLog() error { 179 + var lfr LogFileRef 180 + if err := dp.meta.Order("seq_start desc").Limit(1).Find(&lfr).Error; err != nil { 181 + return err 182 + } 183 + 184 + if lfr.ID == 0 { 185 + // no files, start anew! 186 + return dp.initLogFile() 187 + } 188 + 189 + // 0 for the mode is fine since that is only used if O_CREAT is passed 190 + fi, err := os.OpenFile(filepath.Join(dp.primaryDir, lfr.Path), os.O_RDWR, 0) 191 + if err != nil { 192 + return err 193 + } 194 + 195 + seq, err := scanForLastSeq(fi, -1) 196 + if err != nil { 197 + return fmt.Errorf("failed to scan log file for last seqno: %w", err) 198 + } 199 + 200 + dp.log.Info("loaded seq", "seq", seq, "now", time.Now().UnixMicro(), "time-seq", dp.timeSequence) 201 + 202 + dp.curSeq = seq + 1 203 + dp.logfi = fi 204 + 205 + return nil 206 + } 207 + 208 + func (dp *DiskPersistence) initLogFile() error { 209 + if err := os.MkdirAll(dp.primaryDir, 0775); err != nil { 210 + return err 211 + } 212 + 213 + p := filepath.Join(dp.primaryDir, "evts-0") 214 + fi, err := os.Create(p) 215 + if err != nil { 216 + return err 217 + } 218 + 219 + if err := dp.meta.Create(&LogFileRef{ 220 + Path: "evts-0", 221 + SeqStart: 0, 222 + }).Error; err != nil { 223 + return err 224 + } 225 + 226 + dp.logfi = fi 227 + dp.curSeq = 1 228 + return nil 229 + } 230 + 231 + // swapLog swaps the current log file out for a new empty one 232 + // must only be called while holding dp.lk 233 + func (dp *DiskPersistence) swapLog(ctx context.Context) error { 234 + if err := dp.logfi.Close(); err != nil { 235 + return fmt.Errorf("failed to close current log file: %w", err) 236 + } 237 + 238 + fname := fmt.Sprintf("evts-%d", dp.curSeq) 239 + nextp := filepath.Join(dp.primaryDir, fname) 240 + 241 + fi, err := os.Create(nextp) 242 + if err != nil { 243 + return err 244 + } 245 + 246 + if err := dp.meta.Create(&LogFileRef{ 247 + Path: fname, 248 + SeqStart: dp.curSeq, 249 + }).Error; err != nil { 250 + return err 251 + } 252 + 253 + dp.logfi = fi 254 + return nil 255 + } 256 + 257 + func scanForLastSeq(fi *os.File, end int64) (int64, error) { 258 + scratch := make([]byte, headerSize) 259 + 260 + var lastSeq int64 = -1 261 + var offset int64 262 + for { 263 + eh, err := readHeader(fi, scratch) 264 + if err != nil { 265 + if errors.Is(err, io.EOF) { 266 + return lastSeq, nil 267 + } 268 + return 0, err 269 + } 270 + 271 + if end > 0 && eh.Seq > end { 272 + // return to beginning of offset 273 + n, err := fi.Seek(offset, io.SeekStart) 274 + if err != nil { 275 + return 0, err 276 + } 277 + 278 + if n != offset { 279 + return 0, fmt.Errorf("rewind seek failed") 280 + } 281 + 282 + return eh.Seq, nil 283 + } 284 + 285 + lastSeq = eh.Seq 286 + 287 + noff, err := fi.Seek(int64(eh.Len), io.SeekCurrent) 288 + if err != nil { 289 + return 0, err 290 + } 291 + 292 + if noff != offset+headerSize+int64(eh.Len) { 293 + // TODO: must recover from this 294 + return 0, fmt.Errorf("did not seek to next event properly") 295 + } 296 + 297 + offset = noff 298 + } 299 + } 300 + 301 + const ( 302 + evtKindCommit = 1 303 + evtKindHandle = 2 304 + evtKindTombstone = 3 305 + evtKindIdentity = 4 306 + evtKindAccount = 5 307 + evtKindSync = 6 308 + ) 309 + 310 + var emptyHeader = make([]byte, headerSize) 311 + 312 + func (dp *DiskPersistence) addJobToQueue(ctx context.Context, job persistJob) error { 313 + dp.lk.Lock() 314 + defer dp.lk.Unlock() 315 + 316 + if err := dp.doPersist(ctx, job); err != nil { 317 + return err 318 + } 319 + 320 + // TODO: for some reason replacing this constant with p.writeBufferSize dramatically reduces perf... 321 + if len(dp.evtbuf) > 400 { 322 + if err := dp.flushLog(ctx); err != nil { 323 + return fmt.Errorf("failed to flush disk log: %w", err) 324 + } 325 + } 326 + 327 + return nil 328 + } 329 + 330 + func (dp *DiskPersistence) flushRoutine() { 331 + t := time.NewTicker(time.Millisecond * 100) 332 + 333 + for { 334 + ctx := context.Background() 335 + select { 336 + case <-dp.shutdown: 337 + return 338 + case <-t.C: 339 + dp.lk.Lock() 340 + if err := dp.flushLog(ctx); err != nil { 341 + // TODO: this happening is quite bad. Need a recovery strategy 342 + dp.log.Error("failed to flush disk log", "err", err) 343 + } 344 + dp.lk.Unlock() 345 + } 346 + } 347 + } 348 + 349 + func (dp *DiskPersistence) flushLog(ctx context.Context) error { 350 + if len(dp.evtbuf) == 0 { 351 + return nil 352 + } 353 + 354 + _, err := io.Copy(dp.logfi, dp.outbuf) 355 + if err != nil { 356 + return err 357 + } 358 + 359 + dp.outbuf.Truncate(0) 360 + 361 + for _, ej := range dp.evtbuf { 362 + dp.broadcast(ej.Evt) 363 + ej.Buffer.Truncate(0) 364 + dp.buffers.Put(ej.Buffer) 365 + } 366 + 367 + dp.evtbuf = dp.evtbuf[:0] 368 + 369 + return nil 370 + } 371 + 372 + func (dp *DiskPersistence) garbageCollectRoutine() { 373 + t := time.NewTicker(time.Hour) 374 + 375 + for { 376 + ctx := context.Background() 377 + select { 378 + // Closing a channel can be listened to with multiple routines: https://goplay.tools/snippet/UcwbC0CeJAL 379 + case <-dp.shutdown: 380 + return 381 + case <-t.C: 382 + if errs := dp.garbageCollect(ctx); len(errs) > 0 { 383 + for _, err := range errs { 384 + dp.log.Error("garbage collection error", "err", err) 385 + } 386 + } 387 + } 388 + } 389 + } 390 + 391 + var garbageCollectionsExecuted = promauto.NewCounterVec(prometheus.CounterOpts{ 392 + Name: "disk_persister_garbage_collections_executed", 393 + Help: "Number of garbage collections executed", 394 + }, []string{}) 395 + 396 + var garbageCollectionErrors = promauto.NewCounterVec(prometheus.CounterOpts{ 397 + Name: "disk_persister_garbage_collections_errors", 398 + Help: "Number of errors encountered during garbage collection", 399 + }, []string{}) 400 + 401 + var refsGarbageCollected = promauto.NewCounterVec(prometheus.CounterOpts{ 402 + Name: "disk_persister_garbage_collections_refs_collected", 403 + Help: "Number of refs collected during garbage collection", 404 + }, []string{}) 405 + 406 + var filesGarbageCollected = promauto.NewCounterVec(prometheus.CounterOpts{ 407 + Name: "disk_persister_garbage_collections_files_collected", 408 + Help: "Number of files collected during garbage collection", 409 + }, []string{}) 410 + 411 + func (dp *DiskPersistence) garbageCollect(ctx context.Context) []error { 412 + garbageCollectionsExecuted.WithLabelValues().Inc() 413 + 414 + // Grab refs created before the retention period 415 + var refs []LogFileRef 416 + var errs []error 417 + 418 + defer func() { 419 + garbageCollectionErrors.WithLabelValues().Add(float64(len(errs))) 420 + }() 421 + 422 + if err := dp.meta.WithContext(ctx).Find(&refs, "created_at < ?", time.Now().Add(-dp.retention)).Error; err != nil { 423 + return []error{err} 424 + } 425 + 426 + oldRefsFound := len(refs) 427 + refsDeleted := 0 428 + filesDeleted := 0 429 + 430 + // In the future if we want to support Archiving, we could do that here instead of deleting 431 + for _, r := range refs { 432 + dp.lk.Lock() 433 + currentLogfile := dp.logfi.Name() 434 + dp.lk.Unlock() 435 + 436 + if filepath.Join(dp.primaryDir, r.Path) == currentLogfile { 437 + // Don't delete the current log file 438 + dp.log.Info("skipping deletion of current log file") 439 + continue 440 + } 441 + 442 + // Delete the ref in the database to prevent playback from finding it 443 + if err := dp.meta.WithContext(ctx).Delete(&r).Error; err != nil { 444 + errs = append(errs, err) 445 + continue 446 + } 447 + refsDeleted++ 448 + 449 + // Delete the file from disk 450 + if err := os.Remove(filepath.Join(dp.primaryDir, r.Path)); err != nil { 451 + errs = append(errs, err) 452 + continue 453 + } 454 + filesDeleted++ 455 + } 456 + 457 + refsGarbageCollected.WithLabelValues().Add(float64(refsDeleted)) 458 + filesGarbageCollected.WithLabelValues().Add(float64(filesDeleted)) 459 + 460 + dp.log.Info("garbage collection complete", 461 + "filesDeleted", filesDeleted, 462 + "refsDeleted", refsDeleted, 463 + "oldRefsFound", oldRefsFound, 464 + ) 465 + 466 + return errs 467 + } 468 + 469 + func (dp *DiskPersistence) doPersist(ctx context.Context, pjob persistJob) error { 470 + seq := dp.curSeq 471 + if dp.timeSequence { 472 + seq = time.Now().UnixMicro() 473 + if seq < dp.curSeq { 474 + seq = dp.curSeq 475 + } 476 + dp.curSeq = seq + 1 477 + } else { 478 + dp.curSeq++ 479 + } 480 + 481 + // Set sequence number in event header 482 + // the rest of the header is set in DiskPersistence.Persist() 483 + binary.LittleEndian.PutUint64(pjob.Bytes[20:], uint64(seq)) 484 + 485 + // update the seq in the message 486 + // copy the message from outside to a new object, clobber the seq, add it back to the event 487 + switch { 488 + case pjob.Evt.RepoCommit != nil: 489 + pjob.Evt.RepoCommit.Seq = seq 490 + case pjob.Evt.RepoSync != nil: 491 + pjob.Evt.RepoSync.Seq = seq 492 + case pjob.Evt.RepoHandle != nil: 493 + pjob.Evt.RepoHandle.Seq = seq 494 + case pjob.Evt.RepoIdentity != nil: 495 + pjob.Evt.RepoIdentity.Seq = seq 496 + case pjob.Evt.RepoAccount != nil: 497 + pjob.Evt.RepoAccount.Seq = seq 498 + case pjob.Evt.RepoTombstone != nil: 499 + pjob.Evt.RepoTombstone.Seq = seq 500 + default: 501 + // only those three get peristed right now 502 + // we should not actually ever get here... 503 + return nil 504 + } 505 + 506 + _, err := dp.outbuf.Write(pjob.Bytes) 507 + if err != nil { 508 + return err 509 + } 510 + 511 + dp.evtbuf = append(dp.evtbuf, pjob) 512 + 513 + dp.eventCounter++ 514 + if dp.eventCounter%dp.eventsPerFile == 0 { 515 + if err := dp.flushLog(ctx); err != nil { 516 + return err 517 + } 518 + 519 + // time to roll the log file 520 + if err := dp.swapLog(ctx); err != nil { 521 + return err 522 + } 523 + } 524 + 525 + return nil 526 + } 527 + 528 + // Persist implements events.EventPersistence 529 + // Persist may mutate contents of xevt and what it points to 530 + func (dp *DiskPersistence) Persist(ctx context.Context, xevt *events.XRPCStreamEvent) error { 531 + buffer := dp.buffers.Get().(*bytes.Buffer) 532 + cw := dp.writers.Get().(*cbg.CborWriter) 533 + defer dp.writers.Put(cw) 534 + cw.SetWriter(buffer) 535 + 536 + buffer.Truncate(0) 537 + 538 + buffer.Write(emptyHeader) 539 + 540 + var did string 541 + var evtKind uint32 542 + switch { 543 + case xevt.RepoCommit != nil: 544 + evtKind = evtKindCommit 545 + did = xevt.RepoCommit.Repo 546 + if err := xevt.RepoCommit.MarshalCBOR(cw); err != nil { 547 + return fmt.Errorf("failed to marshal: %w", err) 548 + } 549 + case xevt.RepoSync != nil: 550 + evtKind = evtKindSync 551 + did = xevt.RepoSync.Did 552 + if err := xevt.RepoSync.MarshalCBOR(cw); err != nil { 553 + return fmt.Errorf("failed to marshal: %w", err) 554 + } 555 + case xevt.RepoHandle != nil: 556 + evtKind = evtKindHandle 557 + did = xevt.RepoHandle.Did 558 + if err := xevt.RepoHandle.MarshalCBOR(cw); err != nil { 559 + return fmt.Errorf("failed to marshal: %w", err) 560 + } 561 + case xevt.RepoIdentity != nil: 562 + evtKind = evtKindIdentity 563 + did = xevt.RepoIdentity.Did 564 + if err := xevt.RepoIdentity.MarshalCBOR(cw); err != nil { 565 + return fmt.Errorf("failed to marshal: %w", err) 566 + } 567 + case xevt.RepoAccount != nil: 568 + evtKind = evtKindAccount 569 + did = xevt.RepoAccount.Did 570 + if err := xevt.RepoAccount.MarshalCBOR(cw); err != nil { 571 + return fmt.Errorf("failed to marshal: %w", err) 572 + } 573 + case xevt.RepoTombstone != nil: 574 + evtKind = evtKindTombstone 575 + did = xevt.RepoTombstone.Did 576 + if err := xevt.RepoTombstone.MarshalCBOR(cw); err != nil { 577 + return fmt.Errorf("failed to marshal: %w", err) 578 + } 579 + default: 580 + return nil 581 + // only those two get peristed right now 582 + } 583 + 584 + usr, err := dp.uidForDid(ctx, did) 585 + if err != nil { 586 + return err 587 + } 588 + 589 + b := buffer.Bytes() 590 + 591 + // Set flags in header (no flags for now) 592 + binary.LittleEndian.PutUint32(b, 0) 593 + // Set event kind in header 594 + binary.LittleEndian.PutUint32(b[4:], evtKind) 595 + // Set event length in header 596 + binary.LittleEndian.PutUint32(b[8:], uint32(len(b)-headerSize)) 597 + // Set user UID in header 598 + binary.LittleEndian.PutUint64(b[12:], uint64(usr)) 599 + // set seq at [20:] inside mutex section inside doPersist 600 + 601 + return dp.addJobToQueue(ctx, persistJob{ 602 + Bytes: b, 603 + Evt: xevt, 604 + Buffer: buffer, 605 + }) 606 + } 607 + 608 + type evtHeader struct { 609 + Flags uint32 610 + Kind uint32 611 + Seq int64 612 + Usr models.Uid 613 + Len uint32 614 + } 615 + 616 + func (eh *evtHeader) Len64() int64 { 617 + return int64(eh.Len) 618 + } 619 + 620 + const headerSize = 4 + 4 + 4 + 8 + 8 621 + 622 + func readHeader(r io.Reader, scratch []byte) (*evtHeader, error) { 623 + if len(scratch) < headerSize { 624 + return nil, fmt.Errorf("must pass scratch buffer of at least %d bytes", headerSize) 625 + } 626 + 627 + scratch = scratch[:headerSize] 628 + _, err := io.ReadFull(r, scratch) 629 + if err != nil { 630 + return nil, fmt.Errorf("reading header: %w", err) 631 + } 632 + 633 + flags := binary.LittleEndian.Uint32(scratch[:4]) 634 + kind := binary.LittleEndian.Uint32(scratch[4:8]) 635 + l := binary.LittleEndian.Uint32(scratch[8:12]) 636 + usr := binary.LittleEndian.Uint64(scratch[12:20]) 637 + seq := binary.LittleEndian.Uint64(scratch[20:28]) 638 + 639 + return &evtHeader{ 640 + Flags: flags, 641 + Kind: kind, 642 + Len: l, 643 + Usr: models.Uid(usr), 644 + Seq: int64(seq), 645 + }, nil 646 + } 647 + 648 + func (dp *DiskPersistence) writeHeader(ctx context.Context, flags uint32, kind uint32, l uint32, usr uint64, seq int64) error { 649 + binary.LittleEndian.PutUint32(dp.scratch, flags) 650 + binary.LittleEndian.PutUint32(dp.scratch[4:], kind) 651 + binary.LittleEndian.PutUint32(dp.scratch[8:], l) 652 + binary.LittleEndian.PutUint64(dp.scratch[12:], usr) 653 + binary.LittleEndian.PutUint64(dp.scratch[20:], uint64(seq)) 654 + 655 + nw, err := dp.logfi.Write(dp.scratch) 656 + if err != nil { 657 + return err 658 + } 659 + 660 + if nw != headerSize { 661 + return fmt.Errorf("only wrote %d bytes for header", nw) 662 + } 663 + 664 + return nil 665 + } 666 + 667 + func (dp *DiskPersistence) uidForDid(ctx context.Context, did string) (models.Uid, error) { 668 + if uid, ok := dp.didCache.Get(did); ok { 669 + return uid, nil 670 + } 671 + 672 + uid, err := dp.uids.DidToUid(ctx, did) 673 + if err != nil { 674 + return 0, err 675 + } 676 + 677 + dp.didCache.Add(did, uid) 678 + 679 + return uid, nil 680 + } 681 + 682 + func (dp *DiskPersistence) Playback(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error) error { 683 + var logs []LogFileRef 684 + needslogs := true 685 + if since != 0 { 686 + // find the log file that starts before our since 687 + result := dp.meta.Debug().Order("seq_start desc").Where("seq_start < ?", since).Limit(1).Find(&logs) 688 + if result.Error != nil { 689 + return result.Error 690 + } 691 + if result.RowsAffected != 0 { 692 + needslogs = false 693 + } 694 + } 695 + 696 + // playback data from all the log files we found, then check the db to see if more were written during playback. 697 + // repeat a few times but not unboundedly. 698 + // don't decrease '10' below 2 because we should always do two passes through this if the above before-chunk query was used. 699 + for i := 0; i < 10; i++ { 700 + if needslogs { 701 + if err := dp.meta.Debug().Order("seq_start asc").Find(&logs, "seq_start >= ?", since).Error; err != nil { 702 + return err 703 + } 704 + } 705 + 706 + lastSeq, err := dp.PlaybackLogfiles(ctx, since, cb, logs) 707 + if err != nil { 708 + return err 709 + } 710 + 711 + // No lastSeq implies that we read until the end of known events 712 + if lastSeq == nil { 713 + break 714 + } 715 + 716 + since = *lastSeq 717 + needslogs = true 718 + } 719 + 720 + return nil 721 + } 722 + 723 + func (dp *DiskPersistence) PlaybackLogfiles(ctx context.Context, since int64, cb func(*events.XRPCStreamEvent) error, logFiles []LogFileRef) (*int64, error) { 724 + for i, lf := range logFiles { 725 + lastSeq, err := dp.readEventsFrom(ctx, since, filepath.Join(dp.primaryDir, lf.Path), cb) 726 + if err != nil { 727 + return nil, err 728 + } 729 + since = 0 730 + if i == len(logFiles)-1 && 731 + lastSeq != nil && 732 + (*lastSeq-lf.SeqStart) == dp.eventsPerFile-1 { 733 + // There may be more log files to read since the last one was full 734 + return lastSeq, nil 735 + } 736 + } 737 + 738 + return nil, nil 739 + } 740 + 741 + func postDoNotEmit(flags uint32) bool { 742 + if flags&(EvtFlagRebased|EvtFlagTakedown) != 0 { 743 + return true 744 + } 745 + 746 + return false 747 + } 748 + 749 + func (dp *DiskPersistence) readEventsFrom(ctx context.Context, since int64, fn string, cb func(*events.XRPCStreamEvent) error) (*int64, error) { 750 + fi, err := os.OpenFile(fn, os.O_RDONLY, 0) 751 + if err != nil { 752 + return nil, err 753 + } 754 + 755 + if since != 0 { 756 + lastSeq, err := scanForLastSeq(fi, since) 757 + if err != nil { 758 + return nil, err 759 + } 760 + if since > lastSeq { 761 + dp.log.Error("playback cursor is greater than last seq of file checked", 762 + "since", since, 763 + "lastSeq", lastSeq, 764 + "filename", fn, 765 + ) 766 + return nil, nil 767 + } 768 + } 769 + 770 + bufr := bufio.NewReader(fi) 771 + 772 + lastSeq := int64(0) 773 + 774 + scratch := make([]byte, headerSize) 775 + for { 776 + h, err := readHeader(bufr, scratch) 777 + if err != nil { 778 + if errors.Is(err, io.EOF) { 779 + return &lastSeq, nil 780 + } 781 + 782 + return nil, err 783 + } 784 + 785 + lastSeq = h.Seq 786 + 787 + if postDoNotEmit(h.Flags) { 788 + // event taken down, skip 789 + _, err := io.CopyN(io.Discard, bufr, h.Len64()) // would be really nice if the buffered reader had a 'skip' method that does a seek under the hood 790 + if err != nil { 791 + return nil, fmt.Errorf("failed while skipping event (seq: %d, fn: %q): %w", h.Seq, fn, err) 792 + } 793 + continue 794 + } 795 + 796 + switch h.Kind { 797 + case evtKindCommit: 798 + var evt atproto.SyncSubscribeRepos_Commit 799 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 800 + return nil, err 801 + } 802 + evt.Seq = h.Seq 803 + if err := cb(&events.XRPCStreamEvent{RepoCommit: &evt}); err != nil { 804 + return nil, err 805 + } 806 + case evtKindSync: 807 + var evt atproto.SyncSubscribeRepos_Sync 808 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 809 + return nil, err 810 + } 811 + evt.Seq = h.Seq 812 + if err := cb(&events.XRPCStreamEvent{RepoSync: &evt}); err != nil { 813 + return nil, err 814 + } 815 + case evtKindHandle: 816 + var evt atproto.SyncSubscribeRepos_Handle 817 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 818 + return nil, err 819 + } 820 + evt.Seq = h.Seq 821 + if err := cb(&events.XRPCStreamEvent{RepoHandle: &evt}); err != nil { 822 + return nil, err 823 + } 824 + case evtKindIdentity: 825 + var evt atproto.SyncSubscribeRepos_Identity 826 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 827 + return nil, err 828 + } 829 + evt.Seq = h.Seq 830 + if err := cb(&events.XRPCStreamEvent{RepoIdentity: &evt}); err != nil { 831 + return nil, err 832 + } 833 + case evtKindAccount: 834 + var evt atproto.SyncSubscribeRepos_Account 835 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 836 + return nil, err 837 + } 838 + evt.Seq = h.Seq 839 + if err := cb(&events.XRPCStreamEvent{RepoAccount: &evt}); err != nil { 840 + return nil, err 841 + } 842 + case evtKindTombstone: 843 + var evt atproto.SyncSubscribeRepos_Tombstone 844 + if err := evt.UnmarshalCBOR(io.LimitReader(bufr, h.Len64())); err != nil { 845 + return nil, err 846 + } 847 + evt.Seq = h.Seq 848 + if err := cb(&events.XRPCStreamEvent{RepoTombstone: &evt}); err != nil { 849 + return nil, err 850 + } 851 + default: 852 + dp.log.Warn("unrecognized event kind coming from log file", "seq", h.Seq, "kind", h.Kind) 853 + return nil, fmt.Errorf("halting on unrecognized event kind") 854 + } 855 + } 856 + } 857 + 858 + type UserAction struct { 859 + gorm.Model 860 + 861 + Usr models.Uid 862 + RebaseAt int64 863 + Takedown bool 864 + } 865 + 866 + func (dp *DiskPersistence) TakeDownRepo(ctx context.Context, usr models.Uid) error { 867 + /* 868 + if err := p.meta.Create(&UserAction{ 869 + Usr: usr, 870 + Takedown: true, 871 + }).Error; err != nil { 872 + return err 873 + } 874 + */ 875 + 876 + return dp.forEachShardWithUserEvents(ctx, usr, func(ctx context.Context, fn string) error { 877 + if err := dp.deleteEventsForUser(ctx, usr, fn); err != nil { 878 + return err 879 + } 880 + 881 + return nil 882 + }) 883 + } 884 + 885 + func (dp *DiskPersistence) forEachShardWithUserEvents(ctx context.Context, usr models.Uid, cb func(context.Context, string) error) error { 886 + var refs []LogFileRef 887 + if err := dp.meta.Order("created_at desc").Find(&refs).Error; err != nil { 888 + return err 889 + } 890 + 891 + for _, r := range refs { 892 + mhas, err := dp.refMaybeHasUserEvents(ctx, usr, r) 893 + if err != nil { 894 + return err 895 + } 896 + 897 + if mhas { 898 + var path string 899 + if r.Archived { 900 + path = filepath.Join(dp.archiveDir, r.Path) 901 + } else { 902 + path = filepath.Join(dp.primaryDir, r.Path) 903 + } 904 + 905 + if err := cb(ctx, path); err != nil { 906 + return err 907 + } 908 + } 909 + } 910 + 911 + return nil 912 + } 913 + 914 + func (dp *DiskPersistence) refMaybeHasUserEvents(ctx context.Context, usr models.Uid, ref LogFileRef) (bool, error) { 915 + // TODO: lazily computed bloom filters for users in each logfile 916 + return true, nil 917 + } 918 + 919 + type zeroReader struct{} 920 + 921 + func (zr *zeroReader) Read(p []byte) (n int, err error) { 922 + for i := range p { 923 + p[i] = 0 924 + } 925 + return len(p), nil 926 + } 927 + 928 + func (dp *DiskPersistence) deleteEventsForUser(ctx context.Context, usr models.Uid, fn string) error { 929 + return dp.mutateUserEventsInLog(ctx, usr, fn, EvtFlagTakedown, true) 930 + } 931 + 932 + func (dp *DiskPersistence) mutateUserEventsInLog(ctx context.Context, usr models.Uid, fn string, flag uint32, zeroEvts bool) error { 933 + fi, err := os.OpenFile(fn, os.O_RDWR, 0) 934 + if err != nil { 935 + return fmt.Errorf("failed to open log file: %w", err) 936 + } 937 + defer fi.Close() 938 + defer fi.Sync() 939 + 940 + scratch := make([]byte, headerSize) 941 + var offset int64 942 + for { 943 + h, err := readHeader(fi, scratch) 944 + if err != nil { 945 + if errors.Is(err, io.EOF) { 946 + return nil 947 + } 948 + 949 + return err 950 + } 951 + 952 + if h.Usr == usr && h.Flags&flag == 0 { 953 + nflag := h.Flags | flag 954 + 955 + binary.LittleEndian.PutUint32(scratch, nflag) 956 + 957 + if _, err := fi.WriteAt(scratch[:4], offset); err != nil { 958 + return fmt.Errorf("failed to write updated flag value: %w", err) 959 + } 960 + 961 + if zeroEvts { 962 + // sync that write before blanking the event data 963 + if err := fi.Sync(); err != nil { 964 + return err 965 + } 966 + 967 + if _, err := fi.Seek(offset+headerSize, io.SeekStart); err != nil { 968 + return fmt.Errorf("failed to seek: %w", err) 969 + } 970 + 971 + _, err := io.CopyN(fi, &zeroReader{}, h.Len64()) 972 + if err != nil { 973 + return err 974 + } 975 + } 976 + } 977 + 978 + offset += headerSize + h.Len64() 979 + _, err = fi.Seek(offset, io.SeekStart) 980 + if err != nil { 981 + return fmt.Errorf("failed to seek: %w", err) 982 + } 983 + } 984 + } 985 + 986 + func (dp *DiskPersistence) Flush(ctx context.Context) error { 987 + dp.lk.Lock() 988 + defer dp.lk.Unlock() 989 + if len(dp.evtbuf) > 0 { 990 + return dp.flushLog(ctx) 991 + } 992 + return nil 993 + } 994 + 995 + func (dp *DiskPersistence) Shutdown(ctx context.Context) error { 996 + close(dp.shutdown) 997 + if err := dp.Flush(ctx); err != nil { 998 + return err 999 + } 1000 + 1001 + dp.logfi.Close() 1002 + return nil 1003 + } 1004 + 1005 + func (dp *DiskPersistence) SetEventBroadcaster(f func(*events.XRPCStreamEvent)) { 1006 + dp.broadcast = f 1007 + }
+549
cmd/relay/events/events.go
··· 1 + package events 2 + 3 + import ( 4 + "bytes" 5 + "context" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "sync" 11 + "time" 12 + 13 + comatproto "github.com/bluesky-social/indigo/api/atproto" 14 + "github.com/bluesky-social/indigo/cmd/relay/models" 15 + lexutil "github.com/bluesky-social/indigo/lex/util" 16 + "github.com/prometheus/client_golang/prometheus" 17 + 18 + cbg "github.com/whyrusleeping/cbor-gen" 19 + "go.opentelemetry.io/otel" 20 + ) 21 + 22 + var log = slog.Default().With("system", "events") 23 + 24 + type Scheduler interface { 25 + AddWork(ctx context.Context, repo string, val *XRPCStreamEvent) error 26 + Shutdown() 27 + } 28 + 29 + type EventManager struct { 30 + subs []*Subscriber 31 + subsLk sync.Mutex 32 + 33 + bufferSize int 34 + crossoverBufferSize int 35 + 36 + persister EventPersistence 37 + 38 + log *slog.Logger 39 + } 40 + 41 + func NewEventManager(persister EventPersistence) *EventManager { 42 + em := &EventManager{ 43 + bufferSize: 16 << 10, 44 + crossoverBufferSize: 512, 45 + persister: persister, 46 + log: slog.Default().With("system", "events"), 47 + } 48 + 49 + persister.SetEventBroadcaster(em.broadcastEvent) 50 + 51 + return em 52 + } 53 + 54 + const ( 55 + opSubscribe = iota 56 + opUnsubscribe 57 + opSend 58 + ) 59 + 60 + type Operation struct { 61 + op int 62 + sub *Subscriber 63 + evt *XRPCStreamEvent 64 + } 65 + 66 + func (em *EventManager) Shutdown(ctx context.Context) error { 67 + return em.persister.Shutdown(ctx) 68 + } 69 + 70 + // broadcastEvent is the target for EventPersistence.SetEventBroadcaster() 71 + func (em *EventManager) broadcastEvent(evt *XRPCStreamEvent) { 72 + // the main thing we do is send it out, so MarshalCBOR once 73 + if err := evt.Preserialize(); err != nil { 74 + em.log.Error("broadcast serialize failed", "err", err) 75 + // serialize isn't going to go better later, this event is cursed 76 + return 77 + } 78 + 79 + em.subsLk.Lock() 80 + defer em.subsLk.Unlock() 81 + 82 + // TODO: for a larger fanout we should probably have dedicated goroutines 83 + // for subsets of the subscriber set, and tiered channels to distribute 84 + // events out to them, or some similar architecture 85 + // Alternatively, we might just want to not allow too many subscribers 86 + // directly to the bgs, and have rebroadcasting proxies instead 87 + for _, s := range em.subs { 88 + if s.filter(evt) { 89 + s.enqueuedCounter.Inc() 90 + select { 91 + case s.outgoing <- evt: 92 + // sent evt on this subscriber's chan! yay! 93 + case <-s.done: 94 + // this subscriber is closing, quickly do nothing 95 + default: 96 + // filter out all future messages that would be 97 + // sent to this subscriber, but wait for it to 98 + // actually be removed by the correct bit of 99 + // code 100 + s.filter = func(*XRPCStreamEvent) bool { return false } 101 + 102 + em.log.Warn("dropping slow consumer due to event overflow", "bufferSize", len(s.outgoing), "ident", s.ident) 103 + go func(torem *Subscriber) { 104 + torem.lk.Lock() 105 + if !torem.cleanedUp { 106 + select { 107 + case torem.outgoing <- &XRPCStreamEvent{ 108 + Error: &ErrorFrame{ 109 + Error: "ConsumerTooSlow", 110 + }, 111 + }: 112 + case <-time.After(time.Second * 5): 113 + em.log.Warn("failed to send error frame to backed up consumer", "ident", torem.ident) 114 + } 115 + } 116 + torem.lk.Unlock() 117 + torem.cleanup() 118 + }(s) 119 + } 120 + s.broadcastCounter.Inc() 121 + } 122 + } 123 + } 124 + 125 + func (em *EventManager) persistAndSendEvent(ctx context.Context, evt *XRPCStreamEvent) { 126 + // TODO: can cut 5-10% off of disk persister benchmarks by making this function 127 + // accept a uid. The lookup inside the persister is notably expensive (despite 128 + // being an lru cache?) 129 + if err := em.persister.Persist(ctx, evt); err != nil { 130 + em.log.Error("failed to persist outbound event", "err", err) 131 + } 132 + } 133 + 134 + type Subscriber struct { 135 + outgoing chan *XRPCStreamEvent 136 + 137 + filter func(*XRPCStreamEvent) bool 138 + 139 + done chan struct{} 140 + 141 + cleanup func() 142 + 143 + lk sync.Mutex 144 + cleanedUp bool 145 + 146 + ident string 147 + enqueuedCounter prometheus.Counter 148 + broadcastCounter prometheus.Counter 149 + } 150 + 151 + const ( 152 + EvtKindErrorFrame = -1 153 + EvtKindMessage = 1 154 + ) 155 + 156 + type EventHeader struct { 157 + Op int64 `cborgen:"op"` 158 + MsgType string `cborgen:"t,omitempty"` 159 + } 160 + 161 + var ( 162 + // AccountStatusActive is not in the spec but used internally 163 + // the alternative would be an additional SQL column for "active" or status="" to imply active 164 + AccountStatusActive = "active" 165 + 166 + AccountStatusDeactivated = "deactivated" 167 + AccountStatusDeleted = "deleted" 168 + AccountStatusDesynchronized = "desynchronized" 169 + AccountStatusSuspended = "suspended" 170 + AccountStatusTakendown = "takendown" 171 + AccountStatusThrottled = "throttled" 172 + ) 173 + 174 + var AccountStatusList = []string{ 175 + AccountStatusActive, 176 + AccountStatusDeactivated, 177 + AccountStatusDeleted, 178 + AccountStatusDesynchronized, 179 + AccountStatusSuspended, 180 + AccountStatusTakendown, 181 + AccountStatusThrottled, 182 + } 183 + var AccountStatuses map[string]bool 184 + 185 + func init() { 186 + AccountStatuses = make(map[string]bool, len(AccountStatusList)) 187 + for _, status := range AccountStatusList { 188 + AccountStatuses[status] = true 189 + } 190 + } 191 + 192 + type XRPCStreamEvent struct { 193 + Error *ErrorFrame 194 + RepoCommit *comatproto.SyncSubscribeRepos_Commit 195 + RepoSync *comatproto.SyncSubscribeRepos_Sync 196 + RepoHandle *comatproto.SyncSubscribeRepos_Handle // DEPRECATED 197 + RepoIdentity *comatproto.SyncSubscribeRepos_Identity 198 + RepoInfo *comatproto.SyncSubscribeRepos_Info 199 + RepoMigrate *comatproto.SyncSubscribeRepos_Migrate // DEPRECATED 200 + RepoTombstone *comatproto.SyncSubscribeRepos_Tombstone // DEPRECATED 201 + RepoAccount *comatproto.SyncSubscribeRepos_Account 202 + LabelLabels *comatproto.LabelSubscribeLabels_Labels 203 + LabelInfo *comatproto.LabelSubscribeLabels_Info 204 + 205 + // some private fields for internal routing perf 206 + PrivUid models.Uid `json:"-" cborgen:"-"` 207 + PrivPdsId uint `json:"-" cborgen:"-"` 208 + PrivRelevantPds []uint `json:"-" cborgen:"-"` 209 + Preserialized []byte `json:"-" cborgen:"-"` 210 + } 211 + 212 + func (evt *XRPCStreamEvent) Serialize(wc io.Writer) error { 213 + header := EventHeader{Op: EvtKindMessage} 214 + var obj lexutil.CBOR 215 + 216 + switch { 217 + case evt.Error != nil: 218 + header.Op = EvtKindErrorFrame 219 + obj = evt.Error 220 + case evt.RepoCommit != nil: 221 + header.MsgType = "#commit" 222 + obj = evt.RepoCommit 223 + case evt.RepoSync != nil: 224 + header.MsgType = "#sync" 225 + obj = evt.RepoSync 226 + case evt.RepoHandle != nil: 227 + header.MsgType = "#handle" 228 + obj = evt.RepoHandle 229 + case evt.RepoIdentity != nil: 230 + header.MsgType = "#identity" 231 + obj = evt.RepoIdentity 232 + case evt.RepoAccount != nil: 233 + header.MsgType = "#account" 234 + obj = evt.RepoAccount 235 + case evt.RepoInfo != nil: 236 + header.MsgType = "#info" 237 + obj = evt.RepoInfo 238 + case evt.RepoMigrate != nil: 239 + header.MsgType = "#migrate" 240 + obj = evt.RepoMigrate 241 + case evt.RepoTombstone != nil: 242 + header.MsgType = "#tombstone" 243 + obj = evt.RepoTombstone 244 + default: 245 + return fmt.Errorf("unrecognized event kind") 246 + } 247 + 248 + cborWriter := cbg.NewCborWriter(wc) 249 + if err := header.MarshalCBOR(cborWriter); err != nil { 250 + return fmt.Errorf("failed to write header: %w", err) 251 + } 252 + return obj.MarshalCBOR(cborWriter) 253 + } 254 + 255 + func (xevt *XRPCStreamEvent) Deserialize(r io.Reader) error { 256 + var header EventHeader 257 + if err := header.UnmarshalCBOR(r); err != nil { 258 + return fmt.Errorf("reading header: %w", err) 259 + } 260 + switch header.Op { 261 + case EvtKindMessage: 262 + switch header.MsgType { 263 + case "#commit": 264 + var evt comatproto.SyncSubscribeRepos_Commit 265 + if err := evt.UnmarshalCBOR(r); err != nil { 266 + return fmt.Errorf("reading repoCommit event: %w", err) 267 + } 268 + xevt.RepoCommit = &evt 269 + case "#sync": 270 + var evt comatproto.SyncSubscribeRepos_Sync 271 + if err := evt.UnmarshalCBOR(r); err != nil { 272 + return fmt.Errorf("reading repoSync event: %w", err) 273 + } 274 + xevt.RepoSync = &evt 275 + case "#handle": 276 + // TODO: DEPRECATED message; warning/counter; drop message 277 + var evt comatproto.SyncSubscribeRepos_Handle 278 + if err := evt.UnmarshalCBOR(r); err != nil { 279 + return err 280 + } 281 + xevt.RepoHandle = &evt 282 + case "#identity": 283 + var evt comatproto.SyncSubscribeRepos_Identity 284 + if err := evt.UnmarshalCBOR(r); err != nil { 285 + return err 286 + } 287 + xevt.RepoIdentity = &evt 288 + case "#account": 289 + var evt comatproto.SyncSubscribeRepos_Account 290 + if err := evt.UnmarshalCBOR(r); err != nil { 291 + return err 292 + } 293 + xevt.RepoAccount = &evt 294 + case "#info": 295 + // TODO: this might also be a LabelInfo (as opposed to RepoInfo) 296 + var evt comatproto.SyncSubscribeRepos_Info 297 + if err := evt.UnmarshalCBOR(r); err != nil { 298 + return err 299 + } 300 + xevt.RepoInfo = &evt 301 + case "#migrate": 302 + // TODO: DEPRECATED message; warning/counter; drop message 303 + var evt comatproto.SyncSubscribeRepos_Migrate 304 + if err := evt.UnmarshalCBOR(r); err != nil { 305 + return err 306 + } 307 + xevt.RepoMigrate = &evt 308 + case "#tombstone": 309 + // TODO: DEPRECATED message; warning/counter; drop message 310 + var evt comatproto.SyncSubscribeRepos_Tombstone 311 + if err := evt.UnmarshalCBOR(r); err != nil { 312 + return err 313 + } 314 + xevt.RepoTombstone = &evt 315 + case "#labels": 316 + var evt comatproto.LabelSubscribeLabels_Labels 317 + if err := evt.UnmarshalCBOR(r); err != nil { 318 + return fmt.Errorf("reading Labels event: %w", err) 319 + } 320 + xevt.LabelLabels = &evt 321 + } 322 + case EvtKindErrorFrame: 323 + var errframe ErrorFrame 324 + if err := errframe.UnmarshalCBOR(r); err != nil { 325 + return err 326 + } 327 + xevt.Error = &errframe 328 + default: 329 + return fmt.Errorf("unrecognized event stream type: %d", header.Op) 330 + } 331 + return nil 332 + } 333 + 334 + var ErrNoSeq = errors.New("event has no sequence number") 335 + 336 + // serialize content into Preserialized cache 337 + func (evt *XRPCStreamEvent) Preserialize() error { 338 + if evt.Preserialized != nil { 339 + return nil 340 + } 341 + var buf bytes.Buffer 342 + err := evt.Serialize(&buf) 343 + if err != nil { 344 + return err 345 + } 346 + evt.Preserialized = buf.Bytes() 347 + return nil 348 + } 349 + 350 + type ErrorFrame struct { 351 + Error string `cborgen:"error"` 352 + Message string `cborgen:"message"` 353 + } 354 + 355 + func (em *EventManager) AddEvent(ctx context.Context, ev *XRPCStreamEvent) error { 356 + ctx, span := otel.Tracer("events").Start(ctx, "AddEvent") 357 + defer span.End() 358 + 359 + em.persistAndSendEvent(ctx, ev) 360 + return nil 361 + } 362 + 363 + var ( 364 + ErrPlaybackShutdown = fmt.Errorf("playback shutting down") 365 + ErrCaughtUp = fmt.Errorf("caught up") 366 + ) 367 + 368 + func (em *EventManager) Subscribe(ctx context.Context, ident string, filter func(*XRPCStreamEvent) bool, since *int64) (<-chan *XRPCStreamEvent, func(), error) { 369 + // TODO: the only known filters are 'true' and 'false', replace the function pointer with a bool 370 + if filter == nil { 371 + filter = func(*XRPCStreamEvent) bool { return true } 372 + } 373 + 374 + done := make(chan struct{}) 375 + sub := &Subscriber{ 376 + ident: ident, 377 + outgoing: make(chan *XRPCStreamEvent, em.bufferSize), 378 + filter: filter, 379 + done: done, 380 + enqueuedCounter: eventsEnqueued.WithLabelValues(ident), 381 + broadcastCounter: eventsBroadcast.WithLabelValues(ident), 382 + } 383 + 384 + sub.cleanup = sync.OnceFunc(func() { 385 + sub.lk.Lock() 386 + defer sub.lk.Unlock() 387 + close(done) 388 + em.rmSubscriber(sub) 389 + close(sub.outgoing) 390 + sub.cleanedUp = true 391 + }) 392 + 393 + if since == nil { 394 + em.addSubscriber(sub) 395 + return sub.outgoing, sub.cleanup, nil 396 + } 397 + 398 + out := make(chan *XRPCStreamEvent, em.crossoverBufferSize) 399 + 400 + go func() { 401 + lastSeq := *since 402 + // run playback to get through *most* of the events, getting our current cursor close to realtime 403 + if err := em.persister.Playback(ctx, *since, func(e *XRPCStreamEvent) error { 404 + select { 405 + case <-done: 406 + return ErrPlaybackShutdown 407 + case out <- e: 408 + seq := SequenceForEvent(e) 409 + if seq > 0 { 410 + lastSeq = seq 411 + } 412 + return nil 413 + } 414 + }); err != nil { 415 + if errors.Is(err, ErrPlaybackShutdown) { 416 + em.log.Warn("events playback", "err", err) 417 + } else { 418 + em.log.Error("events playback", "err", err) 419 + } 420 + 421 + // TODO: send an error frame or something? 422 + close(out) 423 + return 424 + } 425 + 426 + // now, start buffering events from the live stream 427 + em.addSubscriber(sub) 428 + 429 + first := <-sub.outgoing 430 + 431 + // run playback again to get us to the events that have started buffering 432 + if err := em.persister.Playback(ctx, lastSeq, func(e *XRPCStreamEvent) error { 433 + seq := SequenceForEvent(e) 434 + if seq > SequenceForEvent(first) { 435 + return ErrCaughtUp 436 + } 437 + 438 + select { 439 + case <-done: 440 + return ErrPlaybackShutdown 441 + case out <- e: 442 + return nil 443 + } 444 + }); err != nil { 445 + if !errors.Is(err, ErrCaughtUp) { 446 + em.log.Error("events playback", "err", err) 447 + 448 + // TODO: send an error frame or something? 449 + close(out) 450 + em.rmSubscriber(sub) 451 + return 452 + } 453 + } 454 + 455 + // now that we are caught up, just copy events from the channel over 456 + for evt := range sub.outgoing { 457 + select { 458 + case out <- evt: 459 + case <-done: 460 + em.rmSubscriber(sub) 461 + return 462 + } 463 + } 464 + }() 465 + 466 + return out, sub.cleanup, nil 467 + } 468 + 469 + func SequenceForEvent(evt *XRPCStreamEvent) int64 { 470 + return evt.Sequence() 471 + } 472 + 473 + func (evt *XRPCStreamEvent) Sequence() int64 { 474 + switch { 475 + case evt == nil: 476 + return -1 477 + case evt.RepoCommit != nil: 478 + return evt.RepoCommit.Seq 479 + case evt.RepoSync != nil: 480 + return evt.RepoSync.Seq 481 + case evt.RepoHandle != nil: 482 + return evt.RepoHandle.Seq 483 + case evt.RepoMigrate != nil: 484 + return evt.RepoMigrate.Seq 485 + case evt.RepoTombstone != nil: 486 + return evt.RepoTombstone.Seq 487 + case evt.RepoIdentity != nil: 488 + return evt.RepoIdentity.Seq 489 + case evt.RepoAccount != nil: 490 + return evt.RepoAccount.Seq 491 + case evt.RepoInfo != nil: 492 + return -1 493 + case evt.Error != nil: 494 + return -1 495 + default: 496 + return -1 497 + } 498 + } 499 + 500 + func (evt *XRPCStreamEvent) GetSequence() (int64, bool) { 501 + switch { 502 + case evt == nil: 503 + return -1, false 504 + case evt.RepoCommit != nil: 505 + return evt.RepoCommit.Seq, true 506 + case evt.RepoSync != nil: 507 + return evt.RepoSync.Seq, true 508 + case evt.RepoHandle != nil: 509 + return evt.RepoHandle.Seq, true 510 + case evt.RepoMigrate != nil: 511 + return evt.RepoMigrate.Seq, true 512 + case evt.RepoTombstone != nil: 513 + return evt.RepoTombstone.Seq, true 514 + case evt.RepoIdentity != nil: 515 + return evt.RepoIdentity.Seq, true 516 + case evt.RepoAccount != nil: 517 + return evt.RepoAccount.Seq, true 518 + case evt.RepoInfo != nil: 519 + return -1, false 520 + case evt.Error != nil: 521 + return -1, false 522 + default: 523 + return -1, false 524 + } 525 + } 526 + 527 + func (em *EventManager) rmSubscriber(sub *Subscriber) { 528 + em.subsLk.Lock() 529 + defer em.subsLk.Unlock() 530 + 531 + for i, s := range em.subs { 532 + if s == sub { 533 + em.subs[i] = em.subs[len(em.subs)-1] 534 + em.subs = em.subs[:len(em.subs)-1] 535 + break 536 + } 537 + } 538 + } 539 + 540 + func (em *EventManager) addSubscriber(sub *Subscriber) { 541 + em.subsLk.Lock() 542 + defer em.subsLk.Unlock() 543 + 544 + em.subs = append(em.subs, sub) 545 + } 546 + 547 + func (em *EventManager) TakeDownRepo(ctx context.Context, user models.Uid) error { 548 + return em.persister.TakeDownRepo(ctx, user) 549 + }
+26
cmd/relay/events/metrics.go
··· 1 + package events 2 + 3 + import ( 4 + "github.com/prometheus/client_golang/prometheus" 5 + "github.com/prometheus/client_golang/prometheus/promauto" 6 + ) 7 + 8 + var eventsFromStreamCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 9 + Name: "indigo_repo_stream_events_received_total", 10 + Help: "Total number of events received from the stream", 11 + }, []string{"remote_addr"}) 12 + 13 + var bytesFromStreamCounter = promauto.NewCounterVec(prometheus.CounterOpts{ 14 + Name: "indigo_repo_stream_bytes_total", 15 + Help: "Total bytes received from the stream", 16 + }, []string{"remote_addr"}) 17 + 18 + var eventsEnqueued = promauto.NewCounterVec(prometheus.CounterOpts{ 19 + Name: "indigo_events_enqueued_for_broadcast_total", 20 + Help: "Total number of events enqueued to broadcast to subscribers", 21 + }, []string{"pool"}) 22 + 23 + var eventsBroadcast = promauto.NewCounterVec(prometheus.CounterOpts{ 24 + Name: "indigo_events_broadcast_total", 25 + Help: "Total number of events broadcast to subscribers", 26 + }, []string{"pool"})
+99
cmd/relay/events/persist.go
··· 1 + package events 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "sync" 7 + 8 + "github.com/bluesky-social/indigo/cmd/relay/models" 9 + ) 10 + 11 + // Note that this interface looks generic, but some persisters might only work with RepoAppend or LabelLabels 12 + type EventPersistence interface { 13 + Persist(ctx context.Context, e *XRPCStreamEvent) error 14 + Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error 15 + TakeDownRepo(ctx context.Context, usr models.Uid) error 16 + Flush(context.Context) error 17 + Shutdown(context.Context) error 18 + 19 + SetEventBroadcaster(func(*XRPCStreamEvent)) 20 + } 21 + 22 + // MemPersister is the most naive implementation of event persistence 23 + // This EventPersistence option works fine with all event types 24 + // ill do better later 25 + type MemPersister struct { 26 + buf []*XRPCStreamEvent 27 + lk sync.Mutex 28 + seq int64 29 + 30 + broadcast func(*XRPCStreamEvent) 31 + } 32 + 33 + func NewMemPersister() *MemPersister { 34 + return &MemPersister{} 35 + } 36 + 37 + func (mp *MemPersister) Persist(ctx context.Context, e *XRPCStreamEvent) error { 38 + mp.lk.Lock() 39 + defer mp.lk.Unlock() 40 + mp.seq++ 41 + switch { 42 + case e.RepoCommit != nil: 43 + e.RepoCommit.Seq = mp.seq 44 + case e.RepoHandle != nil: 45 + e.RepoHandle.Seq = mp.seq 46 + case e.RepoIdentity != nil: 47 + e.RepoIdentity.Seq = mp.seq 48 + case e.RepoAccount != nil: 49 + e.RepoAccount.Seq = mp.seq 50 + case e.RepoMigrate != nil: 51 + e.RepoMigrate.Seq = mp.seq 52 + case e.RepoTombstone != nil: 53 + e.RepoTombstone.Seq = mp.seq 54 + case e.LabelLabels != nil: 55 + e.LabelLabels.Seq = mp.seq 56 + default: 57 + panic("no event in persist call") 58 + } 59 + mp.buf = append(mp.buf, e) 60 + 61 + mp.broadcast(e) 62 + 63 + return nil 64 + } 65 + 66 + func (mp *MemPersister) Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error { 67 + mp.lk.Lock() 68 + l := len(mp.buf) 69 + mp.lk.Unlock() 70 + 71 + if since >= int64(l) { 72 + return nil 73 + } 74 + 75 + // TODO: abusing the fact that buf[0].seq is currently always 1 76 + for _, e := range mp.buf[since:l] { 77 + if err := cb(e); err != nil { 78 + return err 79 + } 80 + } 81 + 82 + return nil 83 + } 84 + 85 + func (mp *MemPersister) TakeDownRepo(ctx context.Context, uid models.Uid) error { 86 + return fmt.Errorf("repo takedowns not currently supported by memory persister, test usage only") 87 + } 88 + 89 + func (mp *MemPersister) Flush(ctx context.Context) error { 90 + return nil 91 + } 92 + 93 + func (mp *MemPersister) SetEventBroadcaster(brc func(*XRPCStreamEvent)) { 94 + mp.broadcast = brc 95 + } 96 + 97 + func (mp *MemPersister) Shutdown(context.Context) error { 98 + return nil 99 + }
+26
cmd/relay/events/schedulers/metrics.go
··· 1 + package schedulers 2 + 3 + import ( 4 + "github.com/prometheus/client_golang/prometheus" 5 + "github.com/prometheus/client_golang/prometheus/promauto" 6 + ) 7 + 8 + var WorkItemsAdded = promauto.NewCounterVec(prometheus.CounterOpts{ 9 + Name: "indigo_scheduler_work_items_added_total", 10 + Help: "Total number of work items added to the consumer pool", 11 + }, []string{"pool", "scheduler_type"}) 12 + 13 + var WorkItemsProcessed = promauto.NewCounterVec(prometheus.CounterOpts{ 14 + Name: "indigo_scheduler_work_items_processed_total", 15 + Help: "Total number of work items processed by the consumer pool", 16 + }, []string{"pool", "scheduler_type"}) 17 + 18 + var WorkItemsActive = promauto.NewCounterVec(prometheus.CounterOpts{ 19 + Name: "indigo_scheduler_work_items_active_total", 20 + Help: "Total number of work items passed into a worker", 21 + }, []string{"pool", "scheduler_type"}) 22 + 23 + var WorkersActive = promauto.NewGaugeVec(prometheus.GaugeOpts{ 24 + Name: "indigo_scheduler_workers_active", 25 + Help: "Number of workers currently active", 26 + }, []string{"pool", "scheduler_type"})
+148
cmd/relay/events/schedulers/parallel/parallel.go
··· 1 + package parallel 2 + 3 + import ( 4 + "context" 5 + "log/slog" 6 + "sync" 7 + 8 + "github.com/bluesky-social/indigo/cmd/relay/events" 9 + "github.com/bluesky-social/indigo/events/schedulers" 10 + 11 + "github.com/prometheus/client_golang/prometheus" 12 + ) 13 + 14 + // Scheduler is a parallel scheduler that will run work on a fixed number of workers 15 + type Scheduler struct { 16 + maxConcurrency int 17 + maxQueue int 18 + 19 + do func(context.Context, *events.XRPCStreamEvent) error 20 + 21 + feeder chan *consumerTask 22 + out chan struct{} 23 + 24 + lk sync.Mutex 25 + active map[string][]*consumerTask 26 + 27 + ident string 28 + 29 + // metrics 30 + itemsAdded prometheus.Counter 31 + itemsProcessed prometheus.Counter 32 + itemsActive prometheus.Counter 33 + workesActive prometheus.Gauge 34 + 35 + log *slog.Logger 36 + } 37 + 38 + func NewScheduler(maxC, maxQ int, ident string, do func(context.Context, *events.XRPCStreamEvent) error) *Scheduler { 39 + p := &Scheduler{ 40 + maxConcurrency: maxC, 41 + maxQueue: maxQ, 42 + 43 + do: do, 44 + 45 + feeder: make(chan *consumerTask), 46 + active: make(map[string][]*consumerTask), 47 + out: make(chan struct{}), 48 + 49 + ident: ident, 50 + 51 + itemsAdded: schedulers.WorkItemsAdded.WithLabelValues(ident, "parallel"), 52 + itemsProcessed: schedulers.WorkItemsProcessed.WithLabelValues(ident, "parallel"), 53 + itemsActive: schedulers.WorkItemsActive.WithLabelValues(ident, "parallel"), 54 + workesActive: schedulers.WorkersActive.WithLabelValues(ident, "parallel"), 55 + 56 + log: slog.Default().With("system", "parallel-scheduler"), 57 + } 58 + 59 + for i := 0; i < maxC; i++ { 60 + go p.worker() 61 + } 62 + 63 + p.workesActive.Set(float64(maxC)) 64 + 65 + return p 66 + } 67 + 68 + func (p *Scheduler) Shutdown() { 69 + p.log.Info("shutting down parallel scheduler", "ident", p.ident) 70 + 71 + for i := 0; i < p.maxConcurrency; i++ { 72 + p.feeder <- &consumerTask{ 73 + control: "stop", 74 + } 75 + } 76 + 77 + close(p.feeder) 78 + 79 + for i := 0; i < p.maxConcurrency; i++ { 80 + <-p.out 81 + } 82 + 83 + p.log.Info("parallel scheduler shutdown complete") 84 + } 85 + 86 + type consumerTask struct { 87 + repo string 88 + val *events.XRPCStreamEvent 89 + control string 90 + } 91 + 92 + func (p *Scheduler) AddWork(ctx context.Context, repo string, val *events.XRPCStreamEvent) error { 93 + p.itemsAdded.Inc() 94 + t := &consumerTask{ 95 + repo: repo, 96 + val: val, 97 + } 98 + p.lk.Lock() 99 + 100 + a, ok := p.active[repo] 101 + if ok { 102 + p.active[repo] = append(a, t) 103 + p.lk.Unlock() 104 + return nil 105 + } 106 + 107 + p.active[repo] = []*consumerTask{} 108 + p.lk.Unlock() 109 + 110 + select { 111 + case p.feeder <- t: 112 + return nil 113 + case <-ctx.Done(): 114 + return ctx.Err() 115 + } 116 + } 117 + 118 + func (p *Scheduler) worker() { 119 + for work := range p.feeder { 120 + for work != nil { 121 + if work.control == "stop" { 122 + p.out <- struct{}{} 123 + return 124 + } 125 + 126 + p.itemsActive.Inc() 127 + if err := p.do(context.TODO(), work.val); err != nil { 128 + p.log.Error("event handler failed", "err", err) 129 + } 130 + p.itemsProcessed.Inc() 131 + 132 + p.lk.Lock() 133 + rem, ok := p.active[work.repo] 134 + if !ok { 135 + p.log.Error("should always have an 'active' entry if a worker is processing a job") 136 + } 137 + 138 + if len(rem) == 0 { 139 + delete(p.active, work.repo) 140 + work = nil 141 + } else { 142 + work = rem[0] 143 + p.active[work.repo] = rem[1:] 144 + } 145 + p.lk.Unlock() 146 + } 147 + } 148 + }
+1
cmd/relay/events/schedulers/scheduler.go
··· 1 + package schedulers
+478
cmd/relay/main.go
··· 1 + package main 2 + 3 + import ( 4 + "context" 5 + "crypto/rand" 6 + "encoding/base64" 7 + "errors" 8 + "fmt" 9 + "github.com/bluesky-social/indigo/atproto/identity" 10 + "github.com/bluesky-social/indigo/cmd/relay/events/diskpersist" 11 + "gorm.io/gorm" 12 + "io" 13 + "log/slog" 14 + _ "net/http/pprof" 15 + "net/url" 16 + "os" 17 + "os/signal" 18 + "path/filepath" 19 + "strconv" 20 + "strings" 21 + "syscall" 22 + "time" 23 + 24 + libbgs "github.com/bluesky-social/indigo/cmd/relay/bgs" 25 + "github.com/bluesky-social/indigo/cmd/relay/events" 26 + "github.com/bluesky-social/indigo/util" 27 + "github.com/bluesky-social/indigo/util/cliutil" 28 + "github.com/bluesky-social/indigo/xrpc" 29 + 30 + _ "github.com/joho/godotenv/autoload" 31 + _ "go.uber.org/automaxprocs" 32 + 33 + "github.com/carlmjohnson/versioninfo" 34 + "github.com/urfave/cli/v2" 35 + "go.opentelemetry.io/otel" 36 + "go.opentelemetry.io/otel/attribute" 37 + "go.opentelemetry.io/otel/exporters/jaeger" 38 + "go.opentelemetry.io/otel/exporters/otlp/otlptrace/otlptracehttp" 39 + "go.opentelemetry.io/otel/sdk/resource" 40 + tracesdk "go.opentelemetry.io/otel/sdk/trace" 41 + semconv "go.opentelemetry.io/otel/semconv/v1.4.0" 42 + "gorm.io/plugin/opentelemetry/tracing" 43 + ) 44 + 45 + func main() { 46 + if err := run(os.Args); err != nil { 47 + slog.Error(err.Error()) 48 + os.Exit(1) 49 + } 50 + } 51 + 52 + func run(args []string) error { 53 + 54 + app := cli.App{ 55 + Name: "relay", 56 + Usage: "atproto Relay daemon", 57 + Version: versioninfo.Short(), 58 + } 59 + 60 + app.Flags = []cli.Flag{ 61 + &cli.BoolFlag{ 62 + Name: "jaeger", 63 + }, 64 + &cli.StringFlag{ 65 + Name: "db-url", 66 + Usage: "database connection string for BGS database", 67 + Value: "sqlite://./data/bigsky/bgs.sqlite", 68 + EnvVars: []string{"DATABASE_URL"}, 69 + }, 70 + &cli.BoolFlag{ 71 + Name: "db-tracing", 72 + }, 73 + &cli.StringFlag{ 74 + Name: "plc-host", 75 + Usage: "method, hostname, and port of PLC registry", 76 + Value: "https://plc.directory", 77 + EnvVars: []string{"ATP_PLC_HOST"}, 78 + }, 79 + &cli.BoolFlag{ 80 + Name: "crawl-insecure-ws", 81 + Usage: "when connecting to PDS instances, use ws:// instead of wss://", 82 + }, 83 + &cli.StringFlag{ 84 + Name: "api-listen", 85 + Value: ":2470", 86 + EnvVars: []string{"RELAY_API_LISTEN"}, 87 + }, 88 + &cli.StringFlag{ 89 + Name: "metrics-listen", 90 + Value: ":2471", 91 + EnvVars: []string{"RELAY_METRICS_LISTEN", "BGS_METRICS_LISTEN"}, 92 + }, 93 + &cli.StringFlag{ 94 + Name: "disk-persister-dir", 95 + Usage: "set directory for disk persister (implicitly enables disk persister)", 96 + EnvVars: []string{"RELAY_PERSISTER_DIR"}, 97 + }, 98 + &cli.StringFlag{ 99 + Name: "admin-key", 100 + EnvVars: []string{"RELAY_ADMIN_KEY", "BGS_ADMIN_KEY"}, 101 + }, 102 + &cli.IntFlag{ 103 + Name: "max-metadb-connections", 104 + EnvVars: []string{"MAX_METADB_CONNECTIONS"}, 105 + Value: 40, 106 + }, 107 + &cli.StringFlag{ 108 + Name: "env", 109 + Value: "dev", 110 + EnvVars: []string{"ENVIRONMENT"}, 111 + Usage: "declared hosting environment (prod, qa, etc); used in metrics", 112 + }, 113 + &cli.StringFlag{ 114 + Name: "otel-exporter-otlp-endpoint", 115 + EnvVars: []string{"OTEL_EXPORTER_OTLP_ENDPOINT"}, 116 + }, 117 + &cli.StringFlag{ 118 + Name: "bsky-social-rate-limit-skip", 119 + EnvVars: []string{"BSKY_SOCIAL_RATE_LIMIT_SKIP"}, 120 + Usage: "ratelimit bypass secret token for *.bsky.social domains", 121 + }, 122 + &cli.IntFlag{ 123 + Name: "default-repo-limit", 124 + Value: 100, 125 + EnvVars: []string{"RELAY_DEFAULT_REPO_LIMIT"}, 126 + }, 127 + &cli.IntFlag{ 128 + Name: "concurrency-per-pds", 129 + EnvVars: []string{"RELAY_CONCURRENCY_PER_PDS"}, 130 + Value: 100, 131 + }, 132 + &cli.IntFlag{ 133 + Name: "max-queue-per-pds", 134 + EnvVars: []string{"RELAY_MAX_QUEUE_PER_PDS"}, 135 + Value: 1_000, 136 + }, 137 + &cli.IntFlag{ 138 + Name: "did-cache-size", 139 + Usage: "in-process cache by number of Did documents", 140 + EnvVars: []string{"RELAY_DID_CACHE_SIZE"}, 141 + Value: 5_000_000, 142 + }, 143 + &cli.DurationFlag{ 144 + Name: "event-playback-ttl", 145 + Usage: "time to live for event playback buffering (only applies to disk persister)", 146 + EnvVars: []string{"RELAY_EVENT_PLAYBACK_TTL"}, 147 + Value: 72 * time.Hour, 148 + }, 149 + &cli.StringSliceFlag{ 150 + Name: "next-crawler", 151 + Usage: "forward POST requestCrawl to this url, should be machine root url and not xrpc/requestCrawl, comma separated list", 152 + EnvVars: []string{"RELAY_NEXT_CRAWLER"}, 153 + }, 154 + &cli.StringFlag{ 155 + Name: "trace-induction", 156 + Usage: "file path to log debug trace stuff about induction firehose", 157 + EnvVars: []string{"RELAY_TRACE_INDUCTION"}, 158 + }, 159 + &cli.BoolFlag{ 160 + Name: "time-seq", 161 + EnvVars: []string{"RELAY_TIME_SEQUENCE"}, 162 + Value: false, 163 + Usage: "make outbound firehose sequence number approximately unix microseconds", 164 + }, 165 + } 166 + 167 + app.Action = runRelay 168 + return app.Run(os.Args) 169 + } 170 + 171 + func setupOTEL(cctx *cli.Context) error { 172 + 173 + env := cctx.String("env") 174 + if env == "" { 175 + env = "dev" 176 + } 177 + if cctx.Bool("jaeger") { 178 + jaegerUrl := "http://localhost:14268/api/traces" 179 + exp, err := jaeger.New(jaeger.WithCollectorEndpoint(jaeger.WithEndpoint(jaegerUrl))) 180 + if err != nil { 181 + return err 182 + } 183 + tp := tracesdk.NewTracerProvider( 184 + // Always be sure to batch in production. 185 + tracesdk.WithBatcher(exp), 186 + // Record information about this application in a Resource. 187 + tracesdk.WithResource(resource.NewWithAttributes( 188 + semconv.SchemaURL, 189 + semconv.ServiceNameKey.String("bgs"), 190 + attribute.String("env", env), // DataDog 191 + attribute.String("environment", env), // Others 192 + attribute.Int64("ID", 1), 193 + )), 194 + ) 195 + 196 + otel.SetTracerProvider(tp) 197 + } 198 + 199 + // Enable OTLP HTTP exporter 200 + // For relevant environment variables: 201 + // https://pkg.go.dev/go.opentelemetry.io/otel/exporters/otlp/otlptrace#readme-environment-variables 202 + // At a minimum, you need to set 203 + // OTEL_EXPORTER_OTLP_ENDPOINT=http://localhost:4318 204 + if ep := cctx.String("otel-exporter-otlp-endpoint"); ep != "" { 205 + slog.Info("setting up trace exporter", "endpoint", ep) 206 + ctx, cancel := context.WithCancel(context.Background()) 207 + defer cancel() 208 + 209 + exp, err := otlptracehttp.New(ctx) 210 + if err != nil { 211 + slog.Error("failed to create trace exporter", "error", err) 212 + os.Exit(1) 213 + } 214 + defer func() { 215 + ctx, cancel := context.WithTimeout(context.Background(), time.Second) 216 + defer cancel() 217 + if err := exp.Shutdown(ctx); err != nil { 218 + slog.Error("failed to shutdown trace exporter", "error", err) 219 + } 220 + }() 221 + 222 + tp := tracesdk.NewTracerProvider( 223 + tracesdk.WithBatcher(exp), 224 + tracesdk.WithResource(resource.NewWithAttributes( 225 + semconv.SchemaURL, 226 + semconv.ServiceNameKey.String("bgs"), 227 + attribute.String("env", env), // DataDog 228 + attribute.String("environment", env), // Others 229 + attribute.Int64("ID", 1), 230 + )), 231 + ) 232 + otel.SetTracerProvider(tp) 233 + } 234 + 235 + return nil 236 + } 237 + 238 + func runRelay(cctx *cli.Context) error { 239 + // Trap SIGINT to trigger a shutdown. 240 + signals := make(chan os.Signal, 1) 241 + signal.Notify(signals, syscall.SIGINT, syscall.SIGTERM) 242 + 243 + logger, logWriter, err := cliutil.SetupSlog(cliutil.LogOptions{}) 244 + if err != nil { 245 + return err 246 + } 247 + 248 + var inductionTraceLog *slog.Logger 249 + 250 + if cctx.IsSet("trace-induction") { 251 + traceFname := cctx.String("trace-induction") 252 + traceFout, err := os.OpenFile(traceFname, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0644) 253 + if err != nil { 254 + return fmt.Errorf("%s: could not open trace file: %w", traceFname, err) 255 + } 256 + defer traceFout.Close() 257 + if traceFname != "" { 258 + inductionTraceLog = slog.New(slog.NewJSONHandler(traceFout, &slog.HandlerOptions{Level: slog.LevelDebug})) 259 + } 260 + } else { 261 + inductionTraceLog = slog.New(slog.NewTextHandler(io.Discard, &slog.HandlerOptions{Level: slog.Level(999)})) 262 + } 263 + 264 + // start observability/tracing (OTEL and jaeger) 265 + if err := setupOTEL(cctx); err != nil { 266 + return err 267 + } 268 + 269 + dburl := cctx.String("db-url") 270 + logger.Info("setting up main database", "url", dburl) 271 + db, err := cliutil.SetupDatabase(dburl, cctx.Int("max-metadb-connections")) 272 + if err != nil { 273 + return err 274 + } 275 + if cctx.Bool("db-tracing") { 276 + if err := db.Use(tracing.NewPlugin()); err != nil { 277 + return err 278 + } 279 + } 280 + if err := db.AutoMigrate(RelaySetting{}); err != nil { 281 + panic(err) 282 + } 283 + 284 + // TODO: add shared external cache 285 + baseDir := identity.BaseDirectory{ 286 + SkipHandleVerification: true, 287 + SkipDNSDomainSuffixes: []string{".bsky.social"}, 288 + TryAuthoritativeDNS: true, 289 + } 290 + cacheDir := identity.NewCacheDirectory(&baseDir, cctx.Int("did-cache-size"), time.Hour*24, time.Minute*2, time.Minute*5) 291 + 292 + // TODO: rename repoman 293 + repoman := libbgs.NewValidator(&cacheDir, inductionTraceLog) 294 + 295 + var persister events.EventPersistence 296 + 297 + dpd := cctx.String("disk-persister-dir") 298 + if dpd == "" { 299 + logger.Info("empty disk-persister-dir, use current working directory") 300 + cwd, err := os.Getwd() 301 + if err != nil { 302 + return err 303 + } 304 + dpd = filepath.Join(cwd, "relay-persist") 305 + } 306 + logger.Info("setting up disk persister", "dir", dpd) 307 + 308 + pOpts := diskpersist.DefaultDiskPersistOptions() 309 + pOpts.Retention = cctx.Duration("event-playback-ttl") 310 + pOpts.TimeSequence = cctx.Bool("time-seq") 311 + 312 + // ensure that time-ish sequence stays consistent within a server context 313 + storedTimeSeq, hadStoredTimeSeq, err := getRelaySettingBool(db, "time-seq") 314 + if err != nil { 315 + return err 316 + } 317 + if !hadStoredTimeSeq { 318 + if err := setRelaySettingBool(db, "time-seq", pOpts.TimeSequence); err != nil { 319 + return err 320 + } 321 + } else { 322 + if pOpts.TimeSequence != storedTimeSeq { 323 + return fmt.Errorf("time-seq stored as %v but param/env set as %v", storedTimeSeq, pOpts.TimeSequence) 324 + } 325 + } 326 + 327 + dp, err := diskpersist.NewDiskPersistence(dpd, "", db, pOpts) 328 + if err != nil { 329 + return fmt.Errorf("setting up disk persister: %w", err) 330 + } 331 + persister = dp 332 + 333 + evtman := events.NewEventManager(persister) 334 + 335 + ratelimitBypass := cctx.String("bsky-social-rate-limit-skip") 336 + 337 + logger.Info("constructing bgs") 338 + bgsConfig := libbgs.DefaultBGSConfig() 339 + bgsConfig.SSL = !cctx.Bool("crawl-insecure-ws") 340 + bgsConfig.ConcurrencyPerPDS = cctx.Int64("concurrency-per-pds") 341 + bgsConfig.MaxQueuePerPDS = cctx.Int64("max-queue-per-pds") 342 + bgsConfig.DefaultRepoLimit = cctx.Int64("default-repo-limit") 343 + bgsConfig.ApplyPDSClientSettings = makePdsClientSetup(ratelimitBypass) 344 + bgsConfig.InductionTraceLog = inductionTraceLog 345 + nextCrawlers := cctx.StringSlice("next-crawler") 346 + if len(nextCrawlers) != 0 { 347 + nextCrawlerUrls := make([]*url.URL, len(nextCrawlers)) 348 + for i, tu := range nextCrawlers { 349 + var err error 350 + nextCrawlerUrls[i], err = url.Parse(tu) 351 + if err != nil { 352 + return fmt.Errorf("failed to parse next-crawler url: %w", err) 353 + } 354 + logger.Info("configuring relay for requestCrawl", "host", nextCrawlerUrls[i]) 355 + } 356 + bgsConfig.NextCrawlers = nextCrawlerUrls 357 + } 358 + if cctx.IsSet("admin-key") { 359 + bgsConfig.AdminToken = cctx.String("admin-key") 360 + } else { 361 + var rblob [10]byte 362 + _, _ = rand.Read(rblob[:]) 363 + bgsConfig.AdminToken = base64.URLEncoding.EncodeToString(rblob[:]) 364 + logger.Info("generated random admin key", "header", "Authorization: Bearer "+bgsConfig.AdminToken) 365 + } 366 + bgs, err := libbgs.NewBGS(db, repoman, evtman, &cacheDir, bgsConfig) 367 + if err != nil { 368 + return err 369 + } 370 + dp.SetUidSource(bgs) 371 + 372 + // set up metrics endpoint 373 + go func() { 374 + if err := bgs.StartMetrics(cctx.String("metrics-listen")); err != nil { 375 + logger.Error("failed to start metrics endpoint", "err", err) 376 + os.Exit(1) 377 + } 378 + }() 379 + 380 + bgsErr := make(chan error, 1) 381 + 382 + go func() { 383 + err := bgs.Start(cctx.String("api-listen"), logWriter) 384 + bgsErr <- err 385 + }() 386 + 387 + logger.Info("startup complete") 388 + select { 389 + case <-signals: 390 + logger.Info("received shutdown signal") 391 + errs := bgs.Shutdown() 392 + for err := range errs { 393 + logger.Error("error during BGS shutdown", "err", err) 394 + } 395 + case err := <-bgsErr: 396 + if err != nil { 397 + logger.Error("error during BGS startup", "err", err) 398 + } 399 + logger.Info("shutting down") 400 + errs := bgs.Shutdown() 401 + for err := range errs { 402 + logger.Error("error during BGS shutdown", "err", err) 403 + } 404 + } 405 + 406 + logger.Info("shutdown complete") 407 + 408 + return nil 409 + } 410 + 411 + func makePdsClientSetup(ratelimitBypass string) func(c *xrpc.Client) { 412 + return func(c *xrpc.Client) { 413 + if c.Client == nil { 414 + c.Client = util.RobustHTTPClient() 415 + } 416 + if strings.HasSuffix(c.Host, ".bsky.network") { 417 + c.Client.Timeout = time.Minute * 30 418 + if ratelimitBypass != "" { 419 + c.Headers = map[string]string{ 420 + "x-ratelimit-bypass": ratelimitBypass, 421 + } 422 + } 423 + } else { 424 + // Generic PDS timeout 425 + c.Client.Timeout = time.Minute * 1 426 + } 427 + } 428 + } 429 + 430 + // RelaySetting is a gorm model 431 + type RelaySetting struct { 432 + Name string `gorm:"primarykey"` 433 + Value string 434 + } 435 + 436 + func getRelaySetting(db *gorm.DB, name string) (value string, found bool, err error) { 437 + var setting RelaySetting 438 + dbResult := db.First(&setting, "name = ?", name) 439 + if errors.Is(dbResult.Error, gorm.ErrRecordNotFound) { 440 + return "", false, nil 441 + } 442 + if dbResult.Error != nil { 443 + return "", false, dbResult.Error 444 + } 445 + return setting.Value, true, nil 446 + } 447 + 448 + func setRelaySetting(db *gorm.DB, name string, value string) error { 449 + return db.Transaction(func(tx *gorm.DB) error { 450 + var setting RelaySetting 451 + found := tx.First(&setting, "name = ?", name) 452 + if errors.Is(found.Error, gorm.ErrRecordNotFound) { 453 + // ok! create it 454 + setting.Name = name 455 + setting.Value = value 456 + return tx.Create(&setting).Error 457 + } else if found.Error != nil { 458 + return found.Error 459 + } 460 + setting.Value = value 461 + return tx.Save(&setting).Error 462 + }) 463 + } 464 + 465 + func getRelaySettingBool(db *gorm.DB, name string) (value bool, found bool, err error) { 466 + strval, found, err := getRelaySetting(db, name) 467 + if err != nil || !found { 468 + return false, found, err 469 + } 470 + value, err = strconv.ParseBool(strval) 471 + if err != nil { 472 + return false, false, err 473 + } 474 + return value, true, nil 475 + } 476 + func setRelaySettingBool(db *gorm.DB, name string, value bool) error { 477 + return setRelaySetting(db, name, strconv.FormatBool(value)) 478 + }
+82
cmd/relay/models/models.go
··· 1 + package models 2 + 3 + import ( 4 + "database/sql/driver" 5 + "encoding/json" 6 + "fmt" 7 + "github.com/ipfs/go-cid" 8 + "gorm.io/gorm" 9 + ) 10 + 11 + type Uid uint64 12 + 13 + type DbCID struct { 14 + CID cid.Cid 15 + } 16 + 17 + func (dbc *DbCID) Scan(v interface{}) error { 18 + b, ok := v.([]byte) 19 + if !ok { 20 + return fmt.Errorf("dbcids must get bytes!") 21 + } 22 + 23 + if len(b) == 0 { 24 + return nil 25 + } 26 + 27 + c, err := cid.Cast(b) 28 + if err != nil { 29 + return err 30 + } 31 + 32 + dbc.CID = c 33 + return nil 34 + } 35 + 36 + func (dbc DbCID) Value() (driver.Value, error) { 37 + if !dbc.CID.Defined() { 38 + return nil, fmt.Errorf("cannot serialize undefined cid to database") 39 + } 40 + return dbc.CID.Bytes(), nil 41 + } 42 + 43 + func (dbc DbCID) MarshalJSON() ([]byte, error) { 44 + return json.Marshal(dbc.CID.String()) 45 + } 46 + 47 + func (dbc *DbCID) UnmarshalJSON(b []byte) error { 48 + var s string 49 + if err := json.Unmarshal(b, &s); err != nil { 50 + return err 51 + } 52 + 53 + c, err := cid.Decode(s) 54 + if err != nil { 55 + return err 56 + } 57 + 58 + dbc.CID = c 59 + return nil 60 + } 61 + 62 + func (dbc *DbCID) GormDataType() string { 63 + return "bytes" 64 + } 65 + 66 + type PDS struct { 67 + gorm.Model 68 + 69 + Host string `gorm:"unique"` 70 + SSL bool 71 + Cursor int64 72 + Registered bool 73 + Blocked bool 74 + 75 + RateLimit float64 76 + 77 + RepoCount int64 78 + RepoLimit int64 79 + 80 + HourlyEventLimit int64 81 + DailyEventLimit int64 82 + }
+14 -2
events/consumer.go
··· 33 33 switch { 34 34 case xev.RepoCommit != nil && rsc.RepoCommit != nil: 35 35 return rsc.RepoCommit(xev.RepoCommit) 36 - case xev.RepoSync != nil && rsc.RepoCommit != nil: 36 + case xev.RepoSync != nil && rsc.RepoSync != nil: 37 37 return rsc.RepoSync(xev.RepoSync) 38 38 case xev.RepoHandle != nil && rsc.RepoHandle != nil: 39 39 return rsc.RepoHandle(xev.RepoHandle) ··· 129 129 go func() { 130 130 t := time.NewTicker(time.Second * 30) 131 131 defer t.Stop() 132 + failcount := 0 132 133 133 134 for { 134 135 ··· 136 137 case <-t.C: 137 138 if err := con.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(time.Second*10)); err != nil { 138 139 log.Warn("failed to ping", "err", err) 140 + failcount++ 141 + if failcount >= 4 { 142 + log.Error("too many ping fails", "count", failcount) 143 + con.Close() 144 + return 145 + } 146 + } else { 147 + failcount = 0 // ok ping 139 148 } 140 149 case <-ctx.Done(): 141 150 con.Close() ··· 172 181 173 182 mt, rawReader, err := con.NextReader() 174 183 if err != nil { 175 - return err 184 + return fmt.Errorf("con err at read: %w", err) 176 185 } 177 186 178 187 switch mt { ··· 233 242 return err 234 243 } 235 244 case "#handle": 245 + // TODO: DEPRECATED message; warning/counter; drop message 236 246 var evt comatproto.SyncSubscribeRepos_Handle 237 247 if err := evt.UnmarshalCBOR(r); err != nil { 238 248 return err ··· 293 303 return err 294 304 } 295 305 case "#migrate": 306 + // TODO: DEPRECATED message; warning/counter; drop message 296 307 var evt comatproto.SyncSubscribeRepos_Migrate 297 308 if err := evt.UnmarshalCBOR(r); err != nil { 298 309 return err ··· 309 320 return err 310 321 } 311 322 case "#tombstone": 323 + // TODO: DEPRECATED message; warning/counter; drop message 312 324 var evt comatproto.SyncSubscribeRepos_Tombstone 313 325 if err := evt.UnmarshalCBOR(r); err != nil { 314 326 return err
+34 -8
events/events.go
··· 156 156 } 157 157 158 158 var ( 159 - AccountStatusActive = "active" 160 - AccountStatusTakendown = "takendown" 161 - AccountStatusSuspended = "suspended" 162 - AccountStatusDeleted = "deleted" 163 - AccountStatusDeactivated = "deactivated" 159 + // AccountStatusActive is not in the spec but used internally 160 + // the alternative would be an additional SQL column for "active" or status="" to imply active 161 + AccountStatusActive = "active" 162 + 163 + AccountStatusDeactivated = "deactivated" 164 + AccountStatusDeleted = "deleted" 165 + AccountStatusDesynchronized = "desynchronized" 166 + AccountStatusSuspended = "suspended" 167 + AccountStatusTakendown = "takendown" 168 + AccountStatusThrottled = "throttled" 164 169 ) 165 170 171 + var AccountStatusList = []string{ 172 + AccountStatusActive, 173 + AccountStatusDeactivated, 174 + AccountStatusDeleted, 175 + AccountStatusDesynchronized, 176 + AccountStatusSuspended, 177 + AccountStatusTakendown, 178 + AccountStatusThrottled, 179 + } 180 + var AccountStatuses map[string]bool 181 + 182 + func init() { 183 + AccountStatuses = make(map[string]bool, len(AccountStatusList)) 184 + for _, status := range AccountStatusList { 185 + AccountStatuses[status] = true 186 + } 187 + } 188 + 166 189 type XRPCStreamEvent struct { 167 190 Error *ErrorFrame 168 191 RepoCommit *comatproto.SyncSubscribeRepos_Commit 169 192 RepoSync *comatproto.SyncSubscribeRepos_Sync 170 - RepoHandle *comatproto.SyncSubscribeRepos_Handle 193 + RepoHandle *comatproto.SyncSubscribeRepos_Handle // DEPRECATED 171 194 RepoIdentity *comatproto.SyncSubscribeRepos_Identity 172 195 RepoInfo *comatproto.SyncSubscribeRepos_Info 173 - RepoMigrate *comatproto.SyncSubscribeRepos_Migrate 174 - RepoTombstone *comatproto.SyncSubscribeRepos_Tombstone 196 + RepoMigrate *comatproto.SyncSubscribeRepos_Migrate // DEPRECATED 197 + RepoTombstone *comatproto.SyncSubscribeRepos_Tombstone // DEPRECATED 175 198 RepoAccount *comatproto.SyncSubscribeRepos_Account 176 199 LabelLabels *comatproto.LabelSubscribeLabels_Labels 177 200 LabelInfo *comatproto.LabelSubscribeLabels_Info ··· 247 270 } 248 271 xevt.RepoSync = &evt 249 272 case "#handle": 273 + // TODO: DEPRECATED message; warning/counter; drop message 250 274 var evt comatproto.SyncSubscribeRepos_Handle 251 275 if err := evt.UnmarshalCBOR(r); err != nil { 252 276 return err ··· 272 296 } 273 297 xevt.RepoInfo = &evt 274 298 case "#migrate": 299 + // TODO: DEPRECATED message; warning/counter; drop message 275 300 var evt comatproto.SyncSubscribeRepos_Migrate 276 301 if err := evt.UnmarshalCBOR(r); err != nil { 277 302 return err 278 303 } 279 304 xevt.RepoMigrate = &evt 280 305 case "#tombstone": 306 + // TODO: DEPRECATED message; warning/counter; drop message 281 307 var evt comatproto.SyncSubscribeRepos_Tombstone 282 308 if err := evt.UnmarshalCBOR(r); err != nil { 283 309 return err
+1
events/persist.go
··· 10 10 11 11 // Note that this interface looks generic, but some persisters might only work with RepoAppend or LabelLabels 12 12 type EventPersistence interface { 13 + // Persist may mutate contents of *XRPCStreamEvent and what it points to 13 14 Persist(ctx context.Context, e *XRPCStreamEvent) error 14 15 Playback(ctx context.Context, since int64, cb func(*XRPCStreamEvent) error) error 15 16 TakeDownRepo(ctx context.Context, usr models.Uid) error
-69
ts/bgs-dash/src/components/Dash/Dash.tsx
··· 69 69 useState<PDS | null>(null); 70 70 const [editingPerDayRateLimit, setEditingPerDayRateLimit] = 71 71 useState<PDS | null>(null); 72 - const [editingCrawlRateLimit, setEditingCrawlRateLimit] = 73 - useState<PDS | null>(null); 74 72 const [editingRepoLimit, setEditingRepoLimit] = 75 73 useState<PDS | null>(null); 76 74 ··· 394 392 per_second: pds.PerSecondEventRate.Max, 395 393 per_hour: pds.PerHourEventRate.Max, 396 394 per_day: pds.PerDayEventRate.Max, 397 - crawl_rate: pds.CrawlRate.Max, 398 395 repo_limit: pds.RepoLimit, 399 396 }), 400 397 } ··· 857 854 className="px-3 py-3.5 text-right text-sm font-semibold text-gray-900 pr-6 whitespace-nowrap" 858 855 > 859 856 <a href="#" className="group inline-flex"> 860 - Crawl Limit 861 - </a> 862 - </th> 863 - <th 864 - scope="col" 865 - className="px-3 py-3.5 text-right text-sm font-semibold text-gray-900 pr-6 whitespace-nowrap" 866 - > 867 - <a href="#" className="group inline-flex"> 868 857 Repo Limit 869 858 </a> 870 859 </th> ··· 1165 1154 className={ 1166 1155 "rounded-md p-2 ml-1 hover:text-green-600 hover:bg-green-100 focus:outline-none focus:ring-2 focus:ring-green-600 focus:ring-offset-2 focus:ring-offset-green-50" + 1167 1156 (editingPerDayRateLimit?.ID === pds.ID 1168 - ? "" 1169 - : " hidden") 1170 - } 1171 - > 1172 - <CheckIcon 1173 - className="h-5 w-5 text-green-500 inline-block align-sub" 1174 - aria-hidden="true" 1175 - /> 1176 - </a> 1177 - </td> 1178 - <td className="whitespace-nowrap px-3 py-2 text-sm text-gray-400 text-center w-8 pr-6"> 1179 - <span 1180 - className={ 1181 - editingCrawlRateLimit?.ID === pds.ID 1182 - ? "hidden" 1183 - : "" 1184 - } 1185 - > 1186 - {pds.CrawlRate.Max?.toLocaleString()} 1187 - /sec 1188 - </span> 1189 - <input 1190 - type="number" 1191 - name={`crawl-rate-limit-${pds.ID}`} 1192 - id={`crawl-rate-limit-${pds.ID}`} 1193 - className={ 1194 - `inline-block w-24 rounded-md border-0 py-1.5 text-gray-900 shadow-sm ring-1 ring-inset ring-gray-300 placeholder:text-gray-400 focus:ring-2 focus:ring-inset focus:ring-indigo-600 sm:text-sm sm:leading-6` + 1195 - (editingCrawlRateLimit?.ID === pds.ID 1196 - ? "" 1197 - : " hidden") 1198 - } 1199 - defaultValue={pds.CrawlRate.Max?.toLocaleString()} 1200 - /> 1201 - <a 1202 - href="#" 1203 - onClick={() => setEditingCrawlRateLimit(pds)} 1204 - className={editingCrawlRateLimit ? "hidden" : ""} 1205 - > 1206 - <PencilSquareIcon 1207 - className="h-5 w-5 text-gray-500 ml-1 inline-block align-sub" 1208 - aria-hidden="true" 1209 - /> 1210 - </a> 1211 - <a 1212 - href="#" 1213 - onClick={() => { 1214 - const newRateLimit = document.getElementById( 1215 - `crawl-rate-limit-${pds.ID}` 1216 - ) as HTMLInputElement; 1217 - if (newRateLimit) { 1218 - pds.CrawlRate.Max = +newRateLimit.value; 1219 - updateRateLimits(pds); 1220 - } 1221 - setEditingCrawlRateLimit(null); 1222 - }} 1223 - className={ 1224 - "rounded-md p-2 ml-1 hover:text-green-600 hover:bg-green-100 focus:outline-none focus:ring-2 focus:ring-green-600 focus:ring-offset-2 focus:ring-offset-green-50" + 1225 - (editingCrawlRateLimit?.ID === pds.ID 1226 1157 ? "" 1227 1158 : " hidden") 1228 1159 }
-1
ts/bgs-dash/src/models/pds.ts
··· 16 16 Blocked: boolean; 17 17 HasActiveConnection: boolean; 18 18 EventsSeenSinceStartup?: number; 19 - CrawlRate: RateLimit; 20 19 PerSecondEventRate: RateLimit; 21 20 PerHourEventRate: RateLimit; 22 21 PerDayEventRate: RateLimit;
+8 -8
util/cliutil/util.go
··· 235 235 // The env vars were derived from ipfs logging library, and also respond to some GOLOG_ vars from that library, 236 236 // but BSKYLOG_ variables are preferred because imported code still using the ipfs log library may misbehave 237 237 // if some GOLOG values are set, especially GOLOG_FILE. 238 - func SetupSlog(options LogOptions) (*slog.Logger, error) { 238 + func SetupSlog(options LogOptions) (*slog.Logger, io.Writer, error) { 239 239 fmt.Fprintf(os.Stderr, "SetupSlog\n") 240 240 var hopts slog.HandlerOptions 241 241 hopts.Level = slog.LevelInfo ··· 258 258 case "error": 259 259 hopts.Level = slog.LevelError 260 260 default: 261 - return nil, fmt.Errorf("unknown log level: %#v", options.LogLevel) 261 + return nil, nil, fmt.Errorf("unknown log level: %#v", options.LogLevel) 262 262 } 263 263 } 264 264 if options.LogFormat == "" { ··· 271 271 if format == "json" || format == "text" { 272 272 // ok 273 273 } else { 274 - return nil, fmt.Errorf("invalid log format: %#v", options.LogFormat) 274 + return nil, nil, fmt.Errorf("invalid log format: %#v", options.LogFormat) 275 275 } 276 276 options.LogFormat = format 277 277 } ··· 284 284 if rotateBytesStr != "" { 285 285 rotateBytes, err := strconv.ParseInt(rotateBytesStr, 10, 64) 286 286 if err != nil { 287 - return nil, fmt.Errorf("invalid BSKYLOG_ROTATE_BYTES value: %w", err) 287 + return nil, nil, fmt.Errorf("invalid BSKYLOG_ROTATE_BYTES value: %w", err) 288 288 } 289 289 options.LogRotateBytes = rotateBytes 290 290 } ··· 295 295 if keepOldStr != "" { 296 296 keepOld, err := strconv.ParseInt(keepOldStr, 10, 64) 297 297 if err != nil { 298 - return nil, fmt.Errorf("invalid BSKYLOG_ROTATE_KEEP value: %w", err) 298 + return nil, nil, fmt.Errorf("invalid BSKYLOG_ROTATE_KEEP value: %w", err) 299 299 } 300 300 keepOldUnset = false 301 301 options.KeepOld = int(keepOld) ··· 320 320 var err error 321 321 out, err = os.Create(options.LogPath) 322 322 if err != nil { 323 - return nil, fmt.Errorf("%s: %w", options.LogPath, err) 323 + return nil, nil, fmt.Errorf("%s: %w", options.LogPath, err) 324 324 } 325 325 fmt.Fprintf(os.Stderr, "SetupSlog create %#v\n", options.LogPath) 326 326 } ··· 331 331 case "json": 332 332 handler = slog.NewJSONHandler(out, &hopts) 333 333 default: 334 - return nil, fmt.Errorf("unknown log format: %#v", options.LogFormat) 334 + return nil, nil, fmt.Errorf("unknown log format: %#v", options.LogFormat) 335 335 } 336 336 logger := slog.New(handler) 337 337 slog.SetDefault(logger) ··· 341 341 fmt.Fprintf(os.Stdout, "%s\n", filepath.Join(templateDirPart, ent.Name())) 342 342 } 343 343 SetIpfsWriter(out, options.LogFormat, options.LogLevel) 344 - return logger, nil 344 + return logger, out, nil 345 345 } 346 346 347 347 type logRotateWriter struct {