forked from hailey.at/cocoon
An atproto PDS written in Go

Compare changes

Choose any two refs to compare.

Changed files
+5212 -2323
.github
workflows
blockstore
cmd
admin
cocoon
identity
internal
db
helpers
metrics
models
oauth
plc
recording_blockstore
server
templates
sqlite_blockstore
+1 -1
.env.example
··· 6 6 COCOON_RELAYS=https://bsky.network 7 7 # Generate with `openssl rand -hex 16` 8 8 COCOON_ADMIN_PASSWORD= 9 - # openssl rand -hex 32 9 + # Generate with `openssl rand -hex 32` 10 10 COCOON_SESSION_SECRET=
+116
.github/workflows/docker-image.yml
··· 1 + name: Docker image 2 + 3 + on: 4 + workflow_dispatch: 5 + push: 6 + branches: 7 + - main 8 + tags: 9 + - 'v*' 10 + 11 + env: 12 + REGISTRY: ghcr.io 13 + IMAGE_NAME: ${{ github.repository }} 14 + 15 + jobs: 16 + build-and-push-image: 17 + strategy: 18 + matrix: 19 + include: 20 + - arch: amd64 21 + runner: ubuntu-latest 22 + - arch: arm64 23 + runner: ubuntu-24.04-arm 24 + runs-on: ${{ matrix.runner }} 25 + # Sets the permissions granted to the `GITHUB_TOKEN` for the actions in this job. 26 + permissions: 27 + contents: read 28 + packages: write 29 + attestations: write 30 + id-token: write 31 + outputs: 32 + digest-amd64: ${{ matrix.arch == 'amd64' && steps.push.outputs.digest || '' }} 33 + digest-arm64: ${{ matrix.arch == 'arm64' && steps.push.outputs.digest || '' }} 34 + steps: 35 + - name: Checkout repository 36 + uses: actions/checkout@v4 37 + 38 + # Uses the `docker/login-action` action to log in to the Container registry registry using the account and password that will publish the packages. Once published, the packages are scoped to the account defined here. 39 + - name: Log in to the Container registry 40 + uses: docker/login-action@v3 41 + with: 42 + registry: ${{ env.REGISTRY }} 43 + username: ${{ github.actor }} 44 + password: ${{ secrets.GITHUB_TOKEN }} 45 + 46 + # This step uses [docker/metadata-action](https://github.com/docker/metadata-action#about) to extract tags and labels that will be applied to the specified image. The `id` "meta" allows the output of this step to be referenced in a subsequent step. The `images` value provides the base name for the tags and labels. 47 + - name: Extract metadata (tags, labels) for Docker 48 + id: meta 49 + uses: docker/metadata-action@v5 50 + with: 51 + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 52 + tags: | 53 + type=raw,value=latest,enable={{is_default_branch}},suffix=-${{ matrix.arch }} 54 + type=sha,suffix=-${{ matrix.arch }} 55 + type=sha,format=long,suffix=-${{ matrix.arch }} 56 + type=semver,pattern={{version}},suffix=-${{ matrix.arch }} 57 + type=semver,pattern={{major}}.{{minor}},suffix=-${{ matrix.arch }} 58 + 59 + # This step uses the `docker/build-push-action` action to build the image, based on your repository's `Dockerfile`. If the build succeeds, it pushes the image to GitHub Packages. 60 + # It uses the `context` parameter to define the build's context as the set of files located in the specified path. For more information, see "[Usage](https://github.com/docker/build-push-action#usage)" in the README of the `docker/build-push-action` repository. 61 + # It uses the `tags` and `labels` parameters to tag and label the image with the output from the "meta" step. 62 + - name: Build and push Docker image 63 + id: push 64 + uses: docker/build-push-action@v6 65 + with: 66 + context: . 67 + push: true 68 + tags: ${{ steps.meta.outputs.tags }} 69 + labels: ${{ steps.meta.outputs.labels }} 70 + 71 + publish-manifest: 72 + needs: build-and-push-image 73 + runs-on: ubuntu-latest 74 + permissions: 75 + packages: write 76 + attestations: write 77 + id-token: write 78 + steps: 79 + - name: Log in to the Container registry 80 + uses: docker/login-action@v3 81 + with: 82 + registry: ${{ env.REGISTRY }} 83 + username: ${{ github.actor }} 84 + password: ${{ secrets.GITHUB_TOKEN }} 85 + 86 + - name: Extract metadata (tags, labels) for Docker 87 + id: meta 88 + uses: docker/metadata-action@v5 89 + with: 90 + images: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME }} 91 + tags: | 92 + type=raw,value=latest,enable={{is_default_branch}} 93 + type=sha 94 + type=sha,format=long 95 + type=semver,pattern={{version}} 96 + type=semver,pattern={{major}}.{{minor}} 97 + 98 + - name: Create and push manifest 99 + run: | 100 + # Split tags into an array 101 + readarray -t tags <<< "${{ steps.meta.outputs.tags }}" 102 + 103 + # Create and push manifest for each tag 104 + for tag in "${tags[@]}"; do 105 + docker buildx imagetools create -t "$tag" \ 106 + "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@${{ needs.build-and-push-image.outputs.digest-amd64 }}" \ 107 + "${{ env.REGISTRY }}/${{ env.IMAGE_NAME }}@${{ needs.build-and-push-image.outputs.digest-arm64 }}" 108 + done 109 + 110 + # This step generates an artifact attestation for the image, which is an unforgeable statement about where and how it was built. It increases supply chain security for people who consume the image. For more information, see "[AUTOTITLE](/actions/security-guides/using-artifact-attestations-to-establish-provenance-for-builds)." 111 + - name: Generate artifact attestation 112 + uses: actions/attest-build-provenance@v1 113 + with: 114 + subject-name: ${{ env.REGISTRY }}/${{ env.IMAGE_NAME}} 115 + subject-digest: ${{ needs.build-and-push-image.outputs.digest-amd64 }} 116 + push-to-registry: true
+3
.gitignore
··· 4 4 *.key 5 5 *.secret 6 6 .DS_Store 7 + data/ 8 + keys/ 9 + dist/
+10
Caddyfile
··· 1 + {$COCOON_HOSTNAME} { 2 + reverse_proxy localhost:8080 3 + 4 + encode gzip 5 + 6 + log { 7 + output file /data/access.log 8 + format json 9 + } 10 + }
+10
Caddyfile.postgres
··· 1 + {$COCOON_HOSTNAME} { 2 + reverse_proxy cocoon:8080 3 + 4 + encode gzip 5 + 6 + log { 7 + output file /data/access.log 8 + format json 9 + } 10 + }
+25
Dockerfile
··· 1 + ### Compile stage 2 + FROM golang:1.25.1-bookworm AS build-env 3 + 4 + ADD . /dockerbuild 5 + WORKDIR /dockerbuild 6 + 7 + RUN GIT_VERSION=$(git describe --tags --long --always || echo "dev-local") && \ 8 + go mod tidy && \ 9 + go build -ldflags "-X main.Version=$GIT_VERSION" -o cocoon ./cmd/cocoon 10 + 11 + ### Run stage 12 + FROM debian:bookworm-slim AS run 13 + 14 + RUN apt-get update && apt-get install -y dumb-init runit ca-certificates curl && rm -rf /var/lib/apt/lists/* 15 + ENTRYPOINT ["dumb-init", "--"] 16 + 17 + WORKDIR / 18 + RUN mkdir -p data/cocoon 19 + COPY --from=build-env /dockerbuild/cocoon / 20 + 21 + CMD ["/cocoon", "run"] 22 + 23 + LABEL org.opencontainers.image.source=https://github.com/haileyok/cocoon 24 + LABEL org.opencontainers.image.description="Cocoon ATProto PDS" 25 + LABEL org.opencontainers.image.licenses=MIT
+40
Makefile
··· 4 4 GIT_COMMIT := $(shell git rev-parse --short=9 HEAD) 5 5 VERSION := $(if $(GIT_TAG),$(GIT_TAG),dev-$(GIT_COMMIT)) 6 6 7 + # Build output directory 8 + BUILD_DIR := dist 9 + 10 + # Platforms to build for 11 + PLATFORMS := \ 12 + linux/amd64 \ 13 + linux/arm64 \ 14 + linux/arm \ 15 + darwin/amd64 \ 16 + darwin/arm64 \ 17 + windows/amd64 \ 18 + windows/arm64 \ 19 + freebsd/amd64 \ 20 + freebsd/arm64 \ 21 + openbsd/amd64 \ 22 + openbsd/arm64 23 + 7 24 .PHONY: help 8 25 help: ## Print info about all commands 9 26 @echo "Commands:" ··· 14 31 build: ## Build all executables 15 32 go build -ldflags "-X main.Version=$(VERSION)" -o cocoon ./cmd/cocoon 16 33 34 + .PHONY: build-release 35 + build-all: ## Build binaries for all architectures 36 + @echo "Building for all architectures..." 37 + @mkdir -p $(BUILD_DIR) 38 + @$(foreach platform,$(PLATFORMS), \ 39 + $(eval OS := $(word 1,$(subst /, ,$(platform)))) \ 40 + $(eval ARCH := $(word 2,$(subst /, ,$(platform)))) \ 41 + $(eval EXT := $(if $(filter windows,$(OS)),.exe,)) \ 42 + $(eval OUTPUT := $(BUILD_DIR)/cocoon-$(VERSION)-$(OS)-$(ARCH)$(EXT)) \ 43 + echo "Building $(OS)/$(ARCH)..."; \ 44 + GOOS=$(OS) GOARCH=$(ARCH) go build -ldflags "-X main.Version=$(VERSION)" -o $(OUTPUT) ./cmd/cocoon && \ 45 + echo " โœ“ $(OUTPUT)" || echo " โœ— Failed: $(OS)/$(ARCH)"; \ 46 + ) 47 + @echo "Done! Binaries are in $(BUILD_DIR)/" 48 + 49 + .PHONY: clean-dist 50 + clean-dist: ## Remove all built binaries 51 + rm -rf $(BUILD_DIR) 52 + 17 53 .PHONY: run 18 54 run: 19 55 go build -ldflags "-X main.Version=dev-local" -o cocoon ./cmd/cocoon && ./cocoon run ··· 40 76 41 77 .env: 42 78 if [ ! -f ".env" ]; then cp example.dev.env .env; fi 79 + 80 + .PHONY: docker-build 81 + docker-build: 82 + docker build -t cocoon .
+248 -60
README.md
··· 1 1 # Cocoon 2 2 3 3 > [!WARNING] 4 - You should not use this PDS. You should not rely on this code as a reference for a PDS implementation. You should not trust this code. Using this PDS implementation may result in data loss, corruption, etc. 4 + I migrated and have been running my main account on this PDS for months now without issue, however, I am still not responsible if things go awry, particularly during account migration. Please use caution. 5 5 6 6 Cocoon is a PDS implementation in Go. It is highly experimental, and is not ready for any production use. 7 7 8 - ### Impmlemented Endpoints 8 + ## Quick Start with Docker Compose 9 + 10 + ### Prerequisites 11 + 12 + - Docker and Docker Compose installed 13 + - A domain name pointing to your server (for automatic HTTPS) 14 + - Ports 80 and 443 open in i.e. UFW 15 + 16 + ### Installation 17 + 18 + 1. **Clone the repository** 19 + ```bash 20 + git clone https://github.com/haileyok/cocoon.git 21 + cd cocoon 22 + ``` 23 + 24 + 2. **Create your configuration file** 25 + ```bash 26 + cp .env.example .env 27 + ``` 28 + 29 + 3. **Edit `.env` with your settings** 30 + 31 + Required settings: 32 + ```bash 33 + COCOON_DID="did:web:your-domain.com" 34 + COCOON_HOSTNAME="your-domain.com" 35 + COCOON_CONTACT_EMAIL="you@example.com" 36 + COCOON_RELAYS="https://bsky.network" 37 + 38 + # Generate with: openssl rand -hex 16 39 + COCOON_ADMIN_PASSWORD="your-secure-password" 40 + 41 + # Generate with: openssl rand -hex 32 42 + COCOON_SESSION_SECRET="your-session-secret" 43 + ``` 44 + 45 + 4. **Start the services** 46 + ```bash 47 + # Pull pre-built image from GitHub Container Registry 48 + docker-compose pull 49 + docker-compose up -d 50 + ``` 51 + 52 + Or build locally: 53 + ```bash 54 + docker-compose build 55 + docker-compose up -d 56 + ``` 57 + 58 + **For PostgreSQL deployment:** 59 + ```bash 60 + # Add POSTGRES_PASSWORD to your .env file first! 61 + docker-compose -f docker-compose.postgres.yaml up -d 62 + ``` 63 + 64 + 5. **Get your invite code** 65 + 66 + On first run, an invite code is automatically created. View it with: 67 + ```bash 68 + docker-compose logs create-invite 69 + ``` 70 + 71 + Or check the saved file: 72 + ```bash 73 + cat keys/initial-invite-code.txt 74 + ``` 75 + 76 + **IMPORTANT**: Save this invite code! You'll need it to create your first account. 77 + 78 + 6. **Monitor the services** 79 + ```bash 80 + docker-compose logs -f 81 + ``` 82 + 83 + ### What Gets Set Up 84 + 85 + The Docker Compose setup includes: 86 + 87 + - **init-keys**: Automatically generates cryptographic keys (rotation key and JWK) on first run 88 + - **cocoon**: The main PDS service running on port 8080 89 + - **create-invite**: Automatically creates an initial invite code after Cocoon starts (first run only) 90 + - **caddy**: Reverse proxy with automatic HTTPS via Let's Encrypt 91 + 92 + ### Data Persistence 93 + 94 + The following directories will be created automatically: 95 + 96 + - `./keys/` - Cryptographic keys (generated automatically) 97 + - `rotation.key` - PDS rotation key 98 + - `jwk.key` - JWK private key 99 + - `initial-invite-code.txt` - Your first invite code (first run only) 100 + - `./data/` - SQLite database and blockstore 101 + - Docker volumes for Caddy configuration and certificates 102 + 103 + ### Optional Configuration 104 + 105 + #### Database Configuration 106 + 107 + By default, Cocoon uses SQLite which requires no additional setup. For production deployments with higher traffic, you can use PostgreSQL: 108 + 109 + ```bash 110 + # Database type: sqlite (default) or postgres 111 + COCOON_DB_TYPE="postgres" 112 + 113 + # PostgreSQL connection string (required if db-type is postgres) 114 + # Format: postgres://user:password@host:port/database?sslmode=disable 115 + COCOON_DATABASE_URL="postgres://cocoon:password@localhost:5432/cocoon?sslmode=disable" 116 + 117 + # Or use the standard DATABASE_URL environment variable 118 + DATABASE_URL="postgres://cocoon:password@localhost:5432/cocoon?sslmode=disable" 119 + ``` 120 + 121 + For SQLite (default): 122 + ```bash 123 + COCOON_DB_TYPE="sqlite" 124 + COCOON_DB_NAME="/data/cocoon/cocoon.db" 125 + ``` 126 + 127 + > **Note**: When using PostgreSQL, database backups to S3 are not handled by Cocoon. Use `pg_dump` or your database provider's backup solution instead. 128 + 129 + #### SMTP Email Settings 130 + ```bash 131 + COCOON_SMTP_USER="your-smtp-username" 132 + COCOON_SMTP_PASS="your-smtp-password" 133 + COCOON_SMTP_HOST="smtp.example.com" 134 + COCOON_SMTP_PORT="587" 135 + COCOON_SMTP_EMAIL="noreply@example.com" 136 + COCOON_SMTP_NAME="Cocoon PDS" 137 + ``` 138 + 139 + #### S3 Storage 140 + 141 + Cocoon supports S3-compatible storage for both database backups (SQLite only) and blob storage (images, videos, etc.): 142 + 143 + ```bash 144 + # Enable S3 backups (SQLite databases only - hourly backups) 145 + COCOON_S3_BACKUPS_ENABLED=true 146 + 147 + # Enable S3 for blob storage (images, videos, etc.) 148 + # When enabled, blobs are stored in S3 instead of the database 149 + COCOON_S3_BLOBSTORE_ENABLED=true 150 + 151 + # S3 configuration (works with AWS S3, MinIO, Cloudflare R2, etc.) 152 + COCOON_S3_REGION="us-east-1" 153 + COCOON_S3_BUCKET="your-bucket" 154 + COCOON_S3_ENDPOINT="https://s3.amazonaws.com" 155 + COCOON_S3_ACCESS_KEY="your-access-key" 156 + COCOON_S3_SECRET_KEY="your-secret-key" 157 + 158 + # Optional: CDN/public URL for blob redirects 159 + # When set, com.atproto.sync.getBlob redirects to this URL instead of proxying 160 + COCOON_S3_CDN_URL="https://cdn.example.com" 161 + ``` 162 + 163 + **Blob Storage Options:** 164 + - `COCOON_S3_BLOBSTORE_ENABLED=false` (default): Blobs stored in the database 165 + - `COCOON_S3_BLOBSTORE_ENABLED=true`: Blobs stored in S3 bucket under `blobs/{did}/{cid}` 166 + 167 + **Blob Serving Options:** 168 + - Without `COCOON_S3_CDN_URL`: Blobs are proxied through the PDS server 169 + - With `COCOON_S3_CDN_URL`: `getBlob` returns a 302 redirect to `{CDN_URL}/blobs/{did}/{cid}` 170 + 171 + > **Tip**: For Cloudflare R2, you can use the public bucket URL as the CDN URL. For AWS S3, you can use CloudFront or the S3 bucket URL directly if public access is enabled. 172 + 173 + ### Management Commands 174 + 175 + Create an invite code: 176 + ```bash 177 + docker exec cocoon-pds /cocoon create-invite-code --uses 1 178 + ``` 179 + 180 + Reset a user's password: 181 + ```bash 182 + docker exec cocoon-pds /cocoon reset-password --did "did:plc:xxx" 183 + ``` 184 + 185 + ### Updating 186 + 187 + ```bash 188 + docker-compose pull 189 + docker-compose up -d 190 + ``` 191 + 192 + ## Implemented Endpoints 9 193 10 194 > [!NOTE] 11 - Just because something is implemented doesn't mean it is finisehd. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that. 195 + Just because something is implemented doesn't mean it is finished. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that. 12 196 13 - #### Identity 14 - - [ ] com.atproto.identity.getRecommendedDidCredentials 15 - - [ ] com.atproto.identity.requestPlcOperationSignature 16 - - [x] com.atproto.identity.resolveHandle 17 - - [ ] com.atproto.identity.signPlcOperation 18 - - [ ] com.atproto.identity.submitPlcOperatioin 19 - - [x] com.atproto.identity.updateHandle 197 + ### Identity 20 198 21 - #### Repo 22 - - [x] com.atproto.repo.applyWrites 23 - - [x] com.atproto.repo.createRecord 24 - - [x] com.atproto.repo.putRecord 25 - - [x] com.atproto.repo.deleteRecord 26 - - [x] com.atproto.repo.describeRepo 27 - - [x] com.atproto.repo.getRecord 28 - - [ ] com.atproto.repo.importRepo 29 - - [x] com.atproto.repo.listRecords 30 - - [ ] com.atproto.repo.listMissingBlobs 199 + - [x] `com.atproto.identity.getRecommendedDidCredentials` 200 + - [x] `com.atproto.identity.requestPlcOperationSignature` 201 + - [x] `com.atproto.identity.resolveHandle` 202 + - [x] `com.atproto.identity.signPlcOperation` 203 + - [x] `com.atproto.identity.submitPlcOperation` 204 + - [x] `com.atproto.identity.updateHandle` 31 205 32 - #### Server 33 - - [ ] com.atproto.server.activateAccount 34 - - [x] com.atproto.server.checkAccountStatus 35 - - [x] com.atproto.server.confirmEmail 36 - - [x] com.atproto.server.createAccount 37 - - [x] com.atproto.server.createInviteCode 38 - - [x] com.atproto.server.createInviteCodes 39 - - [ ] com.atproto.server.deactivateAccount 40 - - [ ] com.atproto.server.deleteAccount 41 - - [x] com.atproto.server.deleteSession 42 - - [x] com.atproto.server.describeServer 43 - - [ ] com.atproto.server.getAccountInviteCodes 44 - - [ ] com.atproto.server.getServiceAuth 45 - - ~[ ] com.atproto.server.listAppPasswords~ - not going to add app passwords 46 - - [x] com.atproto.server.refreshSession 47 - - [ ] com.atproto.server.requestAccountDelete 48 - - [x] com.atproto.server.requestEmailConfirmation 49 - - [x] com.atproto.server.requestEmailUpdate 50 - - [x] com.atproto.server.requestPasswordReset 51 - - [ ] com.atproto.server.reserveSigningKey 52 - - [x] com.atproto.server.resetPassword 53 - - ~[ ] com.atproto.server.revokeAppPassword~ - not going to add app passwords 54 - - [x] com.atproto.server.updateEmail 206 + ### Repo 207 + 208 + - [x] `com.atproto.repo.applyWrites` 209 + - [x] `com.atproto.repo.createRecord` 210 + - [x] `com.atproto.repo.putRecord` 211 + - [x] `com.atproto.repo.deleteRecord` 212 + - [x] `com.atproto.repo.describeRepo` 213 + - [x] `com.atproto.repo.getRecord` 214 + - [x] `com.atproto.repo.importRepo` (Works "okay". Use with extreme caution.) 215 + - [x] `com.atproto.repo.listRecords` 216 + - [x] `com.atproto.repo.listMissingBlobs` 217 + 218 + ### Server 219 + 220 + - [x] `com.atproto.server.activateAccount` 221 + - [x] `com.atproto.server.checkAccountStatus` 222 + - [x] `com.atproto.server.confirmEmail` 223 + - [x] `com.atproto.server.createAccount` 224 + - [x] `com.atproto.server.createInviteCode` 225 + - [x] `com.atproto.server.createInviteCodes` 226 + - [x] `com.atproto.server.deactivateAccount` 227 + - [x] `com.atproto.server.deleteAccount` 228 + - [x] `com.atproto.server.deleteSession` 229 + - [x] `com.atproto.server.describeServer` 230 + - [ ] `com.atproto.server.getAccountInviteCodes` 231 + - [x] `com.atproto.server.getServiceAuth` 232 + - ~~[ ] `com.atproto.server.listAppPasswords`~~ - not going to add app passwords 233 + - [x] `com.atproto.server.refreshSession` 234 + - [x] `com.atproto.server.requestAccountDelete` 235 + - [x] `com.atproto.server.requestEmailConfirmation` 236 + - [x] `com.atproto.server.requestEmailUpdate` 237 + - [x] `com.atproto.server.requestPasswordReset` 238 + - [x] `com.atproto.server.reserveSigningKey` 239 + - [x] `com.atproto.server.resetPassword` 240 + - ~~[] `com.atproto.server.revokeAppPassword`~~ - not going to add app passwords 241 + - [x] `com.atproto.server.updateEmail` 242 + 243 + ### Sync 55 244 56 - #### Sync 57 - - [x] com.atproto.sync.getBlob 58 - - [x] com.atproto.sync.getBlocks 59 - - [x] com.atproto.sync.getLatestCommit 60 - - [x] com.atproto.sync.getRecord 61 - - [x] com.atproto.sync.getRepoStatus 62 - - [x] com.atproto.sync.getRepo 63 - - [x] com.atproto.sync.listBlobs 64 - - [x] com.atproto.sync.listRepos 65 - - ~[ ] com.atproto.sync.notifyOfUpdate~ - BGS doesn't even have this implemented lol 66 - - [x] com.atproto.sync.requestCrawl 67 - - [x] com.atproto.sync.subscribeRepos 245 + - [x] `com.atproto.sync.getBlob` 246 + - [x] `com.atproto.sync.getBlocks` 247 + - [x] `com.atproto.sync.getLatestCommit` 248 + - [x] `com.atproto.sync.getRecord` 249 + - [x] `com.atproto.sync.getRepoStatus` 250 + - [x] `com.atproto.sync.getRepo` 251 + - [x] `com.atproto.sync.listBlobs` 252 + - [x] `com.atproto.sync.listRepos` 253 + - ~~[ ] `com.atproto.sync.notifyOfUpdate`~~ - BGS doesn't even have this implemented lol 254 + - [x] `com.atproto.sync.requestCrawl` 255 + - [x] `com.atproto.sync.subscribeRepos` 68 256 69 - #### Other 70 - - [ ] com.atproto.label.queryLabels 71 - - [ ] com.atproto.moderation.createReport 72 - - [x] app.bsky.actor.getPreferences 73 - - [x] app.bsky.actor.putPreferences 257 + ### Other 74 258 259 + - [x] `com.atproto.label.queryLabels` 260 + - [x] `com.atproto.moderation.createReport` (Note: this should be handled by proxying, not actually implemented in the PDS) 261 + - [x] `app.bsky.actor.getPreferences` 262 + - [x] `app.bsky.actor.putPreferences` 75 263 76 264 ## License 77 265
-163
blockstore/blockstore.go
··· 1 - package blockstore 2 - 3 - import ( 4 - "context" 5 - "fmt" 6 - 7 - "github.com/bluesky-social/indigo/atproto/syntax" 8 - "github.com/haileyok/cocoon/internal/db" 9 - "github.com/haileyok/cocoon/models" 10 - blocks "github.com/ipfs/go-block-format" 11 - "github.com/ipfs/go-cid" 12 - "gorm.io/gorm/clause" 13 - ) 14 - 15 - type SqliteBlockstore struct { 16 - db *db.DB 17 - did string 18 - readonly bool 19 - inserts map[cid.Cid]blocks.Block 20 - } 21 - 22 - func New(did string, db *db.DB) *SqliteBlockstore { 23 - return &SqliteBlockstore{ 24 - did: did, 25 - db: db, 26 - readonly: false, 27 - inserts: map[cid.Cid]blocks.Block{}, 28 - } 29 - } 30 - 31 - func NewReadOnly(did string, db *db.DB) *SqliteBlockstore { 32 - return &SqliteBlockstore{ 33 - did: did, 34 - db: db, 35 - readonly: true, 36 - inserts: map[cid.Cid]blocks.Block{}, 37 - } 38 - } 39 - 40 - func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) { 41 - var block models.Block 42 - 43 - maybeBlock, ok := bs.inserts[cid] 44 - if ok { 45 - return maybeBlock, nil 46 - } 47 - 48 - if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 - return nil, err 50 - } 51 - 52 - b, err := blocks.NewBlockWithCid(block.Value, cid) 53 - if err != nil { 54 - return nil, err 55 - } 56 - 57 - return b, nil 58 - } 59 - 60 - func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error { 61 - bs.inserts[block.Cid()] = block 62 - 63 - if bs.readonly { 64 - return nil 65 - } 66 - 67 - b := models.Block{ 68 - Did: bs.did, 69 - Cid: block.Cid().Bytes(), 70 - Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 71 - Value: block.RawData(), 72 - } 73 - 74 - if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ 75 - Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 - UpdateAll: true, 77 - }}).Error; err != nil { 78 - return err 79 - } 80 - 81 - return nil 82 - } 83 - 84 - func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error { 85 - panic("not implemented") 86 - } 87 - 88 - func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) { 89 - panic("not implemented") 90 - } 91 - 92 - func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) { 93 - panic("not implemented") 94 - } 95 - 96 - func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 - tx := bs.db.BeginDangerously() 98 - 99 - for _, block := range blocks { 100 - bs.inserts[block.Cid()] = block 101 - 102 - if bs.readonly { 103 - continue 104 - } 105 - 106 - b := models.Block{ 107 - Did: bs.did, 108 - Cid: block.Cid().Bytes(), 109 - Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 110 - Value: block.RawData(), 111 - } 112 - 113 - if err := tx.Clauses(clause.OnConflict{ 114 - Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 115 - UpdateAll: true, 116 - }).Create(&b).Error; err != nil { 117 - tx.Rollback() 118 - return err 119 - } 120 - } 121 - 122 - if bs.readonly { 123 - return nil 124 - } 125 - 126 - tx.Commit() 127 - 128 - return nil 129 - } 130 - 131 - func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 132 - panic("not implemented") 133 - } 134 - 135 - func (bs *SqliteBlockstore) HashOnRead(enabled bool) { 136 - panic("not implemented") 137 - } 138 - 139 - func (bs *SqliteBlockstore) UpdateRepo(ctx context.Context, root cid.Cid, rev string) error { 140 - if err := bs.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, bs.did).Error; err != nil { 141 - return err 142 - } 143 - 144 - return nil 145 - } 146 - 147 - func (bs *SqliteBlockstore) Execute(ctx context.Context) error { 148 - if !bs.readonly { 149 - return fmt.Errorf("blockstore was not readonly") 150 - } 151 - 152 - bs.readonly = false 153 - for _, b := range bs.inserts { 154 - bs.Put(ctx, b) 155 - } 156 - bs.readonly = true 157 - 158 - return nil 159 - } 160 - 161 - func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block { 162 - return bs.inserts 163 - }
-186
cmd/admin/main.go
··· 1 - package main 2 - 3 - import ( 4 - "crypto/ecdsa" 5 - "crypto/elliptic" 6 - "crypto/rand" 7 - "encoding/json" 8 - "fmt" 9 - "os" 10 - "time" 11 - 12 - "github.com/bluesky-social/indigo/atproto/crypto" 13 - "github.com/bluesky-social/indigo/atproto/syntax" 14 - "github.com/haileyok/cocoon/internal/helpers" 15 - "github.com/lestrrat-go/jwx/v2/jwk" 16 - "github.com/urfave/cli/v2" 17 - "golang.org/x/crypto/bcrypt" 18 - "gorm.io/driver/sqlite" 19 - "gorm.io/gorm" 20 - ) 21 - 22 - func main() { 23 - app := cli.App{ 24 - Name: "admin", 25 - Commands: cli.Commands{ 26 - runCreateRotationKey, 27 - runCreatePrivateJwk, 28 - runCreateInviteCode, 29 - runResetPassword, 30 - }, 31 - ErrWriter: os.Stdout, 32 - } 33 - 34 - app.Run(os.Args) 35 - } 36 - 37 - var runCreateRotationKey = &cli.Command{ 38 - Name: "create-rotation-key", 39 - Usage: "creates a rotation key for your pds", 40 - Flags: []cli.Flag{ 41 - &cli.StringFlag{ 42 - Name: "out", 43 - Required: true, 44 - Usage: "output file for your rotation key", 45 - }, 46 - }, 47 - Action: func(cmd *cli.Context) error { 48 - key, err := crypto.GeneratePrivateKeyK256() 49 - if err != nil { 50 - return err 51 - } 52 - 53 - bytes := key.Bytes() 54 - 55 - if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil { 56 - return err 57 - } 58 - 59 - return nil 60 - }, 61 - } 62 - 63 - var runCreatePrivateJwk = &cli.Command{ 64 - Name: "create-private-jwk", 65 - Usage: "creates a private jwk for your pds", 66 - Flags: []cli.Flag{ 67 - &cli.StringFlag{ 68 - Name: "out", 69 - Required: true, 70 - Usage: "output file for your jwk", 71 - }, 72 - }, 73 - Action: func(cmd *cli.Context) error { 74 - privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 75 - if err != nil { 76 - return err 77 - } 78 - 79 - key, err := jwk.FromRaw(privKey) 80 - if err != nil { 81 - return err 82 - } 83 - 84 - kid := fmt.Sprintf("%d", time.Now().Unix()) 85 - 86 - if err := key.Set(jwk.KeyIDKey, kid); err != nil { 87 - return err 88 - } 89 - 90 - b, err := json.Marshal(key) 91 - if err != nil { 92 - return err 93 - } 94 - 95 - if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil { 96 - return err 97 - } 98 - 99 - return nil 100 - }, 101 - } 102 - 103 - var runCreateInviteCode = &cli.Command{ 104 - Name: "create-invite-code", 105 - Usage: "creates an invite code", 106 - Flags: []cli.Flag{ 107 - &cli.StringFlag{ 108 - Name: "for", 109 - Usage: "optional did to assign the invite code to", 110 - }, 111 - &cli.IntFlag{ 112 - Name: "uses", 113 - Usage: "number of times the invite code can be used", 114 - Value: 1, 115 - }, 116 - }, 117 - Action: func(cmd *cli.Context) error { 118 - db, err := newDb() 119 - if err != nil { 120 - return err 121 - } 122 - 123 - forDid := "did:plc:123" 124 - if cmd.String("for") != "" { 125 - did, err := syntax.ParseDID(cmd.String("for")) 126 - if err != nil { 127 - return err 128 - } 129 - 130 - forDid = did.String() 131 - } 132 - 133 - uses := cmd.Int("uses") 134 - 135 - code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8)) 136 - 137 - if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil { 138 - return err 139 - } 140 - 141 - fmt.Printf("New invite code created with %d uses: %s\n", uses, code) 142 - 143 - return nil 144 - }, 145 - } 146 - 147 - var runResetPassword = &cli.Command{ 148 - Name: "reset-password", 149 - Usage: "resets a password", 150 - Flags: []cli.Flag{ 151 - &cli.StringFlag{ 152 - Name: "did", 153 - Usage: "did of the user who's password you want to reset", 154 - }, 155 - }, 156 - Action: func(cmd *cli.Context) error { 157 - db, err := newDb() 158 - if err != nil { 159 - return err 160 - } 161 - 162 - didStr := cmd.String("did") 163 - did, err := syntax.ParseDID(didStr) 164 - if err != nil { 165 - return err 166 - } 167 - 168 - newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12)) 169 - hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10) 170 - if err != nil { 171 - return err 172 - } 173 - 174 - if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil { 175 - return err 176 - } 177 - 178 - fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass) 179 - 180 - return nil 181 - }, 182 - } 183 - 184 - func newDb() (*gorm.DB, error) { 185 - return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 186 - }
+274 -48
cmd/cocoon/main.go
··· 1 1 package main 2 2 3 3 import ( 4 + "crypto/ecdsa" 5 + "crypto/elliptic" 6 + "crypto/rand" 7 + "encoding/json" 4 8 "fmt" 5 9 "os" 10 + "time" 6 11 12 + "github.com/bluesky-social/go-util/pkg/telemetry" 13 + "github.com/bluesky-social/indigo/atproto/atcrypto" 14 + "github.com/bluesky-social/indigo/atproto/syntax" 15 + "github.com/haileyok/cocoon/internal/helpers" 7 16 "github.com/haileyok/cocoon/server" 8 17 _ "github.com/joho/godotenv/autoload" 18 + "github.com/lestrrat-go/jwx/v2/jwk" 9 19 "github.com/urfave/cli/v2" 20 + "golang.org/x/crypto/bcrypt" 21 + "gorm.io/driver/postgres" 22 + "gorm.io/driver/sqlite" 23 + "gorm.io/gorm" 10 24 ) 11 25 12 26 var Version = "dev" ··· 27 41 EnvVars: []string{"COCOON_DB_NAME"}, 28 42 }, 29 43 &cli.StringFlag{ 30 - Name: "did", 31 - Required: true, 32 - EnvVars: []string{"COCOON_DID"}, 44 + Name: "db-type", 45 + Value: "sqlite", 46 + Usage: "Database type: sqlite or postgres", 47 + EnvVars: []string{"COCOON_DB_TYPE"}, 48 + }, 49 + &cli.StringFlag{ 50 + Name: "database-url", 51 + Aliases: []string{"db-url"}, 52 + Usage: "PostgreSQL connection string (required if db-type is postgres)", 53 + EnvVars: []string{"COCOON_DATABASE_URL", "DATABASE_URL"}, 33 54 }, 34 55 &cli.StringFlag{ 35 - Name: "hostname", 36 - Required: true, 37 - EnvVars: []string{"COCOON_HOSTNAME"}, 56 + Name: "did", 57 + EnvVars: []string{"COCOON_DID"}, 38 58 }, 39 59 &cli.StringFlag{ 40 - Name: "rotation-key-path", 41 - Required: true, 42 - EnvVars: []string{"COCOON_ROTATION_KEY_PATH"}, 60 + Name: "hostname", 61 + EnvVars: []string{"COCOON_HOSTNAME"}, 62 + }, 63 + &cli.StringFlag{ 64 + Name: "rotation-key-path", 65 + EnvVars: []string{"COCOON_ROTATION_KEY_PATH"}, 43 66 }, 44 67 &cli.StringFlag{ 45 - Name: "jwk-path", 46 - Required: true, 47 - EnvVars: []string{"COCOON_JWK_PATH"}, 68 + Name: "jwk-path", 69 + EnvVars: []string{"COCOON_JWK_PATH"}, 48 70 }, 49 71 &cli.StringFlag{ 50 - Name: "contact-email", 51 - Required: true, 52 - EnvVars: []string{"COCOON_CONTACT_EMAIL"}, 72 + Name: "contact-email", 73 + EnvVars: []string{"COCOON_CONTACT_EMAIL"}, 53 74 }, 54 75 &cli.StringSliceFlag{ 55 - Name: "relays", 56 - Required: true, 57 - EnvVars: []string{"COCOON_RELAYS"}, 76 + Name: "relays", 77 + EnvVars: []string{"COCOON_RELAYS"}, 58 78 }, 59 79 &cli.StringFlag{ 60 - Name: "admin-password", 61 - Required: true, 62 - EnvVars: []string{"COCOON_ADMIN_PASSWORD"}, 80 + Name: "admin-password", 81 + EnvVars: []string{"COCOON_ADMIN_PASSWORD"}, 82 + }, 83 + &cli.BoolFlag{ 84 + Name: "require-invite", 85 + EnvVars: []string{"COCOON_REQUIRE_INVITE"}, 86 + Value: true, 63 87 }, 64 88 &cli.StringFlag{ 65 - Name: "smtp-user", 66 - Required: false, 67 - EnvVars: []string{"COCOON_SMTP_USER"}, 89 + Name: "smtp-user", 90 + EnvVars: []string{"COCOON_SMTP_USER"}, 68 91 }, 69 92 &cli.StringFlag{ 70 - Name: "smtp-pass", 71 - Required: false, 72 - EnvVars: []string{"COCOON_SMTP_PASS"}, 93 + Name: "smtp-pass", 94 + EnvVars: []string{"COCOON_SMTP_PASS"}, 73 95 }, 74 96 &cli.StringFlag{ 75 - Name: "smtp-host", 76 - Required: false, 77 - EnvVars: []string{"COCOON_SMTP_HOST"}, 97 + Name: "smtp-host", 98 + EnvVars: []string{"COCOON_SMTP_HOST"}, 78 99 }, 79 100 &cli.StringFlag{ 80 - Name: "smtp-port", 81 - Required: false, 82 - EnvVars: []string{"COCOON_SMTP_PORT"}, 101 + Name: "smtp-port", 102 + EnvVars: []string{"COCOON_SMTP_PORT"}, 83 103 }, 84 104 &cli.StringFlag{ 85 - Name: "smtp-email", 86 - Required: false, 87 - EnvVars: []string{"COCOON_SMTP_EMAIL"}, 105 + Name: "smtp-email", 106 + EnvVars: []string{"COCOON_SMTP_EMAIL"}, 88 107 }, 89 108 &cli.StringFlag{ 90 - Name: "smtp-name", 91 - Required: false, 92 - EnvVars: []string{"COCOON_SMTP_NAME"}, 109 + Name: "smtp-name", 110 + EnvVars: []string{"COCOON_SMTP_NAME"}, 93 111 }, 94 112 &cli.BoolFlag{ 95 113 Name: "s3-backups-enabled", 96 114 EnvVars: []string{"COCOON_S3_BACKUPS_ENABLED"}, 115 + }, 116 + &cli.BoolFlag{ 117 + Name: "s3-blobstore-enabled", 118 + EnvVars: []string{"COCOON_S3_BLOBSTORE_ENABLED"}, 97 119 }, 98 120 &cli.StringFlag{ 99 121 Name: "s3-region", ··· 116 138 EnvVars: []string{"COCOON_S3_SECRET_KEY"}, 117 139 }, 118 140 &cli.StringFlag{ 141 + Name: "s3-cdn-url", 142 + EnvVars: []string{"COCOON_S3_CDN_URL"}, 143 + Usage: "Public URL for S3 blob redirects (e.g., https://cdn.example.com). When set, getBlob redirects to this URL instead of proxying.", 144 + }, 145 + &cli.StringFlag{ 119 146 Name: "session-secret", 120 147 EnvVars: []string{"COCOON_SESSION_SECRET"}, 121 148 }, 149 + &cli.StringFlag{ 150 + Name: "blockstore-variant", 151 + EnvVars: []string{"COCOON_BLOCKSTORE_VARIANT"}, 152 + Value: "sqlite", 153 + }, 154 + &cli.StringFlag{ 155 + Name: "fallback-proxy", 156 + EnvVars: []string{"COCOON_FALLBACK_PROXY"}, 157 + }, 158 + telemetry.CLIFlagDebug, 159 + telemetry.CLIFlagMetricsListenAddress, 122 160 }, 123 161 Commands: []*cli.Command{ 124 - run, 162 + runServe, 163 + runCreateRotationKey, 164 + runCreatePrivateJwk, 165 + runCreateInviteCode, 166 + runResetPassword, 125 167 }, 126 168 ErrWriter: os.Stdout, 127 169 Version: Version, ··· 132 174 } 133 175 } 134 176 135 - var run = &cli.Command{ 177 + var runServe = &cli.Command{ 136 178 Name: "run", 137 179 Usage: "Start the cocoon PDS", 138 180 Flags: []cli.Flag{}, 139 181 Action: func(cmd *cli.Context) error { 182 + 183 + logger := telemetry.StartLogger(cmd) 184 + telemetry.StartMetrics(cmd) 185 + 140 186 s, err := server.New(&server.Args{ 187 + Logger: logger, 141 188 Addr: cmd.String("addr"), 142 189 DbName: cmd.String("db-name"), 190 + DbType: cmd.String("db-type"), 191 + DatabaseURL: cmd.String("database-url"), 143 192 Did: cmd.String("did"), 144 193 Hostname: cmd.String("hostname"), 145 194 RotationKeyPath: cmd.String("rotation-key-path"), ··· 148 197 Version: Version, 149 198 Relays: cmd.StringSlice("relays"), 150 199 AdminPassword: cmd.String("admin-password"), 200 + RequireInvite: cmd.Bool("require-invite"), 151 201 SmtpUser: cmd.String("smtp-user"), 152 202 SmtpPass: cmd.String("smtp-pass"), 153 203 SmtpHost: cmd.String("smtp-host"), ··· 155 205 SmtpEmail: cmd.String("smtp-email"), 156 206 SmtpName: cmd.String("smtp-name"), 157 207 S3Config: &server.S3Config{ 158 - BackupsEnabled: cmd.Bool("s3-backups-enabled"), 159 - Region: cmd.String("s3-region"), 160 - Bucket: cmd.String("s3-bucket"), 161 - Endpoint: cmd.String("s3-endpoint"), 162 - AccessKey: cmd.String("s3-access-key"), 163 - SecretKey: cmd.String("s3-secret-key"), 208 + BackupsEnabled: cmd.Bool("s3-backups-enabled"), 209 + BlobstoreEnabled: cmd.Bool("s3-blobstore-enabled"), 210 + Region: cmd.String("s3-region"), 211 + Bucket: cmd.String("s3-bucket"), 212 + Endpoint: cmd.String("s3-endpoint"), 213 + AccessKey: cmd.String("s3-access-key"), 214 + SecretKey: cmd.String("s3-secret-key"), 215 + CDNUrl: cmd.String("s3-cdn-url"), 164 216 }, 165 - SessionSecret: cmd.String("session-secret"), 217 + SessionSecret: cmd.String("session-secret"), 218 + BlockstoreVariant: server.MustReturnBlockstoreVariant(cmd.String("blockstore-variant")), 219 + FallbackProxy: cmd.String("fallback-proxy"), 166 220 }) 167 221 if err != nil { 168 222 fmt.Printf("error creating cocoon: %v", err) ··· 177 231 return nil 178 232 }, 179 233 } 234 + 235 + var runCreateRotationKey = &cli.Command{ 236 + Name: "create-rotation-key", 237 + Usage: "creates a rotation key for your pds", 238 + Flags: []cli.Flag{ 239 + &cli.StringFlag{ 240 + Name: "out", 241 + Required: true, 242 + Usage: "output file for your rotation key", 243 + }, 244 + }, 245 + Action: func(cmd *cli.Context) error { 246 + key, err := atcrypto.GeneratePrivateKeyK256() 247 + if err != nil { 248 + return err 249 + } 250 + 251 + bytes := key.Bytes() 252 + 253 + if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil { 254 + return err 255 + } 256 + 257 + return nil 258 + }, 259 + } 260 + 261 + var runCreatePrivateJwk = &cli.Command{ 262 + Name: "create-private-jwk", 263 + Usage: "creates a private jwk for your pds", 264 + Flags: []cli.Flag{ 265 + &cli.StringFlag{ 266 + Name: "out", 267 + Required: true, 268 + Usage: "output file for your jwk", 269 + }, 270 + }, 271 + Action: func(cmd *cli.Context) error { 272 + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 273 + if err != nil { 274 + return err 275 + } 276 + 277 + key, err := jwk.FromRaw(privKey) 278 + if err != nil { 279 + return err 280 + } 281 + 282 + kid := fmt.Sprintf("%d", time.Now().Unix()) 283 + 284 + if err := key.Set(jwk.KeyIDKey, kid); err != nil { 285 + return err 286 + } 287 + 288 + b, err := json.Marshal(key) 289 + if err != nil { 290 + return err 291 + } 292 + 293 + if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil { 294 + return err 295 + } 296 + 297 + return nil 298 + }, 299 + } 300 + 301 + var runCreateInviteCode = &cli.Command{ 302 + Name: "create-invite-code", 303 + Usage: "creates an invite code", 304 + Flags: []cli.Flag{ 305 + &cli.StringFlag{ 306 + Name: "for", 307 + Usage: "optional did to assign the invite code to", 308 + }, 309 + &cli.IntFlag{ 310 + Name: "uses", 311 + Usage: "number of times the invite code can be used", 312 + Value: 1, 313 + }, 314 + }, 315 + Action: func(cmd *cli.Context) error { 316 + db, err := newDb(cmd) 317 + if err != nil { 318 + return err 319 + } 320 + 321 + forDid := "did:plc:123" 322 + if cmd.String("for") != "" { 323 + did, err := syntax.ParseDID(cmd.String("for")) 324 + if err != nil { 325 + return err 326 + } 327 + 328 + forDid = did.String() 329 + } 330 + 331 + uses := cmd.Int("uses") 332 + 333 + code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8)) 334 + 335 + if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil { 336 + return err 337 + } 338 + 339 + fmt.Printf("New invite code created with %d uses: %s\n", uses, code) 340 + 341 + return nil 342 + }, 343 + } 344 + 345 + var runResetPassword = &cli.Command{ 346 + Name: "reset-password", 347 + Usage: "resets a password", 348 + Flags: []cli.Flag{ 349 + &cli.StringFlag{ 350 + Name: "did", 351 + Usage: "did of the user who's password you want to reset", 352 + }, 353 + }, 354 + Action: func(cmd *cli.Context) error { 355 + db, err := newDb(cmd) 356 + if err != nil { 357 + return err 358 + } 359 + 360 + didStr := cmd.String("did") 361 + did, err := syntax.ParseDID(didStr) 362 + if err != nil { 363 + return err 364 + } 365 + 366 + newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12)) 367 + hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10) 368 + if err != nil { 369 + return err 370 + } 371 + 372 + if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil { 373 + return err 374 + } 375 + 376 + fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass) 377 + 378 + return nil 379 + }, 380 + } 381 + 382 + func newDb(cmd *cli.Context) (*gorm.DB, error) { 383 + dbType := cmd.String("db-type") 384 + if dbType == "" { 385 + dbType = "sqlite" 386 + } 387 + 388 + switch dbType { 389 + case "postgres": 390 + databaseURL := cmd.String("database-url") 391 + if databaseURL == "" { 392 + databaseURL = cmd.String("database-url") 393 + } 394 + if databaseURL == "" { 395 + return nil, fmt.Errorf("COCOON_DATABASE_URL or DATABASE_URL must be set when using postgres") 396 + } 397 + return gorm.Open(postgres.Open(databaseURL), &gorm.Config{}) 398 + default: 399 + dbName := cmd.String("db-name") 400 + if dbName == "" { 401 + dbName = "cocoon.db" 402 + } 403 + return gorm.Open(sqlite.Open(dbName), &gorm.Config{}) 404 + } 405 + }
+56
create-initial-invite.sh
··· 1 + #!/bin/sh 2 + 3 + INVITE_FILE="/keys/initial-invite-code.txt" 4 + MARKER="/keys/.invite_created" 5 + 6 + # Check if invite code was already created 7 + if [ -f "$MARKER" ]; then 8 + echo "โœ“ Initial invite code already created" 9 + exit 0 10 + fi 11 + 12 + echo "Waiting for database to be ready..." 13 + sleep 10 14 + 15 + # Try to create invite code - retry until database is ready 16 + MAX_ATTEMPTS=30 17 + ATTEMPT=0 18 + INVITE_CODE="" 19 + 20 + while [ $ATTEMPT -lt $MAX_ATTEMPTS ]; do 21 + ATTEMPT=$((ATTEMPT + 1)) 22 + OUTPUT=$(/cocoon create-invite-code --uses 1 2>&1) 23 + INVITE_CODE=$(echo "$OUTPUT" | grep -oE '[a-zA-Z0-9]{8}-[a-zA-Z0-9]{8}' || echo "") 24 + 25 + if [ -n "$INVITE_CODE" ]; then 26 + break 27 + fi 28 + 29 + if [ $((ATTEMPT % 5)) -eq 0 ]; then 30 + echo " Waiting for database... ($ATTEMPT/$MAX_ATTEMPTS)" 31 + fi 32 + sleep 2 33 + done 34 + 35 + if [ -n "$INVITE_CODE" ]; then 36 + echo "" 37 + echo "โ•”โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•—" 38 + echo "โ•‘ SAVE THIS INVITE CODE! โ•‘" 39 + echo "โ•‘ โ•‘" 40 + echo "โ•‘ $INVITE_CODE โ•‘" 41 + echo "โ•‘ โ•‘" 42 + echo "โ•‘ Use this to create your first โ•‘" 43 + echo "โ•‘ account on your PDS. โ•‘" 44 + echo "โ•šโ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•โ•" 45 + echo "" 46 + 47 + echo "$INVITE_CODE" > "$INVITE_FILE" 48 + echo "โœ“ Invite code saved to: $INVITE_FILE" 49 + 50 + touch "$MARKER" 51 + echo "โœ“ Initial setup complete!" 52 + else 53 + echo "โœ— Failed to create invite code" 54 + echo "Output: $OUTPUT" 55 + exit 1 56 + fi
+45
cspell.json
··· 1 + { 2 + "version": "0.2", 3 + "language": "en", 4 + "words": [ 5 + "atproto", 6 + "bsky", 7 + "Cocoon", 8 + "PDS", 9 + "Plc", 10 + "plc", 11 + "repo", 12 + "InviteCodes", 13 + "InviteCode", 14 + "Invite", 15 + "Signin", 16 + "Signout", 17 + "JWKS", 18 + "dpop", 19 + "BGS", 20 + "pico", 21 + "picocss", 22 + "par", 23 + "blobs", 24 + "blob", 25 + "did", 26 + "DID", 27 + "OAuth", 28 + "oauth", 29 + "par", 30 + "Cocoon", 31 + "memcache", 32 + "db", 33 + "helpers", 34 + "middleware", 35 + "repo", 36 + "static", 37 + "pico", 38 + "picocss", 39 + "MIT", 40 + "Go" 41 + ], 42 + "ignorePaths": [ 43 + "server/static/pico.css" 44 + ] 45 + }
+158
docker-compose.postgres.yaml
··· 1 + # Docker Compose with PostgreSQL 2 + # 3 + # Usage: 4 + # docker-compose -f docker-compose.postgres.yaml up -d 5 + # 6 + # This file extends the base docker-compose.yaml with a PostgreSQL database. 7 + # Set the following in your .env file: 8 + # COCOON_DB_TYPE=postgres 9 + # POSTGRES_PASSWORD=your-secure-password 10 + 11 + version: '3.8' 12 + 13 + services: 14 + postgres: 15 + image: postgres:16-alpine 16 + container_name: cocoon-postgres 17 + environment: 18 + POSTGRES_USER: cocoon 19 + POSTGRES_PASSWORD: ${POSTGRES_PASSWORD:?POSTGRES_PASSWORD is required} 20 + POSTGRES_DB: cocoon 21 + volumes: 22 + - postgres_data:/var/lib/postgresql/data 23 + healthcheck: 24 + test: ["CMD-SHELL", "pg_isready -U cocoon -d cocoon"] 25 + interval: 10s 26 + timeout: 5s 27 + retries: 5 28 + restart: unless-stopped 29 + 30 + init-keys: 31 + build: 32 + context: . 33 + dockerfile: Dockerfile 34 + image: ghcr.io/haileyok/cocoon:latest 35 + container_name: cocoon-init-keys 36 + volumes: 37 + - ./keys:/keys 38 + - ./data:/data/cocoon 39 + - ./init-keys.sh:/init-keys.sh:ro 40 + environment: 41 + COCOON_DID: ${COCOON_DID} 42 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 43 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 44 + COCOON_JWK_PATH: /keys/jwk.key 45 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 46 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 47 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 48 + entrypoint: ["/bin/sh", "/init-keys.sh"] 49 + restart: "no" 50 + 51 + cocoon: 52 + build: 53 + context: . 54 + dockerfile: Dockerfile 55 + image: ghcr.io/haileyok/cocoon:latest 56 + container_name: cocoon-pds 57 + depends_on: 58 + init-keys: 59 + condition: service_completed_successfully 60 + postgres: 61 + condition: service_healthy 62 + ports: 63 + - "8080:8080" 64 + volumes: 65 + - ./data:/data/cocoon 66 + - ./keys/rotation.key:/keys/rotation.key:ro 67 + - ./keys/jwk.key:/keys/jwk.key:ro 68 + environment: 69 + # Required settings 70 + COCOON_DID: ${COCOON_DID} 71 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 72 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 73 + COCOON_JWK_PATH: /keys/jwk.key 74 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 75 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 76 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 77 + COCOON_SESSION_SECRET: ${COCOON_SESSION_SECRET} 78 + 79 + # Database configuration - PostgreSQL 80 + COCOON_ADDR: ":8080" 81 + COCOON_DB_TYPE: postgres 82 + COCOON_DATABASE_URL: postgres://cocoon:${POSTGRES_PASSWORD}@postgres:5432/cocoon?sslmode=disable 83 + COCOON_BLOCKSTORE_VARIANT: ${COCOON_BLOCKSTORE_VARIANT:-sqlite} 84 + 85 + # Optional: SMTP settings for email 86 + COCOON_SMTP_USER: ${COCOON_SMTP_USER:-} 87 + COCOON_SMTP_PASS: ${COCOON_SMTP_PASS:-} 88 + COCOON_SMTP_HOST: ${COCOON_SMTP_HOST:-} 89 + COCOON_SMTP_PORT: ${COCOON_SMTP_PORT:-} 90 + COCOON_SMTP_EMAIL: ${COCOON_SMTP_EMAIL:-} 91 + COCOON_SMTP_NAME: ${COCOON_SMTP_NAME:-} 92 + 93 + # Optional: S3 configuration 94 + COCOON_S3_BACKUPS_ENABLED: ${COCOON_S3_BACKUPS_ENABLED:-false} 95 + COCOON_S3_BLOBSTORE_ENABLED: ${COCOON_S3_BLOBSTORE_ENABLED:-false} 96 + COCOON_S3_REGION: ${COCOON_S3_REGION:-} 97 + COCOON_S3_BUCKET: ${COCOON_S3_BUCKET:-} 98 + COCOON_S3_ENDPOINT: ${COCOON_S3_ENDPOINT:-} 99 + COCOON_S3_ACCESS_KEY: ${COCOON_S3_ACCESS_KEY:-} 100 + COCOON_S3_SECRET_KEY: ${COCOON_S3_SECRET_KEY:-} 101 + 102 + # Optional: Fallback proxy 103 + COCOON_FALLBACK_PROXY: ${COCOON_FALLBACK_PROXY:-} 104 + restart: unless-stopped 105 + healthcheck: 106 + test: ["CMD", "curl", "-f", "http://localhost:8080/xrpc/_health"] 107 + interval: 30s 108 + timeout: 10s 109 + retries: 3 110 + start_period: 40s 111 + 112 + create-invite: 113 + build: 114 + context: . 115 + dockerfile: Dockerfile 116 + image: ghcr.io/haileyok/cocoon:latest 117 + container_name: cocoon-create-invite 118 + volumes: 119 + - ./keys:/keys 120 + - ./create-initial-invite.sh:/create-initial-invite.sh:ro 121 + environment: 122 + COCOON_DID: ${COCOON_DID} 123 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 124 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 125 + COCOON_JWK_PATH: /keys/jwk.key 126 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 127 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 128 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 129 + COCOON_DB_TYPE: postgres 130 + COCOON_DATABASE_URL: postgres://cocoon:${POSTGRES_PASSWORD}@postgres:5432/cocoon?sslmode=disable 131 + depends_on: 132 + cocoon: 133 + condition: service_healthy 134 + entrypoint: ["/bin/sh", "/create-initial-invite.sh"] 135 + restart: "no" 136 + 137 + caddy: 138 + image: caddy:2-alpine 139 + container_name: cocoon-caddy 140 + ports: 141 + - "80:80" 142 + - "443:443" 143 + volumes: 144 + - ./Caddyfile.postgres:/etc/caddy/Caddyfile:ro 145 + - caddy_data:/data 146 + - caddy_config:/config 147 + restart: unless-stopped 148 + environment: 149 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 150 + CADDY_ACME_EMAIL: ${COCOON_CONTACT_EMAIL:-} 151 + 152 + volumes: 153 + postgres_data: 154 + driver: local 155 + caddy_data: 156 + driver: local 157 + caddy_config: 158 + driver: local
+130
docker-compose.yaml
··· 1 + version: '3.8' 2 + 3 + services: 4 + init-keys: 5 + build: 6 + context: . 7 + dockerfile: Dockerfile 8 + image: ghcr.io/haileyok/cocoon:latest 9 + container_name: cocoon-init-keys 10 + volumes: 11 + - ./keys:/keys 12 + - ./data:/data/cocoon 13 + - ./init-keys.sh:/init-keys.sh:ro 14 + environment: 15 + COCOON_DID: ${COCOON_DID} 16 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 17 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 18 + COCOON_JWK_PATH: /keys/jwk.key 19 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 20 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 21 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 22 + entrypoint: ["/bin/sh", "/init-keys.sh"] 23 + restart: "no" 24 + 25 + cocoon: 26 + build: 27 + context: . 28 + dockerfile: Dockerfile 29 + image: ghcr.io/haileyok/cocoon:latest 30 + container_name: cocoon-pds 31 + network_mode: host 32 + depends_on: 33 + init-keys: 34 + condition: service_completed_successfully 35 + volumes: 36 + - ./data:/data/cocoon 37 + - ./keys/rotation.key:/keys/rotation.key:ro 38 + - ./keys/jwk.key:/keys/jwk.key:ro 39 + environment: 40 + # Required settings 41 + COCOON_DID: ${COCOON_DID} 42 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 43 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 44 + COCOON_JWK_PATH: /keys/jwk.key 45 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 46 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 47 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 48 + COCOON_SESSION_SECRET: ${COCOON_SESSION_SECRET} 49 + 50 + # Server configuration 51 + COCOON_ADDR: ":8080" 52 + COCOON_DB_TYPE: ${COCOON_DB_TYPE:-sqlite} 53 + COCOON_DB_NAME: ${COCOON_DB_NAME:-/data/cocoon/cocoon.db} 54 + COCOON_DATABASE_URL: ${COCOON_DATABASE_URL:-} 55 + COCOON_BLOCKSTORE_VARIANT: ${COCOON_BLOCKSTORE_VARIANT:-sqlite} 56 + 57 + # Optional: SMTP settings for email 58 + COCOON_SMTP_USER: ${COCOON_SMTP_USER:-} 59 + COCOON_SMTP_PASS: ${COCOON_SMTP_PASS:-} 60 + COCOON_SMTP_HOST: ${COCOON_SMTP_HOST:-} 61 + COCOON_SMTP_PORT: ${COCOON_SMTP_PORT:-} 62 + COCOON_SMTP_EMAIL: ${COCOON_SMTP_EMAIL:-} 63 + COCOON_SMTP_NAME: ${COCOON_SMTP_NAME:-} 64 + 65 + # Optional: S3 configuration 66 + COCOON_S3_BACKUPS_ENABLED: ${COCOON_S3_BACKUPS_ENABLED:-false} 67 + COCOON_S3_BLOBSTORE_ENABLED: ${COCOON_S3_BLOBSTORE_ENABLED:-false} 68 + COCOON_S3_REGION: ${COCOON_S3_REGION:-} 69 + COCOON_S3_BUCKET: ${COCOON_S3_BUCKET:-} 70 + COCOON_S3_ENDPOINT: ${COCOON_S3_ENDPOINT:-} 71 + COCOON_S3_ACCESS_KEY: ${COCOON_S3_ACCESS_KEY:-} 72 + COCOON_S3_SECRET_KEY: ${COCOON_S3_SECRET_KEY:-} 73 + COCOON_S3_CDN_URL: ${COCOON_S3_CDN_URL:-} 74 + 75 + # Optional: Fallback proxy 76 + COCOON_FALLBACK_PROXY: ${COCOON_FALLBACK_PROXY:-} 77 + restart: unless-stopped 78 + healthcheck: 79 + test: ["CMD", "curl", "-f", "http://localhost:8080/xrpc/_health"] 80 + interval: 30s 81 + timeout: 10s 82 + retries: 3 83 + start_period: 40s 84 + 85 + create-invite: 86 + build: 87 + context: . 88 + dockerfile: Dockerfile 89 + image: ghcr.io/haileyok/cocoon:latest 90 + container_name: cocoon-create-invite 91 + network_mode: host 92 + volumes: 93 + - ./keys:/keys 94 + - ./create-initial-invite.sh:/create-initial-invite.sh:ro 95 + environment: 96 + COCOON_DID: ${COCOON_DID} 97 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 98 + COCOON_ROTATION_KEY_PATH: /keys/rotation.key 99 + COCOON_JWK_PATH: /keys/jwk.key 100 + COCOON_CONTACT_EMAIL: ${COCOON_CONTACT_EMAIL} 101 + COCOON_RELAYS: ${COCOON_RELAYS:-https://bsky.network} 102 + COCOON_ADMIN_PASSWORD: ${COCOON_ADMIN_PASSWORD} 103 + COCOON_DB_TYPE: ${COCOON_DB_TYPE:-sqlite} 104 + COCOON_DB_NAME: ${COCOON_DB_NAME:-/data/cocoon/cocoon.db} 105 + COCOON_DATABASE_URL: ${COCOON_DATABASE_URL:-} 106 + depends_on: 107 + - init-keys 108 + entrypoint: ["/bin/sh", "/create-initial-invite.sh"] 109 + restart: "no" 110 + 111 + caddy: 112 + image: caddy:2-alpine 113 + container_name: cocoon-caddy 114 + network_mode: host 115 + volumes: 116 + - ./Caddyfile:/etc/caddy/Caddyfile:ro 117 + - caddy_data:/data 118 + - caddy_config:/config 119 + restart: unless-stopped 120 + environment: 121 + COCOON_HOSTNAME: ${COCOON_HOSTNAME} 122 + CADDY_ACME_EMAIL: ${COCOON_CONTACT_EMAIL:-} 123 + 124 + volumes: 125 + data: 126 + driver: local 127 + caddy_data: 128 + driver: local 129 + caddy_config: 130 + driver: local
+18 -15
go.mod
··· 1 1 module github.com/haileyok/cocoon 2 2 3 - go 1.24.1 3 + go 1.24.5 4 4 5 5 require ( 6 6 github.com/Azure/go-autorest/autorest/to v0.4.1 7 7 github.com/aws/aws-sdk-go v1.55.7 8 - github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b 8 + github.com/bluesky-social/go-util v0.0.0-20251012040650-2ebbf57f5934 9 + github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe 9 10 github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 10 11 github.com/domodwyer/mailyak/v3 v3.6.2 11 12 github.com/go-pkgz/expirable-cache/v3 v3.0.0 12 13 github.com/go-playground/validator v9.31.0+incompatible 13 14 github.com/golang-jwt/jwt/v4 v4.5.2 14 - github.com/google/uuid v1.4.0 15 + github.com/google/uuid v1.6.0 15 16 github.com/gorilla/sessions v1.4.0 16 17 github.com/gorilla/websocket v1.5.1 18 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b 17 19 github.com/hashicorp/golang-lru/v2 v2.0.7 18 20 github.com/ipfs/go-block-format v0.2.0 19 21 github.com/ipfs/go-cid v0.4.1 22 + github.com/ipfs/go-ipfs-blockstore v1.3.1 20 23 github.com/ipfs/go-ipld-cbor v0.1.0 21 24 github.com/ipld/go-car v0.6.1-0.20230509095817-92d28eb23ba4 22 25 github.com/joho/godotenv v1.5.1 ··· 24 27 github.com/labstack/echo/v4 v4.13.3 25 28 github.com/lestrrat-go/jwx/v2 v2.0.12 26 29 github.com/multiformats/go-multihash v0.2.3 30 + github.com/prometheus/client_golang v1.23.2 27 31 github.com/samber/slog-echo v1.16.1 28 32 github.com/urfave/cli/v2 v2.27.6 29 33 github.com/whyrusleeping/cbor-gen v0.2.1-0.20241030202151-b7a6831be65e 30 34 gitlab.com/yawning/secp256k1-voi v0.0.0-20230925100816-f2616030848b 31 - golang.org/x/crypto v0.38.0 35 + golang.org/x/crypto v0.41.0 36 + gorm.io/driver/postgres v1.5.7 32 37 gorm.io/driver/sqlite v1.5.7 33 38 gorm.io/gorm v1.25.12 34 39 ) ··· 54 59 github.com/gorilla/securecookie v1.1.2 // indirect 55 60 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed // indirect 56 61 github.com/hashicorp/go-cleanhttp v0.5.2 // indirect 57 - github.com/hashicorp/go-retryablehttp v0.7.5 // indirect 62 + github.com/hashicorp/go-retryablehttp v0.7.7 // indirect 58 63 github.com/hashicorp/golang-lru v1.0.2 // indirect 59 64 github.com/ipfs/bbloom v0.0.4 // indirect 60 65 github.com/ipfs/go-blockservice v0.5.2 // indirect 61 66 github.com/ipfs/go-datastore v0.6.0 // indirect 62 - github.com/ipfs/go-ipfs-blockstore v1.3.1 // indirect 63 67 github.com/ipfs/go-ipfs-ds-help v1.1.1 // indirect 64 68 github.com/ipfs/go-ipfs-exchange-interface v0.2.1 // indirect 65 69 github.com/ipfs/go-ipfs-util v0.0.3 // indirect ··· 101 105 github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect 102 106 github.com/opentracing/opentracing-go v1.2.0 // indirect 103 107 github.com/polydawn/refmt v0.89.1-0.20221221234430-40501e09de1f // indirect 104 - github.com/prometheus/client_golang v1.22.0 // indirect 105 108 github.com/prometheus/client_model v0.6.2 // indirect 106 - github.com/prometheus/common v0.63.0 // indirect 109 + github.com/prometheus/common v0.66.1 // indirect 107 110 github.com/prometheus/procfs v0.16.1 // indirect 108 111 github.com/russross/blackfriday/v2 v2.1.0 // indirect 109 112 github.com/samber/lo v1.49.1 // indirect ··· 113 116 github.com/valyala/fasttemplate v1.2.2 // indirect 114 117 github.com/xrash/smetrics v0.0.0-20240521201337-686a1a2994c1 // indirect 115 118 gitlab.com/yawning/tuplehash v0.0.0-20230713102510-df83abbf9a02 // indirect 116 - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 // indirect 119 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 // indirect 117 120 go.opentelemetry.io/otel v1.29.0 // indirect 118 121 go.opentelemetry.io/otel/metric v1.29.0 // indirect 119 122 go.opentelemetry.io/otel/trace v1.29.0 // indirect 120 123 go.uber.org/atomic v1.11.0 // indirect 121 124 go.uber.org/multierr v1.11.0 // indirect 122 125 go.uber.org/zap v1.26.0 // indirect 123 - golang.org/x/net v0.40.0 // indirect 124 - golang.org/x/sync v0.14.0 // indirect 125 - golang.org/x/sys v0.33.0 // indirect 126 - golang.org/x/text v0.25.0 // indirect 126 + go.yaml.in/yaml/v2 v2.4.2 // indirect 127 + golang.org/x/net v0.43.0 // indirect 128 + golang.org/x/sync v0.16.0 // indirect 129 + golang.org/x/sys v0.35.0 // indirect 130 + golang.org/x/text v0.28.0 // indirect 127 131 golang.org/x/time v0.11.0 // indirect 128 132 golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 // indirect 129 - google.golang.org/protobuf v1.36.6 // indirect 133 + google.golang.org/protobuf v1.36.9 // indirect 130 134 gopkg.in/go-playground/assert.v1 v1.2.1 // indirect 131 135 gopkg.in/inf.v0 v0.9.1 // indirect 132 - gorm.io/driver/postgres v1.5.7 // indirect 133 136 lukechampine.com/blake3 v1.2.1 // indirect 134 137 )
+46 -37
go.sum
··· 16 16 github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= 17 17 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932 h1:mXoPYz/Ul5HYEDvkta6I8/rnYM5gSdSV2tJ6XbZuEtY= 18 18 github.com/bitly/go-hostpool v0.0.0-20171023180738-a3a6125de932/go.mod h1:NOuUCSz6Q9T7+igc/hlvDOUdtWKryOrtFyIVABv/p7k= 19 - github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b h1:elwfbe+W7GkUmPKFX1h7HaeHvC/kC0XJWfiEHC62xPg= 20 - github.com/bluesky-social/indigo v0.0.0-20250414202759-826fcdeaa36b/go.mod h1:yjdhLA1LkK8VDS/WPUoYPo25/Hq/8rX38Ftr67EsqKY= 19 + github.com/bluesky-social/go-util v0.0.0-20251012040650-2ebbf57f5934 h1:btHMur2kTRgWEnCHn6LaI3BE9YRgsqTpwpJ1UdB7VEk= 20 + github.com/bluesky-social/go-util v0.0.0-20251012040650-2ebbf57f5934/go.mod h1:LWamyZfbQGW7PaVc5jumFfjgrshJ5mXgDUnR6fK7+BI= 21 + github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe h1:VBhaqE5ewQgXbY5SfSWFZC/AwHFo7cHxZKFYi2ce9Yo= 22 + github.com/bluesky-social/indigo v0.0.0-20251009212240-20524de167fe/go.mod h1:RuQVrCGm42QNsgumKaR6se+XkFKfCPNwdCiTvqKRUck= 21 23 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869 h1:DDGfHa7BWjL4YnC6+E63dPcxHo2sUxDIu8g3QgEJdRY= 22 24 github.com/bmizerany/assert v0.0.0-20160611221934-b7ed37b82869/go.mod h1:Ekp36dRnpXw/yCqJaO+ZrUyxD+3VXMFFr56k5XYrpB4= 23 25 github.com/btcsuite/websocket v0.0.0-20150119174127-31079b680792 h1:R8vQdOQdZ9Y3SkEwmHoWBmX1DNXhXZqlTpq6s4tyJGc= ··· 39 41 github.com/decred/dcrd/dcrec/secp256k1/v4 v4.2.0/go.mod h1:v57UDF4pDQJcEfFUCRop3lJL149eHGSe9Jvczhzjo/0= 40 42 github.com/domodwyer/mailyak/v3 v3.6.2 h1:x3tGMsyFhTCaxp6ycgR0FE/bu5QiNp+hetUuCOBXMn8= 41 43 github.com/domodwyer/mailyak/v3 v3.6.2/go.mod h1:lOm/u9CyCVWHeaAmHIdF4RiKVxKUT/H5XX10lIKAL6c= 44 + github.com/fatih/color v1.16.0 h1:zmkK9Ngbjj+K0yRhTVONQh1p/HknKYSlNT+vZCzyokM= 45 + github.com/fatih/color v1.16.0/go.mod h1:fL2Sau1YI5c0pdGEVCbKQbLXB6edEj1ZgiY4NijnWvE= 42 46 github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= 43 47 github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= 44 48 github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= ··· 77 81 github.com/google/gopacket v1.1.19 h1:ves8RnFZPGiFnTS0uPQStjwru6uO6h+nlr9j6fL7kF8= 78 82 github.com/google/gopacket v1.1.19/go.mod h1:iJ8V8n6KS+z2U1A8pUwu8bW5SyEMkXJB8Yo/Vo+TKTo= 79 83 github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= 80 - github.com/google/uuid v1.4.0 h1:MtMxsa51/r9yyhkyLsVeVt0B+BGQZzpQiTQ4eHZ8bc4= 81 - github.com/google/uuid v1.4.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 84 + github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0= 85 + github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= 82 86 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1 h1:EGx4pi6eqNxGaHF6qqu48+N2wcFQ5qg5FXgOdqsJ5d8= 83 87 github.com/gopherjs/gopherjs v0.0.0-20181017120253-0766667cb4d1/go.mod h1:wJfORRmW1u3UXTncJ5qlYoELFm8eSnnEO6hX4iZ3EWY= 84 88 github.com/gorilla/context v1.1.2 h1:WRkNAv2uoa03QNIc1A6u4O7DAGMUVoopZhkiXWA2V1o= ··· 91 95 github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= 92 96 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= 93 97 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= 98 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b h1:wDUNC2eKiL35DbLvsDhiblTUXHxcOPwQSCzi7xpQUN4= 99 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b/go.mod h1:VzxiSdG6j1pi7rwGm/xYI5RbtpBgM8sARDXlvEvxlu0= 94 100 github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= 95 101 github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= 96 - github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI= 97 - github.com/hashicorp/go-hclog v0.9.2/go.mod h1:5CU+agLiy3J7N7QjHK5d05KxGsuXiQLrjA0H7acj2lQ= 98 - github.com/hashicorp/go-retryablehttp v0.7.5 h1:bJj+Pj19UZMIweq/iie+1u5YCdGrnxCT9yvm0e+Nd5M= 99 - github.com/hashicorp/go-retryablehttp v0.7.5/go.mod h1:Jy/gPYAdjqffZ/yFGCFV2doI5wjtH1ewM9u8iYVjtX8= 102 + github.com/hashicorp/go-hclog v1.6.3 h1:Qr2kF+eVWjTiYmU7Y31tYlP1h0q/X3Nl3tPGdaB11/k= 103 + github.com/hashicorp/go-hclog v1.6.3/go.mod h1:W4Qnvbt70Wk/zYJryRzDRU/4r0kIg0PVHBcfoyhpF5M= 104 + github.com/hashicorp/go-retryablehttp v0.7.7 h1:C8hUCYzor8PIfXHa4UrZkU4VvK8o9ISHxT2Q8+VepXU= 105 + github.com/hashicorp/go-retryablehttp v0.7.7/go.mod h1:pkQpWZeYWskR+D1tR2O5OcBFOxfA7DoAO6xtkuQnHTk= 100 106 github.com/hashicorp/golang-lru v1.0.2 h1:dV3g9Z/unq5DpblPpw+Oqcv4dU/1omnb4Ok8iPY6p1c= 101 107 github.com/hashicorp/golang-lru v1.0.2/go.mod h1:iADmTwqILo4mZ8BN3D2Q6+9jd8WM5uGBxy+E8yxSoD4= 102 108 github.com/hashicorp/golang-lru/v2 v2.0.7 h1:a+bsQ5rvGLjzHuww6tVxozPZFVghXaHOwFs4luLUK2k= ··· 111 117 github.com/ipfs/go-block-format v0.2.0/go.mod h1:+jpL11nFx5A/SPpsoBn6Bzkra/zaArfSmsknbPMYgzM= 112 118 github.com/ipfs/go-blockservice v0.5.2 h1:in9Bc+QcXwd1apOVM7Un9t8tixPKdaHQFdLSUM1Xgk8= 113 119 github.com/ipfs/go-blockservice v0.5.2/go.mod h1:VpMblFEqG67A/H2sHKAemeH9vlURVavlysbdUI632yk= 114 - github.com/ipfs/go-bs-sqlite3 v0.0.0-20221122195556-bfcee1be620d h1:9V+GGXCuOfDiFpdAHz58q9mKLg447xp0cQKvqQrAwYE= 115 - github.com/ipfs/go-bs-sqlite3 v0.0.0-20221122195556-bfcee1be620d/go.mod h1:pMbnFyNAGjryYCLCe59YDLRv/ujdN+zGJBT1umlvYRM= 116 120 github.com/ipfs/go-cid v0.4.1 h1:A/T3qGvxi4kpKWWcPC/PgbvDA2bjVLO7n4UeVwnbs/s= 117 121 github.com/ipfs/go-cid v0.4.1/go.mod h1:uQHwDeX4c6CtyrFwdqyhpNcxVewur1M7l7fNU7LKwZk= 118 122 github.com/ipfs/go-datastore v0.6.0 h1:JKyz+Gvz1QEZw0LsX1IBn+JFCJQH4SJVFtM4uWU0Myk= ··· 195 199 github.com/jtolds/gls v4.20.0+incompatible/go.mod h1:QJZ7F/aHp+rZTRtaJ1ow/lLfFfVYBRgL+9YlvaHOwJU= 196 200 github.com/kisielk/errcheck v1.5.0/go.mod h1:pFxgyoBC7bSaBwPgfKdkLd5X25qrDl4LWUI2bnpBCr8= 197 201 github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck= 202 + github.com/klauspost/compress v1.18.0 h1:c/Cqfb0r+Yi+JtIEq73FWXVkRonBlf0CRNYc8Zttxdo= 203 + github.com/klauspost/compress v1.18.0/go.mod h1:2Pp+KzxcywXVXMr50+X0Q/Lsb43OQHYWRCY2AiWywWQ= 198 204 github.com/klauspost/cpuid/v2 v2.2.7 h1:ZWSB3igEs+d0qvnxR/ZBzXVmxkgt8DdzP6m9pfuVLDM= 199 205 github.com/klauspost/cpuid/v2 v2.2.7/go.mod h1:Lcz8mBdAVJIBVzewtcLocK12l3Y+JytZYpaMropDUws= 200 206 github.com/koron/go-ssdp v0.0.3 h1:JivLMY45N76b4p/vsWGOKewBQu6uf39y8l+AQ7sDKx8= ··· 206 212 github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= 207 213 github.com/kr/text v0.2.0 h1:5Nx0Ya0ZqY2ygV366QzturHI13Jq95ApcVaJBhpS+AY= 208 214 github.com/kr/text v0.2.0/go.mod h1:eLer722TekiGuMkidMxC/pM04lWEeraHUUmBw8l2grE= 215 + github.com/kylelemons/godebug v1.1.0 h1:RPNrshWIDI6G2gRW9EHilWtl7Z6Sb1BR0xunSBf0SNc= 216 + github.com/kylelemons/godebug v1.1.0/go.mod h1:9/0rRGxNHcop5bhtWyNeEfOS8JIWk580+fNqagV/RAw= 209 217 github.com/labstack/echo-contrib v0.17.4 h1:g5mfsrJfJTKv+F5uNKCyrjLK7js+ZW6HTjg4FnDxxgk= 210 218 github.com/labstack/echo-contrib v0.17.4/go.mod h1:9O7ZPAHUeMGTOAfg80YqQduHzt0CzLak36PZRldYrZ0= 211 219 github.com/labstack/echo/v4 v4.13.3 h1:pwhpCPrTl5qry5HRdM5FwdXnhXSLSY+WE+YQSeCaafY= ··· 289 297 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= 290 298 github.com/polydawn/refmt v0.89.1-0.20221221234430-40501e09de1f h1:VXTQfuJj9vKR4TCkEuWIckKvdHFeJH/huIFJ9/cXOB0= 291 299 github.com/polydawn/refmt v0.89.1-0.20221221234430-40501e09de1f/go.mod h1:/zvteZs/GwLtCgZ4BL6CBsk9IKIlexP43ObX9AxTqTw= 292 - github.com/prometheus/client_golang v1.22.0 h1:rb93p9lokFEsctTys46VnV1kLCDpVZ0a/Y92Vm0Zc6Q= 293 - github.com/prometheus/client_golang v1.22.0/go.mod h1:R7ljNsLXhuQXYZYtw6GAE9AZg8Y7vEW5scdCXrWRXC0= 300 + github.com/prometheus/client_golang v1.23.2 h1:Je96obch5RDVy3FDMndoUsjAhG5Edi49h0RJWRi/o0o= 301 + github.com/prometheus/client_golang v1.23.2/go.mod h1:Tb1a6LWHB3/SPIzCoaDXI4I8UHKeFTEQ1YCr+0Gyqmg= 294 302 github.com/prometheus/client_model v0.6.2 h1:oBsgwpGs7iVziMvrGhE53c/GrLUsZdHnqNwqPLxwZyk= 295 303 github.com/prometheus/client_model v0.6.2/go.mod h1:y3m2F6Gdpfy6Ut/GBsUqTWZqCUvMVzSfMLjcu6wAwpE= 296 - github.com/prometheus/common v0.63.0 h1:YR/EIY1o3mEFP/kZCD7iDMnLPlGyuU2Gb3HIcXnA98k= 297 - github.com/prometheus/common v0.63.0/go.mod h1:VVFF/fBIoToEnWRVkYoXEkq3R3paCoxG9PXP74SnV18= 304 + github.com/prometheus/common v0.66.1 h1:h5E0h5/Y8niHc5DlaLlWLArTQI7tMrsfQjHV+d9ZoGs= 305 + github.com/prometheus/common v0.66.1/go.mod h1:gcaUsgf3KfRSwHY4dIMXLPV0K/Wg1oZ8+SbZk/HH/dA= 298 306 github.com/prometheus/procfs v0.16.1 h1:hZ15bTNuirocR6u0JZ6BAHHmwS1p8B4P6MRqxtzMyRg= 299 307 github.com/prometheus/procfs v0.16.1/go.mod h1:teAbpZRB1iIAJYREa1LsoWUXykVXA1KlTmWl8x/U+Is= 300 308 github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= ··· 319 327 github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= 320 328 github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw= 321 329 github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo= 322 - github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= 323 330 github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= 324 331 github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= 325 332 github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= ··· 327 334 github.com/stretchr/testify v1.7.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= 328 335 github.com/stretchr/testify v1.8.0/go.mod h1:yNjHg4UonilssWZ8iaSj1OCr/vHnekPRkoO+kdMU+MU= 329 336 github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= 330 - github.com/stretchr/testify v1.10.0 h1:Xv5erBjTwe/5IxqUQTdXv5kgmIvbHo3QQyRwhJsOfJA= 331 - github.com/stretchr/testify v1.10.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY= 337 + github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= 338 + github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= 332 339 github.com/urfave/cli v1.22.10/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtXRu0= 333 340 github.com/urfave/cli/v2 v2.27.6 h1:VdRdS98FNhKZ8/Az8B7MTyGQmpIr36O1EHybx/LaZ4g= 334 341 github.com/urfave/cli/v2 v2.27.6/go.mod h1:3Sevf16NykTbInEnD0yKkjDAeZDS0A6bzhBH5hrMvTQ= ··· 354 361 gitlab.com/yawning/secp256k1-voi v0.0.0-20230925100816-f2616030848b/go.mod h1:/y/V339mxv2sZmYYR64O07VuCpdNZqCTwO8ZcouTMI8= 355 362 gitlab.com/yawning/tuplehash v0.0.0-20230713102510-df83abbf9a02 h1:qwDnMxjkyLmAFgcfgTnfJrmYKWhHnci3GjDqcZp1M3Q= 356 363 gitlab.com/yawning/tuplehash v0.0.0-20230713102510-df83abbf9a02/go.mod h1:JTnUj0mpYiAsuZLmKjTx/ex3AtMowcCgnE7YNyCEP0I= 357 - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1 h1:aFJWCqJMNjENlcleuuOkGAPH82y0yULBScfXcIEdS24= 358 - go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.46.1/go.mod h1:sEGXWArGqc3tVa+ekntsN65DmVbVeW+7lTKTjZF3/Fo= 364 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0 h1:jq9TW8u3so/bN+JPT166wjOI6/vQPF6Xe7nMNIltagk= 365 + go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.49.0/go.mod h1:p8pYQP+m5XfbZm9fxtSKAbM6oIllS7s2AfxrChvc7iw= 359 366 go.opentelemetry.io/otel v1.29.0 h1:PdomN/Al4q/lN6iBJEN3AwPvUiHPMlt93c8bqTG5Llw= 360 367 go.opentelemetry.io/otel v1.29.0/go.mod h1:N/WtXPs1CNCUEx+Agz5uouwCba+i+bJGFicT8SR4NP8= 361 368 go.opentelemetry.io/otel/metric v1.29.0 h1:vPf/HFWTNkPu1aYeIsc98l4ktOQaL6LeSoeV2g+8YLc= ··· 367 374 go.uber.org/atomic v1.11.0 h1:ZvwS0R+56ePWxUNi+Atn9dWONBPp/AUETXlHW0DxSjE= 368 375 go.uber.org/atomic v1.11.0/go.mod h1:LUxbIzbOniOlMKjJjyPfpl4v+PKK2cNJn91OQbhoJI0= 369 376 go.uber.org/goleak v1.1.11-0.20210813005559-691160354723/go.mod h1:cwTWslyiVhfpKIDGSZEM2HlOvcqm+tG4zioyIeLoqMQ= 370 - go.uber.org/goleak v1.2.0 h1:xqgm/S+aQvhWFTtR0XK3Jvg7z8kGV8P4X14IzwN3Eqk= 371 - go.uber.org/goleak v1.2.0/go.mod h1:XJYK+MuIchqpmGmUSAzotztawfKvYLUIgg7guXrwVUo= 377 + go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= 378 + go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= 372 379 go.uber.org/multierr v1.5.0/go.mod h1:FeouvMocqHpRaaGuG9EjoKcStLC43Zu/fmqdUMPcKYU= 373 380 go.uber.org/multierr v1.6.0/go.mod h1:cdWPpRnG4AhwMwsgIHip0KRBQjJy5kYEpYjJxpXp9iU= 374 381 go.uber.org/multierr v1.11.0 h1:blXXJkSxSSfBVBlC76pxqeO+LN3aDfLQo+309xJstO0= ··· 378 385 go.uber.org/zap v1.19.1/go.mod h1:j3DNczoxDZroyBnOT1L/Q79cfUMGZxlv/9dzN7SM1rI= 379 386 go.uber.org/zap v1.26.0 h1:sI7k6L95XOKS281NhVKOFCUNIvv9e0w4BF8N3u+tCRo= 380 387 go.uber.org/zap v1.26.0/go.mod h1:dtElttAiwGvoJ/vj4IwHBS/gXsEu/pZ50mUIRWuG0so= 388 + go.yaml.in/yaml/v2 v2.4.2 h1:DzmwEr2rDGHl7lsFgAHxmNz/1NlQ7xLIrlN2h5d1eGI= 389 + go.yaml.in/yaml/v2 v2.4.2/go.mod h1:081UH+NErpNdqlCXm3TtEran0rJZGxAYx9hb/ELlsPU= 381 390 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= 382 391 golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 383 392 golang.org/x/crypto v0.0.0-20191011191535-87dc89f01550/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= 384 393 golang.org/x/crypto v0.0.0-20200622213623-75b288015ac9/go.mod h1:LzIPMQfyMNhhGPhUkYOs5KpL4U8rLKemX1yGLhDgUto= 385 394 golang.org/x/crypto v0.0.0-20210921155107-089bfa567519/go.mod h1:GvvjBRRGRdwPK5ydBHafDWAxML/pGHZbMvKqRZ5+Abc= 386 395 golang.org/x/crypto v0.12.0/go.mod h1:NF0Gs7EO5K4qLn+Ylc+fih8BSTeIjAP05siRnAh98yw= 387 - golang.org/x/crypto v0.38.0 h1:jt+WWG8IZlBnVbomuhg2Mdq0+BBQaHbtqHEFEigjUV8= 388 - golang.org/x/crypto v0.38.0/go.mod h1:MvrbAqul58NNYPKnOra203SB9vpuZW0e+RRZV+Ggqjw= 396 + golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4= 397 + golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc= 389 398 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= 390 399 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= 391 400 golang.org/x/lint v0.0.0-20190930215403-16217165b5de/go.mod h1:6SW0HCj/g11FgYtHlgUYUwCkIfeOF89ocIRzGO/8vkc= ··· 395 404 golang.org/x/mod v0.4.2/go.mod h1:s0Qsj1ACt9ePp/hMypM3fl4fZqREWJwdYDEqhRiZZUA= 396 405 golang.org/x/mod v0.6.0-dev.0.20220419223038-86c51ed26bb4/go.mod h1:jJ57K6gSWd91VN4djpZkiMVwK6gcyfeH4XE8wZrZaV4= 397 406 golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= 398 - golang.org/x/mod v0.17.0 h1:zY54UmvipHiNd+pm+m0x9KhZ9hl1/7QNMyxXbc6ICqA= 399 - golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= 407 + golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg= 408 + golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ= 400 409 golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 401 410 golang.org/x/net v0.0.0-20190404232315-eb5bcb51f2a3/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= 402 411 golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= ··· 407 416 golang.org/x/net v0.0.0-20220722155237-a158d28d115b/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= 408 417 golang.org/x/net v0.6.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= 409 418 golang.org/x/net v0.10.0/go.mod h1:0qNGK6F8kojg2nk9dLZ2mShWaEBan6FAoqfSigmmuDg= 410 - golang.org/x/net v0.40.0 h1:79Xs7wF06Gbdcg4kdCCIQArK11Z1hr5POQ6+fIYHNuY= 411 - golang.org/x/net v0.40.0/go.mod h1:y0hY0exeL2Pku80/zKK7tpntoX23cqL3Oa6njdgRtds= 419 + golang.org/x/net v0.43.0 h1:lat02VYK2j4aLzMzecihNvTlJNQUq316m2Mr9rnM6YE= 420 + golang.org/x/net v0.43.0/go.mod h1:vhO1fvI4dGsIjh73sWfUVjj3N7CA9WkKJNQm2svM6Jg= 412 421 golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 413 422 golang.org/x/sync v0.0.0-20190911185100-cd5d95a43a6e/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 414 423 golang.org/x/sync v0.0.0-20201020160332-67f06af15bc9/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 415 424 golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 416 425 golang.org/x/sync v0.0.0-20220722155255-886fb9371eb4/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 417 426 golang.org/x/sync v0.1.0/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= 418 - golang.org/x/sync v0.14.0 h1:woo0S4Yywslg6hp4eUFjTVOyKt0RookbpAHG4c1HmhQ= 419 - golang.org/x/sync v0.14.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 427 + golang.org/x/sync v0.16.0 h1:ycBJEhp9p4vXvUZNszeOq0kGTPghopOL8q0fq3vstxw= 428 + golang.org/x/sync v0.16.0/go.mod h1:1dzgHSNfp02xaA81J2MS99Qcpr2w7fw1gpm99rleRqA= 420 429 golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= 421 430 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= 422 431 golang.org/x/sys v0.0.0-20200930185726-fdedc70b468f/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= ··· 432 441 golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 433 442 golang.org/x/sys v0.8.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 434 443 golang.org/x/sys v0.11.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= 435 - golang.org/x/sys v0.33.0 h1:q3i8TbbEz+JRD9ywIRlyRAQbM0qF7hu24q3teo2hbuw= 436 - golang.org/x/sys v0.33.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 444 + golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI= 445 + golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= 437 446 golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= 438 447 golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= 439 448 golang.org/x/term v0.5.0/go.mod h1:jMB1sMXY+tzblOD4FWmEbocvup2/aLOaQEp7JmGp78k= ··· 445 454 golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= 446 455 golang.org/x/text v0.9.0/go.mod h1:e1OnstbJyHTd6l/uOt8jFFHp6TRDWZR/bV3emEE/zU8= 447 456 golang.org/x/text v0.12.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE= 448 - golang.org/x/text v0.25.0 h1:qVyWApTSYLk/drJRO5mDlNYskwQznZmkpV2c8q9zls4= 449 - golang.org/x/text v0.25.0/go.mod h1:WEdwpYrmk1qmdHvhkSTNPm3app7v4rsT8F2UD6+VHIA= 457 + golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng= 458 + golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU= 450 459 golang.org/x/time v0.11.0 h1:/bpjEDfN9tkoN/ryeYHnv5hcMlc8ncjMcM4XBk5NWV0= 451 460 golang.org/x/time v0.11.0/go.mod h1:CDIdPxbZBQxdj6cxyCIdrNogrJKMJ7pr37NYpMcMDSg= 452 461 golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= ··· 461 470 golang.org/x/tools v0.1.5/go.mod h1:o0xws9oXOQQZyjljx8fwUC0k7L1pTE6eaCbjGeHmOkk= 462 471 golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc= 463 472 golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= 464 - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d h1:vU5i/LfpvrRCpgM/VPfJLg5KjxD3E+hfT1SH+d9zLwg= 465 - golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= 473 + golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0= 474 + golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw= 466 475 golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 467 476 golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 468 477 golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 469 478 golang.org/x/xerrors v0.0.0-20200804184101-5ec99f83aff1/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= 470 479 golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028 h1:+cNy6SZtPcJQH3LJVLOSmiC7MMxXNOb3PU/VUEz+EhU= 471 480 golang.org/x/xerrors v0.0.0-20231012003039-104605ab7028/go.mod h1:NDW/Ps6MPRej6fsCIbMTohpP40sJ/P/vI1MoTEGwX90= 472 - google.golang.org/protobuf v1.36.6 h1:z1NpPI8ku2WgiWnf+t9wTPsn6eP1L7ksHUlkfLvd9xY= 473 - google.golang.org/protobuf v1.36.6/go.mod h1:jduwjTPXsFjZGTmRluh+L6NjiWu7pchiJ2/5YcXBHnY= 481 + google.golang.org/protobuf v1.36.9 h1:w2gp2mA27hUeUzj9Ex9FBjsBm40zfaDtEWow293U7Iw= 482 + google.golang.org/protobuf v1.36.9/go.mod h1:fuxRtAxBytpl4zzqUh6/eyUujkJdNiuEkXntxiD/uRU= 474 483 gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 475 484 gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= 476 485 gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
+74 -55
identity/identity.go
··· 13 13 "github.com/bluesky-social/indigo/util" 14 14 ) 15 15 16 - func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) { 17 - if cli == nil { 18 - cli = util.RobustHTTPClient() 16 + func ResolveHandleFromTXT(ctx context.Context, handle string) (string, error) { 17 + name := fmt.Sprintf("_atproto.%s", handle) 18 + recs, err := net.LookupTXT(name) 19 + if err != nil { 20 + return "", fmt.Errorf("handle could not be resolved via txt: %w", err) 21 + } 22 + 23 + for _, rec := range recs { 24 + if strings.HasPrefix(rec, "did=") { 25 + maybeDid := strings.Split(rec, "did=")[1] 26 + if _, err := syntax.ParseDID(maybeDid); err == nil { 27 + return maybeDid, nil 28 + } 29 + } 30 + } 31 + 32 + return "", fmt.Errorf("handle could not be resolved via txt: no record found") 33 + } 34 + 35 + func ResolveHandleFromWellKnown(ctx context.Context, cli *http.Client, handle string) (string, error) { 36 + ustr := fmt.Sprintf("https://%s/.well-known/atproto-did", handle) 37 + req, err := http.NewRequestWithContext( 38 + ctx, 39 + "GET", 40 + ustr, 41 + nil, 42 + ) 43 + if err != nil { 44 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 19 45 } 20 46 21 - var did string 47 + resp, err := cli.Do(req) 48 + if err != nil { 49 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 50 + } 51 + defer resp.Body.Close() 22 52 23 - _, err := syntax.ParseHandle(handle) 53 + b, err := io.ReadAll(resp.Body) 24 54 if err != nil { 25 - return "", err 55 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 26 56 } 27 57 28 - recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle)) 29 - if err == nil { 30 - for _, rec := range recs { 31 - if strings.HasPrefix(rec, "did=") { 32 - did = strings.Split(rec, "did=")[1] 33 - break 34 - } 35 - } 36 - } else { 37 - fmt.Printf("erorr getting txt records: %v\n", err) 58 + if resp.StatusCode != http.StatusOK { 59 + return "", fmt.Errorf("handle could not be resolved via web: invalid status code %d", resp.StatusCode) 38 60 } 39 61 40 - if did == "" { 41 - req, err := http.NewRequestWithContext( 42 - ctx, 43 - "GET", 44 - fmt.Sprintf("https://%s/.well-known/atproto-did", handle), 45 - nil, 46 - ) 47 - if err != nil { 48 - return "", nil 49 - } 62 + maybeDid := string(b) 50 63 51 - resp, err := http.DefaultClient.Do(req) 52 - if err != nil { 53 - return "", nil 54 - } 55 - defer resp.Body.Close() 64 + if _, err := syntax.ParseDID(maybeDid); err != nil { 65 + return "", fmt.Errorf("handle could not be resolved via web: invalid did in document") 66 + } 56 67 57 - if resp.StatusCode != http.StatusOK { 58 - io.Copy(io.Discard, resp.Body) 59 - return "", fmt.Errorf("unable to resolve handle") 60 - } 68 + return maybeDid, nil 69 + } 61 70 62 - b, err := io.ReadAll(resp.Body) 63 - if err != nil { 64 - return "", err 65 - } 71 + func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) { 72 + if cli == nil { 73 + cli = util.RobustHTTPClient() 74 + } 66 75 67 - maybeDid := string(b) 76 + _, err := syntax.ParseHandle(handle) 77 + if err != nil { 78 + return "", err 79 + } 68 80 69 - if _, err := syntax.ParseDID(maybeDid); err != nil { 70 - return "", fmt.Errorf("unable to resolve handle") 71 - } 81 + if maybeDidFromTxt, err := ResolveHandleFromTXT(ctx, handle); err == nil { 82 + return maybeDidFromTxt, nil 83 + } 72 84 73 - did = maybeDid 85 + if maybeDidFromWeb, err := ResolveHandleFromWellKnown(ctx, cli, handle); err == nil { 86 + return maybeDidFromWeb, nil 74 87 } 75 88 76 - return did, nil 89 + return "", fmt.Errorf("handle could not be resolved") 90 + } 91 + 92 + func DidToDocUrl(did string) (string, error) { 93 + if strings.HasPrefix(did, "did:plc:") { 94 + return fmt.Sprintf("https://plc.directory/%s", did), nil 95 + } else if after, ok := strings.CutPrefix(did, "did:web:"); ok { 96 + return fmt.Sprintf("https://%s/.well-known/did.json", after), nil 97 + } else { 98 + return "", fmt.Errorf("did was not a supported did type") 99 + } 77 100 } 78 101 79 102 func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) { ··· 81 104 cli = util.RobustHTTPClient() 82 105 } 83 106 84 - var ustr string 85 - if strings.HasPrefix(did, "did:plc:") { 86 - ustr = fmt.Sprintf("https://plc.directory/%s", did) 87 - } else if strings.HasPrefix(did, "did:web:") { 88 - ustr = fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:")) 89 - } else { 90 - return nil, fmt.Errorf("did was not a supported did type") 107 + ustr, err := DidToDocUrl(did) 108 + if err != nil { 109 + return nil, err 91 110 } 92 111 93 112 req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil) ··· 95 114 return nil, err 96 115 } 97 116 98 - resp, err := http.DefaultClient.Do(req) 117 + resp, err := cli.Do(req) 99 118 if err != nil { 100 119 return nil, err 101 120 } ··· 103 122 104 123 if resp.StatusCode != 200 { 105 124 io.Copy(io.Discard, resp.Body) 106 - return nil, fmt.Errorf("could not find identity in plc registry") 125 + return nil, fmt.Errorf("unable to find did doc at url. did: %s. url: %s", did, ustr) 107 126 } 108 127 109 128 var diddoc DidDoc ··· 127 146 return nil, err 128 147 } 129 148 130 - resp, err := http.DefaultClient.Do(req) 149 + resp, err := cli.Do(req) 131 150 if err != nil { 132 151 return nil, err 133 152 }
+15 -5
identity/passport.go
··· 19 19 type Passport struct { 20 20 h *http.Client 21 21 bc BackingCache 22 - lk sync.Mutex 22 + mu sync.RWMutex 23 23 } 24 24 25 25 func NewPassport(h *http.Client, bc BackingCache) *Passport { ··· 30 30 return &Passport{ 31 31 h: h, 32 32 bc: bc, 33 - lk: sync.Mutex{}, 34 33 } 35 34 } 36 35 ··· 38 37 skipCache, _ := ctx.Value("skip-cache").(bool) 39 38 40 39 if !skipCache { 40 + p.mu.RLock() 41 41 cached, ok := p.bc.GetDoc(did) 42 + p.mu.RUnlock() 43 + 42 44 if ok { 43 45 return cached, nil 44 46 } 45 47 } 46 48 47 - p.lk.Lock() // this is pretty pathetic, and i should rethink this. but for now, fuck it 48 - defer p.lk.Unlock() 49 - 50 49 doc, err := FetchDidDoc(ctx, p.h, did) 51 50 if err != nil { 52 51 return nil, err 53 52 } 54 53 54 + p.mu.Lock() 55 55 p.bc.PutDoc(did, doc) 56 + p.mu.Unlock() 56 57 57 58 return doc, nil 58 59 } ··· 61 62 skipCache, _ := ctx.Value("skip-cache").(bool) 62 63 63 64 if !skipCache { 65 + p.mu.RLock() 64 66 cached, ok := p.bc.GetDid(handle) 67 + p.mu.RUnlock() 68 + 65 69 if ok { 66 70 return cached, nil 67 71 } ··· 72 76 return "", err 73 77 } 74 78 79 + p.mu.Lock() 75 80 p.bc.PutDid(handle, did) 81 + p.mu.Unlock() 76 82 77 83 return did, nil 78 84 } 79 85 80 86 func (p *Passport) BustDoc(ctx context.Context, did string) error { 87 + p.mu.Lock() 88 + defer p.mu.Unlock() 81 89 return p.bc.BustDoc(did) 82 90 } 83 91 84 92 func (p *Passport) BustDid(ctx context.Context, handle string) error { 93 + p.mu.Lock() 94 + defer p.mu.Unlock() 85 95 return p.bc.BustDid(handle) 86 96 }
+1 -1
identity/types.go
··· 4 4 Context []string `json:"@context"` 5 5 Id string `json:"id"` 6 6 AlsoKnownAs []string `json:"alsoKnownAs"` 7 - VerificationMethods []DidDocVerificationMethod `json:"verificationMethods"` 7 + VerificationMethods []DidDocVerificationMethod `json:"verificationMethod"` 8 8 Service []DidDocService `json:"service"` 9 9 } 10 10
+34
init-keys.sh
··· 1 + #!/bin/sh 2 + set -e 3 + 4 + mkdir -p /keys 5 + mkdir -p /data/cocoon 6 + 7 + if [ ! -f /keys/rotation.key ]; then 8 + echo "Generating rotation key..." 9 + /cocoon create-rotation-key --out /keys/rotation.key 2>/dev/null || true 10 + if [ -f /keys/rotation.key ]; then 11 + echo "โœ“ Rotation key generated at /keys/rotation.key" 12 + else 13 + echo "โœ— Failed to generate rotation key" 14 + exit 1 15 + fi 16 + else 17 + echo "โœ“ Rotation key already exists" 18 + fi 19 + 20 + if [ ! -f /keys/jwk.key ]; then 21 + echo "Generating JWK..." 22 + /cocoon create-private-jwk --out /keys/jwk.key 2>/dev/null || true 23 + if [ -f /keys/jwk.key ]; then 24 + echo "โœ“ JWK generated at /keys/jwk.key" 25 + else 26 + echo "โœ— Failed to generate JWK" 27 + exit 1 28 + fi 29 + else 30 + echo "โœ“ JWK already exists" 31 + fi 32 + 33 + echo "" 34 + echo "โœ“ Key initialization complete!"
+19 -12
internal/db/db.go
··· 1 1 package db 2 2 3 3 import ( 4 + "context" 4 5 "sync" 5 6 6 7 "gorm.io/gorm" ··· 19 20 } 20 21 } 21 22 22 - func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB { 23 + func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 23 24 db.mu.Lock() 24 25 defer db.mu.Unlock() 25 - return db.cli.Clauses(clauses...).Create(value) 26 + return db.cli.WithContext(ctx).Clauses(clauses...).Create(value) 26 27 } 27 28 28 - func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 29 + func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 29 30 db.mu.Lock() 30 31 defer db.mu.Unlock() 31 - return db.cli.Clauses(clauses...).Exec(sql, values...) 32 + return db.cli.WithContext(ctx).Clauses(clauses...).Save(value) 32 33 } 33 34 34 - func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 35 - return db.cli.Clauses(clauses...).Raw(sql, values...) 35 + func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 36 + db.mu.Lock() 37 + defer db.mu.Unlock() 38 + return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...) 39 + } 40 + 41 + func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 42 + return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...) 36 43 } 37 44 38 45 func (db *DB) AutoMigrate(models ...any) error { 39 46 return db.cli.AutoMigrate(models...) 40 47 } 41 48 42 - func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB { 49 + func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 43 50 db.mu.Lock() 44 51 defer db.mu.Unlock() 45 - return db.cli.Clauses(clauses...).Delete(value) 52 + return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value) 46 53 } 47 54 48 - func (db *DB) First(dest any, conds ...any) *gorm.DB { 49 - return db.cli.First(dest, conds...) 55 + func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB { 56 + return db.cli.WithContext(ctx).First(dest, conds...) 50 57 } 51 58 52 59 // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure 53 60 // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. 54 61 // e.g. when we do apply writes we should also be using a transcation but we don't right now 55 - func (db *DB) BeginDangerously() *gorm.DB { 56 - return db.cli.Begin() 62 + func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB { 63 + return db.cli.WithContext(ctx).Begin() 57 64 } 58 65 59 66 func (db *DB) Lock() {
+29
internal/helpers/helpers.go
··· 7 7 "math/rand" 8 8 "net/url" 9 9 10 + "github.com/Azure/go-autorest/autorest/to" 10 11 "github.com/labstack/echo/v4" 11 12 "github.com/lestrrat-go/jwx/v2/jwk" 12 13 ) ··· 29 30 msg += ". " + *suffix 30 31 } 31 32 return genericError(e, 400, msg) 33 + } 34 + 35 + func UnauthorizedError(e echo.Context, suffix *string) error { 36 + msg := "Unauthorized" 37 + if suffix != nil { 38 + msg += ". " + *suffix 39 + } 40 + return genericError(e, 401, msg) 41 + } 42 + 43 + func ForbiddenError(e echo.Context, suffix *string) error { 44 + msg := "Forbidden" 45 + if suffix != nil { 46 + msg += ". " + *suffix 47 + } 48 + return genericError(e, 403, msg) 49 + } 50 + 51 + func InvalidTokenError(e echo.Context) error { 52 + return InputError(e, to.StringPtr("InvalidToken")) 53 + } 54 + 55 + func ExpiredTokenError(e echo.Context) error { 56 + // WARN: See https://github.com/bluesky-social/atproto/discussions/3319 57 + return e.JSON(400, map[string]string{ 58 + "error": "ExpiredToken", 59 + "message": "*", 60 + }) 32 61 } 33 62 34 63 func genericError(e echo.Context, code int, msg string) error {
+30
metrics/metrics.go
··· 1 + package metrics 2 + 3 + import ( 4 + "github.com/prometheus/client_golang/prometheus" 5 + "github.com/prometheus/client_golang/prometheus/promauto" 6 + ) 7 + 8 + const ( 9 + NAMESPACE = "cocoon" 10 + ) 11 + 12 + var ( 13 + RelaysConnected = promauto.NewGaugeVec(prometheus.GaugeOpts{ 14 + Namespace: NAMESPACE, 15 + Name: "relays_connected", 16 + Help: "number of connected relays, by host", 17 + }, []string{"host"}) 18 + 19 + RelaySends = promauto.NewCounterVec(prometheus.CounterOpts{ 20 + Namespace: NAMESPACE, 21 + Name: "relay_sends", 22 + Help: "number of events sent to a relay, by host", 23 + }, []string{"host", "kind"}) 24 + 25 + RepoOperations = promauto.NewCounterVec(prometheus.CounterOpts{ 26 + Namespace: NAMESPACE, 27 + Name: "repo_operations", 28 + Help: "number of operations made against repos", 29 + }, []string{"kind"}) 30 + )
+38 -2
models/models.go
··· 4 4 "context" 5 5 "time" 6 6 7 - "github.com/bluesky-social/indigo/atproto/crypto" 7 + "github.com/Azure/go-autorest/autorest/to" 8 + "github.com/bluesky-social/indigo/atproto/atcrypto" 9 + ) 10 + 11 + type TwoFactorType string 12 + 13 + var ( 14 + TwoFactorTypeNone = TwoFactorType("none") 15 + TwoFactorTypeEmail = TwoFactorType("email") 8 16 ) 9 17 10 18 type Repo struct { ··· 18 26 EmailUpdateCodeExpiresAt *time.Time 19 27 PasswordResetCode *string 20 28 PasswordResetCodeExpiresAt *time.Time 29 + PlcOperationCode *string 30 + PlcOperationCodeExpiresAt *time.Time 31 + AccountDeleteCode *string 32 + AccountDeleteCodeExpiresAt *time.Time 21 33 Password string 22 34 SigningKey []byte 23 35 Rev string 24 36 Root []byte 25 37 Preferences []byte 38 + Deactivated bool 39 + TwoFactorCode *string 40 + TwoFactorCodeExpiresAt *time.Time 41 + TwoFactorType TwoFactorType `gorm:"default:none"` 26 42 } 27 43 28 44 func (r *Repo) SignFor(ctx context.Context, did string, msg []byte) ([]byte, error) { 29 - k, err := crypto.ParsePrivateBytesK256(r.SigningKey) 45 + k, err := atcrypto.ParsePrivateBytesK256(r.SigningKey) 30 46 if err != nil { 31 47 return nil, err 32 48 } ··· 39 55 return sig, nil 40 56 } 41 57 58 + func (r *Repo) Status() *string { 59 + var status *string 60 + if r.Deactivated { 61 + status = to.StringPtr("deactivated") 62 + } 63 + return status 64 + } 65 + 66 + func (r *Repo) Active() bool { 67 + return r.Status() == nil 68 + } 69 + 42 70 type Actor struct { 43 71 Did string `gorm:"primaryKey"` 44 72 Handle string `gorm:"uniqueIndex"` ··· 92 120 Did string `gorm:"index;index:idx_blob_did_cid"` 93 121 Cid []byte `gorm:"index;index:idx_blob_did_cid"` 94 122 RefCount int 123 + Storage string `gorm:"default:sqlite"` 95 124 } 96 125 97 126 type BlobPart struct { ··· 100 129 Idx int `gorm:"primaryKey"` 101 130 Data []byte 102 131 } 132 + 133 + type ReservedKey struct { 134 + KeyDid string `gorm:"primaryKey"` 135 + Did *string `gorm:"index"` 136 + PrivateKey []byte 137 + CreatedAt time.Time `gorm:"index"` 138 + }
+8
oauth/client/client.go
··· 1 + package client 2 + 3 + import "github.com/lestrrat-go/jwx/v2/jwk" 4 + 5 + type Client struct { 6 + Metadata *Metadata 7 + JWKS jwk.Key 8 + }
+412
oauth/client/manager.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "net/http" 11 + "net/url" 12 + "slices" 13 + "strings" 14 + "time" 15 + 16 + cache "github.com/go-pkgz/expirable-cache/v3" 17 + "github.com/haileyok/cocoon/internal/helpers" 18 + "github.com/lestrrat-go/jwx/v2/jwk" 19 + ) 20 + 21 + type Manager struct { 22 + cli *http.Client 23 + logger *slog.Logger 24 + jwksCache cache.Cache[string, jwk.Key] 25 + metadataCache cache.Cache[string, *Metadata] 26 + } 27 + 28 + type ManagerArgs struct { 29 + Cli *http.Client 30 + Logger *slog.Logger 31 + } 32 + 33 + func NewManager(args ManagerArgs) *Manager { 34 + if args.Logger == nil { 35 + args.Logger = slog.Default() 36 + } 37 + 38 + if args.Cli == nil { 39 + args.Cli = http.DefaultClient 40 + } 41 + 42 + jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 + metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 + 45 + return &Manager{ 46 + cli: args.Cli, 47 + logger: args.Logger, 48 + jwksCache: jwksCache, 49 + metadataCache: metadataCache, 50 + } 51 + } 52 + 53 + func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 + metadata, err := cm.getClientMetadata(ctx, clientId) 55 + if err != nil { 56 + return nil, err 57 + } 58 + 59 + var jwks jwk.Key 60 + if metadata.TokenEndpointAuthMethod == "private_key_jwt" { 61 + if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 { 62 + // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 + // make sure we use the right one 64 + b, err := json.Marshal(metadata.JWKS.Keys[0]) 65 + if err != nil { 66 + return nil, err 67 + } 68 + 69 + k, err := helpers.ParseJWKFromBytes(b) 70 + if err != nil { 71 + return nil, err 72 + } 73 + 74 + jwks = k 75 + } else if metadata.JWKSURI != nil { 76 + maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 77 + if err != nil { 78 + return nil, err 79 + } 80 + 81 + jwks = maybeJwks 82 + } else { 83 + return nil, fmt.Errorf("no valid jwks found in oauth client metadata") 84 + } 85 + } 86 + 87 + return &Client{ 88 + Metadata: metadata, 89 + JWKS: jwks, 90 + }, nil 91 + } 92 + 93 + func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 94 + cached, ok := cm.metadataCache.Get(clientId) 95 + if !ok { 96 + req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 97 + if err != nil { 98 + return nil, err 99 + } 100 + 101 + resp, err := cm.cli.Do(req) 102 + if err != nil { 103 + return nil, err 104 + } 105 + defer resp.Body.Close() 106 + 107 + if resp.StatusCode != http.StatusOK { 108 + io.Copy(io.Discard, resp.Body) 109 + return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 110 + } 111 + 112 + b, err := io.ReadAll(resp.Body) 113 + if err != nil { 114 + return nil, fmt.Errorf("error reading bytes from client response: %w", err) 115 + } 116 + 117 + validated, err := validateAndParseMetadata(clientId, b) 118 + if err != nil { 119 + return nil, err 120 + } 121 + 122 + cm.metadataCache.Set(clientId, validated, 10*time.Minute) 123 + 124 + return validated, nil 125 + } else { 126 + return cached, nil 127 + } 128 + } 129 + 130 + func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 131 + jwks, ok := cm.jwksCache.Get(clientId) 132 + if !ok { 133 + req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 134 + if err != nil { 135 + return nil, err 136 + } 137 + 138 + resp, err := cm.cli.Do(req) 139 + if err != nil { 140 + return nil, err 141 + } 142 + defer resp.Body.Close() 143 + 144 + if resp.StatusCode != http.StatusOK { 145 + io.Copy(io.Discard, resp.Body) 146 + return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 147 + } 148 + 149 + type Keys struct { 150 + Keys []map[string]any `json:"keys"` 151 + } 152 + 153 + var keys Keys 154 + if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 155 + return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 156 + } 157 + 158 + if len(keys.Keys) == 0 { 159 + return nil, errors.New("no keys in jwks response") 160 + } 161 + 162 + // TODO: this is again bad, we should be figuring out which one we need to use... 163 + b, err := json.Marshal(keys.Keys[0]) 164 + if err != nil { 165 + return nil, fmt.Errorf("could not marshal key: %w", err) 166 + } 167 + 168 + k, err := helpers.ParseJWKFromBytes(b) 169 + if err != nil { 170 + return nil, err 171 + } 172 + 173 + jwks = k 174 + } 175 + 176 + return jwks, nil 177 + } 178 + 179 + func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 180 + var metadataMap map[string]any 181 + if err := json.Unmarshal(b, &metadataMap); err != nil { 182 + return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 183 + } 184 + 185 + _, jwksOk := metadataMap["jwks"].(string) 186 + _, jwksUriOk := metadataMap["jwks_uri"].(string) 187 + if jwksOk && jwksUriOk { 188 + return nil, errors.New("jwks_uri and jwks are mutually exclusive") 189 + } 190 + 191 + for _, k := range []string{ 192 + "default_max_age", 193 + "userinfo_signed_response_alg", 194 + "id_token_signed_response_alg", 195 + "userinfo_encryhpted_response_alg", 196 + "authorization_encrypted_response_enc", 197 + "authorization_encrypted_response_alg", 198 + "tls_client_certificate_bound_access_tokens", 199 + } { 200 + _, kOk := metadataMap[k] 201 + if kOk { 202 + return nil, fmt.Errorf("unsupported `%s` parameter", k) 203 + } 204 + } 205 + 206 + var metadata Metadata 207 + if err := json.Unmarshal(b, &metadata); err != nil { 208 + return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 209 + } 210 + 211 + if metadata.ClientURI == "" { 212 + u, err := url.Parse(metadata.ClientID) 213 + if err != nil { 214 + return nil, fmt.Errorf("unable to parse client id: %w", err) 215 + } 216 + u.RawPath = "" 217 + u.RawQuery = "" 218 + metadata.ClientURI = u.String() 219 + } 220 + 221 + u, err := url.Parse(metadata.ClientURI) 222 + if err != nil { 223 + return nil, fmt.Errorf("unable to parse client uri: %w", err) 224 + } 225 + 226 + if metadata.ClientName == "" { 227 + metadata.ClientName = metadata.ClientURI 228 + } 229 + 230 + if isLocalHostname(u.Hostname()) { 231 + return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname()) 232 + } 233 + 234 + if metadata.Scope == "" { 235 + return nil, errors.New("missing `scopes` scope") 236 + } 237 + 238 + scopes := strings.Split(metadata.Scope, " ") 239 + if !slices.Contains(scopes, "atproto") { 240 + return nil, errors.New("missing `atproto` scope") 241 + } 242 + 243 + scopesMap := map[string]bool{} 244 + for _, scope := range scopes { 245 + if scopesMap[scope] { 246 + return nil, fmt.Errorf("duplicate scope `%s`", scope) 247 + } 248 + 249 + // TODO: check for unsupported scopes 250 + 251 + scopesMap[scope] = true 252 + } 253 + 254 + grantTypesMap := map[string]bool{} 255 + for _, gt := range metadata.GrantTypes { 256 + if grantTypesMap[gt] { 257 + return nil, fmt.Errorf("duplicate grant type `%s`", gt) 258 + } 259 + 260 + switch gt { 261 + case "implicit": 262 + return nil, errors.New("grantg type `implicit` is not allowed") 263 + case "authorization_code", "refresh_token": 264 + // TODO check if this grant type is supported 265 + default: 266 + return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 267 + } 268 + 269 + grantTypesMap[gt] = true 270 + } 271 + 272 + if metadata.ClientID != clientId { 273 + return nil, errors.New("`client_id` does not match") 274 + } 275 + 276 + subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 277 + if subjectTypeOk && subjectType != "public" { 278 + return nil, errors.New("only public `subject_type` is supported") 279 + } 280 + 281 + switch metadata.TokenEndpointAuthMethod { 282 + case "none": 283 + if metadata.TokenEndpointAuthSigningAlg != "" { 284 + return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 285 + } 286 + case "private_key_jwt": 287 + if metadata.JWKS == nil && metadata.JWKSURI == nil { 288 + return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 289 + } 290 + 291 + if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 { 292 + return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 293 + } 294 + 295 + if metadata.TokenEndpointAuthSigningAlg == "" { 296 + return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 297 + } 298 + default: 299 + return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 300 + } 301 + 302 + if !metadata.DpopBoundAccessTokens { 303 + return nil, errors.New("dpop_bound_access_tokens must be true") 304 + } 305 + 306 + if !slices.Contains(metadata.ResponseTypes, "code") { 307 + return nil, errors.New("response_types must inclue `code`") 308 + } 309 + 310 + if !slices.Contains(metadata.GrantTypes, "authorization_code") { 311 + return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 312 + } 313 + 314 + if len(metadata.RedirectURIs) == 0 { 315 + return nil, errors.New("at least one `redirect_uri` is required") 316 + } 317 + 318 + if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" { 319 + return nil, errors.New("native clients must authenticate using `none` method") 320 + } 321 + 322 + if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 323 + for _, ruri := range metadata.RedirectURIs { 324 + u, err := url.Parse(ruri) 325 + if err != nil { 326 + return nil, fmt.Errorf("error parsing redirect uri: %w", err) 327 + } 328 + 329 + if u.Scheme != "https" { 330 + return nil, errors.New("web clients must use https redirect uris") 331 + } 332 + 333 + if u.Hostname() == "localhost" { 334 + return nil, errors.New("web clients must not use localhost as the hostname") 335 + } 336 + } 337 + } 338 + 339 + for _, ruri := range metadata.RedirectURIs { 340 + u, err := url.Parse(ruri) 341 + if err != nil { 342 + return nil, fmt.Errorf("error parsing redirect uri: %w", err) 343 + } 344 + 345 + if u.User != nil { 346 + if u.User.Username() != "" { 347 + return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 348 + } 349 + 350 + if _, hasPass := u.User.Password(); hasPass { 351 + return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 352 + } 353 + } 354 + 355 + switch true { 356 + case u.Hostname() == "localhost": 357 + return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 358 + case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 359 + if metadata.ApplicationType != "native" { 360 + return nil, errors.New("loopback redirect uris are only allowed for native apps") 361 + } 362 + 363 + if u.Port() != "" { 364 + // reference impl doesn't do anything with this? 365 + } 366 + 367 + if u.Scheme != "http" { 368 + return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 369 + } 370 + case u.Scheme == "http": 371 + return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 372 + case u.Scheme == "https": 373 + if isLocalHostname(u.Hostname()) { 374 + return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 375 + } 376 + case strings.Contains(u.Scheme, "."): 377 + if metadata.ApplicationType != "native" { 378 + return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 379 + } 380 + 381 + revdomain := reverseDomain(u.Scheme) 382 + 383 + if isLocalHostname(revdomain) { 384 + return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 385 + } 386 + 387 + if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 388 + return nil, fmt.Errorf("private use uri scheme must be in the form ") 389 + } 390 + default: 391 + return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 392 + } 393 + } 394 + 395 + return &metadata, nil 396 + } 397 + 398 + func isLocalHostname(hostname string) bool { 399 + pts := strings.Split(hostname, ".") 400 + if len(pts) < 2 { 401 + return true 402 + } 403 + 404 + tld := strings.ToLower(pts[len(pts)-1]) 405 + return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 406 + } 407 + 408 + func reverseDomain(domain string) string { 409 + pts := strings.Split(domain, ".") 410 + slices.Reverse(pts) 411 + return strings.Join(pts, ".") 412 + }
+24
oauth/client/metadata.go
··· 1 + package client 2 + 3 + type Metadata struct { 4 + ClientID string `json:"client_id"` 5 + ClientName string `json:"client_name"` 6 + ClientURI string `json:"client_uri"` 7 + LogoURI string `json:"logo_uri"` 8 + TOSURI string `json:"tos_uri"` 9 + PolicyURI string `json:"policy_uri"` 10 + RedirectURIs []string `json:"redirect_uris"` 11 + GrantTypes []string `json:"grant_types"` 12 + ResponseTypes []string `json:"response_types"` 13 + ApplicationType string `json:"application_type"` 14 + DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 15 + JWKSURI *string `json:"jwks_uri,omitempty"` 16 + JWKS *MetadataJwks `json:"jwks,omitempty"` 17 + Scope string `json:"scope"` 18 + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 19 + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 20 + } 21 + 22 + type MetadataJwks struct { 23 + Keys []any `json:"keys"` 24 + }
-8
oauth/client.go
··· 1 - package oauth 2 - 3 - import "github.com/lestrrat-go/jwx/v2/jwk" 4 - 5 - type Client struct { 6 - Metadata *ClientMetadata 7 - JWKS jwk.Key 8 - }
-390
oauth/client_manager/client_manager.go
··· 1 - package client_manager 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "errors" 7 - "fmt" 8 - "io" 9 - "log/slog" 10 - "net/http" 11 - "net/url" 12 - "slices" 13 - "strings" 14 - "time" 15 - 16 - cache "github.com/go-pkgz/expirable-cache/v3" 17 - "github.com/haileyok/cocoon/internal/helpers" 18 - "github.com/haileyok/cocoon/oauth" 19 - "github.com/lestrrat-go/jwx/v2/jwk" 20 - ) 21 - 22 - type ClientManager struct { 23 - cli *http.Client 24 - logger *slog.Logger 25 - jwksCache cache.Cache[string, jwk.Key] 26 - metadataCache cache.Cache[string, oauth.ClientMetadata] 27 - } 28 - 29 - type Args struct { 30 - Cli *http.Client 31 - Logger *slog.Logger 32 - } 33 - 34 - func New(args Args) *ClientManager { 35 - if args.Logger == nil { 36 - args.Logger = slog.Default() 37 - } 38 - 39 - if args.Cli == nil { 40 - args.Cli = http.DefaultClient 41 - } 42 - 43 - jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 - metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 45 - 46 - return &ClientManager{ 47 - cli: args.Cli, 48 - logger: args.Logger, 49 - jwksCache: jwksCache, 50 - metadataCache: metadataCache, 51 - } 52 - } 53 - 54 - func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) { 55 - metadata, err := cm.getClientMetadata(ctx, clientId) 56 - if err != nil { 57 - return nil, err 58 - } 59 - 60 - var jwks jwk.Key 61 - if metadata.JWKS != nil { 62 - // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 - // make sure we use the right one 64 - k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0]) 65 - if err != nil { 66 - return nil, err 67 - } 68 - jwks = k 69 - } else if metadata.JWKSURI != nil { 70 - maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 71 - if err != nil { 72 - return nil, err 73 - } 74 - 75 - jwks = maybeJwks 76 - } 77 - 78 - return &oauth.Client{ 79 - Metadata: metadata, 80 - JWKS: jwks, 81 - }, nil 82 - } 83 - 84 - func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) { 85 - metadataCached, ok := cm.metadataCache.Get(clientId) 86 - if !ok { 87 - req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 88 - if err != nil { 89 - return nil, err 90 - } 91 - 92 - resp, err := cm.cli.Do(req) 93 - if err != nil { 94 - return nil, err 95 - } 96 - defer resp.Body.Close() 97 - 98 - if resp.StatusCode != http.StatusOK { 99 - io.Copy(io.Discard, resp.Body) 100 - return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 101 - } 102 - 103 - b, err := io.ReadAll(resp.Body) 104 - if err != nil { 105 - return nil, fmt.Errorf("error reading bytes from client response: %w", err) 106 - } 107 - 108 - validated, err := validateAndParseMetadata(clientId, b) 109 - if err != nil { 110 - return nil, err 111 - } 112 - 113 - return validated, nil 114 - } else { 115 - return &metadataCached, nil 116 - } 117 - } 118 - 119 - func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 120 - jwks, ok := cm.jwksCache.Get(clientId) 121 - if !ok { 122 - req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 123 - if err != nil { 124 - return nil, err 125 - } 126 - 127 - resp, err := cm.cli.Do(req) 128 - if err != nil { 129 - return nil, err 130 - } 131 - defer resp.Body.Close() 132 - 133 - if resp.StatusCode != http.StatusOK { 134 - io.Copy(io.Discard, resp.Body) 135 - return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 136 - } 137 - 138 - type Keys struct { 139 - Keys []map[string]any `json:"keys"` 140 - } 141 - 142 - var keys Keys 143 - if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 144 - return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 145 - } 146 - 147 - if len(keys.Keys) == 0 { 148 - return nil, errors.New("no keys in jwks response") 149 - } 150 - 151 - // TODO: this is again bad, we should be figuring out which one we need to use... 152 - b, err := json.Marshal(keys.Keys[0]) 153 - if err != nil { 154 - return nil, fmt.Errorf("could not marshal key: %w", err) 155 - } 156 - 157 - k, err := helpers.ParseJWKFromBytes(b) 158 - if err != nil { 159 - return nil, err 160 - } 161 - 162 - jwks = k 163 - } 164 - 165 - return jwks, nil 166 - } 167 - 168 - func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) { 169 - var metadataMap map[string]any 170 - if err := json.Unmarshal(b, &metadataMap); err != nil { 171 - return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 172 - } 173 - 174 - _, jwksOk := metadataMap["jwks"].(string) 175 - _, jwksUriOk := metadataMap["jwks_uri"].(string) 176 - if jwksOk && jwksUriOk { 177 - return nil, errors.New("jwks_uri and jwks are mutually exclusive") 178 - } 179 - 180 - for _, k := range []string{ 181 - "default_max_age", 182 - "userinfo_signed_response_alg", 183 - "id_token_signed_response_alg", 184 - "userinfo_encryhpted_response_alg", 185 - "authorization_encrypted_response_enc", 186 - "authorization_encrypted_response_alg", 187 - "tls_client_certificate_bound_access_tokens", 188 - } { 189 - _, kOk := metadataMap[k] 190 - if kOk { 191 - return nil, fmt.Errorf("unsupported `%s` parameter", k) 192 - } 193 - } 194 - 195 - var metadata oauth.ClientMetadata 196 - if err := json.Unmarshal(b, &metadata); err != nil { 197 - return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 198 - } 199 - 200 - u, err := url.Parse(metadata.ClientURI) 201 - if err != nil { 202 - return nil, fmt.Errorf("unable to parse client uri: %w", err) 203 - } 204 - 205 - if isLocalHostname(u.Hostname()) { 206 - return nil, errors.New("`client_uri` hostname is invalid") 207 - } 208 - 209 - if metadata.Scope == "" { 210 - return nil, errors.New("missing `scopes` scope") 211 - } 212 - 213 - scopes := strings.Split(metadata.Scope, " ") 214 - if !slices.Contains(scopes, "atproto") { 215 - return nil, errors.New("missing `atproto` scope") 216 - } 217 - 218 - scopesMap := map[string]bool{} 219 - for _, scope := range scopes { 220 - if scopesMap[scope] { 221 - return nil, fmt.Errorf("duplicate scope `%s`", scope) 222 - } 223 - 224 - // TODO: check for unsupported scopes 225 - 226 - scopesMap[scope] = true 227 - } 228 - 229 - grantTypesMap := map[string]bool{} 230 - for _, gt := range metadata.GrantTypes { 231 - if grantTypesMap[gt] { 232 - return nil, fmt.Errorf("duplicate grant type `%s`", gt) 233 - } 234 - 235 - switch gt { 236 - case "implicit": 237 - return nil, errors.New("grantg type `implicit` is not allowed") 238 - case "authorization_code", "refresh_token": 239 - // TODO check if this grant type is supported 240 - default: 241 - return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 242 - } 243 - 244 - grantTypesMap[gt] = true 245 - } 246 - 247 - if metadata.ClientID != clientId { 248 - return nil, errors.New("`client_id` does not match") 249 - } 250 - 251 - subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 252 - if subjectTypeOk && subjectType != "public" { 253 - return nil, errors.New("only public `subject_type` is supported") 254 - } 255 - 256 - switch metadata.TokenEndpointAuthMethod { 257 - case "none": 258 - if metadata.TokenEndpointAuthSigningAlg != "" { 259 - return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 260 - } 261 - case "private_key_jwt": 262 - if metadata.JWKS == nil && metadata.JWKSURI == nil { 263 - return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 264 - } 265 - 266 - if metadata.JWKS != nil && len(*metadata.JWKS) == 0 { 267 - return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 268 - } 269 - 270 - if metadata.TokenEndpointAuthSigningAlg == "" { 271 - return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 272 - } 273 - default: 274 - return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 275 - } 276 - 277 - if !metadata.DpopBoundAccessTokens { 278 - return nil, errors.New("dpop_bound_access_tokens must be true") 279 - } 280 - 281 - if !slices.Contains(metadata.ResponseTypes, "code") { 282 - return nil, errors.New("response_types must inclue `code`") 283 - } 284 - 285 - if !slices.Contains(metadata.GrantTypes, "authorization_code") { 286 - return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 287 - } 288 - 289 - if len(metadata.RedirectURIs) == 0 { 290 - return nil, errors.New("at least one `redirect_uri` is required") 291 - } 292 - 293 - if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" { 294 - return nil, errors.New("native clients must authenticate using `none` method") 295 - } 296 - 297 - if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 298 - for _, ruri := range metadata.RedirectURIs { 299 - u, err := url.Parse(ruri) 300 - if err != nil { 301 - return nil, fmt.Errorf("error parsing redirect uri: %w", err) 302 - } 303 - 304 - if u.Scheme != "https" { 305 - return nil, errors.New("web clients must use https redirect uris") 306 - } 307 - 308 - if u.Hostname() == "localhost" { 309 - return nil, errors.New("web clients must not use localhost as the hostname") 310 - } 311 - } 312 - } 313 - 314 - for _, ruri := range metadata.RedirectURIs { 315 - u, err := url.Parse(ruri) 316 - if err != nil { 317 - return nil, fmt.Errorf("error parsing redirect uri: %w", err) 318 - } 319 - 320 - if u.User != nil { 321 - if u.User.Username() != "" { 322 - return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 323 - } 324 - 325 - if _, hasPass := u.User.Password(); hasPass { 326 - return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 327 - } 328 - } 329 - 330 - switch true { 331 - case u.Hostname() == "localhost": 332 - return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 333 - case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 334 - if metadata.ApplicationType != "native" { 335 - return nil, errors.New("loopback redirect uris are only allowed for native apps") 336 - } 337 - 338 - if u.Port() != "" { 339 - // reference impl doesn't do anything with this? 340 - } 341 - 342 - if u.Scheme != "http" { 343 - return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 344 - } 345 - 346 - break 347 - case u.Scheme == "http": 348 - return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 349 - case u.Scheme == "https": 350 - if isLocalHostname(u.Hostname()) { 351 - return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 352 - } 353 - break 354 - case strings.Contains(u.Scheme, "."): 355 - if metadata.ApplicationType != "native" { 356 - return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 357 - } 358 - 359 - revdomain := reverseDomain(u.Scheme) 360 - 361 - if isLocalHostname(revdomain) { 362 - return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 363 - } 364 - 365 - if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 366 - return nil, fmt.Errorf("private use uri scheme must be in the form ") 367 - } 368 - default: 369 - return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 370 - } 371 - } 372 - 373 - return &metadata, nil 374 - } 375 - 376 - func isLocalHostname(hostname string) bool { 377 - pts := strings.Split(hostname, ".") 378 - if len(pts) < 2 { 379 - return true 380 - } 381 - 382 - tld := strings.ToLower(pts[len(pts)-1]) 383 - return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 384 - } 385 - 386 - func reverseDomain(domain string) string { 387 - pts := strings.Split(domain, ".") 388 - slices.Reverse(pts) 389 - return strings.Join(pts, ".") 390 - }
-20
oauth/client_metadata.go
··· 1 - package oauth 2 - 3 - type ClientMetadata struct { 4 - ClientID string `json:"client_id"` 5 - ClientName string `json:"client_name"` 6 - ClientURI string `json:"client_uri"` 7 - LogoURI string `json:"logo_uri"` 8 - TOSURI string `json:"tos_uri"` 9 - PolicyURI string `json:"policy_uri"` 10 - RedirectURIs []string `json:"redirect_uris"` 11 - GrantTypes []string `json:"grant_types"` 12 - ResponseTypes []string `json:"response_types"` 13 - ApplicationType string `json:"application_type"` 14 - DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 15 - JWKSURI *string `json:"jwks_uri,omitempty"` 16 - JWKS *[][]byte `json:"jwks,omitempty"` 17 - Scope string `json:"scope"` 18 - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 19 - TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 20 - }
-251
oauth/dpop/dpop_manager/dpop_manager.go
··· 1 - package dpop_manager 2 - 3 - import ( 4 - "crypto" 5 - "crypto/sha256" 6 - "encoding/base64" 7 - "encoding/json" 8 - "errors" 9 - "fmt" 10 - "log/slog" 11 - "net/http" 12 - "net/url" 13 - "strings" 14 - "time" 15 - 16 - "github.com/golang-jwt/jwt/v4" 17 - "github.com/haileyok/cocoon/internal/helpers" 18 - "github.com/haileyok/cocoon/oauth/constants" 19 - "github.com/haileyok/cocoon/oauth/dpop" 20 - "github.com/haileyok/cocoon/oauth/dpop/nonce" 21 - "github.com/lestrrat-go/jwx/v2/jwa" 22 - "github.com/lestrrat-go/jwx/v2/jwk" 23 - ) 24 - 25 - type DpopManager struct { 26 - nonce *nonce.Nonce 27 - jtiCache *jtiCache 28 - logger *slog.Logger 29 - hostname string 30 - } 31 - 32 - type Args struct { 33 - NonceSecret []byte 34 - NonceRotationInterval time.Duration 35 - OnNonceSecretCreated func([]byte) 36 - JTICacheSize int 37 - Logger *slog.Logger 38 - Hostname string 39 - } 40 - 41 - func New(args Args) *DpopManager { 42 - if args.Logger == nil { 43 - args.Logger = slog.Default() 44 - } 45 - 46 - if args.JTICacheSize == 0 { 47 - args.JTICacheSize = 100_000 48 - } 49 - 50 - if args.NonceSecret == nil { 51 - args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 52 - } 53 - 54 - return &DpopManager{ 55 - nonce: nonce.NewNonce(nonce.Args{ 56 - RotationInterval: args.NonceRotationInterval, 57 - Secret: args.NonceSecret, 58 - OnSecretCreated: args.OnNonceSecretCreated, 59 - }), 60 - jtiCache: newJTICache(args.JTICacheSize), 61 - logger: args.Logger, 62 - hostname: args.Hostname, 63 - } 64 - } 65 - 66 - func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) { 67 - if reqMethod == "" { 68 - return nil, errors.New("HTTP method is required") 69 - } 70 - 71 - if !strings.HasPrefix(reqUrl, "https://") { 72 - reqUrl = "https://" + dm.hostname + reqUrl 73 - } 74 - 75 - proof := extractProof(headers) 76 - 77 - if proof == "" { 78 - return nil, nil 79 - } 80 - 81 - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 82 - var token *jwt.Token 83 - 84 - token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 85 - if err != nil { 86 - return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 87 - } 88 - 89 - typ, _ := token.Header["typ"].(string) 90 - if typ != "dpop+jwt" { 91 - return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 92 - } 93 - 94 - dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 95 - if !jwkOk { 96 - return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 97 - } 98 - 99 - jwkb, err := json.Marshal(dpopJwk) 100 - if err != nil { 101 - return nil, fmt.Errorf("failed to marshal jwk: %w", err) 102 - } 103 - 104 - key, err := jwk.ParseKey(jwkb) 105 - if err != nil { 106 - return nil, fmt.Errorf("failed to parse jwk: %w", err) 107 - } 108 - 109 - var pubKey any 110 - if err := key.Raw(&pubKey); err != nil { 111 - return nil, fmt.Errorf("failed to get raw public key: %w", err) 112 - } 113 - 114 - token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 115 - alg := t.Header["alg"].(string) 116 - 117 - switch key.KeyType() { 118 - case jwa.EC: 119 - if !strings.HasPrefix(alg, "ES") { 120 - return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 121 - } 122 - case jwa.RSA: 123 - if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 124 - return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 125 - } 126 - case jwa.OKP: 127 - if alg != "EdDSA" { 128 - return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 129 - } 130 - } 131 - 132 - return pubKey, nil 133 - }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 134 - if err != nil { 135 - return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 136 - } 137 - 138 - if !token.Valid { 139 - return nil, errors.New("dpop proof jwt is invalid") 140 - } 141 - 142 - claims, ok := token.Claims.(jwt.MapClaims) 143 - if !ok { 144 - return nil, errors.New("no claims in dpop proof jwt") 145 - } 146 - 147 - iat, iatOk := claims["iat"].(float64) 148 - if !iatOk { 149 - return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 150 - } 151 - 152 - iatTime := time.Unix(int64(iat), 0) 153 - now := time.Now() 154 - 155 - if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 156 - return nil, errors.New("dpop proof too old") 157 - } 158 - 159 - if iatTime.Sub(now) > constants.DpopCheckTolerance { 160 - return nil, errors.New("dpop proof iat is in the future") 161 - } 162 - 163 - jti, _ := claims["jti"].(string) 164 - if jti == "" { 165 - return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 166 - } 167 - 168 - if dm.jtiCache.add(jti) { 169 - return nil, errors.New("dpop proof replay detected") 170 - } 171 - 172 - htm, _ := claims["htm"].(string) 173 - if htm == "" { 174 - return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 175 - } 176 - 177 - if htm != reqMethod { 178 - return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 179 - } 180 - 181 - htu, _ := claims["htu"].(string) 182 - if htu == "" { 183 - return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 184 - } 185 - 186 - parsedHtu, err := helpers.OauthParseHtu(htu) 187 - if err != nil { 188 - return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 189 - } 190 - 191 - u, _ := url.Parse(reqUrl) 192 - if parsedHtu != helpers.OauthNormalizeHtu(u) { 193 - return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 194 - } 195 - 196 - nonce, _ := claims["nonce"].(string) 197 - if nonce == "" { 198 - // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 199 - return nil, errors.New("use_dpop_nonce") 200 - } 201 - 202 - if nonce != "" && !dm.nonce.Check(nonce) { 203 - // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 204 - return nil, errors.New("use_dpop_nonce") 205 - } 206 - 207 - ath, _ := claims["ath"].(string) 208 - 209 - if accessToken != nil && *accessToken != "" { 210 - if ath == "" { 211 - return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 212 - } 213 - 214 - hash := sha256.Sum256([]byte(*accessToken)) 215 - if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 216 - return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 217 - } 218 - } else if ath != "" { 219 - return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 220 - } 221 - 222 - thumbBytes, err := key.Thumbprint(crypto.SHA256) 223 - if err != nil { 224 - return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 225 - } 226 - 227 - thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 228 - 229 - return &dpop.Proof{ 230 - JTI: jti, 231 - JKT: thumb, 232 - HTM: htm, 233 - HTU: htu, 234 - }, nil 235 - } 236 - 237 - func extractProof(headers http.Header) string { 238 - dpopHeaders := headers["Dpop"] 239 - switch len(dpopHeaders) { 240 - case 0: 241 - return "" 242 - case 1: 243 - return dpopHeaders[0] 244 - default: 245 - return "" 246 - } 247 - } 248 - 249 - func (dm *DpopManager) NextNonce() string { 250 - return dm.nonce.NextNonce() 251 - }
-28
oauth/dpop/dpop_manager/jti_cache.go
··· 1 - package dpop_manager 2 - 3 - import ( 4 - "sync" 5 - "time" 6 - 7 - cache "github.com/go-pkgz/expirable-cache/v3" 8 - "github.com/haileyok/cocoon/oauth/constants" 9 - ) 10 - 11 - type jtiCache struct { 12 - mu sync.Mutex 13 - cache cache.Cache[string, bool] 14 - } 15 - 16 - func newJTICache(size int) *jtiCache { 17 - cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl) 18 - return &jtiCache{ 19 - cache: cache, 20 - mu: sync.Mutex{}, 21 - } 22 - } 23 - 24 - func (c *jtiCache) add(jti string) bool { 25 - c.mu.Lock() 26 - defer c.mu.Unlock() 27 - return c.cache.Add(jti, true) 28 - }
+28
oauth/dpop/jti_cache.go
··· 1 + package dpop 2 + 3 + import ( 4 + "sync" 5 + "time" 6 + 7 + cache "github.com/go-pkgz/expirable-cache/v3" 8 + "github.com/haileyok/cocoon/oauth/constants" 9 + ) 10 + 11 + type jtiCache struct { 12 + mu sync.Mutex 13 + cache cache.Cache[string, bool] 14 + } 15 + 16 + func newJTICache(size int) *jtiCache { 17 + cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl).WithMaxKeys(size) 18 + return &jtiCache{ 19 + cache: cache, 20 + mu: sync.Mutex{}, 21 + } 22 + } 23 + 24 + func (c *jtiCache) add(jti string) bool { 25 + c.mu.Lock() 26 + defer c.mu.Unlock() 27 + return c.cache.Add(jti, true) 28 + }
+253
oauth/dpop/manager.go
··· 1 + package dpop 2 + 3 + import ( 4 + "crypto" 5 + "crypto/sha256" 6 + "encoding/base64" 7 + "encoding/json" 8 + "errors" 9 + "fmt" 10 + "log/slog" 11 + "net/http" 12 + "net/url" 13 + "strings" 14 + "time" 15 + 16 + "github.com/golang-jwt/jwt/v4" 17 + "github.com/haileyok/cocoon/internal/helpers" 18 + "github.com/haileyok/cocoon/oauth/constants" 19 + "github.com/lestrrat-go/jwx/v2/jwa" 20 + "github.com/lestrrat-go/jwx/v2/jwk" 21 + ) 22 + 23 + type Manager struct { 24 + nonce *Nonce 25 + jtiCache *jtiCache 26 + logger *slog.Logger 27 + hostname string 28 + } 29 + 30 + type ManagerArgs struct { 31 + NonceSecret []byte 32 + NonceRotationInterval time.Duration 33 + OnNonceSecretCreated func([]byte) 34 + JTICacheSize int 35 + Logger *slog.Logger 36 + Hostname string 37 + } 38 + 39 + var ( 40 + ErrUseDpopNonce = errors.New("use_dpop_nonce") 41 + ) 42 + 43 + func NewManager(args ManagerArgs) *Manager { 44 + if args.Logger == nil { 45 + args.Logger = slog.Default() 46 + } 47 + 48 + if args.JTICacheSize == 0 { 49 + args.JTICacheSize = 100_000 50 + } 51 + 52 + if args.NonceSecret == nil { 53 + args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 54 + } 55 + 56 + return &Manager{ 57 + nonce: NewNonce(NonceArgs{ 58 + RotationInterval: args.NonceRotationInterval, 59 + Secret: args.NonceSecret, 60 + OnSecretCreated: args.OnNonceSecretCreated, 61 + }), 62 + jtiCache: newJTICache(args.JTICacheSize), 63 + logger: args.Logger, 64 + hostname: args.Hostname, 65 + } 66 + } 67 + 68 + func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) { 69 + if reqMethod == "" { 70 + return nil, errors.New("HTTP method is required") 71 + } 72 + 73 + if !strings.HasPrefix(reqUrl, "https://") { 74 + reqUrl = "https://" + dm.hostname + reqUrl 75 + } 76 + 77 + proof := extractProof(headers) 78 + 79 + if proof == "" { 80 + return nil, nil 81 + } 82 + 83 + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 84 + var token *jwt.Token 85 + 86 + token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 87 + if err != nil { 88 + return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 89 + } 90 + 91 + typ, _ := token.Header["typ"].(string) 92 + if typ != "dpop+jwt" { 93 + return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 94 + } 95 + 96 + dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 97 + if !jwkOk { 98 + return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 99 + } 100 + 101 + jwkb, err := json.Marshal(dpopJwk) 102 + if err != nil { 103 + return nil, fmt.Errorf("failed to marshal jwk: %w", err) 104 + } 105 + 106 + key, err := jwk.ParseKey(jwkb) 107 + if err != nil { 108 + return nil, fmt.Errorf("failed to parse jwk: %w", err) 109 + } 110 + 111 + var pubKey any 112 + if err := key.Raw(&pubKey); err != nil { 113 + return nil, fmt.Errorf("failed to get raw public key: %w", err) 114 + } 115 + 116 + token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 117 + alg := t.Header["alg"].(string) 118 + 119 + switch key.KeyType() { 120 + case jwa.EC: 121 + if !strings.HasPrefix(alg, "ES") { 122 + return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 123 + } 124 + case jwa.RSA: 125 + if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 126 + return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 127 + } 128 + case jwa.OKP: 129 + if alg != "EdDSA" { 130 + return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 131 + } 132 + } 133 + 134 + return pubKey, nil 135 + }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 136 + if err != nil { 137 + return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 138 + } 139 + 140 + if !token.Valid { 141 + return nil, errors.New("dpop proof jwt is invalid") 142 + } 143 + 144 + claims, ok := token.Claims.(jwt.MapClaims) 145 + if !ok { 146 + return nil, errors.New("no claims in dpop proof jwt") 147 + } 148 + 149 + iat, iatOk := claims["iat"].(float64) 150 + if !iatOk { 151 + return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 152 + } 153 + 154 + iatTime := time.Unix(int64(iat), 0) 155 + now := time.Now() 156 + 157 + if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 158 + return nil, errors.New("dpop proof too old") 159 + } 160 + 161 + if iatTime.Sub(now) > constants.DpopCheckTolerance { 162 + return nil, errors.New("dpop proof iat is in the future") 163 + } 164 + 165 + jti, _ := claims["jti"].(string) 166 + if jti == "" { 167 + return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 168 + } 169 + 170 + if dm.jtiCache.add(jti) { 171 + return nil, errors.New("dpop proof replay detected") 172 + } 173 + 174 + htm, _ := claims["htm"].(string) 175 + if htm == "" { 176 + return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 177 + } 178 + 179 + if htm != reqMethod { 180 + return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 181 + } 182 + 183 + htu, _ := claims["htu"].(string) 184 + if htu == "" { 185 + return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 186 + } 187 + 188 + parsedHtu, err := helpers.OauthParseHtu(htu) 189 + if err != nil { 190 + return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 191 + } 192 + 193 + u, _ := url.Parse(reqUrl) 194 + if parsedHtu != helpers.OauthNormalizeHtu(u) { 195 + return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 196 + } 197 + 198 + nonce, _ := claims["nonce"].(string) 199 + if nonce == "" { 200 + // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 201 + return nil, ErrUseDpopNonce 202 + } 203 + 204 + if nonce != "" && !dm.nonce.Check(nonce) { 205 + // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 206 + return nil, ErrUseDpopNonce 207 + } 208 + 209 + ath, _ := claims["ath"].(string) 210 + 211 + if accessToken != nil && *accessToken != "" { 212 + if ath == "" { 213 + return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 214 + } 215 + 216 + hash := sha256.Sum256([]byte(*accessToken)) 217 + if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 218 + return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 219 + } 220 + } else if ath != "" { 221 + return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 222 + } 223 + 224 + thumbBytes, err := key.Thumbprint(crypto.SHA256) 225 + if err != nil { 226 + return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 227 + } 228 + 229 + thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 230 + 231 + return &Proof{ 232 + JTI: jti, 233 + JKT: thumb, 234 + HTM: htm, 235 + HTU: htu, 236 + }, nil 237 + } 238 + 239 + func extractProof(headers http.Header) string { 240 + dpopHeaders := headers["Dpop"] 241 + switch len(dpopHeaders) { 242 + case 0: 243 + return "" 244 + case 1: 245 + return dpopHeaders[0] 246 + default: 247 + return "" 248 + } 249 + } 250 + 251 + func (dm *Manager) NextNonce() string { 252 + return dm.nonce.NextNonce() 253 + }
-108
oauth/dpop/nonce/nonce.go
··· 1 - package nonce 2 - 3 - import ( 4 - "crypto/hmac" 5 - "crypto/sha256" 6 - "encoding/base64" 7 - "encoding/binary" 8 - "sync" 9 - "time" 10 - 11 - "github.com/haileyok/cocoon/internal/helpers" 12 - "github.com/haileyok/cocoon/oauth/constants" 13 - ) 14 - 15 - type Nonce struct { 16 - rotationInterval time.Duration 17 - secret []byte 18 - 19 - mu sync.RWMutex 20 - 21 - counter int64 22 - prev string 23 - curr string 24 - next string 25 - } 26 - 27 - type Args struct { 28 - RotationInterval time.Duration 29 - Secret []byte 30 - OnSecretCreated func([]byte) 31 - } 32 - 33 - func NewNonce(args Args) *Nonce { 34 - if args.RotationInterval == 0 { 35 - args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 - } 37 - 38 - if args.RotationInterval > constants.NonceMaxRotationInterval { 39 - args.RotationInterval = constants.NonceMaxRotationInterval 40 - } 41 - 42 - if args.Secret == nil { 43 - args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength) 44 - args.OnSecretCreated(args.Secret) 45 - } 46 - 47 - n := &Nonce{ 48 - rotationInterval: args.RotationInterval, 49 - secret: args.Secret, 50 - mu: sync.RWMutex{}, 51 - } 52 - 53 - n.counter = n.currentCounter() 54 - n.prev = n.compute(n.counter - 1) 55 - n.curr = n.compute(n.counter) 56 - n.next = n.compute(n.counter + 1) 57 - 58 - return n 59 - } 60 - 61 - func (n *Nonce) currentCounter() int64 { 62 - return time.Now().UnixNano() / int64(n.rotationInterval) 63 - } 64 - 65 - func (n *Nonce) compute(counter int64) string { 66 - h := hmac.New(sha256.New, n.secret) 67 - counterBytes := make([]byte, 8) 68 - binary.BigEndian.PutUint64(counterBytes, uint64(counter)) 69 - h.Write(counterBytes) 70 - return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 71 - } 72 - 73 - func (n *Nonce) rotate() { 74 - counter := n.currentCounter() 75 - diff := counter - n.counter 76 - 77 - switch diff { 78 - case 0: 79 - // counter == n.counter, do nothing 80 - case 1: 81 - n.prev = n.curr 82 - n.curr = n.next 83 - n.next = n.compute(counter + 1) 84 - case 2: 85 - n.prev = n.next 86 - n.curr = n.compute(counter) 87 - n.next = n.compute(counter + 1) 88 - default: 89 - n.prev = n.compute(counter - 1) 90 - n.curr = n.compute(counter) 91 - n.next = n.compute(counter + 1) 92 - } 93 - 94 - n.counter = counter 95 - } 96 - 97 - func (n *Nonce) NextNonce() string { 98 - n.mu.Lock() 99 - defer n.mu.Unlock() 100 - n.rotate() 101 - return n.next 102 - } 103 - 104 - func (n *Nonce) Check(nonce string) bool { 105 - n.mu.RLock() 106 - defer n.mu.RUnlock() 107 - return nonce == n.prev || nonce == n.curr || nonce == n.next 108 - }
+109
oauth/dpop/nonce.go
··· 1 + package dpop 2 + 3 + import ( 4 + "crypto/hmac" 5 + "crypto/sha256" 6 + "encoding/base64" 7 + "encoding/binary" 8 + "sync" 9 + "time" 10 + 11 + "github.com/haileyok/cocoon/internal/helpers" 12 + "github.com/haileyok/cocoon/oauth/constants" 13 + ) 14 + 15 + type Nonce struct { 16 + rotationInterval time.Duration 17 + secret []byte 18 + 19 + mu sync.RWMutex 20 + 21 + counter int64 22 + prev string 23 + curr string 24 + next string 25 + } 26 + 27 + type NonceArgs struct { 28 + RotationInterval time.Duration 29 + Secret []byte 30 + OnSecretCreated func([]byte) 31 + } 32 + 33 + func NewNonce(args NonceArgs) *Nonce { 34 + if args.RotationInterval == 0 { 35 + args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 + } 37 + 38 + if args.RotationInterval > constants.NonceMaxRotationInterval { 39 + args.RotationInterval = constants.NonceMaxRotationInterval 40 + } 41 + 42 + if args.Secret == nil { 43 + args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength) 44 + args.OnSecretCreated(args.Secret) 45 + } 46 + 47 + n := &Nonce{ 48 + rotationInterval: args.RotationInterval, 49 + secret: args.Secret, 50 + mu: sync.RWMutex{}, 51 + } 52 + 53 + n.counter = n.currentCounter() 54 + n.prev = n.compute(n.counter - 1) 55 + n.curr = n.compute(n.counter) 56 + n.next = n.compute(n.counter + 1) 57 + 58 + return n 59 + } 60 + 61 + func (n *Nonce) currentCounter() int64 { 62 + return time.Now().UnixNano() / int64(n.rotationInterval) 63 + } 64 + 65 + func (n *Nonce) compute(counter int64) string { 66 + h := hmac.New(sha256.New, n.secret) 67 + counterBytes := make([]byte, 8) 68 + binary.BigEndian.PutUint64(counterBytes, uint64(counter)) 69 + h.Write(counterBytes) 70 + return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 71 + } 72 + 73 + func (n *Nonce) rotate() { 74 + counter := n.currentCounter() 75 + diff := counter - n.counter 76 + 77 + switch diff { 78 + case 0: 79 + // counter == n.counter, do nothing 80 + case 1: 81 + n.prev = n.curr 82 + n.curr = n.next 83 + n.next = n.compute(counter + 1) 84 + case 2: 85 + n.prev = n.next 86 + n.curr = n.compute(counter) 87 + n.next = n.compute(counter + 1) 88 + default: 89 + n.prev = n.compute(counter - 1) 90 + n.curr = n.compute(counter) 91 + n.next = n.compute(counter + 1) 92 + } 93 + 94 + n.counter = counter 95 + } 96 + 97 + func (n *Nonce) NextNonce() string { 98 + n.mu.Lock() 99 + defer n.mu.Unlock() 100 + n.rotate() 101 + return n.next 102 + } 103 + 104 + func (n *Nonce) Check(nonce string) bool { 105 + n.mu.Lock() 106 + defer n.mu.Unlock() 107 + n.rotate() 108 + return nonce == n.prev || nonce == n.curr || nonce == n.next 109 + }
+32
oauth/helpers.go
··· 4 4 "errors" 5 5 "fmt" 6 6 "net/url" 7 + "time" 7 8 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 10 "github.com/haileyok/cocoon/oauth/constants" 11 + "github.com/haileyok/cocoon/oauth/provider" 10 12 ) 11 13 12 14 func GenerateCode() string { ··· 46 48 47 49 return reqId, nil 48 50 } 51 + 52 + type SessionAgeResult struct { 53 + SessionAge time.Duration 54 + RefreshAge time.Duration 55 + SessionExpired bool 56 + RefreshExpired bool 57 + } 58 + 59 + func GetSessionAgeFromToken(t provider.OauthToken) SessionAgeResult { 60 + sessionLifetime := constants.PublicClientSessionLifetime 61 + refreshLifetime := constants.PublicClientRefreshLifetime 62 + if t.ClientAuth.Method != "none" { 63 + sessionLifetime = constants.ConfidentialClientSessionLifetime 64 + refreshLifetime = constants.ConfidentialClientRefreshLifetime 65 + } 66 + 67 + res := SessionAgeResult{} 68 + 69 + res.SessionAge = time.Since(t.CreatedAt) 70 + if res.SessionAge > sessionLifetime { 71 + res.SessionExpired = true 72 + } 73 + 74 + refreshAge := time.Since(t.UpdatedAt) 75 + if refreshAge > refreshLifetime { 76 + res.RefreshExpired = true 77 + } 78 + 79 + return res 80 + }
+3 -26
oauth/provider/client_auth.go
··· 3 3 import ( 4 4 "context" 5 5 "crypto" 6 - "database/sql/driver" 7 6 "encoding/base64" 8 - "encoding/json" 9 7 "errors" 10 8 "fmt" 11 9 "time" 12 10 13 11 "github.com/golang-jwt/jwt/v4" 14 - "github.com/haileyok/cocoon/oauth" 12 + "github.com/haileyok/cocoon/oauth/client" 15 13 "github.com/haileyok/cocoon/oauth/constants" 16 14 "github.com/haileyok/cocoon/oauth/dpop" 17 15 ) 18 16 19 - type ClientAuth struct { 20 - Method string 21 - Alg string 22 - Kid string 23 - Jkt string 24 - Jti string 25 - Exp *float64 26 - } 27 - 28 - func (ca *ClientAuth) Scan(value any) error { 29 - b, ok := value.([]byte) 30 - if !ok { 31 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 32 - } 33 - return json.Unmarshal(b, ca) 34 - } 35 - 36 - func (ca ClientAuth) Value() (driver.Value, error) { 37 - return json.Marshal(ca) 38 - } 39 - 40 17 type AuthenticateClientOptions struct { 41 18 AllowMissingDpopProof bool 42 19 } ··· 47 24 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 48 25 } 49 26 50 - func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) { 27 + func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) { 51 28 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 52 29 if err != nil { 53 30 return nil, nil, fmt.Errorf("failed to get client: %w", err) ··· 69 46 return client, clientAuth, nil 70 47 } 71 48 72 - func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) { 49 + func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) { 73 50 metadata := client.Metadata 74 51 75 52 if metadata.TokenEndpointAuthMethod == "none" {
+83
oauth/provider/models.go
··· 1 + package provider 2 + 3 + import ( 4 + "database/sql/driver" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "gorm.io/gorm" 10 + ) 11 + 12 + type ClientAuth struct { 13 + Method string 14 + Alg string 15 + Kid string 16 + Jkt string 17 + Jti string 18 + Exp *float64 19 + } 20 + 21 + func (ca *ClientAuth) Scan(value any) error { 22 + b, ok := value.([]byte) 23 + if !ok { 24 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 25 + } 26 + return json.Unmarshal(b, ca) 27 + } 28 + 29 + func (ca ClientAuth) Value() (driver.Value, error) { 30 + return json.Marshal(ca) 31 + } 32 + 33 + type ParRequest struct { 34 + AuthenticateClientRequestBase 35 + ResponseType string `form:"response_type" json:"response_type" validate:"required"` 36 + CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 37 + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 38 + State string `form:"state" json:"state" validate:"required"` 39 + RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 40 + Scope string `form:"scope" json:"scope" validate:"required"` 41 + LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 42 + DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 43 + } 44 + 45 + func (opr *ParRequest) Scan(value any) error { 46 + b, ok := value.([]byte) 47 + if !ok { 48 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 49 + } 50 + return json.Unmarshal(b, opr) 51 + } 52 + 53 + func (opr ParRequest) Value() (driver.Value, error) { 54 + return json.Marshal(opr) 55 + } 56 + 57 + type OauthToken struct { 58 + gorm.Model 59 + ClientId string `gorm:"index"` 60 + ClientAuth ClientAuth `gorm:"type:json"` 61 + Parameters ParRequest `gorm:"type:json"` 62 + ExpiresAt time.Time `gorm:"index"` 63 + DeviceId string 64 + Sub string `gorm:"index"` 65 + Code string `gorm:"index"` 66 + Token string `gorm:"uniqueIndex"` 67 + RefreshToken string `gorm:"uniqueIndex"` 68 + Ip string 69 + } 70 + 71 + type OauthAuthorizationRequest struct { 72 + gorm.Model 73 + RequestId string `gorm:"primaryKey"` 74 + ClientId string `gorm:"index"` 75 + ClientAuth ClientAuth `gorm:"type:json"` 76 + Parameters ParRequest `gorm:"type:json"` 77 + ExpiresAt time.Time `gorm:"index"` 78 + DeviceId *string 79 + Sub *string 80 + Code *string 81 + Accepted *bool 82 + Ip string 83 + }
+8 -64
oauth/provider/provider.go
··· 1 1 package provider 2 2 3 3 import ( 4 - "database/sql/driver" 5 - "encoding/json" 6 - "fmt" 7 - "time" 8 - 9 - "github.com/haileyok/cocoon/oauth/client_manager" 10 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 11 - "gorm.io/gorm" 4 + "github.com/haileyok/cocoon/oauth/client" 5 + "github.com/haileyok/cocoon/oauth/dpop" 12 6 ) 13 7 14 8 type Provider struct { 15 - ClientManager *client_manager.ClientManager 16 - DpopManager *dpop_manager.DpopManager 9 + ClientManager *client.Manager 10 + DpopManager *dpop.Manager 17 11 18 12 hostname string 19 13 } 20 14 21 15 type Args struct { 22 16 Hostname string 23 - ClientManagerArgs client_manager.Args 24 - DpopManagerArgs dpop_manager.Args 17 + ClientManagerArgs client.ManagerArgs 18 + DpopManagerArgs dpop.ManagerArgs 25 19 } 26 20 27 21 func NewProvider(args Args) *Provider { 28 22 return &Provider{ 29 - ClientManager: client_manager.New(args.ClientManagerArgs), 30 - DpopManager: dpop_manager.New(args.DpopManagerArgs), 23 + ClientManager: client.NewManager(args.ClientManagerArgs), 24 + DpopManager: dpop.NewManager(args.DpopManagerArgs), 31 25 hostname: args.Hostname, 32 26 } 33 27 } ··· 35 29 func (p *Provider) NextNonce() string { 36 30 return p.DpopManager.NextNonce() 37 31 } 38 - 39 - type ParRequest struct { 40 - AuthenticateClientRequestBase 41 - ResponseType string `form:"response_type" json:"response_type" validate:"required"` 42 - CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 43 - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 44 - State string `form:"state" json:"state" validate:"required"` 45 - RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 46 - Scope string `form:"scope" json:"scope" validate:"required"` 47 - LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 48 - DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 49 - } 50 - 51 - func (opr *ParRequest) Scan(value any) error { 52 - b, ok := value.([]byte) 53 - if !ok { 54 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 55 - } 56 - return json.Unmarshal(b, opr) 57 - } 58 - 59 - func (opr ParRequest) Value() (driver.Value, error) { 60 - return json.Marshal(opr) 61 - } 62 - 63 - type OauthToken struct { 64 - gorm.Model 65 - ClientId string `gorm:"index"` 66 - ClientAuth ClientAuth `gorm:"type:json"` 67 - Parameters ParRequest `gorm:"type:json"` 68 - ExpiresAt time.Time `gorm:"index"` 69 - DeviceId string 70 - Sub string `gorm:"index"` 71 - Code string `gorm:"index"` 72 - Token string `gorm:"uniqueIndex"` 73 - RefreshToken string `gorm:"uniqueIndex"` 74 - } 75 - 76 - type OauthAuthorizationRequest struct { 77 - gorm.Model 78 - RequestId string `gorm:"primaryKey"` 79 - ClientId string `gorm:"index"` 80 - ClientAuth ClientAuth `gorm:"type:json"` 81 - Parameters ParRequest `gorm:"type:json"` 82 - ExpiresAt time.Time `gorm:"index"` 83 - DeviceId *string 84 - Sub *string 85 - Code *string 86 - Accepted *bool 87 - }
+36 -20
plc/client.go
··· 13 13 "net/url" 14 14 "strings" 15 15 16 - "github.com/bluesky-social/indigo/atproto/crypto" 16 + "github.com/bluesky-social/indigo/atproto/atcrypto" 17 17 "github.com/bluesky-social/indigo/util" 18 18 "github.com/haileyok/cocoon/identity" 19 19 ) ··· 22 22 h *http.Client 23 23 service string 24 24 pdsHostname string 25 - rotationKey *crypto.PrivateKeyK256 25 + rotationKey *atcrypto.PrivateKeyK256 26 26 } 27 27 28 28 type ClientArgs struct { ··· 41 41 args.H = util.RobustHTTPClient() 42 42 } 43 43 44 - rk, err := crypto.ParsePrivateBytesK256([]byte(args.RotationKey)) 44 + rk, err := atcrypto.ParsePrivateBytesK256([]byte(args.RotationKey)) 45 45 if err != nil { 46 46 return nil, err 47 47 } ··· 54 54 }, nil 55 55 } 56 56 57 - func (c *Client) CreateDID(sigkey *crypto.PrivateKeyK256, recovery string, handle string) (string, *Operation, error) { 58 - pubsigkey, err := sigkey.PublicKey() 57 + func (c *Client) CreateDID(sigkey *atcrypto.PrivateKeyK256, recovery string, handle string) (string, *Operation, error) { 58 + creds, err := c.CreateDidCredentials(sigkey, recovery, handle) 59 59 if err != nil { 60 60 return "", nil, err 61 61 } 62 62 63 - pubrotkey, err := c.rotationKey.PublicKey() 63 + op := Operation{ 64 + Type: "plc_operation", 65 + VerificationMethods: creds.VerificationMethods, 66 + RotationKeys: creds.RotationKeys, 67 + AlsoKnownAs: creds.AlsoKnownAs, 68 + Services: creds.Services, 69 + Prev: nil, 70 + } 71 + 72 + if err := c.SignOp(sigkey, &op); err != nil { 73 + return "", nil, err 74 + } 75 + 76 + did, err := DidFromOp(&op) 64 77 if err != nil { 65 78 return "", nil, err 66 79 } 67 80 81 + return did, &op, nil 82 + } 83 + 84 + func (c *Client) CreateDidCredentials(sigkey *atcrypto.PrivateKeyK256, recovery string, handle string) (*DidCredentials, error) { 85 + pubsigkey, err := sigkey.PublicKey() 86 + if err != nil { 87 + return nil, err 88 + } 89 + 90 + pubrotkey, err := c.rotationKey.PublicKey() 91 + if err != nil { 92 + return nil, err 93 + } 94 + 68 95 // todo 69 96 rotationKeys := []string{pubrotkey.DIDKey()} 70 97 if recovery != "" { ··· 77 104 }(recovery) 78 105 } 79 106 80 - op := Operation{ 81 - Type: "plc_operation", 107 + creds := DidCredentials{ 82 108 VerificationMethods: map[string]string{ 83 109 "atproto": pubsigkey.DIDKey(), 84 110 }, ··· 92 118 Endpoint: "https://" + c.pdsHostname, 93 119 }, 94 120 }, 95 - Prev: nil, 96 121 } 97 122 98 - if err := c.SignOp(sigkey, &op); err != nil { 99 - return "", nil, err 100 - } 101 - 102 - did, err := DidFromOp(&op) 103 - if err != nil { 104 - return "", nil, err 105 - } 106 - 107 - return did, &op, nil 123 + return &creds, nil 108 124 } 109 125 110 - func (c *Client) SignOp(sigkey *crypto.PrivateKeyK256, op *Operation) error { 126 + func (c *Client) SignOp(sigkey *atcrypto.PrivateKeyK256, op *Operation) error { 111 127 b, err := op.MarshalCBOR() 112 128 if err != nil { 113 129 return err
+10 -2
plc/types.go
··· 3 3 import ( 4 4 "encoding/json" 5 5 6 - "github.com/bluesky-social/indigo/atproto/data" 6 + "github.com/bluesky-social/indigo/atproto/atdata" 7 7 "github.com/haileyok/cocoon/identity" 8 8 cbg "github.com/whyrusleeping/cbor-gen" 9 9 ) 10 + 11 + 12 + type DidCredentials struct { 13 + VerificationMethods map[string]string `json:"verificationMethods"` 14 + RotationKeys []string `json:"rotationKeys"` 15 + AlsoKnownAs []string `json:"alsoKnownAs"` 16 + Services map[string]identity.OperationService `json:"services"` 17 + } 10 18 11 19 type Operation struct { 12 20 Type string `json:"type"` ··· 38 46 return nil, err 39 47 } 40 48 41 - b, err = data.MarshalCBOR(m) 49 + b, err = atdata.MarshalCBOR(m) 42 50 if err != nil { 43 51 return nil, err 44 52 }
+85
recording_blockstore/recording_blockstore.go
··· 1 + package recording_blockstore 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + 7 + blockformat "github.com/ipfs/go-block-format" 8 + "github.com/ipfs/go-cid" 9 + blockstore "github.com/ipfs/go-ipfs-blockstore" 10 + ) 11 + 12 + type RecordingBlockstore struct { 13 + base blockstore.Blockstore 14 + 15 + inserts map[cid.Cid]blockformat.Block 16 + reads map[cid.Cid]blockformat.Block 17 + } 18 + 19 + func New(base blockstore.Blockstore) *RecordingBlockstore { 20 + return &RecordingBlockstore{ 21 + base: base, 22 + inserts: make(map[cid.Cid]blockformat.Block), 23 + reads: make(map[cid.Cid]blockformat.Block), 24 + } 25 + } 26 + 27 + func (bs *RecordingBlockstore) Has(ctx context.Context, c cid.Cid) (bool, error) { 28 + return bs.base.Has(ctx, c) 29 + } 30 + 31 + func (bs *RecordingBlockstore) Get(ctx context.Context, c cid.Cid) (blockformat.Block, error) { 32 + b, err := bs.base.Get(ctx, c) 33 + if err != nil { 34 + return nil, err 35 + } 36 + bs.reads[c] = b 37 + return b, nil 38 + } 39 + 40 + func (bs *RecordingBlockstore) GetSize(ctx context.Context, c cid.Cid) (int, error) { 41 + return bs.base.GetSize(ctx, c) 42 + } 43 + 44 + func (bs *RecordingBlockstore) DeleteBlock(ctx context.Context, c cid.Cid) error { 45 + return bs.base.DeleteBlock(ctx, c) 46 + } 47 + 48 + func (bs *RecordingBlockstore) Put(ctx context.Context, block blockformat.Block) error { 49 + if err := bs.base.Put(ctx, block); err != nil { 50 + return err 51 + } 52 + bs.inserts[block.Cid()] = block 53 + return nil 54 + } 55 + 56 + func (bs *RecordingBlockstore) PutMany(ctx context.Context, blocks []blockformat.Block) error { 57 + if err := bs.base.PutMany(ctx, blocks); err != nil { 58 + return err 59 + } 60 + 61 + for _, b := range blocks { 62 + bs.inserts[b.Cid()] = b 63 + } 64 + 65 + return nil 66 + } 67 + 68 + func (bs *RecordingBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 69 + return nil, fmt.Errorf("iteration not allowed on recording blockstore") 70 + } 71 + 72 + func (bs *RecordingBlockstore) HashOnRead(enabled bool) { 73 + } 74 + 75 + func (bs *RecordingBlockstore) GetWriteLog() map[cid.Cid]blockformat.Block { 76 + return bs.inserts 77 + } 78 + 79 + func (bs *RecordingBlockstore) GetReadLog() []blockformat.Block { 80 + var blocks []blockformat.Block 81 + for _, b := range bs.reads { 82 + blocks = append(blocks, b) 83 + } 84 + return blocks 85 + }
+30
server/blockstore_variant.go
··· 1 + package server 2 + 3 + import ( 4 + "github.com/haileyok/cocoon/sqlite_blockstore" 5 + blockstore "github.com/ipfs/go-ipfs-blockstore" 6 + ) 7 + 8 + type BlockstoreVariant int 9 + 10 + const ( 11 + BlockstoreVariantSqlite = iota 12 + ) 13 + 14 + func MustReturnBlockstoreVariant(maybeBsv string) BlockstoreVariant { 15 + switch maybeBsv { 16 + case "sqlite": 17 + return BlockstoreVariantSqlite 18 + default: 19 + panic("invalid blockstore variant provided") 20 + } 21 + } 22 + 23 + func (s *Server) getBlockstore(did string) blockstore.Blockstore { 24 + switch s.config.BlockstoreVariant { 25 + case BlockstoreVariantSqlite: 26 + return sqlite_blockstore.New(did, s.db) 27 + default: 28 + return sqlite_blockstore.New(did, s.db) 29 + } 30 + }
+10 -8
server/common.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 5 + 4 6 "github.com/haileyok/cocoon/models" 5 7 ) 6 8 7 - func (s *Server) getActorByHandle(handle string) (*models.Actor, error) { 9 + func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) { 8 10 var actor models.Actor 9 - if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil { 11 + if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil { 10 12 return nil, err 11 13 } 12 14 return &actor, nil 13 15 } 14 16 15 - func (s *Server) getRepoByEmail(email string) (*models.Repo, error) { 17 + func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) { 16 18 var repo models.Repo 17 - if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil { 19 + if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil { 18 20 return nil, err 19 21 } 20 22 return &repo, nil 21 23 } 22 24 23 - func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) { 25 + func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) { 24 26 var repo models.RepoActor 25 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 27 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 26 28 return nil, err 27 29 } 28 30 return &repo, nil 29 31 } 30 32 31 - func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) { 33 + func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) { 32 34 var repo models.RepoActor 33 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 35 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 34 36 return nil, err 35 37 } 36 38 return &repo, nil
+40 -8
server/handle_account.go
··· 3 3 import ( 4 4 "time" 5 5 6 + "github.com/haileyok/cocoon/oauth" 7 + "github.com/haileyok/cocoon/oauth/constants" 6 8 "github.com/haileyok/cocoon/oauth/provider" 9 + "github.com/hako/durafmt" 7 10 "github.com/labstack/echo/v4" 8 11 ) 9 12 10 13 func (s *Server) handleAccount(e echo.Context) error { 14 + ctx := e.Request().Context() 15 + logger := s.logger.With("name", "handleAuth") 16 + 11 17 repo, sess, err := s.getSessionRepoOrErr(e) 12 18 if err != nil { 13 19 return e.Redirect(303, "/account/signin") 14 20 } 15 21 16 - now := time.Now() 22 + oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) 17 23 18 24 var tokens []provider.OauthToken 19 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND expires_at >= ? ORDER BY created_at ASC", nil, repo.Repo.Did, now).Scan(&tokens).Error; err != nil { 20 - s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 25 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 26 + logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 21 27 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 22 28 sess.Save(e.Request(), e.Response()) 23 29 return e.Render(200, "account.html", map[string]any{ ··· 25 31 }) 26 32 } 27 33 34 + var filtered []provider.OauthToken 35 + for _, t := range tokens { 36 + ageRes := oauth.GetSessionAgeFromToken(t) 37 + if ageRes.SessionExpired { 38 + continue 39 + } 40 + filtered = append(filtered, t) 41 + } 42 + 43 + now := time.Now() 44 + 28 45 tokenInfo := []map[string]string{} 29 46 for _, t := range tokens { 47 + ageRes := oauth.GetSessionAgeFromToken(t) 48 + maxTime := constants.PublicClientSessionLifetime 49 + if t.ClientAuth.Method != "none" { 50 + maxTime = constants.ConfidentialClientSessionLifetime 51 + } 52 + 53 + var clientName string 54 + metadata, err := s.oauthProvider.ClientManager.GetClient(ctx, t.ClientId) 55 + if err != nil { 56 + clientName = t.ClientId 57 + } else { 58 + clientName = metadata.Metadata.ClientName 59 + } 60 + 30 61 tokenInfo = append(tokenInfo, map[string]string{ 31 - "ClientId": t.ClientId, 32 - "CreatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 33 - "UpdatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 34 - "ExpiresAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 35 - "Token": t.Token, 62 + "ClientName": clientName, 63 + "Age": durafmt.Parse(ageRes.SessionAge).LimitFirstN(2).String(), 64 + "LastUpdated": durafmt.Parse(now.Sub(t.UpdatedAt)).LimitFirstN(2).String(), 65 + "ExpiresIn": durafmt.Parse(now.Add(maxTime).Sub(now)).LimitFirstN(2).String(), 66 + "Token": t.Token, 67 + "Ip": t.Ip, 36 68 }) 37 69 } 38 70
+8 -5
server/handle_account_revoke.go
··· 5 5 "github.com/labstack/echo/v4" 6 6 ) 7 7 8 - type AccountRevokeRequest struct { 8 + type AccountRevokeInput struct { 9 9 Token string `form:"token"` 10 10 } 11 11 12 12 func (s *Server) handleAccountRevoke(e echo.Context) error { 13 - var req AccountRevokeRequest 13 + ctx := e.Request().Context() 14 + logger := s.logger.With("name", "handleAcocuntRevoke") 15 + 16 + var req AccountRevokeInput 14 17 if err := e.Bind(&req); err != nil { 15 - s.logger.Error("could not bind account revoke request", "error", err) 18 + logger.Error("could not bind account revoke request", "error", err) 16 19 return helpers.ServerError(e, nil) 17 20 } 18 21 ··· 21 24 return e.Redirect(303, "/account/signin") 22 25 } 23 26 24 - if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 25 - s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 27 + if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 28 + logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 26 29 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") 27 30 sess.Save(e.Request(), e.Response()) 28 31 return e.Redirect(303, "/account")
+68 -16
server/handle_account_signin.go
··· 2 2 3 3 import ( 4 4 "errors" 5 + "fmt" 5 6 "strings" 7 + "time" 6 8 7 9 "github.com/bluesky-social/indigo/atproto/syntax" 8 10 "github.com/gorilla/sessions" ··· 14 16 "gorm.io/gorm" 15 17 ) 16 18 17 - type OauthSigninRequest struct { 18 - Username string `form:"username"` 19 - Password string `form:"password"` 20 - QueryParams string `form:"query_params"` 19 + type OauthSigninInput struct { 20 + Username string `form:"username"` 21 + Password string `form:"password"` 22 + AuthFactorToken string `form:"token"` 23 + QueryParams string `form:"query_params"` 21 24 } 22 25 23 26 func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 27 + ctx := e.Request().Context() 28 + 24 29 sess, err := session.Get("session", e) 25 30 if err != nil { 26 31 return nil, nil, err ··· 31 36 return nil, sess, errors.New("did was not set in session") 32 37 } 33 38 34 - repo, err := s.getRepoActorByDid(did) 39 + repo, err := s.getRepoActorByDid(ctx, did) 35 40 if err != nil { 36 41 return nil, sess, err 37 42 } ··· 42 47 func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any { 43 48 defer sess.Save(e.Request(), e.Response()) 44 49 return map[string]any{ 45 - "errors": sess.Flashes("error"), 46 - "successes": sess.Flashes("success"), 50 + "errors": sess.Flashes("error"), 51 + "successes": sess.Flashes("success"), 52 + "tokenrequired": sess.Flashes("tokenrequired"), 47 53 } 48 54 } 49 55 ··· 60 66 } 61 67 62 68 func (s *Server) handleAccountSigninPost(e echo.Context) error { 63 - var req OauthSigninRequest 69 + ctx := e.Request().Context() 70 + logger := s.logger.With("name", "handleAccountSigninPost") 71 + 72 + var req OauthSigninInput 64 73 if err := e.Bind(&req); err != nil { 65 - s.logger.Error("error binding sign in req", "error", err) 74 + logger.Error("error binding sign in req", "error", err) 66 75 return helpers.ServerError(e, nil) 67 76 } 68 77 ··· 76 85 idtype = "handle" 77 86 } else { 78 87 idtype = "email" 88 + } 89 + 90 + queryParams := "" 91 + if req.QueryParams != "" { 92 + queryParams = fmt.Sprintf("?%s", req.QueryParams) 79 93 } 80 94 81 95 // TODO: we should make this a helper since we do it for the base create_session as well ··· 83 97 var err error 84 98 switch idtype { 85 99 case "did": 86 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 100 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 87 101 case "handle": 88 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 102 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 89 103 case "email": 90 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 104 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 91 105 } 92 106 if err != nil { 93 107 if err == gorm.ErrRecordNotFound { ··· 96 110 sess.AddFlash("Something went wrong!", "error") 97 111 } 98 112 sess.Save(e.Request(), e.Response()) 99 - return e.Redirect(303, "/account/signin") 113 + return e.Redirect(303, "/account/signin"+queryParams) 100 114 } 101 115 102 116 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { ··· 106 120 sess.AddFlash("Something went wrong!", "error") 107 121 } 108 122 sess.Save(e.Request(), e.Response()) 109 - return e.Redirect(303, "/account/signin") 123 + return e.Redirect(303, "/account/signin"+queryParams) 124 + } 125 + 126 + // if repo requires 2FA token and one hasn't been provided, return error prompting for one 127 + if repo.TwoFactorType != models.TwoFactorTypeNone && req.AuthFactorToken == "" { 128 + err = s.createAndSendTwoFactorCode(ctx, repo) 129 + if err != nil { 130 + sess.AddFlash("Something went wrong!", "error") 131 + sess.Save(e.Request(), e.Response()) 132 + return e.Redirect(303, "/account/signin"+queryParams) 133 + } 134 + 135 + sess.AddFlash("requires 2FA token", "tokenrequired") 136 + sess.Save(e.Request(), e.Response()) 137 + return e.Redirect(303, "/account/signin"+queryParams) 138 + } 139 + 140 + // if 2FAis required, now check that the one provided is valid 141 + if repo.TwoFactorType != models.TwoFactorTypeNone { 142 + if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil { 143 + err = s.createAndSendTwoFactorCode(ctx, repo) 144 + if err != nil { 145 + sess.AddFlash("Something went wrong!", "error") 146 + sess.Save(e.Request(), e.Response()) 147 + return e.Redirect(303, "/account/signin"+queryParams) 148 + } 149 + 150 + sess.AddFlash("requires 2FA token", "tokenrequired") 151 + sess.Save(e.Request(), e.Response()) 152 + return e.Redirect(303, "/account/signin"+queryParams) 153 + } 154 + 155 + if *repo.TwoFactorCode != req.AuthFactorToken { 156 + return helpers.InvalidTokenError(e) 157 + } 158 + 159 + if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 160 + return helpers.ExpiredTokenError(e) 161 + } 110 162 } 111 163 112 164 sess.Options = &sessions.Options{ ··· 122 174 return err 123 175 } 124 176 125 - if req.QueryParams != "" { 126 - return e.Redirect(303, "/oauth/authorize?"+req.QueryParams) 177 + if queryParams != "" { 178 + return e.Redirect(303, "/oauth/authorize"+queryParams) 127 179 } else { 128 180 return e.Redirect(303, "/account") 129 181 }
+1 -1
server/handle_actor_get_preferences.go
··· 16 16 err := json.Unmarshal(repo.Preferences, &prefs) 17 17 if err != nil || prefs["preferences"] == nil { 18 18 prefs = map[string]any{ 19 - "preferences": map[string]any{}, 19 + "preferences": []any{}, 20 20 } 21 21 } 22 22
+3 -1
server/handle_actor_put_preferences.go
··· 10 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 11 12 12 func (s *Server) handleActorPutPreferences(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 13 15 repo := e.Get("repo").(*models.RepoActor) 14 16 15 17 var prefs map[string]any ··· 22 24 return err 23 25 } 24 26 25 - if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 27 + if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 26 28 return err 27 29 } 28 30
+32
server/handle_identity_request_plc_operation.go
··· 1 + package server 2 + 3 + import ( 4 + "fmt" 5 + "time" 6 + 7 + "github.com/haileyok/cocoon/internal/helpers" 8 + "github.com/haileyok/cocoon/models" 9 + "github.com/labstack/echo/v4" 10 + ) 11 + 12 + func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + logger := s.logger.With("name", "handleIdentityRequestPlcOperationSignature") 15 + 16 + urepo := e.Get("repo").(*models.RepoActor) 17 + 18 + code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 19 + eat := time.Now().Add(10 * time.Minute).UTC() 20 + 21 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 22 + logger.Error("error updating user", "error", err) 23 + return helpers.ServerError(e, nil) 24 + } 25 + 26 + if err := s.sendPlcTokenReset(urepo.Email, urepo.Handle, code); err != nil { 27 + logger.Error("error sending mail", "error", err) 28 + return helpers.ServerError(e, nil) 29 + } 30 + 31 + return e.NoContent(200) 32 + }
+105
server/handle_identity_sign_plc_operation.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "strings" 6 + "time" 7 + 8 + "github.com/Azure/go-autorest/autorest/to" 9 + "github.com/bluesky-social/indigo/atproto/atcrypto" 10 + "github.com/haileyok/cocoon/identity" 11 + "github.com/haileyok/cocoon/internal/helpers" 12 + "github.com/haileyok/cocoon/models" 13 + "github.com/haileyok/cocoon/plc" 14 + "github.com/labstack/echo/v4" 15 + ) 16 + 17 + type ComAtprotoSignPlcOperationRequest struct { 18 + Token string `json:"token"` 19 + VerificationMethods *map[string]string `json:"verificationMethods"` 20 + RotationKeys *[]string `json:"rotationKeys"` 21 + AlsoKnownAs *[]string `json:"alsoKnownAs"` 22 + Services *map[string]identity.OperationService `json:"services"` 23 + } 24 + 25 + type ComAtprotoSignPlcOperationResponse struct { 26 + Operation plc.Operation `json:"operation"` 27 + } 28 + 29 + func (s *Server) handleSignPlcOperation(e echo.Context) error { 30 + logger := s.logger.With("name", "handleSignPlcOperation") 31 + 32 + repo := e.Get("repo").(*models.RepoActor) 33 + 34 + var req ComAtprotoSignPlcOperationRequest 35 + if err := e.Bind(&req); err != nil { 36 + logger.Error("error binding", "error", err) 37 + return helpers.ServerError(e, nil) 38 + } 39 + 40 + if !strings.HasPrefix(repo.Repo.Did, "did:plc:") { 41 + return helpers.InputError(e, nil) 42 + } 43 + 44 + if repo.PlcOperationCode == nil || repo.PlcOperationCodeExpiresAt == nil { 45 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 46 + } 47 + 48 + if *repo.PlcOperationCode != req.Token { 49 + return helpers.InvalidTokenError(e) 50 + } 51 + 52 + if time.Now().UTC().After(*repo.PlcOperationCodeExpiresAt) { 53 + return helpers.ExpiredTokenError(e) 54 + } 55 + 56 + ctx := context.WithValue(e.Request().Context(), "skip-cache", true) 57 + log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did) 58 + if err != nil { 59 + logger.Error("error fetching doc", "error", err) 60 + return helpers.ServerError(e, nil) 61 + } 62 + 63 + latest := log[len(log)-1] 64 + 65 + op := plc.Operation{ 66 + Type: "plc_operation", 67 + VerificationMethods: latest.Operation.VerificationMethods, 68 + RotationKeys: latest.Operation.RotationKeys, 69 + AlsoKnownAs: latest.Operation.AlsoKnownAs, 70 + Services: latest.Operation.Services, 71 + Prev: &latest.Cid, 72 + } 73 + if req.VerificationMethods != nil { 74 + op.VerificationMethods = *req.VerificationMethods 75 + } 76 + if req.RotationKeys != nil { 77 + op.RotationKeys = *req.RotationKeys 78 + } 79 + if req.AlsoKnownAs != nil { 80 + op.AlsoKnownAs = *req.AlsoKnownAs 81 + } 82 + if req.Services != nil { 83 + op.Services = *req.Services 84 + } 85 + 86 + k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 87 + if err != nil { 88 + logger.Error("error parsing signing key", "error", err) 89 + return helpers.ServerError(e, nil) 90 + } 91 + 92 + if err := s.plcClient.SignOp(k, &op); err != nil { 93 + logger.Error("error signing plc operation", "error", err) 94 + return helpers.ServerError(e, nil) 95 + } 96 + 97 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 98 + logger.Error("error updating repo", "error", err) 99 + return helpers.ServerError(e, nil) 100 + } 101 + 102 + return e.JSON(200, ComAtprotoSignPlcOperationResponse{ 103 + Operation: op, 104 + }) 105 + }
+89
server/handle_identity_submit_plc_operation.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "slices" 6 + "strings" 7 + "time" 8 + 9 + "github.com/bluesky-social/indigo/api/atproto" 10 + "github.com/bluesky-social/indigo/atproto/atcrypto" 11 + "github.com/bluesky-social/indigo/events" 12 + "github.com/bluesky-social/indigo/util" 13 + "github.com/haileyok/cocoon/internal/helpers" 14 + "github.com/haileyok/cocoon/models" 15 + "github.com/haileyok/cocoon/plc" 16 + "github.com/labstack/echo/v4" 17 + ) 18 + 19 + type ComAtprotoSubmitPlcOperationRequest struct { 20 + Operation plc.Operation `json:"operation"` 21 + } 22 + 23 + func (s *Server) handleSubmitPlcOperation(e echo.Context) error { 24 + logger := s.logger.With("name", "handleIdentitySubmitPlcOperation") 25 + 26 + repo := e.Get("repo").(*models.RepoActor) 27 + 28 + var req ComAtprotoSubmitPlcOperationRequest 29 + if err := e.Bind(&req); err != nil { 30 + logger.Error("error binding", "error", err) 31 + return helpers.ServerError(e, nil) 32 + } 33 + 34 + if err := e.Validate(req); err != nil { 35 + return helpers.InputError(e, nil) 36 + } 37 + if !strings.HasPrefix(repo.Repo.Did, "did:plc:") { 38 + return helpers.InputError(e, nil) 39 + } 40 + 41 + op := req.Operation 42 + 43 + k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 44 + if err != nil { 45 + logger.Error("error parsing key", "error", err) 46 + return helpers.ServerError(e, nil) 47 + } 48 + required, err := s.plcClient.CreateDidCredentials(k, "", repo.Actor.Handle) 49 + if err != nil { 50 + logger.Error("error crating did credentials", "error", err) 51 + return helpers.ServerError(e, nil) 52 + } 53 + 54 + for _, expectedKey := range required.RotationKeys { 55 + if !slices.Contains(op.RotationKeys, expectedKey) { 56 + return helpers.InputError(e, nil) 57 + } 58 + } 59 + if op.Services["atproto_pds"].Type != "AtprotoPersonalDataServer" { 60 + return helpers.InputError(e, nil) 61 + } 62 + if op.Services["atproto_pds"].Endpoint != required.Services["atproto_pds"].Endpoint { 63 + return helpers.InputError(e, nil) 64 + } 65 + if op.VerificationMethods["atproto"] != required.VerificationMethods["atproto"] { 66 + return helpers.InputError(e, nil) 67 + } 68 + if op.AlsoKnownAs[0] != required.AlsoKnownAs[0] { 69 + return helpers.InputError(e, nil) 70 + } 71 + 72 + if err := s.plcClient.SendOperation(e.Request().Context(), repo.Repo.Did, &op); err != nil { 73 + return err 74 + } 75 + 76 + if err := s.passport.BustDoc(context.TODO(), repo.Repo.Did); err != nil { 77 + logger.Warn("error busting did doc", "error", err) 78 + } 79 + 80 + s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 81 + RepoIdentity: &atproto.SyncSubscribeRepos_Identity{ 82 + Did: repo.Repo.Did, 83 + Seq: time.Now().UnixMicro(), // TODO: no 84 + Time: time.Now().Format(util.ISO8601), 85 + }, 86 + }) 87 + 88 + return nil 89 + }
+10 -17
server/handle_identity_update_handle.go
··· 7 7 8 8 "github.com/Azure/go-autorest/autorest/to" 9 9 "github.com/bluesky-social/indigo/api/atproto" 10 - "github.com/bluesky-social/indigo/atproto/crypto" 10 + "github.com/bluesky-social/indigo/atproto/atcrypto" 11 11 "github.com/bluesky-social/indigo/events" 12 12 "github.com/bluesky-social/indigo/util" 13 13 "github.com/haileyok/cocoon/identity" ··· 22 22 } 23 23 24 24 func (s *Server) handleIdentityUpdateHandle(e echo.Context) error { 25 + logger := s.logger.With("name", "handleIdentityUpdateHandle") 26 + 25 27 repo := e.Get("repo").(*models.RepoActor) 26 28 27 29 var req ComAtprotoIdentityUpdateHandleRequest 28 30 if err := e.Bind(&req); err != nil { 29 - s.logger.Error("error binding", "error", err) 31 + logger.Error("error binding", "error", err) 30 32 return helpers.ServerError(e, nil) 31 33 } 32 34 ··· 41 43 if strings.HasPrefix(repo.Repo.Did, "did:plc:") { 42 44 log, err := identity.FetchDidAuditLog(ctx, nil, repo.Repo.Did) 43 45 if err != nil { 44 - s.logger.Error("error fetching doc", "error", err) 46 + logger.Error("error fetching doc", "error", err) 45 47 return helpers.ServerError(e, nil) 46 48 } 47 49 ··· 66 68 Prev: &latest.Cid, 67 69 } 68 70 69 - k, err := crypto.ParsePrivateBytesK256(repo.SigningKey) 71 + k, err := atcrypto.ParsePrivateBytesK256(repo.SigningKey) 70 72 if err != nil { 71 - s.logger.Error("error parsing signing key", "error", err) 73 + logger.Error("error parsing signing key", "error", err) 72 74 return helpers.ServerError(e, nil) 73 75 } 74 76 ··· 82 84 } 83 85 84 86 if err := s.passport.BustDoc(context.TODO(), repo.Repo.Did); err != nil { 85 - s.logger.Warn("error busting did doc", "error", err) 87 + logger.Warn("error busting did doc", "error", err) 86 88 } 87 89 88 90 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 89 - RepoHandle: &atproto.SyncSubscribeRepos_Handle{ 90 - Did: repo.Repo.Did, 91 - Handle: req.Handle, 92 - Seq: time.Now().UnixMicro(), // TODO: no 93 - Time: time.Now().Format(util.ISO8601), 94 - }, 95 - }) 96 - 97 - s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 98 91 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{ 99 92 Did: repo.Repo.Did, 100 93 Handle: to.StringPtr(req.Handle), ··· 103 96 }, 104 97 }) 105 98 106 - if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 107 - s.logger.Error("error updating handle in db", "error", err) 99 + if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 100 + logger.Error("error updating handle in db", "error", err) 108 101 return helpers.ServerError(e, nil) 109 102 } 110 103
+17 -15
server/handle_import_repo.go
··· 9 9 10 10 "github.com/bluesky-social/indigo/atproto/syntax" 11 11 "github.com/bluesky-social/indigo/repo" 12 - "github.com/haileyok/cocoon/blockstore" 13 12 "github.com/haileyok/cocoon/internal/helpers" 14 13 "github.com/haileyok/cocoon/models" 15 14 blocks "github.com/ipfs/go-block-format" ··· 19 18 ) 20 19 21 20 func (s *Server) handleRepoImportRepo(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + logger := s.logger.With("name", "handleImportRepo") 23 + 22 24 urepo := e.Get("repo").(*models.RepoActor) 23 25 24 26 b, err := io.ReadAll(e.Request().Body) 25 27 if err != nil { 26 - s.logger.Error("could not read bytes in import request", "error", err) 28 + logger.Error("could not read bytes in import request", "error", err) 27 29 return helpers.ServerError(e, nil) 28 30 } 29 31 30 - bs := blockstore.New(urepo.Repo.Did, s.db) 32 + bs := s.getBlockstore(urepo.Repo.Did) 31 33 32 34 cs, err := car.NewCarReader(bytes.NewReader(b)) 33 35 if err != nil { 34 - s.logger.Error("could not read car in import request", "error", err) 36 + logger.Error("could not read car in import request", "error", err) 35 37 return helpers.ServerError(e, nil) 36 38 } 37 39 38 40 orderedBlocks := []blocks.Block{} 39 41 currBlock, err := cs.Next() 40 42 if err != nil { 41 - s.logger.Error("could not get first block from car", "error", err) 43 + logger.Error("could not get first block from car", "error", err) 42 44 return helpers.ServerError(e, nil) 43 45 } 44 46 currBlockCt := 1 45 47 46 48 for currBlock != nil { 47 - s.logger.Info("someone is importing their repo", "block", currBlockCt) 49 + logger.Info("someone is importing their repo", "block", currBlockCt) 48 50 orderedBlocks = append(orderedBlocks, currBlock) 49 51 next, _ := cs.Next() 50 52 currBlock = next ··· 54 56 slices.Reverse(orderedBlocks) 55 57 56 58 if err := bs.PutMany(context.TODO(), orderedBlocks); err != nil { 57 - s.logger.Error("could not insert blocks", "error", err) 59 + logger.Error("could not insert blocks", "error", err) 58 60 return helpers.ServerError(e, nil) 59 61 } 60 62 61 63 r, err := repo.OpenRepo(context.TODO(), bs, cs.Header.Roots[0]) 62 64 if err != nil { 63 - s.logger.Error("could not open repo", "error", err) 65 + logger.Error("could not open repo", "error", err) 64 66 return helpers.ServerError(e, nil) 65 67 } 66 68 67 - tx := s.db.BeginDangerously() 69 + tx := s.db.BeginDangerously(ctx) 68 70 69 71 clock := syntax.NewTIDClock(0) 70 72 ··· 75 77 cidStr := cid.String() 76 78 b, err := bs.Get(context.TODO(), cid) 77 79 if err != nil { 78 - s.logger.Error("record bytes don't exist in blockstore", "error", err) 80 + logger.Error("record bytes don't exist in blockstore", "error", err) 79 81 return helpers.ServerError(e, nil) 80 82 } 81 83 ··· 88 90 Value: b.RawData(), 89 91 } 90 92 91 - if err := tx.Create(rec).Error; err != nil { 93 + if err := tx.Save(rec).Error; err != nil { 92 94 return err 93 95 } 94 96 95 97 return nil 96 98 }); err != nil { 97 99 tx.Rollback() 98 - s.logger.Error("record bytes don't exist in blockstore", "error", err) 100 + logger.Error("record bytes don't exist in blockstore", "error", err) 99 101 return helpers.ServerError(e, nil) 100 102 } 101 103 ··· 103 105 104 106 root, rev, err := r.Commit(context.TODO(), urepo.SignFor) 105 107 if err != nil { 106 - s.logger.Error("error committing", "error", err) 108 + logger.Error("error committing", "error", err) 107 109 return helpers.ServerError(e, nil) 108 110 } 109 111 110 - if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil { 111 - s.logger.Error("error updating repo after commit", "error", err) 112 + if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil { 113 + logger.Error("error updating repo after commit", "error", err) 112 114 return helpers.ServerError(e, nil) 113 115 } 114 116
+34
server/handle_label_query_labels.go
··· 1 + package server 2 + 3 + import ( 4 + "github.com/labstack/echo/v4" 5 + ) 6 + 7 + type Label struct { 8 + Ver *int `json:"ver,omitempty"` 9 + Src string `json:"src"` 10 + Uri string `json:"uri"` 11 + Cid *string `json:"cid,omitempty"` 12 + Val string `json:"val"` 13 + Neg *bool `json:"neg,omitempty"` 14 + Cts string `json:"cts"` 15 + Exp *string `json:"exp,omitempty"` 16 + Sig []byte `json:"sig,omitempty"` 17 + } 18 + 19 + type ComAtprotoLabelQueryLabelsResponse struct { 20 + Cursor *string `json:"cursor,omitempty"` 21 + Labels []Label `json:"labels"` 22 + } 23 + 24 + func (s *Server) handleLabelQueryLabels(e echo.Context) error { 25 + svc := e.Request().Header.Get("atproto-proxy") 26 + if svc != "" || s.config.FallbackProxy != "" { 27 + return s.handleProxy(e) 28 + } 29 + 30 + return e.JSON(200, ComAtprotoLabelQueryLabelsResponse{ 31 + Cursor: nil, 32 + Labels: []Label{}, 33 + }) 34 + }
+10 -5
server/handle_oauth_authorize.go
··· 13 13 ) 14 14 15 15 func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + 16 18 reqUri := e.QueryParam("request_uri") 17 19 if reqUri == "" { 18 20 // render page for logged out dev ··· 38 40 } 39 41 40 42 var req provider.OauthAuthorizationRequest 41 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 43 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 42 44 return helpers.ServerError(e, to.StringPtr(err.Error())) 43 45 } 44 46 ··· 72 74 } 73 75 74 76 func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 77 + ctx := e.Request().Context() 78 + logger := s.logger.With("name", "handleOauthAuthorizePost") 79 + 75 80 repo, _, err := s.getSessionRepoOrErr(e) 76 81 if err != nil { 77 82 return e.Redirect(303, "/account/signin") ··· 79 84 80 85 var req OauthAuthorizePostRequest 81 86 if err := e.Bind(&req); err != nil { 82 - s.logger.Error("error binding authorize post request", "error", err) 87 + logger.Error("error binding authorize post request", "error", err) 83 88 return helpers.InputError(e, nil) 84 89 } 85 90 ··· 89 94 } 90 95 91 96 var authReq provider.OauthAuthorizationRequest 92 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 97 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 93 98 return helpers.ServerError(e, to.StringPtr(err.Error())) 94 99 } 95 100 ··· 113 118 114 119 code := oauth.GenerateCode() 115 120 116 - if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, reqId).Error; err != nil { 117 - s.logger.Error("error updating authorization request", "error", err) 121 + if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 122 + logger.Error("error updating authorization request", "error", err) 118 123 return helpers.ServerError(e, nil) 119 124 } 120 125
+24 -9
server/handle_oauth_par.go
··· 1 1 package server 2 2 3 3 import ( 4 + "errors" 4 5 "time" 5 6 6 7 "github.com/Azure/go-autorest/autorest/to" 7 8 "github.com/haileyok/cocoon/internal/helpers" 8 9 "github.com/haileyok/cocoon/oauth" 9 10 "github.com/haileyok/cocoon/oauth/constants" 11 + "github.com/haileyok/cocoon/oauth/dpop" 10 12 "github.com/haileyok/cocoon/oauth/provider" 11 13 "github.com/labstack/echo/v4" 12 14 ) ··· 17 19 } 18 20 19 21 func (s *Server) handleOauthPar(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + logger := s.logger.With("name", "handleOauthPar") 24 + 20 25 var parRequest provider.ParRequest 21 26 if err := e.Bind(&parRequest); err != nil { 22 - s.logger.Error("error binding for par request", "error", err) 27 + logger.Error("error binding for par request", "error", err) 23 28 return helpers.ServerError(e, nil) 24 29 } 25 30 26 31 if err := e.Validate(parRequest); err != nil { 27 - s.logger.Error("missing parameters for par request", "error", err) 32 + logger.Error("missing parameters for par request", "error", err) 28 33 return helpers.InputError(e, nil) 29 34 } 30 35 31 36 // TODO: this seems wrong. should be a way to get the entire request url i believe, but this will work for now 32 37 dpopProof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, nil) 33 38 if err != nil { 34 - s.logger.Error("error getting dpop proof", "error", err) 35 - return helpers.InputError(e, to.StringPtr(err.Error())) 39 + if errors.Is(err, dpop.ErrUseDpopNonce) { 40 + nonce := s.oauthProvider.NextNonce() 41 + if nonce != "" { 42 + e.Response().Header().Set("DPoP-Nonce", nonce) 43 + e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 44 + } 45 + return e.JSON(400, map[string]string{ 46 + "error": "use_dpop_nonce", 47 + }) 48 + } 49 + logger.Error("error getting dpop proof", "error", err) 50 + return helpers.InputError(e, nil) 36 51 } 37 52 38 53 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), parRequest.AuthenticateClientRequestBase, dpopProof, &provider.AuthenticateClientOptions{ ··· 41 56 AllowMissingDpopProof: true, 42 57 }) 43 58 if err != nil { 44 - s.logger.Error("error authenticating client", "error", err) 59 + logger.Error("error authenticating client", "client_id", parRequest.ClientID, "error", err) 45 60 return helpers.InputError(e, to.StringPtr(err.Error())) 46 61 } 47 62 ··· 52 67 } else { 53 68 if !client.Metadata.DpopBoundAccessTokens { 54 69 msg := "dpop bound access tokens are not enabled for this client" 55 - s.logger.Error(msg) 70 + logger.Error(msg) 56 71 return helpers.InputError(e, &msg) 57 72 } 58 73 59 74 if dpopProof.JKT != *parRequest.DpopJkt { 60 75 msg := "supplied dpop jkt does not match header dpop jkt" 61 - s.logger.Error(msg) 76 + logger.Error(msg) 62 77 return helpers.InputError(e, &msg) 63 78 } 64 79 } ··· 74 89 ExpiresAt: eat, 75 90 } 76 91 77 - if err := s.db.Create(authRequest, nil).Error; err != nil { 78 - s.logger.Error("error creating auth request in db", "error", err) 92 + if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 93 + logger.Error("error creating auth request in db", "error", err) 79 94 return helpers.ServerError(e, nil) 80 95 } 81 96
+33 -24
server/handle_oauth_token.go
··· 4 4 "bytes" 5 5 "crypto/sha256" 6 6 "encoding/base64" 7 + "errors" 7 8 "fmt" 8 9 "slices" 9 10 "time" ··· 13 14 "github.com/haileyok/cocoon/internal/helpers" 14 15 "github.com/haileyok/cocoon/oauth" 15 16 "github.com/haileyok/cocoon/oauth/constants" 17 + "github.com/haileyok/cocoon/oauth/dpop" 16 18 "github.com/haileyok/cocoon/oauth/provider" 17 19 "github.com/labstack/echo/v4" 18 20 ) ··· 36 38 } 37 39 38 40 func (s *Server) handleOauthToken(e echo.Context) error { 41 + ctx := e.Request().Context() 42 + logger := s.logger.With("name", "handleOauthToken") 43 + 39 44 var req OauthTokenRequest 40 45 if err := e.Bind(&req); err != nil { 41 - s.logger.Error("error binding token request", "error", err) 46 + logger.Error("error binding token request", "error", err) 42 47 return helpers.ServerError(e, nil) 43 48 } 44 49 45 50 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, e.Request().URL.String(), e.Request().Header, nil) 46 51 if err != nil { 47 - s.logger.Error("error getting dpop proof", "error", err) 48 - return helpers.InputError(e, to.StringPtr(err.Error())) 52 + if errors.Is(err, dpop.ErrUseDpopNonce) { 53 + nonce := s.oauthProvider.NextNonce() 54 + if nonce != "" { 55 + e.Response().Header().Set("DPoP-Nonce", nonce) 56 + e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 57 + } 58 + return e.JSON(400, map[string]string{ 59 + "error": "use_dpop_nonce", 60 + }) 61 + } 62 + logger.Error("error getting dpop proof", "error", err) 63 + return helpers.InputError(e, nil) 49 64 } 50 65 51 66 client, clientAuth, err := s.oauthProvider.AuthenticateClient(e.Request().Context(), req.AuthenticateClientRequestBase, proof, &provider.AuthenticateClientOptions{ 52 67 AllowMissingDpopProof: true, 53 68 }) 54 69 if err != nil { 55 - s.logger.Error("error authenticating client", "error", err) 70 + logger.Error("error authenticating client", "client_id", req.ClientID, "error", err) 56 71 return helpers.InputError(e, to.StringPtr(err.Error())) 57 72 } 58 73 ··· 72 87 73 88 var authReq provider.OauthAuthorizationRequest 74 89 // get the lil guy and delete him 75 - if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 76 - s.logger.Error("error finding authorization request", "error", err) 90 + if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 91 + logger.Error("error finding authorization request", "error", err) 77 92 return helpers.ServerError(e, nil) 78 93 } 79 94 ··· 98 113 case "S256": 99 114 inputChal, err := base64.RawURLEncoding.DecodeString(*authReq.Parameters.CodeChallenge) 100 115 if err != nil { 101 - s.logger.Error("error decoding code challenge", "error", err) 116 + logger.Error("error decoding code challenge", "error", err) 102 117 return helpers.ServerError(e, nil) 103 118 } 104 119 ··· 116 131 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 117 132 } 118 133 119 - repo, err := s.getRepoActorByDid(*authReq.Sub) 134 + repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) 120 135 if err != nil { 121 136 helpers.InputError(e, to.StringPtr("unable to find actor")) 122 137 } ··· 147 162 return err 148 163 } 149 164 150 - if err := s.db.Create(&provider.OauthToken{ 165 + if err := s.db.Create(ctx, &provider.OauthToken{ 151 166 ClientId: authReq.ClientId, 152 167 ClientAuth: *clientAuth, 153 168 Parameters: authReq.Parameters, ··· 157 172 Code: *authReq.Code, 158 173 Token: accessString, 159 174 RefreshToken: refreshToken, 175 + Ip: authReq.Ip, 160 176 }, nil).Error; err != nil { 161 - s.logger.Error("error creating token in db", "error", err) 177 + logger.Error("error creating token in db", "error", err) 162 178 return helpers.ServerError(e, nil) 163 179 } 164 180 ··· 186 202 } 187 203 188 204 var oauthToken provider.OauthToken 189 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 190 - s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 205 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 206 + logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 191 207 return helpers.ServerError(e, nil) 192 208 } 193 209 ··· 203 219 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 204 220 } 205 221 206 - sessionLifetime := constants.PublicClientSessionLifetime 207 - refreshLifetime := constants.PublicClientRefreshLifetime 208 - if clientAuth.Method != "none" { 209 - sessionLifetime = constants.ConfidentialClientSessionLifetime 210 - refreshLifetime = constants.ConfidentialClientRefreshLifetime 211 - } 222 + ageRes := oauth.GetSessionAgeFromToken(oauthToken) 212 223 213 - sessionAge := time.Since(oauthToken.CreatedAt) 214 - if sessionAge > sessionLifetime { 224 + if ageRes.SessionExpired { 215 225 return helpers.InputError(e, to.StringPtr("Session expired")) 216 226 } 217 227 218 - refreshAge := time.Since(oauthToken.UpdatedAt) 219 - if refreshAge > refreshLifetime { 228 + if ageRes.RefreshExpired { 220 229 return helpers.InputError(e, to.StringPtr("Refresh token expired")) 221 230 } 222 231 ··· 251 260 return err 252 261 } 253 262 254 - if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 255 - s.logger.Error("error updating token", "error", err) 263 + if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 264 + logger.Error("error updating token", "error", err) 256 265 return helpers.ServerError(e, nil) 257 266 } 258 267
+43 -18
server/handle_proxy.go
··· 17 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 18 ) 19 19 20 - func (s *Server) handleProxy(e echo.Context) error { 21 - repo, isAuthed := e.Get("repo").(*models.RepoActor) 22 - 23 - pts := strings.Split(e.Request().URL.Path, "/") 24 - if len(pts) != 3 { 25 - return fmt.Errorf("incorrect number of parts") 26 - } 27 - 20 + func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 28 21 svc := e.Request().Header.Get("atproto-proxy") 29 - if svc == "" { 30 - svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably 22 + if svc == "" && s.config.FallbackProxy != "" { 23 + svc = s.config.FallbackProxy 31 24 } 32 25 33 26 svcPts := strings.Split(svc, "#") 34 27 if len(svcPts) != 2 { 35 - return fmt.Errorf("invalid service header") 28 + return "", "", fmt.Errorf("invalid service header") 36 29 } 37 30 38 31 svcDid := svcPts[0] ··· 40 33 41 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 42 35 if err != nil { 43 - return err 36 + return "", "", err 44 37 } 45 38 46 39 var endpoint string ··· 50 43 } 51 44 } 52 45 46 + return endpoint, svcDid, nil 47 + } 48 + 49 + func (s *Server) handleProxy(e echo.Context) error { 50 + logger := s.logger.With("handler", "handleProxy") 51 + 52 + repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 + 54 + pts := strings.Split(e.Request().URL.Path, "/") 55 + if len(pts) != 3 { 56 + return fmt.Errorf("incorrect number of parts") 57 + } 58 + 59 + endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 + if err != nil { 61 + logger.Error("could not get atproto proxy", "error", err) 62 + return helpers.ServerError(e, nil) 63 + } 64 + 53 65 requrl := e.Request().URL 54 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 55 67 requrl.Scheme = "https" ··· 78 90 } 79 91 hj, err := json.Marshal(header) 80 92 if err != nil { 81 - s.logger.Error("error marshaling header", "error", err) 93 + logger.Error("error marshaling header", "error", err) 82 94 return helpers.ServerError(e, nil) 83 95 } 84 96 85 97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 86 98 99 + // When proxying app.bsky.feed.getFeed the token is actually issued for the 100 + // underlying feed generator and the app view passes it on. This allows the 101 + // getFeed implementation to pass in the desired lxm and aud for the token 102 + // and then just delegate to the general proxying logic 103 + lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string) 104 + if !proxyTokenLxmExists || lxm == "" { 105 + lxm = pts[2] 106 + } 107 + aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string) 108 + if !proxyTokenAudExists || aud == "" { 109 + aud = svcDid 110 + } 111 + 87 112 payload := map[string]any{ 88 113 "iss": repo.Repo.Did, 89 - "aud": svcDid, 90 - "lxm": pts[2], 114 + "aud": aud, 115 + "lxm": lxm, 91 116 "jti": uuid.NewString(), 92 117 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 93 118 } 94 119 pj, err := json.Marshal(payload) 95 120 if err != nil { 96 - s.logger.Error("error marashaling payload", "error", err) 121 + logger.Error("error marashaling payload", "error", err) 97 122 return helpers.ServerError(e, nil) 98 123 } 99 124 ··· 104 129 105 130 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 106 131 if err != nil { 107 - s.logger.Error("can't load private key", "error", err) 132 + logger.Error("can't load private key", "error", err) 108 133 return err 109 134 } 110 135 111 136 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 112 137 if err != nil { 113 - s.logger.Error("error signing", "error", err) 138 + logger.Error("error signing", "error", err) 114 139 } 115 140 116 141 rBytes := R.Bytes()
+35
server/handle_proxy_get_feed.go
··· 1 + package server 2 + 3 + import ( 4 + "github.com/Azure/go-autorest/autorest/to" 5 + "github.com/bluesky-social/indigo/api/atproto" 6 + "github.com/bluesky-social/indigo/api/bsky" 7 + "github.com/bluesky-social/indigo/atproto/syntax" 8 + "github.com/bluesky-social/indigo/xrpc" 9 + "github.com/haileyok/cocoon/internal/helpers" 10 + "github.com/labstack/echo/v4" 11 + ) 12 + 13 + func (s *Server) handleProxyBskyFeedGetFeed(e echo.Context) error { 14 + feedUri, err := syntax.ParseATURI(e.QueryParam("feed")) 15 + if err != nil { 16 + return helpers.InputError(e, to.StringPtr("invalid feed uri")) 17 + } 18 + 19 + appViewEndpoint, _, err := s.getAtprotoProxyEndpointFromRequest(e) 20 + if err != nil { 21 + e.Logger().Error("could not get atproto proxy", "error", err) 22 + return helpers.ServerError(e, nil) 23 + } 24 + 25 + appViewClient := xrpc.Client{ 26 + Host: appViewEndpoint, 27 + } 28 + feedRecord, err := atproto.RepoGetRecord(e.Request().Context(), &appViewClient, "", feedUri.Collection().String(), feedUri.Authority().String(), feedUri.RecordKey().String()) 29 + feedGeneratorDid := feedRecord.Value.Val.(*bsky.FeedGenerator).Did 30 + 31 + e.Set("proxyTokenLxm", "app.bsky.feed.getFeedSkeleton") 32 + e.Set("proxyTokenAud", feedGeneratorDid) 33 + 34 + return s.handleProxy(e) 35 + }
+14 -11
server/handle_repo_apply_writes.go
··· 6 6 "github.com/labstack/echo/v4" 7 7 ) 8 8 9 - type ComAtprotoRepoApplyWritesRequest struct { 9 + type ComAtprotoRepoApplyWritesInput struct { 10 10 Repo string `json:"repo" validate:"required,atproto-did"` 11 11 Validate *bool `json:"bool,omitempty"` 12 12 Writes []ComAtprotoRepoApplyWritesItem `json:"writes"` ··· 20 20 Value *MarshalableMap `json:"value,omitempty"` 21 21 } 22 22 23 - type ComAtprotoRepoApplyWritesResponse struct { 23 + type ComAtprotoRepoApplyWritesOutput struct { 24 24 Commit RepoCommit `json:"commit"` 25 25 Results []ApplyWriteResult `json:"results"` 26 26 } 27 27 28 28 func (s *Server) handleApplyWrites(e echo.Context) error { 29 - repo := e.Get("repo").(*models.RepoActor) 29 + ctx := e.Request().Context() 30 + logger := s.logger.With("name", "handleRepoApplyWrites") 30 31 31 - var req ComAtprotoRepoApplyWritesRequest 32 + var req ComAtprotoRepoApplyWritesInput 32 33 if err := e.Bind(&req); err != nil { 33 - s.logger.Error("error binding", "error", err) 34 + logger.Error("error binding", "error", err) 34 35 return helpers.ServerError(e, nil) 35 36 } 36 37 37 38 if err := e.Validate(req); err != nil { 38 - s.logger.Error("error validating", "error", err) 39 + logger.Error("error validating", "error", err) 39 40 return helpers.InputError(e, nil) 40 41 } 41 42 43 + repo := e.Get("repo").(*models.RepoActor) 44 + 42 45 if repo.Repo.Did != req.Repo { 43 - s.logger.Warn("mismatched repo/auth") 46 + logger.Warn("mismatched repo/auth") 44 47 return helpers.InputError(e, nil) 45 48 } 46 49 47 - ops := []Op{} 50 + ops := make([]Op, 0, len(req.Writes)) 48 51 for _, item := range req.Writes { 49 52 ops = append(ops, Op{ 50 53 Type: OpType(item.Type), ··· 54 57 }) 55 58 } 56 59 57 - results, err := s.repoman.applyWrites(repo.Repo, ops, req.SwapCommit) 60 + results, err := s.repoman.applyWrites(ctx, repo.Repo, ops, req.SwapCommit) 58 61 if err != nil { 59 - s.logger.Error("error applying writes", "error", err) 62 + logger.Error("error applying writes", "error", err) 60 63 return helpers.ServerError(e, nil) 61 64 } 62 65 ··· 66 69 results[i].Commit = nil 67 70 } 68 71 69 - return e.JSON(200, ComAtprotoRepoApplyWritesResponse{ 72 + return e.JSON(200, ComAtprotoRepoApplyWritesOutput{ 70 73 Commit: commit, 71 74 Results: results, 72 75 })
+10 -7
server/handle_repo_create_record.go
··· 6 6 "github.com/labstack/echo/v4" 7 7 ) 8 8 9 - type ComAtprotoRepoCreateRecordRequest struct { 9 + type ComAtprotoRepoCreateRecordInput struct { 10 10 Repo string `json:"repo" validate:"required,atproto-did"` 11 11 Collection string `json:"collection" validate:"required,atproto-nsid"` 12 12 Rkey *string `json:"rkey,omitempty"` ··· 17 17 } 18 18 19 19 func (s *Server) handleCreateRecord(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + logger := s.logger.With("name", "handleCreateRecord") 22 + 20 23 repo := e.Get("repo").(*models.RepoActor) 21 24 22 - var req ComAtprotoRepoCreateRecordRequest 25 + var req ComAtprotoRepoCreateRecordInput 23 26 if err := e.Bind(&req); err != nil { 24 - s.logger.Error("error binding", "error", err) 27 + logger.Error("error binding", "error", err) 25 28 return helpers.ServerError(e, nil) 26 29 } 27 30 28 31 if err := e.Validate(req); err != nil { 29 - s.logger.Error("error validating", "error", err) 32 + logger.Error("error validating", "error", err) 30 33 return helpers.InputError(e, nil) 31 34 } 32 35 33 36 if repo.Repo.Did != req.Repo { 34 - s.logger.Warn("mismatched repo/auth") 37 + logger.Warn("mismatched repo/auth") 35 38 return helpers.InputError(e, nil) 36 39 } 37 40 ··· 40 43 optype = OpTypeUpdate 41 44 } 42 45 43 - results, err := s.repoman.applyWrites(repo.Repo, []Op{ 46 + results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{ 44 47 { 45 48 Type: optype, 46 49 Collection: req.Collection, ··· 51 54 }, 52 55 }, req.SwapCommit) 53 56 if err != nil { 54 - s.logger.Error("error applying writes", "error", err) 57 + logger.Error("error applying writes", "error", err) 55 58 return helpers.ServerError(e, nil) 56 59 } 57 60
+10 -7
server/handle_repo_delete_record.go
··· 6 6 "github.com/labstack/echo/v4" 7 7 ) 8 8 9 - type ComAtprotoRepoDeleteRecordRequest struct { 9 + type ComAtprotoRepoDeleteRecordInput struct { 10 10 Repo string `json:"repo" validate:"required,atproto-did"` 11 11 Collection string `json:"collection" validate:"required,atproto-nsid"` 12 12 Rkey string `json:"rkey" validate:"required,atproto-rkey"` ··· 15 15 } 16 16 17 17 func (s *Server) handleDeleteRecord(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + logger := s.logger.With("name", "handleDeleteRecord") 20 + 18 21 repo := e.Get("repo").(*models.RepoActor) 19 22 20 - var req ComAtprotoRepoDeleteRecordRequest 23 + var req ComAtprotoRepoDeleteRecordInput 21 24 if err := e.Bind(&req); err != nil { 22 - s.logger.Error("error binding", "error", err) 25 + logger.Error("error binding", "error", err) 23 26 return helpers.ServerError(e, nil) 24 27 } 25 28 26 29 if err := e.Validate(req); err != nil { 27 - s.logger.Error("error validating", "error", err) 30 + logger.Error("error validating", "error", err) 28 31 return helpers.InputError(e, nil) 29 32 } 30 33 31 34 if repo.Repo.Did != req.Repo { 32 - s.logger.Warn("mismatched repo/auth") 35 + logger.Warn("mismatched repo/auth") 33 36 return helpers.InputError(e, nil) 34 37 } 35 38 36 - results, err := s.repoman.applyWrites(repo.Repo, []Op{ 39 + results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{ 37 40 { 38 41 Type: OpTypeDelete, 39 42 Collection: req.Collection, ··· 42 45 }, 43 46 }, req.SwapCommit) 44 47 if err != nil { 45 - s.logger.Error("error applying writes", "error", err) 48 + logger.Error("error applying writes", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51
+8 -5
server/handle_repo_describe_repo.go
··· 20 20 } 21 21 22 22 func (s *Server) handleDescribeRepo(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + logger := s.logger.With("name", "handleDescribeRepo") 25 + 23 26 did := e.QueryParam("repo") 24 - repo, err := s.getRepoActorByDid(did) 27 + repo, err := s.getRepoActorByDid(ctx, did) 25 28 if err != nil { 26 29 if err == gorm.ErrRecordNotFound { 27 30 return helpers.InputError(e, to.StringPtr("RepoNotFound")) 28 31 } 29 32 30 - s.logger.Error("error looking up repo", "error", err) 33 + logger.Error("error looking up repo", "error", err) 31 34 return helpers.ServerError(e, nil) 32 35 } 33 36 ··· 35 38 36 39 diddoc, err := s.passport.FetchDoc(e.Request().Context(), repo.Repo.Did) 37 40 if err != nil { 38 - s.logger.Error("error fetching diddoc", "error", err) 41 + logger.Error("error fetching diddoc", "error", err) 39 42 return helpers.ServerError(e, nil) 40 43 } 41 44 ··· 64 67 } 65 68 66 69 var records []models.Record 67 - if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 68 - s.logger.Error("error getting collections", "error", err) 70 + if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 71 + logger.Error("error getting collections", "error", err) 69 72 return helpers.ServerError(e, nil) 70 73 } 71 74
+5 -3
server/handle_repo_get_record.go
··· 1 1 package server 2 2 3 3 import ( 4 - "github.com/bluesky-social/indigo/atproto/data" 4 + "github.com/bluesky-social/indigo/atproto/atdata" 5 5 "github.com/bluesky-social/indigo/atproto/syntax" 6 6 "github.com/haileyok/cocoon/models" 7 7 "github.com/labstack/echo/v4" ··· 14 14 } 15 15 16 16 func (s *Server) handleRepoGetRecord(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 repo := e.QueryParam("repo") 18 20 collection := e.QueryParam("collection") 19 21 rkey := e.QueryParam("rkey") ··· 32 34 } 33 35 34 36 var record models.Record 35 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 37 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 36 38 // TODO: handle error nicely 37 39 return err 38 40 } 39 41 40 - val, err := data.UnmarshalCBOR(record.Value) 42 + val, err := atdata.UnmarshalCBOR(record.Value) 41 43 if err != nil { 42 44 return s.handleProxy(e) // TODO: this should be getting handled like...if we don't find it in the db. why doesn't it throw error up there? 43 45 }
+115
server/handle_repo_list_missing_blobs.go
··· 1 + package server 2 + 3 + import ( 4 + "fmt" 5 + "strconv" 6 + 7 + "github.com/bluesky-social/indigo/atproto/atdata" 8 + "github.com/haileyok/cocoon/internal/helpers" 9 + "github.com/haileyok/cocoon/models" 10 + "github.com/ipfs/go-cid" 11 + "github.com/labstack/echo/v4" 12 + ) 13 + 14 + type ComAtprotoRepoListMissingBlobsResponse struct { 15 + Cursor *string `json:"cursor,omitempty"` 16 + Blobs []ComAtprotoRepoListMissingBlobsRecordBlob `json:"blobs"` 17 + } 18 + 19 + type ComAtprotoRepoListMissingBlobsRecordBlob struct { 20 + Cid string `json:"cid"` 21 + RecordUri string `json:"recordUri"` 22 + } 23 + 24 + func (s *Server) handleListMissingBlobs(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + logger := s.logger.With("name", "handleListMissingBlos") 27 + 28 + urepo := e.Get("repo").(*models.RepoActor) 29 + 30 + limitStr := e.QueryParam("limit") 31 + cursor := e.QueryParam("cursor") 32 + 33 + limit := 500 34 + if limitStr != "" { 35 + if l, err := strconv.Atoi(limitStr); err == nil && l > 0 && l <= 1000 { 36 + limit = l 37 + } 38 + } 39 + 40 + var records []models.Record 41 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 42 + logger.Error("failed to get records for listMissingBlobs", "error", err) 43 + return helpers.ServerError(e, nil) 44 + } 45 + 46 + type blobRef struct { 47 + cid cid.Cid 48 + recordUri string 49 + } 50 + var allBlobRefs []blobRef 51 + 52 + for _, rec := range records { 53 + blobs := getBlobsFromRecord(rec.Value) 54 + recordUri := fmt.Sprintf("at://%s/%s/%s", urepo.Repo.Did, rec.Nsid, rec.Rkey) 55 + for _, b := range blobs { 56 + allBlobRefs = append(allBlobRefs, blobRef{cid: cid.Cid(b.Ref), recordUri: recordUri}) 57 + } 58 + } 59 + 60 + missingBlobs := make([]ComAtprotoRepoListMissingBlobsRecordBlob, 0) 61 + seenCids := make(map[string]bool) 62 + 63 + for _, ref := range allBlobRefs { 64 + cidStr := ref.cid.String() 65 + 66 + if seenCids[cidStr] { 67 + continue 68 + } 69 + 70 + if cursor != "" && cidStr <= cursor { 71 + continue 72 + } 73 + 74 + var count int64 75 + if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { 76 + continue 77 + } 78 + 79 + if count == 0 { 80 + missingBlobs = append(missingBlobs, ComAtprotoRepoListMissingBlobsRecordBlob{ 81 + Cid: cidStr, 82 + RecordUri: ref.recordUri, 83 + }) 84 + seenCids[cidStr] = true 85 + 86 + if len(missingBlobs) >= limit { 87 + break 88 + } 89 + } 90 + } 91 + 92 + var nextCursor *string 93 + if len(missingBlobs) > 0 && len(missingBlobs) >= limit { 94 + lastCid := missingBlobs[len(missingBlobs)-1].Cid 95 + nextCursor = &lastCid 96 + } 97 + 98 + return e.JSON(200, ComAtprotoRepoListMissingBlobsResponse{ 99 + Cursor: nextCursor, 100 + Blobs: missingBlobs, 101 + }) 102 + } 103 + 104 + func getBlobsFromRecord(data []byte) []atdata.Blob { 105 + if len(data) == 0 { 106 + return nil 107 + } 108 + 109 + decoded, err := atdata.UnmarshalCBOR(data) 110 + if err != nil { 111 + return nil 112 + } 113 + 114 + return atdata.ExtractBlobs(decoded) 115 + }
+9 -6
server/handle_repo_list_records.go
··· 4 4 "strconv" 5 5 6 6 "github.com/Azure/go-autorest/autorest/to" 7 - "github.com/bluesky-social/indigo/atproto/data" 7 + "github.com/bluesky-social/indigo/atproto/atdata" 8 8 "github.com/bluesky-social/indigo/atproto/syntax" 9 9 "github.com/haileyok/cocoon/internal/helpers" 10 10 "github.com/haileyok/cocoon/models" ··· 46 46 } 47 47 48 48 func (s *Server) handleListRecords(e echo.Context) error { 49 + ctx := e.Request().Context() 50 + logger := s.logger.With("name", "handleListRecords") 51 + 49 52 var req ComAtprotoRepoListRecordsRequest 50 53 if err := e.Bind(&req); err != nil { 51 - s.logger.Error("could not bind list records request", "error", err) 54 + logger.Error("could not bind list records request", "error", err) 52 55 return helpers.ServerError(e, nil) 53 56 } 54 57 ··· 78 81 79 82 did := req.Repo 80 83 if _, err := syntax.ParseDID(did); err != nil { 81 - actor, err := s.getActorByHandle(req.Repo) 84 + actor, err := s.getActorByHandle(ctx, req.Repo) 82 85 if err != nil { 83 86 return helpers.InputError(e, to.StringPtr("RepoNotFound")) 84 87 } ··· 93 96 params = append(params, limit) 94 97 95 98 var records []models.Record 96 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 97 - s.logger.Error("error getting records", "error", err) 99 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 100 + logger.Error("error getting records", "error", err) 98 101 return helpers.ServerError(e, nil) 99 102 } 100 103 101 104 items := []ComAtprotoRepoListRecordsRecordItem{} 102 105 for _, r := range records { 103 - val, err := data.UnmarshalCBOR(r.Value) 106 + val, err := atdata.UnmarshalCBOR(r.Value) 104 107 if err != nil { 105 108 return err 106 109 }
+5 -3
server/handle_repo_list_repos.go
··· 21 21 22 22 // TODO: paginate this bitch 23 23 func (s *Server) handleListRepos(e echo.Context) error { 24 + ctx := e.Request().Context() 25 + 24 26 var repos []models.Repo 25 - if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 27 + if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 26 28 return err 27 29 } 28 30 ··· 37 39 Did: r.Did, 38 40 Head: c.String(), 39 41 Rev: r.Rev, 40 - Active: true, 41 - Status: nil, 42 + Active: r.Active(), 43 + Status: r.Status(), 42 44 }) 43 45 } 44 46
+10 -7
server/handle_repo_put_record.go
··· 6 6 "github.com/labstack/echo/v4" 7 7 ) 8 8 9 - type ComAtprotoRepoPutRecordRequest struct { 9 + type ComAtprotoRepoPutRecordInput struct { 10 10 Repo string `json:"repo" validate:"required,atproto-did"` 11 11 Collection string `json:"collection" validate:"required,atproto-nsid"` 12 12 Rkey string `json:"rkey" validate:"required,atproto-rkey"` ··· 17 17 } 18 18 19 19 func (s *Server) handlePutRecord(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + logger := s.logger.With("name", "handlePutRecord") 22 + 20 23 repo := e.Get("repo").(*models.RepoActor) 21 24 22 - var req ComAtprotoRepoPutRecordRequest 25 + var req ComAtprotoRepoPutRecordInput 23 26 if err := e.Bind(&req); err != nil { 24 - s.logger.Error("error binding", "error", err) 27 + logger.Error("error binding", "error", err) 25 28 return helpers.ServerError(e, nil) 26 29 } 27 30 28 31 if err := e.Validate(req); err != nil { 29 - s.logger.Error("error validating", "error", err) 32 + logger.Error("error validating", "error", err) 30 33 return helpers.InputError(e, nil) 31 34 } 32 35 33 36 if repo.Repo.Did != req.Repo { 34 - s.logger.Warn("mismatched repo/auth") 37 + logger.Warn("mismatched repo/auth") 35 38 return helpers.InputError(e, nil) 36 39 } 37 40 ··· 40 43 optype = OpTypeUpdate 41 44 } 42 45 43 - results, err := s.repoman.applyWrites(repo.Repo, []Op{ 46 + results, err := s.repoman.applyWrites(ctx, repo.Repo, []Op{ 44 47 { 45 48 Type: optype, 46 49 Collection: req.Collection, ··· 51 54 }, 52 55 }, req.SwapCommit) 53 56 if err != nil { 54 - s.logger.Error("error applying writes", "error", err) 57 + logger.Error("error applying writes", "error", err) 55 58 return helpers.ServerError(e, nil) 56 59 } 57 60
+59 -14
server/handle_repo_upload_blob.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 + "fmt" 5 6 "io" 6 7 8 + "github.com/aws/aws-sdk-go/aws" 9 + "github.com/aws/aws-sdk-go/aws/credentials" 10 + "github.com/aws/aws-sdk-go/aws/session" 11 + "github.com/aws/aws-sdk-go/service/s3" 7 12 "github.com/haileyok/cocoon/internal/helpers" 8 13 "github.com/haileyok/cocoon/models" 9 14 "github.com/ipfs/go-cid" ··· 27 32 } 28 33 29 34 func (s *Server) handleRepoUploadBlob(e echo.Context) error { 35 + ctx := e.Request().Context() 36 + logger := s.logger.With("name", "handleRepoUploadBlob") 37 + 30 38 urepo := e.Get("repo").(*models.RepoActor) 31 39 32 40 mime := e.Request().Header.Get("content-type") ··· 34 42 mime = "application/octet-stream" 35 43 } 36 44 45 + storage := "sqlite" 46 + s3Upload := s.s3Config != nil && s.s3Config.BlobstoreEnabled 47 + if s3Upload { 48 + storage = "s3" 49 + } 37 50 blob := models.Blob{ 38 51 Did: urepo.Repo.Did, 39 52 RefCount: 0, 40 53 CreatedAt: s.repoman.clock.Next().String(), 54 + Storage: storage, 41 55 } 42 56 43 - if err := s.db.Create(&blob, nil).Error; err != nil { 44 - s.logger.Error("error creating new blob in db", "error", err) 57 + if err := s.db.Create(ctx, &blob, nil).Error; err != nil { 58 + logger.Error("error creating new blob in db", "error", err) 45 59 return helpers.ServerError(e, nil) 46 60 } 47 61 ··· 58 72 break 59 73 } 60 74 } else if err != nil && err != io.ErrUnexpectedEOF { 61 - s.logger.Error("error reading blob", "error", err) 75 + logger.Error("error reading blob", "error", err) 62 76 return helpers.ServerError(e, nil) 63 77 } 64 78 ··· 66 80 read += n 67 81 fulldata.Write(data) 68 82 69 - blobPart := models.BlobPart{ 70 - BlobID: blob.ID, 71 - Idx: part, 72 - Data: data, 73 - } 83 + if !s3Upload { 84 + blobPart := models.BlobPart{ 85 + BlobID: blob.ID, 86 + Idx: part, 87 + Data: data, 88 + } 74 89 75 - if err := s.db.Create(&blobPart, nil).Error; err != nil { 76 - s.logger.Error("error adding blob part to db", "error", err) 77 - return helpers.ServerError(e, nil) 90 + if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil { 91 + logger.Error("error adding blob part to db", "error", err) 92 + return helpers.ServerError(e, nil) 93 + } 78 94 } 79 95 part++ 80 96 ··· 85 101 86 102 c, err := cid.NewPrefixV1(cid.Raw, multihash.SHA2_256).Sum(fulldata.Bytes()) 87 103 if err != nil { 88 - s.logger.Error("error creating cid prefix", "error", err) 104 + logger.Error("error creating cid prefix", "error", err) 89 105 return helpers.ServerError(e, nil) 90 106 } 91 107 92 - if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 108 + if s3Upload { 109 + config := &aws.Config{ 110 + Region: aws.String(s.s3Config.Region), 111 + Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 112 + } 113 + 114 + if s.s3Config.Endpoint != "" { 115 + config.Endpoint = aws.String(s.s3Config.Endpoint) 116 + config.S3ForcePathStyle = aws.Bool(true) 117 + } 118 + 119 + sess, err := session.NewSession(config) 120 + if err != nil { 121 + logger.Error("error creating aws session", "error", err) 122 + return helpers.ServerError(e, nil) 123 + } 124 + 125 + svc := s3.New(sess) 126 + 127 + if _, err := svc.PutObject(&s3.PutObjectInput{ 128 + Bucket: aws.String(s.s3Config.Bucket), 129 + Key: aws.String(fmt.Sprintf("blobs/%s/%s", urepo.Repo.Did, c.String())), 130 + Body: bytes.NewReader(fulldata.Bytes()), 131 + }); err != nil { 132 + logger.Error("error uploading blob to s3", "error", err) 133 + return helpers.ServerError(e, nil) 134 + } 135 + } 136 + 137 + if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 93 138 // there should probably be somme handling here if this fails... 94 - s.logger.Error("error updating blob", "error", err) 139 + logger.Error("error updating blob", "error", err) 95 140 return helpers.ServerError(e, nil) 96 141 } 97 142
+48
server/handle_server_activate_account.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "time" 6 + 7 + "github.com/bluesky-social/indigo/api/atproto" 8 + "github.com/bluesky-social/indigo/events" 9 + "github.com/bluesky-social/indigo/util" 10 + "github.com/haileyok/cocoon/internal/helpers" 11 + "github.com/haileyok/cocoon/models" 12 + "github.com/labstack/echo/v4" 13 + ) 14 + 15 + type ComAtprotoServerActivateAccountRequest struct { 16 + // NOTE: this implementation will not pay attention to this value 17 + DeleteAfter time.Time `json:"deleteAfter"` 18 + } 19 + 20 + func (s *Server) handleServerActivateAccount(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + logger := s.logger.With("name", "handleServerActivateAccount") 23 + 24 + var req ComAtprotoServerDeactivateAccountRequest 25 + if err := e.Bind(&req); err != nil { 26 + logger.Error("error binding", "error", err) 27 + return helpers.ServerError(e, nil) 28 + } 29 + 30 + urepo := e.Get("repo").(*models.RepoActor) 31 + 32 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 33 + logger.Error("error updating account status to deactivated", "error", err) 34 + return helpers.ServerError(e, nil) 35 + } 36 + 37 + s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 38 + RepoAccount: &atproto.SyncSubscribeRepos_Account{ 39 + Active: true, 40 + Did: urepo.Repo.Did, 41 + Status: nil, 42 + Seq: time.Now().UnixMicro(), // TODO: bad puppy 43 + Time: time.Now().Format(util.ISO8601), 44 + }, 45 + }) 46 + 47 + return e.NoContent(200) 48 + }
+10 -7
server/handle_server_check_account_status.go
··· 20 20 } 21 21 22 22 func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + logger := s.logger.With("name", "handleServerCheckAccountStatus") 25 + 23 26 urepo := e.Get("repo").(*models.RepoActor) 24 27 25 28 resp := ComAtprotoServerCheckAccountStatusResponse{ ··· 31 34 32 35 rootcid, err := cid.Cast(urepo.Root) 33 36 if err != nil { 34 - s.logger.Error("error casting cid", "error", err) 37 + logger.Error("error casting cid", "error", err) 35 38 return helpers.ServerError(e, nil) 36 39 } 37 40 resp.RepoCommit = rootcid.String() ··· 41 44 } 42 45 43 46 var blockCtResp CountResp 44 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 45 - s.logger.Error("error getting block count", "error", err) 47 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 48 + logger.Error("error getting block count", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51 resp.RepoBlocks = blockCtResp.Ct 49 52 50 53 var recCtResp CountResp 51 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 52 - s.logger.Error("error getting record count", "error", err) 54 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 55 + logger.Error("error getting record count", "error", err) 53 56 return helpers.ServerError(e, nil) 54 57 } 55 58 resp.IndexedRecords = recCtResp.Ct 56 59 57 60 var blobCtResp CountResp 58 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 59 - s.logger.Error("error getting record count", "error", err) 61 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 62 + logger.Error("error getting record count", "error", err) 60 63 return helpers.ServerError(e, nil) 61 64 } 62 65 resp.ExpectedBlobs = blobCtResp.Ct
+8 -5
server/handle_server_confirm_email.go
··· 15 15 } 16 16 17 17 func (s *Server) handleServerConfirmEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + logger := s.logger.With("name", "handleServerConfirmEmail") 20 + 18 21 urepo := e.Get("repo").(*models.RepoActor) 19 22 20 23 var req ComAtprotoServerConfirmEmailRequest 21 24 if err := e.Bind(&req); err != nil { 22 - s.logger.Error("error binding", "error", err) 25 + logger.Error("error binding", "error", err) 23 26 return helpers.ServerError(e, nil) 24 27 } 25 28 ··· 28 31 } 29 32 30 33 if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil { 31 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 34 + return helpers.ExpiredTokenError(e) 32 35 } 33 36 34 37 if *urepo.EmailVerificationCode != req.Token { ··· 36 39 } 37 40 38 41 if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) { 39 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 42 + return helpers.ExpiredTokenError(e) 40 43 } 41 44 42 45 now := time.Now().UTC() 43 46 44 - if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 45 - s.logger.Error("error updating user", "error", err) 47 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 48 + logger.Error("error updating user", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51
+110 -75
server/handle_server_create_account.go
··· 9 9 10 10 "github.com/Azure/go-autorest/autorest/to" 11 11 "github.com/bluesky-social/indigo/api/atproto" 12 - "github.com/bluesky-social/indigo/atproto/crypto" 13 - "github.com/bluesky-social/indigo/atproto/syntax" 12 + "github.com/bluesky-social/indigo/atproto/atcrypto" 14 13 "github.com/bluesky-social/indigo/events" 15 14 "github.com/bluesky-social/indigo/repo" 16 15 "github.com/bluesky-social/indigo/util" 17 - "github.com/haileyok/cocoon/blockstore" 18 16 "github.com/haileyok/cocoon/internal/helpers" 19 17 "github.com/haileyok/cocoon/models" 20 18 "github.com/labstack/echo/v4" ··· 27 25 Handle string `json:"handle" validate:"required,atproto-handle"` 28 26 Did *string `json:"did" validate:"atproto-did"` 29 27 Password string `json:"password" validate:"required"` 30 - InviteCode string `json:"inviteCode" validate:"required"` 28 + InviteCode string `json:"inviteCode" validate:"omitempty"` 31 29 } 32 30 33 31 type ComAtprotoServerCreateAccountResponse struct { ··· 38 36 } 39 37 40 38 func (s *Server) handleCreateAccount(e echo.Context) error { 41 - var request ComAtprotoServerCreateAccountRequest 42 - 43 - var signupDid string 44 - customDidHeader := e.Request().Header.Get("authorization") 45 - if customDidHeader != "" { 46 - pts := strings.Split(customDidHeader, " ") 47 - if len(pts) != 2 { 48 - return helpers.InputError(e, to.StringPtr("InvalidDid")) 49 - } 39 + ctx := e.Request().Context() 40 + logger := s.logger.With("name", "handleServerCreateAccount") 50 41 51 - _, err := syntax.ParseDID(pts[1]) 52 - if err != nil { 53 - return helpers.InputError(e, to.StringPtr("InvalidDid")) 54 - } 55 - 56 - signupDid = pts[1] 57 - } 42 + var request ComAtprotoServerCreateAccountRequest 58 43 59 44 if err := e.Bind(&request); err != nil { 60 - s.logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err) 45 + logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err) 61 46 return helpers.ServerError(e, nil) 62 47 } 63 48 64 49 request.Handle = strings.ToLower(request.Handle) 65 50 66 51 if err := e.Validate(request); err != nil { 67 - s.logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err) 52 + logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err) 68 53 69 54 var verr ValidationError 70 55 if errors.As(err, &verr) { ··· 87 72 } 88 73 } 89 74 75 + var signupDid string 76 + if request.Did != nil { 77 + signupDid = *request.Did 78 + 79 + token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 80 + if token == "" { 81 + return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) 82 + } 83 + authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount") 84 + 85 + if err != nil { 86 + logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err) 87 + return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token")) 88 + } 89 + 90 + if authDid != signupDid { 91 + return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did")) 92 + } 93 + } 94 + 90 95 // see if the handle is already taken 91 - _, err := s.getActorByHandle(request.Handle) 96 + actor, err := s.getActorByHandle(ctx, request.Handle) 92 97 if err != nil && err != gorm.ErrRecordNotFound { 93 - s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 98 + logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 94 99 return helpers.ServerError(e, nil) 95 100 } 96 - if err == nil { 101 + if err == nil && actor.Did != signupDid { 97 102 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 98 103 } 99 104 100 - if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != "" { 105 + if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid { 101 106 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 102 107 } 103 108 104 109 var ic models.InviteCode 105 - if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 106 - if err == gorm.ErrRecordNotFound { 110 + if s.config.RequireInvite { 111 + if strings.TrimSpace(request.InviteCode) == "" { 107 112 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 108 113 } 109 - s.logger.Error("error getting invite code from db", "error", err) 110 - return helpers.ServerError(e, nil) 111 - } 112 114 113 - if ic.RemainingUseCount < 1 { 114 - return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 115 + if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 116 + if err == gorm.ErrRecordNotFound { 117 + return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 118 + } 119 + logger.Error("error getting invite code from db", "error", err) 120 + return helpers.ServerError(e, nil) 121 + } 122 + 123 + if ic.RemainingUseCount < 1 { 124 + return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 125 + } 115 126 } 116 127 117 128 // see if the email is already taken 118 - _, err = s.getRepoByEmail(request.Email) 129 + existingRepo, err := s.getRepoByEmail(ctx, request.Email) 119 130 if err != nil && err != gorm.ErrRecordNotFound { 120 - s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 131 + logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 121 132 return helpers.ServerError(e, nil) 122 133 } 123 - if err == nil { 134 + if err == nil && existingRepo.Did != signupDid { 124 135 return helpers.InputError(e, to.StringPtr("EmailNotAvailable")) 125 136 } 126 137 127 138 // TODO: unsupported domains 128 139 129 - k, err := crypto.GeneratePrivateKeyK256() 130 - if err != nil { 131 - s.logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err) 132 - return helpers.ServerError(e, nil) 140 + var k *atcrypto.PrivateKeyK256 141 + 142 + if signupDid != "" { 143 + reservedKey, err := s.getReservedKey(ctx, signupDid) 144 + if err != nil { 145 + logger.Error("error looking up reserved key", "error", err) 146 + } 147 + if reservedKey != nil { 148 + k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey) 149 + if err != nil { 150 + logger.Error("error parsing reserved key", "error", err) 151 + k = nil 152 + } else { 153 + defer func() { 154 + if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil { 155 + logger.Error("error deleting reserved key", "error", delErr) 156 + } 157 + }() 158 + } 159 + } 160 + } 161 + 162 + if k == nil { 163 + k, err = atcrypto.GeneratePrivateKeyK256() 164 + if err != nil { 165 + logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err) 166 + return helpers.ServerError(e, nil) 167 + } 133 168 } 134 169 135 170 if signupDid == "" { 136 171 did, op, err := s.plcClient.CreateDID(k, "", request.Handle) 137 172 if err != nil { 138 - s.logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err) 173 + logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err) 139 174 return helpers.ServerError(e, nil) 140 175 } 141 176 142 177 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil { 143 - s.logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err) 178 + logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err) 144 179 return helpers.ServerError(e, nil) 145 180 } 146 181 signupDid = did ··· 148 183 149 184 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10) 150 185 if err != nil { 151 - s.logger.Error("error hashing password", "error", err) 186 + logger.Error("error hashing password", "error", err) 152 187 return helpers.ServerError(e, nil) 153 188 } 154 189 ··· 161 196 SigningKey: k.Bytes(), 162 197 } 163 198 164 - actor := models.Actor{ 165 - Did: signupDid, 166 - Handle: request.Handle, 167 - } 199 + if actor == nil { 200 + actor = &models.Actor{ 201 + Did: signupDid, 202 + Handle: request.Handle, 203 + } 168 204 169 - if err := s.db.Create(&urepo, nil).Error; err != nil { 170 - s.logger.Error("error inserting new repo", "error", err) 171 - return helpers.ServerError(e, nil) 172 - } 205 + if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { 206 + logger.Error("error inserting new repo", "error", err) 207 + return helpers.ServerError(e, nil) 208 + } 173 209 174 - if err := s.db.Create(&actor, nil).Error; err != nil { 175 - s.logger.Error("error inserting new actor", "error", err) 176 - return helpers.ServerError(e, nil) 210 + if err := s.db.Create(ctx, &actor, nil).Error; err != nil { 211 + logger.Error("error inserting new actor", "error", err) 212 + return helpers.ServerError(e, nil) 213 + } 214 + } else { 215 + if err := s.db.Save(ctx, &actor, nil).Error; err != nil { 216 + logger.Error("error inserting new actor", "error", err) 217 + return helpers.ServerError(e, nil) 218 + } 177 219 } 178 220 179 - if customDidHeader == "" { 180 - bs := blockstore.New(signupDid, s.db) 221 + if request.Did == nil || *request.Did == "" { 222 + bs := s.getBlockstore(signupDid) 181 223 r := repo.NewRepo(context.TODO(), signupDid, bs) 182 224 183 225 root, rev, err := r.Commit(context.TODO(), urepo.SignFor) 184 226 if err != nil { 185 - s.logger.Error("error committing", "error", err) 227 + logger.Error("error committing", "error", err) 186 228 return helpers.ServerError(e, nil) 187 229 } 188 230 189 - if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil { 190 - s.logger.Error("error updating repo after commit", "error", err) 231 + if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil { 232 + logger.Error("error updating repo after commit", "error", err) 191 233 return helpers.ServerError(e, nil) 192 234 } 193 235 194 236 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 195 - RepoHandle: &atproto.SyncSubscribeRepos_Handle{ 196 - Did: urepo.Did, 197 - Handle: request.Handle, 198 - Seq: time.Now().UnixMicro(), // TODO: no 199 - Time: time.Now().Format(util.ISO8601), 200 - }, 201 - }) 202 - 203 - s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 204 237 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{ 205 238 Did: urepo.Did, 206 239 Handle: to.StringPtr(request.Handle), ··· 210 243 }) 211 244 } 212 245 213 - if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 214 - s.logger.Error("error decrementing use count", "error", err) 215 - return helpers.ServerError(e, nil) 246 + if s.config.RequireInvite { 247 + if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 248 + logger.Error("error decrementing use count", "error", err) 249 + return helpers.ServerError(e, nil) 250 + } 216 251 } 217 252 218 - sess, err := s.createSession(&urepo) 253 + sess, err := s.createSession(ctx, &urepo) 219 254 if err != nil { 220 - s.logger.Error("error creating new session", "error", err) 255 + logger.Error("error creating new session", "error", err) 221 256 return helpers.ServerError(e, nil) 222 257 } 223 258 224 259 go func() { 225 260 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil { 226 - s.logger.Error("error sending email verification email", "error", err) 261 + logger.Error("error sending email verification email", "error", err) 227 262 } 228 263 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil { 229 - s.logger.Error("error sending welcome email", "error", err) 264 + logger.Error("error sending welcome email", "error", err) 230 265 } 231 266 }() 232 267
+7 -4
server/handle_server_create_invite_code.go
··· 17 17 } 18 18 19 19 func (s *Server) handleCreateInviteCode(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + logger := s.logger.With("name", "handleServerCreateInviteCode") 22 + 20 23 var req ComAtprotoServerCreateInviteCodeRequest 21 24 if err := e.Bind(&req); err != nil { 22 - s.logger.Error("error binding", "error", err) 25 + logger.Error("error binding", "error", err) 23 26 return helpers.ServerError(e, nil) 24 27 } 25 28 26 29 if err := e.Validate(req); err != nil { 27 - s.logger.Error("error validating", "error", err) 30 + logger.Error("error validating", "error", err) 28 31 return helpers.InputError(e, nil) 29 32 } 30 33 ··· 37 40 acc = *req.ForAccount 38 41 } 39 42 40 - if err := s.db.Create(&models.InviteCode{ 43 + if err := s.db.Create(ctx, &models.InviteCode{ 41 44 Code: ic, 42 45 Did: acc, 43 46 RemainingUseCount: req.UseCount, 44 47 }, nil).Error; err != nil { 45 - s.logger.Error("error creating invite code", "error", err) 48 + logger.Error("error creating invite code", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51
+7 -4
server/handle_server_create_invite_codes.go
··· 22 22 } 23 23 24 24 func (s *Server) handleCreateInviteCodes(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + logger := s.logger.With("name", "handleServerCreateInviteCodes") 27 + 25 28 var req ComAtprotoServerCreateInviteCodesRequest 26 29 if err := e.Bind(&req); err != nil { 27 - s.logger.Error("error binding", "error", err) 30 + logger.Error("error binding", "error", err) 28 31 return helpers.ServerError(e, nil) 29 32 } 30 33 31 34 if err := e.Validate(req); err != nil { 32 - s.logger.Error("error validating", "error", err) 35 + logger.Error("error validating", "error", err) 33 36 return helpers.InputError(e, nil) 34 37 } 35 38 ··· 50 53 ic := uuid.NewString() 51 54 ics = append(ics, ic) 52 55 53 - if err := s.db.Create(&models.InviteCode{ 56 + if err := s.db.Create(ctx, &models.InviteCode{ 54 57 Code: ic, 55 58 Did: did, 56 59 RemainingUseCount: req.UseCount, 57 60 }, nil).Error; err != nil { 58 - s.logger.Error("error creating invite code", "error", err) 61 + logger.Error("error creating invite code", "error", err) 59 62 return helpers.ServerError(e, nil) 60 63 } 61 64 }
+67 -11
server/handle_server_create_session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "errors" 6 + "fmt" 5 7 "strings" 8 + "time" 6 9 7 10 "github.com/Azure/go-autorest/autorest/to" 8 11 "github.com/bluesky-social/indigo/atproto/syntax" ··· 32 35 } 33 36 34 37 func (s *Server) handleCreateSession(e echo.Context) error { 38 + ctx := e.Request().Context() 39 + logger := s.logger.With("name", "handleServerCreateSession") 40 + 35 41 var req ComAtprotoServerCreateSessionRequest 36 42 if err := e.Bind(&req); err != nil { 37 - s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) 43 + logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) 38 44 return helpers.ServerError(e, nil) 39 45 } 40 46 ··· 65 71 var err error 66 72 switch idtype { 67 73 case "did": 68 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 74 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 69 75 case "handle": 70 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 76 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 71 77 case "email": 72 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 78 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 73 79 } 74 80 75 81 if err != nil { ··· 77 83 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 78 84 } 79 85 80 - s.logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err) 86 + logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err) 81 87 return helpers.ServerError(e, nil) 82 88 } 83 89 84 90 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { 85 91 if err != bcrypt.ErrMismatchedHashAndPassword { 86 - s.logger.Error("erorr comparing hash and password", "error", err) 92 + logger.Error("erorr comparing hash and password", "error", err) 87 93 } 88 94 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 89 95 } 90 96 91 - sess, err := s.createSession(&repo.Repo) 97 + // if repo requires 2FA token and one hasn't been provided, return error prompting for one 98 + if repo.TwoFactorType != models.TwoFactorTypeNone && (req.AuthFactorToken == nil || *req.AuthFactorToken == "") { 99 + err = s.createAndSendTwoFactorCode(ctx, repo) 100 + if err != nil { 101 + logger.Error("sending 2FA code", "error", err) 102 + return helpers.ServerError(e, nil) 103 + } 104 + 105 + return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 106 + } 107 + 108 + // if 2FA is required, now check that the one provided is valid 109 + if repo.TwoFactorType != models.TwoFactorTypeNone { 110 + if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil { 111 + err = s.createAndSendTwoFactorCode(ctx, repo) 112 + if err != nil { 113 + logger.Error("sending 2FA code", "error", err) 114 + return helpers.ServerError(e, nil) 115 + } 116 + 117 + return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 118 + } 119 + 120 + if *repo.TwoFactorCode != *req.AuthFactorToken { 121 + return helpers.InvalidTokenError(e) 122 + } 123 + 124 + if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 125 + return helpers.ExpiredTokenError(e) 126 + } 127 + } 128 + 129 + sess, err := s.createSession(ctx, &repo.Repo) 92 130 if err != nil { 93 - s.logger.Error("error creating session", "error", err) 131 + logger.Error("error creating session", "error", err) 94 132 return helpers.ServerError(e, nil) 95 133 } 96 134 ··· 101 139 Did: repo.Repo.Did, 102 140 Email: repo.Email, 103 141 EmailConfirmed: repo.EmailConfirmedAt != nil, 104 - EmailAuthFactor: false, 105 - Active: true, // TODO: eventually do takedowns 106 - Status: nil, // TODO eventually do takedowns 142 + EmailAuthFactor: repo.TwoFactorType != models.TwoFactorTypeNone, 143 + Active: repo.Active(), 144 + Status: repo.Status(), 107 145 }) 108 146 } 147 + 148 + func (s *Server) createAndSendTwoFactorCode(ctx context.Context, repo models.RepoActor) error { 149 + // TODO: when implementing a new type of 2FA there should be some logic in here to send the 150 + // right type of code 151 + 152 + code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 153 + eat := time.Now().Add(10 * time.Minute).UTC() 154 + 155 + if err := s.db.Exec(ctx, "UPDATE repos SET two_factor_code = ?, two_factor_code_expires_at = ? WHERE did = ?", nil, code, eat, repo.Repo.Did).Error; err != nil { 156 + return fmt.Errorf("updating repo: %w", err) 157 + } 158 + 159 + if err := s.sendTwoFactorCode(repo.Email, repo.Handle, code); err != nil { 160 + return fmt.Errorf("sending email: %w", err) 161 + } 162 + 163 + return nil 164 + }
+49
server/handle_server_deactivate_account.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "time" 6 + 7 + "github.com/Azure/go-autorest/autorest/to" 8 + "github.com/bluesky-social/indigo/api/atproto" 9 + "github.com/bluesky-social/indigo/events" 10 + "github.com/bluesky-social/indigo/util" 11 + "github.com/haileyok/cocoon/internal/helpers" 12 + "github.com/haileyok/cocoon/models" 13 + "github.com/labstack/echo/v4" 14 + ) 15 + 16 + type ComAtprotoServerDeactivateAccountRequest struct { 17 + // NOTE: this implementation will not pay attention to this value 18 + DeleteAfter time.Time `json:"deleteAfter"` 19 + } 20 + 21 + func (s *Server) handleServerDeactivateAccount(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + logger := s.logger.With("name", "handleServerDeactivateAccount") 24 + 25 + var req ComAtprotoServerDeactivateAccountRequest 26 + if err := e.Bind(&req); err != nil { 27 + logger.Error("error binding", "error", err) 28 + return helpers.ServerError(e, nil) 29 + } 30 + 31 + urepo := e.Get("repo").(*models.RepoActor) 32 + 33 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 34 + logger.Error("error updating account status to deactivated", "error", err) 35 + return helpers.ServerError(e, nil) 36 + } 37 + 38 + s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 39 + RepoAccount: &atproto.SyncSubscribeRepos_Account{ 40 + Active: false, 41 + Did: urepo.Repo.Did, 42 + Status: to.StringPtr("deactivated"), 43 + Seq: time.Now().UnixMicro(), // TODO: bad puppy 44 + Time: time.Now().Format(util.ISO8601), 45 + }, 46 + }) 47 + 48 + return e.NoContent(200) 49 + }
+150
server/handle_server_delete_account.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "time" 6 + 7 + "github.com/Azure/go-autorest/autorest/to" 8 + "github.com/bluesky-social/indigo/api/atproto" 9 + "github.com/bluesky-social/indigo/events" 10 + "github.com/bluesky-social/indigo/util" 11 + "github.com/haileyok/cocoon/internal/helpers" 12 + "github.com/labstack/echo/v4" 13 + "golang.org/x/crypto/bcrypt" 14 + ) 15 + 16 + type ComAtprotoServerDeleteAccountRequest struct { 17 + Did string `json:"did" validate:"required"` 18 + Password string `json:"password" validate:"required"` 19 + Token string `json:"token" validate:"required"` 20 + } 21 + 22 + func (s *Server) handleServerDeleteAccount(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + logger := s.logger.With("name", "handleServerDeleteAccount") 25 + 26 + var req ComAtprotoServerDeleteAccountRequest 27 + if err := e.Bind(&req); err != nil { 28 + logger.Error("error binding", "error", err) 29 + return helpers.ServerError(e, nil) 30 + } 31 + 32 + if err := e.Validate(&req); err != nil { 33 + logger.Error("error validating", "error", err) 34 + return helpers.ServerError(e, nil) 35 + } 36 + 37 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 38 + if err != nil { 39 + logger.Error("error getting repo", "error", err) 40 + return echo.NewHTTPError(400, "account not found") 41 + } 42 + 43 + if err := bcrypt.CompareHashAndPassword([]byte(urepo.Repo.Password), []byte(req.Password)); err != nil { 44 + logger.Error("password mismatch", "error", err) 45 + return echo.NewHTTPError(401, "Invalid did or password") 46 + } 47 + 48 + if urepo.Repo.AccountDeleteCode == nil || urepo.Repo.AccountDeleteCodeExpiresAt == nil { 49 + logger.Error("no deletion token found for account") 50 + return echo.NewHTTPError(400, map[string]interface{}{ 51 + "error": "InvalidToken", 52 + "message": "Token is invalid", 53 + }) 54 + } 55 + 56 + if *urepo.Repo.AccountDeleteCode != req.Token { 57 + logger.Error("deletion token mismatch") 58 + return echo.NewHTTPError(400, map[string]interface{}{ 59 + "error": "InvalidToken", 60 + "message": "Token is invalid", 61 + }) 62 + } 63 + 64 + if time.Now().UTC().After(*urepo.Repo.AccountDeleteCodeExpiresAt) { 65 + logger.Error("deletion token expired") 66 + return echo.NewHTTPError(400, map[string]interface{}{ 67 + "error": "ExpiredToken", 68 + "message": "Token is expired", 69 + }) 70 + } 71 + 72 + tx := s.db.BeginDangerously(ctx) 73 + if tx.Error != nil { 74 + logger.Error("error starting transaction", "error", tx.Error) 75 + return helpers.ServerError(e, nil) 76 + } 77 + 78 + status := "error" 79 + func() { 80 + if status == "error" { 81 + if err := tx.Rollback().Error; err != nil { 82 + logger.Error("error rolling back after delete failure", "err", err) 83 + } 84 + } 85 + }() 86 + 87 + if err := tx.Exec("DELETE FROM blocks WHERE did = ?", nil, req.Did).Error; err != nil { 88 + logger.Error("error deleting blocks", "error", err) 89 + return helpers.ServerError(e, nil) 90 + } 91 + 92 + if err := tx.Exec("DELETE FROM records WHERE did = ?", nil, req.Did).Error; err != nil { 93 + logger.Error("error deleting records", "error", err) 94 + return helpers.ServerError(e, nil) 95 + } 96 + 97 + if err := tx.Exec("DELETE FROM blobs WHERE did = ?", nil, req.Did).Error; err != nil { 98 + logger.Error("error deleting blobs", "error", err) 99 + return helpers.ServerError(e, nil) 100 + } 101 + 102 + if err := tx.Exec("DELETE FROM tokens WHERE did = ?", nil, req.Did).Error; err != nil { 103 + logger.Error("error deleting tokens", "error", err) 104 + return helpers.ServerError(e, nil) 105 + } 106 + 107 + if err := tx.Exec("DELETE FROM refresh_tokens WHERE did = ?", nil, req.Did).Error; err != nil { 108 + logger.Error("error deleting refresh tokens", "error", err) 109 + return helpers.ServerError(e, nil) 110 + } 111 + 112 + if err := tx.Exec("DELETE FROM reserved_keys WHERE did = ?", nil, req.Did).Error; err != nil { 113 + logger.Error("error deleting reserved keys", "error", err) 114 + return helpers.ServerError(e, nil) 115 + } 116 + 117 + if err := tx.Exec("DELETE FROM invite_codes WHERE did = ?", nil, req.Did).Error; err != nil { 118 + logger.Error("error deleting invite codes", "error", err) 119 + return helpers.ServerError(e, nil) 120 + } 121 + 122 + if err := tx.Exec("DELETE FROM actors WHERE did = ?", nil, req.Did).Error; err != nil { 123 + logger.Error("error deleting actor", "error", err) 124 + return helpers.ServerError(e, nil) 125 + } 126 + 127 + if err := tx.Exec("DELETE FROM repos WHERE did = ?", nil, req.Did).Error; err != nil { 128 + logger.Error("error deleting repo", "error", err) 129 + return helpers.ServerError(e, nil) 130 + } 131 + 132 + status = "ok" 133 + 134 + if err := tx.Commit().Error; err != nil { 135 + logger.Error("error committing transaction", "error", err) 136 + return helpers.ServerError(e, nil) 137 + } 138 + 139 + s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 140 + RepoAccount: &atproto.SyncSubscribeRepos_Account{ 141 + Active: false, 142 + Did: req.Did, 143 + Status: to.StringPtr("deleted"), 144 + Seq: time.Now().UnixMicro(), 145 + Time: time.Now().Format(util.ISO8601), 146 + }, 147 + }) 148 + 149 + return e.NoContent(200) 150 + }
+4 -2
server/handle_server_delete_session.go
··· 7 7 ) 8 8 9 9 func (s *Server) handleDeleteSession(e echo.Context) error { 10 + ctx := e.Request().Context() 11 + 10 12 token := e.Get("token").(string) 11 13 12 14 var acctok models.Token 13 - if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 15 + if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 14 16 s.logger.Error("error deleting access token from db", "error", err) 15 17 return helpers.ServerError(e, nil) 16 18 } 17 19 18 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 20 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 19 21 s.logger.Error("error deleting refresh token from db", "error", err) 20 22 return helpers.ServerError(e, nil) 21 23 }
+1 -1
server/handle_server_describe_server.go
··· 22 22 23 23 func (s *Server) handleDescribeServer(e echo.Context) error { 24 24 return e.JSON(200, ComAtprotoServerDescribeServerResponse{ 25 - InviteCodeRequired: true, 25 + InviteCodeRequired: s.config.RequireInvite, 26 26 PhoneVerificationRequired: false, 27 27 AvailableUserDomains: []string{"." + s.config.Hostname}, // TODO: more 28 28 Links: ComAtprotoServerDescribeServerResponseLinks{
+24 -13
server/handle_server_get_service_auth.go
··· 19 19 20 20 type ServerGetServiceAuthRequest struct { 21 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 - Exp int64 `query:"exp"` 23 - Lxm string `query:"lxm" validate:"required,atproto-nsid"` 22 + // exp should be a float, as some clients will send a non-integer expiration 23 + Exp float64 `query:"exp"` 24 + Lxm string `query:"lxm"` 24 25 } 25 26 26 27 func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 28 + logger := s.logger.With("name", "handleServerGetServiceAuth") 29 + 27 30 var req ServerGetServiceAuthRequest 28 31 if err := e.Bind(&req); err != nil { 29 - s.logger.Error("could not bind service auth request", "error", err) 32 + logger.Error("could not bind service auth request", "error", err) 30 33 return helpers.ServerError(e, nil) 31 34 } 32 35 ··· 34 37 return helpers.InputError(e, nil) 35 38 } 36 39 40 + exp := int64(req.Exp) 37 41 now := time.Now().Unix() 38 - if req.Exp == 0 { 39 - req.Exp = now + 60 // default 42 + if exp == 0 { 43 + exp = now + 60 // default 40 44 } 41 45 42 46 if req.Lxm == "com.atproto.server.getServiceAuth" { 43 47 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 44 48 } 45 49 46 - maxExp := now + (60 * 30) 47 - if req.Exp > maxExp { 50 + var maxExp int64 51 + if req.Lxm != "" { 52 + maxExp = now + (60 * 60) 53 + } else { 54 + maxExp = now + 60 55 + } 56 + if exp > maxExp { 48 57 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 49 58 } 50 59 ··· 57 66 } 58 67 hj, err := json.Marshal(header) 59 68 if err != nil { 60 - s.logger.Error("error marshaling header", "error", err) 69 + logger.Error("error marshaling header", "error", err) 61 70 return helpers.ServerError(e, nil) 62 71 } 63 72 ··· 66 75 payload := map[string]any{ 67 76 "iss": repo.Repo.Did, 68 77 "aud": req.Aud, 69 - "lxm": req.Lxm, 70 78 "jti": uuid.NewString(), 71 - "exp": req.Exp, 79 + "exp": exp, 72 80 "iat": now, 73 81 } 82 + if req.Lxm != "" { 83 + payload["lxm"] = req.Lxm 84 + } 74 85 pj, err := json.Marshal(payload) 75 86 if err != nil { 76 - s.logger.Error("error marashaling payload", "error", err) 87 + logger.Error("error marashaling payload", "error", err) 77 88 return helpers.ServerError(e, nil) 78 89 } 79 90 ··· 84 95 85 96 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 86 97 if err != nil { 87 - s.logger.Error("can't load private key", "error", err) 98 + logger.Error("can't load private key", "error", err) 88 99 return err 89 100 } 90 101 91 102 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 92 103 if err != nil { 93 - s.logger.Error("error signing", "error", err) 104 + logger.Error("error signing", "error", err) 94 105 return helpers.ServerError(e, nil) 95 106 } 96 107
+3 -3
server/handle_server_get_session.go
··· 23 23 Did: repo.Repo.Did, 24 24 Email: repo.Email, 25 25 EmailConfirmed: repo.EmailConfirmedAt != nil, 26 - EmailAuthFactor: false, // TODO: todo todo 27 - Active: true, 28 - Status: nil, 26 + EmailAuthFactor: repo.TwoFactorType != models.TwoFactorTypeNone, 27 + Active: repo.Active(), 28 + Status: repo.Status(), 29 29 }) 30 30 }
+11 -8
server/handle_server_refresh_session.go
··· 16 16 } 17 17 18 18 func (s *Server) handleRefreshSession(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + logger := s.logger.With("name", "handleServerRefreshSession") 21 + 19 22 token := e.Get("token").(string) 20 23 repo := e.Get("repo").(*models.RepoActor) 21 24 22 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 23 - s.logger.Error("error getting refresh token from db", "error", err) 25 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 26 + logger.Error("error getting refresh token from db", "error", err) 24 27 return helpers.ServerError(e, nil) 25 28 } 26 29 27 - if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 28 - s.logger.Error("error deleting access token from db", "error", err) 30 + if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 31 + logger.Error("error deleting access token from db", "error", err) 29 32 return helpers.ServerError(e, nil) 30 33 } 31 34 32 - sess, err := s.createSession(&repo.Repo) 35 + sess, err := s.createSession(ctx, &repo.Repo) 33 36 if err != nil { 34 - s.logger.Error("error creating new session for refresh", "error", err) 37 + logger.Error("error creating new session for refresh", "error", err) 35 38 return helpers.ServerError(e, nil) 36 39 } 37 40 ··· 40 43 RefreshJwt: sess.RefreshToken, 41 44 Handle: repo.Handle, 42 45 Did: repo.Repo.Did, 43 - Active: true, 44 - Status: nil, 46 + Active: repo.Active(), 47 + Status: repo.Status(), 45 48 }) 46 49 }
+52
server/handle_server_request_account_delete.go
··· 1 + package server 2 + 3 + import ( 4 + "fmt" 5 + "time" 6 + 7 + "github.com/haileyok/cocoon/internal/helpers" 8 + "github.com/haileyok/cocoon/models" 9 + "github.com/labstack/echo/v4" 10 + ) 11 + 12 + func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + logger := s.logger.With("name", "handleServerRequestAccountDelete") 15 + 16 + urepo := e.Get("repo").(*models.RepoActor) 17 + 18 + token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 19 + expiresAt := time.Now().UTC().Add(15 * time.Minute) 20 + 21 + if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 22 + logger.Error("error setting deletion token", "error", err) 23 + return helpers.ServerError(e, nil) 24 + } 25 + 26 + if urepo.Email != "" { 27 + if err := s.sendAccountDeleteEmail(urepo.Email, urepo.Actor.Handle, token); err != nil { 28 + logger.Error("error sending account deletion email", "error", err) 29 + } 30 + } 31 + 32 + return e.NoContent(200) 33 + } 34 + 35 + func (s *Server) sendAccountDeleteEmail(email, handle, token string) error { 36 + if s.mail == nil { 37 + return nil 38 + } 39 + 40 + s.mailLk.Lock() 41 + defer s.mailLk.Unlock() 42 + 43 + s.mail.To(email) 44 + s.mail.Subject("Account Deletion Request for " + s.config.Hostname) 45 + s.mail.Plain().Set(fmt.Sprintf("Hello %s. Your account deletion code is %s. This code will expire in fifteen minutes. If you did not request this, please ignore this email.", handle, token)) 46 + 47 + if err := s.mail.Send(); err != nil { 48 + return err 49 + } 50 + 51 + return nil 52 + }
+6 -3
server/handle_server_request_email_confirmation.go
··· 11 11 ) 12 12 13 13 func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { 14 + ctx := e.Request().Context() 15 + logger := s.logger.With("name", "handleServerRequestEmailConfirm") 16 + 14 17 urepo := e.Get("repo").(*models.RepoActor) 15 18 16 19 if urepo.EmailConfirmedAt != nil { ··· 20 23 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 24 eat := time.Now().Add(10 * time.Minute).UTC() 22 25 23 - if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 - s.logger.Error("error updating user", "error", err) 26 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 27 + logger.Error("error updating user", "error", err) 25 28 return helpers.ServerError(e, nil) 26 29 } 27 30 28 31 if err := s.sendEmailVerification(urepo.Email, urepo.Handle, code); err != nil { 29 - s.logger.Error("error sending mail", "error", err) 32 + logger.Error("error sending mail", "error", err) 30 33 return helpers.ServerError(e, nil) 31 34 } 32 35
+6 -3
server/handle_server_request_email_update.go
··· 14 14 } 15 15 16 16 func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + logger := s.logger.With("name", "handleServerRequestEmailUpdate") 19 + 17 20 urepo := e.Get("repo").(*models.RepoActor) 18 21 19 22 if urepo.EmailConfirmedAt != nil { 20 23 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 24 eat := time.Now().Add(10 * time.Minute).UTC() 22 25 23 - if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 - s.logger.Error("error updating repo", "error", err) 26 + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 27 + logger.Error("error updating repo", "error", err) 25 28 return helpers.ServerError(e, nil) 26 29 } 27 30 28 31 if err := s.sendEmailUpdate(urepo.Email, urepo.Handle, code); err != nil { 29 - s.logger.Error("error sending email", "error", err) 32 + logger.Error("error sending email", "error", err) 30 33 return helpers.ServerError(e, nil) 31 34 } 32 35 }
+7 -4
server/handle_server_request_password_reset.go
··· 14 14 } 15 15 16 16 func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + logger := s.logger.With("name", "handleServerRequestPasswordReset") 19 + 17 20 urepo, ok := e.Get("repo").(*models.RepoActor) 18 21 if !ok { 19 22 var req ComAtprotoServerRequestPasswordResetRequest ··· 25 28 return err 26 29 } 27 30 28 - murepo, err := s.getRepoActorByEmail(req.Email) 31 + murepo, err := s.getRepoActorByEmail(ctx, req.Email) 29 32 if err != nil { 30 33 return err 31 34 } ··· 36 39 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 37 40 eat := time.Now().Add(10 * time.Minute).UTC() 38 41 39 - if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 40 - s.logger.Error("error updating repo", "error", err) 42 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 43 + logger.Error("error updating repo", "error", err) 41 44 return helpers.ServerError(e, nil) 42 45 } 43 46 44 47 if err := s.sendPasswordReset(urepo.Email, urepo.Handle, code); err != nil { 45 - s.logger.Error("error sending email", "error", err) 48 + logger.Error("error sending email", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51
+99
server/handle_server_reserve_signing_key.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "time" 6 + 7 + "github.com/bluesky-social/indigo/atproto/atcrypto" 8 + "github.com/haileyok/cocoon/internal/helpers" 9 + "github.com/haileyok/cocoon/models" 10 + "github.com/labstack/echo/v4" 11 + ) 12 + 13 + type ServerReserveSigningKeyRequest struct { 14 + Did *string `json:"did"` 15 + } 16 + 17 + type ServerReserveSigningKeyResponse struct { 18 + SigningKey string `json:"signingKey"` 19 + } 20 + 21 + func (s *Server) handleServerReserveSigningKey(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + logger := s.logger.With("name", "handleServerReserveSigningKey") 24 + 25 + var req ServerReserveSigningKeyRequest 26 + if err := e.Bind(&req); err != nil { 27 + logger.Error("could not bind reserve signing key request", "error", err) 28 + return helpers.ServerError(e, nil) 29 + } 30 + 31 + if req.Did != nil && *req.Did != "" { 32 + var existing models.ReservedKey 33 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 34 + return e.JSON(200, ServerReserveSigningKeyResponse{ 35 + SigningKey: existing.KeyDid, 36 + }) 37 + } 38 + } 39 + 40 + k, err := atcrypto.GeneratePrivateKeyK256() 41 + if err != nil { 42 + logger.Error("error creating signing key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 43 + return helpers.ServerError(e, nil) 44 + } 45 + 46 + pubKey, err := k.PublicKey() 47 + if err != nil { 48 + logger.Error("error getting public key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 49 + return helpers.ServerError(e, nil) 50 + } 51 + 52 + keyDid := pubKey.DIDKey() 53 + 54 + reservedKey := models.ReservedKey{ 55 + KeyDid: keyDid, 56 + Did: req.Did, 57 + PrivateKey: k.Bytes(), 58 + CreatedAt: time.Now(), 59 + } 60 + 61 + if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil { 62 + logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 63 + return helpers.ServerError(e, nil) 64 + } 65 + 66 + logger.Info("reserved signing key", "keyDid", keyDid, "forDid", req.Did) 67 + 68 + return e.JSON(200, ServerReserveSigningKeyResponse{ 69 + SigningKey: keyDid, 70 + }) 71 + } 72 + 73 + func (s *Server) getReservedKey(ctx context.Context, keyDidOrDid string) (*models.ReservedKey, error) { 74 + var reservedKey models.ReservedKey 75 + 76 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 77 + return &reservedKey, nil 78 + } 79 + 80 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 81 + return &reservedKey, nil 82 + } 83 + 84 + return nil, nil 85 + } 86 + 87 + func (s *Server) deleteReservedKey(ctx context.Context, keyDid string, did *string) error { 88 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { 89 + return err 90 + } 91 + 92 + if did != nil && *did != "" { 93 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { 94 + return err 95 + } 96 + } 97 + 98 + return nil 99 + }
+9 -6
server/handle_server_reset_password.go
··· 16 16 } 17 17 18 18 func (s *Server) handleServerResetPassword(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + logger := s.logger.With("name", "handleServerResetPassword") 21 + 19 22 urepo := e.Get("repo").(*models.RepoActor) 20 23 21 24 var req ComAtprotoServerResetPasswordRequest 22 25 if err := e.Bind(&req); err != nil { 23 - s.logger.Error("error binding", "error", err) 26 + logger.Error("error binding", "error", err) 24 27 return helpers.ServerError(e, nil) 25 28 } 26 29 ··· 33 36 } 34 37 35 38 if *urepo.PasswordResetCode != req.Token { 36 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 39 + return helpers.InvalidTokenError(e) 37 40 } 38 41 39 42 if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) { 40 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 43 + return helpers.ExpiredTokenError(e) 41 44 } 42 45 43 46 hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10) 44 47 if err != nil { 45 - s.logger.Error("error creating hash", "error", err) 48 + logger.Error("error creating hash", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51 49 - if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 50 - s.logger.Error("error updating repo", "error", err) 52 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 53 + logger.Error("error updating repo", "error", err) 51 54 return helpers.ServerError(e, nil) 52 55 } 53 56
+3 -1
server/handle_server_resolve_handle.go
··· 10 10 ) 11 11 12 12 func (s *Server) handleResolveHandle(e echo.Context) error { 13 + logger := s.logger.With("name", "handleServerResolveHandle") 14 + 13 15 type Resp struct { 14 16 Did string `json:"did"` 15 17 } ··· 28 30 ctx := context.WithValue(e.Request().Context(), "skip-cache", true) 29 31 did, err := s.passport.ResolveHandle(ctx, parsed.String()) 30 32 if err != nil { 31 - s.logger.Error("error resolving handle", "error", err) 33 + logger.Error("error resolving handle", "error", err) 32 34 return helpers.ServerError(e, nil) 33 35 } 34 36
+35 -11
server/handle_server_update_email.go
··· 3 3 import ( 4 4 "time" 5 5 6 - "github.com/Azure/go-autorest/autorest/to" 7 6 "github.com/haileyok/cocoon/internal/helpers" 8 7 "github.com/haileyok/cocoon/models" 9 8 "github.com/labstack/echo/v4" ··· 12 11 type ComAtprotoServerUpdateEmailRequest struct { 13 12 Email string `json:"email" validate:"required"` 14 13 EmailAuthFactor bool `json:"emailAuthFactor"` 15 - Token string `json:"token" validate:"required"` 14 + Token string `json:"token"` 16 15 } 17 16 18 17 func (s *Server) handleServerUpdateEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + logger := s.logger.With("name", "handleServerUpdateEmail") 20 + 19 21 urepo := e.Get("repo").(*models.RepoActor) 20 22 21 23 var req ComAtprotoServerUpdateEmailRequest 22 24 if err := e.Bind(&req); err != nil { 23 - s.logger.Error("error binding", "error", err) 25 + logger.Error("error binding", "error", err) 24 26 return helpers.ServerError(e, nil) 25 27 } 26 28 ··· 28 30 return helpers.InputError(e, nil) 29 31 } 30 32 31 - if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil { 32 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 33 + // To disable email auth factor a token is required. 34 + // To enable email auth factor a token is not required. 35 + // If updating an email address, a token will be sent anyway 36 + if urepo.TwoFactorType != models.TwoFactorTypeNone && req.EmailAuthFactor == false && req.Token == "" { 37 + return helpers.InvalidTokenError(e) 33 38 } 34 39 35 - if *urepo.EmailUpdateCode != req.Token { 36 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 40 + if req.Token != "" { 41 + if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil { 42 + return helpers.InvalidTokenError(e) 43 + } 44 + 45 + if *urepo.EmailUpdateCode != req.Token { 46 + return helpers.InvalidTokenError(e) 47 + } 48 + 49 + if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) { 50 + return helpers.ExpiredTokenError(e) 51 + } 37 52 } 38 53 39 - if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) { 40 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 54 + twoFactorType := models.TwoFactorTypeNone 55 + if req.EmailAuthFactor { 56 + twoFactorType = models.TwoFactorTypeEmail 57 + } 58 + 59 + query := "UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, two_factor_type = ?, email = ?" 60 + 61 + if urepo.Email != req.Email { 62 + query += ",email_confirmed_at = NULL" 41 63 } 42 64 43 - if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { 44 - s.logger.Error("error updating repo", "error", err) 65 + query += " WHERE did = ?" 66 + 67 + if err := s.db.Exec(ctx, query, nil, twoFactorType, req.Email, urepo.Repo.Did).Error; err != nil { 68 + logger.Error("error updating repo", "error", err) 45 69 return helpers.ServerError(e, nil) 46 70 } 47 71
+96 -10
server/handle_sync_get_blob.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 + "fmt" 6 + "io" 5 7 8 + "github.com/Azure/go-autorest/autorest/to" 9 + "github.com/aws/aws-sdk-go/aws" 10 + "github.com/aws/aws-sdk-go/aws/credentials" 11 + "github.com/aws/aws-sdk-go/aws/session" 12 + "github.com/aws/aws-sdk-go/service/s3" 6 13 "github.com/haileyok/cocoon/internal/helpers" 7 14 "github.com/haileyok/cocoon/models" 8 15 "github.com/ipfs/go-cid" ··· 10 17 ) 11 18 12 19 func (s *Server) handleSyncGetBlob(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + logger := s.logger.With("name", "handleSyncGetBlob") 22 + 13 23 did := e.QueryParam("did") 14 24 if did == "" { 15 25 return helpers.InputError(e, nil) ··· 25 35 return helpers.InputError(e, nil) 26 36 } 27 37 38 + urepo, err := s.getRepoActorByDid(ctx, did) 39 + if err != nil { 40 + logger.Error("could not find user for requested blob", "error", err) 41 + return helpers.InputError(e, nil) 42 + } 43 + 44 + status := urepo.Status() 45 + if status != nil { 46 + if *status == "deactivated" { 47 + return helpers.InputError(e, to.StringPtr("RepoDeactivated")) 48 + } 49 + } 50 + 28 51 var blob models.Blob 29 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 30 - s.logger.Error("error looking up blob", "error", err) 52 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 53 + logger.Error("error looking up blob", "error", err) 31 54 return helpers.ServerError(e, nil) 32 55 } 33 56 34 57 buf := new(bytes.Buffer) 35 58 36 - var parts []models.BlobPart 37 - if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 38 - s.logger.Error("error getting blob parts", "error", err) 39 - return helpers.ServerError(e, nil) 40 - } 59 + if blob.Storage == "sqlite" { 60 + var parts []models.BlobPart 61 + if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 62 + logger.Error("error getting blob parts", "error", err) 63 + return helpers.ServerError(e, nil) 64 + } 41 65 42 - // TODO: we can just stream this, don't need to make a buffer 43 - for _, p := range parts { 44 - buf.Write(p.Data) 66 + // TODO: we can just stream this, don't need to make a buffer 67 + for _, p := range parts { 68 + buf.Write(p.Data) 69 + } 70 + } else if blob.Storage == "s3" { 71 + if !(s.s3Config != nil && s.s3Config.BlobstoreEnabled) { 72 + logger.Error("s3 storage disabled") 73 + return helpers.ServerError(e, nil) 74 + } 75 + 76 + blobKey := fmt.Sprintf("blobs/%s/%s", urepo.Repo.Did, c.String()) 77 + 78 + if s.s3Config.CDNUrl != "" { 79 + redirectUrl := fmt.Sprintf("%s/%s", s.s3Config.CDNUrl, blobKey) 80 + return e.Redirect(302, redirectUrl) 81 + } 82 + 83 + config := &aws.Config{ 84 + Region: aws.String(s.s3Config.Region), 85 + Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 86 + } 87 + 88 + if s.s3Config.Endpoint != "" { 89 + config.Endpoint = aws.String(s.s3Config.Endpoint) 90 + config.S3ForcePathStyle = aws.Bool(true) 91 + } 92 + 93 + sess, err := session.NewSession(config) 94 + if err != nil { 95 + logger.Error("error creating aws session", "error", err) 96 + return helpers.ServerError(e, nil) 97 + } 98 + 99 + svc := s3.New(sess) 100 + if result, err := svc.GetObject(&s3.GetObjectInput{ 101 + Bucket: aws.String(s.s3Config.Bucket), 102 + Key: aws.String(blobKey), 103 + }); err != nil { 104 + logger.Error("error getting blob from s3", "error", err) 105 + return helpers.ServerError(e, nil) 106 + } else { 107 + read := 0 108 + part := 0 109 + partBuf := make([]byte, 0x10000) 110 + 111 + for { 112 + n, err := io.ReadFull(result.Body, partBuf) 113 + if err == io.ErrUnexpectedEOF || err == io.EOF { 114 + if n == 0 { 115 + break 116 + } 117 + } else if err != nil && err != io.ErrUnexpectedEOF { 118 + logger.Error("error reading blob", "error", err) 119 + return helpers.ServerError(e, nil) 120 + } 121 + 122 + data := partBuf[:n] 123 + read += n 124 + buf.Write(data) 125 + part++ 126 + } 127 + } 128 + } else { 129 + logger.Error("unknown storage", "storage", blob.Storage) 130 + return helpers.ServerError(e, nil) 45 131 } 46 132 47 133 e.Response().Header().Set(echo.HeaderContentDisposition, "attachment; filename="+c.String())
+16 -13
server/handle_sync_get_blocks.go
··· 2 2 3 3 import ( 4 4 "bytes" 5 - "context" 6 - "strings" 7 5 8 6 "github.com/bluesky-social/indigo/carstore" 9 - "github.com/haileyok/cocoon/blockstore" 10 7 "github.com/haileyok/cocoon/internal/helpers" 11 8 "github.com/ipfs/go-cid" 12 9 cbor "github.com/ipfs/go-ipld-cbor" 13 10 "github.com/ipld/go-car" 14 11 "github.com/labstack/echo/v4" 15 12 ) 13 + 14 + type ComAtprotoSyncGetBlocksRequest struct { 15 + Did string `query:"did"` 16 + Cids []string `query:"cids"` 17 + } 16 18 17 19 func (s *Server) handleGetBlocks(e echo.Context) error { 18 - did := e.QueryParam("did") 19 - cidsstr := e.QueryParam("cids") 20 - if did == "" { 20 + ctx := e.Request().Context() 21 + logger := s.logger.With("name", "handleSyncGetBlocks") 22 + 23 + var req ComAtprotoSyncGetBlocksRequest 24 + if err := e.Bind(&req); err != nil { 21 25 return helpers.InputError(e, nil) 22 26 } 23 27 24 - cidstrs := strings.Split(cidsstr, ",") 25 - cids := []cid.Cid{} 28 + var cids []cid.Cid 26 29 27 - for _, cs := range cidstrs { 30 + for _, cs := range req.Cids { 28 31 c, err := cid.Cast([]byte(cs)) 29 32 if err != nil { 30 33 return err ··· 33 36 cids = append(cids, c) 34 37 } 35 38 36 - urepo, err := s.getRepoActorByDid(did) 39 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 37 40 if err != nil { 38 41 return helpers.ServerError(e, nil) 39 42 } ··· 50 53 }) 51 54 52 55 if _, err := carstore.LdWrite(buf, hb); err != nil { 53 - s.logger.Error("error writing to car", "error", err) 56 + logger.Error("error writing to car", "error", err) 54 57 return helpers.ServerError(e, nil) 55 58 } 56 59 57 - bs := blockstore.New(urepo.Repo.Did, s.db) 60 + bs := s.getBlockstore(urepo.Repo.Did) 58 61 59 62 for _, c := range cids { 60 - b, err := bs.Get(context.TODO(), c) 63 + b, err := bs.Get(ctx, c) 61 64 if err != nil { 62 65 return err 63 66 }
+3 -1
server/handle_sync_get_latest_commit.go
··· 12 12 } 13 13 14 14 func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { 15 + ctx := e.Request().Context() 16 + 15 17 did := e.QueryParam("did") 16 18 if did == "" { 17 19 return helpers.InputError(e, nil) 18 20 } 19 21 20 - urepo, err := s.getRepoActorByDid(did) 22 + urepo, err := s.getRepoActorByDid(ctx, did) 21 23 if err != nil { 22 24 return err 23 25 }
+8 -5
server/handle_sync_get_record.go
··· 13 13 ) 14 14 15 15 func (s *Server) handleSyncGetRecord(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + logger := s.logger.With("name", "handleSyncGetRecord") 18 + 16 19 did := e.QueryParam("did") 17 20 collection := e.QueryParam("collection") 18 21 rkey := e.QueryParam("rkey") 19 22 20 23 var urepo models.Repo 21 - if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 22 - s.logger.Error("error getting repo", "error", err) 24 + if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 25 + logger.Error("error getting repo", "error", err) 23 26 return helpers.ServerError(e, nil) 24 27 } 25 28 26 - root, blocks, err := s.repoman.getRecordProof(urepo, collection, rkey) 29 + root, blocks, err := s.repoman.getRecordProof(ctx, urepo, collection, rkey) 27 30 if err != nil { 28 31 return err 29 32 } ··· 36 39 }) 37 40 38 41 if _, err := carstore.LdWrite(buf, hb); err != nil { 39 - s.logger.Error("error writing to car", "error", err) 42 + logger.Error("error writing to car", "error", err) 40 43 return helpers.ServerError(e, nil) 41 44 } 42 45 43 46 for _, blk := range blocks { 44 47 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 45 - s.logger.Error("error writing to car", "error", err) 48 + logger.Error("error writing to car", "error", err) 46 49 return helpers.ServerError(e, nil) 47 50 } 48 51 }
+6 -3
server/handle_sync_get_repo.go
··· 13 13 ) 14 14 15 15 func (s *Server) handleSyncGetRepo(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + logger := s.logger.With("name", "handleSyncGetRepo") 18 + 16 19 did := e.QueryParam("did") 17 20 if did == "" { 18 21 return helpers.InputError(e, nil) 19 22 } 20 23 21 - urepo, err := s.getRepoActorByDid(did) 24 + urepo, err := s.getRepoActorByDid(ctx, did) 22 25 if err != nil { 23 26 return err 24 27 } ··· 36 39 buf := new(bytes.Buffer) 37 40 38 41 if _, err := carstore.LdWrite(buf, hb); err != nil { 39 - s.logger.Error("error writing to car", "error", err) 42 + logger.Error("error writing to car", "error", err) 40 43 return helpers.ServerError(e, nil) 41 44 } 42 45 43 46 var blocks []models.Block 44 - if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 47 + if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 45 48 return err 46 49 } 47 50
+5 -3
server/handle_sync_get_repo_status.go
··· 14 14 15 15 // TODO: make this actually do the right thing 16 16 func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 did := e.QueryParam("did") 18 20 if did == "" { 19 21 return helpers.InputError(e, nil) 20 22 } 21 23 22 - urepo, err := s.getRepoActorByDid(did) 24 + urepo, err := s.getRepoActorByDid(ctx, did) 23 25 if err != nil { 24 26 return err 25 27 } 26 28 27 29 return e.JSON(200, ComAtprotoSyncGetRepoStatusResponse{ 28 30 Did: urepo.Repo.Did, 29 - Active: true, 30 - Status: nil, 31 + Active: urepo.Active(), 32 + Status: urepo.Status(), 31 33 Rev: &urepo.Rev, 32 34 }) 33 35 }
+20 -3
server/handle_sync_list_blobs.go
··· 1 1 package server 2 2 3 3 import ( 4 + "github.com/Azure/go-autorest/autorest/to" 4 5 "github.com/haileyok/cocoon/internal/helpers" 5 6 "github.com/haileyok/cocoon/models" 6 7 "github.com/ipfs/go-cid" ··· 13 14 } 14 15 15 16 func (s *Server) handleSyncListBlobs(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + logger := s.logger.With("name", "handleSyncListBlobs") 19 + 16 20 did := e.QueryParam("did") 17 21 if did == "" { 18 22 return helpers.InputError(e, nil) ··· 34 38 } 35 39 params = append(params, limit) 36 40 41 + urepo, err := s.getRepoActorByDid(ctx, did) 42 + if err != nil { 43 + logger.Error("could not find user for requested blobs", "error", err) 44 + return helpers.InputError(e, nil) 45 + } 46 + 47 + status := urepo.Status() 48 + if status != nil { 49 + if *status == "deactivated" { 50 + return helpers.InputError(e, to.StringPtr("RepoDeactivated")) 51 + } 52 + } 53 + 37 54 var blobs []models.Blob 38 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 39 - s.logger.Error("error getting records", "error", err) 55 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 56 + logger.Error("error getting records", "error", err) 40 57 return helpers.ServerError(e, nil) 41 58 } 42 59 ··· 44 61 for _, b := range blobs { 45 62 c, err := cid.Cast(b.Cid) 46 63 if err != nil { 47 - s.logger.Error("error casting cid", "error", err) 64 + logger.Error("error casting cid", "error", err) 48 65 return helpers.ServerError(e, nil) 49 66 } 50 67 cstrs = append(cstrs, c.String())
+70 -56
server/handle_sync_subscribe_repos.go
··· 1 1 package server 2 2 3 3 import ( 4 - "fmt" 5 - "net/http" 4 + "context" 5 + "time" 6 6 7 7 "github.com/bluesky-social/indigo/events" 8 8 "github.com/bluesky-social/indigo/lex/util" 9 9 "github.com/btcsuite/websocket" 10 + "github.com/haileyok/cocoon/metrics" 10 11 "github.com/labstack/echo/v4" 11 12 ) 12 13 13 - var upgrader = websocket.Upgrader{ 14 - ReadBufferSize: 1024, 15 - WriteBufferSize: 1024, 16 - CheckOrigin: func(r *http.Request) bool { 17 - return true 18 - }, 19 - } 20 - 21 14 func (s *Server) handleSyncSubscribeRepos(e echo.Context) error { 15 + ctx := e.Request().Context() 16 + logger := s.logger.With("component", "subscribe-repos-websocket") 17 + 22 18 conn, err := websocket.Upgrade(e.Response().Writer, e.Request(), e.Response().Header(), 1<<10, 1<<10) 23 19 if err != nil { 20 + logger.Error("unable to establish websocket with relay", "err", err) 24 21 return err 25 22 } 26 23 27 - s.logger.Info("new connection", "ua", e.Request().UserAgent()) 28 - 29 - ctx := e.Request().Context() 30 - 31 24 ident := e.RealIP() + "-" + e.Request().UserAgent() 25 + logger = logger.With("ident", ident) 26 + logger.Info("new connection established") 27 + 28 + metrics.RelaysConnected.WithLabelValues(ident).Inc() 29 + defer func() { 30 + metrics.RelaysConnected.WithLabelValues(ident).Dec() 31 + }() 32 32 33 33 evts, cancel, err := s.evtman.Subscribe(ctx, ident, func(evt *events.XRPCStreamEvent) bool { 34 34 return true ··· 40 40 41 41 header := events.EventHeader{Op: events.EvtKindMessage} 42 42 for evt := range evts { 43 - wc, err := conn.NextWriter(websocket.BinaryMessage) 44 - if err != nil { 45 - return err 46 - } 43 + func() { 44 + defer func() { 45 + metrics.RelaySends.WithLabelValues(ident, header.MsgType).Inc() 46 + }() 47 47 48 - var obj util.CBOR 48 + wc, err := conn.NextWriter(websocket.BinaryMessage) 49 + if err != nil { 50 + logger.Error("error writing message to relay", "err", err) 51 + return 52 + } 49 53 50 - switch { 51 - case evt.Error != nil: 52 - header.Op = events.EvtKindErrorFrame 53 - obj = evt.Error 54 - case evt.RepoCommit != nil: 55 - header.MsgType = "#commit" 56 - obj = evt.RepoCommit 57 - case evt.RepoHandle != nil: 58 - header.MsgType = "#handle" 59 - obj = evt.RepoHandle 60 - case evt.RepoIdentity != nil: 61 - header.MsgType = "#identity" 62 - obj = evt.RepoIdentity 63 - case evt.RepoAccount != nil: 64 - header.MsgType = "#account" 65 - obj = evt.RepoAccount 66 - case evt.RepoInfo != nil: 67 - header.MsgType = "#info" 68 - obj = evt.RepoInfo 69 - case evt.RepoMigrate != nil: 70 - header.MsgType = "#migrate" 71 - obj = evt.RepoMigrate 72 - case evt.RepoTombstone != nil: 73 - header.MsgType = "#tombstone" 74 - obj = evt.RepoTombstone 75 - default: 76 - return fmt.Errorf("unrecognized event kind") 77 - } 54 + if ctx.Err() != nil { 55 + logger.Error("context error", "err", err) 56 + return 57 + } 58 + 59 + var obj util.CBOR 60 + switch { 61 + case evt.Error != nil: 62 + header.Op = events.EvtKindErrorFrame 63 + obj = evt.Error 64 + case evt.RepoCommit != nil: 65 + header.MsgType = "#commit" 66 + obj = evt.RepoCommit 67 + case evt.RepoIdentity != nil: 68 + header.MsgType = "#identity" 69 + obj = evt.RepoIdentity 70 + case evt.RepoAccount != nil: 71 + header.MsgType = "#account" 72 + obj = evt.RepoAccount 73 + case evt.RepoInfo != nil: 74 + header.MsgType = "#info" 75 + obj = evt.RepoInfo 76 + default: 77 + logger.Warn("unrecognized event kind") 78 + return 79 + } 80 + 81 + if err := header.MarshalCBOR(wc); err != nil { 82 + logger.Error("failed to write header to relay", "err", err) 83 + return 84 + } 78 85 79 - if err := header.MarshalCBOR(wc); err != nil { 80 - return fmt.Errorf("failed to write header: %w", err) 81 - } 86 + if err := obj.MarshalCBOR(wc); err != nil { 87 + logger.Error("failed to write event to relay", "err", err) 88 + return 89 + } 82 90 83 - if err := obj.MarshalCBOR(wc); err != nil { 84 - return fmt.Errorf("failed to write event: %w", err) 85 - } 91 + if err := wc.Close(); err != nil { 92 + logger.Error("failed to flush-close our event write", "err", err) 93 + return 94 + } 95 + }() 96 + } 86 97 87 - if err := wc.Close(); err != nil { 88 - return fmt.Errorf("failed to flush-close our event write: %w", err) 89 - } 98 + // we should tell the relay to request a new crawl at this point if we got disconnected 99 + // use a new context since the old one might be cancelled at this point 100 + ctx, cancel = context.WithTimeout(context.Background(), 10*time.Second) 101 + defer cancel() 102 + if err := s.requestCrawl(ctx); err != nil { 103 + logger.Error("error requesting crawls", "err", err) 90 104 } 91 105 92 106 return nil
+36
server/handle_well_known.go
··· 2 2 3 3 import ( 4 4 "fmt" 5 + "strings" 5 6 6 7 "github.com/Azure/go-autorest/autorest/to" 8 + "github.com/haileyok/cocoon/internal/helpers" 7 9 "github.com/labstack/echo/v4" 10 + "gorm.io/gorm" 8 11 ) 9 12 10 13 var ( ··· 61 64 }, 62 65 }, 63 66 }) 67 + } 68 + 69 + func (s *Server) handleAtprotoDid(e echo.Context) error { 70 + ctx := e.Request().Context() 71 + logger := s.logger.With("name", "handleAtprotoDid") 72 + 73 + host := e.Request().Host 74 + if host == "" { 75 + return helpers.InputError(e, to.StringPtr("Invalid handle.")) 76 + } 77 + 78 + host = strings.Split(host, ":")[0] 79 + host = strings.ToLower(strings.TrimSpace(host)) 80 + 81 + if host == s.config.Hostname { 82 + return e.String(200, s.config.Did) 83 + } 84 + 85 + suffix := "." + s.config.Hostname 86 + if !strings.HasSuffix(host, suffix) { 87 + return e.NoContent(404) 88 + } 89 + 90 + actor, err := s.getActorByHandle(ctx, host) 91 + if err != nil { 92 + if err == gorm.ErrRecordNotFound { 93 + return e.NoContent(404) 94 + } 95 + logger.Error("error looking up actor by handle", "error", err) 96 + return helpers.ServerError(e, nil) 97 + } 98 + 99 + return e.String(200, actor.Did) 64 100 } 65 101 66 102 func (s *Server) handleOauthProtectedResource(e echo.Context) error {
+38
server/mail.go
··· 40 40 return nil 41 41 } 42 42 43 + func (s *Server) sendPlcTokenReset(email, handle, code string) error { 44 + if s.mail == nil { 45 + return nil 46 + } 47 + 48 + s.mailLk.Lock() 49 + defer s.mailLk.Unlock() 50 + 51 + s.mail.To(email) 52 + s.mail.Subject("PLC token for " + s.config.Hostname) 53 + s.mail.Plain().Set(fmt.Sprintf("Hello %s. Your PLC operation code is %s. This code will expire in ten minutes.", handle, code)) 54 + 55 + if err := s.mail.Send(); err != nil { 56 + return err 57 + } 58 + 59 + return nil 60 + } 61 + 43 62 func (s *Server) sendEmailUpdate(email, handle, code string) error { 44 63 if s.mail == nil { 45 64 return nil ··· 77 96 78 97 return nil 79 98 } 99 + 100 + func (s *Server) sendTwoFactorCode(email, handle, code string) error { 101 + if s.mail == nil { 102 + return nil 103 + } 104 + 105 + s.mailLk.Lock() 106 + defer s.mailLk.Unlock() 107 + 108 + s.mail.To(email) 109 + s.mail.Subject("2FA code for " + s.config.Hostname) 110 + s.mail.Plain().Set(fmt.Sprintf("Hello %s. Your 2FA code is %s. This code will expire in ten minutes.", handle, code)) 111 + 112 + if err := s.mail.Send(); err != nil { 113 + return err 114 + } 115 + 116 + return nil 117 + }
+303
server/middleware.go
··· 1 + package server 2 + 3 + import ( 4 + "crypto/sha256" 5 + "encoding/base64" 6 + "errors" 7 + "fmt" 8 + "strings" 9 + "time" 10 + 11 + "github.com/Azure/go-autorest/autorest/to" 12 + "github.com/golang-jwt/jwt/v4" 13 + "github.com/haileyok/cocoon/internal/helpers" 14 + "github.com/haileyok/cocoon/models" 15 + "github.com/haileyok/cocoon/oauth/dpop" 16 + "github.com/haileyok/cocoon/oauth/provider" 17 + "github.com/labstack/echo/v4" 18 + "gitlab.com/yawning/secp256k1-voi" 19 + secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 + "gorm.io/gorm" 21 + ) 22 + 23 + func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 + return func(e echo.Context) error { 25 + username, password, ok := e.Request().BasicAuth() 26 + if !ok || username != "admin" || password != s.config.AdminPassword { 27 + return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 + } 29 + 30 + if err := next(e); err != nil { 31 + e.Error(err) 32 + } 33 + 34 + return nil 35 + } 36 + } 37 + 38 + func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 + return func(e echo.Context) error { 40 + ctx := e.Request().Context() 41 + logger := s.logger.With("name", "handleLegacySessionMiddleware") 42 + 43 + authheader := e.Request().Header.Get("authorization") 44 + if authheader == "" { 45 + return e.JSON(401, map[string]string{"error": "Unauthorized"}) 46 + } 47 + 48 + pts := strings.Split(authheader, " ") 49 + if len(pts) != 2 { 50 + return helpers.ServerError(e, nil) 51 + } 52 + 53 + // move on to oauth session middleware if this is a dpop token 54 + if pts[0] == "DPoP" { 55 + return next(e) 56 + } 57 + 58 + tokenstr := pts[1] 59 + token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 60 + claims, ok := token.Claims.(jwt.MapClaims) 61 + if !ok { 62 + return helpers.InvalidTokenError(e) 63 + } 64 + 65 + var did string 66 + var repo *models.RepoActor 67 + 68 + // service auth tokens 69 + lxm, hasLxm := claims["lxm"] 70 + if hasLxm { 71 + pts := strings.Split(e.Request().URL.String(), "/") 72 + if lxm != pts[len(pts)-1] { 73 + logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 74 + return helpers.InputError(e, nil) 75 + } 76 + 77 + maybeDid, ok := claims["iss"].(string) 78 + if !ok { 79 + logger.Error("no iss in service auth token", "error", err) 80 + return helpers.InputError(e, nil) 81 + } 82 + did = maybeDid 83 + 84 + maybeRepo, err := s.getRepoActorByDid(ctx, did) 85 + if err != nil { 86 + logger.Error("error fetching repo", "error", err) 87 + return helpers.ServerError(e, nil) 88 + } 89 + repo = maybeRepo 90 + } 91 + 92 + if token.Header["alg"] != "ES256K" { 93 + token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 94 + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 95 + return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 96 + } 97 + return s.privateKey.Public(), nil 98 + }) 99 + if err != nil { 100 + logger.Error("error parsing jwt", "error", err) 101 + return helpers.ExpiredTokenError(e) 102 + } 103 + 104 + if !token.Valid { 105 + return helpers.InvalidTokenError(e) 106 + } 107 + } else { 108 + kpts := strings.Split(tokenstr, ".") 109 + signingInput := kpts[0] + "." + kpts[1] 110 + hash := sha256.Sum256([]byte(signingInput)) 111 + sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 112 + if err != nil { 113 + logger.Error("error decoding signature bytes", "error", err) 114 + return helpers.ServerError(e, nil) 115 + } 116 + 117 + if len(sigBytes) != 64 { 118 + logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 119 + return helpers.ServerError(e, nil) 120 + } 121 + 122 + rBytes := sigBytes[:32] 123 + sBytes := sigBytes[32:] 124 + rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 125 + ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 126 + 127 + if repo == nil { 128 + sub, ok := claims["sub"].(string) 129 + if !ok { 130 + s.logger.Error("no sub claim in ES256K token and repo not set") 131 + return helpers.InvalidTokenError(e) 132 + } 133 + maybeRepo, err := s.getRepoActorByDid(ctx, sub) 134 + if err != nil { 135 + s.logger.Error("error fetching repo for ES256K verification", "error", err) 136 + return helpers.ServerError(e, nil) 137 + } 138 + repo = maybeRepo 139 + did = sub 140 + } 141 + 142 + sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 143 + if err != nil { 144 + logger.Error("can't load private key", "error", err) 145 + return err 146 + } 147 + 148 + pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 149 + if !ok { 150 + logger.Error("error getting public key from sk") 151 + return helpers.ServerError(e, nil) 152 + } 153 + 154 + verified := pubKey.VerifyRaw(hash[:], rr, ss) 155 + if !verified { 156 + logger.Error("error verifying", "error", err) 157 + return helpers.ServerError(e, nil) 158 + } 159 + } 160 + 161 + isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 162 + scope, _ := claims["scope"].(string) 163 + 164 + if isRefresh && scope != "com.atproto.refresh" { 165 + return helpers.InvalidTokenError(e) 166 + } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 167 + return helpers.InvalidTokenError(e) 168 + } 169 + 170 + table := "tokens" 171 + if isRefresh { 172 + table = "refresh_tokens" 173 + } 174 + 175 + if isRefresh { 176 + type Result struct { 177 + Found bool 178 + } 179 + var result Result 180 + if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 181 + if err == gorm.ErrRecordNotFound { 182 + return helpers.InvalidTokenError(e) 183 + } 184 + 185 + logger.Error("error getting token from db", "error", err) 186 + return helpers.ServerError(e, nil) 187 + } 188 + 189 + if !result.Found { 190 + return helpers.InvalidTokenError(e) 191 + } 192 + } 193 + 194 + exp, ok := claims["exp"].(float64) 195 + if !ok { 196 + logger.Error("error getting iat from token") 197 + return helpers.ServerError(e, nil) 198 + } 199 + 200 + if exp < float64(time.Now().UTC().Unix()) { 201 + return helpers.ExpiredTokenError(e) 202 + } 203 + 204 + if repo == nil { 205 + maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 206 + if err != nil { 207 + logger.Error("error fetching repo", "error", err) 208 + return helpers.ServerError(e, nil) 209 + } 210 + repo = maybeRepo 211 + did = repo.Repo.Did 212 + } 213 + 214 + e.Set("repo", repo) 215 + e.Set("did", did) 216 + e.Set("token", tokenstr) 217 + 218 + if err := next(e); err != nil { 219 + return helpers.InvalidTokenError(e) 220 + } 221 + 222 + return nil 223 + } 224 + } 225 + 226 + func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 227 + return func(e echo.Context) error { 228 + ctx := e.Request().Context() 229 + logger := s.logger.With("name", "handleOauthSessionMiddleware") 230 + 231 + authheader := e.Request().Header.Get("authorization") 232 + if authheader == "" { 233 + return e.JSON(401, map[string]string{"error": "Unauthorized"}) 234 + } 235 + 236 + pts := strings.Split(authheader, " ") 237 + if len(pts) != 2 { 238 + return helpers.ServerError(e, nil) 239 + } 240 + 241 + if pts[0] != "DPoP" { 242 + return next(e) 243 + } 244 + 245 + accessToken := pts[1] 246 + 247 + nonce := s.oauthProvider.NextNonce() 248 + if nonce != "" { 249 + e.Response().Header().Set("DPoP-Nonce", nonce) 250 + e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 251 + } 252 + 253 + proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 254 + if err != nil { 255 + if errors.Is(err, dpop.ErrUseDpopNonce) { 256 + e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 257 + e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 258 + return e.JSON(401, map[string]string{ 259 + "error": "use_dpop_nonce", 260 + }) 261 + } 262 + logger.Error("invalid dpop proof", "error", err) 263 + return helpers.InputError(e, nil) 264 + } 265 + 266 + var oauthToken provider.OauthToken 267 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 268 + logger.Error("error finding access token in db", "error", err) 269 + return helpers.InputError(e, nil) 270 + } 271 + 272 + if oauthToken.Token == "" { 273 + return helpers.InvalidTokenError(e) 274 + } 275 + 276 + if *oauthToken.Parameters.DpopJkt != proof.JKT { 277 + logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 278 + return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 279 + } 280 + 281 + if time.Now().After(oauthToken.ExpiresAt) { 282 + e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 283 + e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 284 + return e.JSON(401, map[string]string{ 285 + "error": "invalid_token", 286 + "error_description": "Token expired", 287 + }) 288 + } 289 + 290 + repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 291 + if err != nil { 292 + logger.Error("could not find actor in db", "error", err) 293 + return helpers.ServerError(e, nil) 294 + } 295 + 296 + e.Set("repo", repo) 297 + e.Set("did", repo.Repo.Did) 298 + e.Set("token", accessToken) 299 + e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 300 + 301 + return next(e) 302 + } 303 + }
+108 -49
server/repo.go
··· 10 10 11 11 "github.com/Azure/go-autorest/autorest/to" 12 12 "github.com/bluesky-social/indigo/api/atproto" 13 - "github.com/bluesky-social/indigo/atproto/data" 13 + "github.com/bluesky-social/indigo/atproto/atdata" 14 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 15 "github.com/bluesky-social/indigo/carstore" 16 16 "github.com/bluesky-social/indigo/events" 17 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 18 "github.com/bluesky-social/indigo/repo" 19 - "github.com/bluesky-social/indigo/util" 20 - "github.com/haileyok/cocoon/blockstore" 21 19 "github.com/haileyok/cocoon/internal/db" 20 + "github.com/haileyok/cocoon/metrics" 22 21 "github.com/haileyok/cocoon/models" 22 + "github.com/haileyok/cocoon/recording_blockstore" 23 23 blocks "github.com/ipfs/go-block-format" 24 24 "github.com/ipfs/go-cid" 25 25 cbor "github.com/ipfs/go-ipld-cbor" ··· 73 73 } 74 74 75 75 func (mm *MarshalableMap) MarshalCBOR(w io.Writer) error { 76 - data, err := data.MarshalCBOR(*mm) 76 + data, err := atdata.MarshalCBOR(*mm) 77 77 if err != nil { 78 78 return err 79 79 } ··· 97 97 } 98 98 99 99 // TODO make use of swap commit 100 - func (rm *RepoMan) applyWrites(urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 100 + func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes []Op, swapCommit *string) ([]ApplyWriteResult, error) { 101 101 rootcid, err := cid.Cast(urepo.Root) 102 102 if err != nil { 103 103 return nil, err 104 104 } 105 105 106 - dbs := blockstore.New(urepo.Did, rm.db) 107 - r, err := repo.OpenRepo(context.TODO(), dbs, rootcid) 106 + dbs := rm.s.getBlockstore(urepo.Did) 107 + bs := recording_blockstore.New(dbs) 108 + r, err := repo.OpenRepo(ctx, bs, rootcid) 108 109 109 - entries := []models.Record{} 110 110 var results []ApplyWriteResult 111 111 112 + entries := make([]models.Record, 0, len(writes)) 112 113 for i, op := range writes { 114 + // updates or deletes must supply an rkey 113 115 if op.Type != OpTypeCreate && op.Rkey == nil { 114 116 return nil, fmt.Errorf("invalid rkey") 115 117 } else if op.Type == OpTypeCreate && op.Rkey != nil { 116 - _, _, err := r.GetRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 118 + // we should conver this op to an update if the rkey already exists 119 + _, _, err := r.GetRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)) 117 120 if err == nil { 118 121 op.Type = OpTypeUpdate 119 122 } 120 123 } else if op.Rkey == nil { 124 + // creates that don't supply an rkey will have one generated for them 121 125 op.Rkey = to.StringPtr(rm.clock.Next().String()) 122 126 writes[i].Rkey = op.Rkey 123 127 } 124 128 129 + // validate the record key is actually valid 125 130 _, err := syntax.ParseRecordKey(*op.Rkey) 126 131 if err != nil { 127 132 return nil, err ··· 129 134 130 135 switch op.Type { 131 136 case OpTypeCreate: 132 - j, err := json.Marshal(*op.Record) 137 + // HACK: this fixes some type conversions, mainly around integers 138 + // first we convert to json bytes 139 + b, err := json.Marshal(*op.Record) 133 140 if err != nil { 134 141 return nil, err 135 142 } 136 - out, err := data.UnmarshalJSON(j) 143 + // then we use atdata.UnmarshalJSON to convert it back to a map 144 + out, err := atdata.UnmarshalJSON(b) 137 145 if err != nil { 138 146 return nil, err 139 147 } 148 + // finally we can cast to a MarshalableMap 140 149 mm := MarshalableMap(out) 141 - nc, err := r.PutRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 150 + 151 + // HACK: if a record doesn't contain a $type, we can manually set it here based on the op's collection 152 + // i forget why this is actually necessary? 153 + if mm["$type"] == "" { 154 + mm["$type"] = op.Collection 155 + } 156 + 157 + nc, err := r.PutRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm) 142 158 if err != nil { 143 159 return nil, err 144 160 } 145 - d, err := data.MarshalCBOR(mm) 161 + 162 + d, err := atdata.MarshalCBOR(mm) 146 163 if err != nil { 147 164 return nil, err 148 165 } 166 + 149 167 entries = append(entries, models.Record{ 150 168 Did: urepo.Did, 151 169 CreatedAt: rm.clock.Next().String(), ··· 154 172 Cid: nc.String(), 155 173 Value: d, 156 174 }) 175 + 157 176 results = append(results, ApplyWriteResult{ 158 177 Type: to.StringPtr(OpTypeCreate.String()), 159 178 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), ··· 161 180 ValidationStatus: to.StringPtr("valid"), // TODO: obviously this might not be true atm lol 162 181 }) 163 182 case OpTypeDelete: 183 + // try to find the old record in the database 164 184 var old models.Record 165 - if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 185 + if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 166 186 return nil, err 167 187 } 188 + 189 + // TODO: this is really confusing, and looking at it i have no idea why i did this. below when we are doing deletes, we 190 + // check if `cid` here is nil to indicate if we should delete. that really doesn't make much sense and its super illogical 191 + // when reading this code. i dont feel like fixing right now though so 168 192 entries = append(entries, models.Record{ 169 193 Did: urepo.Did, 170 194 Nsid: op.Collection, 171 195 Rkey: *op.Rkey, 172 196 Value: old.Value, 173 197 }) 174 - err := r.DeleteRecord(context.TODO(), op.Collection+"/"+*op.Rkey) 198 + 199 + // delete the record from the repo 200 + err := r.DeleteRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey)) 175 201 if err != nil { 176 202 return nil, err 177 203 } 204 + 205 + // add a result for the delete 178 206 results = append(results, ApplyWriteResult{ 179 207 Type: to.StringPtr(OpTypeDelete.String()), 180 208 }) 181 209 case OpTypeUpdate: 182 - j, err := json.Marshal(*op.Record) 210 + // HACK: same hack as above for type fixes 211 + b, err := json.Marshal(*op.Record) 183 212 if err != nil { 184 213 return nil, err 185 214 } 186 - out, err := data.UnmarshalJSON(j) 215 + out, err := atdata.UnmarshalJSON(b) 187 216 if err != nil { 188 217 return nil, err 189 218 } 190 219 mm := MarshalableMap(out) 191 - nc, err := r.UpdateRecord(context.TODO(), op.Collection+"/"+*op.Rkey, &mm) 220 + 221 + nc, err := r.UpdateRecord(ctx, fmt.Sprintf("%s/%s", op.Collection, *op.Rkey), &mm) 192 222 if err != nil { 193 223 return nil, err 194 224 } 195 - d, err := data.MarshalCBOR(mm) 225 + 226 + d, err := atdata.MarshalCBOR(mm) 196 227 if err != nil { 197 228 return nil, err 198 229 } 230 + 199 231 entries = append(entries, models.Record{ 200 232 Did: urepo.Did, 201 233 CreatedAt: rm.clock.Next().String(), ··· 204 236 Cid: nc.String(), 205 237 Value: d, 206 238 }) 239 + 207 240 results = append(results, ApplyWriteResult{ 208 241 Type: to.StringPtr(OpTypeUpdate.String()), 209 242 Uri: to.StringPtr("at://" + urepo.Did + "/" + op.Collection + "/" + *op.Rkey), ··· 213 246 } 214 247 } 215 248 216 - newroot, rev, err := r.Commit(context.TODO(), urepo.SignFor) 249 + // commit and get the new root 250 + newroot, rev, err := r.Commit(ctx, urepo.SignFor) 217 251 if err != nil { 218 252 return nil, err 219 253 } 220 254 255 + for _, result := range results { 256 + if result.Type != nil { 257 + metrics.RepoOperations.WithLabelValues(*result.Type).Inc() 258 + } 259 + } 260 + 261 + // create a buffer for dumping our new cbor into 221 262 buf := new(bytes.Buffer) 222 263 264 + // first write the car header to the buffer 223 265 hb, err := cbor.DumpObject(&car.CarHeader{ 224 266 Roots: []cid.Cid{newroot}, 225 267 Version: 1, 226 268 }) 227 - 228 269 if _, err := carstore.LdWrite(buf, hb); err != nil { 229 270 return nil, err 230 271 } 231 272 232 - diffops, err := r.DiffSince(context.TODO(), rootcid) 273 + // get a diff of the changes to the repo 274 + diffops, err := r.DiffSince(ctx, rootcid) 233 275 if err != nil { 234 276 return nil, err 235 277 } 236 278 279 + // create the repo ops for the given diff 237 280 ops := make([]*atproto.SyncSubscribeRepos_RepoOp, 0, len(diffops)) 238 - 239 281 for _, op := range diffops { 240 282 var c cid.Cid 241 283 switch op.Op { ··· 264 306 }) 265 307 } 266 308 267 - blk, err := dbs.Get(context.TODO(), c) 309 + blk, err := dbs.Get(ctx, c) 268 310 if err != nil { 269 311 return nil, err 270 312 } 271 313 314 + // write the block to the buffer 272 315 if _, err := carstore.LdWrite(buf, blk.Cid().Bytes(), blk.RawData()); err != nil { 273 316 return nil, err 274 317 } 275 318 } 276 319 277 - for _, op := range dbs.GetLog() { 320 + // write the writelog to the buffer 321 + for _, op := range bs.GetWriteLog() { 278 322 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 279 323 return nil, err 280 324 } 281 325 } 282 326 327 + // blob blob blob blob blob :3 283 328 var blobs []lexutil.LexLink 284 329 for _, entry := range entries { 285 330 var cids []cid.Cid 331 + // whenever there is cid present, we know it's a create (dumb) 286 332 if entry.Cid != "" { 287 - if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ 333 + if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{ 288 334 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 289 335 UpdateAll: true, 290 336 }}).Error; err != nil { 291 337 return nil, err 292 338 } 293 339 294 - cids, err = rm.incrementBlobRefs(urepo, entry.Value) 340 + // increment the given blob refs, yay 341 + cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value) 295 342 if err != nil { 296 343 return nil, err 297 344 } 298 345 } else { 299 - if err := rm.s.db.Delete(&entry, nil).Error; err != nil { 346 + // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey 347 + // is did + collection + rkey. i still really want to separate that out, or use a different type to make 348 + // this less confusing/easy to read. alas, its 2 am and yea no 349 + if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil { 300 350 return nil, err 301 351 } 302 - cids, err = rm.decrementBlobRefs(urepo, entry.Value) 352 + 353 + // TODO: 354 + cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value) 303 355 if err != nil { 304 356 return nil, err 305 357 } 306 358 } 307 359 360 + // add all the relevant blobs to the blobs list of blobs. blob ^.^ 308 361 for _, c := range cids { 309 362 blobs = append(blobs, lexutil.LexLink(c)) 310 363 } 311 364 } 312 365 313 - rm.s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 366 + // NOTE: using the request ctx seems a bit suss here, so using a background context. i'm not sure if this 367 + // runs sync or not 368 + rm.s.evtman.AddEvent(context.Background(), &events.XRPCStreamEvent{ 314 369 RepoCommit: &atproto.SyncSubscribeRepos_Commit{ 315 370 Repo: urepo.Did, 316 371 Blocks: buf.Bytes(), ··· 318 373 Rev: rev, 319 374 Since: &urepo.Rev, 320 375 Commit: lexutil.LexLink(newroot), 321 - Time: time.Now().Format(util.ISO8601), 376 + Time: time.Now().Format(time.RFC3339Nano), 322 377 Ops: ops, 323 378 TooBig: false, 324 379 }, 325 380 }) 326 381 327 - if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil { 382 + if err := rm.s.UpdateRepo(ctx, urepo.Did, newroot, rev); err != nil { 328 383 return nil, err 329 384 } 330 385 ··· 339 394 return results, nil 340 395 } 341 396 342 - func (rm *RepoMan) getRecordProof(urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 397 + // this is a fun little guy. to get a proof, we need to read the record out of the blockstore and record how we actually 398 + // got to the guy. we'll wrap a new blockstore in a recording blockstore, then return the log for proof 399 + func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collection, rkey string) (cid.Cid, []blocks.Block, error) { 343 400 c, err := cid.Cast(urepo.Root) 344 401 if err != nil { 345 402 return cid.Undef, nil, err 346 403 } 347 404 348 - dbs := blockstore.New(urepo.Did, rm.db) 349 - bs := util.NewLoggingBstore(dbs) 405 + dbs := rm.s.getBlockstore(urepo.Did) 406 + bs := recording_blockstore.New(dbs) 350 407 351 - r, err := repo.OpenRepo(context.TODO(), bs, c) 408 + r, err := repo.OpenRepo(ctx, bs, c) 352 409 if err != nil { 353 410 return cid.Undef, nil, err 354 411 } 355 412 356 - _, _, err = r.GetRecordBytes(context.TODO(), collection+"/"+rkey) 413 + _, _, err = r.GetRecordBytes(ctx, fmt.Sprintf("%s/%s", collection, rkey)) 357 414 if err != nil { 358 415 return cid.Undef, nil, err 359 416 } 360 417 361 - return c, bs.GetLoggedBlocks(), nil 418 + return c, bs.GetReadLog(), nil 362 419 } 363 420 364 - func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 421 + func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 365 422 cids, err := getBlobCidsFromCbor(cbor) 366 423 if err != nil { 367 424 return nil, err 368 425 } 369 426 370 427 for _, c := range cids { 371 - if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 428 + if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 372 429 return nil, err 373 430 } 374 431 } ··· 376 433 return cids, nil 377 434 } 378 435 379 - func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 436 + func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 380 437 cids, err := getBlobCidsFromCbor(cbor) 381 438 if err != nil { 382 439 return nil, err ··· 387 444 ID uint 388 445 Count int 389 446 } 390 - if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 447 + if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 391 448 return nil, err 392 449 } 393 450 451 + // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what 452 + // storage it is in, and clean up s3!!!! 394 453 if res.Count == 0 { 395 - if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 454 + if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 396 455 return nil, err 397 456 } 398 - if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 457 + if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 399 458 return nil, err 400 459 } 401 460 } ··· 409 468 func getBlobCidsFromCbor(cbor []byte) ([]cid.Cid, error) { 410 469 var cids []cid.Cid 411 470 412 - decoded, err := data.UnmarshalCBOR(cbor) 471 + decoded, err := atdata.UnmarshalCBOR(cbor) 413 472 if err != nil { 414 473 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 415 474 } 416 475 417 - var deepiter func(interface{}) error 418 - deepiter = func(item interface{}) error { 476 + var deepiter func(any) error 477 + deepiter = func(item any) error { 419 478 switch val := item.(type) { 420 - case map[string]interface{}: 479 + case map[string]any: 421 480 if val["$type"] == "blob" { 422 481 if ref, ok := val["ref"].(string); ok { 423 482 c, err := cid.Parse(ref) ··· 430 489 return deepiter(v) 431 490 } 432 491 } 433 - case []interface{}: 492 + case []any: 434 493 for _, v := range val { 435 494 deepiter(v) 436 495 }
+171 -319
server/server.go
··· 4 4 "bytes" 5 5 "context" 6 6 "crypto/ecdsa" 7 - "crypto/sha256" 8 7 "embed" 9 - "encoding/base64" 10 8 "errors" 11 9 "fmt" 12 10 "io" ··· 15 13 "net/smtp" 16 14 "os" 17 15 "path/filepath" 18 - "strings" 19 16 "sync" 20 17 "text/template" 21 18 "time" 22 19 23 - "github.com/Azure/go-autorest/autorest/to" 24 20 "github.com/aws/aws-sdk-go/aws" 25 21 "github.com/aws/aws-sdk-go/aws/credentials" 26 22 "github.com/aws/aws-sdk-go/aws/session" ··· 32 28 "github.com/bluesky-social/indigo/xrpc" 33 29 "github.com/domodwyer/mailyak/v3" 34 30 "github.com/go-playground/validator" 35 - "github.com/golang-jwt/jwt/v4" 36 31 "github.com/gorilla/sessions" 37 32 "github.com/haileyok/cocoon/identity" 38 33 "github.com/haileyok/cocoon/internal/db" 39 34 "github.com/haileyok/cocoon/internal/helpers" 40 35 "github.com/haileyok/cocoon/models" 41 - "github.com/haileyok/cocoon/oauth/client_manager" 36 + "github.com/haileyok/cocoon/oauth/client" 42 37 "github.com/haileyok/cocoon/oauth/constants" 43 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 38 + "github.com/haileyok/cocoon/oauth/dpop" 44 39 "github.com/haileyok/cocoon/oauth/provider" 45 40 "github.com/haileyok/cocoon/plc" 41 + "github.com/ipfs/go-cid" 42 + "github.com/labstack/echo-contrib/echoprometheus" 46 43 echo_session "github.com/labstack/echo-contrib/session" 47 44 "github.com/labstack/echo/v4" 48 45 "github.com/labstack/echo/v4/middleware" 49 46 slogecho "github.com/samber/slog-echo" 50 - "gitlab.com/yawning/secp256k1-voi" 51 - secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 47 + "gorm.io/driver/postgres" 52 48 "gorm.io/driver/sqlite" 53 49 "gorm.io/gorm" 54 50 ) ··· 58 54 ) 59 55 60 56 type S3Config struct { 61 - BackupsEnabled bool 62 - Endpoint string 63 - Region string 64 - Bucket string 65 - AccessKey string 66 - SecretKey string 57 + BackupsEnabled bool 58 + BlobstoreEnabled bool 59 + Endpoint string 60 + Region string 61 + Bucket string 62 + AccessKey string 63 + SecretKey string 64 + CDNUrl string 67 65 } 68 66 69 67 type Server struct { ··· 81 79 oauthProvider *provider.Provider 82 80 evtman *events.EventManager 83 81 passport *identity.Passport 82 + fallbackProxy string 83 + 84 + lastRequestCrawl time.Time 85 + requestCrawlMu sync.Mutex 84 86 85 87 dbName string 88 + dbType string 86 89 s3Config *S3Config 87 90 } 88 91 89 92 type Args struct { 93 + Logger *slog.Logger 94 + 90 95 Addr string 91 96 DbName string 92 - Logger *slog.Logger 97 + DbType string 98 + DatabaseURL string 93 99 Version string 94 100 Did string 95 101 Hostname string ··· 98 104 ContactEmail string 99 105 Relays []string 100 106 AdminPassword string 107 + RequireInvite bool 101 108 102 109 SmtpUser string 103 110 SmtpPass string ··· 109 116 S3Config *S3Config 110 117 111 118 SessionSecret string 119 + 120 + BlockstoreVariant BlockstoreVariant 121 + FallbackProxy string 112 122 } 113 123 114 124 type config struct { 115 - Version string 116 - Did string 117 - Hostname string 118 - ContactEmail string 119 - EnforcePeering bool 120 - Relays []string 121 - AdminPassword string 122 - SmtpEmail string 123 - SmtpName string 125 + Version string 126 + Did string 127 + Hostname string 128 + ContactEmail string 129 + EnforcePeering bool 130 + Relays []string 131 + AdminPassword string 132 + RequireInvite bool 133 + SmtpEmail string 134 + SmtpName string 135 + BlockstoreVariant BlockstoreVariant 136 + FallbackProxy string 124 137 } 125 138 126 139 type CustomValidator struct { ··· 197 210 return t.templates.ExecuteTemplate(w, name, data) 198 211 } 199 212 200 - func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 201 - return func(e echo.Context) error { 202 - username, password, ok := e.Request().BasicAuth() 203 - if !ok || username != "admin" || password != s.config.AdminPassword { 204 - return helpers.InputError(e, to.StringPtr("Unauthorized")) 205 - } 206 - 207 - if err := next(e); err != nil { 208 - e.Error(err) 209 - } 210 - 211 - return nil 212 - } 213 - } 214 - 215 - func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 216 - return func(e echo.Context) error { 217 - authheader := e.Request().Header.Get("authorization") 218 - if authheader == "" { 219 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 220 - } 221 - 222 - pts := strings.Split(authheader, " ") 223 - if len(pts) != 2 { 224 - return helpers.ServerError(e, nil) 225 - } 226 - 227 - // move on to oauth session middleware if this is a dpop token 228 - if pts[0] == "DPoP" { 229 - return next(e) 230 - } 231 - 232 - tokenstr := pts[1] 233 - token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 234 - claims, ok := token.Claims.(jwt.MapClaims) 235 - if !ok { 236 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 237 - } 238 - 239 - var did string 240 - var repo *models.RepoActor 241 - 242 - // service auth tokens 243 - lxm, hasLxm := claims["lxm"] 244 - if hasLxm { 245 - pts := strings.Split(e.Request().URL.String(), "/") 246 - if lxm != pts[len(pts)-1] { 247 - s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 248 - return helpers.InputError(e, nil) 249 - } 250 - 251 - maybeDid, ok := claims["iss"].(string) 252 - if !ok { 253 - s.logger.Error("no iss in service auth token", "error", err) 254 - return helpers.InputError(e, nil) 255 - } 256 - did = maybeDid 257 - 258 - maybeRepo, err := s.getRepoActorByDid(did) 259 - if err != nil { 260 - s.logger.Error("error fetching repo", "error", err) 261 - return helpers.ServerError(e, nil) 262 - } 263 - repo = maybeRepo 264 - } 265 - 266 - if token.Header["alg"] != "ES256K" { 267 - token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 268 - if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 269 - return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 270 - } 271 - return s.privateKey.Public(), nil 272 - }) 273 - if err != nil { 274 - s.logger.Error("error parsing jwt", "error", err) 275 - // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 276 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 277 - } 278 - 279 - if !token.Valid { 280 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 281 - } 282 - } else { 283 - kpts := strings.Split(tokenstr, ".") 284 - signingInput := kpts[0] + "." + kpts[1] 285 - hash := sha256.Sum256([]byte(signingInput)) 286 - sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 287 - if err != nil { 288 - s.logger.Error("error decoding signature bytes", "error", err) 289 - return helpers.ServerError(e, nil) 290 - } 291 - 292 - if len(sigBytes) != 64 { 293 - s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 294 - return helpers.ServerError(e, nil) 295 - } 296 - 297 - rBytes := sigBytes[:32] 298 - sBytes := sigBytes[32:] 299 - rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 300 - ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 301 - 302 - sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 303 - if err != nil { 304 - s.logger.Error("can't load private key", "error", err) 305 - return err 306 - } 307 - 308 - pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 309 - if !ok { 310 - s.logger.Error("error getting public key from sk") 311 - return helpers.ServerError(e, nil) 312 - } 313 - 314 - verified := pubKey.VerifyRaw(hash[:], rr, ss) 315 - if !verified { 316 - s.logger.Error("error verifying", "error", err) 317 - return helpers.ServerError(e, nil) 318 - } 319 - } 320 - 321 - isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 322 - scope, _ := claims["scope"].(string) 323 - 324 - if isRefresh && scope != "com.atproto.refresh" { 325 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 326 - } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 327 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 328 - } 329 - 330 - table := "tokens" 331 - if isRefresh { 332 - table = "refresh_tokens" 333 - } 334 - 335 - if isRefresh { 336 - type Result struct { 337 - Found bool 338 - } 339 - var result Result 340 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 341 - if err == gorm.ErrRecordNotFound { 342 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 343 - } 344 - 345 - s.logger.Error("error getting token from db", "error", err) 346 - return helpers.ServerError(e, nil) 347 - } 348 - 349 - if !result.Found { 350 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 351 - } 352 - } 353 - 354 - exp, ok := claims["exp"].(float64) 355 - if !ok { 356 - s.logger.Error("error getting iat from token") 357 - return helpers.ServerError(e, nil) 358 - } 359 - 360 - if exp < float64(time.Now().UTC().Unix()) { 361 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 362 - } 363 - 364 - if repo == nil { 365 - maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 366 - if err != nil { 367 - s.logger.Error("error fetching repo", "error", err) 368 - return helpers.ServerError(e, nil) 369 - } 370 - repo = maybeRepo 371 - did = repo.Repo.Did 372 - } 373 - 374 - e.Set("repo", repo) 375 - e.Set("did", did) 376 - e.Set("token", tokenstr) 377 - 378 - if err := next(e); err != nil { 379 - e.Error(err) 380 - } 381 - 382 - return nil 213 + func New(args *Args) (*Server, error) { 214 + if args.Logger == nil { 215 + args.Logger = slog.Default() 383 216 } 384 - } 385 217 386 - func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 387 - return func(e echo.Context) error { 388 - authheader := e.Request().Header.Get("authorization") 389 - if authheader == "" { 390 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 391 - } 392 - 393 - pts := strings.Split(authheader, " ") 394 - if len(pts) != 2 { 395 - return helpers.ServerError(e, nil) 396 - } 397 - 398 - if pts[0] != "DPoP" { 399 - return next(e) 400 - } 401 - 402 - accessToken := pts[1] 403 - 404 - nonce := s.oauthProvider.NextNonce() 405 - if nonce != "" { 406 - e.Response().Header().Set("DPoP-Nonce", nonce) 407 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 408 - } 409 - 410 - proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 411 - if err != nil { 412 - s.logger.Error("invalid dpop proof", "error", err) 413 - return helpers.InputError(e, to.StringPtr(err.Error())) 414 - } 415 - 416 - var oauthToken provider.OauthToken 417 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 418 - s.logger.Error("error finding access token in db", "error", err) 419 - return helpers.InputError(e, nil) 420 - } 421 - 422 - if oauthToken.Token == "" { 423 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 424 - } 218 + logger := args.Logger.With("name", "New") 425 219 426 - if *oauthToken.Parameters.DpopJkt != proof.JKT { 427 - s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 428 - return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 429 - } 430 - 431 - if time.Now().After(oauthToken.ExpiresAt) { 432 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 433 - } 434 - 435 - repo, err := s.getRepoActorByDid(oauthToken.Sub) 436 - if err != nil { 437 - s.logger.Error("could not find actor in db", "error", err) 438 - return helpers.ServerError(e, nil) 439 - } 440 - 441 - e.Set("repo", repo) 442 - e.Set("did", repo.Repo.Did) 443 - e.Set("token", accessToken) 444 - e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 445 - 446 - return next(e) 447 - } 448 - } 449 - 450 - func New(args *Args) (*Server, error) { 451 220 if args.Addr == "" { 452 221 return nil, fmt.Errorf("addr must be set") 453 222 } ··· 476 245 return nil, fmt.Errorf("admin password must be set") 477 246 } 478 247 479 - if args.Logger == nil { 480 - args.Logger = slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{})) 481 - } 482 - 483 248 if args.SessionSecret == "" { 484 249 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 485 250 } ··· 487 252 e := echo.New() 488 253 489 254 e.Pre(middleware.RemoveTrailingSlash()) 490 - e.Pre(slogecho.New(args.Logger)) 255 + e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 491 256 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 257 + e.Use(echoprometheus.NewMiddleware("cocoon")) 492 258 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 493 259 AllowOrigins: []string{"*"}, 494 260 AllowHeaders: []string{"*"}, ··· 534 300 IdleTimeout: 5 * time.Minute, 535 301 } 536 302 537 - gdb, err := gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 538 - if err != nil { 539 - return nil, err 303 + dbType := args.DbType 304 + if dbType == "" { 305 + dbType = "sqlite" 306 + } 307 + 308 + var gdb *gorm.DB 309 + var err error 310 + switch dbType { 311 + case "postgres": 312 + if args.DatabaseURL == "" { 313 + return nil, fmt.Errorf("database-url must be set when using postgres") 314 + } 315 + gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 316 + if err != nil { 317 + return nil, fmt.Errorf("failed to connect to postgres: %w", err) 318 + } 319 + logger.Info("connected to PostgreSQL database") 320 + default: 321 + gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 322 + if err != nil { 323 + return nil, fmt.Errorf("failed to open sqlite database: %w", err) 324 + } 325 + logger.Info("connected to SQLite database", "path", args.DbName) 540 326 } 541 327 dbw := db.NewDB(gdb) 542 328 ··· 579 365 var nonceSecret []byte 580 366 maybeSecret, err := os.ReadFile("nonce.secret") 581 367 if err != nil && !os.IsNotExist(err) { 582 - args.Logger.Error("error attempting to read nonce secret", "error", err) 368 + logger.Error("error attempting to read nonce secret", "error", err) 583 369 } else { 584 370 nonceSecret = maybeSecret 585 371 } ··· 593 379 plcClient: plcClient, 594 380 privateKey: &pkey, 595 381 config: &config{ 596 - Version: args.Version, 597 - Did: args.Did, 598 - Hostname: args.Hostname, 599 - ContactEmail: args.ContactEmail, 600 - EnforcePeering: false, 601 - Relays: args.Relays, 602 - AdminPassword: args.AdminPassword, 603 - SmtpName: args.SmtpName, 604 - SmtpEmail: args.SmtpEmail, 382 + Version: args.Version, 383 + Did: args.Did, 384 + Hostname: args.Hostname, 385 + ContactEmail: args.ContactEmail, 386 + EnforcePeering: false, 387 + Relays: args.Relays, 388 + AdminPassword: args.AdminPassword, 389 + RequireInvite: args.RequireInvite, 390 + SmtpName: args.SmtpName, 391 + SmtpEmail: args.SmtpEmail, 392 + BlockstoreVariant: args.BlockstoreVariant, 393 + FallbackProxy: args.FallbackProxy, 605 394 }, 606 395 evtman: events.NewEventManager(events.NewMemPersister()), 607 396 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 608 397 609 398 dbName: args.DbName, 399 + dbType: dbType, 610 400 s3Config: args.S3Config, 611 401 612 402 oauthProvider: provider.NewProvider(provider.Args{ 613 403 Hostname: args.Hostname, 614 - ClientManagerArgs: client_manager.Args{ 404 + ClientManagerArgs: client.ManagerArgs{ 615 405 Cli: oauthCli, 616 - Logger: args.Logger, 406 + Logger: args.Logger.With("component", "oauth-client-manager"), 617 407 }, 618 - DpopManagerArgs: dpop_manager.Args{ 408 + DpopManagerArgs: dpop.ManagerArgs{ 619 409 NonceSecret: nonceSecret, 620 410 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 621 411 OnNonceSecretCreated: func(newNonce []byte) { 622 412 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 623 - args.Logger.Error("error writing new nonce secret", "error", err) 413 + logger.Error("error writing new nonce secret", "error", err) 624 414 } 625 415 }, 626 - Logger: args.Logger, 416 + Logger: args.Logger.With("component", "dpop-manager"), 627 417 Hostname: args.Hostname, 628 418 }, 629 419 }), ··· 635 425 636 426 // TODO: should validate these args 637 427 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 638 - args.Logger.Warn("not enough smpt args were provided. mailing will not work for your server.") 428 + args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 639 429 } else { 640 430 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 641 431 mail.From(s.config.SmtpEmail) ··· 660 450 s.echo.GET("/", s.handleRoot) 661 451 s.echo.GET("/xrpc/_health", s.handleHealth) 662 452 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 453 + s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 663 454 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 664 455 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 665 456 s.echo.GET("/robots.txt", s.handleRobots) ··· 667 458 // public 668 459 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 669 460 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 670 - s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 671 461 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 672 462 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 463 + s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 673 464 674 465 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 675 466 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) ··· 684 475 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 685 476 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 686 477 478 + // labels 479 + s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 480 + 687 481 // account 688 482 s.echo.GET("/account", s.handleAccount) 689 483 s.echo.POST("/account/revoke", s.handleAccountRevoke) ··· 704 498 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 705 499 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 706 500 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 501 + s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 707 502 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 503 + s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 504 + s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 505 + s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 708 506 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 709 507 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 710 508 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE ··· 713 511 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 714 512 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 715 513 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 514 + s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 515 + s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 516 + s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 517 + s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 716 518 717 519 // repo 520 + s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 718 521 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 719 522 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 720 523 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) ··· 725 528 // stupid silly endpoints 726 529 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 727 530 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 531 + s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 532 + 533 + // admin routes 534 + s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 535 + s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 728 536 729 537 // are there any routes that we should be allowing without auth? i dont think so but idk 730 538 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 731 539 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 732 - 733 - // admin routes 734 - s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 735 - s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 736 540 } 737 541 738 542 func (s *Server) Serve(ctx context.Context) error { 543 + logger := s.logger.With("name", "Serve") 544 + 739 545 s.addRoutes() 740 546 741 - s.logger.Info("migrating...") 547 + logger.Info("migrating...") 742 548 743 549 s.db.AutoMigrate( 744 550 &models.Actor{}, ··· 750 556 &models.Record{}, 751 557 &models.Blob{}, 752 558 &models.BlobPart{}, 559 + &models.ReservedKey{}, 753 560 &provider.OauthToken{}, 754 561 &provider.OauthAuthorizationRequest{}, 755 562 ) 756 563 757 - s.logger.Info("starting cocoon") 564 + logger.Info("starting cocoon") 758 565 759 566 go func() { 760 567 if err := s.httpd.ListenAndServe(); err != nil { ··· 764 571 765 572 go s.backupRoutine() 766 573 574 + go func() { 575 + if err := s.requestCrawl(ctx); err != nil { 576 + logger.Error("error requesting crawls", "err", err) 577 + } 578 + }() 579 + 580 + <-ctx.Done() 581 + 582 + fmt.Println("shut down") 583 + 584 + return nil 585 + } 586 + 587 + func (s *Server) requestCrawl(ctx context.Context) error { 588 + logger := s.logger.With("component", "request-crawl") 589 + s.requestCrawlMu.Lock() 590 + defer s.requestCrawlMu.Unlock() 591 + 592 + logger.Info("requesting crawl with configured relays") 593 + 594 + if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 595 + return fmt.Errorf("a crawl request has already been made within the last minute") 596 + } 597 + 767 598 for _, relay := range s.config.Relays { 599 + logger := logger.With("relay", relay) 600 + logger.Info("requesting crawl from relay") 768 601 cli := xrpc.Client{Host: relay} 769 - atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 602 + if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 770 603 Hostname: s.config.Hostname, 771 - }) 604 + }); err != nil { 605 + logger.Error("error requesting crawl", "err", err) 606 + } else { 607 + logger.Info("crawl requested successfully") 608 + } 772 609 } 773 610 774 - <-ctx.Done() 775 - 776 - fmt.Println("shut down") 611 + s.lastRequestCrawl = time.Now() 777 612 778 613 return nil 779 614 } 780 615 781 616 func (s *Server) doBackup() { 617 + logger := s.logger.With("name", "doBackup") 618 + 619 + if s.dbType == "postgres" { 620 + logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 621 + return 622 + } 623 + 782 624 start := time.Now() 783 625 784 - s.logger.Info("beginning backup to s3...") 626 + logger.Info("beginning backup to s3...") 785 627 786 628 var buf bytes.Buffer 787 629 if err := func() error { 788 - s.logger.Info("reading database bytes...") 630 + logger.Info("reading database bytes...") 789 631 s.db.Lock() 790 632 defer s.db.Unlock() 791 633 ··· 801 643 802 644 return nil 803 645 }(); err != nil { 804 - s.logger.Error("error backing up database", "error", err) 646 + logger.Error("error backing up database", "error", err) 805 647 return 806 648 } 807 649 808 650 if err := func() error { 809 - s.logger.Info("sending to s3...") 651 + logger.Info("sending to s3...") 810 652 811 653 currTime := time.Now().Format("2006-01-02_15-04-05") 812 654 key := "cocoon-backup-" + currTime + ".db" ··· 836 678 return fmt.Errorf("error uploading file to s3: %w", err) 837 679 } 838 680 839 - s.logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 681 + logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 840 682 841 683 return nil 842 684 }(); err != nil { 843 - s.logger.Error("error uploading database backup", "error", err) 685 + logger.Error("error uploading database backup", "error", err) 844 686 return 845 687 } 846 688 ··· 848 690 } 849 691 850 692 func (s *Server) backupRoutine() { 693 + logger := s.logger.With("name", "backupRoutine") 694 + 851 695 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 852 696 return 853 697 } 854 698 855 699 if s.s3Config.Region == "" { 856 - s.logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 700 + logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 857 701 return 858 702 } 859 703 860 704 if s.s3Config.Bucket == "" { 861 - s.logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 705 + logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 862 706 return 863 707 } 864 708 865 709 if s.s3Config.AccessKey == "" { 866 - s.logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 710 + logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 867 711 return 868 712 } 869 713 870 714 if s.s3Config.SecretKey == "" { 871 - s.logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 715 + logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 872 716 return 873 717 } 874 718 ··· 894 738 go s.doBackup() 895 739 } 896 740 } 741 + 742 + func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 743 + if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 744 + return err 745 + } 746 + 747 + return nil 748 + }
+91
server/service_auth.go
··· 1 + package server 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "strings" 7 + 8 + "github.com/bluesky-social/indigo/atproto/atcrypto" 9 + "github.com/bluesky-social/indigo/atproto/identity" 10 + atproto_identity "github.com/bluesky-social/indigo/atproto/identity" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + "github.com/golang-jwt/jwt/v4" 13 + ) 14 + 15 + type ES256KSigningMethod struct { 16 + alg string 17 + } 18 + 19 + func (m *ES256KSigningMethod) Alg() string { 20 + return m.alg 21 + } 22 + 23 + func (m *ES256KSigningMethod) Verify(signingString string, signature string, key interface{}) error { 24 + signatureBytes, err := jwt.DecodeSegment(signature) 25 + if err != nil { 26 + return err 27 + } 28 + return key.(atcrypto.PublicKey).HashAndVerifyLenient([]byte(signingString), signatureBytes) 29 + } 30 + 31 + func (m *ES256KSigningMethod) Sign(signingString string, key interface{}) (string, error) { 32 + return "", fmt.Errorf("unimplemented") 33 + } 34 + 35 + func init() { 36 + ES256K := ES256KSigningMethod{alg: "ES256K"} 37 + jwt.RegisterSigningMethod(ES256K.Alg(), func() jwt.SigningMethod { 38 + return &ES256K 39 + }) 40 + } 41 + 42 + func (s *Server) validateServiceAuth(ctx context.Context, rawToken string, nsid string) (string, error) { 43 + token := strings.TrimSpace(rawToken) 44 + 45 + parsedToken, err := jwt.ParseWithClaims(token, jwt.MapClaims{}, func(token *jwt.Token) (interface{}, error) { 46 + did := syntax.DID(token.Claims.(jwt.MapClaims)["iss"].(string)) 47 + didDoc, err := s.passport.FetchDoc(ctx, did.String()); 48 + if err != nil { 49 + return nil, fmt.Errorf("unable to resolve did %s: %s", did, err) 50 + } 51 + 52 + verificationMethods := make([]atproto_identity.DocVerificationMethod, len(didDoc.VerificationMethods)) 53 + for i, verificationMethod := range didDoc.VerificationMethods { 54 + verificationMethods[i] = atproto_identity.DocVerificationMethod{ 55 + ID: verificationMethod.Id, 56 + Type: verificationMethod.Type, 57 + PublicKeyMultibase: verificationMethod.PublicKeyMultibase, 58 + Controller: verificationMethod.Controller, 59 + } 60 + } 61 + services := make([]atproto_identity.DocService, len(didDoc.Service)) 62 + for i, service := range didDoc.Service { 63 + services[i] = atproto_identity.DocService{ 64 + ID: service.Id, 65 + Type: service.Type, 66 + ServiceEndpoint: service.ServiceEndpoint, 67 + } 68 + } 69 + parsedIdentity := atproto_identity.ParseIdentity(&identity.DIDDocument{ 70 + DID: did, 71 + AlsoKnownAs: didDoc.AlsoKnownAs, 72 + VerificationMethod: verificationMethods, 73 + Service: services, 74 + }) 75 + 76 + key, err := parsedIdentity.PublicKey() 77 + if err != nil { 78 + return nil, fmt.Errorf("signing key not found for did %s: %s", did, err) 79 + } 80 + return key, nil 81 + }) 82 + if err != nil { 83 + return "", fmt.Errorf("invalid token: %s", err) 84 + } 85 + 86 + claims := parsedToken.Claims.(jwt.MapClaims) 87 + if claims["lxm"] != nsid { 88 + return "", fmt.Errorf("bad jwt lexicon method (\"lxm\"). must match: %s", nsid) 89 + } 90 + return claims["iss"].(string), nil 91 + }
+4 -3
server/session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "time" 5 6 6 7 "github.com/golang-jwt/jwt/v4" ··· 13 14 RefreshToken string 14 15 } 15 16 16 - func (s *Server) createSession(repo *models.Repo) (*Session, error) { 17 + func (s *Server) createSession(ctx context.Context, repo *models.Repo) (*Session, error) { 17 18 now := time.Now() 18 19 accexp := now.Add(3 * time.Hour) 19 20 refexp := now.Add(7 * 24 * time.Hour) ··· 49 50 return nil, err 50 51 } 51 52 52 - if err := s.db.Create(&models.Token{ 53 + if err := s.db.Create(ctx, &models.Token{ 53 54 Token: accessString, 54 55 Did: repo.Did, 55 56 RefreshToken: refreshString, ··· 59 60 return nil, err 60 61 } 61 62 62 - if err := s.db.Create(&models.RefreshToken{ 63 + if err := s.db.Create(ctx, &models.RefreshToken{ 63 64 Token: refreshString, 64 65 Did: repo.Did, 65 66 CreatedAt: now,
+5 -4
server/templates/account.html
··· 24 24 </div> 25 25 {{ else }} {{ range .Tokens }} 26 26 <div class="base-container"> 27 - <h4>{{ .ClientId }}</h4> 28 - <p>Created: {{ .CreatedAt }}</p> 29 - <p>Updated: {{ .UpdatedAt }}</p> 30 - <p>Expires: {{ .ExpiresAt }}</p> 27 + <h4>{{ .ClientName }}</h4> 28 + <p>Session Age: {{ .Age}}</p> 29 + <p>Last Updated: {{ .LastUpdated }} ago</p> 30 + <p>Expires In: {{ .ExpiresIn }}</p> 31 + <p>IP Address: {{ .Ip }}</p> 31 32 <form action="/account/revoke" method="post"> 32 33 <input type="hidden" name="token" value="{{ .Token }}" /> 33 34 <button type="submit" value="">Revoke</button>
+4
server/templates/signin.html
··· 26 26 type="password" 27 27 placeholder="Password" 28 28 /> 29 + {{ if .flashes.tokenrequired }} 30 + <br /> 31 + <input name="token" id="token" placeholder="Enter your 2FA token" /> 32 + {{ end }} 29 33 <input name="query_params" type="hidden" value="{{ .QueryParams }}" /> 30 34 <button class="primary" type="submit" value="Login">Login</button> 31 35 </form>
+137
sqlite_blockstore/sqlite_blockstore.go
··· 1 + package sqlite_blockstore 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + 7 + "github.com/bluesky-social/indigo/atproto/syntax" 8 + "github.com/haileyok/cocoon/internal/db" 9 + "github.com/haileyok/cocoon/models" 10 + blocks "github.com/ipfs/go-block-format" 11 + "github.com/ipfs/go-cid" 12 + "gorm.io/gorm/clause" 13 + ) 14 + 15 + type SqliteBlockstore struct { 16 + db *db.DB 17 + did string 18 + readonly bool 19 + inserts map[cid.Cid]blocks.Block 20 + } 21 + 22 + func New(did string, db *db.DB) *SqliteBlockstore { 23 + return &SqliteBlockstore{ 24 + did: did, 25 + db: db, 26 + readonly: false, 27 + inserts: map[cid.Cid]blocks.Block{}, 28 + } 29 + } 30 + 31 + func NewReadOnly(did string, db *db.DB) *SqliteBlockstore { 32 + return &SqliteBlockstore{ 33 + did: did, 34 + db: db, 35 + readonly: true, 36 + inserts: map[cid.Cid]blocks.Block{}, 37 + } 38 + } 39 + 40 + func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) { 41 + var block models.Block 42 + 43 + maybeBlock, ok := bs.inserts[cid] 44 + if ok { 45 + return maybeBlock, nil 46 + } 47 + 48 + if err := bs.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 + return nil, err 50 + } 51 + 52 + b, err := blocks.NewBlockWithCid(block.Value, cid) 53 + if err != nil { 54 + return nil, err 55 + } 56 + 57 + return b, nil 58 + } 59 + 60 + func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error { 61 + bs.inserts[block.Cid()] = block 62 + 63 + if bs.readonly { 64 + return nil 65 + } 66 + 67 + b := models.Block{ 68 + Did: bs.did, 69 + Cid: block.Cid().Bytes(), 70 + Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 71 + Value: block.RawData(), 72 + } 73 + 74 + if err := bs.db.Create(ctx, &b, []clause.Expression{clause.OnConflict{ 75 + Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 + UpdateAll: true, 77 + }}).Error; err != nil { 78 + return err 79 + } 80 + 81 + return nil 82 + } 83 + 84 + func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error { 85 + panic("not implemented") 86 + } 87 + 88 + func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) { 89 + panic("not implemented") 90 + } 91 + 92 + func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) { 93 + panic("not implemented") 94 + } 95 + 96 + func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 + tx := bs.db.BeginDangerously(ctx) 98 + 99 + for _, block := range blocks { 100 + bs.inserts[block.Cid()] = block 101 + 102 + if bs.readonly { 103 + continue 104 + } 105 + 106 + b := models.Block{ 107 + Did: bs.did, 108 + Cid: block.Cid().Bytes(), 109 + Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 110 + Value: block.RawData(), 111 + } 112 + 113 + if err := tx.Clauses(clause.OnConflict{ 114 + Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 115 + UpdateAll: true, 116 + }).Create(&b).Error; err != nil { 117 + tx.Rollback() 118 + return err 119 + } 120 + } 121 + 122 + if bs.readonly { 123 + return nil 124 + } 125 + 126 + tx.Commit() 127 + 128 + return nil 129 + } 130 + 131 + func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 132 + return nil, fmt.Errorf("iteration not allowed on sqlite blockstore") 133 + } 134 + 135 + func (bs *SqliteBlockstore) HashOnRead(enabled bool) { 136 + panic("not implemented") 137 + }
+1 -1
test.go
··· 32 32 33 33 u.Path = "xrpc/com.atproto.sync.subscribeRepos" 34 34 conn, _, err := dialer.Dial(u.String(), http.Header{ 35 - "User-Agent": []string{fmt.Sprintf("hot-topic/0.0.0")}, 35 + "User-Agent": []string{"cocoon-test/0.0.0"}, 36 36 }) 37 37 if err != nil { 38 38 return fmt.Errorf("subscribing to firehose failed (dialing): %w", err)