Live video on the AT Protocol
1package websocketrep
2
3import (
4 "bytes"
5 "context"
6 "fmt"
7 "net/url"
8 "sync"
9
10 "github.com/gorilla/websocket"
11 "golang.org/x/sync/errgroup"
12 "stream.place/streamplace/pkg/bus"
13 "stream.place/streamplace/pkg/config"
14 "stream.place/streamplace/pkg/log"
15 "stream.place/streamplace/pkg/media"
16 "stream.place/streamplace/pkg/model"
17 "stream.place/streamplace/pkg/streamplace"
18)
19
20type WebsocketReplicator struct {
21 bus *bus.Bus
22 cli *config.CLI
23 mod model.Model
24 conns map[string]bool
25 connsMutex sync.RWMutex
26 group *errgroup.Group
27 mm *media.MediaManager
28}
29
30func NewWebsocketReplicator(bus *bus.Bus, mod model.Model, mm *media.MediaManager) *WebsocketReplicator {
31 return &WebsocketReplicator{
32 bus: bus,
33 mod: mod,
34 conns: make(map[string]bool),
35 connsMutex: sync.RWMutex{},
36 mm: mm,
37 }
38}
39
40func (r *WebsocketReplicator) Start(ctx context.Context, cli *config.CLI) error {
41 r.cli = cli
42 _ = r.getMyWebsocketURL() // panic check
43 r.group, ctx = errgroup.WithContext(ctx)
44 return r.startBusSubscribe(ctx)
45}
46
47func (r *WebsocketReplicator) startBusSubscribe(ctx context.Context) error {
48 // start subscription first so we're buffering new origins
49 busCh := r.bus.Subscribe("")
50 originViews, err := r.mod.GetRecentBroadcastOrigins(ctx)
51 if err != nil {
52 return fmt.Errorf("failed to get recent broadcast origins: %w", err)
53 }
54 for _, view := range originViews {
55 err = r.handleOriginMessage(ctx, view)
56 if err != nil {
57 log.Error(ctx, "could not check origin", "error", err)
58 }
59 }
60 log.Log(ctx, "Resumed recent broadcast origins", "count", len(originViews))
61 for {
62 select {
63 case <-ctx.Done():
64 return ctx.Err()
65 case msg := <-busCh:
66 if view, ok := msg.(*streamplace.BroadcastDefs_BroadcastOriginView); ok {
67 log.Debug(ctx, "got broadcast origin view", "view", view)
68 err = r.handleOriginMessage(ctx, view)
69 if err != nil {
70 log.Error(ctx, "could not handle origin message", "error", err)
71 }
72 }
73 }
74 }
75}
76
77func (r *WebsocketReplicator) handleOriginMessage(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error {
78 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin)
79 if !ok {
80 return fmt.Errorf("record is not a BroadcastOrigin")
81 }
82 ctx = log.WithLogValues(ctx, "streamer", view.Author.Did)
83 if origin.WebsocketURL == nil {
84 return fmt.Errorf("origin has no websocket URL author=%s", view.Author.Did)
85 }
86 if r.hasConnection(origin.Streamer) {
87 log.Debug(ctx, "already has connection")
88 return nil
89 }
90 myURL := r.getMyWebsocketURL()
91 u, err := url.Parse(*origin.WebsocketURL)
92 if err != nil {
93 return fmt.Errorf("could not parse origin websocket URL: %w", err)
94 }
95 if u.Host == myURL.Host {
96 log.Debug(ctx, "origin websocket URL is on this node, skipping")
97 return nil
98 }
99 r.group.Go(func() error {
100 err := r.openWebsocket(ctx, view)
101 log.Error(ctx, "websocket connection error", "error", err)
102 return nil
103 })
104 return nil
105}
106
107func (r *WebsocketReplicator) openWebsocket(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error {
108 err := r.tryConnection(view.Author.Did)
109 if err != nil {
110 return err
111 }
112 defer r.removeConnection(view.Author.Did)
113 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin)
114 if !ok {
115 return fmt.Errorf("record is not a BroadcastOrigin")
116 }
117 if origin.WebsocketURL == nil {
118 return fmt.Errorf("origin has no websocket URL")
119 }
120 conn, _, err := websocket.DefaultDialer.Dial(*origin.WebsocketURL, nil)
121 if err != nil {
122 return fmt.Errorf("could not dial websocket: %w", err)
123 }
124 defer conn.Close()
125 for {
126 typ, msg, err := conn.ReadMessage()
127 if err != nil {
128 log.Error(ctx, "could not read message", "error", err)
129 return fmt.Errorf("could not read message: %w", err)
130 }
131 if typ != websocket.BinaryMessage {
132 log.Error(ctx, "expected binary message", "type", typ)
133 return fmt.Errorf("expected binary message")
134 }
135 log.Debug(ctx, "received message", "type", typ, "length", len(msg))
136 err = r.mm.ValidateMP4(context.Background(), bytes.NewReader(msg), false)
137 if err != nil {
138 return fmt.Errorf("could not validate segment: %w", err)
139 }
140 }
141}
142
143func (r *WebsocketReplicator) hasConnection(origin string) bool {
144 r.connsMutex.RLock()
145 defer r.connsMutex.RUnlock()
146 return r.conns[origin]
147}
148
149func (r *WebsocketReplicator) tryConnection(origin string) error {
150 r.connsMutex.Lock()
151 defer r.connsMutex.Unlock()
152 if _, ok := r.conns[origin]; ok {
153 return fmt.Errorf("connection already exists")
154 }
155 r.conns[origin] = true
156 return nil
157}
158
159func (r *WebsocketReplicator) removeConnection(origin string) {
160 r.connsMutex.Lock()
161 defer r.connsMutex.Unlock()
162 delete(r.conns, origin)
163}
164
165// we're pull-based, nothing to do here
166func (r *WebsocketReplicator) SendSegment(ctx context.Context, seg *media.NewSegmentNotification) error {
167 return nil
168}
169
170func (r *WebsocketReplicator) BuildOriginRecord(origin *streamplace.BroadcastOrigin) error {
171 u := r.getMyWebsocketURL()
172 u.Path = "/xrpc/place.stream.live.subscribeSegments"
173 u.RawQuery = url.Values{
174 "streamer": []string{origin.Streamer},
175 }.Encode()
176
177 urlStr := u.String()
178 origin.WebsocketURL = &urlStr
179 return nil
180}
181
182func (r *WebsocketReplicator) getMyWebsocketURL() *url.URL {
183 if r.cli.WebsocketURL != "" {
184 u, err := url.Parse(r.cli.WebsocketURL)
185 // chill to panic, we're going to check this on boot
186 if err != nil {
187 panic("invalid websocket override URL: " + r.cli.WebsocketURL)
188 }
189 return u
190 }
191 u := url.URL{
192 Scheme: "ws",
193 Host: r.cli.ServerHost,
194 }
195 if r.cli.HasHTTPS() {
196 u.Scheme = "wss"
197 }
198 return &u
199}