forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
this repo has no description
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "strings"
9 "time"
10
11 comatproto "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/identity"
13 "github.com/bluesky-social/indigo/xrpc"
14 "github.com/go-chi/chi/v5"
15 "tangled.sh/tangled.sh/core/appview"
16 "tangled.sh/tangled.sh/core/appview/auth"
17 "tangled.sh/tangled.sh/core/appview/db"
18)
19
20type Middleware func(http.Handler) http.Handler
21
22func AuthMiddleware(s *State) Middleware {
23 return func(next http.Handler) http.Handler {
24 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
25 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
26 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
27 }
28 if r.Header.Get("HX-Request") == "true" {
29 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
30 w.Header().Set("HX-Redirect", "/login")
31 w.WriteHeader(http.StatusOK)
32 }
33 }
34
35 session, err := s.auth.GetSession(r)
36 if session.IsNew || err != nil {
37 log.Printf("not logged in, redirecting")
38 redirectFunc(w, r)
39 return
40 }
41
42 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
43 if !ok || !authorized {
44 log.Printf("not logged in, redirecting")
45 redirectFunc(w, r)
46 return
47 }
48
49 // refresh if nearing expiry
50 // TODO: dedup with /login
51 expiryStr := session.Values[appview.SessionExpiry].(string)
52 expiry, err := time.Parse(time.RFC3339, expiryStr)
53 if err != nil {
54 log.Println("invalid expiry time", err)
55 redirectFunc(w, r)
56 return
57 }
58 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
59 did, ok2 := session.Values[appview.SessionDid].(string)
60 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
61
62 if !ok1 || !ok2 || !ok3 {
63 log.Println("invalid expiry time", err)
64 redirectFunc(w, r)
65 return
66 }
67
68 if time.Now().After(expiry) {
69 log.Println("token expired, refreshing ...")
70
71 client := xrpc.Client{
72 Host: pdsUrl,
73 Auth: &xrpc.AuthInfo{
74 Did: did,
75 AccessJwt: refreshJwt,
76 RefreshJwt: refreshJwt,
77 },
78 }
79 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
80 if err != nil {
81 log.Println("failed to refresh session", err)
82 redirectFunc(w, r)
83 return
84 }
85
86 sessionish := auth.RefreshSessionWrapper{atSession}
87
88 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
89 if err != nil {
90 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
91 return
92 }
93
94 log.Println("successfully refreshed token")
95 }
96
97 next.ServeHTTP(w, r)
98 })
99 }
100}
101
102func knotRoleMiddleware(s *State, group string) Middleware {
103 return func(next http.Handler) http.Handler {
104 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
105 // requires auth also
106 actor := s.auth.GetUser(r)
107 if actor == nil {
108 // we need a logged in user
109 log.Printf("not logged in, redirecting")
110 http.Error(w, "Forbiden", http.StatusUnauthorized)
111 return
112 }
113 domain := chi.URLParam(r, "domain")
114 if domain == "" {
115 http.Error(w, "malformed url", http.StatusBadRequest)
116 return
117 }
118
119 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
120 if err != nil || !ok {
121 // we need a logged in user
122 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
123 http.Error(w, "Forbiden", http.StatusUnauthorized)
124 return
125 }
126
127 next.ServeHTTP(w, r)
128 })
129 }
130}
131
132func KnotOwner(s *State) Middleware {
133 return knotRoleMiddleware(s, "server:owner")
134}
135
136func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
137 return func(next http.Handler) http.Handler {
138 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
139 // requires auth also
140 actor := s.auth.GetUser(r)
141 if actor == nil {
142 // we need a logged in user
143 log.Printf("not logged in, redirecting")
144 http.Error(w, "Forbiden", http.StatusUnauthorized)
145 return
146 }
147 f, err := fullyResolvedRepo(r)
148 if err != nil {
149 http.Error(w, "malformed url", http.StatusBadRequest)
150 return
151 }
152
153 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.OwnerSlashRepo(), requiredPerm)
154 if err != nil || !ok {
155 // we need a logged in user
156 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
157 http.Error(w, "Forbiden", http.StatusUnauthorized)
158 return
159 }
160
161 next.ServeHTTP(w, r)
162 })
163 }
164}
165
166func StripLeadingAt(next http.Handler) http.Handler {
167 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
168 path := req.URL.EscapedPath()
169 if strings.HasPrefix(path, "/@") {
170 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@")
171 }
172 next.ServeHTTP(w, req)
173 })
174}
175
176func ResolveIdent(s *State) Middleware {
177 return func(next http.Handler) http.Handler {
178 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
179 didOrHandle := chi.URLParam(req, "user")
180
181 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
182 if err != nil {
183 // invalid did or handle
184 log.Println("failed to resolve did/handle:", err)
185 w.WriteHeader(http.StatusNotFound)
186 return
187 }
188
189 ctx := context.WithValue(req.Context(), "resolvedId", *id)
190
191 next.ServeHTTP(w, req.WithContext(ctx))
192 })
193 }
194}
195
196func ResolveRepo(s *State) Middleware {
197 return func(next http.Handler) http.Handler {
198 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
199 repoName := chi.URLParam(req, "repo")
200 id, ok := req.Context().Value("resolvedId").(identity.Identity)
201 if !ok {
202 log.Println("malformed middleware")
203 w.WriteHeader(http.StatusInternalServerError)
204 return
205 }
206
207 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
208 if err != nil {
209 // invalid did or handle
210 log.Println("failed to resolve repo")
211 w.WriteHeader(http.StatusNotFound)
212 return
213 }
214
215 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
216 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
217 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
218 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
219 next.ServeHTTP(w, req.WithContext(ctx))
220 })
221 }
222}
223
224// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
225func ResolvePull(s *State) Middleware {
226 return func(next http.Handler) http.Handler {
227 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
228 f, err := fullyResolvedRepo(r)
229 if err != nil {
230 log.Println("failed to fully resolve repo", err)
231 http.Error(w, "invalid repo url", http.StatusNotFound)
232 return
233 }
234
235 prId := chi.URLParam(r, "pull")
236 prIdInt, err := strconv.Atoi(prId)
237 if err != nil {
238 http.Error(w, "bad pr id", http.StatusBadRequest)
239 log.Println("failed to parse pr id", err)
240 return
241 }
242
243 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt)
244 if err != nil {
245 log.Println("failed to get pull and comments", err)
246 return
247 }
248
249 ctx := context.WithValue(r.Context(), "pull", pr)
250
251 next.ServeHTTP(w, r.WithContext(ctx))
252 })
253 }
254}