+63
-59
README.md
+63
-59
README.md
···
5
5
6
6
Cocoon is a PDS implementation in Go. It is highly experimental, and is not ready for any production use.
7
7
8
-
### Impmlemented Endpoints
8
+
## Implemented Endpoints
9
9
10
10
> [!NOTE]
11
-
Just because something is implemented doesn't mean it is finisehd. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that.
11
+
Just because something is implemented doesn't mean it is finished. Tons of these are returning bad errors, don't do validation properly, etc. I'll make a "second pass" checklist at some point to do all of that.
12
12
13
-
#### Identity
14
-
- [ ] com.atproto.identity.getRecommendedDidCredentials
15
-
- [ ] com.atproto.identity.requestPlcOperationSignature
16
-
- [x] com.atproto.identity.resolveHandle
17
-
- [ ] com.atproto.identity.signPlcOperation
18
-
- [ ] com.atproto.identity.submitPlcOperatioin
19
-
- [x] com.atproto.identity.updateHandle
13
+
### Identity
20
14
21
-
#### Repo
22
-
- [x] com.atproto.repo.applyWrites
23
-
- [x] com.atproto.repo.createRecord
24
-
- [x] com.atproto.repo.putRecord
25
-
- [x] com.atproto.repo.deleteRecord
26
-
- [x] com.atproto.repo.describeRepo
27
-
- [x] com.atproto.repo.getRecord
28
-
- [ ] com.atproto.repo.importRepo
29
-
- [x] com.atproto.repo.listRecords
30
-
- [ ] com.atproto.repo.listMissingBlobs
15
+
- [ ] `com.atproto.identity.getRecommendedDidCredentials`
16
+
- [ ] `com.atproto.identity.requestPlcOperationSignature`
17
+
- [x] `com.atproto.identity.resolveHandle`
18
+
- [ ] `com.atproto.identity.signPlcOperation`
19
+
- [ ] `com.atproto.identity.submitPlcOperation`
20
+
- [x] `com.atproto.identity.updateHandle`
31
21
32
-
#### Server
33
-
- [ ] com.atproto.server.activateAccount
34
-
- [ ] com.atproto.server.checkAccountStatus
35
-
- [x] com.atproto.server.confirmEmail
36
-
- [x] com.atproto.server.createAccount
37
-
- [x] com.atproto.server.createInviteCode
38
-
- [x] com.atproto.server.createInviteCodes
39
-
- [ ] com.atproto.server.deactivateAccount
40
-
- [ ] com.atproto.server.deleteAccount
41
-
- [x] com.atproto.server.deleteSession
42
-
- [x] com.atproto.server.describeServer
43
-
- [ ] com.atproto.server.getAccountInviteCodes
44
-
- [ ] com.atproto.server.getServiceAuth
45
-
- ~[ ] com.atproto.server.listAppPasswords~ - not going to add app passwords
46
-
- [x] com.atproto.server.refreshSession
47
-
- [ ] com.atproto.server.requestAccountDelete
48
-
- [x] com.atproto.server.requestEmailConfirmation
49
-
- [x] com.atproto.server.requestEmailUpdate
50
-
- [x] com.atproto.server.requestPasswordReset
51
-
- [ ] com.atproto.server.reserveSigningKey
52
-
- [x] com.atproto.server.resetPassword
53
-
- ~[ ] com.atproto.server.revokeAppPassword~ - not going to add app passwords
54
-
- [x] com.atproto.server.updateEmail
22
+
### Repo
55
23
56
-
#### Sync
57
-
- [x] com.atproto.sync.getBlob
58
-
- [x] com.atproto.sync.getBlocks
59
-
- [x] com.atproto.sync.getLatestCommit
60
-
- [x] com.atproto.sync.getRecord
61
-
- [x] com.atproto.sync.getRepoStatus
62
-
- [x] com.atproto.sync.getRepo
63
-
- [x] com.atproto.sync.listBlobs
64
-
- [x] com.atproto.sync.listRepos
65
-
- ~[ ] com.atproto.sync.notifyOfUpdate~ - BGS doesn't even have this implemented lol
66
-
- [x] com.atproto.sync.requestCrawl
67
-
- [x] com.atproto.sync.subscribeRepos
24
+
- [x] `com.atproto.repo.applyWrites`
25
+
- [x] `com.atproto.repo.createRecord`
26
+
- [x] `com.atproto.repo.putRecord`
27
+
- [x] `com.atproto.repo.deleteRecord`
28
+
- [x] `com.atproto.repo.describeRepo`
29
+
- [x] `com.atproto.repo.getRecord`
30
+
- [x] `com.atproto.repo.importRepo` (Works "okay". You still have to handle PLC operations on your own when migrating. Use with extreme caution.)
31
+
- [x] `com.atproto.repo.listRecords`
32
+
- [ ] `com.atproto.repo.listMissingBlobs`
33
+
34
+
### Server
35
+
36
+
- [ ] `com.atproto.server.activateAccount`
37
+
- [x] `com.atproto.server.checkAccountStatus`
38
+
- [x] `com.atproto.server.confirmEmail`
39
+
- [x] `com.atproto.server.createAccount`
40
+
- [x] `com.atproto.server.createInviteCode`
41
+
- [x] `com.atproto.server.createInviteCodes`
42
+
- [ ] `com.atproto.server.deactivateAccount`
43
+
- [ ] `com.atproto.server.deleteAccount`
44
+
- [x] `com.atproto.server.deleteSession`
45
+
- [x] `com.atproto.server.describeServer`
46
+
- [ ] `com.atproto.server.getAccountInviteCodes`
47
+
- [ ] `com.atproto.server.getServiceAuth`
48
+
- ~~[ ] `com.atproto.server.listAppPasswords`~~ - not going to add app passwords
49
+
- [x] `com.atproto.server.refreshSession`
50
+
- [ ] `com.atproto.server.requestAccountDelete`
51
+
- [x] `com.atproto.server.requestEmailConfirmation`
52
+
- [x] `com.atproto.server.requestEmailUpdate`
53
+
- [x] `com.atproto.server.requestPasswordReset`
54
+
- [ ] `com.atproto.server.reserveSigningKey`
55
+
- [x] `com.atproto.server.resetPassword`
56
+
- ~~[] `com.atproto.server.revokeAppPassword`~~ - not going to add app passwords
57
+
- [x] `com.atproto.server.updateEmail`
68
58
69
-
#### Other
70
-
- [ ] com.atproto.label.queryLabels
71
-
- [ ] com.atproto.moderation.createReport
72
-
- [x] app.bsky.actor.getPreferences
73
-
- [x] app.bsky.actor.putPreferences
59
+
### Sync
74
60
61
+
- [x] `com.atproto.sync.getBlob`
62
+
- [x] `com.atproto.sync.getBlocks`
63
+
- [x] `com.atproto.sync.getLatestCommit`
64
+
- [x] `com.atproto.sync.getRecord`
65
+
- [x] `com.atproto.sync.getRepoStatus`
66
+
- [x] `com.atproto.sync.getRepo`
67
+
- [x] `com.atproto.sync.listBlobs`
68
+
- [x] `com.atproto.sync.listRepos`
69
+
- ~~[ ] `com.atproto.sync.notifyOfUpdate`~~ - BGS doesn't even have this implemented lol
70
+
- [x] `com.atproto.sync.requestCrawl`
71
+
- [x] `com.atproto.sync.subscribeRepos`
72
+
73
+
### Other
74
+
75
+
- [ ] `com.atproto.label.queryLabels`
76
+
- [x] `com.atproto.moderation.createReport` (Note: this should be handled by proxying, not actually implemented in the PDS)
77
+
- [x] `app.bsky.actor.getPreferences`
78
+
- [x] `app.bsky.actor.putPreferences`
75
79
76
80
## License
77
81
-163
blockstore/blockstore.go
-163
blockstore/blockstore.go
···
1
-
package blockstore
2
-
3
-
import (
4
-
"context"
5
-
"fmt"
6
-
7
-
"github.com/bluesky-social/indigo/atproto/syntax"
8
-
"github.com/haileyok/cocoon/internal/db"
9
-
"github.com/haileyok/cocoon/models"
10
-
blocks "github.com/ipfs/go-block-format"
11
-
"github.com/ipfs/go-cid"
12
-
"gorm.io/gorm/clause"
13
-
)
14
-
15
-
type SqliteBlockstore struct {
16
-
db *db.DB
17
-
did string
18
-
readonly bool
19
-
inserts map[cid.Cid]blocks.Block
20
-
}
21
-
22
-
func New(did string, db *db.DB) *SqliteBlockstore {
23
-
return &SqliteBlockstore{
24
-
did: did,
25
-
db: db,
26
-
readonly: false,
27
-
inserts: map[cid.Cid]blocks.Block{},
28
-
}
29
-
}
30
-
31
-
func NewReadOnly(did string, db *db.DB) *SqliteBlockstore {
32
-
return &SqliteBlockstore{
33
-
did: did,
34
-
db: db,
35
-
readonly: true,
36
-
inserts: map[cid.Cid]blocks.Block{},
37
-
}
38
-
}
39
-
40
-
func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) {
41
-
var block models.Block
42
-
43
-
maybeBlock, ok := bs.inserts[cid]
44
-
if ok {
45
-
return maybeBlock, nil
46
-
}
47
-
48
-
if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
49
-
return nil, err
50
-
}
51
-
52
-
b, err := blocks.NewBlockWithCid(block.Value, cid)
53
-
if err != nil {
54
-
return nil, err
55
-
}
56
-
57
-
return b, nil
58
-
}
59
-
60
-
func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error {
61
-
bs.inserts[block.Cid()] = block
62
-
63
-
if bs.readonly {
64
-
return nil
65
-
}
66
-
67
-
b := models.Block{
68
-
Did: bs.did,
69
-
Cid: block.Cid().Bytes(),
70
-
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
71
-
Value: block.RawData(),
72
-
}
73
-
74
-
if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{
75
-
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
76
-
UpdateAll: true,
77
-
}}).Error; err != nil {
78
-
return err
79
-
}
80
-
81
-
return nil
82
-
}
83
-
84
-
func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error {
85
-
panic("not implemented")
86
-
}
87
-
88
-
func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) {
89
-
panic("not implemented")
90
-
}
91
-
92
-
func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) {
93
-
panic("not implemented")
94
-
}
95
-
96
-
func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error {
97
-
tx := bs.db.BeginDangerously()
98
-
99
-
for _, block := range blocks {
100
-
bs.inserts[block.Cid()] = block
101
-
102
-
if bs.readonly {
103
-
continue
104
-
}
105
-
106
-
b := models.Block{
107
-
Did: bs.did,
108
-
Cid: block.Cid().Bytes(),
109
-
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
110
-
Value: block.RawData(),
111
-
}
112
-
113
-
if err := tx.Clauses(clause.OnConflict{
114
-
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
115
-
UpdateAll: true,
116
-
}).Create(&b).Error; err != nil {
117
-
tx.Rollback()
118
-
return err
119
-
}
120
-
}
121
-
122
-
if bs.readonly {
123
-
return nil
124
-
}
125
-
126
-
tx.Commit()
127
-
128
-
return nil
129
-
}
130
-
131
-
func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
132
-
panic("not implemented")
133
-
}
134
-
135
-
func (bs *SqliteBlockstore) HashOnRead(enabled bool) {
136
-
panic("not implemented")
137
-
}
138
-
139
-
func (bs *SqliteBlockstore) UpdateRepo(ctx context.Context, root cid.Cid, rev string) error {
140
-
if err := bs.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, bs.did).Error; err != nil {
141
-
return err
142
-
}
143
-
144
-
return nil
145
-
}
146
-
147
-
func (bs *SqliteBlockstore) Execute(ctx context.Context) error {
148
-
if !bs.readonly {
149
-
return fmt.Errorf("blockstore was not readonly")
150
-
}
151
-
152
-
bs.readonly = false
153
-
for _, b := range bs.inserts {
154
-
bs.Put(ctx, b)
155
-
}
156
-
bs.readonly = true
157
-
158
-
return nil
159
-
}
160
-
161
-
func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block {
162
-
return bs.inserts
163
-
}
-186
cmd/admin/main.go
-186
cmd/admin/main.go
···
1
-
package main
2
-
3
-
import (
4
-
"crypto/ecdsa"
5
-
"crypto/elliptic"
6
-
"crypto/rand"
7
-
"encoding/json"
8
-
"fmt"
9
-
"os"
10
-
"time"
11
-
12
-
"github.com/bluesky-social/indigo/atproto/crypto"
13
-
"github.com/bluesky-social/indigo/atproto/syntax"
14
-
"github.com/haileyok/cocoon/internal/helpers"
15
-
"github.com/lestrrat-go/jwx/v2/jwk"
16
-
"github.com/urfave/cli/v2"
17
-
"golang.org/x/crypto/bcrypt"
18
-
"gorm.io/driver/sqlite"
19
-
"gorm.io/gorm"
20
-
)
21
-
22
-
func main() {
23
-
app := cli.App{
24
-
Name: "admin",
25
-
Commands: cli.Commands{
26
-
runCreateRotationKey,
27
-
runCreatePrivateJwk,
28
-
runCreateInviteCode,
29
-
runResetPassword,
30
-
},
31
-
ErrWriter: os.Stdout,
32
-
}
33
-
34
-
app.Run(os.Args)
35
-
}
36
-
37
-
var runCreateRotationKey = &cli.Command{
38
-
Name: "create-rotation-key",
39
-
Usage: "creates a rotation key for your pds",
40
-
Flags: []cli.Flag{
41
-
&cli.StringFlag{
42
-
Name: "out",
43
-
Required: true,
44
-
Usage: "output file for your rotation key",
45
-
},
46
-
},
47
-
Action: func(cmd *cli.Context) error {
48
-
key, err := crypto.GeneratePrivateKeyK256()
49
-
if err != nil {
50
-
return err
51
-
}
52
-
53
-
bytes := key.Bytes()
54
-
55
-
if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil {
56
-
return err
57
-
}
58
-
59
-
return nil
60
-
},
61
-
}
62
-
63
-
var runCreatePrivateJwk = &cli.Command{
64
-
Name: "create-private-jwk",
65
-
Usage: "creates a private jwk for your pds",
66
-
Flags: []cli.Flag{
67
-
&cli.StringFlag{
68
-
Name: "out",
69
-
Required: true,
70
-
Usage: "output file for your jwk",
71
-
},
72
-
},
73
-
Action: func(cmd *cli.Context) error {
74
-
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
75
-
if err != nil {
76
-
return err
77
-
}
78
-
79
-
key, err := jwk.FromRaw(privKey)
80
-
if err != nil {
81
-
return err
82
-
}
83
-
84
-
kid := fmt.Sprintf("%d", time.Now().Unix())
85
-
86
-
if err := key.Set(jwk.KeyIDKey, kid); err != nil {
87
-
return err
88
-
}
89
-
90
-
b, err := json.Marshal(key)
91
-
if err != nil {
92
-
return err
93
-
}
94
-
95
-
if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil {
96
-
return err
97
-
}
98
-
99
-
return nil
100
-
},
101
-
}
102
-
103
-
var runCreateInviteCode = &cli.Command{
104
-
Name: "create-invite-code",
105
-
Usage: "creates an invite code",
106
-
Flags: []cli.Flag{
107
-
&cli.StringFlag{
108
-
Name: "for",
109
-
Usage: "optional did to assign the invite code to",
110
-
},
111
-
&cli.IntFlag{
112
-
Name: "uses",
113
-
Usage: "number of times the invite code can be used",
114
-
Value: 1,
115
-
},
116
-
},
117
-
Action: func(cmd *cli.Context) error {
118
-
db, err := newDb()
119
-
if err != nil {
120
-
return err
121
-
}
122
-
123
-
forDid := "did:plc:123"
124
-
if cmd.String("for") != "" {
125
-
did, err := syntax.ParseDID(cmd.String("for"))
126
-
if err != nil {
127
-
return err
128
-
}
129
-
130
-
forDid = did.String()
131
-
}
132
-
133
-
uses := cmd.Int("uses")
134
-
135
-
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8))
136
-
137
-
if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil {
138
-
return err
139
-
}
140
-
141
-
fmt.Printf("New invite code created with %d uses: %s\n", uses, code)
142
-
143
-
return nil
144
-
},
145
-
}
146
-
147
-
var runResetPassword = &cli.Command{
148
-
Name: "reset-password",
149
-
Usage: "resets a password",
150
-
Flags: []cli.Flag{
151
-
&cli.StringFlag{
152
-
Name: "did",
153
-
Usage: "did of the user who's password you want to reset",
154
-
},
155
-
},
156
-
Action: func(cmd *cli.Context) error {
157
-
db, err := newDb()
158
-
if err != nil {
159
-
return err
160
-
}
161
-
162
-
didStr := cmd.String("did")
163
-
did, err := syntax.ParseDID(didStr)
164
-
if err != nil {
165
-
return err
166
-
}
167
-
168
-
newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12))
169
-
hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10)
170
-
if err != nil {
171
-
return err
172
-
}
173
-
174
-
if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil {
175
-
return err
176
-
}
177
-
178
-
fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass)
179
-
180
-
return nil
181
-
},
182
-
}
183
-
184
-
func newDb() (*gorm.DB, error) {
185
-
return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
186
-
}
+183
-3
cmd/cocoon/main.go
+183
-3
cmd/cocoon/main.go
···
1
1
package main
2
2
3
3
import (
4
+
"crypto/ecdsa"
5
+
"crypto/elliptic"
6
+
"crypto/rand"
7
+
"encoding/json"
4
8
"fmt"
5
9
"os"
10
+
"time"
6
11
12
+
"github.com/bluesky-social/indigo/atproto/crypto"
13
+
"github.com/bluesky-social/indigo/atproto/syntax"
14
+
"github.com/haileyok/cocoon/internal/helpers"
7
15
"github.com/haileyok/cocoon/server"
8
16
_ "github.com/joho/godotenv/autoload"
17
+
"github.com/lestrrat-go/jwx/v2/jwk"
9
18
"github.com/urfave/cli/v2"
19
+
"golang.org/x/crypto/bcrypt"
20
+
"gorm.io/driver/sqlite"
21
+
"gorm.io/gorm"
10
22
)
11
23
12
24
var Version = "dev"
···
119
131
Name: "session-secret",
120
132
EnvVars: []string{"COCOON_SESSION_SECRET"},
121
133
},
134
+
&cli.StringFlag{
135
+
Name: "default-atproto-proxy",
136
+
EnvVars: []string{"COCOON_DEFAULT_ATPROTO_PROXY"},
137
+
Value: "did:web:api.bsky.app#bsky_appview",
138
+
},
139
+
&cli.StringFlag{
140
+
Name: "blockstore-variant",
141
+
EnvVars: []string{"COCOON_BLOCKSTORE_VARIANT"},
142
+
Value: "sqlite",
143
+
},
122
144
},
123
145
Commands: []*cli.Command{
124
-
run,
146
+
runServe,
147
+
runCreateRotationKey,
148
+
runCreatePrivateJwk,
149
+
runCreateInviteCode,
150
+
runResetPassword,
125
151
},
126
152
ErrWriter: os.Stdout,
127
153
Version: Version,
···
132
158
}
133
159
}
134
160
135
-
var run = &cli.Command{
161
+
var runServe = &cli.Command{
136
162
Name: "run",
137
163
Usage: "Start the cocoon PDS",
138
164
Flags: []cli.Flag{},
139
165
Action: func(cmd *cli.Context) error {
166
+
140
167
s, err := server.New(&server.Args{
141
168
Addr: cmd.String("addr"),
142
169
DbName: cmd.String("db-name"),
···
162
189
AccessKey: cmd.String("s3-access-key"),
163
190
SecretKey: cmd.String("s3-secret-key"),
164
191
},
165
-
SessionSecret: cmd.String("session-secret"),
192
+
SessionSecret: cmd.String("session-secret"),
193
+
DefaultAtprotoProxy: cmd.String("default-atproto-proxy"),
194
+
BlockstoreVariant: server.MustReturnBlockstoreVariant(cmd.String("blockstore-variant")),
166
195
})
167
196
if err != nil {
168
197
fmt.Printf("error creating cocoon: %v", err)
···
177
206
return nil
178
207
},
179
208
}
209
+
210
+
var runCreateRotationKey = &cli.Command{
211
+
Name: "create-rotation-key",
212
+
Usage: "creates a rotation key for your pds",
213
+
Flags: []cli.Flag{
214
+
&cli.StringFlag{
215
+
Name: "out",
216
+
Required: true,
217
+
Usage: "output file for your rotation key",
218
+
},
219
+
},
220
+
Action: func(cmd *cli.Context) error {
221
+
key, err := crypto.GeneratePrivateKeyK256()
222
+
if err != nil {
223
+
return err
224
+
}
225
+
226
+
bytes := key.Bytes()
227
+
228
+
if err := os.WriteFile(cmd.String("out"), bytes, 0644); err != nil {
229
+
return err
230
+
}
231
+
232
+
return nil
233
+
},
234
+
}
235
+
236
+
var runCreatePrivateJwk = &cli.Command{
237
+
Name: "create-private-jwk",
238
+
Usage: "creates a private jwk for your pds",
239
+
Flags: []cli.Flag{
240
+
&cli.StringFlag{
241
+
Name: "out",
242
+
Required: true,
243
+
Usage: "output file for your jwk",
244
+
},
245
+
},
246
+
Action: func(cmd *cli.Context) error {
247
+
privKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
248
+
if err != nil {
249
+
return err
250
+
}
251
+
252
+
key, err := jwk.FromRaw(privKey)
253
+
if err != nil {
254
+
return err
255
+
}
256
+
257
+
kid := fmt.Sprintf("%d", time.Now().Unix())
258
+
259
+
if err := key.Set(jwk.KeyIDKey, kid); err != nil {
260
+
return err
261
+
}
262
+
263
+
b, err := json.Marshal(key)
264
+
if err != nil {
265
+
return err
266
+
}
267
+
268
+
if err := os.WriteFile(cmd.String("out"), b, 0644); err != nil {
269
+
return err
270
+
}
271
+
272
+
return nil
273
+
},
274
+
}
275
+
276
+
var runCreateInviteCode = &cli.Command{
277
+
Name: "create-invite-code",
278
+
Usage: "creates an invite code",
279
+
Flags: []cli.Flag{
280
+
&cli.StringFlag{
281
+
Name: "for",
282
+
Usage: "optional did to assign the invite code to",
283
+
},
284
+
&cli.IntFlag{
285
+
Name: "uses",
286
+
Usage: "number of times the invite code can be used",
287
+
Value: 1,
288
+
},
289
+
},
290
+
Action: func(cmd *cli.Context) error {
291
+
db, err := newDb()
292
+
if err != nil {
293
+
return err
294
+
}
295
+
296
+
forDid := "did:plc:123"
297
+
if cmd.String("for") != "" {
298
+
did, err := syntax.ParseDID(cmd.String("for"))
299
+
if err != nil {
300
+
return err
301
+
}
302
+
303
+
forDid = did.String()
304
+
}
305
+
306
+
uses := cmd.Int("uses")
307
+
308
+
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(8), helpers.RandomVarchar(8))
309
+
310
+
if err := db.Exec("INSERT INTO invite_codes (did, code, remaining_use_count) VALUES (?, ?, ?)", forDid, code, uses).Error; err != nil {
311
+
return err
312
+
}
313
+
314
+
fmt.Printf("New invite code created with %d uses: %s\n", uses, code)
315
+
316
+
return nil
317
+
},
318
+
}
319
+
320
+
var runResetPassword = &cli.Command{
321
+
Name: "reset-password",
322
+
Usage: "resets a password",
323
+
Flags: []cli.Flag{
324
+
&cli.StringFlag{
325
+
Name: "did",
326
+
Usage: "did of the user who's password you want to reset",
327
+
},
328
+
},
329
+
Action: func(cmd *cli.Context) error {
330
+
db, err := newDb()
331
+
if err != nil {
332
+
return err
333
+
}
334
+
335
+
didStr := cmd.String("did")
336
+
did, err := syntax.ParseDID(didStr)
337
+
if err != nil {
338
+
return err
339
+
}
340
+
341
+
newPass := fmt.Sprintf("%s-%s", helpers.RandomVarchar(12), helpers.RandomVarchar(12))
342
+
hashed, err := bcrypt.GenerateFromPassword([]byte(newPass), 10)
343
+
if err != nil {
344
+
return err
345
+
}
346
+
347
+
if err := db.Exec("UPDATE repos SET password = ? WHERE did = ?", hashed, did.String()).Error; err != nil {
348
+
return err
349
+
}
350
+
351
+
fmt.Printf("Password for %s has been reset to: %s", did.String(), newPass)
352
+
353
+
return nil
354
+
},
355
+
}
356
+
357
+
func newDb() (*gorm.DB, error) {
358
+
return gorm.Open(sqlite.Open("cocoon.db"), &gorm.Config{})
359
+
}
+45
cspell.json
+45
cspell.json
···
1
+
{
2
+
"version": "0.2",
3
+
"language": "en",
4
+
"words": [
5
+
"atproto",
6
+
"bsky",
7
+
"Cocoon",
8
+
"PDS",
9
+
"Plc",
10
+
"plc",
11
+
"repo",
12
+
"InviteCodes",
13
+
"InviteCode",
14
+
"Invite",
15
+
"Signin",
16
+
"Signout",
17
+
"JWKS",
18
+
"dpop",
19
+
"BGS",
20
+
"pico",
21
+
"picocss",
22
+
"par",
23
+
"blobs",
24
+
"blob",
25
+
"did",
26
+
"DID",
27
+
"OAuth",
28
+
"oauth",
29
+
"par",
30
+
"Cocoon",
31
+
"memcache",
32
+
"db",
33
+
"helpers",
34
+
"middleware",
35
+
"repo",
36
+
"static",
37
+
"pico",
38
+
"picocss",
39
+
"MIT",
40
+
"Go"
41
+
],
42
+
"ignorePaths": [
43
+
"server/static/pico.css"
44
+
]
45
+
}
+1
go.mod
+1
go.mod
···
14
14
github.com/google/uuid v1.4.0
15
15
github.com/gorilla/sessions v1.4.0
16
16
github.com/gorilla/websocket v1.5.1
17
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b
17
18
github.com/hashicorp/golang-lru/v2 v2.0.7
18
19
github.com/ipfs/go-block-format v0.2.0
19
20
github.com/ipfs/go-cid v0.4.1
+2
go.sum
+2
go.sum
···
91
91
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
92
92
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed h1:5upAirOpQc1Q53c0bnx2ufif5kANL7bfZWcc6VJWJd8=
93
93
github.com/hailocab/go-hostpool v0.0.0-20160125115350-e80d13ce29ed/go.mod h1:tMWxXQ9wFIaZeTI9F+hmhFiGpFmhOHzyShyFUhRm0H4=
94
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b h1:wDUNC2eKiL35DbLvsDhiblTUXHxcOPwQSCzi7xpQUN4=
95
+
github.com/hako/durafmt v0.0.0-20210608085754-5c1018a4e16b/go.mod h1:VzxiSdG6j1pi7rwGm/xYI5RbtpBgM8sARDXlvEvxlu0=
94
96
github.com/hashicorp/go-cleanhttp v0.5.2 h1:035FKYIWjmULyFRBKPs8TBQoi0x6d9G4xc9neXJWAZQ=
95
97
github.com/hashicorp/go-cleanhttp v0.5.2/go.mod h1:kO/YDlP8L1346E6Sodw+PrpBSV4/SoxCXGY6BqNFT48=
96
98
github.com/hashicorp/go-hclog v0.9.2 h1:CG6TE5H9/JXsFWJCfoIVpKFIkFe6ysEuHirp4DxCsHI=
+73
-54
identity/identity.go
+73
-54
identity/identity.go
···
13
13
"github.com/bluesky-social/indigo/util"
14
14
)
15
15
16
-
func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) {
17
-
if cli == nil {
18
-
cli = util.RobustHTTPClient()
19
-
}
20
-
21
-
var did string
22
-
23
-
_, err := syntax.ParseHandle(handle)
16
+
func ResolveHandleFromTXT(ctx context.Context, handle string) (string, error) {
17
+
name := fmt.Sprintf("_atproto.%s", handle)
18
+
recs, err := net.LookupTXT(name)
24
19
if err != nil {
25
-
return "", err
20
+
return "", fmt.Errorf("handle could not be resolved via txt: %w", err)
26
21
}
27
22
28
-
recs, err := net.LookupTXT(fmt.Sprintf("_atproto.%s", handle))
29
-
if err == nil {
30
-
for _, rec := range recs {
31
-
if strings.HasPrefix(rec, "did=") {
32
-
did = strings.Split(rec, "did=")[1]
33
-
break
23
+
for _, rec := range recs {
24
+
if strings.HasPrefix(rec, "did=") {
25
+
maybeDid := strings.Split(rec, "did=")[1]
26
+
if _, err := syntax.ParseDID(maybeDid); err == nil {
27
+
return maybeDid, nil
34
28
}
35
29
}
36
-
} else {
37
-
fmt.Printf("erorr getting txt records: %v\n", err)
38
30
}
39
31
40
-
if did == "" {
41
-
req, err := http.NewRequestWithContext(
42
-
ctx,
43
-
"GET",
44
-
fmt.Sprintf("https://%s/.well-known/atproto-did", handle),
45
-
nil,
46
-
)
47
-
if err != nil {
48
-
return "", nil
49
-
}
32
+
return "", fmt.Errorf("handle could not be resolved via txt: no record found")
33
+
}
50
34
51
-
resp, err := http.DefaultClient.Do(req)
52
-
if err != nil {
53
-
return "", nil
54
-
}
55
-
defer resp.Body.Close()
35
+
func ResolveHandleFromWellKnown(ctx context.Context, cli *http.Client, handle string) (string, error) {
36
+
ustr := fmt.Sprintf("https://%s/.well=known/atproto-did", handle)
37
+
req, err := http.NewRequestWithContext(
38
+
ctx,
39
+
"GET",
40
+
ustr,
41
+
nil,
42
+
)
43
+
if err != nil {
44
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
45
+
}
56
46
57
-
if resp.StatusCode != http.StatusOK {
58
-
io.Copy(io.Discard, resp.Body)
59
-
return "", fmt.Errorf("unable to resolve handle")
60
-
}
47
+
resp, err := cli.Do(req)
48
+
if err != nil {
49
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
50
+
}
51
+
defer resp.Body.Close()
61
52
62
-
b, err := io.ReadAll(resp.Body)
63
-
if err != nil {
64
-
return "", err
65
-
}
53
+
b, err := io.ReadAll(resp.Body)
54
+
if err != nil {
55
+
return "", fmt.Errorf("handle could not be resolved via web: %w", err)
56
+
}
66
57
67
-
maybeDid := string(b)
58
+
if resp.StatusCode != http.StatusOK {
59
+
return "", fmt.Errorf("handle could not be resolved via web: invalid status code %d", resp.StatusCode)
60
+
}
68
61
69
-
if _, err := syntax.ParseDID(maybeDid); err != nil {
70
-
return "", fmt.Errorf("unable to resolve handle")
71
-
}
62
+
maybeDid := string(b)
72
63
73
-
did = maybeDid
64
+
if _, err := syntax.ParseDID(maybeDid); err != nil {
65
+
return "", fmt.Errorf("handle could not be resolved via web: invalid did in document")
74
66
}
75
67
76
-
return did, nil
68
+
return maybeDid, nil
77
69
}
78
70
79
-
func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) {
71
+
func ResolveHandle(ctx context.Context, cli *http.Client, handle string) (string, error) {
80
72
if cli == nil {
81
73
cli = util.RobustHTTPClient()
82
74
}
83
75
84
-
var ustr string
76
+
_, err := syntax.ParseHandle(handle)
77
+
if err != nil {
78
+
return "", err
79
+
}
80
+
81
+
if maybeDidFromTxt, err := ResolveHandleFromTXT(ctx, handle); err == nil {
82
+
return maybeDidFromTxt, nil
83
+
}
84
+
85
+
if maybeDidFromWeb, err := ResolveHandleFromWellKnown(ctx, cli, handle); err == nil {
86
+
return maybeDidFromWeb, nil
87
+
}
88
+
89
+
return "", fmt.Errorf("handle could not be resolved")
90
+
}
91
+
92
+
func DidToDocUrl(did string) (string, error) {
85
93
if strings.HasPrefix(did, "did:plc:") {
86
-
ustr = fmt.Sprintf("https://plc.directory/%s", did)
94
+
return fmt.Sprintf("https://plc.directory/%s", did), nil
87
95
} else if strings.HasPrefix(did, "did:web:") {
88
-
ustr = fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:"))
96
+
return fmt.Sprintf("https://%s/.well-known/did.json", strings.TrimPrefix(did, "did:web:")), nil
89
97
} else {
90
-
return nil, fmt.Errorf("did was not a supported did type")
98
+
return "", fmt.Errorf("did was not a supported did type")
99
+
}
100
+
}
101
+
102
+
func FetchDidDoc(ctx context.Context, cli *http.Client, did string) (*DidDoc, error) {
103
+
if cli == nil {
104
+
cli = util.RobustHTTPClient()
105
+
}
106
+
107
+
ustr, err := DidToDocUrl(did)
108
+
if err != nil {
109
+
return nil, err
91
110
}
92
111
93
112
req, err := http.NewRequestWithContext(ctx, "GET", ustr, nil)
···
95
114
return nil, err
96
115
}
97
116
98
-
resp, err := http.DefaultClient.Do(req)
117
+
resp, err := cli.Do(req)
99
118
if err != nil {
100
119
return nil, err
101
120
}
···
103
122
104
123
if resp.StatusCode != 200 {
105
124
io.Copy(io.Discard, resp.Body)
106
-
return nil, fmt.Errorf("could not find identity in plc registry")
125
+
return nil, fmt.Errorf("unable to find did doc at url. did: %s. url: %s", did, ustr)
107
126
}
108
127
109
128
var diddoc DidDoc
···
127
146
return nil, err
128
147
}
129
148
130
-
resp, err := http.DefaultClient.Do(req)
149
+
resp, err := cli.Do(req)
131
150
if err != nil {
132
151
return nil, err
133
152
}
+16
-5
identity/passport.go
+16
-5
identity/passport.go
···
19
19
type Passport struct {
20
20
h *http.Client
21
21
bc BackingCache
22
-
lk sync.Mutex
22
+
mu sync.RWMutex
23
23
}
24
24
25
25
func NewPassport(h *http.Client, bc BackingCache) *Passport {
···
30
30
return &Passport{
31
31
h: h,
32
32
bc: bc,
33
-
lk: sync.Mutex{},
34
33
}
35
34
}
36
35
···
38
37
skipCache, _ := ctx.Value("skip-cache").(bool)
39
38
40
39
if !skipCache {
40
+
p.mu.RLock()
41
41
cached, ok := p.bc.GetDoc(did)
42
+
p.mu.RUnlock()
43
+
42
44
if ok {
43
45
return cached, nil
44
46
}
45
47
}
46
48
47
-
p.lk.Lock() // this is pretty pathetic, and i should rethink this. but for now, fuck it
48
-
defer p.lk.Unlock()
49
-
49
+
// TODO: should coalesce requests here
50
50
doc, err := FetchDidDoc(ctx, p.h, did)
51
51
if err != nil {
52
52
return nil, err
53
53
}
54
54
55
+
p.mu.Lock()
55
56
p.bc.PutDoc(did, doc)
57
+
p.mu.Unlock()
56
58
57
59
return doc, nil
58
60
}
···
61
63
skipCache, _ := ctx.Value("skip-cache").(bool)
62
64
63
65
if !skipCache {
66
+
p.mu.RLock()
64
67
cached, ok := p.bc.GetDid(handle)
68
+
p.mu.RUnlock()
69
+
65
70
if ok {
66
71
return cached, nil
67
72
}
···
72
77
return "", err
73
78
}
74
79
80
+
p.mu.Lock()
75
81
p.bc.PutDid(handle, did)
82
+
p.mu.Unlock()
76
83
77
84
return did, nil
78
85
}
79
86
80
87
func (p *Passport) BustDoc(ctx context.Context, did string) error {
88
+
p.mu.Lock()
89
+
defer p.mu.Unlock()
81
90
return p.bc.BustDoc(did)
82
91
}
83
92
84
93
func (p *Passport) BustDid(ctx context.Context, handle string) error {
94
+
p.mu.Lock()
95
+
defer p.mu.Unlock()
85
96
return p.bc.BustDid(handle)
86
97
}
+13
internal/helpers/helpers.go
+13
internal/helpers/helpers.go
···
7
7
"math/rand"
8
8
"net/url"
9
9
10
+
"github.com/Azure/go-autorest/autorest/to"
10
11
"github.com/labstack/echo/v4"
11
12
"github.com/lestrrat-go/jwx/v2/jwk"
12
13
)
···
29
30
msg += ". " + *suffix
30
31
}
31
32
return genericError(e, 400, msg)
33
+
}
34
+
35
+
func InvalidTokenError(e echo.Context) error {
36
+
return InputError(e, to.StringPtr("InvalidToken"))
37
+
}
38
+
39
+
func ExpiredTokenError(e echo.Context) error {
40
+
// WARN: See https://github.com/bluesky-social/atproto/discussions/3319
41
+
return e.JSON(400, map[string]string{
42
+
"error": "ExpiredToken",
43
+
"message": "*",
44
+
})
32
45
}
33
46
34
47
func genericError(e echo.Context, code int, msg string) error {
+8
oauth/client/client.go
+8
oauth/client/client.go
+389
oauth/client/manager.go
+389
oauth/client/manager.go
···
1
+
package client
2
+
3
+
import (
4
+
"context"
5
+
"encoding/json"
6
+
"errors"
7
+
"fmt"
8
+
"io"
9
+
"log/slog"
10
+
"net/http"
11
+
"net/url"
12
+
"slices"
13
+
"strings"
14
+
"time"
15
+
16
+
cache "github.com/go-pkgz/expirable-cache/v3"
17
+
"github.com/haileyok/cocoon/internal/helpers"
18
+
"github.com/lestrrat-go/jwx/v2/jwk"
19
+
)
20
+
21
+
type Manager struct {
22
+
cli *http.Client
23
+
logger *slog.Logger
24
+
jwksCache cache.Cache[string, jwk.Key]
25
+
metadataCache cache.Cache[string, Metadata]
26
+
}
27
+
28
+
type ManagerArgs struct {
29
+
Cli *http.Client
30
+
Logger *slog.Logger
31
+
}
32
+
33
+
func NewManager(args ManagerArgs) *Manager {
34
+
if args.Logger == nil {
35
+
args.Logger = slog.Default()
36
+
}
37
+
38
+
if args.Cli == nil {
39
+
args.Cli = http.DefaultClient
40
+
}
41
+
42
+
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
43
+
metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44
+
45
+
return &Manager{
46
+
cli: args.Cli,
47
+
logger: args.Logger,
48
+
jwksCache: jwksCache,
49
+
metadataCache: metadataCache,
50
+
}
51
+
}
52
+
53
+
func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) {
54
+
metadata, err := cm.getClientMetadata(ctx, clientId)
55
+
if err != nil {
56
+
return nil, err
57
+
}
58
+
59
+
var jwks jwk.Key
60
+
if metadata.JWKS != nil {
61
+
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
62
+
// make sure we use the right one
63
+
k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0])
64
+
if err != nil {
65
+
return nil, err
66
+
}
67
+
jwks = k
68
+
} else if metadata.JWKSURI != nil {
69
+
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
70
+
if err != nil {
71
+
return nil, err
72
+
}
73
+
74
+
jwks = maybeJwks
75
+
}
76
+
77
+
return &Client{
78
+
Metadata: metadata,
79
+
JWKS: jwks,
80
+
}, nil
81
+
}
82
+
83
+
func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
84
+
metadataCached, ok := cm.metadataCache.Get(clientId)
85
+
if !ok {
86
+
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
87
+
if err != nil {
88
+
return nil, err
89
+
}
90
+
91
+
resp, err := cm.cli.Do(req)
92
+
if err != nil {
93
+
return nil, err
94
+
}
95
+
defer resp.Body.Close()
96
+
97
+
if resp.StatusCode != http.StatusOK {
98
+
io.Copy(io.Discard, resp.Body)
99
+
return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
100
+
}
101
+
102
+
b, err := io.ReadAll(resp.Body)
103
+
if err != nil {
104
+
return nil, fmt.Errorf("error reading bytes from client response: %w", err)
105
+
}
106
+
107
+
validated, err := validateAndParseMetadata(clientId, b)
108
+
if err != nil {
109
+
return nil, err
110
+
}
111
+
112
+
return validated, nil
113
+
} else {
114
+
return &metadataCached, nil
115
+
}
116
+
}
117
+
118
+
func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
119
+
jwks, ok := cm.jwksCache.Get(clientId)
120
+
if !ok {
121
+
req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
122
+
if err != nil {
123
+
return nil, err
124
+
}
125
+
126
+
resp, err := cm.cli.Do(req)
127
+
if err != nil {
128
+
return nil, err
129
+
}
130
+
defer resp.Body.Close()
131
+
132
+
if resp.StatusCode != http.StatusOK {
133
+
io.Copy(io.Discard, resp.Body)
134
+
return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
135
+
}
136
+
137
+
type Keys struct {
138
+
Keys []map[string]any `json:"keys"`
139
+
}
140
+
141
+
var keys Keys
142
+
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
143
+
return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
144
+
}
145
+
146
+
if len(keys.Keys) == 0 {
147
+
return nil, errors.New("no keys in jwks response")
148
+
}
149
+
150
+
// TODO: this is again bad, we should be figuring out which one we need to use...
151
+
b, err := json.Marshal(keys.Keys[0])
152
+
if err != nil {
153
+
return nil, fmt.Errorf("could not marshal key: %w", err)
154
+
}
155
+
156
+
k, err := helpers.ParseJWKFromBytes(b)
157
+
if err != nil {
158
+
return nil, err
159
+
}
160
+
161
+
jwks = k
162
+
}
163
+
164
+
return jwks, nil
165
+
}
166
+
167
+
func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) {
168
+
var metadataMap map[string]any
169
+
if err := json.Unmarshal(b, &metadataMap); err != nil {
170
+
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
171
+
}
172
+
173
+
_, jwksOk := metadataMap["jwks"].(string)
174
+
_, jwksUriOk := metadataMap["jwks_uri"].(string)
175
+
if jwksOk && jwksUriOk {
176
+
return nil, errors.New("jwks_uri and jwks are mutually exclusive")
177
+
}
178
+
179
+
for _, k := range []string{
180
+
"default_max_age",
181
+
"userinfo_signed_response_alg",
182
+
"id_token_signed_response_alg",
183
+
"userinfo_encryhpted_response_alg",
184
+
"authorization_encrypted_response_enc",
185
+
"authorization_encrypted_response_alg",
186
+
"tls_client_certificate_bound_access_tokens",
187
+
} {
188
+
_, kOk := metadataMap[k]
189
+
if kOk {
190
+
return nil, fmt.Errorf("unsupported `%s` parameter", k)
191
+
}
192
+
}
193
+
194
+
var metadata Metadata
195
+
if err := json.Unmarshal(b, &metadata); err != nil {
196
+
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
197
+
}
198
+
199
+
u, err := url.Parse(metadata.ClientURI)
200
+
if err != nil {
201
+
return nil, fmt.Errorf("unable to parse client uri: %w", err)
202
+
}
203
+
204
+
if isLocalHostname(u.Hostname()) {
205
+
return nil, errors.New("`client_uri` hostname is invalid")
206
+
}
207
+
208
+
if metadata.Scope == "" {
209
+
return nil, errors.New("missing `scopes` scope")
210
+
}
211
+
212
+
scopes := strings.Split(metadata.Scope, " ")
213
+
if !slices.Contains(scopes, "atproto") {
214
+
return nil, errors.New("missing `atproto` scope")
215
+
}
216
+
217
+
scopesMap := map[string]bool{}
218
+
for _, scope := range scopes {
219
+
if scopesMap[scope] {
220
+
return nil, fmt.Errorf("duplicate scope `%s`", scope)
221
+
}
222
+
223
+
// TODO: check for unsupported scopes
224
+
225
+
scopesMap[scope] = true
226
+
}
227
+
228
+
grantTypesMap := map[string]bool{}
229
+
for _, gt := range metadata.GrantTypes {
230
+
if grantTypesMap[gt] {
231
+
return nil, fmt.Errorf("duplicate grant type `%s`", gt)
232
+
}
233
+
234
+
switch gt {
235
+
case "implicit":
236
+
return nil, errors.New("grantg type `implicit` is not allowed")
237
+
case "authorization_code", "refresh_token":
238
+
// TODO check if this grant type is supported
239
+
default:
240
+
return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
241
+
}
242
+
243
+
grantTypesMap[gt] = true
244
+
}
245
+
246
+
if metadata.ClientID != clientId {
247
+
return nil, errors.New("`client_id` does not match")
248
+
}
249
+
250
+
subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
251
+
if subjectTypeOk && subjectType != "public" {
252
+
return nil, errors.New("only public `subject_type` is supported")
253
+
}
254
+
255
+
switch metadata.TokenEndpointAuthMethod {
256
+
case "none":
257
+
if metadata.TokenEndpointAuthSigningAlg != "" {
258
+
return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
259
+
}
260
+
case "private_key_jwt":
261
+
if metadata.JWKS == nil && metadata.JWKSURI == nil {
262
+
return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
263
+
}
264
+
265
+
if metadata.JWKS != nil && len(*metadata.JWKS) == 0 {
266
+
return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
267
+
}
268
+
269
+
if metadata.TokenEndpointAuthSigningAlg == "" {
270
+
return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
271
+
}
272
+
default:
273
+
return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
274
+
}
275
+
276
+
if !metadata.DpopBoundAccessTokens {
277
+
return nil, errors.New("dpop_bound_access_tokens must be true")
278
+
}
279
+
280
+
if !slices.Contains(metadata.ResponseTypes, "code") {
281
+
return nil, errors.New("response_types must inclue `code`")
282
+
}
283
+
284
+
if !slices.Contains(metadata.GrantTypes, "authorization_code") {
285
+
return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
286
+
}
287
+
288
+
if len(metadata.RedirectURIs) == 0 {
289
+
return nil, errors.New("at least one `redirect_uri` is required")
290
+
}
291
+
292
+
if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" {
293
+
return nil, errors.New("native clients must authenticate using `none` method")
294
+
}
295
+
296
+
if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
297
+
for _, ruri := range metadata.RedirectURIs {
298
+
u, err := url.Parse(ruri)
299
+
if err != nil {
300
+
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
301
+
}
302
+
303
+
if u.Scheme != "https" {
304
+
return nil, errors.New("web clients must use https redirect uris")
305
+
}
306
+
307
+
if u.Hostname() == "localhost" {
308
+
return nil, errors.New("web clients must not use localhost as the hostname")
309
+
}
310
+
}
311
+
}
312
+
313
+
for _, ruri := range metadata.RedirectURIs {
314
+
u, err := url.Parse(ruri)
315
+
if err != nil {
316
+
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
317
+
}
318
+
319
+
if u.User != nil {
320
+
if u.User.Username() != "" {
321
+
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
322
+
}
323
+
324
+
if _, hasPass := u.User.Password(); hasPass {
325
+
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
326
+
}
327
+
}
328
+
329
+
switch true {
330
+
case u.Hostname() == "localhost":
331
+
return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
332
+
case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
333
+
if metadata.ApplicationType != "native" {
334
+
return nil, errors.New("loopback redirect uris are only allowed for native apps")
335
+
}
336
+
337
+
if u.Port() != "" {
338
+
// reference impl doesn't do anything with this?
339
+
}
340
+
341
+
if u.Scheme != "http" {
342
+
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
343
+
}
344
+
345
+
break
346
+
case u.Scheme == "http":
347
+
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
348
+
case u.Scheme == "https":
349
+
if isLocalHostname(u.Hostname()) {
350
+
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
351
+
}
352
+
break
353
+
case strings.Contains(u.Scheme, "."):
354
+
if metadata.ApplicationType != "native" {
355
+
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
356
+
}
357
+
358
+
revdomain := reverseDomain(u.Scheme)
359
+
360
+
if isLocalHostname(revdomain) {
361
+
return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
362
+
}
363
+
364
+
if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
365
+
return nil, fmt.Errorf("private use uri scheme must be in the form ")
366
+
}
367
+
default:
368
+
return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
369
+
}
370
+
}
371
+
372
+
return &metadata, nil
373
+
}
374
+
375
+
func isLocalHostname(hostname string) bool {
376
+
pts := strings.Split(hostname, ".")
377
+
if len(pts) < 2 {
378
+
return true
379
+
}
380
+
381
+
tld := strings.ToLower(pts[len(pts)-1])
382
+
return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
383
+
}
384
+
385
+
func reverseDomain(domain string) string {
386
+
pts := strings.Split(domain, ".")
387
+
slices.Reverse(pts)
388
+
return strings.Join(pts, ".")
389
+
}
+20
oauth/client/metadata.go
+20
oauth/client/metadata.go
···
1
+
package client
2
+
3
+
type Metadata struct {
4
+
ClientID string `json:"client_id"`
5
+
ClientName string `json:"client_name"`
6
+
ClientURI string `json:"client_uri"`
7
+
LogoURI string `json:"logo_uri"`
8
+
TOSURI string `json:"tos_uri"`
9
+
PolicyURI string `json:"policy_uri"`
10
+
RedirectURIs []string `json:"redirect_uris"`
11
+
GrantTypes []string `json:"grant_types"`
12
+
ResponseTypes []string `json:"response_types"`
13
+
ApplicationType string `json:"application_type"`
14
+
DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
15
+
JWKSURI *string `json:"jwks_uri,omitempty"`
16
+
JWKS *[][]byte `json:"jwks,omitempty"`
17
+
Scope string `json:"scope"`
18
+
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
19
+
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
20
+
}
-8
oauth/client.go
-8
oauth/client.go
-390
oauth/client_manager/client_manager.go
-390
oauth/client_manager/client_manager.go
···
1
-
package client_manager
2
-
3
-
import (
4
-
"context"
5
-
"encoding/json"
6
-
"errors"
7
-
"fmt"
8
-
"io"
9
-
"log/slog"
10
-
"net/http"
11
-
"net/url"
12
-
"slices"
13
-
"strings"
14
-
"time"
15
-
16
-
cache "github.com/go-pkgz/expirable-cache/v3"
17
-
"github.com/haileyok/cocoon/internal/helpers"
18
-
"github.com/haileyok/cocoon/oauth"
19
-
"github.com/lestrrat-go/jwx/v2/jwk"
20
-
)
21
-
22
-
type ClientManager struct {
23
-
cli *http.Client
24
-
logger *slog.Logger
25
-
jwksCache cache.Cache[string, jwk.Key]
26
-
metadataCache cache.Cache[string, oauth.ClientMetadata]
27
-
}
28
-
29
-
type Args struct {
30
-
Cli *http.Client
31
-
Logger *slog.Logger
32
-
}
33
-
34
-
func New(args Args) *ClientManager {
35
-
if args.Logger == nil {
36
-
args.Logger = slog.Default()
37
-
}
38
-
39
-
if args.Cli == nil {
40
-
args.Cli = http.DefaultClient
41
-
}
42
-
43
-
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44
-
metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
45
-
46
-
return &ClientManager{
47
-
cli: args.Cli,
48
-
logger: args.Logger,
49
-
jwksCache: jwksCache,
50
-
metadataCache: metadataCache,
51
-
}
52
-
}
53
-
54
-
func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) {
55
-
metadata, err := cm.getClientMetadata(ctx, clientId)
56
-
if err != nil {
57
-
return nil, err
58
-
}
59
-
60
-
var jwks jwk.Key
61
-
if metadata.JWKS != nil {
62
-
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
63
-
// make sure we use the right one
64
-
k, err := helpers.ParseJWKFromBytes((*metadata.JWKS)[0])
65
-
if err != nil {
66
-
return nil, err
67
-
}
68
-
jwks = k
69
-
} else if metadata.JWKSURI != nil {
70
-
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
71
-
if err != nil {
72
-
return nil, err
73
-
}
74
-
75
-
jwks = maybeJwks
76
-
}
77
-
78
-
return &oauth.Client{
79
-
Metadata: metadata,
80
-
JWKS: jwks,
81
-
}, nil
82
-
}
83
-
84
-
func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) {
85
-
metadataCached, ok := cm.metadataCache.Get(clientId)
86
-
if !ok {
87
-
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
88
-
if err != nil {
89
-
return nil, err
90
-
}
91
-
92
-
resp, err := cm.cli.Do(req)
93
-
if err != nil {
94
-
return nil, err
95
-
}
96
-
defer resp.Body.Close()
97
-
98
-
if resp.StatusCode != http.StatusOK {
99
-
io.Copy(io.Discard, resp.Body)
100
-
return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode)
101
-
}
102
-
103
-
b, err := io.ReadAll(resp.Body)
104
-
if err != nil {
105
-
return nil, fmt.Errorf("error reading bytes from client response: %w", err)
106
-
}
107
-
108
-
validated, err := validateAndParseMetadata(clientId, b)
109
-
if err != nil {
110
-
return nil, err
111
-
}
112
-
113
-
return validated, nil
114
-
} else {
115
-
return &metadataCached, nil
116
-
}
117
-
}
118
-
119
-
func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) {
120
-
jwks, ok := cm.jwksCache.Get(clientId)
121
-
if !ok {
122
-
req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil)
123
-
if err != nil {
124
-
return nil, err
125
-
}
126
-
127
-
resp, err := cm.cli.Do(req)
128
-
if err != nil {
129
-
return nil, err
130
-
}
131
-
defer resp.Body.Close()
132
-
133
-
if resp.StatusCode != http.StatusOK {
134
-
io.Copy(io.Discard, resp.Body)
135
-
return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode)
136
-
}
137
-
138
-
type Keys struct {
139
-
Keys []map[string]any `json:"keys"`
140
-
}
141
-
142
-
var keys Keys
143
-
if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil {
144
-
return nil, fmt.Errorf("error unmarshaling keys response: %w", err)
145
-
}
146
-
147
-
if len(keys.Keys) == 0 {
148
-
return nil, errors.New("no keys in jwks response")
149
-
}
150
-
151
-
// TODO: this is again bad, we should be figuring out which one we need to use...
152
-
b, err := json.Marshal(keys.Keys[0])
153
-
if err != nil {
154
-
return nil, fmt.Errorf("could not marshal key: %w", err)
155
-
}
156
-
157
-
k, err := helpers.ParseJWKFromBytes(b)
158
-
if err != nil {
159
-
return nil, err
160
-
}
161
-
162
-
jwks = k
163
-
}
164
-
165
-
return jwks, nil
166
-
}
167
-
168
-
func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) {
169
-
var metadataMap map[string]any
170
-
if err := json.Unmarshal(b, &metadataMap); err != nil {
171
-
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
172
-
}
173
-
174
-
_, jwksOk := metadataMap["jwks"].(string)
175
-
_, jwksUriOk := metadataMap["jwks_uri"].(string)
176
-
if jwksOk && jwksUriOk {
177
-
return nil, errors.New("jwks_uri and jwks are mutually exclusive")
178
-
}
179
-
180
-
for _, k := range []string{
181
-
"default_max_age",
182
-
"userinfo_signed_response_alg",
183
-
"id_token_signed_response_alg",
184
-
"userinfo_encryhpted_response_alg",
185
-
"authorization_encrypted_response_enc",
186
-
"authorization_encrypted_response_alg",
187
-
"tls_client_certificate_bound_access_tokens",
188
-
} {
189
-
_, kOk := metadataMap[k]
190
-
if kOk {
191
-
return nil, fmt.Errorf("unsupported `%s` parameter", k)
192
-
}
193
-
}
194
-
195
-
var metadata oauth.ClientMetadata
196
-
if err := json.Unmarshal(b, &metadata); err != nil {
197
-
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
198
-
}
199
-
200
-
u, err := url.Parse(metadata.ClientURI)
201
-
if err != nil {
202
-
return nil, fmt.Errorf("unable to parse client uri: %w", err)
203
-
}
204
-
205
-
if isLocalHostname(u.Hostname()) {
206
-
return nil, errors.New("`client_uri` hostname is invalid")
207
-
}
208
-
209
-
if metadata.Scope == "" {
210
-
return nil, errors.New("missing `scopes` scope")
211
-
}
212
-
213
-
scopes := strings.Split(metadata.Scope, " ")
214
-
if !slices.Contains(scopes, "atproto") {
215
-
return nil, errors.New("missing `atproto` scope")
216
-
}
217
-
218
-
scopesMap := map[string]bool{}
219
-
for _, scope := range scopes {
220
-
if scopesMap[scope] {
221
-
return nil, fmt.Errorf("duplicate scope `%s`", scope)
222
-
}
223
-
224
-
// TODO: check for unsupported scopes
225
-
226
-
scopesMap[scope] = true
227
-
}
228
-
229
-
grantTypesMap := map[string]bool{}
230
-
for _, gt := range metadata.GrantTypes {
231
-
if grantTypesMap[gt] {
232
-
return nil, fmt.Errorf("duplicate grant type `%s`", gt)
233
-
}
234
-
235
-
switch gt {
236
-
case "implicit":
237
-
return nil, errors.New("grantg type `implicit` is not allowed")
238
-
case "authorization_code", "refresh_token":
239
-
// TODO check if this grant type is supported
240
-
default:
241
-
return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt)
242
-
}
243
-
244
-
grantTypesMap[gt] = true
245
-
}
246
-
247
-
if metadata.ClientID != clientId {
248
-
return nil, errors.New("`client_id` does not match")
249
-
}
250
-
251
-
subjectType, subjectTypeOk := metadataMap["subject_type"].(string)
252
-
if subjectTypeOk && subjectType != "public" {
253
-
return nil, errors.New("only public `subject_type` is supported")
254
-
}
255
-
256
-
switch metadata.TokenEndpointAuthMethod {
257
-
case "none":
258
-
if metadata.TokenEndpointAuthSigningAlg != "" {
259
-
return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg")
260
-
}
261
-
case "private_key_jwt":
262
-
if metadata.JWKS == nil && metadata.JWKSURI == nil {
263
-
return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri")
264
-
}
265
-
266
-
if metadata.JWKS != nil && len(*metadata.JWKS) == 0 {
267
-
return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks")
268
-
}
269
-
270
-
if metadata.TokenEndpointAuthSigningAlg == "" {
271
-
return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata")
272
-
}
273
-
default:
274
-
return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod)
275
-
}
276
-
277
-
if !metadata.DpopBoundAccessTokens {
278
-
return nil, errors.New("dpop_bound_access_tokens must be true")
279
-
}
280
-
281
-
if !slices.Contains(metadata.ResponseTypes, "code") {
282
-
return nil, errors.New("response_types must inclue `code`")
283
-
}
284
-
285
-
if !slices.Contains(metadata.GrantTypes, "authorization_code") {
286
-
return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`")
287
-
}
288
-
289
-
if len(metadata.RedirectURIs) == 0 {
290
-
return nil, errors.New("at least one `redirect_uri` is required")
291
-
}
292
-
293
-
if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod == "none" {
294
-
return nil, errors.New("native clients must authenticate using `none` method")
295
-
}
296
-
297
-
if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") {
298
-
for _, ruri := range metadata.RedirectURIs {
299
-
u, err := url.Parse(ruri)
300
-
if err != nil {
301
-
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
302
-
}
303
-
304
-
if u.Scheme != "https" {
305
-
return nil, errors.New("web clients must use https redirect uris")
306
-
}
307
-
308
-
if u.Hostname() == "localhost" {
309
-
return nil, errors.New("web clients must not use localhost as the hostname")
310
-
}
311
-
}
312
-
}
313
-
314
-
for _, ruri := range metadata.RedirectURIs {
315
-
u, err := url.Parse(ruri)
316
-
if err != nil {
317
-
return nil, fmt.Errorf("error parsing redirect uri: %w", err)
318
-
}
319
-
320
-
if u.User != nil {
321
-
if u.User.Username() != "" {
322
-
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
323
-
}
324
-
325
-
if _, hasPass := u.User.Password(); hasPass {
326
-
return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri)
327
-
}
328
-
}
329
-
330
-
switch true {
331
-
case u.Hostname() == "localhost":
332
-
return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)")
333
-
case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]":
334
-
if metadata.ApplicationType != "native" {
335
-
return nil, errors.New("loopback redirect uris are only allowed for native apps")
336
-
}
337
-
338
-
if u.Port() != "" {
339
-
// reference impl doesn't do anything with this?
340
-
}
341
-
342
-
if u.Scheme != "http" {
343
-
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
344
-
}
345
-
346
-
break
347
-
case u.Scheme == "http":
348
-
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
349
-
case u.Scheme == "https":
350
-
if isLocalHostname(u.Hostname()) {
351
-
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
352
-
}
353
-
break
354
-
case strings.Contains(u.Scheme, "."):
355
-
if metadata.ApplicationType != "native" {
356
-
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")
357
-
}
358
-
359
-
revdomain := reverseDomain(u.Scheme)
360
-
361
-
if isLocalHostname(revdomain) {
362
-
return nil, errors.New("private use uri scheme redirect uris must not be local hostnames")
363
-
}
364
-
365
-
if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" {
366
-
return nil, fmt.Errorf("private use uri scheme must be in the form ")
367
-
}
368
-
default:
369
-
return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme)
370
-
}
371
-
}
372
-
373
-
return &metadata, nil
374
-
}
375
-
376
-
func isLocalHostname(hostname string) bool {
377
-
pts := strings.Split(hostname, ".")
378
-
if len(pts) < 2 {
379
-
return true
380
-
}
381
-
382
-
tld := strings.ToLower(pts[len(pts)-1])
383
-
return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example"
384
-
}
385
-
386
-
func reverseDomain(domain string) string {
387
-
pts := strings.Split(domain, ".")
388
-
slices.Reverse(pts)
389
-
return strings.Join(pts, ".")
390
-
}
-20
oauth/client_metadata.go
-20
oauth/client_metadata.go
···
1
-
package oauth
2
-
3
-
type ClientMetadata struct {
4
-
ClientID string `json:"client_id"`
5
-
ClientName string `json:"client_name"`
6
-
ClientURI string `json:"client_uri"`
7
-
LogoURI string `json:"logo_uri"`
8
-
TOSURI string `json:"tos_uri"`
9
-
PolicyURI string `json:"policy_uri"`
10
-
RedirectURIs []string `json:"redirect_uris"`
11
-
GrantTypes []string `json:"grant_types"`
12
-
ResponseTypes []string `json:"response_types"`
13
-
ApplicationType string `json:"application_type"`
14
-
DpopBoundAccessTokens bool `json:"dpop_bound_access_tokens"`
15
-
JWKSURI *string `json:"jwks_uri,omitempty"`
16
-
JWKS *[][]byte `json:"jwks,omitempty"`
17
-
Scope string `json:"scope"`
18
-
TokenEndpointAuthMethod string `json:"token_endpoint_auth_method"`
19
-
TokenEndpointAuthSigningAlg string `json:"token_endpoint_auth_signing_alg"`
20
-
}
-251
oauth/dpop/dpop_manager/dpop_manager.go
-251
oauth/dpop/dpop_manager/dpop_manager.go
···
1
-
package dpop_manager
2
-
3
-
import (
4
-
"crypto"
5
-
"crypto/sha256"
6
-
"encoding/base64"
7
-
"encoding/json"
8
-
"errors"
9
-
"fmt"
10
-
"log/slog"
11
-
"net/http"
12
-
"net/url"
13
-
"strings"
14
-
"time"
15
-
16
-
"github.com/golang-jwt/jwt/v4"
17
-
"github.com/haileyok/cocoon/internal/helpers"
18
-
"github.com/haileyok/cocoon/oauth/constants"
19
-
"github.com/haileyok/cocoon/oauth/dpop"
20
-
"github.com/haileyok/cocoon/oauth/dpop/nonce"
21
-
"github.com/lestrrat-go/jwx/v2/jwa"
22
-
"github.com/lestrrat-go/jwx/v2/jwk"
23
-
)
24
-
25
-
type DpopManager struct {
26
-
nonce *nonce.Nonce
27
-
jtiCache *jtiCache
28
-
logger *slog.Logger
29
-
hostname string
30
-
}
31
-
32
-
type Args struct {
33
-
NonceSecret []byte
34
-
NonceRotationInterval time.Duration
35
-
OnNonceSecretCreated func([]byte)
36
-
JTICacheSize int
37
-
Logger *slog.Logger
38
-
Hostname string
39
-
}
40
-
41
-
func New(args Args) *DpopManager {
42
-
if args.Logger == nil {
43
-
args.Logger = slog.Default()
44
-
}
45
-
46
-
if args.JTICacheSize == 0 {
47
-
args.JTICacheSize = 100_000
48
-
}
49
-
50
-
if args.NonceSecret == nil {
51
-
args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
52
-
}
53
-
54
-
return &DpopManager{
55
-
nonce: nonce.NewNonce(nonce.Args{
56
-
RotationInterval: args.NonceRotationInterval,
57
-
Secret: args.NonceSecret,
58
-
OnSecretCreated: args.OnNonceSecretCreated,
59
-
}),
60
-
jtiCache: newJTICache(args.JTICacheSize),
61
-
logger: args.Logger,
62
-
hostname: args.Hostname,
63
-
}
64
-
}
65
-
66
-
func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) {
67
-
if reqMethod == "" {
68
-
return nil, errors.New("HTTP method is required")
69
-
}
70
-
71
-
if !strings.HasPrefix(reqUrl, "https://") {
72
-
reqUrl = "https://" + dm.hostname + reqUrl
73
-
}
74
-
75
-
proof := extractProof(headers)
76
-
77
-
if proof == "" {
78
-
return nil, nil
79
-
}
80
-
81
-
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
82
-
var token *jwt.Token
83
-
84
-
token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
85
-
if err != nil {
86
-
return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
87
-
}
88
-
89
-
typ, _ := token.Header["typ"].(string)
90
-
if typ != "dpop+jwt" {
91
-
return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
92
-
}
93
-
94
-
dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
95
-
if !jwkOk {
96
-
return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
97
-
}
98
-
99
-
jwkb, err := json.Marshal(dpopJwk)
100
-
if err != nil {
101
-
return nil, fmt.Errorf("failed to marshal jwk: %w", err)
102
-
}
103
-
104
-
key, err := jwk.ParseKey(jwkb)
105
-
if err != nil {
106
-
return nil, fmt.Errorf("failed to parse jwk: %w", err)
107
-
}
108
-
109
-
var pubKey any
110
-
if err := key.Raw(&pubKey); err != nil {
111
-
return nil, fmt.Errorf("failed to get raw public key: %w", err)
112
-
}
113
-
114
-
token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
115
-
alg := t.Header["alg"].(string)
116
-
117
-
switch key.KeyType() {
118
-
case jwa.EC:
119
-
if !strings.HasPrefix(alg, "ES") {
120
-
return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
121
-
}
122
-
case jwa.RSA:
123
-
if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
124
-
return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
125
-
}
126
-
case jwa.OKP:
127
-
if alg != "EdDSA" {
128
-
return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
129
-
}
130
-
}
131
-
132
-
return pubKey, nil
133
-
}, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
134
-
if err != nil {
135
-
return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
136
-
}
137
-
138
-
if !token.Valid {
139
-
return nil, errors.New("dpop proof jwt is invalid")
140
-
}
141
-
142
-
claims, ok := token.Claims.(jwt.MapClaims)
143
-
if !ok {
144
-
return nil, errors.New("no claims in dpop proof jwt")
145
-
}
146
-
147
-
iat, iatOk := claims["iat"].(float64)
148
-
if !iatOk {
149
-
return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
150
-
}
151
-
152
-
iatTime := time.Unix(int64(iat), 0)
153
-
now := time.Now()
154
-
155
-
if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
156
-
return nil, errors.New("dpop proof too old")
157
-
}
158
-
159
-
if iatTime.Sub(now) > constants.DpopCheckTolerance {
160
-
return nil, errors.New("dpop proof iat is in the future")
161
-
}
162
-
163
-
jti, _ := claims["jti"].(string)
164
-
if jti == "" {
165
-
return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
166
-
}
167
-
168
-
if dm.jtiCache.add(jti) {
169
-
return nil, errors.New("dpop proof replay detected")
170
-
}
171
-
172
-
htm, _ := claims["htm"].(string)
173
-
if htm == "" {
174
-
return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
175
-
}
176
-
177
-
if htm != reqMethod {
178
-
return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
179
-
}
180
-
181
-
htu, _ := claims["htu"].(string)
182
-
if htu == "" {
183
-
return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
184
-
}
185
-
186
-
parsedHtu, err := helpers.OauthParseHtu(htu)
187
-
if err != nil {
188
-
return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
189
-
}
190
-
191
-
u, _ := url.Parse(reqUrl)
192
-
if parsedHtu != helpers.OauthNormalizeHtu(u) {
193
-
return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
194
-
}
195
-
196
-
nonce, _ := claims["nonce"].(string)
197
-
if nonce == "" {
198
-
// WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
199
-
return nil, errors.New("use_dpop_nonce")
200
-
}
201
-
202
-
if nonce != "" && !dm.nonce.Check(nonce) {
203
-
// WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
204
-
return nil, errors.New("use_dpop_nonce")
205
-
}
206
-
207
-
ath, _ := claims["ath"].(string)
208
-
209
-
if accessToken != nil && *accessToken != "" {
210
-
if ath == "" {
211
-
return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
212
-
}
213
-
214
-
hash := sha256.Sum256([]byte(*accessToken))
215
-
if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
216
-
return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
217
-
}
218
-
} else if ath != "" {
219
-
return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
220
-
}
221
-
222
-
thumbBytes, err := key.Thumbprint(crypto.SHA256)
223
-
if err != nil {
224
-
return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
225
-
}
226
-
227
-
thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
228
-
229
-
return &dpop.Proof{
230
-
JTI: jti,
231
-
JKT: thumb,
232
-
HTM: htm,
233
-
HTU: htu,
234
-
}, nil
235
-
}
236
-
237
-
func extractProof(headers http.Header) string {
238
-
dpopHeaders := headers["Dpop"]
239
-
switch len(dpopHeaders) {
240
-
case 0:
241
-
return ""
242
-
case 1:
243
-
return dpopHeaders[0]
244
-
default:
245
-
return ""
246
-
}
247
-
}
248
-
249
-
func (dm *DpopManager) NextNonce() string {
250
-
return dm.nonce.NextNonce()
251
-
}
-28
oauth/dpop/dpop_manager/jti_cache.go
-28
oauth/dpop/dpop_manager/jti_cache.go
···
1
-
package dpop_manager
2
-
3
-
import (
4
-
"sync"
5
-
"time"
6
-
7
-
cache "github.com/go-pkgz/expirable-cache/v3"
8
-
"github.com/haileyok/cocoon/oauth/constants"
9
-
)
10
-
11
-
type jtiCache struct {
12
-
mu sync.Mutex
13
-
cache cache.Cache[string, bool]
14
-
}
15
-
16
-
func newJTICache(size int) *jtiCache {
17
-
cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl)
18
-
return &jtiCache{
19
-
cache: cache,
20
-
mu: sync.Mutex{},
21
-
}
22
-
}
23
-
24
-
func (c *jtiCache) add(jti string) bool {
25
-
c.mu.Lock()
26
-
defer c.mu.Unlock()
27
-
return c.cache.Add(jti, true)
28
-
}
+28
oauth/dpop/jti_cache.go
+28
oauth/dpop/jti_cache.go
···
1
+
package dpop
2
+
3
+
import (
4
+
"sync"
5
+
"time"
6
+
7
+
cache "github.com/go-pkgz/expirable-cache/v3"
8
+
"github.com/haileyok/cocoon/oauth/constants"
9
+
)
10
+
11
+
type jtiCache struct {
12
+
mu sync.Mutex
13
+
cache cache.Cache[string, bool]
14
+
}
15
+
16
+
func newJTICache(size int) *jtiCache {
17
+
cache := cache.NewCache[string, bool]().WithTTL(24 * time.Hour).WithLRU().WithTTL(constants.JTITtl)
18
+
return &jtiCache{
19
+
cache: cache,
20
+
mu: sync.Mutex{},
21
+
}
22
+
}
23
+
24
+
func (c *jtiCache) add(jti string) bool {
25
+
c.mu.Lock()
26
+
defer c.mu.Unlock()
27
+
return c.cache.Add(jti, true)
28
+
}
+249
oauth/dpop/manager.go
+249
oauth/dpop/manager.go
···
1
+
package dpop
2
+
3
+
import (
4
+
"crypto"
5
+
"crypto/sha256"
6
+
"encoding/base64"
7
+
"encoding/json"
8
+
"errors"
9
+
"fmt"
10
+
"log/slog"
11
+
"net/http"
12
+
"net/url"
13
+
"strings"
14
+
"time"
15
+
16
+
"github.com/golang-jwt/jwt/v4"
17
+
"github.com/haileyok/cocoon/internal/helpers"
18
+
"github.com/haileyok/cocoon/oauth/constants"
19
+
"github.com/lestrrat-go/jwx/v2/jwa"
20
+
"github.com/lestrrat-go/jwx/v2/jwk"
21
+
)
22
+
23
+
type Manager struct {
24
+
nonce *Nonce
25
+
jtiCache *jtiCache
26
+
logger *slog.Logger
27
+
hostname string
28
+
}
29
+
30
+
type ManagerArgs struct {
31
+
NonceSecret []byte
32
+
NonceRotationInterval time.Duration
33
+
OnNonceSecretCreated func([]byte)
34
+
JTICacheSize int
35
+
Logger *slog.Logger
36
+
Hostname string
37
+
}
38
+
39
+
func NewManager(args ManagerArgs) *Manager {
40
+
if args.Logger == nil {
41
+
args.Logger = slog.Default()
42
+
}
43
+
44
+
if args.JTICacheSize == 0 {
45
+
args.JTICacheSize = 100_000
46
+
}
47
+
48
+
if args.NonceSecret == nil {
49
+
args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.")
50
+
}
51
+
52
+
return &Manager{
53
+
nonce: NewNonce(NonceArgs{
54
+
RotationInterval: args.NonceRotationInterval,
55
+
Secret: args.NonceSecret,
56
+
OnSecretCreated: args.OnNonceSecretCreated,
57
+
}),
58
+
jtiCache: newJTICache(args.JTICacheSize),
59
+
logger: args.Logger,
60
+
hostname: args.Hostname,
61
+
}
62
+
}
63
+
64
+
func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) {
65
+
if reqMethod == "" {
66
+
return nil, errors.New("HTTP method is required")
67
+
}
68
+
69
+
if !strings.HasPrefix(reqUrl, "https://") {
70
+
reqUrl = "https://" + dm.hostname + reqUrl
71
+
}
72
+
73
+
proof := extractProof(headers)
74
+
75
+
if proof == "" {
76
+
return nil, nil
77
+
}
78
+
79
+
parser := jwt.NewParser(jwt.WithoutClaimsValidation())
80
+
var token *jwt.Token
81
+
82
+
token, _, err := parser.ParseUnverified(proof, jwt.MapClaims{})
83
+
if err != nil {
84
+
return nil, fmt.Errorf("could not parse dpop proof jwt: %w", err)
85
+
}
86
+
87
+
typ, _ := token.Header["typ"].(string)
88
+
if typ != "dpop+jwt" {
89
+
return nil, errors.New(`invalid dpop proof jwt: "typ" must be 'dpop+jwt'`)
90
+
}
91
+
92
+
dpopJwk, jwkOk := token.Header["jwk"].(map[string]any)
93
+
if !jwkOk {
94
+
return nil, errors.New(`invalid dpop proof jwt: "jwk" is missing in header`)
95
+
}
96
+
97
+
jwkb, err := json.Marshal(dpopJwk)
98
+
if err != nil {
99
+
return nil, fmt.Errorf("failed to marshal jwk: %w", err)
100
+
}
101
+
102
+
key, err := jwk.ParseKey(jwkb)
103
+
if err != nil {
104
+
return nil, fmt.Errorf("failed to parse jwk: %w", err)
105
+
}
106
+
107
+
var pubKey any
108
+
if err := key.Raw(&pubKey); err != nil {
109
+
return nil, fmt.Errorf("failed to get raw public key: %w", err)
110
+
}
111
+
112
+
token, err = jwt.Parse(proof, func(t *jwt.Token) (any, error) {
113
+
alg := t.Header["alg"].(string)
114
+
115
+
switch key.KeyType() {
116
+
case jwa.EC:
117
+
if !strings.HasPrefix(alg, "ES") {
118
+
return nil, fmt.Errorf("algorithm %s doesn't match EC key type", alg)
119
+
}
120
+
case jwa.RSA:
121
+
if !strings.HasPrefix(alg, "RS") && !strings.HasPrefix(alg, "PS") {
122
+
return nil, fmt.Errorf("algorithm %s doesn't match RSA key type", alg)
123
+
}
124
+
case jwa.OKP:
125
+
if alg != "EdDSA" {
126
+
return nil, fmt.Errorf("algorithm %s doesn't match OKP key type", alg)
127
+
}
128
+
}
129
+
130
+
return pubKey, nil
131
+
}, jwt.WithValidMethods([]string{"ES256", "ES384", "ES512", "RS256", "RS384", "RS512", "PS256", "PS384", "PS512", "EdDSA"}))
132
+
if err != nil {
133
+
return nil, fmt.Errorf("could not verify dpop proof jwt: %w", err)
134
+
}
135
+
136
+
if !token.Valid {
137
+
return nil, errors.New("dpop proof jwt is invalid")
138
+
}
139
+
140
+
claims, ok := token.Claims.(jwt.MapClaims)
141
+
if !ok {
142
+
return nil, errors.New("no claims in dpop proof jwt")
143
+
}
144
+
145
+
iat, iatOk := claims["iat"].(float64)
146
+
if !iatOk {
147
+
return nil, errors.New(`invalid dpop proof jwt: "iat" is missing`)
148
+
}
149
+
150
+
iatTime := time.Unix(int64(iat), 0)
151
+
now := time.Now()
152
+
153
+
if now.Sub(iatTime) > constants.DpopNonceMaxAge+constants.DpopCheckTolerance {
154
+
return nil, errors.New("dpop proof too old")
155
+
}
156
+
157
+
if iatTime.Sub(now) > constants.DpopCheckTolerance {
158
+
return nil, errors.New("dpop proof iat is in the future")
159
+
}
160
+
161
+
jti, _ := claims["jti"].(string)
162
+
if jti == "" {
163
+
return nil, errors.New(`invalid dpop proof jwt: "jti" is missing`)
164
+
}
165
+
166
+
if dm.jtiCache.add(jti) {
167
+
return nil, errors.New("dpop proof replay detected")
168
+
}
169
+
170
+
htm, _ := claims["htm"].(string)
171
+
if htm == "" {
172
+
return nil, errors.New(`invalid dpop proof jwt: "htm" is missing`)
173
+
}
174
+
175
+
if htm != reqMethod {
176
+
return nil, errors.New(`invalid dpop proof jwt: "htm" mismatch`)
177
+
}
178
+
179
+
htu, _ := claims["htu"].(string)
180
+
if htu == "" {
181
+
return nil, errors.New(`invalid dpop proof jwt: "htu" is missing`)
182
+
}
183
+
184
+
parsedHtu, err := helpers.OauthParseHtu(htu)
185
+
if err != nil {
186
+
return nil, errors.New(`invalid dpop proof jwt: "htu" could not be parsed`)
187
+
}
188
+
189
+
u, _ := url.Parse(reqUrl)
190
+
if parsedHtu != helpers.OauthNormalizeHtu(u) {
191
+
return nil, fmt.Errorf(`invalid dpop proof jwt: "htu" mismatch. reqUrl: %s, parsed: %s, normalized: %s`, reqUrl, parsedHtu, helpers.OauthNormalizeHtu(u))
192
+
}
193
+
194
+
nonce, _ := claims["nonce"].(string)
195
+
if nonce == "" {
196
+
// WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request
197
+
return nil, errors.New("use_dpop_nonce")
198
+
}
199
+
200
+
if nonce != "" && !dm.nonce.Check(nonce) {
201
+
// WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce
202
+
return nil, errors.New("use_dpop_nonce")
203
+
}
204
+
205
+
ath, _ := claims["ath"].(string)
206
+
207
+
if accessToken != nil && *accessToken != "" {
208
+
if ath == "" {
209
+
return nil, errors.New(`invalid dpop proof jwt: "ath" is required with access token`)
210
+
}
211
+
212
+
hash := sha256.Sum256([]byte(*accessToken))
213
+
if ath != base64.RawURLEncoding.EncodeToString(hash[:]) {
214
+
return nil, errors.New(`invalid dpop proof jwt: "ath" mismatch`)
215
+
}
216
+
} else if ath != "" {
217
+
return nil, errors.New(`invalid dpop proof jwt: "ath" claim not allowed`)
218
+
}
219
+
220
+
thumbBytes, err := key.Thumbprint(crypto.SHA256)
221
+
if err != nil {
222
+
return nil, fmt.Errorf("failed to calculate thumbprint: %w", err)
223
+
}
224
+
225
+
thumb := base64.RawURLEncoding.EncodeToString(thumbBytes)
226
+
227
+
return &Proof{
228
+
JTI: jti,
229
+
JKT: thumb,
230
+
HTM: htm,
231
+
HTU: htu,
232
+
}, nil
233
+
}
234
+
235
+
func extractProof(headers http.Header) string {
236
+
dpopHeaders := headers["Dpop"]
237
+
switch len(dpopHeaders) {
238
+
case 0:
239
+
return ""
240
+
case 1:
241
+
return dpopHeaders[0]
242
+
default:
243
+
return ""
244
+
}
245
+
}
246
+
247
+
func (dm *Manager) NextNonce() string {
248
+
return dm.nonce.NextNonce()
249
+
}
-108
oauth/dpop/nonce/nonce.go
-108
oauth/dpop/nonce/nonce.go
···
1
-
package nonce
2
-
3
-
import (
4
-
"crypto/hmac"
5
-
"crypto/sha256"
6
-
"encoding/base64"
7
-
"encoding/binary"
8
-
"sync"
9
-
"time"
10
-
11
-
"github.com/haileyok/cocoon/internal/helpers"
12
-
"github.com/haileyok/cocoon/oauth/constants"
13
-
)
14
-
15
-
type Nonce struct {
16
-
rotationInterval time.Duration
17
-
secret []byte
18
-
19
-
mu sync.RWMutex
20
-
21
-
counter int64
22
-
prev string
23
-
curr string
24
-
next string
25
-
}
26
-
27
-
type Args struct {
28
-
RotationInterval time.Duration
29
-
Secret []byte
30
-
OnSecretCreated func([]byte)
31
-
}
32
-
33
-
func NewNonce(args Args) *Nonce {
34
-
if args.RotationInterval == 0 {
35
-
args.RotationInterval = constants.NonceMaxRotationInterval / 3
36
-
}
37
-
38
-
if args.RotationInterval > constants.NonceMaxRotationInterval {
39
-
args.RotationInterval = constants.NonceMaxRotationInterval
40
-
}
41
-
42
-
if args.Secret == nil {
43
-
args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
44
-
args.OnSecretCreated(args.Secret)
45
-
}
46
-
47
-
n := &Nonce{
48
-
rotationInterval: args.RotationInterval,
49
-
secret: args.Secret,
50
-
mu: sync.RWMutex{},
51
-
}
52
-
53
-
n.counter = n.currentCounter()
54
-
n.prev = n.compute(n.counter - 1)
55
-
n.curr = n.compute(n.counter)
56
-
n.next = n.compute(n.counter + 1)
57
-
58
-
return n
59
-
}
60
-
61
-
func (n *Nonce) currentCounter() int64 {
62
-
return time.Now().UnixNano() / int64(n.rotationInterval)
63
-
}
64
-
65
-
func (n *Nonce) compute(counter int64) string {
66
-
h := hmac.New(sha256.New, n.secret)
67
-
counterBytes := make([]byte, 8)
68
-
binary.BigEndian.PutUint64(counterBytes, uint64(counter))
69
-
h.Write(counterBytes)
70
-
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
71
-
}
72
-
73
-
func (n *Nonce) rotate() {
74
-
counter := n.currentCounter()
75
-
diff := counter - n.counter
76
-
77
-
switch diff {
78
-
case 0:
79
-
// counter == n.counter, do nothing
80
-
case 1:
81
-
n.prev = n.curr
82
-
n.curr = n.next
83
-
n.next = n.compute(counter + 1)
84
-
case 2:
85
-
n.prev = n.next
86
-
n.curr = n.compute(counter)
87
-
n.next = n.compute(counter + 1)
88
-
default:
89
-
n.prev = n.compute(counter - 1)
90
-
n.curr = n.compute(counter)
91
-
n.next = n.compute(counter + 1)
92
-
}
93
-
94
-
n.counter = counter
95
-
}
96
-
97
-
func (n *Nonce) NextNonce() string {
98
-
n.mu.Lock()
99
-
defer n.mu.Unlock()
100
-
n.rotate()
101
-
return n.next
102
-
}
103
-
104
-
func (n *Nonce) Check(nonce string) bool {
105
-
n.mu.RLock()
106
-
defer n.mu.RUnlock()
107
-
return nonce == n.prev || nonce == n.curr || nonce == n.next
108
-
}
+108
oauth/dpop/nonce.go
+108
oauth/dpop/nonce.go
···
1
+
package dpop
2
+
3
+
import (
4
+
"crypto/hmac"
5
+
"crypto/sha256"
6
+
"encoding/base64"
7
+
"encoding/binary"
8
+
"sync"
9
+
"time"
10
+
11
+
"github.com/haileyok/cocoon/internal/helpers"
12
+
"github.com/haileyok/cocoon/oauth/constants"
13
+
)
14
+
15
+
type Nonce struct {
16
+
rotationInterval time.Duration
17
+
secret []byte
18
+
19
+
mu sync.RWMutex
20
+
21
+
counter int64
22
+
prev string
23
+
curr string
24
+
next string
25
+
}
26
+
27
+
type NonceArgs struct {
28
+
RotationInterval time.Duration
29
+
Secret []byte
30
+
OnSecretCreated func([]byte)
31
+
}
32
+
33
+
func NewNonce(args NonceArgs) *Nonce {
34
+
if args.RotationInterval == 0 {
35
+
args.RotationInterval = constants.NonceMaxRotationInterval / 3
36
+
}
37
+
38
+
if args.RotationInterval > constants.NonceMaxRotationInterval {
39
+
args.RotationInterval = constants.NonceMaxRotationInterval
40
+
}
41
+
42
+
if args.Secret == nil {
43
+
args.Secret = helpers.RandomBytes(constants.NonceSecretByteLength)
44
+
args.OnSecretCreated(args.Secret)
45
+
}
46
+
47
+
n := &Nonce{
48
+
rotationInterval: args.RotationInterval,
49
+
secret: args.Secret,
50
+
mu: sync.RWMutex{},
51
+
}
52
+
53
+
n.counter = n.currentCounter()
54
+
n.prev = n.compute(n.counter - 1)
55
+
n.curr = n.compute(n.counter)
56
+
n.next = n.compute(n.counter + 1)
57
+
58
+
return n
59
+
}
60
+
61
+
func (n *Nonce) currentCounter() int64 {
62
+
return time.Now().UnixNano() / int64(n.rotationInterval)
63
+
}
64
+
65
+
func (n *Nonce) compute(counter int64) string {
66
+
h := hmac.New(sha256.New, n.secret)
67
+
counterBytes := make([]byte, 8)
68
+
binary.BigEndian.PutUint64(counterBytes, uint64(counter))
69
+
h.Write(counterBytes)
70
+
return base64.RawURLEncoding.EncodeToString(h.Sum(nil))
71
+
}
72
+
73
+
func (n *Nonce) rotate() {
74
+
counter := n.currentCounter()
75
+
diff := counter - n.counter
76
+
77
+
switch diff {
78
+
case 0:
79
+
// counter == n.counter, do nothing
80
+
case 1:
81
+
n.prev = n.curr
82
+
n.curr = n.next
83
+
n.next = n.compute(counter + 1)
84
+
case 2:
85
+
n.prev = n.next
86
+
n.curr = n.compute(counter)
87
+
n.next = n.compute(counter + 1)
88
+
default:
89
+
n.prev = n.compute(counter - 1)
90
+
n.curr = n.compute(counter)
91
+
n.next = n.compute(counter + 1)
92
+
}
93
+
94
+
n.counter = counter
95
+
}
96
+
97
+
func (n *Nonce) NextNonce() string {
98
+
n.mu.Lock()
99
+
defer n.mu.Unlock()
100
+
n.rotate()
101
+
return n.next
102
+
}
103
+
104
+
func (n *Nonce) Check(nonce string) bool {
105
+
n.mu.RLock()
106
+
defer n.mu.RUnlock()
107
+
return nonce == n.prev || nonce == n.curr || nonce == n.next
108
+
}
+32
oauth/helpers.go
+32
oauth/helpers.go
···
4
4
"errors"
5
5
"fmt"
6
6
"net/url"
7
+
"time"
7
8
8
9
"github.com/haileyok/cocoon/internal/helpers"
9
10
"github.com/haileyok/cocoon/oauth/constants"
11
+
"github.com/haileyok/cocoon/oauth/provider"
10
12
)
11
13
12
14
func GenerateCode() string {
···
46
48
47
49
return reqId, nil
48
50
}
51
+
52
+
type SessionAgeResult struct {
53
+
SessionAge time.Duration
54
+
RefreshAge time.Duration
55
+
SessionExpired bool
56
+
RefreshExpired bool
57
+
}
58
+
59
+
func GetSessionAgeFromToken(t provider.OauthToken) SessionAgeResult {
60
+
sessionLifetime := constants.PublicClientSessionLifetime
61
+
refreshLifetime := constants.PublicClientRefreshLifetime
62
+
if t.ClientAuth.Method != "none" {
63
+
sessionLifetime = constants.ConfidentialClientSessionLifetime
64
+
refreshLifetime = constants.ConfidentialClientRefreshLifetime
65
+
}
66
+
67
+
res := SessionAgeResult{}
68
+
69
+
res.SessionAge = time.Since(t.CreatedAt)
70
+
if res.SessionAge > sessionLifetime {
71
+
res.SessionExpired = true
72
+
}
73
+
74
+
refreshAge := time.Since(t.UpdatedAt)
75
+
if refreshAge > refreshLifetime {
76
+
res.RefreshExpired = true
77
+
}
78
+
79
+
return res
80
+
}
+3
-26
oauth/provider/client_auth.go
+3
-26
oauth/provider/client_auth.go
···
3
3
import (
4
4
"context"
5
5
"crypto"
6
-
"database/sql/driver"
7
6
"encoding/base64"
8
-
"encoding/json"
9
7
"errors"
10
8
"fmt"
11
9
"time"
12
10
13
11
"github.com/golang-jwt/jwt/v4"
14
-
"github.com/haileyok/cocoon/oauth"
12
+
"github.com/haileyok/cocoon/oauth/client"
15
13
"github.com/haileyok/cocoon/oauth/constants"
16
14
"github.com/haileyok/cocoon/oauth/dpop"
17
15
)
18
16
19
-
type ClientAuth struct {
20
-
Method string
21
-
Alg string
22
-
Kid string
23
-
Jkt string
24
-
Jti string
25
-
Exp *float64
26
-
}
27
-
28
-
func (ca *ClientAuth) Scan(value any) error {
29
-
b, ok := value.([]byte)
30
-
if !ok {
31
-
return fmt.Errorf("failed to unmarshal OauthParRequest value")
32
-
}
33
-
return json.Unmarshal(b, ca)
34
-
}
35
-
36
-
func (ca ClientAuth) Value() (driver.Value, error) {
37
-
return json.Marshal(ca)
38
-
}
39
-
40
17
type AuthenticateClientOptions struct {
41
18
AllowMissingDpopProof bool
42
19
}
···
47
24
ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"`
48
25
}
49
26
50
-
func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) {
27
+
func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) {
51
28
client, err := p.ClientManager.GetClient(ctx, req.ClientID)
52
29
if err != nil {
53
30
return nil, nil, fmt.Errorf("failed to get client: %w", err)
···
69
46
return client, clientAuth, nil
70
47
}
71
48
72
-
func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) {
49
+
func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) {
73
50
metadata := client.Metadata
74
51
75
52
if metadata.TokenEndpointAuthMethod == "none" {
+83
oauth/provider/models.go
+83
oauth/provider/models.go
···
1
+
package provider
2
+
3
+
import (
4
+
"database/sql/driver"
5
+
"encoding/json"
6
+
"fmt"
7
+
"time"
8
+
9
+
"gorm.io/gorm"
10
+
)
11
+
12
+
type ClientAuth struct {
13
+
Method string
14
+
Alg string
15
+
Kid string
16
+
Jkt string
17
+
Jti string
18
+
Exp *float64
19
+
}
20
+
21
+
func (ca *ClientAuth) Scan(value any) error {
22
+
b, ok := value.([]byte)
23
+
if !ok {
24
+
return fmt.Errorf("failed to unmarshal OauthParRequest value")
25
+
}
26
+
return json.Unmarshal(b, ca)
27
+
}
28
+
29
+
func (ca ClientAuth) Value() (driver.Value, error) {
30
+
return json.Marshal(ca)
31
+
}
32
+
33
+
type ParRequest struct {
34
+
AuthenticateClientRequestBase
35
+
ResponseType string `form:"response_type" json:"response_type" validate:"required"`
36
+
CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"`
37
+
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"`
38
+
State string `form:"state" json:"state" validate:"required"`
39
+
RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"`
40
+
Scope string `form:"scope" json:"scope" validate:"required"`
41
+
LoginHint *string `form:"login_hint" json:"login_hint,omitempty"`
42
+
DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"`
43
+
}
44
+
45
+
func (opr *ParRequest) Scan(value any) error {
46
+
b, ok := value.([]byte)
47
+
if !ok {
48
+
return fmt.Errorf("failed to unmarshal OauthParRequest value")
49
+
}
50
+
return json.Unmarshal(b, opr)
51
+
}
52
+
53
+
func (opr ParRequest) Value() (driver.Value, error) {
54
+
return json.Marshal(opr)
55
+
}
56
+
57
+
type OauthToken struct {
58
+
gorm.Model
59
+
ClientId string `gorm:"index"`
60
+
ClientAuth ClientAuth `gorm:"type:json"`
61
+
Parameters ParRequest `gorm:"type:json"`
62
+
ExpiresAt time.Time `gorm:"index"`
63
+
DeviceId string
64
+
Sub string `gorm:"index"`
65
+
Code string `gorm:"index"`
66
+
Token string `gorm:"uniqueIndex"`
67
+
RefreshToken string `gorm:"uniqueIndex"`
68
+
Ip string
69
+
}
70
+
71
+
type OauthAuthorizationRequest struct {
72
+
gorm.Model
73
+
RequestId string `gorm:"primaryKey"`
74
+
ClientId string `gorm:"index"`
75
+
ClientAuth ClientAuth `gorm:"type:json"`
76
+
Parameters ParRequest `gorm:"type:json"`
77
+
ExpiresAt time.Time `gorm:"index"`
78
+
DeviceId *string
79
+
Sub *string
80
+
Code *string
81
+
Accepted *bool
82
+
Ip string
83
+
}
+8
-64
oauth/provider/provider.go
+8
-64
oauth/provider/provider.go
···
1
1
package provider
2
2
3
3
import (
4
-
"database/sql/driver"
5
-
"encoding/json"
6
-
"fmt"
7
-
"time"
8
-
9
-
"github.com/haileyok/cocoon/oauth/client_manager"
10
-
"github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
11
-
"gorm.io/gorm"
4
+
"github.com/haileyok/cocoon/oauth/client"
5
+
"github.com/haileyok/cocoon/oauth/dpop"
12
6
)
13
7
14
8
type Provider struct {
15
-
ClientManager *client_manager.ClientManager
16
-
DpopManager *dpop_manager.DpopManager
9
+
ClientManager *client.Manager
10
+
DpopManager *dpop.Manager
17
11
18
12
hostname string
19
13
}
20
14
21
15
type Args struct {
22
16
Hostname string
23
-
ClientManagerArgs client_manager.Args
24
-
DpopManagerArgs dpop_manager.Args
17
+
ClientManagerArgs client.ManagerArgs
18
+
DpopManagerArgs dpop.ManagerArgs
25
19
}
26
20
27
21
func NewProvider(args Args) *Provider {
28
22
return &Provider{
29
-
ClientManager: client_manager.New(args.ClientManagerArgs),
30
-
DpopManager: dpop_manager.New(args.DpopManagerArgs),
23
+
ClientManager: client.NewManager(args.ClientManagerArgs),
24
+
DpopManager: dpop.NewManager(args.DpopManagerArgs),
31
25
hostname: args.Hostname,
32
26
}
33
27
}
···
35
29
func (p *Provider) NextNonce() string {
36
30
return p.DpopManager.NextNonce()
37
31
}
38
-
39
-
type ParRequest struct {
40
-
AuthenticateClientRequestBase
41
-
ResponseType string `form:"response_type" json:"response_type" validate:"required"`
42
-
CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"`
43
-
CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"`
44
-
State string `form:"state" json:"state" validate:"required"`
45
-
RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"`
46
-
Scope string `form:"scope" json:"scope" validate:"required"`
47
-
LoginHint *string `form:"login_hint" json:"login_hint,omitempty"`
48
-
DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"`
49
-
}
50
-
51
-
func (opr *ParRequest) Scan(value any) error {
52
-
b, ok := value.([]byte)
53
-
if !ok {
54
-
return fmt.Errorf("failed to unmarshal OauthParRequest value")
55
-
}
56
-
return json.Unmarshal(b, opr)
57
-
}
58
-
59
-
func (opr ParRequest) Value() (driver.Value, error) {
60
-
return json.Marshal(opr)
61
-
}
62
-
63
-
type OauthToken struct {
64
-
gorm.Model
65
-
ClientId string `gorm:"index"`
66
-
ClientAuth ClientAuth `gorm:"type:json"`
67
-
Parameters ParRequest `gorm:"type:json"`
68
-
ExpiresAt time.Time `gorm:"index"`
69
-
DeviceId string
70
-
Sub string `gorm:"index"`
71
-
Code string `gorm:"index"`
72
-
Token string `gorm:"uniqueIndex"`
73
-
RefreshToken string `gorm:"uniqueIndex"`
74
-
}
75
-
76
-
type OauthAuthorizationRequest struct {
77
-
gorm.Model
78
-
RequestId string `gorm:"primaryKey"`
79
-
ClientId string `gorm:"index"`
80
-
ClientAuth ClientAuth `gorm:"type:json"`
81
-
Parameters ParRequest `gorm:"type:json"`
82
-
ExpiresAt time.Time `gorm:"index"`
83
-
DeviceId *string
84
-
Sub *string
85
-
Code *string
86
-
Accepted *bool
87
-
}
+77
recording_blockstore/recording_blockstore.go
+77
recording_blockstore/recording_blockstore.go
···
1
+
package recording_blockstore
2
+
3
+
import (
4
+
"context"
5
+
6
+
blockformat "github.com/ipfs/go-block-format"
7
+
"github.com/ipfs/go-cid"
8
+
blockstore "github.com/ipfs/go-ipfs-blockstore"
9
+
)
10
+
11
+
type RecordingBlockstore struct {
12
+
base blockstore.Blockstore
13
+
14
+
inserts map[cid.Cid]blockformat.Block
15
+
}
16
+
17
+
func New(base blockstore.Blockstore) *RecordingBlockstore {
18
+
return &RecordingBlockstore{
19
+
base: base,
20
+
inserts: make(map[cid.Cid]blockformat.Block),
21
+
}
22
+
}
23
+
24
+
func (bs *RecordingBlockstore) Has(ctx context.Context, c cid.Cid) (bool, error) {
25
+
return bs.base.Has(ctx, c)
26
+
}
27
+
28
+
func (bs *RecordingBlockstore) Get(ctx context.Context, c cid.Cid) (blockformat.Block, error) {
29
+
return bs.base.Get(ctx, c)
30
+
}
31
+
32
+
func (bs *RecordingBlockstore) GetSize(ctx context.Context, c cid.Cid) (int, error) {
33
+
return bs.base.GetSize(ctx, c)
34
+
}
35
+
36
+
func (bs *RecordingBlockstore) DeleteBlock(ctx context.Context, c cid.Cid) error {
37
+
return bs.base.DeleteBlock(ctx, c)
38
+
}
39
+
40
+
func (bs *RecordingBlockstore) Put(ctx context.Context, block blockformat.Block) error {
41
+
if err := bs.base.Put(ctx, block); err != nil {
42
+
return err
43
+
}
44
+
bs.inserts[block.Cid()] = block
45
+
return nil
46
+
}
47
+
48
+
func (bs *RecordingBlockstore) PutMany(ctx context.Context, blocks []blockformat.Block) error {
49
+
if err := bs.base.PutMany(ctx, blocks); err != nil {
50
+
return err
51
+
}
52
+
53
+
for _, b := range blocks {
54
+
bs.inserts[b.Cid()] = b
55
+
}
56
+
57
+
return nil
58
+
}
59
+
60
+
func (bs *RecordingBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
61
+
return bs.AllKeysChan(ctx)
62
+
}
63
+
64
+
func (bs *RecordingBlockstore) HashOnRead(enabled bool) {
65
+
}
66
+
67
+
func (bs *RecordingBlockstore) GetLogMap() map[cid.Cid]blockformat.Block {
68
+
return bs.inserts
69
+
}
70
+
71
+
func (bs *RecordingBlockstore) GetLogArray() []blockformat.Block {
72
+
var blocks []blockformat.Block
73
+
for _, b := range bs.inserts {
74
+
blocks = append(blocks, b)
75
+
}
76
+
return blocks
77
+
}
+30
server/blockstore_variant.go
+30
server/blockstore_variant.go
···
1
+
package server
2
+
3
+
import (
4
+
"github.com/haileyok/cocoon/sqlite_blockstore"
5
+
blockstore "github.com/ipfs/go-ipfs-blockstore"
6
+
)
7
+
8
+
type BlockstoreVariant int
9
+
10
+
const (
11
+
BlockstoreVariantSqlite = iota
12
+
)
13
+
14
+
func MustReturnBlockstoreVariant(maybeBsv string) BlockstoreVariant {
15
+
switch maybeBsv {
16
+
case "sqlite":
17
+
return BlockstoreVariantSqlite
18
+
default:
19
+
panic("invalid blockstore variant provided")
20
+
}
21
+
}
22
+
23
+
func (s *Server) getBlockstore(did string) blockstore.Blockstore {
24
+
switch s.config.BlockstoreVariant {
25
+
case BlockstoreVariantSqlite:
26
+
return sqlite_blockstore.New(did, s.db)
27
+
default:
28
+
return sqlite_blockstore.New(did, s.db)
29
+
}
30
+
}
+37
-7
server/handle_account.go
+37
-7
server/handle_account.go
···
3
3
import (
4
4
"time"
5
5
6
+
"github.com/haileyok/cocoon/oauth"
7
+
"github.com/haileyok/cocoon/oauth/constants"
6
8
"github.com/haileyok/cocoon/oauth/provider"
9
+
"github.com/hako/durafmt"
7
10
"github.com/labstack/echo/v4"
8
11
)
9
12
10
13
func (s *Server) handleAccount(e echo.Context) error {
14
+
ctx := e.Request().Context()
11
15
repo, sess, err := s.getSessionRepoOrErr(e)
12
16
if err != nil {
13
17
return e.Redirect(303, "/account/signin")
14
18
}
15
19
16
-
now := time.Now()
20
+
oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
17
21
18
22
var tokens []provider.OauthToken
19
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND expires_at >= ? ORDER BY created_at ASC", nil, repo.Repo.Did, now).Scan(&tokens).Error; err != nil {
23
+
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
20
24
s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
21
25
sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
22
26
sess.Save(e.Request(), e.Response())
···
25
29
})
26
30
}
27
31
32
+
var filtered []provider.OauthToken
33
+
for _, t := range tokens {
34
+
ageRes := oauth.GetSessionAgeFromToken(t)
35
+
if ageRes.SessionExpired {
36
+
continue
37
+
}
38
+
filtered = append(filtered, t)
39
+
}
40
+
41
+
now := time.Now()
42
+
28
43
tokenInfo := []map[string]string{}
29
44
for _, t := range tokens {
45
+
ageRes := oauth.GetSessionAgeFromToken(t)
46
+
maxTime := constants.PublicClientSessionLifetime
47
+
if t.ClientAuth.Method != "none" {
48
+
maxTime = constants.ConfidentialClientSessionLifetime
49
+
}
50
+
51
+
var clientName string
52
+
metadata, err := s.oauthProvider.ClientManager.GetClient(ctx, t.ClientId)
53
+
if err != nil {
54
+
clientName = t.ClientId
55
+
} else {
56
+
clientName = metadata.Metadata.ClientName
57
+
}
58
+
30
59
tokenInfo = append(tokenInfo, map[string]string{
31
-
"ClientId": t.ClientId,
32
-
"CreatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
33
-
"UpdatedAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
34
-
"ExpiresAt": t.CreatedAt.Format("02 Jan 06 15:04 MST"),
35
-
"Token": t.Token,
60
+
"ClientName": clientName,
61
+
"Age": durafmt.Parse(ageRes.SessionAge).LimitFirstN(2).String(),
62
+
"LastUpdated": durafmt.Parse(now.Sub(t.UpdatedAt)).LimitFirstN(2).String(),
63
+
"ExpiresIn": durafmt.Parse(now.Add(maxTime).Sub(now)).LimitFirstN(2).String(),
64
+
"Token": t.Token,
65
+
"Ip": t.Ip,
36
66
})
37
67
}
38
68
+2
-3
server/handle_import_repo.go
+2
-3
server/handle_import_repo.go
···
9
9
10
10
"github.com/bluesky-social/indigo/atproto/syntax"
11
11
"github.com/bluesky-social/indigo/repo"
12
-
"github.com/haileyok/cocoon/blockstore"
13
12
"github.com/haileyok/cocoon/internal/helpers"
14
13
"github.com/haileyok/cocoon/models"
15
14
blocks "github.com/ipfs/go-block-format"
···
27
26
return helpers.ServerError(e, nil)
28
27
}
29
28
30
-
bs := blockstore.New(urepo.Repo.Did, s.db)
29
+
bs := s.getBlockstore(urepo.Repo.Did)
31
30
32
31
cs, err := car.NewCarReader(bytes.NewReader(b))
33
32
if err != nil {
···
107
106
return helpers.ServerError(e, nil)
108
107
}
109
108
110
-
if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil {
109
+
if err := s.UpdateRepo(context.TODO(), urepo.Repo.Did, root, rev); err != nil {
111
110
s.logger.Error("error updating repo after commit", "error", err)
112
111
return helpers.ServerError(e, nil)
113
112
}
+4
-10
server/handle_oauth_token.go
+4
-10
server/handle_oauth_token.go
···
157
157
Code: *authReq.Code,
158
158
Token: accessString,
159
159
RefreshToken: refreshToken,
160
+
Ip: authReq.Ip,
160
161
}, nil).Error; err != nil {
161
162
s.logger.Error("error creating token in db", "error", err)
162
163
return helpers.ServerError(e, nil)
···
203
204
return helpers.InputError(e, to.StringPtr("dpop proof does not match expected jkt"))
204
205
}
205
206
206
-
sessionLifetime := constants.PublicClientSessionLifetime
207
-
refreshLifetime := constants.PublicClientRefreshLifetime
208
-
if clientAuth.Method != "none" {
209
-
sessionLifetime = constants.ConfidentialClientSessionLifetime
210
-
refreshLifetime = constants.ConfidentialClientRefreshLifetime
211
-
}
207
+
ageRes := oauth.GetSessionAgeFromToken(oauthToken)
212
208
213
-
sessionAge := time.Since(oauthToken.CreatedAt)
214
-
if sessionAge > sessionLifetime {
209
+
if ageRes.SessionExpired {
215
210
return helpers.InputError(e, to.StringPtr("Session expired"))
216
211
}
217
212
218
-
refreshAge := time.Since(oauthToken.UpdatedAt)
219
-
if refreshAge > refreshLifetime {
213
+
if ageRes.RefreshExpired {
220
214
return helpers.InputError(e, to.StringPtr("Refresh token expired"))
221
215
}
222
216
+27
-15
server/handle_proxy.go
+27
-15
server/handle_proxy.go
···
17
17
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18
18
)
19
19
20
-
func (s *Server) handleProxy(e echo.Context) error {
21
-
repo, isAuthed := e.Get("repo").(*models.RepoActor)
22
-
23
-
pts := strings.Split(e.Request().URL.Path, "/")
24
-
if len(pts) != 3 {
25
-
return fmt.Errorf("incorrect number of parts")
26
-
}
27
-
20
+
func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
28
21
svc := e.Request().Header.Get("atproto-proxy")
29
22
if svc == "" {
30
-
svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably
23
+
svc = s.config.DefaultAtprotoProxy
31
24
}
32
25
33
26
svcPts := strings.Split(svc, "#")
34
27
if len(svcPts) != 2 {
35
-
return fmt.Errorf("invalid service header")
28
+
return "", "", fmt.Errorf("invalid service header")
36
29
}
37
30
38
31
svcDid := svcPts[0]
···
40
33
41
34
doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
42
35
if err != nil {
43
-
return err
36
+
return "", "", err
44
37
}
45
38
46
39
var endpoint string
···
50
43
}
51
44
}
52
45
46
+
return endpoint, svcDid, nil
47
+
}
48
+
49
+
func (s *Server) handleProxy(e echo.Context) error {
50
+
lgr := s.logger.With("handler", "handleProxy")
51
+
52
+
repo, isAuthed := e.Get("repo").(*models.RepoActor)
53
+
54
+
pts := strings.Split(e.Request().URL.Path, "/")
55
+
if len(pts) != 3 {
56
+
return fmt.Errorf("incorrect number of parts")
57
+
}
58
+
59
+
endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
60
+
if err != nil {
61
+
lgr.Error("could not get atproto proxy", "error", err)
62
+
return helpers.ServerError(e, nil)
63
+
}
64
+
53
65
requrl := e.Request().URL
54
66
requrl.Host = strings.TrimPrefix(endpoint, "https://")
55
67
requrl.Scheme = "https"
···
78
90
}
79
91
hj, err := json.Marshal(header)
80
92
if err != nil {
81
-
s.logger.Error("error marshaling header", "error", err)
93
+
lgr.Error("error marshaling header", "error", err)
82
94
return helpers.ServerError(e, nil)
83
95
}
84
96
···
93
105
}
94
106
pj, err := json.Marshal(payload)
95
107
if err != nil {
96
-
s.logger.Error("error marashaling payload", "error", err)
108
+
lgr.Error("error marashaling payload", "error", err)
97
109
return helpers.ServerError(e, nil)
98
110
}
99
111
···
104
116
105
117
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
106
118
if err != nil {
107
-
s.logger.Error("can't load private key", "error", err)
119
+
lgr.Error("can't load private key", "error", err)
108
120
return err
109
121
}
110
122
111
123
R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
112
124
if err != nil {
113
-
s.logger.Error("error signing", "error", err)
125
+
lgr.Error("error signing", "error", err)
114
126
}
115
127
116
128
rBytes := R.Bytes()
+65
server/handle_server_check_account_status.go
+65
server/handle_server_check_account_status.go
···
1
+
package server
2
+
3
+
import (
4
+
"github.com/haileyok/cocoon/internal/helpers"
5
+
"github.com/haileyok/cocoon/models"
6
+
"github.com/ipfs/go-cid"
7
+
"github.com/labstack/echo/v4"
8
+
)
9
+
10
+
type ComAtprotoServerCheckAccountStatusResponse struct {
11
+
Activated bool `json:"activated"`
12
+
ValidDid bool `json:"validDid"`
13
+
RepoCommit string `json:"repoCommit"`
14
+
RepoRev string `json:"repoRev"`
15
+
RepoBlocks int64 `json:"repoBlocks"`
16
+
IndexedRecords int64 `json:"indexedRecords"`
17
+
PrivateStateValues int64 `json:"privateStateValues"`
18
+
ExpectedBlobs int64 `json:"expectedBlobs"`
19
+
ImportedBlobs int64 `json:"importedBlobs"`
20
+
}
21
+
22
+
func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
23
+
urepo := e.Get("repo").(*models.RepoActor)
24
+
25
+
resp := ComAtprotoServerCheckAccountStatusResponse{
26
+
Activated: true, // TODO: should allow for deactivation etc.
27
+
ValidDid: true, // TODO: should probably verify?
28
+
RepoRev: urepo.Rev,
29
+
ImportedBlobs: 0, // TODO: ???
30
+
}
31
+
32
+
rootcid, err := cid.Cast(urepo.Root)
33
+
if err != nil {
34
+
s.logger.Error("error casting cid", "error", err)
35
+
return helpers.ServerError(e, nil)
36
+
}
37
+
resp.RepoCommit = rootcid.String()
38
+
39
+
type CountResp struct {
40
+
Ct int64
41
+
}
42
+
43
+
var blockCtResp CountResp
44
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
45
+
s.logger.Error("error getting block count", "error", err)
46
+
return helpers.ServerError(e, nil)
47
+
}
48
+
resp.RepoBlocks = blockCtResp.Ct
49
+
50
+
var recCtResp CountResp
51
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
52
+
s.logger.Error("error getting record count", "error", err)
53
+
return helpers.ServerError(e, nil)
54
+
}
55
+
resp.IndexedRecords = recCtResp.Ct
56
+
57
+
var blobCtResp CountResp
58
+
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
59
+
s.logger.Error("error getting record count", "error", err)
60
+
return helpers.ServerError(e, nil)
61
+
}
62
+
resp.ExpectedBlobs = blobCtResp.Ct
63
+
64
+
return e.JSON(200, resp)
65
+
}
+2
-2
server/handle_server_confirm_email.go
+2
-2
server/handle_server_confirm_email.go
···
28
28
}
29
29
30
30
if urepo.EmailVerificationCode == nil || urepo.EmailVerificationCodeExpiresAt == nil {
31
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
31
+
return helpers.ExpiredTokenError(e)
32
32
}
33
33
34
34
if *urepo.EmailVerificationCode != req.Token {
···
36
36
}
37
37
38
38
if time.Now().UTC().After(*urepo.EmailVerificationCodeExpiresAt) {
39
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
39
+
return helpers.ExpiredTokenError(e)
40
40
}
41
41
42
42
now := time.Now().UTC()
+2
-3
server/handle_server_create_account.go
+2
-3
server/handle_server_create_account.go
···
14
14
"github.com/bluesky-social/indigo/events"
15
15
"github.com/bluesky-social/indigo/repo"
16
16
"github.com/bluesky-social/indigo/util"
17
-
"github.com/haileyok/cocoon/blockstore"
18
17
"github.com/haileyok/cocoon/internal/helpers"
19
18
"github.com/haileyok/cocoon/models"
20
19
"github.com/labstack/echo/v4"
···
177
176
}
178
177
179
178
if customDidHeader == "" {
180
-
bs := blockstore.New(signupDid, s.db)
179
+
bs := s.getBlockstore(signupDid)
181
180
r := repo.NewRepo(context.TODO(), signupDid, bs)
182
181
183
182
root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
···
186
185
return helpers.ServerError(e, nil)
187
186
}
188
187
189
-
if err := bs.UpdateRepo(context.TODO(), root, rev); err != nil {
188
+
if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
190
189
s.logger.Error("error updating repo after commit", "error", err)
191
190
return helpers.ServerError(e, nil)
192
191
}
+8
-6
server/handle_server_get_service_auth.go
+8
-6
server/handle_server_get_service_auth.go
···
19
19
20
20
type ServerGetServiceAuthRequest struct {
21
21
Aud string `query:"aud" validate:"required,atproto-did"`
22
-
Exp int64 `query:"exp"`
23
-
Lxm string `query:"lxm" validate:"required,atproto-nsid"`
22
+
// exp should be a float, as some clients will send a non-integer expiration
23
+
Exp float64 `query:"exp"`
24
+
Lxm string `query:"lxm" validate:"required,atproto-nsid"`
24
25
}
25
26
26
27
func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
···
34
35
return helpers.InputError(e, nil)
35
36
}
36
37
38
+
exp := int64(req.Exp)
37
39
now := time.Now().Unix()
38
-
if req.Exp == 0 {
39
-
req.Exp = now + 60 // default
40
+
if exp == 0 {
41
+
exp = now + 60 // default
40
42
}
41
43
42
44
if req.Lxm == "com.atproto.server.getServiceAuth" {
···
44
46
}
45
47
46
48
maxExp := now + (60 * 30)
47
-
if req.Exp > maxExp {
49
+
if exp > maxExp {
48
50
return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
49
51
}
50
52
···
68
70
"aud": req.Aud,
69
71
"lxm": req.Lxm,
70
72
"jti": uuid.NewString(),
71
-
"exp": req.Exp,
73
+
"exp": exp,
72
74
"iat": now,
73
75
}
74
76
pj, err := json.Marshal(payload)
+2
-2
server/handle_server_reset_password.go
+2
-2
server/handle_server_reset_password.go
···
33
33
}
34
34
35
35
if *urepo.PasswordResetCode != req.Token {
36
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
36
+
return helpers.InvalidTokenError(e)
37
37
}
38
38
39
39
if time.Now().UTC().After(*urepo.PasswordResetCodeExpiresAt) {
40
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
40
+
return helpers.ExpiredTokenError(e)
41
41
}
42
42
43
43
hash, err := bcrypt.GenerateFromPassword([]byte(req.Password), 10)
+3
-4
server/handle_server_update_email.go
+3
-4
server/handle_server_update_email.go
···
3
3
import (
4
4
"time"
5
5
6
-
"github.com/Azure/go-autorest/autorest/to"
7
6
"github.com/haileyok/cocoon/internal/helpers"
8
7
"github.com/haileyok/cocoon/models"
9
8
"github.com/labstack/echo/v4"
···
29
28
}
30
29
31
30
if urepo.EmailUpdateCode == nil || urepo.EmailUpdateCodeExpiresAt == nil {
32
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
31
+
return helpers.InvalidTokenError(e)
33
32
}
34
33
35
34
if *urepo.EmailUpdateCode != req.Token {
36
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
35
+
return helpers.InvalidTokenError(e)
37
36
}
38
37
39
38
if time.Now().UTC().After(*urepo.EmailUpdateCodeExpiresAt) {
40
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
39
+
return helpers.ExpiredTokenError(e)
41
40
}
42
41
43
42
if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
+1
-2
server/handle_sync_get_blocks.go
+1
-2
server/handle_sync_get_blocks.go
···
6
6
"strings"
7
7
8
8
"github.com/bluesky-social/indigo/carstore"
9
-
"github.com/haileyok/cocoon/blockstore"
10
9
"github.com/haileyok/cocoon/internal/helpers"
11
10
"github.com/ipfs/go-cid"
12
11
cbor "github.com/ipfs/go-ipld-cbor"
···
54
53
return helpers.ServerError(e, nil)
55
54
}
56
55
57
-
bs := blockstore.New(urepo.Repo.Did, s.db)
56
+
bs := s.getBlockstore(urepo.Repo.Did)
58
57
59
58
for _, c := range cids {
60
59
b, err := bs.Get(context.TODO(), c)
+268
server/middleware.go
+268
server/middleware.go
···
1
+
package server
2
+
3
+
import (
4
+
"crypto/sha256"
5
+
"encoding/base64"
6
+
"fmt"
7
+
"strings"
8
+
"time"
9
+
10
+
"github.com/Azure/go-autorest/autorest/to"
11
+
"github.com/golang-jwt/jwt/v4"
12
+
"github.com/haileyok/cocoon/internal/helpers"
13
+
"github.com/haileyok/cocoon/models"
14
+
"github.com/haileyok/cocoon/oauth/provider"
15
+
"github.com/labstack/echo/v4"
16
+
"gitlab.com/yawning/secp256k1-voi"
17
+
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18
+
"gorm.io/gorm"
19
+
)
20
+
21
+
func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
22
+
return func(e echo.Context) error {
23
+
username, password, ok := e.Request().BasicAuth()
24
+
if !ok || username != "admin" || password != s.config.AdminPassword {
25
+
return helpers.InputError(e, to.StringPtr("Unauthorized"))
26
+
}
27
+
28
+
if err := next(e); err != nil {
29
+
e.Error(err)
30
+
}
31
+
32
+
return nil
33
+
}
34
+
}
35
+
36
+
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
37
+
return func(e echo.Context) error {
38
+
authheader := e.Request().Header.Get("authorization")
39
+
if authheader == "" {
40
+
return e.JSON(401, map[string]string{"error": "Unauthorized"})
41
+
}
42
+
43
+
pts := strings.Split(authheader, " ")
44
+
if len(pts) != 2 {
45
+
return helpers.ServerError(e, nil)
46
+
}
47
+
48
+
// move on to oauth session middleware if this is a dpop token
49
+
if pts[0] == "DPoP" {
50
+
return next(e)
51
+
}
52
+
53
+
tokenstr := pts[1]
54
+
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
55
+
claims, ok := token.Claims.(jwt.MapClaims)
56
+
if !ok {
57
+
return helpers.InvalidTokenError(e)
58
+
}
59
+
60
+
var did string
61
+
var repo *models.RepoActor
62
+
63
+
// service auth tokens
64
+
lxm, hasLxm := claims["lxm"]
65
+
if hasLxm {
66
+
pts := strings.Split(e.Request().URL.String(), "/")
67
+
if lxm != pts[len(pts)-1] {
68
+
s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
69
+
return helpers.InputError(e, nil)
70
+
}
71
+
72
+
maybeDid, ok := claims["iss"].(string)
73
+
if !ok {
74
+
s.logger.Error("no iss in service auth token", "error", err)
75
+
return helpers.InputError(e, nil)
76
+
}
77
+
did = maybeDid
78
+
79
+
maybeRepo, err := s.getRepoActorByDid(did)
80
+
if err != nil {
81
+
s.logger.Error("error fetching repo", "error", err)
82
+
return helpers.ServerError(e, nil)
83
+
}
84
+
repo = maybeRepo
85
+
}
86
+
87
+
if token.Header["alg"] != "ES256K" {
88
+
token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
89
+
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
90
+
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
91
+
}
92
+
return s.privateKey.Public(), nil
93
+
})
94
+
if err != nil {
95
+
s.logger.Error("error parsing jwt", "error", err)
96
+
return helpers.ExpiredTokenError(e)
97
+
}
98
+
99
+
if !token.Valid {
100
+
return helpers.InvalidTokenError(e)
101
+
}
102
+
} else {
103
+
kpts := strings.Split(tokenstr, ".")
104
+
signingInput := kpts[0] + "." + kpts[1]
105
+
hash := sha256.Sum256([]byte(signingInput))
106
+
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
107
+
if err != nil {
108
+
s.logger.Error("error decoding signature bytes", "error", err)
109
+
return helpers.ServerError(e, nil)
110
+
}
111
+
112
+
if len(sigBytes) != 64 {
113
+
s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
114
+
return helpers.ServerError(e, nil)
115
+
}
116
+
117
+
rBytes := sigBytes[:32]
118
+
sBytes := sigBytes[32:]
119
+
rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
120
+
ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
121
+
122
+
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
123
+
if err != nil {
124
+
s.logger.Error("can't load private key", "error", err)
125
+
return err
126
+
}
127
+
128
+
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
129
+
if !ok {
130
+
s.logger.Error("error getting public key from sk")
131
+
return helpers.ServerError(e, nil)
132
+
}
133
+
134
+
verified := pubKey.VerifyRaw(hash[:], rr, ss)
135
+
if !verified {
136
+
s.logger.Error("error verifying", "error", err)
137
+
return helpers.ServerError(e, nil)
138
+
}
139
+
}
140
+
141
+
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
142
+
scope, _ := claims["scope"].(string)
143
+
144
+
if isRefresh && scope != "com.atproto.refresh" {
145
+
return helpers.InvalidTokenError(e)
146
+
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
147
+
return helpers.InvalidTokenError(e)
148
+
}
149
+
150
+
table := "tokens"
151
+
if isRefresh {
152
+
table = "refresh_tokens"
153
+
}
154
+
155
+
if isRefresh {
156
+
type Result struct {
157
+
Found bool
158
+
}
159
+
var result Result
160
+
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
161
+
if err == gorm.ErrRecordNotFound {
162
+
return helpers.InvalidTokenError(e)
163
+
}
164
+
165
+
s.logger.Error("error getting token from db", "error", err)
166
+
return helpers.ServerError(e, nil)
167
+
}
168
+
169
+
if !result.Found {
170
+
return helpers.InvalidTokenError(e)
171
+
}
172
+
}
173
+
174
+
exp, ok := claims["exp"].(float64)
175
+
if !ok {
176
+
s.logger.Error("error getting iat from token")
177
+
return helpers.ServerError(e, nil)
178
+
}
179
+
180
+
if exp < float64(time.Now().UTC().Unix()) {
181
+
return helpers.ExpiredTokenError(e)
182
+
}
183
+
184
+
if repo == nil {
185
+
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
186
+
if err != nil {
187
+
s.logger.Error("error fetching repo", "error", err)
188
+
return helpers.ServerError(e, nil)
189
+
}
190
+
repo = maybeRepo
191
+
did = repo.Repo.Did
192
+
}
193
+
194
+
e.Set("repo", repo)
195
+
e.Set("did", did)
196
+
e.Set("token", tokenstr)
197
+
198
+
if err := next(e); err != nil {
199
+
return helpers.InvalidTokenError(e)
200
+
}
201
+
202
+
return nil
203
+
}
204
+
}
205
+
206
+
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
207
+
return func(e echo.Context) error {
208
+
authheader := e.Request().Header.Get("authorization")
209
+
if authheader == "" {
210
+
return e.JSON(401, map[string]string{"error": "Unauthorized"})
211
+
}
212
+
213
+
pts := strings.Split(authheader, " ")
214
+
if len(pts) != 2 {
215
+
return helpers.ServerError(e, nil)
216
+
}
217
+
218
+
if pts[0] != "DPoP" {
219
+
return next(e)
220
+
}
221
+
222
+
accessToken := pts[1]
223
+
224
+
nonce := s.oauthProvider.NextNonce()
225
+
if nonce != "" {
226
+
e.Response().Header().Set("DPoP-Nonce", nonce)
227
+
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
228
+
}
229
+
230
+
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
231
+
if err != nil {
232
+
s.logger.Error("invalid dpop proof", "error", err)
233
+
return helpers.InputError(e, to.StringPtr(err.Error()))
234
+
}
235
+
236
+
var oauthToken provider.OauthToken
237
+
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
238
+
s.logger.Error("error finding access token in db", "error", err)
239
+
return helpers.InputError(e, nil)
240
+
}
241
+
242
+
if oauthToken.Token == "" {
243
+
return helpers.InvalidTokenError(e)
244
+
}
245
+
246
+
if *oauthToken.Parameters.DpopJkt != proof.JKT {
247
+
s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
248
+
return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
249
+
}
250
+
251
+
if time.Now().After(oauthToken.ExpiresAt) {
252
+
return helpers.ExpiredTokenError(e)
253
+
}
254
+
255
+
repo, err := s.getRepoActorByDid(oauthToken.Sub)
256
+
if err != nil {
257
+
s.logger.Error("could not find actor in db", "error", err)
258
+
return helpers.ServerError(e, nil)
259
+
}
260
+
261
+
e.Set("repo", repo)
262
+
e.Set("did", repo.Repo.Did)
263
+
e.Set("token", accessToken)
264
+
e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
265
+
266
+
return next(e)
267
+
}
268
+
}
+13
-13
server/repo.go
+13
-13
server/repo.go
···
16
16
"github.com/bluesky-social/indigo/events"
17
17
lexutil "github.com/bluesky-social/indigo/lex/util"
18
18
"github.com/bluesky-social/indigo/repo"
19
-
"github.com/bluesky-social/indigo/util"
20
-
"github.com/haileyok/cocoon/blockstore"
21
19
"github.com/haileyok/cocoon/internal/db"
22
20
"github.com/haileyok/cocoon/models"
21
+
"github.com/haileyok/cocoon/recording_blockstore"
23
22
blocks "github.com/ipfs/go-block-format"
24
23
"github.com/ipfs/go-cid"
25
24
cbor "github.com/ipfs/go-ipld-cbor"
···
103
102
return nil, err
104
103
}
105
104
106
-
dbs := blockstore.New(urepo.Did, rm.db)
105
+
dbs := rm.s.getBlockstore(urepo.Did)
106
+
bs := recording_blockstore.New(dbs)
107
107
r, err := repo.OpenRepo(context.TODO(), dbs, rootcid)
108
108
109
109
entries := []models.Record{}
···
274
274
}
275
275
}
276
276
277
-
for _, op := range dbs.GetLog() {
277
+
for _, op := range bs.GetLogMap() {
278
278
if _, err := carstore.LdWrite(buf, op.Cid().Bytes(), op.RawData()); err != nil {
279
279
return nil, err
280
280
}
···
318
318
Rev: rev,
319
319
Since: &urepo.Rev,
320
320
Commit: lexutil.LexLink(newroot),
321
-
Time: time.Now().Format(util.ISO8601),
321
+
Time: time.Now().Format(time.RFC3339Nano),
322
322
Ops: ops,
323
323
TooBig: false,
324
324
},
325
325
})
326
326
327
-
if err := dbs.UpdateRepo(context.TODO(), newroot, rev); err != nil {
327
+
if err := rm.s.UpdateRepo(context.TODO(), urepo.Did, newroot, rev); err != nil {
328
328
return nil, err
329
329
}
330
330
···
345
345
return cid.Undef, nil, err
346
346
}
347
347
348
-
dbs := blockstore.New(urepo.Did, rm.db)
349
-
bs := util.NewLoggingBstore(dbs)
348
+
dbs := rm.s.getBlockstore(urepo.Did)
349
+
bs := recording_blockstore.New(dbs)
350
350
351
351
r, err := repo.OpenRepo(context.TODO(), bs, c)
352
352
if err != nil {
···
358
358
return cid.Undef, nil, err
359
359
}
360
360
361
-
return c, bs.GetLoggedBlocks(), nil
361
+
return c, bs.GetLogArray(), nil
362
362
}
363
363
364
364
func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
···
414
414
return nil, fmt.Errorf("error unmarshaling cbor: %w", err)
415
415
}
416
416
417
-
var deepiter func(interface{}) error
418
-
deepiter = func(item interface{}) error {
417
+
var deepiter func(any) error
418
+
deepiter = func(item any) error {
419
419
switch val := item.(type) {
420
-
case map[string]interface{}:
420
+
case map[string]any:
421
421
if val["$type"] == "blob" {
422
422
if ref, ok := val["ref"].(string); ok {
423
423
c, err := cid.Parse(ref)
···
430
430
return deepiter(v)
431
431
}
432
432
}
433
-
case []interface{}:
433
+
case []any:
434
434
for _, v := range val {
435
435
deepiter(v)
436
436
}
+44
-283
server/server.go
+44
-283
server/server.go
···
4
4
"bytes"
5
5
"context"
6
6
"crypto/ecdsa"
7
-
"crypto/sha256"
8
7
"embed"
9
-
"encoding/base64"
10
8
"errors"
11
9
"fmt"
12
10
"io"
···
15
13
"net/smtp"
16
14
"os"
17
15
"path/filepath"
18
-
"strings"
19
16
"sync"
20
17
"text/template"
21
18
"time"
22
19
23
-
"github.com/Azure/go-autorest/autorest/to"
24
20
"github.com/aws/aws-sdk-go/aws"
25
21
"github.com/aws/aws-sdk-go/aws/credentials"
26
22
"github.com/aws/aws-sdk-go/aws/session"
···
32
28
"github.com/bluesky-social/indigo/xrpc"
33
29
"github.com/domodwyer/mailyak/v3"
34
30
"github.com/go-playground/validator"
35
-
"github.com/golang-jwt/jwt/v4"
36
31
"github.com/gorilla/sessions"
37
32
"github.com/haileyok/cocoon/identity"
38
33
"github.com/haileyok/cocoon/internal/db"
39
34
"github.com/haileyok/cocoon/internal/helpers"
40
35
"github.com/haileyok/cocoon/models"
41
-
"github.com/haileyok/cocoon/oauth/client_manager"
36
+
"github.com/haileyok/cocoon/oauth/client"
42
37
"github.com/haileyok/cocoon/oauth/constants"
43
-
"github.com/haileyok/cocoon/oauth/dpop/dpop_manager"
38
+
"github.com/haileyok/cocoon/oauth/dpop"
44
39
"github.com/haileyok/cocoon/oauth/provider"
45
40
"github.com/haileyok/cocoon/plc"
41
+
"github.com/ipfs/go-cid"
46
42
echo_session "github.com/labstack/echo-contrib/session"
47
43
"github.com/labstack/echo/v4"
48
44
"github.com/labstack/echo/v4/middleware"
49
45
slogecho "github.com/samber/slog-echo"
50
-
"gitlab.com/yawning/secp256k1-voi"
51
-
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
52
46
"gorm.io/driver/sqlite"
53
47
"gorm.io/gorm"
54
48
)
···
109
103
S3Config *S3Config
110
104
111
105
SessionSecret string
106
+
107
+
DefaultAtprotoProxy string
108
+
109
+
BlockstoreVariant BlockstoreVariant
112
110
}
113
111
114
112
type config struct {
115
-
Version string
116
-
Did string
117
-
Hostname string
118
-
ContactEmail string
119
-
EnforcePeering bool
120
-
Relays []string
121
-
AdminPassword string
122
-
SmtpEmail string
123
-
SmtpName string
113
+
Version string
114
+
Did string
115
+
Hostname string
116
+
ContactEmail string
117
+
EnforcePeering bool
118
+
Relays []string
119
+
AdminPassword string
120
+
SmtpEmail string
121
+
SmtpName string
122
+
DefaultAtprotoProxy string
123
+
BlockstoreVariant BlockstoreVariant
124
124
}
125
125
126
126
type CustomValidator struct {
···
197
197
return t.templates.ExecuteTemplate(w, name, data)
198
198
}
199
199
200
-
func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
201
-
return func(e echo.Context) error {
202
-
username, password, ok := e.Request().BasicAuth()
203
-
if !ok || username != "admin" || password != s.config.AdminPassword {
204
-
return helpers.InputError(e, to.StringPtr("Unauthorized"))
205
-
}
206
-
207
-
if err := next(e); err != nil {
208
-
e.Error(err)
209
-
}
210
-
211
-
return nil
212
-
}
213
-
}
214
-
215
-
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
216
-
return func(e echo.Context) error {
217
-
authheader := e.Request().Header.Get("authorization")
218
-
if authheader == "" {
219
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
220
-
}
221
-
222
-
pts := strings.Split(authheader, " ")
223
-
if len(pts) != 2 {
224
-
return helpers.ServerError(e, nil)
225
-
}
226
-
227
-
// move on to oauth session middleware if this is a dpop token
228
-
if pts[0] == "DPoP" {
229
-
return next(e)
230
-
}
231
-
232
-
tokenstr := pts[1]
233
-
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
234
-
claims, ok := token.Claims.(jwt.MapClaims)
235
-
if !ok {
236
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
237
-
}
238
-
239
-
var did string
240
-
var repo *models.RepoActor
241
-
242
-
// service auth tokens
243
-
lxm, hasLxm := claims["lxm"]
244
-
if hasLxm {
245
-
pts := strings.Split(e.Request().URL.String(), "/")
246
-
if lxm != pts[len(pts)-1] {
247
-
s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
248
-
return helpers.InputError(e, nil)
249
-
}
250
-
251
-
maybeDid, ok := claims["iss"].(string)
252
-
if !ok {
253
-
s.logger.Error("no iss in service auth token", "error", err)
254
-
return helpers.InputError(e, nil)
255
-
}
256
-
did = maybeDid
257
-
258
-
maybeRepo, err := s.getRepoActorByDid(did)
259
-
if err != nil {
260
-
s.logger.Error("error fetching repo", "error", err)
261
-
return helpers.ServerError(e, nil)
262
-
}
263
-
repo = maybeRepo
264
-
}
265
-
266
-
if token.Header["alg"] != "ES256K" {
267
-
token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
268
-
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
269
-
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
270
-
}
271
-
return s.privateKey.Public(), nil
272
-
})
273
-
if err != nil {
274
-
s.logger.Error("error parsing jwt", "error", err)
275
-
// NOTE: https://github.com/bluesky-social/atproto/discussions/3319
276
-
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
277
-
}
278
-
279
-
if !token.Valid {
280
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
281
-
}
282
-
} else {
283
-
kpts := strings.Split(tokenstr, ".")
284
-
signingInput := kpts[0] + "." + kpts[1]
285
-
hash := sha256.Sum256([]byte(signingInput))
286
-
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
287
-
if err != nil {
288
-
s.logger.Error("error decoding signature bytes", "error", err)
289
-
return helpers.ServerError(e, nil)
290
-
}
291
-
292
-
if len(sigBytes) != 64 {
293
-
s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
294
-
return helpers.ServerError(e, nil)
295
-
}
296
-
297
-
rBytes := sigBytes[:32]
298
-
sBytes := sigBytes[32:]
299
-
rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
300
-
ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
301
-
302
-
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
303
-
if err != nil {
304
-
s.logger.Error("can't load private key", "error", err)
305
-
return err
306
-
}
307
-
308
-
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
309
-
if !ok {
310
-
s.logger.Error("error getting public key from sk")
311
-
return helpers.ServerError(e, nil)
312
-
}
313
-
314
-
verified := pubKey.VerifyRaw(hash[:], rr, ss)
315
-
if !verified {
316
-
s.logger.Error("error verifying", "error", err)
317
-
return helpers.ServerError(e, nil)
318
-
}
319
-
}
320
-
321
-
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
322
-
scope, _ := claims["scope"].(string)
323
-
324
-
if isRefresh && scope != "com.atproto.refresh" {
325
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
326
-
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
327
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
328
-
}
329
-
330
-
table := "tokens"
331
-
if isRefresh {
332
-
table = "refresh_tokens"
333
-
}
334
-
335
-
if isRefresh {
336
-
type Result struct {
337
-
Found bool
338
-
}
339
-
var result Result
340
-
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
341
-
if err == gorm.ErrRecordNotFound {
342
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
343
-
}
344
-
345
-
s.logger.Error("error getting token from db", "error", err)
346
-
return helpers.ServerError(e, nil)
347
-
}
348
-
349
-
if !result.Found {
350
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
351
-
}
352
-
}
353
-
354
-
exp, ok := claims["exp"].(float64)
355
-
if !ok {
356
-
s.logger.Error("error getting iat from token")
357
-
return helpers.ServerError(e, nil)
358
-
}
359
-
360
-
if exp < float64(time.Now().UTC().Unix()) {
361
-
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
362
-
}
363
-
364
-
if repo == nil {
365
-
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
366
-
if err != nil {
367
-
s.logger.Error("error fetching repo", "error", err)
368
-
return helpers.ServerError(e, nil)
369
-
}
370
-
repo = maybeRepo
371
-
did = repo.Repo.Did
372
-
}
373
-
374
-
e.Set("repo", repo)
375
-
e.Set("did", did)
376
-
e.Set("token", tokenstr)
377
-
378
-
if err := next(e); err != nil {
379
-
e.Error(err)
380
-
}
381
-
382
-
return nil
383
-
}
384
-
}
385
-
386
-
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
387
-
return func(e echo.Context) error {
388
-
authheader := e.Request().Header.Get("authorization")
389
-
if authheader == "" {
390
-
return e.JSON(401, map[string]string{"error": "Unauthorized"})
391
-
}
392
-
393
-
pts := strings.Split(authheader, " ")
394
-
if len(pts) != 2 {
395
-
return helpers.ServerError(e, nil)
396
-
}
397
-
398
-
if pts[0] != "DPoP" {
399
-
return next(e)
400
-
}
401
-
402
-
accessToken := pts[1]
403
-
404
-
nonce := s.oauthProvider.NextNonce()
405
-
if nonce != "" {
406
-
e.Response().Header().Set("DPoP-Nonce", nonce)
407
-
e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
408
-
}
409
-
410
-
proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
411
-
if err != nil {
412
-
s.logger.Error("invalid dpop proof", "error", err)
413
-
return helpers.InputError(e, to.StringPtr(err.Error()))
414
-
}
415
-
416
-
var oauthToken provider.OauthToken
417
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
418
-
s.logger.Error("error finding access token in db", "error", err)
419
-
return helpers.InputError(e, nil)
420
-
}
421
-
422
-
if oauthToken.Token == "" {
423
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
424
-
}
425
-
426
-
if *oauthToken.Parameters.DpopJkt != proof.JKT {
427
-
s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
428
-
return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
429
-
}
430
-
431
-
if time.Now().After(oauthToken.ExpiresAt) {
432
-
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
433
-
}
434
-
435
-
repo, err := s.getRepoActorByDid(oauthToken.Sub)
436
-
if err != nil {
437
-
s.logger.Error("could not find actor in db", "error", err)
438
-
return helpers.ServerError(e, nil)
439
-
}
440
-
441
-
e.Set("repo", repo)
442
-
e.Set("did", repo.Repo.Did)
443
-
e.Set("token", accessToken)
444
-
e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
445
-
446
-
return next(e)
447
-
}
448
-
}
449
-
450
200
func New(args *Args) (*Server, error) {
451
201
if args.Addr == "" {
452
202
return nil, fmt.Errorf("addr must be set")
···
593
343
plcClient: plcClient,
594
344
privateKey: &pkey,
595
345
config: &config{
596
-
Version: args.Version,
597
-
Did: args.Did,
598
-
Hostname: args.Hostname,
599
-
ContactEmail: args.ContactEmail,
600
-
EnforcePeering: false,
601
-
Relays: args.Relays,
602
-
AdminPassword: args.AdminPassword,
603
-
SmtpName: args.SmtpName,
604
-
SmtpEmail: args.SmtpEmail,
346
+
Version: args.Version,
347
+
Did: args.Did,
348
+
Hostname: args.Hostname,
349
+
ContactEmail: args.ContactEmail,
350
+
EnforcePeering: false,
351
+
Relays: args.Relays,
352
+
AdminPassword: args.AdminPassword,
353
+
SmtpName: args.SmtpName,
354
+
SmtpEmail: args.SmtpEmail,
355
+
DefaultAtprotoProxy: args.DefaultAtprotoProxy,
356
+
BlockstoreVariant: args.BlockstoreVariant,
605
357
},
606
358
evtman: events.NewEventManager(events.NewMemPersister()),
607
359
passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
···
611
363
612
364
oauthProvider: provider.NewProvider(provider.Args{
613
365
Hostname: args.Hostname,
614
-
ClientManagerArgs: client_manager.Args{
366
+
ClientManagerArgs: client.ManagerArgs{
615
367
Cli: oauthCli,
616
368
Logger: args.Logger,
617
369
},
618
-
DpopManagerArgs: dpop_manager.Args{
370
+
DpopManagerArgs: dpop.ManagerArgs{
619
371
NonceSecret: nonceSecret,
620
372
NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
621
373
OnNonceSecretCreated: func(newNonce []byte) {
···
712
464
s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
713
465
s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
714
466
s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
467
+
s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
715
468
716
469
// repo
717
470
s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
···
725
478
s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
726
479
s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
727
480
728
-
// are there any routes that we should be allowing without auth? i dont think so but idk
729
-
s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
730
-
s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
731
-
732
481
// admin routes
733
482
s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
734
483
s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
484
+
485
+
// are there any routes that we should be allowing without auth? i dont think so but idk
486
+
s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
487
+
s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
735
488
}
736
489
737
490
func (s *Server) Serve(ctx context.Context) error {
···
893
646
go s.doBackup()
894
647
}
895
648
}
649
+
650
+
func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
651
+
if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
652
+
return err
653
+
}
654
+
655
+
return nil
656
+
}
+5
-4
server/templates/account.html
+5
-4
server/templates/account.html
···
24
24
</div>
25
25
{{ else }} {{ range .Tokens }}
26
26
<div class="base-container">
27
-
<h4>{{ .ClientId }}</h4>
28
-
<p>Created: {{ .CreatedAt }}</p>
29
-
<p>Updated: {{ .UpdatedAt }}</p>
30
-
<p>Expires: {{ .ExpiresAt }}</p>
27
+
<h4>{{ .ClientName }}</h4>
28
+
<p>Session Age: {{ .Age}}</p>
29
+
<p>Last Updated: {{ .LastUpdated }} ago</p>
30
+
<p>Expires In: {{ .ExpiresIn }}</p>
31
+
<p>IP Address: {{ .Ip }}</p>
31
32
<form action="/account/revoke" method="post">
32
33
<input type="hidden" name="token" value="{{ .Token }}" />
33
34
<button type="submit" value="">Revoke</button>
+155
sqlite_blockstore/sqlite_blockstore.go
+155
sqlite_blockstore/sqlite_blockstore.go
···
1
+
package sqlite_blockstore
2
+
3
+
import (
4
+
"context"
5
+
"fmt"
6
+
7
+
"github.com/bluesky-social/indigo/atproto/syntax"
8
+
"github.com/haileyok/cocoon/internal/db"
9
+
"github.com/haileyok/cocoon/models"
10
+
blocks "github.com/ipfs/go-block-format"
11
+
"github.com/ipfs/go-cid"
12
+
"gorm.io/gorm/clause"
13
+
)
14
+
15
+
type SqliteBlockstore struct {
16
+
db *db.DB
17
+
did string
18
+
readonly bool
19
+
inserts map[cid.Cid]blocks.Block
20
+
}
21
+
22
+
func New(did string, db *db.DB) *SqliteBlockstore {
23
+
return &SqliteBlockstore{
24
+
did: did,
25
+
db: db,
26
+
readonly: false,
27
+
inserts: map[cid.Cid]blocks.Block{},
28
+
}
29
+
}
30
+
31
+
func NewReadOnly(did string, db *db.DB) *SqliteBlockstore {
32
+
return &SqliteBlockstore{
33
+
did: did,
34
+
db: db,
35
+
readonly: true,
36
+
inserts: map[cid.Cid]blocks.Block{},
37
+
}
38
+
}
39
+
40
+
func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, error) {
41
+
var block models.Block
42
+
43
+
maybeBlock, ok := bs.inserts[cid]
44
+
if ok {
45
+
return maybeBlock, nil
46
+
}
47
+
48
+
if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
49
+
return nil, err
50
+
}
51
+
52
+
b, err := blocks.NewBlockWithCid(block.Value, cid)
53
+
if err != nil {
54
+
return nil, err
55
+
}
56
+
57
+
return b, nil
58
+
}
59
+
60
+
func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error {
61
+
bs.inserts[block.Cid()] = block
62
+
63
+
if bs.readonly {
64
+
return nil
65
+
}
66
+
67
+
b := models.Block{
68
+
Did: bs.did,
69
+
Cid: block.Cid().Bytes(),
70
+
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
71
+
Value: block.RawData(),
72
+
}
73
+
74
+
if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{
75
+
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
76
+
UpdateAll: true,
77
+
}}).Error; err != nil {
78
+
return err
79
+
}
80
+
81
+
return nil
82
+
}
83
+
84
+
func (bs *SqliteBlockstore) DeleteBlock(context.Context, cid.Cid) error {
85
+
panic("not implemented")
86
+
}
87
+
88
+
func (bs *SqliteBlockstore) Has(context.Context, cid.Cid) (bool, error) {
89
+
panic("not implemented")
90
+
}
91
+
92
+
func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) {
93
+
panic("not implemented")
94
+
}
95
+
96
+
func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error {
97
+
tx := bs.db.BeginDangerously()
98
+
99
+
for _, block := range blocks {
100
+
bs.inserts[block.Cid()] = block
101
+
102
+
if bs.readonly {
103
+
continue
104
+
}
105
+
106
+
b := models.Block{
107
+
Did: bs.did,
108
+
Cid: block.Cid().Bytes(),
109
+
Rev: syntax.NewTIDNow(0).String(), // TODO: WARN, this is bad. don't do this
110
+
Value: block.RawData(),
111
+
}
112
+
113
+
if err := tx.Clauses(clause.OnConflict{
114
+
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
115
+
UpdateAll: true,
116
+
}).Create(&b).Error; err != nil {
117
+
tx.Rollback()
118
+
return err
119
+
}
120
+
}
121
+
122
+
if bs.readonly {
123
+
return nil
124
+
}
125
+
126
+
tx.Commit()
127
+
128
+
return nil
129
+
}
130
+
131
+
func (bs *SqliteBlockstore) AllKeysChan(ctx context.Context) (<-chan cid.Cid, error) {
132
+
panic("not implemented")
133
+
}
134
+
135
+
func (bs *SqliteBlockstore) HashOnRead(enabled bool) {
136
+
panic("not implemented")
137
+
}
138
+
139
+
func (bs *SqliteBlockstore) Execute(ctx context.Context) error {
140
+
if !bs.readonly {
141
+
return fmt.Errorf("blockstore was not readonly")
142
+
}
143
+
144
+
bs.readonly = false
145
+
for _, b := range bs.inserts {
146
+
bs.Put(ctx, b)
147
+
}
148
+
bs.readonly = true
149
+
150
+
return nil
151
+
}
152
+
153
+
func (bs *SqliteBlockstore) GetLog() map[cid.Cid]blocks.Block {
154
+
return bs.inserts
155
+
}