An atproto PDS written in Go

refactor: cleanup oauth package

+2 -2
oauth/client.go oauth/client/client.go
··· 1 - package oauth 2 3 import "github.com/lestrrat-go/jwx/v2/jwk" 4 5 type Client struct { 6 - Metadata *ClientMetadata 7 JWKS jwk.Key 8 }
··· 1 + package client 2 3 import "github.com/lestrrat-go/jwx/v2/jwk" 4 5 type Client struct { 6 + Metadata *Metadata 7 JWKS jwk.Key 8 }
+13 -14
oauth/client_manager/client_manager.go oauth/client/manager.go
··· 1 - package client_manager 2 3 import ( 4 "context" ··· 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 - "github.com/haileyok/cocoon/oauth" 19 "github.com/lestrrat-go/jwx/v2/jwk" 20 ) 21 22 - type ClientManager struct { 23 cli *http.Client 24 logger *slog.Logger 25 jwksCache cache.Cache[string, jwk.Key] 26 - metadataCache cache.Cache[string, oauth.ClientMetadata] 27 } 28 29 - type Args struct { 30 Cli *http.Client 31 Logger *slog.Logger 32 } 33 34 - func New(args Args) *ClientManager { 35 if args.Logger == nil { 36 args.Logger = slog.Default() 37 } ··· 41 } 42 43 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 - metadataCache := cache.NewCache[string, oauth.ClientMetadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 45 46 - return &ClientManager{ 47 cli: args.Cli, 48 logger: args.Logger, 49 jwksCache: jwksCache, ··· 51 } 52 } 53 54 - func (cm *ClientManager) GetClient(ctx context.Context, clientId string) (*oauth.Client, error) { 55 metadata, err := cm.getClientMetadata(ctx, clientId) 56 if err != nil { 57 return nil, err ··· 75 jwks = maybeJwks 76 } 77 78 - return &oauth.Client{ 79 Metadata: metadata, 80 JWKS: jwks, 81 }, nil 82 } 83 84 - func (cm *ClientManager) getClientMetadata(ctx context.Context, clientId string) (*oauth.ClientMetadata, error) { 85 metadataCached, ok := cm.metadataCache.Get(clientId) 86 if !ok { 87 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) ··· 116 } 117 } 118 119 - func (cm *ClientManager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 120 jwks, ok := cm.jwksCache.Get(clientId) 121 if !ok { 122 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) ··· 165 return jwks, nil 166 } 167 168 - func validateAndParseMetadata(clientId string, b []byte) (*oauth.ClientMetadata, error) { 169 var metadataMap map[string]any 170 if err := json.Unmarshal(b, &metadataMap); err != nil { 171 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) ··· 192 } 193 } 194 195 - var metadata oauth.ClientMetadata 196 if err := json.Unmarshal(b, &metadata); err != nil { 197 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 198 }
··· 1 + package client 2 3 import ( 4 "context" ··· 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/lestrrat-go/jwx/v2/jwk" 19 ) 20 21 + type Manager struct { 22 cli *http.Client 23 logger *slog.Logger 24 jwksCache cache.Cache[string, jwk.Key] 25 + metadataCache cache.Cache[string, Metadata] 26 } 27 28 + type ManagerArgs struct { 29 Cli *http.Client 30 Logger *slog.Logger 31 } 32 33 + func NewManager(args ManagerArgs) *Manager { 34 if args.Logger == nil { 35 args.Logger = slog.Default() 36 } ··· 40 } 41 42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 + metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 45 + return &Manager{ 46 cli: args.Cli, 47 logger: args.Logger, 48 jwksCache: jwksCache, ··· 50 } 51 } 52 53 + func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 metadata, err := cm.getClientMetadata(ctx, clientId) 55 if err != nil { 56 return nil, err ··· 74 jwks = maybeJwks 75 } 76 77 + return &Client{ 78 Metadata: metadata, 79 JWKS: jwks, 80 }, nil 81 } 82 83 + func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 84 metadataCached, ok := cm.metadataCache.Get(clientId) 85 if !ok { 86 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) ··· 115 } 116 } 117 118 + func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 119 jwks, ok := cm.jwksCache.Get(clientId) 120 if !ok { 121 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) ··· 164 return jwks, nil 165 } 166 167 + func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 168 var metadataMap map[string]any 169 if err := json.Unmarshal(b, &metadataMap); err != nil { 170 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) ··· 191 } 192 } 193 194 + var metadata Metadata 195 if err := json.Unmarshal(b, &metadata); err != nil { 196 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 197 }
+2 -2
oauth/client_metadata.go oauth/client/metadata.go
··· 1 - package oauth 2 3 - type ClientMetadata struct { 4 ClientID string `json:"client_id"` 5 ClientName string `json:"client_name"` 6 ClientURI string `json:"client_uri"`
··· 1 + package client 2 3 + type Metadata struct { 4 ClientID string `json:"client_id"` 5 ClientName string `json:"client_name"` 6 ClientURI string `json:"client_uri"`
+10 -12
oauth/dpop/dpop_manager/dpop_manager.go oauth/dpop/manager.go
··· 1 - package dpop_manager 2 3 import ( 4 "crypto" ··· 16 "github.com/golang-jwt/jwt/v4" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth/constants" 19 - "github.com/haileyok/cocoon/oauth/dpop" 20 - "github.com/haileyok/cocoon/oauth/dpop/nonce" 21 "github.com/lestrrat-go/jwx/v2/jwa" 22 "github.com/lestrrat-go/jwx/v2/jwk" 23 ) 24 25 - type DpopManager struct { 26 - nonce *nonce.Nonce 27 jtiCache *jtiCache 28 logger *slog.Logger 29 hostname string 30 } 31 32 - type Args struct { 33 NonceSecret []byte 34 NonceRotationInterval time.Duration 35 OnNonceSecretCreated func([]byte) ··· 38 Hostname string 39 } 40 41 - func New(args Args) *DpopManager { 42 if args.Logger == nil { 43 args.Logger = slog.Default() 44 } ··· 51 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 52 } 53 54 - return &DpopManager{ 55 - nonce: nonce.NewNonce(nonce.Args{ 56 RotationInterval: args.NonceRotationInterval, 57 Secret: args.NonceSecret, 58 OnSecretCreated: args.OnNonceSecretCreated, ··· 63 } 64 } 65 66 - func (dm *DpopManager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*dpop.Proof, error) { 67 if reqMethod == "" { 68 return nil, errors.New("HTTP method is required") 69 } ··· 226 227 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 228 229 - return &dpop.Proof{ 230 JTI: jti, 231 JKT: thumb, 232 HTM: htm, ··· 246 } 247 } 248 249 - func (dm *DpopManager) NextNonce() string { 250 return dm.nonce.NextNonce() 251 }
··· 1 + package dpop 2 3 import ( 4 "crypto" ··· 16 "github.com/golang-jwt/jwt/v4" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/haileyok/cocoon/oauth/constants" 19 "github.com/lestrrat-go/jwx/v2/jwa" 20 "github.com/lestrrat-go/jwx/v2/jwk" 21 ) 22 23 + type Manager struct { 24 + nonce *Nonce 25 jtiCache *jtiCache 26 logger *slog.Logger 27 hostname string 28 } 29 30 + type ManagerArgs struct { 31 NonceSecret []byte 32 NonceRotationInterval time.Duration 33 OnNonceSecretCreated func([]byte) ··· 36 Hostname string 37 } 38 39 + func NewManager(args ManagerArgs) *Manager { 40 if args.Logger == nil { 41 args.Logger = slog.Default() 42 } ··· 49 args.Logger.Warn("nonce secret passed to dpop manager was nil. existing sessions may break. consider saving and restoring your nonce.") 50 } 51 52 + return &Manager{ 53 + nonce: NewNonce(NonceArgs{ 54 RotationInterval: args.NonceRotationInterval, 55 Secret: args.NonceSecret, 56 OnSecretCreated: args.OnNonceSecretCreated, ··· 61 } 62 } 63 64 + func (dm *Manager) CheckProof(reqMethod, reqUrl string, headers http.Header, accessToken *string) (*Proof, error) { 65 if reqMethod == "" { 66 return nil, errors.New("HTTP method is required") 67 } ··· 224 225 thumb := base64.RawURLEncoding.EncodeToString(thumbBytes) 226 227 + return &Proof{ 228 JTI: jti, 229 JKT: thumb, 230 HTM: htm, ··· 244 } 245 } 246 247 + func (dm *Manager) NextNonce() string { 248 return dm.nonce.NextNonce() 249 }
+1 -1
oauth/dpop/dpop_manager/jti_cache.go oauth/dpop/jti_cache.go
··· 1 - package dpop_manager 2 3 import ( 4 "sync"
··· 1 + package dpop 2 3 import ( 4 "sync"
+3 -3
oauth/dpop/nonce/nonce.go oauth/dpop/nonce.go
··· 1 - package nonce 2 3 import ( 4 "crypto/hmac" ··· 24 next string 25 } 26 27 - type Args struct { 28 RotationInterval time.Duration 29 Secret []byte 30 OnSecretCreated func([]byte) 31 } 32 33 - func NewNonce(args Args) *Nonce { 34 if args.RotationInterval == 0 { 35 args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 }
··· 1 + package dpop 2 3 import ( 4 "crypto/hmac" ··· 24 next string 25 } 26 27 + type NonceArgs struct { 28 RotationInterval time.Duration 29 Secret []byte 30 OnSecretCreated func([]byte) 31 } 32 33 + func NewNonce(args NonceArgs) *Nonce { 34 if args.RotationInterval == 0 { 35 args.RotationInterval = constants.NonceMaxRotationInterval / 3 36 }
+3 -26
oauth/provider/client_auth.go
··· 3 import ( 4 "context" 5 "crypto" 6 - "database/sql/driver" 7 "encoding/base64" 8 - "encoding/json" 9 "errors" 10 "fmt" 11 "time" 12 13 "github.com/golang-jwt/jwt/v4" 14 - "github.com/haileyok/cocoon/oauth" 15 "github.com/haileyok/cocoon/oauth/constants" 16 "github.com/haileyok/cocoon/oauth/dpop" 17 ) 18 19 - type ClientAuth struct { 20 - Method string 21 - Alg string 22 - Kid string 23 - Jkt string 24 - Jti string 25 - Exp *float64 26 - } 27 - 28 - func (ca *ClientAuth) Scan(value any) error { 29 - b, ok := value.([]byte) 30 - if !ok { 31 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 32 - } 33 - return json.Unmarshal(b, ca) 34 - } 35 - 36 - func (ca ClientAuth) Value() (driver.Value, error) { 37 - return json.Marshal(ca) 38 - } 39 - 40 type AuthenticateClientOptions struct { 41 AllowMissingDpopProof bool 42 } ··· 47 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 48 } 49 50 - func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*oauth.Client, *ClientAuth, error) { 51 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 52 if err != nil { 53 return nil, nil, fmt.Errorf("failed to get client: %w", err) ··· 69 return client, clientAuth, nil 70 } 71 72 - func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *oauth.Client) (*ClientAuth, error) { 73 metadata := client.Metadata 74 75 if metadata.TokenEndpointAuthMethod == "none" {
··· 3 import ( 4 "context" 5 "crypto" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "time" 10 11 "github.com/golang-jwt/jwt/v4" 12 + "github.com/haileyok/cocoon/oauth/client" 13 "github.com/haileyok/cocoon/oauth/constants" 14 "github.com/haileyok/cocoon/oauth/dpop" 15 ) 16 17 type AuthenticateClientOptions struct { 18 AllowMissingDpopProof bool 19 } ··· 24 ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 25 } 26 27 + func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) { 28 client, err := p.ClientManager.GetClient(ctx, req.ClientID) 29 if err != nil { 30 return nil, nil, fmt.Errorf("failed to get client: %w", err) ··· 46 return client, clientAuth, nil 47 } 48 49 + func (p *Provider) Authenticate(_ context.Context, req AuthenticateClientRequestBase, client *client.Client) (*ClientAuth, error) { 50 metadata := client.Metadata 51 52 if metadata.TokenEndpointAuthMethod == "none" {
+81
oauth/provider/models.go
···
··· 1 + package provider 2 + 3 + import ( 4 + "database/sql/driver" 5 + "encoding/json" 6 + "fmt" 7 + "time" 8 + 9 + "gorm.io/gorm" 10 + ) 11 + 12 + type ClientAuth struct { 13 + Method string 14 + Alg string 15 + Kid string 16 + Jkt string 17 + Jti string 18 + Exp *float64 19 + } 20 + 21 + func (ca *ClientAuth) Scan(value any) error { 22 + b, ok := value.([]byte) 23 + if !ok { 24 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 25 + } 26 + return json.Unmarshal(b, ca) 27 + } 28 + 29 + func (ca ClientAuth) Value() (driver.Value, error) { 30 + return json.Marshal(ca) 31 + } 32 + 33 + type ParRequest struct { 34 + AuthenticateClientRequestBase 35 + ResponseType string `form:"response_type" json:"response_type" validate:"required"` 36 + CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 37 + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 38 + State string `form:"state" json:"state" validate:"required"` 39 + RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 40 + Scope string `form:"scope" json:"scope" validate:"required"` 41 + LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 42 + DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 43 + } 44 + 45 + func (opr *ParRequest) Scan(value any) error { 46 + b, ok := value.([]byte) 47 + if !ok { 48 + return fmt.Errorf("failed to unmarshal OauthParRequest value") 49 + } 50 + return json.Unmarshal(b, opr) 51 + } 52 + 53 + func (opr ParRequest) Value() (driver.Value, error) { 54 + return json.Marshal(opr) 55 + } 56 + 57 + type OauthToken struct { 58 + gorm.Model 59 + ClientId string `gorm:"index"` 60 + ClientAuth ClientAuth `gorm:"type:json"` 61 + Parameters ParRequest `gorm:"type:json"` 62 + ExpiresAt time.Time `gorm:"index"` 63 + DeviceId string 64 + Sub string `gorm:"index"` 65 + Code string `gorm:"index"` 66 + Token string `gorm:"uniqueIndex"` 67 + RefreshToken string `gorm:"uniqueIndex"` 68 + } 69 + 70 + type OauthAuthorizationRequest struct { 71 + gorm.Model 72 + RequestId string `gorm:"primaryKey"` 73 + ClientId string `gorm:"index"` 74 + ClientAuth ClientAuth `gorm:"type:json"` 75 + Parameters ParRequest `gorm:"type:json"` 76 + ExpiresAt time.Time `gorm:"index"` 77 + DeviceId *string 78 + Sub *string 79 + Code *string 80 + Accepted *bool 81 + }
+8 -64
oauth/provider/provider.go
··· 1 package provider 2 3 import ( 4 - "database/sql/driver" 5 - "encoding/json" 6 - "fmt" 7 - "time" 8 - 9 - "github.com/haileyok/cocoon/oauth/client_manager" 10 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 11 - "gorm.io/gorm" 12 ) 13 14 type Provider struct { 15 - ClientManager *client_manager.ClientManager 16 - DpopManager *dpop_manager.DpopManager 17 18 hostname string 19 } 20 21 type Args struct { 22 Hostname string 23 - ClientManagerArgs client_manager.Args 24 - DpopManagerArgs dpop_manager.Args 25 } 26 27 func NewProvider(args Args) *Provider { 28 return &Provider{ 29 - ClientManager: client_manager.New(args.ClientManagerArgs), 30 - DpopManager: dpop_manager.New(args.DpopManagerArgs), 31 hostname: args.Hostname, 32 } 33 } ··· 35 func (p *Provider) NextNonce() string { 36 return p.DpopManager.NextNonce() 37 } 38 - 39 - type ParRequest struct { 40 - AuthenticateClientRequestBase 41 - ResponseType string `form:"response_type" json:"response_type" validate:"required"` 42 - CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 43 - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 44 - State string `form:"state" json:"state" validate:"required"` 45 - RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 46 - Scope string `form:"scope" json:"scope" validate:"required"` 47 - LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 48 - DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 49 - } 50 - 51 - func (opr *ParRequest) Scan(value any) error { 52 - b, ok := value.([]byte) 53 - if !ok { 54 - return fmt.Errorf("failed to unmarshal OauthParRequest value") 55 - } 56 - return json.Unmarshal(b, opr) 57 - } 58 - 59 - func (opr ParRequest) Value() (driver.Value, error) { 60 - return json.Marshal(opr) 61 - } 62 - 63 - type OauthToken struct { 64 - gorm.Model 65 - ClientId string `gorm:"index"` 66 - ClientAuth ClientAuth `gorm:"type:json"` 67 - Parameters ParRequest `gorm:"type:json"` 68 - ExpiresAt time.Time `gorm:"index"` 69 - DeviceId string 70 - Sub string `gorm:"index"` 71 - Code string `gorm:"index"` 72 - Token string `gorm:"uniqueIndex"` 73 - RefreshToken string `gorm:"uniqueIndex"` 74 - } 75 - 76 - type OauthAuthorizationRequest struct { 77 - gorm.Model 78 - RequestId string `gorm:"primaryKey"` 79 - ClientId string `gorm:"index"` 80 - ClientAuth ClientAuth `gorm:"type:json"` 81 - Parameters ParRequest `gorm:"type:json"` 82 - ExpiresAt time.Time `gorm:"index"` 83 - DeviceId *string 84 - Sub *string 85 - Code *string 86 - Accepted *bool 87 - }
··· 1 package provider 2 3 import ( 4 + "github.com/haileyok/cocoon/oauth/client" 5 + "github.com/haileyok/cocoon/oauth/dpop" 6 ) 7 8 type Provider struct { 9 + ClientManager *client.Manager 10 + DpopManager *dpop.Manager 11 12 hostname string 13 } 14 15 type Args struct { 16 Hostname string 17 + ClientManagerArgs client.ManagerArgs 18 + DpopManagerArgs dpop.ManagerArgs 19 } 20 21 func NewProvider(args Args) *Provider { 22 return &Provider{ 23 + ClientManager: client.NewManager(args.ClientManagerArgs), 24 + DpopManager: dpop.NewManager(args.DpopManagerArgs), 25 hostname: args.Hostname, 26 } 27 } ··· 29 func (p *Provider) NextNonce() string { 30 return p.DpopManager.NextNonce() 31 }
+4 -4
server/server.go
··· 38 "github.com/haileyok/cocoon/internal/db" 39 "github.com/haileyok/cocoon/internal/helpers" 40 "github.com/haileyok/cocoon/models" 41 - "github.com/haileyok/cocoon/oauth/client_manager" 42 "github.com/haileyok/cocoon/oauth/constants" 43 - "github.com/haileyok/cocoon/oauth/dpop/dpop_manager" 44 "github.com/haileyok/cocoon/oauth/provider" 45 "github.com/haileyok/cocoon/plc" 46 echo_session "github.com/labstack/echo-contrib/session" ··· 611 612 oauthProvider: provider.NewProvider(provider.Args{ 613 Hostname: args.Hostname, 614 - ClientManagerArgs: client_manager.Args{ 615 Cli: oauthCli, 616 Logger: args.Logger, 617 }, 618 - DpopManagerArgs: dpop_manager.Args{ 619 NonceSecret: nonceSecret, 620 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 621 OnNonceSecretCreated: func(newNonce []byte) {
··· 38 "github.com/haileyok/cocoon/internal/db" 39 "github.com/haileyok/cocoon/internal/helpers" 40 "github.com/haileyok/cocoon/models" 41 + "github.com/haileyok/cocoon/oauth/client" 42 "github.com/haileyok/cocoon/oauth/constants" 43 + "github.com/haileyok/cocoon/oauth/dpop" 44 "github.com/haileyok/cocoon/oauth/provider" 45 "github.com/haileyok/cocoon/plc" 46 echo_session "github.com/labstack/echo-contrib/session" ··· 611 612 oauthProvider: provider.NewProvider(provider.Args{ 613 Hostname: args.Hostname, 614 + ClientManagerArgs: client.ManagerArgs{ 615 Cli: oauthCli, 616 Logger: args.Logger, 617 }, 618 + DpopManagerArgs: dpop.ManagerArgs{ 619 NonceSecret: nonceSecret, 620 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 621 OnNonceSecretCreated: func(newNonce []byte) {