Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net"
8 "net/http"
9 "time"
10
11 bsky "github.com/bluesky-social/indigo/api/bsky"
12 "github.com/google/uuid"
13 "github.com/gorilla/websocket"
14 "github.com/julienschmidt/httprouter"
15
16 "stream.place/streamplace/pkg/atproto"
17 apierrors "stream.place/streamplace/pkg/errors"
18 "stream.place/streamplace/pkg/log"
19 "stream.place/streamplace/pkg/renditions"
20 "stream.place/streamplace/pkg/spmetrics"
21 "stream.place/streamplace/pkg/streamplace"
22)
23
24// todo: does this mean a whole message has to fit within the buffer?
25var upgrader = websocket.Upgrader{
26 ReadBufferSize: 1024,
27 WriteBufferSize: 1024,
28 CheckOrigin: func(r *http.Request) bool {
29 return true
30 },
31}
32
33var pingPeriod = 5 * time.Second
34
35func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
36 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
37 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
38 ip, _, err := net.SplitHostPort(req.RemoteAddr)
39 if err != nil {
40 ip = req.RemoteAddr
41 }
42
43 if a.CLI.RateLimitWebsocket > 0 {
44 if !a.connTracker.AddConnection(ip) {
45 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
46 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
47 return
48 }
49
50 defer a.connTracker.RemoveConnection(ip)
51 }
52
53 uu, _ := uuid.NewV7()
54 connID := uu.String()
55
56 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
57 log.Log(ctx, "websocket opened")
58 spmetrics.WebsocketsOpen.Inc()
59 defer spmetrics.WebsocketsOpen.Dec()
60 user := params.ByName("repoDID")
61 if user == "" {
62 apierrors.WriteHTTPBadRequest(w, "user required", nil)
63 return
64 }
65 repoDID, err := a.NormalizeUser(ctx, user)
66 if err != nil {
67 apierrors.WriteHTTPNotFound(w, "user not found", err)
68 return
69 }
70 conn, err := upgrader.Upgrade(w, req, nil)
71 if err != nil {
72 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
73 return
74 }
75 ctx, cancel := context.WithCancel(ctx)
76 defer cancel()
77 defer conn.Close()
78
79 initialBurst := make(chan any, 200)
80 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
81 if err != nil {
82 log.Error(ctx, "could not set read deadline", "error", err)
83 return
84 }
85
86 pongCh := make(chan struct{})
87
88 go func() {
89 for {
90 select {
91 case <-ctx.Done():
92 return
93 case <-pongCh:
94 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
95 if err != nil {
96 log.Error(ctx, "could not set read deadline", "error", err)
97 return
98 }
99 case <-time.After(30 * time.Second):
100 log.Log(ctx, "websocket timeout, closing connection")
101 // timeout!
102 conn.Close()
103 cancel()
104 return
105 }
106 }
107 }()
108
109 conn.SetPongHandler(func(appData string) error {
110 log.Debug(ctx, "received pong", "appData", appData)
111 pongCh <- struct{}{}
112 return nil
113 })
114 go func() {
115
116 ch := a.Bus.Subscribe(repoDID)
117 defer a.Bus.Unsubscribe(repoDID, ch)
118 // Create a ticker that fires every 3 seconds
119 ticker := time.NewTicker(3 * time.Second)
120 pingTicker := time.NewTicker(pingPeriod)
121 defer ticker.Stop()
122 defer pingTicker.Stop()
123
124 send := func(msg any) {
125 bs, err := json.Marshal(msg)
126 if err != nil {
127 log.Error(ctx, "could not marshal message", "error", err)
128 return
129 }
130 log.Debug(ctx, "sending message", "message", string(bs))
131 err = conn.WriteMessage(websocket.TextMessage, bs)
132 if err != nil {
133 log.Error(ctx, "could not write message", "error", err)
134 return
135 }
136 }
137
138 for {
139 select {
140 case msg := <-ch:
141 send(msg)
142 case msg := <-initialBurst:
143 send(msg)
144 case <-ticker.C:
145 count := a.Bus.GetViewerCount(repoDID)
146 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
147 if err != nil {
148 log.Error(ctx, "could not marshal view count", "error", err)
149 continue
150 }
151 err = conn.WriteMessage(websocket.TextMessage, bs)
152 if err != nil {
153 log.Error(ctx, "could not write ping message", "error", err)
154 return
155 }
156 case <-pingTicker.C:
157 err := conn.WriteMessage(websocket.PingMessage, []byte{})
158 if err != nil {
159 log.Error(ctx, "could not write ping message", "error", err)
160 return
161 }
162 case <-ctx.Done():
163 log.Debug(ctx, "context done, stopping websocket sender")
164 return
165 }
166 }
167 }()
168
169 go func() {
170 profile, err := a.Model.GetRepo(repoDID)
171 if err != nil {
172 log.Error(ctx, "could not get profile", "error", err)
173 return
174 }
175 if profile != nil {
176 p := map[string]any{
177 "$type": "app.bsky.actor.defs#profileViewBasic",
178 "did": repoDID,
179 "handle": profile.Handle,
180 }
181 initialBurst <- p
182 }
183 }()
184
185 go func() {
186 seg, err := a.LocalDB.LatestSegmentForUser(repoDID)
187 if err != nil {
188 log.Error(ctx, "could not get replies", "error", err)
189 return
190 }
191 spSeg, err := seg.ToStreamplaceSegment()
192 if err != nil {
193 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
194 return
195 }
196 initialBurst <- spSeg
197 if a.CLI.LivepeerGatewayURL != "" {
198 renditions, err := renditions.GenerateRenditions(spSeg)
199 if err != nil {
200 log.Error(ctx, "could not generate renditions", "error", err)
201 return
202 }
203 outRs := streamplace.Defs_Renditions{
204 LexiconTypeID: "place.stream.defs#renditions",
205 }
206 outRs.Renditions = []*streamplace.Defs_Rendition{}
207 for _, r := range renditions {
208 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
209 LexiconTypeID: "place.stream.defs#rendition",
210 Name: r.Name,
211 })
212 }
213 initialBurst <- outRs
214 }
215 }()
216
217 go func() {
218 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
219 if err != nil {
220 log.Error(ctx, "could not get latest livestream", "error", err)
221 return
222 }
223 lsv, err := ls.ToLivestreamView()
224 if err != nil {
225 log.Error(ctx, "could not marshal livestream", "error", err)
226 return
227 }
228 initialBurst <- lsv
229 }()
230
231 go func() {
232 count := a.Bus.GetViewerCount(repoDID)
233 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
234 }()
235
236 go func() {
237 messages, err := a.Model.MostRecentChatMessages(repoDID)
238 if err != nil {
239 log.Error(ctx, "could not get chat messages", "error", err)
240 return
241 }
242
243 // Add mod badges to messages
244 issuerDID := fmt.Sprintf("did:web:%s", a.CLI.BroadcasterHost)
245 for _, message := range messages {
246 err := atproto.AddModBadgeIfApplicable(ctx, message, repoDID, issuerDID, a.Model)
247 if err != nil {
248 log.Error(ctx, "failed to add mod badge to message", "error", err)
249 }
250 initialBurst <- message
251 }
252 }()
253
254 go func() {
255 teleports, err := a.Model.GetActiveTeleportsToRepo(repoDID)
256 if err != nil {
257 log.Error(ctx, "could not get active teleports", "error", err)
258 return
259 }
260 // just send the latest one if it started <3m ago
261 if len(teleports) > 0 && teleports[0].StartsAt.After(time.Now().Add(-3*time.Minute)) {
262 tp := teleports[0]
263 if tp.Repo == nil {
264 log.Error(ctx, "teleportee repo is nil", "uri", tp.URI)
265 }
266 viewerCount := a.Bus.GetViewerCount(tp.RepoDID)
267 arrivalMsg := streamplace.Livestream_TeleportArrival{
268 LexiconTypeID: "place.stream.livestream#teleportArrival",
269 TeleportUri: tp.URI,
270 Source: &bsky.ActorDefs_ProfileViewBasic{
271 Did: tp.RepoDID,
272 Handle: tp.Repo.Handle,
273 },
274 ViewerCount: int64(viewerCount),
275 StartsAt: tp.StartsAt.Format(time.RFC3339),
276 }
277
278 // get the source chat profile
279 chatProfile, err := a.Model.GetChatProfile(ctx, tp.RepoDID)
280 if err == nil && chatProfile != nil {
281 spcp, err := chatProfile.ToStreamplaceChatProfile()
282 if err == nil {
283 arrivalMsg.ChatProfile = spcp
284 }
285 }
286
287 initialBurst <- arrivalMsg
288 }
289 }()
290
291 for {
292 messageType, message, err := conn.ReadMessage()
293 if err != nil {
294 log.Error(ctx, "error reading message", "error", err)
295 break
296 }
297 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
298 }
299 }
300}