forked from hailey.at/cocoon
An atproto PDS written in Go

Compare changes

Choose any two refs to compare.

+63 -59
README.md
··· 5 5 6 6 Cocoon is a PDS implementation in Go. It is highly experimental, and is not ready for any production use. 7 7 8 - ### Impmlemented Endpoints 8 + ## Implemented Endpoints 9 9 10 10 > [!NOTE] 11 - Just because something is implemented doesn't mean it is finisehd. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that. 11 + Just because something is implemented doesn't mean it is finished. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that. 12 12 13 - #### Identity 14 - - [ ] com.atproto.identity.getRecommendedDidCredentials 15 - - [ ] com.atproto.identity.requestPlcOperationSignature 16 - - [x] com.atproto.identity.resolveHandle 17 - - [ ] com.atproto.identity.signPlcOperation 18 - - [ ] com.atproto.identity.submitPlcOperatioin 19 - - [x] com.atproto.identity.updateHandle 13 + ### Identity 20 14 21 - #### Repo 22 - - [x] com.atproto.repo.applyWrites 23 - - [x] com.atproto.repo.createRecord 24 - - [x] com.atproto.repo.putRecord 25 - - [x] com.atproto.repo.deleteRecord 26 - - [x] com.atproto.repo.describeRepo 27 - - [x] com.atproto.repo.getRecord 28 - - [ ] com.atproto.repo.importRepo 29 - - [x] com.atproto.repo.listRecords 30 - - [ ] com.atproto.repo.listMissingBlobs 15 + - [ ] `com.atproto.identity.getRecommendedDidCredentials` 16 + - [ ] `com.atproto.identity.requestPlcOperationSignature` 17 + - [x] `com.atproto.identity.resolveHandle` 18 + - [ ] `com.atproto.identity.signPlcOperation` 19 + - [ ] `com.atproto.identity.submitPlcOperation` 20 + - [x] `com.atproto.identity.updateHandle` 31 21 32 - #### Server 33 - - [ ] com.atproto.server.activateAccount 34 - - [ ] com.atproto.server.checkAccountStatus 35 - - [x] com.atproto.server.confirmEmail 36 - - [x] com.atproto.server.createAccount 37 - - [x] com.atproto.server.createInviteCode 38 - - [x] com.atproto.server.createInviteCodes 39 - - [ ] com.atproto.server.deactivateAccount 40 - - [ ] com.atproto.server.deleteAccount 41 - - [x] com.atproto.server.deleteSession 42 - - [x] com.atproto.server.describeServer 43 - - [ ] com.atproto.server.getAccountInviteCodes 44 - - [ ] com.atproto.server.getServiceAuth 45 - - ~[ ] com.atproto.server.listAppPasswords~ - not going to add app passwords 46 - - [x] com.atproto.server.refreshSession 47 - - [ ] com.atproto.server.requestAccountDelete 48 - - [x] com.atproto.server.requestEmailConfirmation 49 - - [x] com.atproto.server.requestEmailUpdate 50 - - [x] com.atproto.server.requestPasswordReset 51 - - [ ] com.atproto.server.reserveSigningKey 52 - - [x] com.atproto.server.resetPassword 53 - - ~[ ] com.atproto.server.revokeAppPassword~ - not going to add app passwords 54 - - [x] com.atproto.server.updateEmail 22 + ### Repo 55 23 56 - #### Sync 57 - - [x] com.atproto.sync.getBlob 58 - - [x] com.atproto.sync.getBlocks 59 - - [x] com.atproto.sync.getLatestCommit 60 - - [x] com.atproto.sync.getRecord 61 - - [x] com.atproto.sync.getRepoStatus 62 - - [x] com.atproto.sync.getRepo 63 - - [x] com.atproto.sync.listBlobs 64 - - [x] com.atproto.sync.listRepos 65 - - ~[ ] com.atproto.sync.notifyOfUpdate~ - BGS doesn't even have this implemented lol 66 - - [x] com.atproto.sync.requestCrawl 67 - - [x] com.atproto.sync.subscribeRepos 24 + - [x] `com.atproto.repo.applyWrites` 25 + - [x] `com.atproto.repo.createRecord` 26 + - [x] `com.atproto.repo.putRecord` 27 + - [x] `com.atproto.repo.deleteRecord` 28 + - [x] `com.atproto.repo.describeRepo` 29 + - [x] `com.atproto.repo.getRecord` 30 + - [x] `com.atproto.repo.importRepo` (Works "okay". You still have to handle PLC operations on your own when migrating. Use with extreme caution.) 31 + - [x] `com.atproto.repo.listRecords` 32 + - [ ] `com.atproto.repo.listMissingBlobs` 33 + 34 + ### Server 35 + 36 + - [ ] `com.atproto.server.activateAccount` 37 + - [x] `com.atproto.server.checkAccountStatus` 38 + - [x] `com.atproto.server.confirmEmail` 39 + - [x] `com.atproto.server.createAccount` 40 + - [x] `com.atproto.server.createInviteCode` 41 + - [x] `com.atproto.server.createInviteCodes` 42 + - [ ] `com.atproto.server.deactivateAccount` 43 + - [ ] `com.atproto.server.deleteAccount` 44 + - [x] `com.atproto.server.deleteSession` 45 + - [x] `com.atproto.server.describeServer` 46 + - [ ] `com.atproto.server.getAccountInviteCodes` 47 + - [ ] `com.atproto.server.getServiceAuth` 48 + - ~~[ ] `com.atproto.server.listAppPasswords`~~ - not going to add app passwords 49 + - [x] `com.atproto.server.refreshSession` 50 + - [ ] `com.atproto.server.requestAccountDelete` 51 + - [x] `com.atproto.server.requestEmailConfirmation` 52 + - [x] `com.atproto.server.requestEmailUpdate` 53 + - [x] `com.atproto.server.requestPasswordReset` 54 + - [ ] `com.atproto.server.reserveSigningKey` 55 + - [x] `com.atproto.server.resetPassword` 56 + - ~~[] `com.atproto.server.revokeAppPassword`~~ - not going to add app passwords 57 + - [x] `com.atproto.server.updateEmail` 68 58 69 - #### Other 70 - - [ ] com.atproto.label.queryLabels 71 - - [ ] com.atproto.moderation.createReport 72 - - [x] app.bsky.actor.getPreferences 73 - - [x] app.bsky.actor.putPreferences 59 + ### Sync 74 60 61 + - [x] `com.atproto.sync.getBlob` 62 + - [x] `com.atproto.sync.getBlocks` 63 + - [x] `com.atproto.sync.getLatestCommit` 64 + - [x] `com.atproto.sync.getRecord` 65 + - [x] `com.atproto.sync.getRepoStatus` 66 + - [x] `com.atproto.sync.getRepo` 67 + - [x] `com.atproto.sync.listBlobs` 68 + - [x] `com.atproto.sync.listRepos` 69 + - ~~[ ] `com.atproto.sync.notifyOfUpdate`~~ - BGS doesn't even have this implemented lol 70 + - [x] `com.atproto.sync.requestCrawl` 71 + - [x] `com.atproto.sync.subscribeRepos` 72 + 73 + ### Other 74 + 75 + - [ ] `com.atproto.label.queryLabels` 76 + - [x] `com.atproto.moderation.createReport` (Note: this should be handled by proxying, not actually implemented in the PDS) 77 + - [x] `app.bsky.actor.getPreferences` 78 + - [x] `app.bsky.actor.putPreferences` 75 79 76 80 ## License 77 81
-163
blockstore/blockstore.go
··· 1 - package blockstore 2 - 3 - import ( 4 - "context" 5 - "fmt" 6 - 7 - "github.com/bluesky-social/indigo/atproto/syntax" 8 - "github.com/haileyok/cocoon/internal/db" 9 - "github.com/haileyok/cocoon/models" 10 - blocks "github.com/ipfs/go-block-format" 11 - "github.com/ipfs/go-cid" 12 - "gorm.io/gorm/clause" 13 - ) 14 - 15 - type SqliteBlockstore struct { 16 - db *db.DB 17 - did string 18 - readonly bool 19 - inserts map[cid.Cid]blocks.Block 20 - } 21 - 22 - func New(did string, db *db.DB) *SqliteBlockstore { 23 - return &SqliteBlockstore{ 24 - did: did, 25 - db: db, 26 - readonly: false, 27 - inserts: map[cid.Cid]blocks.Block{}, 28 - } 29 - } 30 - 31 - func NewReadOnly(did string, db *db.DB) *SqliteBlockstore { 32 - return &SqliteBlockstore{ 33 - did: did, 34 - db: db, 35 - readonly: true, 36 - inserts: map[cid.Cid]blocks.Block{}, 37 - } 38 - } 39 - 40 - func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) { 41 - var block models.Block 42 - 43 - maybeBlock, ok := bs.inserts[cid] 44 - if ok { 45 - return maybeBlock, nil 46 - } 47 - 48 - if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 - return nil, err 50 - } 51 - 52 - b, err := blocks.NewBlockWithCid(block.Value, cid) 53 - if err != nil { 54 - return nil, err 55 - } 56 - 57 - return b, nil 58 - } 59 - 60 - func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error { 61 - bs.inserts[block.Cid()] = block 62 - 63 - if bs.readonly { 64 - return nil 65 - } 66 - 67 - b := models.Block{ 68 - Did: bs.did, 69 - Cid: block.Cid().Bytes(), 70 - Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 71 - Value: block.RawData(), 72 - } 73 - 74 - if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ 75 - Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 - UpdateAll: true, 77 - }}).Error; err != nil { 78 - return err 79 - } 80 - 81 - return nil 82 - } 83 - 84 - func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error { 85 - panic("not implemented") 86 - } 87 - 88 - func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) { 89 - panic("not implemented") 90 - } 91 - 92 - func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) { 93 - panic("not implemented") 94 - } 95 - 96 - func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 - tx := bs.db.BeginDangerously() 98 - 99 - for _, block := range blocks { 100 - bs.inserts[block.Cid()] = block 101 - 102 - if bs.readonly { 103 - continue 104 - } 105 - 106 - b := models.Block{ 107 - Did: bs.did, 108 - Cid: block.Cid().Bytes(), 109 - Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 110 - Value: block.RawData(), 111 - } 112 - 113 - if err := tx.Clauses(clause.OnConflict{ 114 - Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 115 - UpdateAll: true, 116 - }).Create(&b).Error; err != nil { 117 - tx.Rollback() 118 - return err 119 - } 120 - } 121 - 122 - if bs.readonly { 123 - return nil 124 - } 125 - 126 - tx.Commit() 127 - 128 - return nil 129 - } 130 - 131 - func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 132 - panic("not implemented") 133 - } 134 - 135 - func (bs *SqliteBlockstore) HashOnRead(enabled bool) { 136 - panic("not implemented") 137 - } 138 - 139 - func (bs *SqliteBlockstore) UpdateRepo(ctx context.Context, root cid.Cid, rev string) error { 140 - if err := bs.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, bs.did).Error; err != nil { 141 - return err 142 - } 143 - 144 - return nil 145 - } 146 - 147 - func (bs *SqliteBlockstore) Execute(ctx context.Context) error { 148 - if !bs.readonly { 149 - return fmt.Errorf("blockstore was not readonly") 150 - } 151 - 152 - bs.readonly = false 153 - for _, b := range bs.inserts { 154 - bs.Put(ctx, b) 155 - } 156 - bs.readonly = true 157 - 158 - return nil 159 - } 160 - 161 - func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block { 162 - return bs.inserts 163 - }
-186
cmd/admin/main.go
··· 1 - package main 2 - 3 - import ( 4 - "crypto/ecdsa" 5 - "crypto/elliptic" 6 - "crypto/rand" 7 - "encoding/json" 8 - "fmt" 9 - "os" 10 - "time" 11 - 12 - "github.com/bluesky-social/indigo/atproto/crypto" 13 - "github.com/bluesky-social/indigo/atproto/syntax" 14 - "github.com/haileyok/cocoon/internal/helpers" 15 - "github.com/lestrrat-go/jwx/v2/jwk" 16 - "github.com/urfave/cli/v2" 17 - "golang.org/x/crypto/bcrypt" 18 - "gorm.io/driver/sqlite" 19 - "gorm.io/gorm" 20 - ) 21 - 22 - func main() { 23 - app := cli.App{ 24 - Name: "admin", 25 - Commands: cli.Commands{ 26 - runCreateRotationKey, 27 - runCreatePrivateJwk, 28 - runCreateInviteCode, 29 - runResetPassword, 30 - }, 31 - ErrWriter: os.Stdout, 32 - } 33 - 34 - app.Run(os.Args) 35 - } 36 - 37 - var runCreateRotationKey = &cli.Command{ 38 - Name: "create-rotation-key", 39 - Usage: "creates a rotation key for your pds", 40 - Flags: []cli.Flag{ 41 - &cli.StringFlag{ 42 - Name: "out", 43 - Required: true, 44 - Usage: "output file for your rotation key", 45 - }, 46 - }, 47 - Action: func(cmd *cli.Context) error { 48 - key, err := crypto.GeneratePrivateKeyK256() 49 - if err != nil { 50 - return err 51 - } 52 - 53 - bytes := key.Bytes() 54 - 55 - if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil { 56 - return err 57 - } 58 - 59 - return nil 60 - }, 61 - } 62 - 63 - var runCreatePrivateJwk = &cli.Command{ 64 - Name: "create-private-jwk", 65 - Usage: "creates a private jwk for your pds", 66 - Flags: []cli.Flag{ 67 - &cli.StringFlag{ 68 - Name: "out", 69 - Required: true, 70 - Usage: "output file for your jwk", 71 - }, 72 - }, 73 - Action: func(cmd *cli.Context) error { 74 - privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 75 - if err != nil { 76 - return err 77 - } 78 - 79 - key, err := jwk.FromRaw(privKey) 80 - if err != nil { 81 - return err 82 - } 83 - 84 - kid := fmt.Sprintf("%d", time.Now().Unix()) 85 - 86 - if err := key.Set(jwk.KeyIDKey, kid); err != nil { 87 - return err 88 - } 89 - 90 - b, err := json.Marshal(key) 91 - if err != nil { 92 - return err 93 - } 94 - 95 - if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil { 96 - return err 97 - } 98 - 99 - return nil 100 - }, 101 - } 102 - 103 - var runCreateInviteCode = &cli.Command{ 104 - Name: "create-invite-code", 105 - Usage: "creates an invite code", 106 - Flags: []cli.Flag{ 107 - &cli.StringFlag{ 108 - Name: "for", 109 - Usage: "optional did to assign the invite code to", 110 - }, 111 - &cli.IntFlag{ 112 - Name: "uses", 113 - Usage: "number of times the invite code can be used", 114 - Value: 1, 115 - }, 116 - }, 117 - Action: func(cmd *cli.Context) error { 118 - db, err := newDb() 119 - if err != nil { 120 - return err 121 - } 122 - 123 - forDid := "did:plc:123" 124 - if cmd.String("for") != "" { 125 - did, err := syntax.ParseDID(cmd.String("for")) 126 - if err != nil { 127 - return err 128 - } 129 - 130 - forDid = did.String() 131 - } 132 - 133 - uses := cmd.Int("uses") 134 - 135 - code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8)) 136 - 137 - if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil { 138 - return err 139 - } 140 - 141 - fmt.Printf("New invite code created with %d uses: %s\n", uses, code) 142 - 143 - return nil 144 - }, 145 - } 146 - 147 - var runResetPassword = &cli.Command{ 148 - Name: "reset-password", 149 - Usage: "resets a password", 150 - Flags: []cli.Flag{ 151 - &cli.StringFlag{ 152 - Name: "did", 153 - Usage: "did of the user who's password you want to reset", 154 - }, 155 - }, 156 - Action: func(cmd *cli.Context) error { 157 - db, err := newDb() 158 - if err != nil { 159 - return err 160 - } 161 - 162 - didStr := cmd.String("did") 163 - did, err := syntax.ParseDID(didStr) 164 - if err != nil { 165 - return err 166 - } 167 - 168 - newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12)) 169 - hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10) 170 - if err != nil { 171 - return err 172 - } 173 - 174 - if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil { 175 - return err 176 - } 177 - 178 - fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass) 179 - 180 - return nil 181 - }, 182 - } 183 - 184 - func newDb() (*gorm.DB, error) { 185 - return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 186 - }
+183 -3
cmd/cocoon/main.go
··· 1 1 package main 2 2 3 3 import ( 4 + "crypto/ecdsa" 5 + "crypto/elliptic" 6 + "crypto/rand" 7 + "encoding/json" 4 8 "fmt" 5 9 "os" 10 + "time" 6 11 12 + "github.com/bluesky-social/indigo/atproto/crypto" 13 + "github.com/bluesky-social/indigo/atproto/syntax" 14 + "github.com/haileyok/cocoon/internal/helpers" 7 15 "github.com/haileyok/cocoon/server" 8 16 _ "github.com/joho/godotenv/autoload" 17 + "github.com/lestrrat-go/jwx/v2/jwk" 9 18 "github.com/urfave/cli/v2" 19 + "golang.org/x/crypto/bcrypt" 20 + "gorm.io/driver/sqlite" 21 + "gorm.io/gorm" 10 22 ) 11 23 12 24 var Version = "dev" ··· 119 131 Name: "session-secret", 120 132 EnvVars: []string{"COCOON_SESSION_SECRET"}, 121 133 }, 134 + &cli.StringFlag{ 135 + Name: "default-atproto-proxy", 136 + EnvVars: []string{"COCOON_DEFAULT_ATPROTO_PROXY"}, 137 + Value: "did:web:api.bsky.app#bsky_appview", 138 + }, 139 + &cli.StringFlag{ 140 + Name: "blockstore-variant", 141 + EnvVars: []string{"COCOON_BLOCKSTORE_VARIANT"}, 142 + Value: "sqlite", 143 + }, 122 144 }, 123 145 Commands: []*cli.Command{ 124 - run, 146 + runServe, 147 + runCreateRotationKey, 148 + runCreatePrivateJwk, 149 + runCreateInviteCode, 150 + runResetPassword, 125 151 }, 126 152 ErrWriter: os.Stdout, 127 153 Version: Version, ··· 132 158 } 133 159 } 134 160 135 - var run = &cli.Command{ 161 + var runServe = &cli.Command{ 136 162 Name: "run", 137 163 Usage: "Start the cocoon PDS", 138 164 Flags: []cli.Flag{}, 139 165 Action: func(cmd *cli.Context) error { 166 + 140 167 s, err := server.New(&server.Args{ 141 168 Addr: cmd.String("addr"), 142 169 DbName: cmd.String("db-name"), ··· 162 189 AccessKey: cmd.String("s3-access-key"), 163 190 SecretKey: cmd.String("s3-secret-key"), 164 191 }, 165 - SessionSecret: cmd.String("session-secret"), 192 + SessionSecret: cmd.String("session-secret"), 193 + DefaultAtprotoProxy: cmd.String("default-atproto-proxy"), 194 + BlockstoreVariant: server.MustReturnBlockstoreVariant(cmd.String("blockstore-variant")), 166 195 }) 167 196 if err != nil { 168 197 fmt.Printf("error creating cocoon: %v", err) ··· 177 206 return nil 178 207 }, 179 208 } 209 + 210 + var runCreateRotationKey = &cli.Command{ 211 + Name: "create-rotation-key", 212 + Usage: "creates a rotation key for your pds", 213 + Flags: []cli.Flag{ 214 + &cli.StringFlag{ 215 + Name: "out", 216 + Required: true, 217 + Usage: "output file for your rotation key", 218 + }, 219 + }, 220 + Action: func(cmd *cli.Context) error { 221 + key, err := crypto.GeneratePrivateKeyK256() 222 + if err != nil { 223 + return err 224 + } 225 + 226 + bytes := key.Bytes() 227 + 228 + if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil { 229 + return err 230 + } 231 + 232 + return nil 233 + }, 234 + } 235 + 236 + var runCreatePrivateJwk = &cli.Command{ 237 + Name: "create-private-jwk", 238 + Usage: "creates a private jwk for your pds", 239 + Flags: []cli.Flag{ 240 + &cli.StringFlag{ 241 + Name: "out", 242 + Required: true, 243 + Usage: "output file for your jwk", 244 + }, 245 + }, 246 + Action: func(cmd *cli.Context) error { 247 + privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) 248 + if err != nil { 249 + return err 250 + } 251 + 252 + key, err := jwk.FromRaw(privKey) 253 + if err != nil { 254 + return err 255 + } 256 + 257 + kid := fmt.Sprintf("%d", time.Now().Unix()) 258 + 259 + if err := key.Set(jwk.KeyIDKey, kid); err != nil { 260 + return err 261 + } 262 + 263 + b, err := json.Marshal(key) 264 + if err != nil { 265 + return err 266 + } 267 + 268 + if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil { 269 + return err 270 + } 271 + 272 + return nil 273 + }, 274 + } 275 + 276 + var runCreateInviteCode = &cli.Command{ 277 + Name: "create-invite-code", 278 + Usage: "creates an invite code", 279 + Flags: []cli.Flag{ 280 + &cli.StringFlag{ 281 + Name: "for", 282 + Usage: "optional did to assign the invite code to", 283 + }, 284 + &cli.IntFlag{ 285 + Name: "uses", 286 + Usage: "number of times the invite code can be used", 287 + Value: 1, 288 + }, 289 + }, 290 + Action: func(cmd *cli.Context) error { 291 + db, err := newDb() 292 + if err != nil { 293 + return err 294 + } 295 + 296 + forDid := "did:plc:123" 297 + if cmd.String("for") != "" { 298 + did, err := syntax.ParseDID(cmd.String("for")) 299 + if err != nil { 300 + return err 301 + } 302 + 303 + forDid = did.String() 304 + } 305 + 306 + uses := cmd.Int("uses") 307 + 308 + code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8)) 309 + 310 + if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil { 311 + return err 312 + } 313 + 314 + fmt.Printf("New invite code created with %d uses: %s\n", uses, code) 315 + 316 + return nil 317 + }, 318 + } 319 + 320 + var runResetPassword = &cli.Command{ 321 + Name: "reset-password", 322 + Usage: "resets a password", 323 + Flags: []cli.Flag{ 324 + &cli.StringFlag{ 325 + Name: "did", 326 + Usage: "did of the user who's password you want to reset", 327 + }, 328 + }, 329 + Action: func(cmd *cli.Context) error { 330 + db, err := newDb() 331 + if err != nil { 332 + return err 333 + } 334 + 335 + didStr := cmd.String("did") 336 + did, err := syntax.ParseDID(didStr) 337 + if err != nil { 338 + return err 339 + } 340 + 341 + newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12)) 342 + hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10) 343 + if err != nil { 344 + return err 345 + } 346 + 347 + if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil { 348 + return err 349 + } 350 + 351 + fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass) 352 + 353 + return nil 354 + }, 355 + } 356 + 357 + func newDb() (*gorm.DB, error) { 358 + return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{}) 359 + }
+45
cspell.json
··· 1 + { 2 + "version": "0.2", 3 + "language": "en", 4 + "words": [ 5 + "atproto", 6 + "bsky", 7 + "Cocoon", 8 + "PDS", 9 + "Plc", 10 + "plc", 11 + "repo", 12 + "InviteCodes", 13 + "InviteCode", 14 + "Invite", 15 + "Signin", 16 + "Signout", 17 + "JWKS", 18 + "dpop", 19 + "BGS", 20 + "pico", 21 + "picocss", 22 + "par", 23 + "blobs", 24 + "blob", 25 + "did", 26 + "DID", 27 + "OAuth", 28 + "oauth", 29 + "par", 30 + "Cocoon", 31 + "memcache", 32 + "db", 33 + "helpers", 34 + "middleware", 35 + "repo", 36 + "static", 37 + "pico", 38 + "picocss", 39 + "MIT", 40 + "Go" 41 + ], 42 + "ignorePaths": [ 43 + "server/static/pico.css" 44 + ] 45 + }
+1
go.mod
··· 14 14 github.com/google/uuid v1.4.0 15 15 github.com/gorilla/sessions v1.4.0 16 16 github.com/gorilla/websocket v1.5.1 17 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b 17 18 github.com/hashicorp/golang-lru/v2 v2.0.7 18 19 github.com/ipfs/go-block-format v0.2.0 19 20 github.com/ipfs/go-cid v0.4.1
+2
go.sum
··· 91 91 github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY= 92 92 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8= 93 93 github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4= 94 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b h1:wDUNC2eKiL35DbLvsDhiblTUXHxcOPwQSCzi7xpQUN4= 95 + github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b/go.mod h1:VzxiSdG6j1pi7rwGm/xYI5RbtpBgM8sARDXlvEvxlu0= 94 96 github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ= 95 97 github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48= 96 98 github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI=
+73 -54
identity/identity.go
··· 13 13 "github.com/bluesky-social/indigo/util" 14 14 ) 15 15 16 - func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) { 17 - if cli == nil { 18 - cli = util.RobustHTTPClient() 19 - } 20 - 21 - var did string 22 - 23 - _, err := syntax.ParseHandle(handle) 16 + func ResolveHandleFromTXT(ctx context.Context, handle string) (string, error) { 17 + name := fmt.Sprintf("_atproto.%s", handle) 18 + recs, err := net.LookupTXT(name) 24 19 if err != nil { 25 - return "", err 20 + return "", fmt.Errorf("handle could not be resolved via txt: %w", err) 26 21 } 27 22 28 - recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle)) 29 - if err == nil { 30 - for _, rec := range recs { 31 - if strings.HasPrefix(rec, "did=") { 32 - did = strings.Split(rec, "did=")[1] 33 - break 23 + for _, rec := range recs { 24 + if strings.HasPrefix(rec, "did=") { 25 + maybeDid := strings.Split(rec, "did=")[1] 26 + if _, err := syntax.ParseDID(maybeDid); err == nil { 27 + return maybeDid, nil 34 28 } 35 29 } 36 - } else { 37 - fmt.Printf("erorr getting txt records: %v\n", err) 38 30 } 39 31 40 - if did == "" { 41 - req, err := http.NewRequestWithContext( 42 - ctx, 43 - "GET", 44 - fmt.Sprintf("https://%s/.well-known/atproto-did", handle), 45 - nil, 46 - ) 47 - if err != nil { 48 - return "", nil 49 - } 32 + return "", fmt.Errorf("handle could not be resolved via txt: no record found") 33 + } 50 34 51 - resp, err := http.DefaultClient.Do(req) 52 - if err != nil { 53 - return "", nil 54 - } 55 - defer resp.Body.Close() 35 + func ResolveHandleFromWellKnown(ctx context.Context, cli *http.Client, handle string) (string, error) { 36 + ustr := fmt.Sprintf("https://%s/.well=known/atproto-did", handle) 37 + req, err := http.NewRequestWithContext( 38 + ctx, 39 + "GET", 40 + ustr, 41 + nil, 42 + ) 43 + if err != nil { 44 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 45 + } 56 46 57 - if resp.StatusCode != http.StatusOK { 58 - io.Copy(io.Discard, resp.Body) 59 - return "", fmt.Errorf("unable to resolve handle") 60 - } 47 + resp, err := cli.Do(req) 48 + if err != nil { 49 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 50 + } 51 + defer resp.Body.Close() 61 52 62 - b, err := io.ReadAll(resp.Body) 63 - if err != nil { 64 - return "", err 65 - } 53 + b, err := io.ReadAll(resp.Body) 54 + if err != nil { 55 + return "", fmt.Errorf("handle could not be resolved via web: %w", err) 56 + } 66 57 67 - maybeDid := string(b) 58 + if resp.StatusCode != http.StatusOK { 59 + return "", fmt.Errorf("handle could not be resolved via web: invalid status code %d", resp.StatusCode) 60 + } 68 61 69 - if _, err := syntax.ParseDID(maybeDid); err != nil { 70 - return "", fmt.Errorf("unable to resolve handle") 71 - } 62 + maybeDid := string(b) 72 63 73 - did = maybeDid 64 + if _, err := syntax.ParseDID(maybeDid); err != nil { 65 + return "", fmt.Errorf("handle could not be resolved via web: invalid did in document") 74 66 } 75 67 76 - return did, nil 68 + return maybeDid, nil 77 69 } 78 70 79 - func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) { 71 + func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) { 80 72 if cli == nil { 81 73 cli = util.RobustHTTPClient() 82 74 } 83 75 84 - var ustr string 76 + _, err := syntax.ParseHandle(handle) 77 + if err != nil { 78 + return "", err 79 + } 80 + 81 + if maybeDidFromTxt, err := ResolveHandleFromTXT(ctx, handle); err == nil { 82 + return maybeDidFromTxt, nil 83 + } 84 + 85 + if maybeDidFromWeb, err := ResolveHandleFromWellKnown(ctx, cli, handle); err == nil { 86 + return maybeDidFromWeb, nil 87 + } 88 + 89 + return "", fmt.Errorf("handle could not be resolved") 90 + } 91 + 92 + func DidToDocUrl(did string) (string, error) { 85 93 if strings.HasPrefix(did, "did:plc:") { 86 - ustr = fmt.Sprintf("https://plc.directory/%s", did) 94 + return fmt.Sprintf("https://plc.directory/%s", did), nil 87 95 } else if strings.HasPrefix(did, "did:web:") { 88 - ustr = fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:")) 96 + return fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:")), nil 89 97 } else { 90 - return nil, fmt.Errorf("did was not a supported did type") 98 + return "", fmt.Errorf("did was not a supported did type") 99 + } 100 + } 101 + 102 + func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) { 103 + if cli == nil { 104 + cli = util.RobustHTTPClient() 105 + } 106 + 107 + ustr, err := DidToDocUrl(did) 108 + if err != nil { 109 + return nil, err 91 110 } 92 111 93 112 req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil) ··· 95 114 return nil, err 96 115 } 97 116 98 - resp, err := http.DefaultClient.Do(req) 117 + resp, err := cli.Do(req) 99 118 if err != nil { 100 119 return nil, err 101 120 } ··· 103 122 104 123 if resp.StatusCode != 200 { 105 124 io.Copy(io.Discard, resp.Body) 106 - return nil, fmt.Errorf("could not find identity in plc registry") 125 + return nil, fmt.Errorf("unable to find did doc at url. did: %s. url: %s", did, ustr) 107 126 } 108 127 109 128 var diddoc DidDoc ··· 127 146 return nil, err 128 147 } 129 148 130 - resp, err := http.DefaultClient.Do(req) 149 + resp, err := cli.Do(req) 131 150 if err != nil { 132 151 return nil, err 133 152 }
+16 -5
identity/passport.go
··· 19 19 type Passport struct { 20 20 h *http.Client 21 21 bc BackingCache 22 - lk sync.Mutex 22 + mu sync.RWMutex 23 23 } 24 24 25 25 func NewPassport(h *http.Client, bc BackingCache) *Passport { ··· 30 30 return &Passport{ 31 31 h: h, 32 32 bc: bc, 33 - lk: sync.Mutex{}, 34 33 } 35 34 } 36 35 ··· 38 37 skipCache, _ := ctx.Value("skip-cache").(bool) 39 38 40 39 if !skipCache { 40 + p.mu.RLock() 41 41 cached, ok := p.bc.GetDoc(did) 42 + p.mu.RUnlock() 43 + 42 44 if ok { 43 45 return cached, nil 44 46 } 45 47 } 46 48 47 - p.lk.Lock() // this is pretty pathetic, and i should rethink this. but for now, fuck it 48 - defer p.lk.Unlock() 49 - 49 + // TODO: should coalesce requests here 50 50 doc, err := FetchDidDoc(ctx, p.h, did) 51 51 if err != nil { 52 52 return nil, err 53 53 } 54 54 55 + p.mu.Lock() 55 56 p.bc.PutDoc(did, doc) 57 + p.mu.Unlock() 56 58 57 59 return doc, nil 58 60 } ··· 61 63 skipCache, _ := ctx.Value("skip-cache").(bool) 62 64 63 65 if !skipCache { 66 + p.mu.RLock() 64 67 cached, ok := p.bc.GetDid(handle) 68 + p.mu.RUnlock() 69 + 65 70 if ok { 66 71 return cached, nil 67 72 } ··· 72 77 return "", err 73 78 } 74 79 80 + p.mu.Lock() 75 81 p.bc.PutDid(handle, did) 82 + p.mu.Unlock() 76 83 77 84 return did, nil 78 85 } 79 86 80 87 func (p *Passport) BustDoc(ctx context.Context, did string) error { 88 + p.mu.Lock() 89 + defer p.mu.Unlock() 81 90 return p.bc.BustDoc(did) 82 91 } 83 92 84 93 func (p *Passport) BustDid(ctx context.Context, handle string) error { 94 + p.mu.Lock() 95 + defer p.mu.Unlock() 85 96 return p.bc.BustDid(handle) 86 97 }
+13
internal/helpers/helpers.go
··· 7 7 "math/rand" 8 8 "net/url" 9 9 10 + "github.com/Azure/go-autorest/autorest/to" 10 11 "github.com/labstack/echo/v4" 11 12 "github.com/lestrrat-go/jwx/v2/jwk" 12 13 ) ··· 29 30 msg += ". " + *suffix 30 31 } 31 32 return genericError(e, 400, msg) 33 + } 34 + 35 + func InvalidTokenError(e echo.Context) error { 36 + return InputError(e, to.StringPtr("InvalidToken")) 37 + } 38 + 39 + func ExpiredTokenError(e echo.Context) error { 40 + // WARN: See https://github.com/bluesky-social/atproto/discussions/3319 41 + return e.JSON(400, map[string]string{ 42 + "error": "ExpiredToken", 43 + "message": "*", 44 + }) 32 45 } 33 46 34 47 func genericError(e echo.Context, code int, msg string) error {
+8
oauth/client/client.go
··· 1 + package client 2 + 3 + import "github.com/lestrrat-go/jwx/v2/jwk" 4 + 5 + type Client struct { 6 + Metadata *Metadata 7 + JWKS jwk.Key 8 + }
+389
oauth/client/manager.go
··· 1 + package client 2 + 3 + import ( 4 + "context" 5 + "encoding/json" 6 + "errors" 7 + "fmt" 8 + "io" 9 + "log/slog" 10 + "net/http" 11 + "net/url" 12 + "slices" 13 + "strings" 14 + "time" 15 + 16 + cache "github.com/go-pkgz/expirable-cache/v3" 17 + "github.com/haileyok/cocoon/internal/helpers" 18 + "github.com/lestrrat-go/jwx/v2/jwk" 19 + ) 20 + 21 + type Manager struct { 22 + cli *http.Client 23 + logger *slog.Logger 24 + jwksCache cache.Cache[string, jwk.Key] 25 + metadataCache cache.Cache[string, Metadata] 26 + } 27 + 28 + type ManagerArgs struct { 29 + Cli *http.Client 30 + Logger *slog.Logger 31 + } 32 + 33 + func NewManager(args ManagerArgs) *Manager { 34 + if args.Logger == nil { 35 + args.Logger = slog.Default() 36 + } 37 + 38 + if args.Cli == nil { 39 + args.Cli = http.DefaultClient 40 + } 41 + 42 + jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 + metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 + 45 + return &Manager{ 46 + cli: args.Cli, 47 + logger: args.Logger, 48 + jwksCache: jwksCache, 49 + metadataCache: metadataCache, 50 + } 51 + } 52 + 53 + func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 + metadata, err := cm.getClientMetadata(ctx, clientId) 55 + if err != nil { 56 + return nil, err 57 + } 58 + 59 + var jwks jwk.Key 60 + if metadata.JWKS != nil { 61 + // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 62 + // make sure we use the right one 63 + k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0]) 64 + if err != nil { 65 + return nil, err 66 + } 67 + jwks = k 68 + } else if metadata.JWKSURI != nil { 69 + maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 70 + if err != nil { 71 + return nil, err 72 + } 73 + 74 + jwks = maybeJwks 75 + } 76 + 77 + return &Client{ 78 + Metadata: metadata, 79 + JWKS: jwks, 80 + }, nil 81 + } 82 + 83 + func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 84 + metadataCached, ok := cm.metadataCache.Get(clientId) 85 + if !ok { 86 + req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 87 + if err != nil { 88 + return nil, err 89 + } 90 + 91 + resp, err := cm.cli.Do(req) 92 + if err != nil { 93 + return nil, err 94 + } 95 + defer resp.Body.Close() 96 + 97 + if resp.StatusCode != http.StatusOK { 98 + io.Copy(io.Discard, resp.Body) 99 + return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 100 + } 101 + 102 + b, err := io.ReadAll(resp.Body) 103 + if err != nil { 104 + return nil, fmt.Errorf("error reading bytes from client response: %w", err) 105 + } 106 + 107 + validated, err := validateAndParseMetadata(clientId, b) 108 + if err != nil { 109 + return nil, err 110 + } 111 + 112 + return validated, nil 113 + } else { 114 + return &metadataCached, nil 115 + } 116 + } 117 + 118 + func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 119 + jwks, ok := cm.jwksCache.Get(clientId) 120 + if !ok { 121 + req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 122 + if err != nil { 123 + return nil, err 124 + } 125 + 126 + resp, err := cm.cli.Do(req) 127 + if err != nil { 128 + return nil, err 129 + } 130 + defer resp.Body.Close() 131 + 132 + if resp.StatusCode != http.StatusOK { 133 + io.Copy(io.Discard, resp.Body) 134 + return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 135 + } 136 + 137 + type Keys struct { 138 + Keys []map[string]any `json:"keys"` 139 + } 140 + 141 + var keys Keys 142 + if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 143 + return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 144 + } 145 + 146 + if len(keys.Keys) == 0 { 147 + return nil, errors.New("no keys in jwks response") 148 + } 149 + 150 + // TODO: this is again bad, we should be figuring out which one we need to use... 151 + b, err := json.Marshal(keys.Keys[0]) 152 + if err != nil { 153 + return nil, fmt.Errorf("could not marshal key: %w", err) 154 + } 155 + 156 + k, err := helpers.ParseJWKFromBytes(b) 157 + if err != nil { 158 + return nil, err 159 + } 160 + 161 + jwks = k 162 + } 163 + 164 + return jwks, nil 165 + } 166 + 167 + func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 168 + var metadataMap map[string]any 169 + if err := json.Unmarshal(b, &metadataMap); err != nil { 170 + return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 171 + } 172 + 173 + _, jwksOk := metadataMap["jwks"].(string) 174 + _, jwksUriOk := metadataMap["jwks_uri"].(string) 175 + if jwksOk && jwksUriOk { 176 + return nil, errors.New("jwks_uri and jwks are mutually exclusive") 177 + } 178 + 179 + for _, k := range []string{ 180 + "default_max_age", 181 + "userinfo_signed_response_alg", 182 + "id_token_signed_response_alg", 183 + "userinfo_encryhpted_response_alg", 184 + "authorization_encrypted_response_enc", 185 + "authorization_encrypted_response_alg", 186 + "tls_client_certificate_bound_access_tokens", 187 + } { 188 + _, kOk := metadataMap[k] 189 + if kOk { 190 + return nil, fmt.Errorf("unsupported `%s` parameter", k) 191 + } 192 + } 193 + 194 + var metadata Metadata 195 + if err := json.Unmarshal(b, &metadata); err != nil { 196 + return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 197 + } 198 + 199 + u, err := url.Parse(metadata.ClientURI) 200 + if err != nil { 201 + return nil, fmt.Errorf("unable to parse client uri: %w", err) 202 + } 203 + 204 + if isLocalHostname(u.Hostname()) { 205 + return nil, errors.New("`client_uri` hostname is invalid") 206 + } 207 + 208 + if metadata.Scope == "" { 209 + return nil, errors.New("missing `scopes` scope") 210 + } 211 + 212 + scopes := strings.Split(metadata.Scope, " ") 213 + if !slices.Contains(scopes, "atproto") { 214 + return nil, errors.New("missing `atproto` scope") 215 + } 216 + 217 + scopesMap := map[string]bool{} 218 + for _, scope := range scopes { 219 + if scopesMap[scope] { 220 + return nil, fmt.Errorf("duplicate scope `%s`", scope) 221 + } 222 + 223 + // TODO: check for unsupported scopes 224 + 225 + scopesMap[scope] = true 226 + } 227 + 228 + grantTypesMap := map[string]bool{} 229 + for _, gt := range metadata.GrantTypes { 230 + if grantTypesMap[gt] { 231 + return nil, fmt.Errorf("duplicate grant type `%s`", gt) 232 + } 233 + 234 + switch gt { 235 + case "implicit": 236 + return nil, errors.New("grantg type `implicit` is not allowed") 237 + case "authorization_code", "refresh_token": 238 + // TODO check if this grant type is supported 239 + default: 240 + return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 241 + } 242 + 243 + grantTypesMap[gt] = true 244 + } 245 + 246 + if metadata.ClientID != clientId { 247 + return nil, errors.New("`client_id` does not match") 248 + } 249 + 250 + subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 251 + if subjectTypeOk && subjectType != "public" { 252 + return nil, errors.New("only public `subject_type` is supported") 253 + } 254 + 255 + switch metadata.TokenEndpointAuthMethod { 256 + case "none": 257 + if metadata.TokenEndpointAuthSigningAlg != "" { 258 + return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 259 + } 260 + case "private_key_jwt": 261 + if metadata.JWKS == nil && metadata.JWKSURI == nil { 262 + return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 263 + } 264 + 265 + if metadata.JWKS != nil && len(*metadata.JWKS) == 0 { 266 + return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 267 + } 268 + 269 + if metadata.TokenEndpointAuthSigningAlg == "" { 270 + return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 271 + } 272 + default: 273 + return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 274 + } 275 + 276 + if !metadata.DpopBoundAccessTokens { 277 + return nil, errors.New("dpop_bound_access_tokens must be true") 278 + } 279 + 280 + if !slices.Contains(metadata.ResponseTypes, "code") { 281 + return nil, errors.New("response_types must inclue `code`") 282 + } 283 + 284 + if !slices.Contains(metadata.GrantTypes, "authorization_code") { 285 + return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 286 + } 287 + 288 + if len(metadata.RedirectURIs) == 0 { 289 + return nil, errors.New("at least one `redirect_uri` is required") 290 + } 291 + 292 + if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" { 293 + return nil, errors.New("native clients must authenticate using `none` method") 294 + } 295 + 296 + if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 297 + for _, ruri := range metadata.RedirectURIs { 298 + u, err := url.Parse(ruri) 299 + if err != nil { 300 + return nil, fmt.Errorf("error parsing redirect uri: %w", err) 301 + } 302 + 303 + if u.Scheme != "https" { 304 + return nil, errors.New("web clients must use https redirect uris") 305 + } 306 + 307 + if u.Hostname() == "localhost" { 308 + return nil, errors.New("web clients must not use localhost as the hostname") 309 + } 310 + } 311 + } 312 + 313 + for _, ruri := range metadata.RedirectURIs { 314 + u, err := url.Parse(ruri) 315 + if err != nil { 316 + return nil, fmt.Errorf("error parsing redirect uri: %w", err) 317 + } 318 + 319 + if u.User != nil { 320 + if u.User.Username() != "" { 321 + return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 322 + } 323 + 324 + if _, hasPass := u.User.Password(); hasPass { 325 + return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 326 + } 327 + } 328 + 329 + switch true { 330 + case u.Hostname() == "localhost": 331 + return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 332 + case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 333 + if metadata.ApplicationType != "native" { 334 + return nil, errors.New("loopback redirect uris are only allowed for native apps") 335 + } 336 + 337 + if u.Port() != "" { 338 + // reference impl doesn't do anything with this? 339 + } 340 + 341 + if u.Scheme != "http" { 342 + return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 343 + } 344 + 345 + break 346 + case u.Scheme == "http": 347 + return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 348 + case u.Scheme == "https": 349 + if isLocalHostname(u.Hostname()) { 350 + return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 351 + } 352 + break 353 + case strings.Contains(u.Scheme, "."): 354 + if metadata.ApplicationType != "native" { 355 + return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 356 + } 357 + 358 + revdomain := reverseDomain(u.Scheme) 359 + 360 + if isLocalHostname(revdomain) { 361 + return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 362 + } 363 + 364 + if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 365 + return nil, fmt.Errorf("private use uri scheme must be in the form ") 366 + } 367 + default: 368 + return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 369 + } 370 + } 371 + 372 + return &metadata, nil 373 + } 374 + 375 + func isLocalHostname(hostname string) bool { 376 + pts := strings.Split(hostname, ".") 377 + if len(pts) < 2 { 378 + return true 379 + } 380 + 381 + tld := strings.ToLower(pts[len(pts)-1]) 382 + return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 383 + } 384 + 385 + func reverseDomain(domain string) string { 386 + pts := strings.Split(domain, ".") 387 + slices.Reverse(pts) 388 + return strings.Join(pts, ".") 389 + }
+20
oauth/client/metadata.go
··· 1 + package client 2 + 3 + type Metadata struct { 4 + ClientID string `json:"client_id"` 5 + ClientName string `json:"client_name"` 6 + ClientURI string `json:"client_uri"` 7 + LogoURI string `json:"logo_uri"` 8 + TOSURI string `json:"tos_uri"` 9 + PolicyURI string `json:"policy_uri"` 10 + RedirectURIs []string `json:"redirect_uris"` 11 + GrantTypes []string `json:"grant_types"` 12 + ResponseTypes []string `json:"response_types"` 13 + ApplicationType string `json:"application_type"` 14 + DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 15 + JWKSURI *string `json:"jwks_uri,omitempty"` 16 + JWKS *[][]byte `json:"jwks,omitempty"` 17 + Scope string `json:"scope"` 18 + TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 19 + TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 20 + }
-8
oauth/client.go
··· 1 - package oauth 2 - 3 - import "github.com/lestrrat-go/jwx/v2/jwk" 4 - 5 - type Client struct { 6 - Metadata *ClientMetadata 7 - JWKS jwk.Key 8 - }
-390
oauth/client_manager/client_manager.go
··· 1 - package client_manager 2 - 3 - import ( 4 - "context" 5 - "encoding/json" 6 - "errors" 7 - "fmt" 8 - "io" 9 - "log/slog" 10 - "net/http" 11 - "net/url" 12 - "slices" 13 - "strings" 14 - "time" 15 - 16 - cache "github.com/go-pkgz/expirable-cache/v3" 17 - "github.com/haileyok/cocoon/internal/helpers" 18 - "github.com/haileyok/cocoon/oauth" 19 - "github.com/lestrrat-go/jwx/v2/jwk" 20 - ) 21 - 22 - type ClientManager struct { 23 - cli *http.Client 24 - logger *slog.Logger 25 - jwksCache cache.Cache[string, jwk.Key] 26 - metadataCache cache.Cache[string, oauth.ClientMetadata] 27 - } 28 - 29 - type Args struct { 30 - Cli *http.Client 31 - Logger *slog.Logger 32 - } 33 - 34 - func New(args Args) *ClientManager { 35 - if args.Logger == nil { 36 - args.Logger = slog.Default() 37 - } 38 - 39 - if args.Cli == nil { 40 - args.Cli = http.DefaultClient 41 - } 42 - 43 - jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 - metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 45 - 46 - return &ClientManager{ 47 - cli: args.Cli, 48 - logger: args.Logger, 49 - jwksCache: jwksCache, 50 - metadataCache: metadataCache, 51 - } 52 - } 53 - 54 - func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) { 55 - metadata, err := cm.getClientMetadata(ctx, clientId) 56 - if err != nil { 57 - return nil, err 58 - } 59 - 60 - var jwks jwk.Key 61 - if metadata.JWKS != nil { 62 - // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 - // make sure we use the right one 64 - k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0]) 65 - if err != nil { 66 - return nil, err 67 - } 68 - jwks = k 69 - } else if metadata.JWKSURI != nil { 70 - maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 71 - if err != nil { 72 - return nil, err 73 - } 74 - 75 - jwks = maybeJwks 76 - } 77 - 78 - return &oauth.Client{ 79 - Metadata: metadata, 80 - JWKS: jwks, 81 - }, nil 82 - } 83 - 84 - func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) { 85 - metadataCached, ok := cm.metadataCache.Get(clientId) 86 - if !ok { 87 - req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 88 - if err != nil { 89 - return nil, err 90 - } 91 - 92 - resp, err := cm.cli.Do(req) 93 - if err != nil { 94 - return nil, err 95 - } 96 - defer resp.Body.Close() 97 - 98 - if resp.StatusCode != http.StatusOK { 99 - io.Copy(io.Discard, resp.Body) 100 - return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 101 - } 102 - 103 - b, err := io.ReadAll(resp.Body) 104 - if err != nil { 105 - return nil, fmt.Errorf("error reading bytes from client response: %w", err) 106 - } 107 - 108 - validated, err := validateAndParseMetadata(clientId, b) 109 - if err != nil { 110 - return nil, err 111 - } 112 - 113 - return validated, nil 114 - } else { 115 - return &metadataCached, nil 116 - } 117 - } 118 - 119 - func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 120 - jwks, ok := cm.jwksCache.Get(clientId) 121 - if !ok { 122 - req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 123 - if err != nil { 124 - return nil, err 125 - } 126 - 127 - resp, err := cm.cli.Do(req) 128 - if err != nil { 129 - return nil, err 130 - } 131 - defer resp.Body.Close() 132 - 133 - if resp.StatusCode != http.StatusOK { 134 - io.Copy(io.Discard, resp.Body) 135 - return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 136 - } 137 - 138 - type Keys struct { 139 - Keys []map[string]any `json:"keys"` 140 - } 141 - 142 - var keys Keys 143 - if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 144 - return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 145 - } 146 - 147 - if len(keys.Keys) == 0 { 148 - return nil, errors.New("no keys in jwks response") 149 - } 150 - 151 - // TODO: this is again bad, we should be figuring out which one we need to use... 152 - b, err := json.Marshal(keys.Keys[0]) 153 - if err != nil { 154 - return nil, fmt.Errorf("could not marshal key: %w", err) 155 - } 156 - 157 - k, err := helpers.ParseJWKFromBytes(b) 158 - if err != nil { 159 - return nil, err 160 - } 161 - 162 - jwks = k 163 - } 164 - 165 - return jwks, nil 166 - } 167 - 168 - func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) { 169 - var metadataMap map[string]any 170 - if err := json.Unmarshal(b, &metadataMap); err != nil { 171 - return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 172 - } 173 - 174 - _, jwksOk := metadataMap["jwks"].(string) 175 - _, jwksUriOk := metadataMap["jwks_uri"].(string) 176 - if jwksOk && jwksUriOk { 177 - return nil, errors.New("jwks_uri and jwks are mutually exclusive") 178 - } 179 - 180 - for _, k := range []string{ 181 - "default_max_age", 182 - "userinfo_signed_response_alg", 183 - "id_token_signed_response_alg", 184 - "userinfo_encryhpted_response_alg", 185 - "authorization_encrypted_response_enc", 186 - "authorization_encrypted_response_alg", 187 - "tls_client_certificate_bound_access_tokens", 188 - } { 189 - _, kOk := metadataMap[k] 190 - if kOk { 191 - return nil, fmt.Errorf("unsupported `%s` parameter", k) 192 - } 193 - } 194 - 195 - var metadata oauth.ClientMetadata 196 - if err := json.Unmarshal(b, &metadata); err != nil { 197 - return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 198 - } 199 - 200 - u, err := url.Parse(metadata.ClientURI) 201 - if err != nil { 202 - return nil, fmt.Errorf("unable to parse client uri: %w", err) 203 - } 204 - 205 - if isLocalHostname(u.Hostname()) { 206 - return nil, errors.New("`client_uri` hostname is invalid") 207 - } 208 - 209 - if metadata.Scope == "" { 210 - return nil, errors.New("missing `scopes` scope") 211 - } 212 - 213 - scopes := strings.Split(metadata.Scope, " ") 214 - if !slices.Contains(scopes, "atproto") { 215 - return nil, errors.New("missing `atproto` scope") 216 - } 217 - 218 - scopesMap := map[string]bool{} 219 - for _, scope := range scopes { 220 - if scopesMap[scope] { 221 - return nil, fmt.Errorf("duplicate scope `%s`", scope) 222 - } 223 - 224 - // TODO: check for unsupported scopes 225 - 226 - scopesMap[scope] = true 227 - } 228 - 229 - grantTypesMap := map[string]bool{} 230 - for _, gt := range metadata.GrantTypes { 231 - if grantTypesMap[gt] { 232 - return nil, fmt.Errorf("duplicate grant type `%s`", gt) 233 - } 234 - 235 - switch gt { 236 - case "implicit": 237 - return nil, errors.New("grantg type `implicit` is not allowed") 238 - case "authorization_code", "refresh_token": 239 - // TODO check if this grant type is supported 240 - default: 241 - return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 242 - } 243 - 244 - grantTypesMap[gt] = true 245 - } 246 - 247 - if metadata.ClientID != clientId { 248 - return nil, errors.New("`client_id` does not match") 249 - } 250 - 251 - subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 252 - if subjectTypeOk && subjectType != "public" { 253 - return nil, errors.New("only public `subject_type` is supported") 254 - } 255 - 256 - switch metadata.TokenEndpointAuthMethod { 257 - case "none": 258 - if metadata.TokenEndpointAuthSigningAlg != "" { 259 - return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 260 - } 261 - case "private_key_jwt": 262 - if metadata.JWKS == nil && metadata.JWKSURI == nil { 263 - return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 264 - } 265 - 266 - if metadata.JWKS != nil && len(*metadata.JWKS) == 0 { 267 - return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 268 - } 269 - 270 - if metadata.TokenEndpointAuthSigningAlg == "" { 271 - return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 272 - } 273 - default: 274 - return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 275 - } 276 - 277 - if !metadata.DpopBoundAccessTokens { 278 - return nil, errors.New("dpop_bound_access_tokens must be true") 279 - } 280 - 281 - if !slices.Contains(metadata.ResponseTypes, "code") { 282 - return nil, errors.New("response_types must inclue `code`") 283 - } 284 - 285 - if !slices.Contains(metadata.GrantTypes, "authorization_code") { 286 - return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 287 - } 288 - 289 - if len(metadata.RedirectURIs) == 0 { 290 - return nil, errors.New("at least one `redirect_uri` is required") 291 - } 292 - 293 - if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" { 294 - return nil, errors.New("native clients must authenticate using `none` method") 295 - } 296 - 297 - if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 298 - for _, ruri := range metadata.RedirectURIs { 299 - u, err := url.Parse(ruri) 300 - if err != nil { 301 - return nil, fmt.Errorf("error parsing redirect uri: %w", err) 302 - } 303 - 304 - if u.Scheme != "https" { 305 - return nil, errors.New("web clients must use https redirect uris") 306 - } 307 - 308 - if u.Hostname() == "localhost" { 309 - return nil, errors.New("web clients must not use localhost as the hostname") 310 - } 311 - } 312 - } 313 - 314 - for _, ruri := range metadata.RedirectURIs { 315 - u, err := url.Parse(ruri) 316 - if err != nil { 317 - return nil, fmt.Errorf("error parsing redirect uri: %w", err) 318 - } 319 - 320 - if u.User != nil { 321 - if u.User.Username() != "" { 322 - return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 323 - } 324 - 325 - if _, hasPass := u.User.Password(); hasPass { 326 - return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 327 - } 328 - } 329 - 330 - switch true { 331 - case u.Hostname() == "localhost": 332 - return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 333 - case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 334 - if metadata.ApplicationType != "native" { 335 - return nil, errors.New("loopback redirect uris are only allowed for native apps") 336 - } 337 - 338 - if u.Port() != "" { 339 - // reference impl doesn't do anything with this? 340 - } 341 - 342 - if u.Scheme != "http" { 343 - return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 344 - } 345 - 346 - break 347 - case u.Scheme == "http": 348 - return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 349 - case u.Scheme == "https": 350 - if isLocalHostname(u.Hostname()) { 351 - return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 352 - } 353 - break 354 - case strings.Contains(u.Scheme, "."): 355 - if metadata.ApplicationType != "native" { 356 - return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 357 - } 358 - 359 - revdomain := reverseDomain(u.Scheme) 360 - 361 - if isLocalHostname(revdomain) { 362 - return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 363 - } 364 - 365 - if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 366 - return nil, fmt.Errorf("private use uri scheme must be in the form ") 367 - } 368 - default: 369 - return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 370 - } 371 - } 372 - 373 - return &metadata, nil 374 - } 375 - 376 - func isLocalHostname(hostname string) bool { 377 - pts := strings.Split(hostname, ".") 378 - if len(pts) < 2 { 379 - return true 380 - } 381 - 382 - tld := strings.ToLower(pts[len(pts)-1]) 383 - return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 384 - } 385 - 386 - func reverseDomain(domain string) string { 387 - pts := strings.Split(domain, ".") 388 - slices.Reverse(pts) 389 - return strings.Join(pts, ".") 390 - }
-20
oauth/client_metadata.go
··· 1 - package oauth 2 - 3 - type ClientMetadata struct { 4 - ClientID string `json:"client_id"` 5 - ClientName string `json:"client_name"` 6 - ClientURI string `json:"client_uri"` 7 - LogoURI string `json:"logo_uri"` 8 - TOSURI string `json:"tos_uri"` 9 - PolicyURI string `json:"policy_uri"` 10 - RedirectURIs []string `json:"redirect_uris"` 11 - GrantTypes []string `json:"grant_types"` 12 - ResponseTypes []string `json:"response_types"` 13 - ApplicationType string `json:"application_type"` 14 - DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"` 15 - JWKSURI *string `json:"jwks_uri,omitempty"` 16 - JWKS *[][]byte `json:"jwks,omitempty"` 17 - Scope string `json:"scope"` 18 - TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"` 19 - TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"` 20 - }
-251
oauth/dpop/dpop_manager/dpop_manager.go
··· 1 - package dpop_manager 2 - 3 - import ( 4 - "crypto" 5 - "crypto/sha256" 6 - "encoding/base64" 7 - "encoding/json" 8 - "errors" 9 - "fmt" 10 - "log/slog" 11 - "net/http" 12 - "net/url" 13 - "strings" 14 - "time" 15 - 16 - "github.com/golang-jwt/jwt/v4" 17 - "github.com/haileyok/cocoon/internal/helpers" 18 - "github.com/haileyok/cocoon/oauth/constants" 19 - "github.com/haileyok/cocoon/oauth/dpop" 20 - "github.com/haileyok/cocoon/oauth/dpop/nonce" 21 - "github.com/lestrrat-go/jwx/v2/jwa" 22 - "github.com/lestrrat-go/jwx/v2/jwk" 23 - ) 24 - 25 - type DpopManager struct { 26 - nonce *nonce.Nonce 27 - jtiCache *jtiCache 28 - logger *slog.Logger 29 - hostname string 30 - } 31 - 32 - type Args struct { 33 - NonceSecret []byte 34 - NonceRotationInterval time.Duration 35 - OnNonceSecretCreated func([]byte) 36 - JTICacheSize int 37 - Logger *slog.Logger 38 - Hostname string 39 - } 40 - 41 - func New(args Args) *DpopManager { 42 - if args.Logger == nil { 43 - args.Logger = slog.Default() 44 - } 45 - 46 - if args.JTICacheSize == 0 { 47 - args.JTICacheSize = 100_000 48 - } 49 - 50 - if args.NonceSecret == nil { 51 - args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 52 - } 53 - 54 - return &DpopManager{ 55 - nonce: nonce.NewNonce(nonce.Args{ 56 - RotationInterval: args.NonceRotationInterval, 57 - Secret: args.NonceSecret, 58 - OnSecretCreated: args.OnNonceSecretCreated, 59 - }), 60 - jtiCache: newJTICache(args.JTICacheSize), 61 - logger: args.Logger, 62 - hostname: args.Hostname, 63 - } 64 - } 65 - 66 - func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) { 67 - if reqMethod == "" { 68 - return nil, errors.New("HTTP method is required") 69 - } 70 - 71 - if !strings.HasPrefix(reqUrl, "https://") { 72 - reqUrl = "https://" + dm.hostname + reqUrl 73 - } 74 - 75 - proof := extractProof(headers) 76 - 77 - if proof == "" { 78 - return nil, nil 79 - } 80 - 81 - parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 82 - var token *jwt.Token 83 - 84 - token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 85 - if err != nil { 86 - return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 87 - } 88 - 89 - typ, _ := token.Header["typ"].(string) 90 - if typ != "dpop+jwt" { 91 - return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 92 - } 93 - 94 - dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 95 - if !jwkOk { 96 - return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 97 - } 98 - 99 - jwkb, err := json.Marshal(dpopJwk) 100 - if err != nil { 101 - return nil, fmt.Errorf("failed to marshal jwk: %w", err) 102 - } 103 - 104 - key, err := jwk.ParseKey(jwkb) 105 - if err != nil { 106 - return nil, fmt.Errorf("failed to parse jwk: %w", err) 107 - } 108 - 109 - var pubKey any 110 - if err := key.Raw(&pubKey); err != nil { 111 - return nil, fmt.Errorf("failed to get raw public key: %w", err) 112 - } 113 - 114 - token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 115 - alg := t.Header["alg"].(string) 116 - 117 - switch key.KeyType() { 118 - case jwa.EC: 119 - if !strings.HasPrefix(alg, "ES") { 120 - return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 121 - } 122 - case jwa.RSA: 123 - if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 124 - return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 125 - } 126 - case jwa.OKP: 127 - if alg != "EdDSA" { 128 - return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 129 - } 130 - } 131 - 132 - return pubKey, nil 133 - }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 134 - if err != nil { 135 - return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 136 - } 137 - 138 - if !token.Valid { 139 - return nil, errors.New("dpop proof jwt is invalid") 140 - } 141 - 142 - claims, ok := token.Claims.(jwt.MapClaims) 143 - if !ok { 144 - return nil, errors.New("no claims in dpop proof jwt") 145 - } 146 - 147 - iat, iatOk := claims["iat"].(float64) 148 - if !iatOk { 149 - return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 150 - } 151 - 152 - iatTime := time.Unix(int64(iat), 0) 153 - now := time.Now() 154 - 155 - if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 156 - return nil, errors.New("dpop proof too old") 157 - } 158 - 159 - if iatTime.Sub(now) > constants.DpopCheckTolerance { 160 - return nil, errors.New("dpop proof iat is in the future") 161 - } 162 - 163 - jti, _ := claims["jti"].(string) 164 - if jti == "" { 165 - return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 166 - } 167 - 168 - if dm.jtiCache.add(jti) { 169 - return nil, errors.New("dpop proof replay detected") 170 - } 171 - 172 - htm, _ := claims["htm"].(string) 173 - if htm == "" { 174 - return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 175 - } 176 - 177 - if htm != reqMethod { 178 - return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 179 - } 180 - 181 - htu, _ := claims["htu"].(string) 182 - if htu == "" { 183 - return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 184 - } 185 - 186 - parsedHtu, err := helpers.OauthParseHtu(htu) 187 - if err != nil { 188 - return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 189 - } 190 - 191 - u, _ := url.Parse(reqUrl) 192 - if parsedHtu != helpers.OauthNormalizeHtu(u) { 193 - return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 194 - } 195 - 196 - nonce, _ := claims["nonce"].(string) 197 - if nonce == "" { 198 - // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 199 - return nil, errors.New("use_dpop_nonce") 200 - } 201 - 202 - if nonce != "" && !dm.nonce.Check(nonce) { 203 - // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 204 - return nil, errors.New("use_dpop_nonce") 205 - } 206 - 207 - ath, _ := claims["ath"].(string) 208 - 209 - if accessToken != nil && *accessToken != "" { 210 - if ath == "" { 211 - return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 212 - } 213 - 214 - hash := sha256.Sum256([]byte(*accessToken)) 215 - if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 216 - return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 217 - } 218 - } else if ath != "" { 219 - return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 220 - } 221 - 222 - thumbBytes, err := key.Thumbprint(crypto.SHA256) 223 - if err != nil { 224 - return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 225 - } 226 - 227 - thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 228 - 229 - return &dpop.Proof{ 230 - JTI: jti, 231 - JKT: thumb, 232 - HTM: htm, 233 - HTU: htu, 234 - }, nil 235 - } 236 - 237 - func extractProof(headers http.Header) string { 238 - dpopHeaders := headers["Dpop"] 239 - switch len(dpopHeaders) { 240 - case 0: 241 - return "" 242 - case 1: 243 - return dpopHeaders[0] 244 - default: 245 - return "" 246 - } 247 - } 248 - 249 - func (dm *DpopManager) NextNonce() string { 250 - return dm.nonce.NextNonce() 251 - }
-28
oauth/dpop/dpop_manager/jti_cache.go
··· 1 - package dpop_manager 2 - 3 - import ( 4 - "sync" 5 - "time" 6 - 7 - cache "github.com/go-pkgz/expirable-cache/v3" 8 - "github.com/haileyok/cocoon/oauth/constants" 9 - ) 10 - 11 - type jtiCache struct { 12 - mu sync.Mutex 13 - cache cache.Cache[string, bool] 14 - } 15 - 16 - func newJTICache(size int) *jtiCache { 17 - cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl) 18 - return &jtiCache{ 19 - cache: cache, 20 - mu: sync.Mutex{}, 21 - } 22 - } 23 - 24 - func (c *jtiCache) add(jti string) bool { 25 - c.mu.Lock() 26 - defer c.mu.Unlock() 27 - return c.cache.Add(jti, true) 28 - }
+28
oauth/dpop/jti_cache.go
··· 1 + package dpop 2 + 3 + import ( 4 + "sync" 5 + "time" 6 + 7 + cache "github.com/go-pkgz/expirable-cache/v3" 8 + "github.com/haileyok/cocoon/oauth/constants" 9 + ) 10 + 11 + type jtiCache struct { 12 + mu sync.Mutex 13 + cache cache.Cache[string, bool] 14 + } 15 + 16 + func newJTICache(size int) *jtiCache { 17 + cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl) 18 + return &jtiCache{ 19 + cache: cache, 20 + mu: sync.Mutex{}, 21 + } 22 + } 23 + 24 + func (c *jtiCache) add(jti string) bool { 25 + c.mu.Lock() 26 + defer c.mu.Unlock() 27 + return c.cache.Add(jti, true) 28 + }
+249
oauth/dpop/manager.go
··· 1 + package dpop 2 + 3 + import ( 4 + "crypto" 5 + "crypto/sha256" 6 + "encoding/base64" 7 + "encoding/json" 8 + "errors" 9 + "fmt" 10 + "log/slog" 11 + "net/http" 12 + "net/url" 13 + "strings" 14 + "time" 15 + 16 + "github.com/golang-jwt/jwt/v4" 17 + "github.com/haileyok/cocoon/internal/helpers" 18 + "github.com/haileyok/cocoon/oauth/constants" 19 + "github.com/lestrrat-go/jwx/v2/jwa" 20 + "github.com/lestrrat-go/jwx/v2/jwk" 21 + ) 22 + 23 + type Manager struct { 24 + nonce *Nonce 25 + jtiCache *jtiCache 26 + logger *slog.Logger 27 + hostname string 28 + } 29 + 30 + type ManagerArgs struct { 31 + NonceSecret []byte 32 + NonceRotationInterval time.Duration 33 + OnNonceSecretCreated func([]byte) 34 + JTICacheSize int 35 + Logger *slog.Logger 36 + Hostname string 37 + } 38 + 39 + func NewManager(args ManagerArgs) *Manager { 40 + if args.Logger == nil { 41 + args.Logger = slog.Default() 42 + } 43 + 44 + if args.JTICacheSize == 0 { 45 + args.JTICacheSize = 100_000 46 + } 47 + 48 + if args.NonceSecret == nil { 49 + args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 50 + } 51 + 52 + return &Manager{ 53 + nonce: NewNonce(NonceArgs{ 54 + RotationInterval: args.NonceRotationInterval, 55 + Secret: args.NonceSecret, 56 + OnSecretCreated: args.OnNonceSecretCreated, 57 + }), 58 + jtiCache: newJTICache(args.JTICacheSize), 59 + logger: args.Logger, 60 + hostname: args.Hostname, 61 + } 62 + } 63 + 64 + func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) { 65 + if reqMethod == "" { 66 + return nil, errors.New("HTTP method is required") 67 + } 68 + 69 + if !strings.HasPrefix(reqUrl, "https://") { 70 + reqUrl = "https://" + dm.hostname + reqUrl 71 + } 72 + 73 + proof := extractProof(headers) 74 + 75 + if proof == "" { 76 + return nil, nil 77 + } 78 + 79 + parser := jwt.NewParser(jwt.WithoutClaimsValidation()) 80 + var token *jwt.Token 81 + 82 + token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{}) 83 + if err != nil { 84 + return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err) 85 + } 86 + 87 + typ, _ := token.Header["typ"].(string) 88 + if typ != "dpop+jwt" { 89 + return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`) 90 + } 91 + 92 + dpopJwk, jwkOk := token.Header["jwk"].(map[string]any) 93 + if !jwkOk { 94 + return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`) 95 + } 96 + 97 + jwkb, err := json.Marshal(dpopJwk) 98 + if err != nil { 99 + return nil, fmt.Errorf("failed to marshal jwk: %w", err) 100 + } 101 + 102 + key, err := jwk.ParseKey(jwkb) 103 + if err != nil { 104 + return nil, fmt.Errorf("failed to parse jwk: %w", err) 105 + } 106 + 107 + var pubKey any 108 + if err := key.Raw(&pubKey); err != nil { 109 + return nil, fmt.Errorf("failed to get raw public key: %w", err) 110 + } 111 + 112 + token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) { 113 + alg := t.Header["alg"].(string) 114 + 115 + switch key.KeyType() { 116 + case jwa.EC: 117 + if !strings.HasPrefix(alg, "ES") { 118 + return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg) 119 + } 120 + case jwa.RSA: 121 + if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") { 122 + return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg) 123 + } 124 + case jwa.OKP: 125 + if alg != "EdDSA" { 126 + return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg) 127 + } 128 + } 129 + 130 + return pubKey, nil 131 + }, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"})) 132 + if err != nil { 133 + return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err) 134 + } 135 + 136 + if !token.Valid { 137 + return nil, errors.New("dpop proof jwt is invalid") 138 + } 139 + 140 + claims, ok := token.Claims.(jwt.MapClaims) 141 + if !ok { 142 + return nil, errors.New("no claims in dpop proof jwt") 143 + } 144 + 145 + iat, iatOk := claims["iat"].(float64) 146 + if !iatOk { 147 + return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`) 148 + } 149 + 150 + iatTime := time.Unix(int64(iat), 0) 151 + now := time.Now() 152 + 153 + if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance { 154 + return nil, errors.New("dpop proof too old") 155 + } 156 + 157 + if iatTime.Sub(now) > constants.DpopCheckTolerance { 158 + return nil, errors.New("dpop proof iat is in the future") 159 + } 160 + 161 + jti, _ := claims["jti"].(string) 162 + if jti == "" { 163 + return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`) 164 + } 165 + 166 + if dm.jtiCache.add(jti) { 167 + return nil, errors.New("dpop proof replay detected") 168 + } 169 + 170 + htm, _ := claims["htm"].(string) 171 + if htm == "" { 172 + return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`) 173 + } 174 + 175 + if htm != reqMethod { 176 + return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`) 177 + } 178 + 179 + htu, _ := claims["htu"].(string) 180 + if htu == "" { 181 + return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`) 182 + } 183 + 184 + parsedHtu, err := helpers.OauthParseHtu(htu) 185 + if err != nil { 186 + return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`) 187 + } 188 + 189 + u, _ := url.Parse(reqUrl) 190 + if parsedHtu != helpers.OauthNormalizeHtu(u) { 191 + return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u)) 192 + } 193 + 194 + nonce, _ := claims["nonce"].(string) 195 + if nonce == "" { 196 + // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 197 + return nil, errors.New("use_dpop_nonce") 198 + } 199 + 200 + if nonce != "" && !dm.nonce.Check(nonce) { 201 + // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 202 + return nil, errors.New("use_dpop_nonce") 203 + } 204 + 205 + ath, _ := claims["ath"].(string) 206 + 207 + if accessToken != nil && *accessToken != "" { 208 + if ath == "" { 209 + return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`) 210 + } 211 + 212 + hash := sha256.Sum256([]byte(*accessToken)) 213 + if ath != base64.RawURLEncoding.EncodeToString(hash[:]) { 214 + return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`) 215 + } 216 + } else if ath != "" { 217 + return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`) 218 + } 219 + 220 + thumbBytes, err := key.Thumbprint(crypto.SHA256) 221 + if err != nil { 222 + return nil, fmt.Errorf("failed to calculate thumbprint: %w", err) 223 + } 224 + 225 + thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 226 + 227 + return &Proof{ 228 + JTI: jti, 229 + JKT: thumb, 230 + HTM: htm, 231 + HTU: htu, 232 + }, nil 233 + } 234 + 235 + func extractProof(headers http.Header) string { 236 + dpopHeaders := headers["Dpop"] 237 + switch len(dpopHeaders) { 238 + case 0: 239 + return "" 240 + case 1: 241 + return dpopHeaders[0] 242 + default: 243 + return "" 244 + } 245 + } 246 + 247 + func (dm *Manager) NextNonce() string { 248 + return dm.nonce.NextNonce() 249 + }
-108
oauth/dpop/nonce/nonce.go
··· 1 - package nonce 2 - 3 - import ( 4 - "crypto/hmac" 5 - "crypto/sha256" 6 - "encoding/base64" 7 - "encoding/binary" 8 - "sync" 9 - "time" 10 - 11 - "github.com/haileyok/cocoon/internal/helpers" 12 - "github.com/haileyok/cocoon/oauth/constants" 13 - ) 14 - 15 - type Nonce struct { 16 - rotationInterval time.Duration 17 - secret []byte 18 - 19 - mu sync.RWMutex 20 - 21 - counter int64 22 - prev string 23 - curr string 24 - next string 25 - } 26 - 27 - type Args struct { 28 - RotationInterval time.Duration 29 - Secret []byte 30 - OnSecretCreated func([]byte) 31 - } 32 - 33 - func NewNonce(args Args) *Nonce { 34 - if args.RotationInterval == 0 { 35 - args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 - } 37 - 38 - if args.RotationInterval > constants.NonceMaxRotationInterval { 39 - args.RotationInterval = constants.NonceMaxRotationInterval 40 - } 41 - 42 - if args.Secret == nil { 43 - args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength) 44 - args.OnSecretCreated(args.Secret) 45 - } 46 - 47 - n := &Nonce{ 48 - rotationInterval: args.RotationInterval, 49 - secret: args.Secret, 50 - mu: sync.RWMutex{}, 51 - } 52 - 53 - n.counter = n.currentCounter() 54 - n.prev = n.compute(n.counter - 1) 55 - n.curr = n.compute(n.counter) 56 - n.next = n.compute(n.counter + 1) 57 - 58 - return n 59 - } 60 - 61 - func (n *Nonce) currentCounter() int64 { 62 - return time.Now().UnixNano() / int64(n.rotationInterval) 63 - } 64 - 65 - func (n *Nonce) compute(counter int64) string { 66 - h := hmac.New(sha256.New, n.secret) 67 - counterBytes := make([]byte, 8) 68 - binary.BigEndian.PutUint64(counterBytes, uint64(counter)) 69 - h.Write(counterBytes) 70 - return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 71 - } 72 - 73 - func (n *Nonce) rotate() { 74 - counter := n.currentCounter() 75 - diff := counter - n.counter 76 - 77 - switch diff { 78 - case 0: 79 - // counter == n.counter, do nothing 80 - case 1: 81 - n.prev = n.curr 82 - n.curr = n.next 83 - n.next = n.compute(counter + 1) 84 - case 2: 85 - n.prev = n.next 86 - n.curr = n.compute(counter) 87 - n.next = n.compute(counter + 1) 88 - default: 89 - n.prev = n.compute(counter - 1) 90 - n.curr = n.compute(counter) 91 - n.next = n.compute(counter + 1) 92 - } 93 - 94 - n.counter = counter 95 - } 96 - 97 - func (n *Nonce) NextNonce() string { 98 - n.mu.Lock() 99 - defer n.mu.Unlock() 100 - n.rotate() 101 - return n.next 102 - } 103 - 104 - func (n *Nonce) Check(nonce string) bool { 105 - n.mu.RLock() 106 - defer n.mu.RUnlock() 107 - return nonce == n.prev || nonce == n.curr || nonce == n.next 108 - }
+108
oauth/dpop/nonce.go
··· 1 + package dpop 2 + 3 + import ( 4 + "crypto/hmac" 5 + "crypto/sha256" 6 + "encoding/base64" 7 + "encoding/binary" 8 + "sync" 9 + "time" 10 + 11 + "github.com/haileyok/cocoon/internal/helpers" 12 + "github.com/haileyok/cocoon/oauth/constants" 13 + ) 14 + 15 + type Nonce struct { 16 + rotationInterval time.Duration 17 + secret []byte 18 + 19 + mu sync.RWMutex 20 + 21 + counter int64 22 + prev string 23 + curr string 24 + next string 25 + } 26 + 27 + type NonceArgs struct { 28 + RotationInterval time.Duration 29 + Secret []byte 30 + OnSecretCreated func([]byte) 31 + } 32 + 33 + func NewNonce(args NonceArgs) *Nonce { 34 + if args.RotationInterval == 0 { 35 + args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 + } 37 + 38 + if args.RotationInterval > constants.NonceMaxRotationInterval { 39 + args.RotationInterval = constants.NonceMaxRotationInterval 40 + } 41 + 42 + if args.Secret == nil { 43 + args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength) 44 + args.OnSecretCreated(args.Secret) 45 + } 46 + 47 + n := &Nonce{ 48 + rotationInterval: args.RotationInterval, 49 + secret: args.Secret, 50 + mu: sync.RWMutex{}, 51 + } 52 + 53 + n.counter = n.currentCounter() 54 + n.prev = n.compute(n.counter - 1) 55 + n.curr = n.compute(n.counter) 56 + n.next = n.compute(n.counter + 1) 57 + 58 + return n 59 + } 60 + 61 + func (n *Nonce) currentCounter() int64 { 62 + return time.Now().UnixNano() / int64(n.rotationInterval) 63 + } 64 + 65 + func (n *Nonce) compute(counter int64) string { 66 + h := hmac.New(sha256.New, n.secret) 67 + counterBytes := make([]byte, 8) 68 + binary.BigEndian.PutUint64(counterBytes, uint64(counter)) 69 + h.Write(counterBytes) 70 + return base64.RawURLEncoding.EncodeToString(h.Sum(nil)) 71 + } 72 + 73 + func (n *Nonce) rotate() { 74 + counter := n.currentCounter() 75 + diff := counter - n.counter 76 + 77 + switch diff { 78 + case 0: 79 + // counter == n.counter, do nothing 80 + case 1: 81 + n.prev = n.curr 82 + n.curr = n.next 83 + n.next = n.compute(counter + 1) 84 + case 2: 85 + n.prev = n.next 86 + n.curr = n.compute(counter) 87 + n.next = n.compute(counter + 1) 88 + default: 89 + n.prev = n.compute(counter - 1) 90 + n.curr = n.compute(counter) 91 + n.next = n.compute(counter + 1) 92 + } 93 + 94 + n.counter = counter 95 + } 96 + 97 + func (n *Nonce) NextNonce() string { 98 + n.mu.Lock() 99 + defer n.mu.Unlock() 100 + n.rotate() 101 + return n.next 102 + } 103 + 104 + func (n *Nonce) Check(nonce string) bool { 105 + n.mu.RLock() 106 + defer n.mu.RUnlock() 107 + return nonce == n.prev || nonce == n.curr || nonce == n.next 108 + }
+32
oauth/helpers.go
··· 4 4 "errors" 5 5 "fmt" 6 6 "net/url" 7 + "time" 7 8 8 9 "github.com/haileyok/cocoon/internal/helpers" 9 10 "github.com/haileyok/cocoon/oauth/constants" 11 + "github.com/haileyok/cocoon/oauth/provider" 10 12 ) 11 13 12 14 func GenerateCode() string { ··· 46 48 47 49 return reqId, nil 48 50 } 51 + 52 + type SessionAgeResult struct { 53 + SessionAge time.Duration 54 + RefreshAge time.Duration 55 + SessionExpired bool 56 + RefreshExpired bool 57 + } 58 + 59 + func GetSessionAgeFromToken(t provider.OauthToken) SessionAgeResult { 60 + sessionLifetime := constants.PublicClientSessionLifetime 61 + refreshLifetime := constants.PublicClientRefreshLifetime 62 + if t.ClientAuth.Method != "none" { 63 + sessionLifetime = constants.ConfidentialClientSessionLifetime 64 + refreshLifetime = constants.ConfidentialClientRefreshLifetime 65 + } 66 + 67 + res := SessionAgeResult{} 68 + 69 + res.SessionAge = time.Since(t.CreatedAt) 70 + if res.SessionAge > sessionLifetime { 71 + res.SessionExpired = true 72 + } 73 + 74 + refreshAge := time.Since(t.UpdatedAt) 75 + if refreshAge > refreshLifetime { 76 + res.RefreshExpired = true 77 + } 78 + 79 + return res 80 + }
+3 -26
oauth/provider/client_auth.go
··· 3 3 import ( 4 4 "context" 5 5 "crypto" 6 - "database/sql/driver" 7 6 "encoding/base64" 8 - "encoding/json" 9 7 "errors" 10 8 "fmt" 11 9 "time" 12 10 13 11 "github.com/golang-jwt/jwt/v4" 14 - "github.com/haileyok/cocoon/oauth" 12 + "github.com/haileyok/cocoon/oauth/client" 15 13 "github.com/haileyok/cocoon/oauth/constants" 16 14 "github.com/haileyok/cocoon/oauth/dpop" 17 15 ) 18 16 19 - type ClientAuth struct { 20 - Method string 21 - Alg string 22 - Kid string 23 - Jkt string 24 - Jti string 25 - Exp *float64 26 - } 27 - 28 - func (ca *ClientAuth) Scan(value any) error { 29 - b, ok := value.([]byte) 30 - if !ok { 31 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 32 - } 33 - return json.Unmarshal(b, ca) 34 - } 35 - 36 - func (ca ClientAuth) Value() (driver.Value, error) { 37 - return json.Marshal(ca) 38 - } 39 - 40 17 type AuthenticateClientOptions struct { 41 18 AllowMissingDpopProof bool 42 19 } ··· 47 24 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 48 25 } 49 26 50 - func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) { 27 + func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) { 51 28 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 52 29 if err != nil { 53 30 return nil, nil, fmt.Errorf("failed to get client: %w", err) ··· 69 46 return client, clientAuth, nil 70 47 } 71 48 72 - func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) { 49 + func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) { 73 50 metadata := client.Metadata 74 51 75 52 if metadata.TokenEndpointAuthMethod == "none" {
+83
oauth/provider/models.go
··· 1 + package provider 2 + 3 + import ( 4 + "database/sql/driver" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "gorm.io/gorm" 10 + ) 11 + 12 + type ClientAuth struct { 13 + Method string 14 + Alg string 15 + Kid string 16 + Jkt string 17 + Jti string 18 + Exp *float64 19 + } 20 + 21 + func (ca *ClientAuth) Scan(value any) error { 22 + b, ok := value.([]byte) 23 + if !ok { 24 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 25 + } 26 + return json.Unmarshal(b, ca) 27 + } 28 + 29 + func (ca ClientAuth) Value() (driver.Value, error) { 30 + return json.Marshal(ca) 31 + } 32 + 33 + type ParRequest struct { 34 + AuthenticateClientRequestBase 35 + ResponseType string `form:"response_type" json:"response_type" validate:"required"` 36 + CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 37 + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 38 + State string `form:"state" json:"state" validate:"required"` 39 + RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 40 + Scope string `form:"scope" json:"scope" validate:"required"` 41 + LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 42 + DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 43 + } 44 + 45 + func (opr *ParRequest) Scan(value any) error { 46 + b, ok := value.([]byte) 47 + if !ok { 48 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 49 + } 50 + return json.Unmarshal(b, opr) 51 + } 52 + 53 + func (opr ParRequest) Value() (driver.Value, error) { 54 + return json.Marshal(opr) 55 + } 56 + 57 + type OauthToken struct { 58 + gorm.Model 59 + ClientId string `gorm:"index"` 60 + ClientAuth ClientAuth `gorm:"type:json"` 61 + Parameters ParRequest `gorm:"type:json"` 62 + ExpiresAt time.Time `gorm:"index"` 63 + DeviceId string 64 + Sub string `gorm:"index"` 65 + Code string `gorm:"index"` 66 + Token string `gorm:"uniqueIndex"` 67 + RefreshToken string `gorm:"uniqueIndex"` 68 + Ip string 69 + } 70 + 71 + type OauthAuthorizationRequest struct { 72 + gorm.Model 73 + RequestId string `gorm:"primaryKey"` 74 + ClientId string `gorm:"index"` 75 + ClientAuth ClientAuth `gorm:"type:json"` 76 + Parameters ParRequest `gorm:"type:json"` 77 + ExpiresAt time.Time `gorm:"index"` 78 + DeviceId *string 79 + Sub *string 80 + Code *string 81 + Accepted *bool 82 + Ip string 83 + }
+8 -64
oauth/provider/provider.go
··· 1 1 package provider 2 2 3 3 import ( 4 - "database/sql/driver" 5 - "encoding/json" 6 - "fmt" 7 - "time" 8 - 9 - "github.com/haileyok/cocoon/oauth/client_manager" 10 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 11 - "gorm.io/gorm" 4 + "github.com/haileyok/cocoon/oauth/client" 5 + "github.com/haileyok/cocoon/oauth/dpop" 12 6 ) 13 7 14 8 type Provider struct { 15 - ClientManager *client_manager.ClientManager 16 - DpopManager *dpop_manager.DpopManager 9 + ClientManager *client.Manager 10 + DpopManager *dpop.Manager 17 11 18 12 hostname string 19 13 } 20 14 21 15 type Args struct { 22 16 Hostname string 23 - ClientManagerArgs client_manager.Args 24 - DpopManagerArgs dpop_manager.Args 17 + ClientManagerArgs client.ManagerArgs 18 + DpopManagerArgs dpop.ManagerArgs 25 19 } 26 20 27 21 func NewProvider(args Args) *Provider { 28 22 return &Provider{ 29 - ClientManager: client_manager.New(args.ClientManagerArgs), 30 - DpopManager: dpop_manager.New(args.DpopManagerArgs), 23 + ClientManager: client.NewManager(args.ClientManagerArgs), 24 + DpopManager: dpop.NewManager(args.DpopManagerArgs), 31 25 hostname: args.Hostname, 32 26 } 33 27 } ··· 35 29 func (p *Provider) NextNonce() string { 36 30 return p.DpopManager.NextNonce() 37 31 } 38 - 39 - type ParRequest struct { 40 - AuthenticateClientRequestBase 41 - ResponseType string `form:"response_type" json:"response_type" validate:"required"` 42 - CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 43 - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 44 - State string `form:"state" json:"state" validate:"required"` 45 - RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 46 - Scope string `form:"scope" json:"scope" validate:"required"` 47 - LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 48 - DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 49 - } 50 - 51 - func (opr *ParRequest) Scan(value any) error { 52 - b, ok := value.([]byte) 53 - if !ok { 54 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 55 - } 56 - return json.Unmarshal(b, opr) 57 - } 58 - 59 - func (opr ParRequest) Value() (driver.Value, error) { 60 - return json.Marshal(opr) 61 - } 62 - 63 - type OauthToken struct { 64 - gorm.Model 65 - ClientId string `gorm:"index"` 66 - ClientAuth ClientAuth `gorm:"type:json"` 67 - Parameters ParRequest `gorm:"type:json"` 68 - ExpiresAt time.Time `gorm:"index"` 69 - DeviceId string 70 - Sub string `gorm:"index"` 71 - Code string `gorm:"index"` 72 - Token string `gorm:"uniqueIndex"` 73 - RefreshToken string `gorm:"uniqueIndex"` 74 - } 75 - 76 - type OauthAuthorizationRequest struct { 77 - gorm.Model 78 - RequestId string `gorm:"primaryKey"` 79 - ClientId string `gorm:"index"` 80 - ClientAuth ClientAuth `gorm:"type:json"` 81 - Parameters ParRequest `gorm:"type:json"` 82 - ExpiresAt time.Time `gorm:"index"` 83 - DeviceId *string 84 - Sub *string 85 - Code *string 86 - Accepted *bool 87 - }
+77
recording_blockstore/recording_blockstore.go
··· 1 + package recording_blockstore 2 + 3 + import ( 4 + "context" 5 + 6 + blockformat "github.com/ipfs/go-block-format" 7 + "github.com/ipfs/go-cid" 8 + blockstore "github.com/ipfs/go-ipfs-blockstore" 9 + ) 10 + 11 + type RecordingBlockstore struct { 12 + base blockstore.Blockstore 13 + 14 + inserts map[cid.Cid]blockformat.Block 15 + } 16 + 17 + func New(base blockstore.Blockstore) *RecordingBlockstore { 18 + return &RecordingBlockstore{ 19 + base: base, 20 + inserts: make(map[cid.Cid]blockformat.Block), 21 + } 22 + } 23 + 24 + func (bs *RecordingBlockstore) Has(ctx context.Context, c cid.Cid) (bool, error) { 25 + return bs.base.Has(ctx, c) 26 + } 27 + 28 + func (bs *RecordingBlockstore) Get(ctx context.Context, c cid.Cid) (blockformat.Block, error) { 29 + return bs.base.Get(ctx, c) 30 + } 31 + 32 + func (bs *RecordingBlockstore) GetSize(ctx context.Context, c cid.Cid) (int, error) { 33 + return bs.base.GetSize(ctx, c) 34 + } 35 + 36 + func (bs *RecordingBlockstore) DeleteBlock(ctx context.Context, c cid.Cid) error { 37 + return bs.base.DeleteBlock(ctx, c) 38 + } 39 + 40 + func (bs *RecordingBlockstore) Put(ctx context.Context, block blockformat.Block) error { 41 + if err := bs.base.Put(ctx, block); err != nil { 42 + return err 43 + } 44 + bs.inserts[block.Cid()] = block 45 + return nil 46 + } 47 + 48 + func (bs *RecordingBlockstore) PutMany(ctx context.Context, blocks []blockformat.Block) error { 49 + if err := bs.base.PutMany(ctx, blocks); err != nil { 50 + return err 51 + } 52 + 53 + for _, b := range blocks { 54 + bs.inserts[b.Cid()] = b 55 + } 56 + 57 + return nil 58 + } 59 + 60 + func (bs *RecordingBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 61 + return bs.AllKeysChan(ctx) 62 + } 63 + 64 + func (bs *RecordingBlockstore) HashOnRead(enabled bool) { 65 + } 66 + 67 + func (bs *RecordingBlockstore) GetLogMap() map[cid.Cid]blockformat.Block { 68 + return bs.inserts 69 + } 70 + 71 + func (bs *RecordingBlockstore) GetLogArray() []blockformat.Block { 72 + var blocks []blockformat.Block 73 + for _, b := range bs.inserts { 74 + blocks = append(blocks, b) 75 + } 76 + return blocks 77 + }
+30
server/blockstore_variant.go
··· 1 + package server 2 + 3 + import ( 4 + "github.com/haileyok/cocoon/sqlite_blockstore" 5 + blockstore "github.com/ipfs/go-ipfs-blockstore" 6 + ) 7 + 8 + type BlockstoreVariant int 9 + 10 + const ( 11 + BlockstoreVariantSqlite = iota 12 + ) 13 + 14 + func MustReturnBlockstoreVariant(maybeBsv string) BlockstoreVariant { 15 + switch maybeBsv { 16 + case "sqlite": 17 + return BlockstoreVariantSqlite 18 + default: 19 + panic("invalid blockstore variant provided") 20 + } 21 + } 22 + 23 + func (s *Server) getBlockstore(did string) blockstore.Blockstore { 24 + switch s.config.BlockstoreVariant { 25 + case BlockstoreVariantSqlite: 26 + return sqlite_blockstore.New(did, s.db) 27 + default: 28 + return sqlite_blockstore.New(did, s.db) 29 + } 30 + }
+37 -7
server/handle_account.go
··· 3 3 import ( 4 4 "time" 5 5 6 + "github.com/haileyok/cocoon/oauth" 7 + "github.com/haileyok/cocoon/oauth/constants" 6 8 "github.com/haileyok/cocoon/oauth/provider" 9 + "github.com/hako/durafmt" 7 10 "github.com/labstack/echo/v4" 8 11 ) 9 12 10 13 func (s *Server) handleAccount(e echo.Context) error { 14 + ctx := e.Request().Context() 11 15 repo, sess, err := s.getSessionRepoOrErr(e) 12 16 if err != nil { 13 17 return e.Redirect(303, "/account/signin") 14 18 } 15 19 16 - now := time.Now() 20 + oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) 17 21 18 22 var tokens []provider.OauthToken 19 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND expires_at >= ? ORDER BY created_at ASC", nil, repo.Repo.Did, now).Scan(&tokens).Error; err != nil { 23 + if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 20 24 s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 21 25 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 22 26 sess.Save(e.Request(), e.Response()) ··· 25 29 }) 26 30 } 27 31 32 + var filtered []provider.OauthToken 33 + for _, t := range tokens { 34 + ageRes := oauth.GetSessionAgeFromToken(t) 35 + if ageRes.SessionExpired { 36 + continue 37 + } 38 + filtered = append(filtered, t) 39 + } 40 + 41 + now := time.Now() 42 + 28 43 tokenInfo := []map[string]string{} 29 44 for _, t := range tokens { 45 + ageRes := oauth.GetSessionAgeFromToken(t) 46 + maxTime := constants.PublicClientSessionLifetime 47 + if t.ClientAuth.Method != "none" { 48 + maxTime = constants.ConfidentialClientSessionLifetime 49 + } 50 + 51 + var clientName string 52 + metadata, err := s.oauthProvider.ClientManager.GetClient(ctx, t.ClientId) 53 + if err != nil { 54 + clientName = t.ClientId 55 + } else { 56 + clientName = metadata.Metadata.ClientName 57 + } 58 + 30 59 tokenInfo = append(tokenInfo, map[string]string{ 31 - "ClientId": t.ClientId, 32 - "CreatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 33 - "UpdatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 34 - "ExpiresAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"), 35 - "Token": t.Token, 60 + "ClientName": clientName, 61 + "Age": durafmt.Parse(ageRes.SessionAge).LimitFirstN(2).String(), 62 + "LastUpdated": durafmt.Parse(now.Sub(t.UpdatedAt)).LimitFirstN(2).String(), 63 + "ExpiresIn": durafmt.Parse(now.Add(maxTime).Sub(now)).LimitFirstN(2).String(), 64 + "Token": t.Token, 65 + "Ip": t.Ip, 36 66 }) 37 67 } 38 68
+2 -3
server/handle_import_repo.go
··· 9 9 10 10 "github.com/bluesky-social/indigo/atproto/syntax" 11 11 "github.com/bluesky-social/indigo/repo" 12 - "github.com/haileyok/cocoon/blockstore" 13 12 "github.com/haileyok/cocoon/internal/helpers" 14 13 "github.com/haileyok/cocoon/models" 15 14 blocks "github.com/ipfs/go-block-format" ··· 27 26 return helpers.ServerError(e, nil) 28 27 } 29 28 30 - bs := blockstore.New(urepo.Repo.Did, s.db) 29 + bs := s.getBlockstore(urepo.Repo.Did) 31 30 32 31 cs, err := car.NewCarReader(bytes.NewReader(b)) 33 32 if err != nil { ··· 107 106 return helpers.ServerError(e, nil) 108 107 } 109 108 110 - if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil { 109 + if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil { 111 110 s.logger.Error("error updating repo after commit", "error", err) 112 111 return helpers.ServerError(e, nil) 113 112 }
+1 -1
server/handle_oauth_authorize.go
··· 113 113 114 114 code := oauth.GenerateCode() 115 115 116 - if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, reqId).Error; err != nil { 116 + if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 117 117 s.logger.Error("error updating authorization request", "error", err) 118 118 return helpers.ServerError(e, nil) 119 119 }
+4 -10
server/handle_oauth_token.go
··· 157 157 Code: *authReq.Code, 158 158 Token: accessString, 159 159 RefreshToken: refreshToken, 160 + Ip: authReq.Ip, 160 161 }, nil).Error; err != nil { 161 162 s.logger.Error("error creating token in db", "error", err) 162 163 return helpers.ServerError(e, nil) ··· 203 204 return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt")) 204 205 } 205 206 206 - sessionLifetime := constants.PublicClientSessionLifetime 207 - refreshLifetime := constants.PublicClientRefreshLifetime 208 - if clientAuth.Method != "none" { 209 - sessionLifetime = constants.ConfidentialClientSessionLifetime 210 - refreshLifetime = constants.ConfidentialClientRefreshLifetime 211 - } 207 + ageRes := oauth.GetSessionAgeFromToken(oauthToken) 212 208 213 - sessionAge := time.Since(oauthToken.CreatedAt) 214 - if sessionAge > sessionLifetime { 209 + if ageRes.SessionExpired { 215 210 return helpers.InputError(e, to.StringPtr("Session expired")) 216 211 } 217 212 218 - refreshAge := time.Since(oauthToken.UpdatedAt) 219 - if refreshAge > refreshLifetime { 213 + if ageRes.RefreshExpired { 220 214 return helpers.InputError(e, to.StringPtr("Refresh token expired")) 221 215 } 222 216
+27 -15
server/handle_proxy.go
··· 17 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 18 ) 19 19 20 - func (s *Server) handleProxy(e echo.Context) error { 21 - repo, isAuthed := e.Get("repo").(*models.RepoActor) 22 - 23 - pts := strings.Split(e.Request().URL.Path, "/") 24 - if len(pts) != 3 { 25 - return fmt.Errorf("incorrect number of parts") 26 - } 27 - 20 + func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 28 21 svc := e.Request().Header.Get("atproto-proxy") 29 22 if svc == "" { 30 - svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably 23 + svc = s.config.DefaultAtprotoProxy 31 24 } 32 25 33 26 svcPts := strings.Split(svc, "#") 34 27 if len(svcPts) != 2 { 35 - return fmt.Errorf("invalid service header") 28 + return "", "", fmt.Errorf("invalid service header") 36 29 } 37 30 38 31 svcDid := svcPts[0] ··· 40 33 41 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 42 35 if err != nil { 43 - return err 36 + return "", "", err 44 37 } 45 38 46 39 var endpoint string ··· 50 43 } 51 44 } 52 45 46 + return endpoint, svcDid, nil 47 + } 48 + 49 + func (s *Server) handleProxy(e echo.Context) error { 50 + lgr := s.logger.With("handler", "handleProxy") 51 + 52 + repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 + 54 + pts := strings.Split(e.Request().URL.Path, "/") 55 + if len(pts) != 3 { 56 + return fmt.Errorf("incorrect number of parts") 57 + } 58 + 59 + endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 + if err != nil { 61 + lgr.Error("could not get atproto proxy", "error", err) 62 + return helpers.ServerError(e, nil) 63 + } 64 + 53 65 requrl := e.Request().URL 54 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 55 67 requrl.Scheme = "https" ··· 78 90 } 79 91 hj, err := json.Marshal(header) 80 92 if err != nil { 81 - s.logger.Error("error marshaling header", "error", err) 93 + lgr.Error("error marshaling header", "error", err) 82 94 return helpers.ServerError(e, nil) 83 95 } 84 96 ··· 93 105 } 94 106 pj, err := json.Marshal(payload) 95 107 if err != nil { 96 - s.logger.Error("error marashaling payload", "error", err) 108 + lgr.Error("error marashaling payload", "error", err) 97 109 return helpers.ServerError(e, nil) 98 110 } 99 111 ··· 104 116 105 117 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 106 118 if err != nil { 107 - s.logger.Error("can't load private key", "error", err) 119 + lgr.Error("can't load private key", "error", err) 108 120 return err 109 121 } 110 122 111 123 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 112 124 if err != nil { 113 - s.logger.Error("error signing", "error", err) 125 + lgr.Error("error signing", "error", err) 114 126 } 115 127 116 128 rBytes := R.Bytes()
+65
server/handle_server_check_account_status.go
··· 1 + package server 2 + 3 + import ( 4 + "github.com/haileyok/cocoon/internal/helpers" 5 + "github.com/haileyok/cocoon/models" 6 + "github.com/ipfs/go-cid" 7 + "github.com/labstack/echo/v4" 8 + ) 9 + 10 + type ComAtprotoServerCheckAccountStatusResponse struct { 11 + Activated bool `json:"activated"` 12 + ValidDid bool `json:"validDid"` 13 + RepoCommit string `json:"repoCommit"` 14 + RepoRev string `json:"repoRev"` 15 + RepoBlocks int64 `json:"repoBlocks"` 16 + IndexedRecords int64 `json:"indexedRecords"` 17 + PrivateStateValues int64 `json:"privateStateValues"` 18 + ExpectedBlobs int64 `json:"expectedBlobs"` 19 + ImportedBlobs int64 `json:"importedBlobs"` 20 + } 21 + 22 + func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 + urepo := e.Get("repo").(*models.RepoActor) 24 + 25 + resp := ComAtprotoServerCheckAccountStatusResponse{ 26 + Activated: true, // TODO: should allow for deactivation etc. 27 + ValidDid: true, // TODO: should probably verify? 28 + RepoRev: urepo.Rev, 29 + ImportedBlobs: 0, // TODO: ??? 30 + } 31 + 32 + rootcid, err := cid.Cast(urepo.Root) 33 + if err != nil { 34 + s.logger.Error("error casting cid", "error", err) 35 + return helpers.ServerError(e, nil) 36 + } 37 + resp.RepoCommit = rootcid.String() 38 + 39 + type CountResp struct { 40 + Ct int64 41 + } 42 + 43 + var blockCtResp CountResp 44 + if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 45 + s.logger.Error("error getting block count", "error", err) 46 + return helpers.ServerError(e, nil) 47 + } 48 + resp.RepoBlocks = blockCtResp.Ct 49 + 50 + var recCtResp CountResp 51 + if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 52 + s.logger.Error("error getting record count", "error", err) 53 + return helpers.ServerError(e, nil) 54 + } 55 + resp.IndexedRecords = recCtResp.Ct 56 + 57 + var blobCtResp CountResp 58 + if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 59 + s.logger.Error("error getting record count", "error", err) 60 + return helpers.ServerError(e, nil) 61 + } 62 + resp.ExpectedBlobs = blobCtResp.Ct 63 + 64 + return e.JSON(200, resp) 65 + }
+2 -2
server/handle_server_confirm_email.go
··· 28 28 } 29 29 30 30 if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil { 31 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 31 + return helpers.ExpiredTokenError(e) 32 32 } 33 33 34 34 if *urepo.EmailVerificationCode != req.Token { ··· 36 36 } 37 37 38 38 if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) { 39 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 39 + return helpers.ExpiredTokenError(e) 40 40 } 41 41 42 42 now := time.Now().UTC()
+2 -3
server/handle_server_create_account.go
··· 14 14 "github.com/bluesky-social/indigo/events" 15 15 "github.com/bluesky-social/indigo/repo" 16 16 "github.com/bluesky-social/indigo/util" 17 - "github.com/haileyok/cocoon/blockstore" 18 17 "github.com/haileyok/cocoon/internal/helpers" 19 18 "github.com/haileyok/cocoon/models" 20 19 "github.com/labstack/echo/v4" ··· 177 176 } 178 177 179 178 if customDidHeader == "" { 180 - bs := blockstore.New(signupDid, s.db) 179 + bs := s.getBlockstore(signupDid) 181 180 r := repo.NewRepo(context.TODO(), signupDid, bs) 182 181 183 182 root, rev, err := r.Commit(context.TODO(), urepo.SignFor) ··· 186 185 return helpers.ServerError(e, nil) 187 186 } 188 187 189 - if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil { 188 + if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil { 190 189 s.logger.Error("error updating repo after commit", "error", err) 191 190 return helpers.ServerError(e, nil) 192 191 }
+8 -6
server/handle_server_get_service_auth.go
··· 19 19 20 20 type ServerGetServiceAuthRequest struct { 21 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 - Exp int64 `query:"exp"` 23 - Lxm string `query:"lxm" validate:"required,atproto-nsid"` 22 + // exp should be a float, as some clients will send a non-integer expiration 23 + Exp float64 `query:"exp"` 24 + Lxm string `query:"lxm" validate:"required,atproto-nsid"` 24 25 } 25 26 26 27 func (s *Server) handleServerGetServiceAuth(e echo.Context) error { ··· 34 35 return helpers.InputError(e, nil) 35 36 } 36 37 38 + exp := int64(req.Exp) 37 39 now := time.Now().Unix() 38 - if req.Exp == 0 { 39 - req.Exp = now + 60 // default 40 + if exp == 0 { 41 + exp = now + 60 // default 40 42 } 41 43 42 44 if req.Lxm == "com.atproto.server.getServiceAuth" { ··· 44 46 } 45 47 46 48 maxExp := now + (60 * 30) 47 - if req.Exp > maxExp { 49 + if exp > maxExp { 48 50 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 49 51 } 50 52 ··· 68 70 "aud": req.Aud, 69 71 "lxm": req.Lxm, 70 72 "jti": uuid.NewString(), 71 - "exp": req.Exp, 73 + "exp": exp, 72 74 "iat": now, 73 75 } 74 76 pj, err := json.Marshal(payload)
+2 -2
server/handle_server_reset_password.go
··· 33 33 } 34 34 35 35 if *urepo.PasswordResetCode != req.Token { 36 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 36 + return helpers.InvalidTokenError(e) 37 37 } 38 38 39 39 if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) { 40 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 40 + return helpers.ExpiredTokenError(e) 41 41 } 42 42 43 43 hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10)
+3 -4
server/handle_server_update_email.go
··· 3 3 import ( 4 4 "time" 5 5 6 - "github.com/Azure/go-autorest/autorest/to" 7 6 "github.com/haileyok/cocoon/internal/helpers" 8 7 "github.com/haileyok/cocoon/models" 9 8 "github.com/labstack/echo/v4" ··· 29 28 } 30 29 31 30 if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil { 32 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 31 + return helpers.InvalidTokenError(e) 33 32 } 34 33 35 34 if *urepo.EmailUpdateCode != req.Token { 36 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 35 + return helpers.InvalidTokenError(e) 37 36 } 38 37 39 38 if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) { 40 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 39 + return helpers.ExpiredTokenError(e) 41 40 } 42 41 43 42 if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
+1 -2
server/handle_sync_get_blocks.go
··· 6 6 "strings" 7 7 8 8 "github.com/bluesky-social/indigo/carstore" 9 - "github.com/haileyok/cocoon/blockstore" 10 9 "github.com/haileyok/cocoon/internal/helpers" 11 10 "github.com/ipfs/go-cid" 12 11 cbor "github.com/ipfs/go-ipld-cbor" ··· 54 53 return helpers.ServerError(e, nil) 55 54 } 56 55 57 - bs := blockstore.New(urepo.Repo.Did, s.db) 56 + bs := s.getBlockstore(urepo.Repo.Did) 58 57 59 58 for _, c := range cids { 60 59 b, err := bs.Get(context.TODO(), c)
+268
server/middleware.go
··· 1 + package server 2 + 3 + import ( 4 + "crypto/sha256" 5 + "encoding/base64" 6 + "fmt" 7 + "strings" 8 + "time" 9 + 10 + "github.com/Azure/go-autorest/autorest/to" 11 + "github.com/golang-jwt/jwt/v4" 12 + "github.com/haileyok/cocoon/internal/helpers" 13 + "github.com/haileyok/cocoon/models" 14 + "github.com/haileyok/cocoon/oauth/provider" 15 + "github.com/labstack/echo/v4" 16 + "gitlab.com/yawning/secp256k1-voi" 17 + secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 + "gorm.io/gorm" 19 + ) 20 + 21 + func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 22 + return func(e echo.Context) error { 23 + username, password, ok := e.Request().BasicAuth() 24 + if !ok || username != "admin" || password != s.config.AdminPassword { 25 + return helpers.InputError(e, to.StringPtr("Unauthorized")) 26 + } 27 + 28 + if err := next(e); err != nil { 29 + e.Error(err) 30 + } 31 + 32 + return nil 33 + } 34 + } 35 + 36 + func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 37 + return func(e echo.Context) error { 38 + authheader := e.Request().Header.Get("authorization") 39 + if authheader == "" { 40 + return e.JSON(401, map[string]string{"error": "Unauthorized"}) 41 + } 42 + 43 + pts := strings.Split(authheader, " ") 44 + if len(pts) != 2 { 45 + return helpers.ServerError(e, nil) 46 + } 47 + 48 + // move on to oauth session middleware if this is a dpop token 49 + if pts[0] == "DPoP" { 50 + return next(e) 51 + } 52 + 53 + tokenstr := pts[1] 54 + token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 55 + claims, ok := token.Claims.(jwt.MapClaims) 56 + if !ok { 57 + return helpers.InvalidTokenError(e) 58 + } 59 + 60 + var did string 61 + var repo *models.RepoActor 62 + 63 + // service auth tokens 64 + lxm, hasLxm := claims["lxm"] 65 + if hasLxm { 66 + pts := strings.Split(e.Request().URL.String(), "/") 67 + if lxm != pts[len(pts)-1] { 68 + s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 69 + return helpers.InputError(e, nil) 70 + } 71 + 72 + maybeDid, ok := claims["iss"].(string) 73 + if !ok { 74 + s.logger.Error("no iss in service auth token", "error", err) 75 + return helpers.InputError(e, nil) 76 + } 77 + did = maybeDid 78 + 79 + maybeRepo, err := s.getRepoActorByDid(did) 80 + if err != nil { 81 + s.logger.Error("error fetching repo", "error", err) 82 + return helpers.ServerError(e, nil) 83 + } 84 + repo = maybeRepo 85 + } 86 + 87 + if token.Header["alg"] != "ES256K" { 88 + token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 89 + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 90 + return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 91 + } 92 + return s.privateKey.Public(), nil 93 + }) 94 + if err != nil { 95 + s.logger.Error("error parsing jwt", "error", err) 96 + return helpers.ExpiredTokenError(e) 97 + } 98 + 99 + if !token.Valid { 100 + return helpers.InvalidTokenError(e) 101 + } 102 + } else { 103 + kpts := strings.Split(tokenstr, ".") 104 + signingInput := kpts[0] + "." + kpts[1] 105 + hash := sha256.Sum256([]byte(signingInput)) 106 + sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 107 + if err != nil { 108 + s.logger.Error("error decoding signature bytes", "error", err) 109 + return helpers.ServerError(e, nil) 110 + } 111 + 112 + if len(sigBytes) != 64 { 113 + s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 114 + return helpers.ServerError(e, nil) 115 + } 116 + 117 + rBytes := sigBytes[:32] 118 + sBytes := sigBytes[32:] 119 + rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 120 + ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 121 + 122 + sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 123 + if err != nil { 124 + s.logger.Error("can't load private key", "error", err) 125 + return err 126 + } 127 + 128 + pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 129 + if !ok { 130 + s.logger.Error("error getting public key from sk") 131 + return helpers.ServerError(e, nil) 132 + } 133 + 134 + verified := pubKey.VerifyRaw(hash[:], rr, ss) 135 + if !verified { 136 + s.logger.Error("error verifying", "error", err) 137 + return helpers.ServerError(e, nil) 138 + } 139 + } 140 + 141 + isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 142 + scope, _ := claims["scope"].(string) 143 + 144 + if isRefresh && scope != "com.atproto.refresh" { 145 + return helpers.InvalidTokenError(e) 146 + } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 147 + return helpers.InvalidTokenError(e) 148 + } 149 + 150 + table := "tokens" 151 + if isRefresh { 152 + table = "refresh_tokens" 153 + } 154 + 155 + if isRefresh { 156 + type Result struct { 157 + Found bool 158 + } 159 + var result Result 160 + if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 161 + if err == gorm.ErrRecordNotFound { 162 + return helpers.InvalidTokenError(e) 163 + } 164 + 165 + s.logger.Error("error getting token from db", "error", err) 166 + return helpers.ServerError(e, nil) 167 + } 168 + 169 + if !result.Found { 170 + return helpers.InvalidTokenError(e) 171 + } 172 + } 173 + 174 + exp, ok := claims["exp"].(float64) 175 + if !ok { 176 + s.logger.Error("error getting iat from token") 177 + return helpers.ServerError(e, nil) 178 + } 179 + 180 + if exp < float64(time.Now().UTC().Unix()) { 181 + return helpers.ExpiredTokenError(e) 182 + } 183 + 184 + if repo == nil { 185 + maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 186 + if err != nil { 187 + s.logger.Error("error fetching repo", "error", err) 188 + return helpers.ServerError(e, nil) 189 + } 190 + repo = maybeRepo 191 + did = repo.Repo.Did 192 + } 193 + 194 + e.Set("repo", repo) 195 + e.Set("did", did) 196 + e.Set("token", tokenstr) 197 + 198 + if err := next(e); err != nil { 199 + return helpers.InvalidTokenError(e) 200 + } 201 + 202 + return nil 203 + } 204 + } 205 + 206 + func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 207 + return func(e echo.Context) error { 208 + authheader := e.Request().Header.Get("authorization") 209 + if authheader == "" { 210 + return e.JSON(401, map[string]string{"error": "Unauthorized"}) 211 + } 212 + 213 + pts := strings.Split(authheader, " ") 214 + if len(pts) != 2 { 215 + return helpers.ServerError(e, nil) 216 + } 217 + 218 + if pts[0] != "DPoP" { 219 + return next(e) 220 + } 221 + 222 + accessToken := pts[1] 223 + 224 + nonce := s.oauthProvider.NextNonce() 225 + if nonce != "" { 226 + e.Response().Header().Set("DPoP-Nonce", nonce) 227 + e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 228 + } 229 + 230 + proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 231 + if err != nil { 232 + s.logger.Error("invalid dpop proof", "error", err) 233 + return helpers.InputError(e, to.StringPtr(err.Error())) 234 + } 235 + 236 + var oauthToken provider.OauthToken 237 + if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 238 + s.logger.Error("error finding access token in db", "error", err) 239 + return helpers.InputError(e, nil) 240 + } 241 + 242 + if oauthToken.Token == "" { 243 + return helpers.InvalidTokenError(e) 244 + } 245 + 246 + if *oauthToken.Parameters.DpopJkt != proof.JKT { 247 + s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 248 + return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 249 + } 250 + 251 + if time.Now().After(oauthToken.ExpiresAt) { 252 + return helpers.ExpiredTokenError(e) 253 + } 254 + 255 + repo, err := s.getRepoActorByDid(oauthToken.Sub) 256 + if err != nil { 257 + s.logger.Error("could not find actor in db", "error", err) 258 + return helpers.ServerError(e, nil) 259 + } 260 + 261 + e.Set("repo", repo) 262 + e.Set("did", repo.Repo.Did) 263 + e.Set("token", accessToken) 264 + e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 265 + 266 + return next(e) 267 + } 268 + }
+13 -13
server/repo.go
··· 16 16 "github.com/bluesky-social/indigo/events" 17 17 lexutil "github.com/bluesky-social/indigo/lex/util" 18 18 "github.com/bluesky-social/indigo/repo" 19 - "github.com/bluesky-social/indigo/util" 20 - "github.com/haileyok/cocoon/blockstore" 21 19 "github.com/haileyok/cocoon/internal/db" 22 20 "github.com/haileyok/cocoon/models" 21 + "github.com/haileyok/cocoon/recording_blockstore" 23 22 blocks "github.com/ipfs/go-block-format" 24 23 "github.com/ipfs/go-cid" 25 24 cbor "github.com/ipfs/go-ipld-cbor" ··· 103 102 return nil, err 104 103 } 105 104 106 - dbs := blockstore.New(urepo.Did, rm.db) 105 + dbs := rm.s.getBlockstore(urepo.Did) 106 + bs := recording_blockstore.New(dbs) 107 107 r, err := repo.OpenRepo(context.TODO(), dbs, rootcid) 108 108 109 109 entries := []models.Record{} ··· 274 274 } 275 275 } 276 276 277 - for _, op := range dbs.GetLog() { 277 + for _, op := range bs.GetLogMap() { 278 278 if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil { 279 279 return nil, err 280 280 } ··· 318 318 Rev: rev, 319 319 Since: &urepo.Rev, 320 320 Commit: lexutil.LexLink(newroot), 321 - Time: time.Now().Format(util.ISO8601), 321 + Time: time.Now().Format(time.RFC3339Nano), 322 322 Ops: ops, 323 323 TooBig: false, 324 324 }, 325 325 }) 326 326 327 - if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil { 327 + if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil { 328 328 return nil, err 329 329 } 330 330 ··· 345 345 return cid.Undef, nil, err 346 346 } 347 347 348 - dbs := blockstore.New(urepo.Did, rm.db) 349 - bs := util.NewLoggingBstore(dbs) 348 + dbs := rm.s.getBlockstore(urepo.Did) 349 + bs := recording_blockstore.New(dbs) 350 350 351 351 r, err := repo.OpenRepo(context.TODO(), bs, c) 352 352 if err != nil { ··· 358 358 return cid.Undef, nil, err 359 359 } 360 360 361 - return c, bs.GetLoggedBlocks(), nil 361 + return c, bs.GetLogArray(), nil 362 362 } 363 363 364 364 func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { ··· 414 414 return nil, fmt.Errorf("error unmarshaling cbor: %w", err) 415 415 } 416 416 417 - var deepiter func(interface{}) error 418 - deepiter = func(item interface{}) error { 417 + var deepiter func(any) error 418 + deepiter = func(item any) error { 419 419 switch val := item.(type) { 420 - case map[string]interface{}: 420 + case map[string]any: 421 421 if val["$type"] == "blob" { 422 422 if ref, ok := val["ref"].(string); ok { 423 423 c, err := cid.Parse(ref) ··· 430 430 return deepiter(v) 431 431 } 432 432 } 433 - case []interface{}: 433 + case []any: 434 434 for _, v := range val { 435 435 deepiter(v) 436 436 }
+44 -283
server/server.go
··· 4 4 "bytes" 5 5 "context" 6 6 "crypto/ecdsa" 7 - "crypto/sha256" 8 7 "embed" 9 - "encoding/base64" 10 8 "errors" 11 9 "fmt" 12 10 "io" ··· 15 13 "net/smtp" 16 14 "os" 17 15 "path/filepath" 18 - "strings" 19 16 "sync" 20 17 "text/template" 21 18 "time" 22 19 23 - "github.com/Azure/go-autorest/autorest/to" 24 20 "github.com/aws/aws-sdk-go/aws" 25 21 "github.com/aws/aws-sdk-go/aws/credentials" 26 22 "github.com/aws/aws-sdk-go/aws/session" ··· 32 28 "github.com/bluesky-social/indigo/xrpc" 33 29 "github.com/domodwyer/mailyak/v3" 34 30 "github.com/go-playground/validator" 35 - "github.com/golang-jwt/jwt/v4" 36 31 "github.com/gorilla/sessions" 37 32 "github.com/haileyok/cocoon/identity" 38 33 "github.com/haileyok/cocoon/internal/db" 39 34 "github.com/haileyok/cocoon/internal/helpers" 40 35 "github.com/haileyok/cocoon/models" 41 - "github.com/haileyok/cocoon/oauth/client_manager" 36 + "github.com/haileyok/cocoon/oauth/client" 42 37 "github.com/haileyok/cocoon/oauth/constants" 43 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 38 + "github.com/haileyok/cocoon/oauth/dpop" 44 39 "github.com/haileyok/cocoon/oauth/provider" 45 40 "github.com/haileyok/cocoon/plc" 41 + "github.com/ipfs/go-cid" 46 42 echo_session "github.com/labstack/echo-contrib/session" 47 43 "github.com/labstack/echo/v4" 48 44 "github.com/labstack/echo/v4/middleware" 49 45 slogecho "github.com/samber/slog-echo" 50 - "gitlab.com/yawning/secp256k1-voi" 51 - secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 52 46 "gorm.io/driver/sqlite" 53 47 "gorm.io/gorm" 54 48 ) ··· 109 103 S3Config *S3Config 110 104 111 105 SessionSecret string 106 + 107 + DefaultAtprotoProxy string 108 + 109 + BlockstoreVariant BlockstoreVariant 112 110 } 113 111 114 112 type config struct { 115 - Version string 116 - Did string 117 - Hostname string 118 - ContactEmail string 119 - EnforcePeering bool 120 - Relays []string 121 - AdminPassword string 122 - SmtpEmail string 123 - SmtpName string 113 + Version string 114 + Did string 115 + Hostname string 116 + ContactEmail string 117 + EnforcePeering bool 118 + Relays []string 119 + AdminPassword string 120 + SmtpEmail string 121 + SmtpName string 122 + DefaultAtprotoProxy string 123 + BlockstoreVariant BlockstoreVariant 124 124 } 125 125 126 126 type CustomValidator struct { ··· 197 197 return t.templates.ExecuteTemplate(w, name, data) 198 198 } 199 199 200 - func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 201 - return func(e echo.Context) error { 202 - username, password, ok := e.Request().BasicAuth() 203 - if !ok || username != "admin" || password != s.config.AdminPassword { 204 - return helpers.InputError(e, to.StringPtr("Unauthorized")) 205 - } 206 - 207 - if err := next(e); err != nil { 208 - e.Error(err) 209 - } 210 - 211 - return nil 212 - } 213 - } 214 - 215 - func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 216 - return func(e echo.Context) error { 217 - authheader := e.Request().Header.Get("authorization") 218 - if authheader == "" { 219 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 220 - } 221 - 222 - pts := strings.Split(authheader, " ") 223 - if len(pts) != 2 { 224 - return helpers.ServerError(e, nil) 225 - } 226 - 227 - // move on to oauth session middleware if this is a dpop token 228 - if pts[0] == "DPoP" { 229 - return next(e) 230 - } 231 - 232 - tokenstr := pts[1] 233 - token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 234 - claims, ok := token.Claims.(jwt.MapClaims) 235 - if !ok { 236 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 237 - } 238 - 239 - var did string 240 - var repo *models.RepoActor 241 - 242 - // service auth tokens 243 - lxm, hasLxm := claims["lxm"] 244 - if hasLxm { 245 - pts := strings.Split(e.Request().URL.String(), "/") 246 - if lxm != pts[len(pts)-1] { 247 - s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 248 - return helpers.InputError(e, nil) 249 - } 250 - 251 - maybeDid, ok := claims["iss"].(string) 252 - if !ok { 253 - s.logger.Error("no iss in service auth token", "error", err) 254 - return helpers.InputError(e, nil) 255 - } 256 - did = maybeDid 257 - 258 - maybeRepo, err := s.getRepoActorByDid(did) 259 - if err != nil { 260 - s.logger.Error("error fetching repo", "error", err) 261 - return helpers.ServerError(e, nil) 262 - } 263 - repo = maybeRepo 264 - } 265 - 266 - if token.Header["alg"] != "ES256K" { 267 - token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 268 - if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 269 - return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 270 - } 271 - return s.privateKey.Public(), nil 272 - }) 273 - if err != nil { 274 - s.logger.Error("error parsing jwt", "error", err) 275 - // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 276 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 277 - } 278 - 279 - if !token.Valid { 280 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 281 - } 282 - } else { 283 - kpts := strings.Split(tokenstr, ".") 284 - signingInput := kpts[0] + "." + kpts[1] 285 - hash := sha256.Sum256([]byte(signingInput)) 286 - sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 287 - if err != nil { 288 - s.logger.Error("error decoding signature bytes", "error", err) 289 - return helpers.ServerError(e, nil) 290 - } 291 - 292 - if len(sigBytes) != 64 { 293 - s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 294 - return helpers.ServerError(e, nil) 295 - } 296 - 297 - rBytes := sigBytes[:32] 298 - sBytes := sigBytes[32:] 299 - rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 300 - ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 301 - 302 - sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 303 - if err != nil { 304 - s.logger.Error("can't load private key", "error", err) 305 - return err 306 - } 307 - 308 - pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 309 - if !ok { 310 - s.logger.Error("error getting public key from sk") 311 - return helpers.ServerError(e, nil) 312 - } 313 - 314 - verified := pubKey.VerifyRaw(hash[:], rr, ss) 315 - if !verified { 316 - s.logger.Error("error verifying", "error", err) 317 - return helpers.ServerError(e, nil) 318 - } 319 - } 320 - 321 - isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 322 - scope, _ := claims["scope"].(string) 323 - 324 - if isRefresh && scope != "com.atproto.refresh" { 325 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 326 - } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 327 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 328 - } 329 - 330 - table := "tokens" 331 - if isRefresh { 332 - table = "refresh_tokens" 333 - } 334 - 335 - if isRefresh { 336 - type Result struct { 337 - Found bool 338 - } 339 - var result Result 340 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 341 - if err == gorm.ErrRecordNotFound { 342 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 343 - } 344 - 345 - s.logger.Error("error getting token from db", "error", err) 346 - return helpers.ServerError(e, nil) 347 - } 348 - 349 - if !result.Found { 350 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 351 - } 352 - } 353 - 354 - exp, ok := claims["exp"].(float64) 355 - if !ok { 356 - s.logger.Error("error getting iat from token") 357 - return helpers.ServerError(e, nil) 358 - } 359 - 360 - if exp < float64(time.Now().UTC().Unix()) { 361 - return helpers.InputError(e, to.StringPtr("ExpiredToken")) 362 - } 363 - 364 - if repo == nil { 365 - maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 366 - if err != nil { 367 - s.logger.Error("error fetching repo", "error", err) 368 - return helpers.ServerError(e, nil) 369 - } 370 - repo = maybeRepo 371 - did = repo.Repo.Did 372 - } 373 - 374 - e.Set("repo", repo) 375 - e.Set("did", did) 376 - e.Set("token", tokenstr) 377 - 378 - if err := next(e); err != nil { 379 - e.Error(err) 380 - } 381 - 382 - return nil 383 - } 384 - } 385 - 386 - func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 387 - return func(e echo.Context) error { 388 - authheader := e.Request().Header.Get("authorization") 389 - if authheader == "" { 390 - return e.JSON(401, map[string]string{"error": "Unauthorized"}) 391 - } 392 - 393 - pts := strings.Split(authheader, " ") 394 - if len(pts) != 2 { 395 - return helpers.ServerError(e, nil) 396 - } 397 - 398 - if pts[0] != "DPoP" { 399 - return next(e) 400 - } 401 - 402 - accessToken := pts[1] 403 - 404 - nonce := s.oauthProvider.NextNonce() 405 - if nonce != "" { 406 - e.Response().Header().Set("DPoP-Nonce", nonce) 407 - e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 408 - } 409 - 410 - proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 411 - if err != nil { 412 - s.logger.Error("invalid dpop proof", "error", err) 413 - return helpers.InputError(e, to.StringPtr(err.Error())) 414 - } 415 - 416 - var oauthToken provider.OauthToken 417 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 418 - s.logger.Error("error finding access token in db", "error", err) 419 - return helpers.InputError(e, nil) 420 - } 421 - 422 - if oauthToken.Token == "" { 423 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 424 - } 425 - 426 - if *oauthToken.Parameters.DpopJkt != proof.JKT { 427 - s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 428 - return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 429 - } 430 - 431 - if time.Now().After(oauthToken.ExpiresAt) { 432 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 433 - } 434 - 435 - repo, err := s.getRepoActorByDid(oauthToken.Sub) 436 - if err != nil { 437 - s.logger.Error("could not find actor in db", "error", err) 438 - return helpers.ServerError(e, nil) 439 - } 440 - 441 - e.Set("repo", repo) 442 - e.Set("did", repo.Repo.Did) 443 - e.Set("token", accessToken) 444 - e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 445 - 446 - return next(e) 447 - } 448 - } 449 - 450 200 func New(args *Args) (*Server, error) { 451 201 if args.Addr == "" { 452 202 return nil, fmt.Errorf("addr must be set") ··· 593 343 plcClient: plcClient, 594 344 privateKey: &pkey, 595 345 config: &config{ 596 - Version: args.Version, 597 - Did: args.Did, 598 - Hostname: args.Hostname, 599 - ContactEmail: args.ContactEmail, 600 - EnforcePeering: false, 601 - Relays: args.Relays, 602 - AdminPassword: args.AdminPassword, 603 - SmtpName: args.SmtpName, 604 - SmtpEmail: args.SmtpEmail, 346 + Version: args.Version, 347 + Did: args.Did, 348 + Hostname: args.Hostname, 349 + ContactEmail: args.ContactEmail, 350 + EnforcePeering: false, 351 + Relays: args.Relays, 352 + AdminPassword: args.AdminPassword, 353 + SmtpName: args.SmtpName, 354 + SmtpEmail: args.SmtpEmail, 355 + DefaultAtprotoProxy: args.DefaultAtprotoProxy, 356 + BlockstoreVariant: args.BlockstoreVariant, 605 357 }, 606 358 evtman: events.NewEventManager(events.NewMemPersister()), 607 359 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), ··· 611 363 612 364 oauthProvider: provider.NewProvider(provider.Args{ 613 365 Hostname: args.Hostname, 614 - ClientManagerArgs: client_manager.Args{ 366 + ClientManagerArgs: client.ManagerArgs{ 615 367 Cli: oauthCli, 616 368 Logger: args.Logger, 617 369 }, 618 - DpopManagerArgs: dpop_manager.Args{ 370 + DpopManagerArgs: dpop.ManagerArgs{ 619 371 NonceSecret: nonceSecret, 620 372 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 621 373 OnNonceSecretCreated: func(newNonce []byte) { ··· 712 464 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 713 465 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 714 466 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 467 + s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 715 468 716 469 // repo 717 470 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) ··· 725 478 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 726 479 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 727 480 728 - // are there any routes that we should be allowing without auth? i dont think so but idk 729 - s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 730 - s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 731 - 732 481 // admin routes 733 482 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 734 483 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 484 + 485 + // are there any routes that we should be allowing without auth? i dont think so but idk 486 + s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 487 + s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 735 488 } 736 489 737 490 func (s *Server) Serve(ctx context.Context) error { ··· 893 646 go s.doBackup() 894 647 } 895 648 } 649 + 650 + func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 651 + if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 652 + return err 653 + } 654 + 655 + return nil 656 + }
+5 -4
server/templates/account.html
··· 24 24 </div> 25 25 {{ else }} {{ range .Tokens }} 26 26 <div class="base-container"> 27 - <h4>{{ .ClientId }}</h4> 28 - <p>Created: {{ .CreatedAt }}</p> 29 - <p>Updated: {{ .UpdatedAt }}</p> 30 - <p>Expires: {{ .ExpiresAt }}</p> 27 + <h4>{{ .ClientName }}</h4> 28 + <p>Session Age: {{ .Age}}</p> 29 + <p>Last Updated: {{ .LastUpdated }} ago</p> 30 + <p>Expires In: {{ .ExpiresIn }}</p> 31 + <p>IP Address: {{ .Ip }}</p> 31 32 <form action="/account/revoke" method="post"> 32 33 <input type="hidden" name="token" value="{{ .Token }}" /> 33 34 <button type="submit" value="">Revoke</button>
+155
sqlite_blockstore/sqlite_blockstore.go
··· 1 + package sqlite_blockstore 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + 7 + "github.com/bluesky-social/indigo/atproto/syntax" 8 + "github.com/haileyok/cocoon/internal/db" 9 + "github.com/haileyok/cocoon/models" 10 + blocks "github.com/ipfs/go-block-format" 11 + "github.com/ipfs/go-cid" 12 + "gorm.io/gorm/clause" 13 + ) 14 + 15 + type SqliteBlockstore struct { 16 + db *db.DB 17 + did string 18 + readonly bool 19 + inserts map[cid.Cid]blocks.Block 20 + } 21 + 22 + func New(did string, db *db.DB) *SqliteBlockstore { 23 + return &SqliteBlockstore{ 24 + did: did, 25 + db: db, 26 + readonly: false, 27 + inserts: map[cid.Cid]blocks.Block{}, 28 + } 29 + } 30 + 31 + func NewReadOnly(did string, db *db.DB) *SqliteBlockstore { 32 + return &SqliteBlockstore{ 33 + did: did, 34 + db: db, 35 + readonly: true, 36 + inserts: map[cid.Cid]blocks.Block{}, 37 + } 38 + } 39 + 40 + func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) { 41 + var block models.Block 42 + 43 + maybeBlock, ok := bs.inserts[cid] 44 + if ok { 45 + return maybeBlock, nil 46 + } 47 + 48 + if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 + return nil, err 50 + } 51 + 52 + b, err := blocks.NewBlockWithCid(block.Value, cid) 53 + if err != nil { 54 + return nil, err 55 + } 56 + 57 + return b, nil 58 + } 59 + 60 + func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error { 61 + bs.inserts[block.Cid()] = block 62 + 63 + if bs.readonly { 64 + return nil 65 + } 66 + 67 + b := models.Block{ 68 + Did: bs.did, 69 + Cid: block.Cid().Bytes(), 70 + Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 71 + Value: block.RawData(), 72 + } 73 + 74 + if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ 75 + Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 + UpdateAll: true, 77 + }}).Error; err != nil { 78 + return err 79 + } 80 + 81 + return nil 82 + } 83 + 84 + func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error { 85 + panic("not implemented") 86 + } 87 + 88 + func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) { 89 + panic("not implemented") 90 + } 91 + 92 + func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) { 93 + panic("not implemented") 94 + } 95 + 96 + func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 + tx := bs.db.BeginDangerously() 98 + 99 + for _, block := range blocks { 100 + bs.inserts[block.Cid()] = block 101 + 102 + if bs.readonly { 103 + continue 104 + } 105 + 106 + b := models.Block{ 107 + Did: bs.did, 108 + Cid: block.Cid().Bytes(), 109 + Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this 110 + Value: block.RawData(), 111 + } 112 + 113 + if err := tx.Clauses(clause.OnConflict{ 114 + Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 115 + UpdateAll: true, 116 + }).Create(&b).Error; err != nil { 117 + tx.Rollback() 118 + return err 119 + } 120 + } 121 + 122 + if bs.readonly { 123 + return nil 124 + } 125 + 126 + tx.Commit() 127 + 128 + return nil 129 + } 130 + 131 + func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) { 132 + panic("not implemented") 133 + } 134 + 135 + func (bs *SqliteBlockstore) HashOnRead(enabled bool) { 136 + panic("not implemented") 137 + } 138 + 139 + func (bs *SqliteBlockstore) Execute(ctx context.Context) error { 140 + if !bs.readonly { 141 + return fmt.Errorf("blockstore was not readonly") 142 + } 143 + 144 + bs.readonly = false 145 + for _, b := range bs.inserts { 146 + bs.Put(ctx, b) 147 + } 148 + bs.readonly = true 149 + 150 + return nil 151 + } 152 + 153 + func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block { 154 + return bs.inserts 155 + }