Live video on the AT Protocol
at v0.9.4 203 lines 5.7 kB view raw
1package websocketrep 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "net/url" 8 "sync" 9 10 "github.com/gorilla/websocket" 11 "golang.org/x/sync/errgroup" 12 "stream.place/streamplace/pkg/bus" 13 "stream.place/streamplace/pkg/config" 14 "stream.place/streamplace/pkg/log" 15 "stream.place/streamplace/pkg/media" 16 "stream.place/streamplace/pkg/model" 17 "stream.place/streamplace/pkg/streamplace" 18) 19 20type WebsocketReplicator struct { 21 bus *bus.Bus 22 cli *config.CLI 23 mod model.Model 24 conns map[string]bool 25 connsMutex sync.RWMutex 26 group *errgroup.Group 27 mm *media.MediaManager 28} 29 30func NewWebsocketReplicator(bus *bus.Bus, mod model.Model, mm *media.MediaManager) *WebsocketReplicator { 31 return &WebsocketReplicator{ 32 bus: bus, 33 mod: mod, 34 conns: make(map[string]bool), 35 connsMutex: sync.RWMutex{}, 36 mm: mm, 37 } 38} 39 40func (r *WebsocketReplicator) Start(ctx context.Context, cli *config.CLI) error { 41 r.cli = cli 42 _ = r.getMyWebsocketURL() // panic check 43 r.group, ctx = errgroup.WithContext(ctx) 44 return r.startBusSubscribe(ctx) 45} 46 47func (r *WebsocketReplicator) startBusSubscribe(ctx context.Context) error { 48 // start subscription first so we're buffering new origins 49 busCh := r.bus.Subscribe("") 50 originViews, err := r.mod.GetRecentBroadcastOrigins(ctx) 51 if err != nil { 52 return fmt.Errorf("failed to get recent broadcast origins: %w", err) 53 } 54 for _, view := range originViews { 55 err = r.handleOriginMessage(ctx, view) 56 if err != nil { 57 log.Error(ctx, "could not check origin", "error", err) 58 } 59 } 60 log.Log(ctx, "Resumed recent broadcast origins", "count", len(originViews)) 61 for { 62 select { 63 case <-ctx.Done(): 64 return ctx.Err() 65 case msg := <-busCh: 66 if view, ok := msg.(*streamplace.BroadcastDefs_BroadcastOriginView); ok { 67 log.Debug(ctx, "got broadcast origin view", "view", view) 68 err = r.handleOriginMessage(ctx, view) 69 if err != nil { 70 log.Error(ctx, "could not handle origin message", "error", err) 71 } 72 } 73 } 74 } 75} 76 77func (r *WebsocketReplicator) handleOriginMessage(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error { 78 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin) 79 if !ok { 80 return fmt.Errorf("record is not a BroadcastOrigin") 81 } 82 ctx = log.WithLogValues(ctx, "streamer", view.Author.Did) 83 if origin.WebsocketURL == nil { 84 return fmt.Errorf("origin has no websocket URL author=%s", view.Author.Did) 85 } 86 if !r.cli.ShouldSyndicate(origin.Streamer) { 87 log.Debug(ctx, "not replicating streamer", "streamer", origin.Streamer) 88 return nil 89 } 90 if r.hasConnection(origin.Streamer) { 91 log.Debug(ctx, "already has connection") 92 return nil 93 } 94 myURL := r.getMyWebsocketURL() 95 u, err := url.Parse(*origin.WebsocketURL) 96 if err != nil { 97 return fmt.Errorf("could not parse origin websocket URL: %w", err) 98 } 99 if u.Host == myURL.Host { 100 log.Debug(ctx, "origin websocket URL is on this node, skipping") 101 return nil 102 } 103 r.group.Go(func() error { 104 err := r.openWebsocket(ctx, view) 105 log.Error(ctx, "websocket connection error", "error", err) 106 return nil 107 }) 108 return nil 109} 110 111func (r *WebsocketReplicator) openWebsocket(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error { 112 err := r.tryConnection(view.Author.Did) 113 if err != nil { 114 return err 115 } 116 defer r.removeConnection(view.Author.Did) 117 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin) 118 if !ok { 119 return fmt.Errorf("record is not a BroadcastOrigin") 120 } 121 if origin.WebsocketURL == nil { 122 return fmt.Errorf("origin has no websocket URL") 123 } 124 conn, _, err := websocket.DefaultDialer.Dial(*origin.WebsocketURL, nil) 125 if err != nil { 126 return fmt.Errorf("could not dial websocket: %w", err) 127 } 128 defer conn.Close() 129 for { 130 typ, msg, err := conn.ReadMessage() 131 if err != nil { 132 log.Error(ctx, "could not read message", "error", err) 133 return fmt.Errorf("could not read message: %w", err) 134 } 135 if typ != websocket.BinaryMessage { 136 log.Error(ctx, "expected binary message", "type", typ) 137 return fmt.Errorf("expected binary message") 138 } 139 log.Debug(ctx, "received message", "type", typ, "length", len(msg)) 140 err = r.mm.ValidateMP4(context.Background(), bytes.NewReader(msg), false) 141 if err != nil { 142 return fmt.Errorf("could not validate segment: %w", err) 143 } 144 } 145} 146 147func (r *WebsocketReplicator) hasConnection(origin string) bool { 148 r.connsMutex.RLock() 149 defer r.connsMutex.RUnlock() 150 return r.conns[origin] 151} 152 153func (r *WebsocketReplicator) tryConnection(origin string) error { 154 r.connsMutex.Lock() 155 defer r.connsMutex.Unlock() 156 if _, ok := r.conns[origin]; ok { 157 return fmt.Errorf("connection already exists") 158 } 159 r.conns[origin] = true 160 return nil 161} 162 163func (r *WebsocketReplicator) removeConnection(origin string) { 164 r.connsMutex.Lock() 165 defer r.connsMutex.Unlock() 166 delete(r.conns, origin) 167} 168 169// we're pull-based, nothing to do here 170func (r *WebsocketReplicator) SendSegment(ctx context.Context, seg *media.NewSegmentNotification) error { 171 return nil 172} 173 174func (r *WebsocketReplicator) BuildOriginRecord(origin *streamplace.BroadcastOrigin) error { 175 u := r.getMyWebsocketURL() 176 u.Path = "/xrpc/place.stream.live.subscribeSegments" 177 u.RawQuery = url.Values{ 178 "streamer": []string{origin.Streamer}, 179 }.Encode() 180 181 urlStr := u.String() 182 origin.WebsocketURL = &urlStr 183 return nil 184} 185 186func (r *WebsocketReplicator) getMyWebsocketURL() *url.URL { 187 if r.cli.WebsocketURL != "" { 188 u, err := url.Parse(r.cli.WebsocketURL) 189 // chill to panic, we're going to check this on boot 190 if err != nil { 191 panic("invalid websocket override URL: " + r.cli.WebsocketURL) 192 } 193 return u 194 } 195 u := url.URL{ 196 Scheme: "ws", 197 Host: r.cli.ServerHost, 198 } 199 if r.cli.HasHTTPS() { 200 u.Scheme = "wss" 201 } 202 return &u 203}