1package server
2
3import (
4 "errors"
5 "strings"
6
7 "github.com/Azure/go-autorest/autorest/to"
8 "github.com/bluesky-social/indigo/atproto/syntax"
9 "github.com/haileyok/cocoon/internal/helpers"
10 "github.com/haileyok/cocoon/models"
11 "github.com/labstack/echo/v4"
12 "golang.org/x/crypto/bcrypt"
13 "gorm.io/gorm"
14)
15
16type ComAtprotoServerCreateSessionRequest struct {
17 Identifier string `json:"identifier" validate:"required"`
18 Password string `json:"password" validate:"required"`
19 AuthFactorToken *string `json:"authFactorToken,omitempty"`
20}
21
22type ComAtprotoServerCreateSessionResponse struct {
23 AccessJwt string `json:"accessJwt"`
24 RefreshJwt string `json:"refreshJwt"`
25 Handle string `json:"handle"`
26 Did string `json:"did"`
27 Email string `json:"email"`
28 EmailConfirmed bool `json:"emailConfirmed"`
29 EmailAuthFactor bool `json:"emailAuthFactor"`
30 Active bool `json:"active"`
31 Status *string `json:"status,omitempty"`
32}
33
34func (s *Server) handleCreateSession(e echo.Context) error {
35 ctx := e.Request().Context()
36 logger := s.logger.With("name", "handleServerCreateSession")
37
38 var req ComAtprotoServerCreateSessionRequest
39 if err := e.Bind(&req); err != nil {
40 logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
41 return helpers.ServerError(e, nil)
42 }
43
44 if err := e.Validate(req); err != nil {
45 var verr ValidationError
46 if errors.As(err, &verr) {
47 if verr.Field == "Identifier" {
48 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
49 }
50
51 if verr.Field == "Password" {
52 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
53 }
54 }
55 }
56
57 req.Identifier = strings.ToLower(req.Identifier)
58 var idtype string
59 if _, err := syntax.ParseDID(req.Identifier); err == nil {
60 idtype = "did"
61 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
62 idtype = "handle"
63 } else {
64 idtype = "email"
65 }
66
67 var repo models.RepoActor
68 var err error
69 switch idtype {
70 case "did":
71 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
72 case "handle":
73 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
74 case "email":
75 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
76 }
77
78 if err != nil {
79 if err == gorm.ErrRecordNotFound {
80 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
81 }
82
83 logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
84 return helpers.ServerError(e, nil)
85 }
86
87 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
88 if err != bcrypt.ErrMismatchedHashAndPassword {
89 logger.Error("erorr comparing hash and password", "error", err)
90 }
91 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
92 }
93
94 sess, err := s.createSession(ctx, &repo.Repo)
95 if err != nil {
96 logger.Error("error creating session", "error", err)
97 return helpers.ServerError(e, nil)
98 }
99
100 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
101 AccessJwt: sess.AccessToken,
102 RefreshJwt: sess.RefreshToken,
103 Handle: repo.Handle,
104 Did: repo.Repo.Did,
105 Email: repo.Email,
106 EmailConfirmed: repo.EmailConfirmedAt != nil,
107 EmailAuthFactor: false,
108 Active: repo.Active(),
109 Status: repo.Status(),
110 })
111}