fork
Configure Feed
Select the types of activity you want to include in your feed.
Live video on the AT Protocol
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/google/uuid"
11 "github.com/gorilla/websocket"
12 "github.com/julienschmidt/httprouter"
13
14 apierrors "stream.place/streamplace/pkg/errors"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/renditions"
17 "stream.place/streamplace/pkg/spmetrics"
18 "stream.place/streamplace/pkg/streamplace"
19)
20
21// todo: does this mean a whole message has to fit within the buffer?
22var upgrader = websocket.Upgrader{
23 ReadBufferSize: 1024,
24 WriteBufferSize: 1024,
25 CheckOrigin: func(r *http.Request) bool {
26 return true
27 },
28}
29
30var pingPeriod = 5 * time.Second
31
32func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
33 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
34 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
35 ip, _, err := net.SplitHostPort(req.RemoteAddr)
36 if err != nil {
37 ip = req.RemoteAddr
38 }
39
40 if a.CLI.RateLimitWebsocket > 0 {
41 if !a.connTracker.AddConnection(ip) {
42 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
43 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
44 return
45 }
46
47 defer a.connTracker.RemoveConnection(ip)
48 }
49
50 uu, _ := uuid.NewV7()
51 connID := uu.String()
52
53 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
54 log.Log(ctx, "websocket opened")
55 spmetrics.WebsocketsOpen.Inc()
56 defer spmetrics.WebsocketsOpen.Dec()
57 user := params.ByName("repoDID")
58 if user == "" {
59 apierrors.WriteHTTPBadRequest(w, "user required", nil)
60 return
61 }
62 repoDID, err := a.NormalizeUser(ctx, user)
63 if err != nil {
64 apierrors.WriteHTTPNotFound(w, "user not found", err)
65 return
66 }
67 conn, err := upgrader.Upgrade(w, req, nil)
68 if err != nil {
69 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
70 return
71 }
72 ctx, cancel := context.WithCancel(ctx)
73 defer cancel()
74 defer conn.Close()
75
76 initialBurst := make(chan any, 200)
77 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
78 if err != nil {
79 log.Error(ctx, "could not set read deadline", "error", err)
80 return
81 }
82
83 pongCh := make(chan struct{})
84
85 go func() {
86 for {
87 select {
88 case <-ctx.Done():
89 return
90 case <-pongCh:
91 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
92 if err != nil {
93 log.Error(ctx, "could not set read deadline", "error", err)
94 return
95 }
96 case <-time.After(30 * time.Second):
97 log.Log(ctx, "websocket timeout, closing connection")
98 // timeout!
99 conn.Close()
100 cancel()
101 return
102 }
103 }
104 }()
105
106 conn.SetPongHandler(func(appData string) error {
107 log.Debug(ctx, "received pong", "appData", appData)
108 pongCh <- struct{}{}
109 return nil
110 })
111 go func() {
112
113 ch := a.Bus.Subscribe(repoDID)
114 defer a.Bus.Unsubscribe(repoDID, ch)
115 // Create a ticker that fires every 3 seconds
116 ticker := time.NewTicker(3 * time.Second)
117 pingTicker := time.NewTicker(pingPeriod)
118 defer ticker.Stop()
119 defer pingTicker.Stop()
120
121 send := func(msg any) {
122 bs, err := json.Marshal(msg)
123 if err != nil {
124 log.Error(ctx, "could not marshal message", "error", err)
125 return
126 }
127 log.Debug(ctx, "sending message", "message", string(bs))
128 err = conn.WriteMessage(websocket.TextMessage, bs)
129 if err != nil {
130 log.Error(ctx, "could not write message", "error", err)
131 return
132 }
133 }
134
135 for {
136 select {
137 case msg := <-ch:
138 send(msg)
139 case msg := <-initialBurst:
140 send(msg)
141 case <-ticker.C:
142 count := a.Bus.GetViewerCount(repoDID)
143 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
144 if err != nil {
145 log.Error(ctx, "could not marshal view count", "error", err)
146 continue
147 }
148 err = conn.WriteMessage(websocket.TextMessage, bs)
149 if err != nil {
150 log.Error(ctx, "could not write ping message", "error", err)
151 return
152 }
153 case <-pingTicker.C:
154 err := conn.WriteMessage(websocket.PingMessage, []byte{})
155 if err != nil {
156 log.Error(ctx, "could not write ping message", "error", err)
157 return
158 }
159 case <-ctx.Done():
160 log.Debug(ctx, "context done, stopping websocket sender")
161 return
162 }
163 }
164 }()
165
166 go func() {
167 profile, err := a.Model.GetRepo(repoDID)
168 if err != nil {
169 log.Error(ctx, "could not get profile", "error", err)
170 return
171 }
172 if profile != nil {
173 p := map[string]any{
174 "$type": "app.bsky.actor.defs#profileViewBasic",
175 "did": repoDID,
176 "handle": profile.Handle,
177 }
178 initialBurst <- p
179 }
180 }()
181
182 go func() {
183 seg, err := a.Model.LatestSegmentForUser(repoDID)
184 if err != nil {
185 log.Error(ctx, "could not get replies", "error", err)
186 return
187 }
188 spSeg, err := seg.ToStreamplaceSegment()
189 if err != nil {
190 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
191 return
192 }
193 initialBurst <- spSeg
194 if a.CLI.LivepeerGatewayURL != "" {
195 renditions, err := renditions.GenerateRenditions(spSeg)
196 if err != nil {
197 log.Error(ctx, "could not generate renditions", "error", err)
198 return
199 }
200 outRs := streamplace.Defs_Renditions{
201 LexiconTypeID: "place.stream.defs#renditions",
202 }
203 outRs.Renditions = []*streamplace.Defs_Rendition{}
204 for _, r := range renditions {
205 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
206 LexiconTypeID: "place.stream.defs#rendition",
207 Name: r.Name,
208 })
209 }
210 initialBurst <- outRs
211 }
212 }()
213
214 go func() {
215 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
216 if err != nil {
217 log.Error(ctx, "could not get latest livestream", "error", err)
218 return
219 }
220 lsv, err := ls.ToLivestreamView()
221 if err != nil {
222 log.Error(ctx, "could not marshal livestream", "error", err)
223 return
224 }
225 initialBurst <- lsv
226 }()
227
228 go func() {
229 count := a.Bus.GetViewerCount(repoDID)
230 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
231 }()
232
233 go func() {
234 messages, err := a.Model.MostRecentChatMessages(repoDID)
235 if err != nil {
236 log.Error(ctx, "could not get chat messages", "error", err)
237 return
238 }
239 for _, message := range messages {
240 initialBurst <- message
241 }
242 }()
243
244 for {
245 messageType, message, err := conn.ReadMessage()
246 if err != nil {
247 log.Error(ctx, "error reading message", "error", err)
248 break
249 }
250 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
251 }
252 }
253}