forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
this repo has no description
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strings"
8 "time"
9
10 comatproto "github.com/bluesky-social/indigo/api/atproto"
11 "github.com/bluesky-social/indigo/atproto/identity"
12 "github.com/bluesky-social/indigo/xrpc"
13 "github.com/go-chi/chi/v5"
14 "github.com/sotangled/tangled/appview"
15 "github.com/sotangled/tangled/appview/auth"
16 "github.com/sotangled/tangled/appview/db"
17)
18
19type Middleware func(http.Handler) http.Handler
20
21func AuthMiddleware(s *State) Middleware {
22 return func(next http.Handler) http.Handler {
23 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
24 session, err := s.auth.GetSession(r)
25 if session.IsNew || err != nil {
26 log.Printf("not logged in, redirecting")
27 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
28 return
29 }
30
31 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
32 if !ok || !authorized {
33 log.Printf("not logged in, redirecting")
34 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
35 return
36 }
37
38 // refresh if nearing expiry
39 // TODO: dedup with /login
40 expiryStr := session.Values[appview.SessionExpiry].(string)
41 expiry, err := time.Parse(time.RFC3339, expiryStr)
42 if err != nil {
43 log.Println("invalid expiry time", err)
44 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
45 return
46 }
47 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
48 did, ok2 := session.Values[appview.SessionDid].(string)
49 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
50
51 if !ok1 || !ok2 || !ok3 {
52 log.Println("invalid expiry time", err)
53 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
54 return
55 }
56
57 if time.Now().After(expiry) {
58 log.Println("token expired, refreshing ...")
59
60 client := xrpc.Client{
61 Host: pdsUrl,
62 Auth: &xrpc.AuthInfo{
63 Did: did,
64 AccessJwt: refreshJwt,
65 RefreshJwt: refreshJwt,
66 },
67 }
68 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
69 if err != nil {
70 log.Println("failed to refresh session", err)
71 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
72 return
73 }
74
75 sessionish := auth.RefreshSessionWrapper{atSession}
76
77 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
78 if err != nil {
79 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
80 return
81 }
82
83 log.Println("successfully refreshed token")
84 }
85
86 next.ServeHTTP(w, r)
87 })
88 }
89}
90
91func RoleMiddleware(s *State, group string) Middleware {
92 return func(next http.Handler) http.Handler {
93 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94 // requires auth also
95 actor := s.auth.GetUser(r)
96 if actor == nil {
97 // we need a logged in user
98 log.Printf("not logged in, redirecting")
99 http.Error(w, "Forbiden", http.StatusUnauthorized)
100 return
101 }
102 domain := chi.URLParam(r, "domain")
103 if domain == "" {
104 http.Error(w, "malformed url", http.StatusBadRequest)
105 return
106 }
107
108 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
109 if err != nil || !ok {
110 // we need a logged in user
111 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
112 http.Error(w, "Forbiden", http.StatusUnauthorized)
113 return
114 }
115
116 next.ServeHTTP(w, r)
117 })
118 }
119}
120
121func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
122 return func(next http.Handler) http.Handler {
123 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
124 // requires auth also
125 actor := s.auth.GetUser(r)
126 if actor == nil {
127 // we need a logged in user
128 log.Printf("not logged in, redirecting")
129 http.Error(w, "Forbiden", http.StatusUnauthorized)
130 return
131 }
132 f, err := fullyResolvedRepo(r)
133 if err != nil {
134 http.Error(w, "malformed url", http.StatusBadRequest)
135 return
136 }
137
138 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
139 if err != nil || !ok {
140 // we need a logged in user
141 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
142 http.Error(w, "Forbiden", http.StatusUnauthorized)
143 return
144 }
145
146 next.ServeHTTP(w, r)
147 })
148 }
149}
150
151func StripLeadingAt(next http.Handler) http.Handler {
152 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
153 path := req.URL.Path
154 if strings.HasPrefix(path, "/@") {
155 req.URL.Path = "/" + strings.TrimPrefix(path, "/@")
156 }
157 next.ServeHTTP(w, req)
158 })
159}
160
161func ResolveIdent(s *State) Middleware {
162 return func(next http.Handler) http.Handler {
163 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
164 didOrHandle := chi.URLParam(req, "user")
165
166 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
167 if err != nil {
168 // invalid did or handle
169 log.Println("failed to resolve did/handle:", err)
170 w.WriteHeader(http.StatusNotFound)
171 return
172 }
173
174 ctx := context.WithValue(req.Context(), "resolvedId", *id)
175
176 next.ServeHTTP(w, req.WithContext(ctx))
177 })
178 }
179}
180
181func ResolveRepoKnot(s *State) Middleware {
182 return func(next http.Handler) http.Handler {
183 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
184 repoName := chi.URLParam(req, "repo")
185 id, ok := req.Context().Value("resolvedId").(identity.Identity)
186 if !ok {
187 log.Println("malformed middleware")
188 w.WriteHeader(http.StatusInternalServerError)
189 return
190 }
191
192 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
193 if err != nil {
194 // invalid did or handle
195 log.Println("failed to resolve repo")
196 w.WriteHeader(http.StatusNotFound)
197 return
198 }
199
200 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
201 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
202 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
203 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
204 next.ServeHTTP(w, req.WithContext(ctx))
205 })
206 }
207}