forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
Monorepo for Tangled
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package state
2
3import (
4 "context"
5 "log"
6 "net/http"
7 "strconv"
8 "strings"
9 "time"
10
11 "slices"
12
13 comatproto "github.com/bluesky-social/indigo/api/atproto"
14 "github.com/bluesky-social/indigo/atproto/identity"
15 "github.com/bluesky-social/indigo/xrpc"
16 "github.com/go-chi/chi/v5"
17 "tangled.sh/tangled.sh/core/appview"
18 "tangled.sh/tangled.sh/core/appview/auth"
19 "tangled.sh/tangled.sh/core/appview/db"
20)
21
22type Middleware func(http.Handler) http.Handler
23
24func AuthMiddleware(s *State) Middleware {
25 return func(next http.Handler) http.Handler {
26 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
27 redirectFunc := func(w http.ResponseWriter, r *http.Request) {
28 http.Redirect(w, r, "/login", http.StatusTemporaryRedirect)
29 }
30 if r.Header.Get("HX-Request") == "true" {
31 redirectFunc = func(w http.ResponseWriter, _ *http.Request) {
32 w.Header().Set("HX-Redirect", "/login")
33 w.WriteHeader(http.StatusOK)
34 }
35 }
36
37 session, err := s.auth.GetSession(r)
38 if session.IsNew || err != nil {
39 log.Printf("not logged in, redirecting")
40 redirectFunc(w, r)
41 return
42 }
43
44 authorized, ok := session.Values[appview.SessionAuthenticated].(bool)
45 if !ok || !authorized {
46 log.Printf("not logged in, redirecting")
47 redirectFunc(w, r)
48 return
49 }
50
51 // refresh if nearing expiry
52 // TODO: dedup with /login
53 expiryStr := session.Values[appview.SessionExpiry].(string)
54 expiry, err := time.Parse(time.RFC3339, expiryStr)
55 if err != nil {
56 log.Println("invalid expiry time", err)
57 redirectFunc(w, r)
58 return
59 }
60 pdsUrl, ok1 := session.Values[appview.SessionPds].(string)
61 did, ok2 := session.Values[appview.SessionDid].(string)
62 refreshJwt, ok3 := session.Values[appview.SessionRefreshJwt].(string)
63
64 if !ok1 || !ok2 || !ok3 {
65 log.Println("invalid expiry time", err)
66 redirectFunc(w, r)
67 return
68 }
69
70 if time.Now().After(expiry) {
71 log.Println("token expired, refreshing ...")
72
73 client := xrpc.Client{
74 Host: pdsUrl,
75 Auth: &xrpc.AuthInfo{
76 Did: did,
77 AccessJwt: refreshJwt,
78 RefreshJwt: refreshJwt,
79 },
80 }
81 atSession, err := comatproto.ServerRefreshSession(r.Context(), &client)
82 if err != nil {
83 log.Println("failed to refresh session", err)
84 redirectFunc(w, r)
85 return
86 }
87
88 sessionish := auth.RefreshSessionWrapper{atSession}
89
90 err = s.auth.StoreSession(r, w, &sessionish, pdsUrl)
91 if err != nil {
92 log.Printf("failed to store session for did: %s\n: %s", atSession.Did, err)
93 return
94 }
95
96 log.Println("successfully refreshed token")
97 }
98
99 next.ServeHTTP(w, r)
100 })
101 }
102}
103
104func knotRoleMiddleware(s *State, group string) Middleware {
105 return func(next http.Handler) http.Handler {
106 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
107 // requires auth also
108 actor := s.auth.GetUser(r)
109 if actor == nil {
110 // we need a logged in user
111 log.Printf("not logged in, redirecting")
112 http.Error(w, "Forbiden", http.StatusUnauthorized)
113 return
114 }
115 domain := chi.URLParam(r, "domain")
116 if domain == "" {
117 http.Error(w, "malformed url", http.StatusBadRequest)
118 return
119 }
120
121 ok, err := s.enforcer.E.HasGroupingPolicy(actor.Did, group, domain)
122 if err != nil || !ok {
123 // we need a logged in user
124 log.Printf("%s does not have perms of a %s in domain %s", actor.Did, group, domain)
125 http.Error(w, "Forbiden", http.StatusUnauthorized)
126 return
127 }
128
129 next.ServeHTTP(w, r)
130 })
131 }
132}
133
134func KnotOwner(s *State) Middleware {
135 return knotRoleMiddleware(s, "server:owner")
136}
137
138func RepoPermissionMiddleware(s *State, requiredPerm string) Middleware {
139 return func(next http.Handler) http.Handler {
140 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
141 // requires auth also
142 actor := s.auth.GetUser(r)
143 if actor == nil {
144 // we need a logged in user
145 log.Printf("not logged in, redirecting")
146 http.Error(w, "Forbiden", http.StatusUnauthorized)
147 return
148 }
149 f, err := fullyResolvedRepo(r)
150 if err != nil {
151 http.Error(w, "malformed url", http.StatusBadRequest)
152 return
153 }
154
155 ok, err := s.enforcer.E.Enforce(actor.Did, f.Knot, f.DidSlashRepo(), requiredPerm)
156 if err != nil || !ok {
157 // we need a logged in user
158 log.Printf("%s does not have perms of a %s in repo %s", actor.Did, requiredPerm, f.OwnerSlashRepo())
159 http.Error(w, "Forbiden", http.StatusUnauthorized)
160 return
161 }
162
163 next.ServeHTTP(w, r)
164 })
165 }
166}
167
168func StripLeadingAt(next http.Handler) http.Handler {
169 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
170 path := req.URL.EscapedPath()
171 if strings.HasPrefix(path, "/@") {
172 req.URL.RawPath = "/" + strings.TrimPrefix(path, "/@")
173 }
174 next.ServeHTTP(w, req)
175 })
176}
177
178func ResolveIdent(s *State) Middleware {
179 excluded := []string{"favicon.ico"}
180
181 return func(next http.Handler) http.Handler {
182 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
183 didOrHandle := chi.URLParam(req, "user")
184 if slices.Contains(excluded, didOrHandle) {
185 next.ServeHTTP(w, req)
186 return
187 }
188
189 id, err := s.resolver.ResolveIdent(req.Context(), didOrHandle)
190 if err != nil {
191 // invalid did or handle
192 log.Println("failed to resolve did/handle:", err)
193 w.WriteHeader(http.StatusNotFound)
194 return
195 }
196
197 ctx := context.WithValue(req.Context(), "resolvedId", *id)
198
199 next.ServeHTTP(w, req.WithContext(ctx))
200 })
201 }
202}
203
204func ResolveRepo(s *State) Middleware {
205 return func(next http.Handler) http.Handler {
206 return http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
207 repoName := chi.URLParam(req, "repo")
208 id, ok := req.Context().Value("resolvedId").(identity.Identity)
209 if !ok {
210 log.Println("malformed middleware")
211 w.WriteHeader(http.StatusInternalServerError)
212 return
213 }
214
215 repo, err := db.GetRepo(s.db, id.DID.String(), repoName)
216 if err != nil {
217 // invalid did or handle
218 log.Println("failed to resolve repo")
219 w.WriteHeader(http.StatusNotFound)
220 return
221 }
222
223 ctx := context.WithValue(req.Context(), "knot", repo.Knot)
224 ctx = context.WithValue(ctx, "repoAt", repo.AtUri)
225 ctx = context.WithValue(ctx, "repoDescription", repo.Description)
226 ctx = context.WithValue(ctx, "repoAddedAt", repo.Created.Format(time.RFC3339))
227 next.ServeHTTP(w, req.WithContext(ctx))
228 })
229 }
230}
231
232// middleware that is tacked on top of /{user}/{repo}/pulls/{pull}
233func ResolvePull(s *State) Middleware {
234 return func(next http.Handler) http.Handler {
235 return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
236 f, err := fullyResolvedRepo(r)
237 if err != nil {
238 log.Println("failed to fully resolve repo", err)
239 http.Error(w, "invalid repo url", http.StatusNotFound)
240 return
241 }
242
243 prId := chi.URLParam(r, "pull")
244 prIdInt, err := strconv.Atoi(prId)
245 if err != nil {
246 http.Error(w, "bad pr id", http.StatusBadRequest)
247 log.Println("failed to parse pr id", err)
248 return
249 }
250
251 pr, err := db.GetPull(s.db, f.RepoAt, prIdInt)
252 if err != nil {
253 log.Println("failed to get pull and comments", err)
254 return
255 }
256
257 ctx := context.WithValue(r.Context(), "pull", pr)
258
259 next.ServeHTTP(w, r.WithContext(ctx))
260 })
261 }
262}