Live video on the AT Protocol
1package websocketrep
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "net/url"
8 "sync"
9
10 "github.com/gorilla/websocket"
11 "golang.org/x/sync/errgroup"
12 "stream.place/streamplace/pkg/bus"
13 "stream.place/streamplace/pkg/config"
14 "stream.place/streamplace/pkg/log"
15 "stream.place/streamplace/pkg/media"
16 "stream.place/streamplace/pkg/model"
17 "stream.place/streamplace/pkg/streamplace"
18)
19
20type WebsocketReplicator struct {
21 bus *bus.Bus
22 cli *config.CLI
23 mod model.Model
24 conns map[string]bool
25 connsMutex sync.RWMutex
26 group *errgroup.Group
27 mm *media.MediaManager
28}
29
30func NewWebsocketReplicator(bus *bus.Bus, mod model.Model, mm *media.MediaManager) *WebsocketReplicator {
31 return &WebsocketReplicator{
32 bus: bus,
33 mod: mod,
34 conns: make(map[string]bool),
35 connsMutex: sync.RWMutex{},
36 mm: mm,
37 }
38}
39
40func (r *WebsocketReplicator) Start(ctx context.Context, cli *config.CLI) error {
41 r.cli = cli
42 _ = r.getMyWebsocketURL() // panic check
43 r.group, ctx = errgroup.WithContext(ctx)
44 return r.startBusSubscribe(ctx)
45}
46
47func (r *WebsocketReplicator) startBusSubscribe(ctx context.Context) error {
48 // start subscription first so we're buffering new origins
49 busCh := r.bus.Subscribe("")
50 originViews, err := r.mod.GetRecentBroadcastOrigins(ctx)
51 if err != nil {
52 return fmt.Errorf("failed to get recent broadcast origins: %w", err)
53 }
54 for _, view := range originViews {
55 err = r.handleOriginMessage(ctx, view)
56 if err != nil {
57 log.Error(ctx, "could not check origin", "error", err)
58 }
59 }
60 log.Log(ctx, "Resumed recent broadcast origins", "count", len(originViews))
61 for {
62 select {
63 case <-ctx.Done():
64 return ctx.Err()
65 case msg := <-busCh:
66 if view, ok := msg.(*streamplace.BroadcastDefs_BroadcastOriginView); ok {
67 log.Debug(ctx, "got broadcast origin view", "view", view)
68 err = r.handleOriginMessage(ctx, view)
69 if err != nil {
70 log.Error(ctx, "could not handle origin message", "error", err)
71 }
72 }
73 }
74 }
75}
76
77func (r *WebsocketReplicator) handleOriginMessage(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error {
78 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin)
79 if !ok {
80 return fmt.Errorf("record is not a BroadcastOrigin")
81 }
82 ctx = log.WithLogValues(ctx, "streamer", view.Author.Did)
83 if origin.WebsocketURL == nil {
84 return fmt.Errorf("origin has no websocket URL author=%s", view.Author.Did)
85 }
86 if !r.cli.ShouldSyndicate(origin.Streamer) {
87 log.Debug(ctx, "not replicating streamer", "streamer", origin.Streamer)
88 return nil
89 }
90 if r.hasConnection(origin.Streamer) {
91 log.Debug(ctx, "already has connection")
92 return nil
93 }
94 myURL := r.getMyWebsocketURL()
95 u, err := url.Parse(*origin.WebsocketURL)
96 if err != nil {
97 return fmt.Errorf("could not parse origin websocket URL: %w", err)
98 }
99 if u.Host == myURL.Host {
100 log.Debug(ctx, "origin websocket URL is on this node, skipping")
101 return nil
102 }
103 r.group.Go(func() error {
104 err := r.openWebsocket(ctx, view)
105 log.Error(ctx, "websocket connection error", "error", err)
106 return nil
107 })
108 return nil
109}
110
111func (r *WebsocketReplicator) openWebsocket(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error {
112 err := r.tryConnection(view.Author.Did)
113 if err != nil {
114 return err
115 }
116 defer r.removeConnection(view.Author.Did)
117 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin)
118 if !ok {
119 return fmt.Errorf("record is not a BroadcastOrigin")
120 }
121 if origin.WebsocketURL == nil {
122 return fmt.Errorf("origin has no websocket URL")
123 }
124 conn, _, err := websocket.DefaultDialer.Dial(*origin.WebsocketURL, nil)
125 if err != nil {
126 return fmt.Errorf("could not dial websocket: %w", err)
127 }
128 defer conn.Close()
129 for {
130 typ, msg, err := conn.ReadMessage()
131 if err != nil {
132 log.Error(ctx, "could not read message", "error", err)
133 return fmt.Errorf("could not read message: %w", err)
134 }
135 if typ != websocket.BinaryMessage {
136 log.Error(ctx, "expected binary message", "type", typ)
137 return fmt.Errorf("expected binary message")
138 }
139 log.Debug(ctx, "received message", "type", typ, "length", len(msg))
140 err = r.mm.ValidateMP4(context.Background(), bytes.NewReader(msg), false)
141 if err != nil {
142 return fmt.Errorf("could not validate segment: %w", err)
143 }
144 }
145}
146
147func (r *WebsocketReplicator) hasConnection(origin string) bool {
148 r.connsMutex.RLock()
149 defer r.connsMutex.RUnlock()
150 return r.conns[origin]
151}
152
153func (r *WebsocketReplicator) tryConnection(origin string) error {
154 r.connsMutex.Lock()
155 defer r.connsMutex.Unlock()
156 if _, ok := r.conns[origin]; ok {
157 return fmt.Errorf("connection already exists")
158 }
159 r.conns[origin] = true
160 return nil
161}
162
163func (r *WebsocketReplicator) removeConnection(origin string) {
164 r.connsMutex.Lock()
165 defer r.connsMutex.Unlock()
166 delete(r.conns, origin)
167}
168
169// we're pull-based, nothing to do here
170func (r *WebsocketReplicator) SendSegment(ctx context.Context, seg *media.NewSegmentNotification) error {
171 return nil
172}
173
174func (r *WebsocketReplicator) BuildOriginRecord(origin *streamplace.BroadcastOrigin) error {
175 u := r.getMyWebsocketURL()
176 u.Path = "/xrpc/place.stream.live.subscribeSegments"
177 u.RawQuery = url.Values{
178 "streamer": []string{origin.Streamer},
179 }.Encode()
180
181 urlStr := u.String()
182 origin.WebsocketURL = &urlStr
183 return nil
184}
185
186func (r *WebsocketReplicator) getMyWebsocketURL() *url.URL {
187 if r.cli.WebsocketURL != "" {
188 u, err := url.Parse(r.cli.WebsocketURL)
189 // chill to panic, we're going to check this on boot
190 if err != nil {
191 panic("invalid websocket override URL: " + r.cli.WebsocketURL)
192 }
193 return u
194 }
195 u := url.URL{
196 Scheme: "ws",
197 Host: r.cli.ServerHost,
198 }
199 if r.cli.HasHTTPS() {
200 u.Scheme = "wss"
201 }
202 return &u
203}