Live video on the AT Protocol
1package api
2
3import (
4 "context"
5 "encoding/json"
6 "net"
7 "net/http"
8 "time"
9
10 "github.com/google/uuid"
11 "github.com/gorilla/websocket"
12 "github.com/julienschmidt/httprouter"
13
14 apierrors "stream.place/streamplace/pkg/errors"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/renditions"
17 "stream.place/streamplace/pkg/spmetrics"
18 "stream.place/streamplace/pkg/streamplace"
19)
20
21// todo: does this mean a whole message has to fit within the buffer?
22var upgrader = websocket.Upgrader{
23 ReadBufferSize: 1024,
24 WriteBufferSize: 1024,
25 CheckOrigin: func(r *http.Request) bool {
26 return true
27 },
28}
29
30var pingPeriod = 5 * time.Second
31
32func (a *StreamplaceAPI) HandleWebsocket(ctx context.Context) httprouter.Handle {
33 ctx = log.WithLogValues(ctx, "func", "HandleWebsocket")
34 return func(w http.ResponseWriter, req *http.Request, params httprouter.Params) {
35 ip, _, err := net.SplitHostPort(req.RemoteAddr)
36 if err != nil {
37 ip = req.RemoteAddr
38 }
39
40 if a.CLI.RateLimitWebsocket > 0 {
41 if !a.connTracker.AddConnection(ip) {
42 log.Warn(ctx, "rate limit exceeded", "ip", ip, "path", req.URL.Path)
43 apierrors.WriteHTTPTooManyRequests(w, "rate limit exceeded")
44 return
45 }
46
47 defer a.connTracker.RemoveConnection(ip)
48 }
49
50 uu, _ := uuid.NewV7()
51 connID := uu.String()
52
53 ctx = log.WithLogValues(ctx, "uuid", connID, "remoteAddr", req.RemoteAddr, "url", req.URL.String())
54 log.Log(ctx, "websocket opened")
55 spmetrics.WebsocketsOpen.Inc()
56 defer spmetrics.WebsocketsOpen.Dec()
57 user := params.ByName("repoDID")
58 if user == "" {
59 apierrors.WriteHTTPBadRequest(w, "user required", nil)
60 return
61 }
62 repoDID, err := a.NormalizeUser(ctx, user)
63 if err != nil {
64 apierrors.WriteHTTPNotFound(w, "user not found", err)
65 return
66 }
67 conn, err := upgrader.Upgrade(w, req, nil)
68 if err != nil {
69 apierrors.WriteHTTPInternalServerError(w, "could not upgrade to websocket", err)
70 return
71 }
72 ctx, cancel := context.WithCancel(ctx)
73 defer cancel()
74 defer conn.Close()
75
76 initialBurst := make(chan any, 200)
77 err = conn.SetReadDeadline(time.Now().Add(30 * time.Second))
78 if err != nil {
79 log.Error(ctx, "could not set read deadline", "error", err)
80 return
81 }
82
83 pongCh := make(chan struct{})
84
85 go func() {
86 for {
87 select {
88 case <-ctx.Done():
89 return
90 case <-pongCh:
91 err := conn.SetReadDeadline(time.Now().Add(30 * time.Second))
92 if err != nil {
93 log.Error(ctx, "could not set read deadline", "error", err)
94 return
95 }
96 case <-time.After(30 * time.Second):
97 log.Log(ctx, "websocket timeout, closing connection")
98 // timeout!
99 conn.Close()
100 cancel()
101 return
102 }
103 }
104 }()
105
106 conn.SetPongHandler(func(appData string) error {
107 log.Debug(ctx, "received pong", "appData", appData)
108 pongCh <- struct{}{}
109 return nil
110 })
111 go func() {
112
113 ch := a.Bus.Subscribe(repoDID)
114 defer a.Bus.Unsubscribe(repoDID, ch)
115 // Create a ticker that fires every 3 seconds
116 ticker := time.NewTicker(3 * time.Second)
117 pingTicker := time.NewTicker(pingPeriod)
118 defer ticker.Stop()
119 defer pingTicker.Stop()
120
121 send := func(msg any) {
122 bs, err := json.Marshal(msg)
123 if err != nil {
124 log.Error(ctx, "could not marshal message", "error", err)
125 return
126 }
127 log.Debug(ctx, "sending message", "message", string(bs))
128 err = conn.WriteMessage(websocket.TextMessage, bs)
129 if err != nil {
130 log.Error(ctx, "could not write message", "error", err)
131 return
132 }
133 }
134
135 for {
136 select {
137 case msg := <-ch:
138 send(msg)
139 case msg := <-initialBurst:
140 send(msg)
141 case <-ticker.C:
142 count := a.Bus.GetViewerCount(repoDID)
143 bs, err := json.Marshal(streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"})
144 if err != nil {
145 log.Error(ctx, "could not marshal view count", "error", err)
146 continue
147 }
148 err = conn.WriteMessage(websocket.TextMessage, bs)
149 if err != nil {
150 log.Error(ctx, "could not write ping message", "error", err)
151 return
152 }
153 case <-pingTicker.C:
154 err := conn.WriteMessage(websocket.PingMessage, []byte{})
155 if err != nil {
156 log.Error(ctx, "could not write ping message", "error", err)
157 return
158 }
159 case <-ctx.Done():
160 log.Debug(ctx, "context done, stopping websocket sender")
161 return
162 }
163 }
164 }()
165
166 go func() {
167 profile, err := a.Model.GetRepo(repoDID)
168 if err != nil {
169 log.Error(ctx, "could not get profile", "error", err)
170 return
171 }
172 if profile != nil {
173 p := map[string]any{
174 "$type": "app.bsky.actor.defs#profileViewBasic",
175 "did": repoDID,
176 "handle": profile.Handle,
177 }
178 initialBurst <- p
179 }
180 }()
181
182 go func() {
183 seg, err := a.Model.LatestSegmentForUser(repoDID)
184 if err != nil {
185 log.Error(ctx, "could not get replies", "error", err)
186 return
187 }
188 spSeg, err := seg.ToStreamplaceSegment()
189 if err != nil {
190 log.Error(ctx, "could not convert segment to streamplace segment", "error", err)
191 return
192 }
193 initialBurst <- spSeg
194 if a.CLI.LivepeerGatewayURL != "" {
195 renditions, err := renditions.GenerateRenditions(spSeg)
196 if err != nil {
197 log.Error(ctx, "could not generate renditions", "error", err)
198 return
199 }
200 outRs := streamplace.Defs_Renditions{
201 LexiconTypeID: "place.stream.defs#renditions",
202 }
203 outRs.Renditions = []*streamplace.Defs_Rendition{}
204 for _, r := range renditions {
205 outRs.Renditions = append(outRs.Renditions, &streamplace.Defs_Rendition{
206 LexiconTypeID: "place.stream.defs#rendition",
207 Name: r.Name,
208 })
209 }
210 initialBurst <- outRs
211 }
212 }()
213
214 go func() {
215 ls, err := a.Model.GetLatestLivestreamForRepo(repoDID)
216 if err != nil {
217 log.Error(ctx, "could not get latest livestream", "error", err)
218 return
219 }
220 lsv, err := ls.ToLivestreamView()
221 if err != nil {
222 log.Error(ctx, "could not marshal livestream", "error", err)
223 return
224 }
225 initialBurst <- lsv
226 }()
227
228 go func() {
229 count := a.Bus.GetViewerCount(repoDID)
230 initialBurst <- streamplace.Livestream_ViewerCount{Count: int64(count), LexiconTypeID: "place.stream.livestream#viewerCount"}
231 }()
232
233 go func() {
234 messages, err := a.Model.MostRecentChatMessages(repoDID)
235 if err != nil {
236 log.Error(ctx, "could not get chat messages", "error", err)
237 return
238 }
239 for _, message := range messages {
240 initialBurst <- message
241 }
242 }()
243
244 for {
245 messageType, message, err := conn.ReadMessage()
246 if err != nil {
247 log.Error(ctx, "error reading message", "error", err)
248 break
249 }
250 log.Log(ctx, "received message", "messageType", messageType, "message", string(message))
251 }
252 }
253}