Live video on the AT Protocol
at eli/lex-it-up 199 lines 5.6 kB view raw
1package websocketrep 2 3import ( 4 "bytes" 5 "context" 6 "fmt" 7 "net/url" 8 "sync" 9 10 "github.com/gorilla/websocket" 11 "golang.org/x/sync/errgroup" 12 "stream.place/streamplace/pkg/bus" 13 "stream.place/streamplace/pkg/config" 14 "stream.place/streamplace/pkg/log" 15 "stream.place/streamplace/pkg/media" 16 "stream.place/streamplace/pkg/model" 17 "stream.place/streamplace/pkg/streamplace" 18) 19 20type WebsocketReplicator struct { 21 bus *bus.Bus 22 cli *config.CLI 23 mod model.Model 24 conns map[string]bool 25 connsMutex sync.RWMutex 26 group *errgroup.Group 27 mm *media.MediaManager 28} 29 30func NewWebsocketReplicator(bus *bus.Bus, mod model.Model, mm *media.MediaManager) *WebsocketReplicator { 31 return &WebsocketReplicator{ 32 bus: bus, 33 mod: mod, 34 conns: make(map[string]bool), 35 connsMutex: sync.RWMutex{}, 36 mm: mm, 37 } 38} 39 40func (r *WebsocketReplicator) Start(ctx context.Context, cli *config.CLI) error { 41 r.cli = cli 42 _ = r.getMyWebsocketURL() // panic check 43 r.group, ctx = errgroup.WithContext(ctx) 44 return r.startBusSubscribe(ctx) 45} 46 47func (r *WebsocketReplicator) startBusSubscribe(ctx context.Context) error { 48 // start subscription first so we're buffering new origins 49 busCh := r.bus.Subscribe("") 50 originViews, err := r.mod.GetRecentBroadcastOrigins(ctx) 51 if err != nil { 52 return fmt.Errorf("failed to get recent broadcast origins: %w", err) 53 } 54 for _, view := range originViews { 55 err = r.handleOriginMessage(ctx, view) 56 if err != nil { 57 log.Error(ctx, "could not check origin", "error", err) 58 } 59 } 60 log.Log(ctx, "Resumed recent broadcast origins", "count", len(originViews)) 61 for { 62 select { 63 case <-ctx.Done(): 64 return ctx.Err() 65 case msg := <-busCh: 66 if view, ok := msg.(*streamplace.BroadcastDefs_BroadcastOriginView); ok { 67 log.Debug(ctx, "got broadcast origin view", "view", view) 68 err = r.handleOriginMessage(ctx, view) 69 if err != nil { 70 log.Error(ctx, "could not handle origin message", "error", err) 71 } 72 } 73 } 74 } 75} 76 77func (r *WebsocketReplicator) handleOriginMessage(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error { 78 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin) 79 if !ok { 80 return fmt.Errorf("record is not a BroadcastOrigin") 81 } 82 ctx = log.WithLogValues(ctx, "streamer", view.Author.Did) 83 if origin.WebsocketURL == nil { 84 return fmt.Errorf("origin has no websocket URL author=%s", view.Author.Did) 85 } 86 if r.hasConnection(origin.Streamer) { 87 log.Debug(ctx, "already has connection") 88 return nil 89 } 90 myURL := r.getMyWebsocketURL() 91 u, err := url.Parse(*origin.WebsocketURL) 92 if err != nil { 93 return fmt.Errorf("could not parse origin websocket URL: %w", err) 94 } 95 if u.Host == myURL.Host { 96 log.Debug(ctx, "origin websocket URL is on this node, skipping") 97 return nil 98 } 99 r.group.Go(func() error { 100 err := r.openWebsocket(ctx, view) 101 log.Error(ctx, "websocket connection error", "error", err) 102 return nil 103 }) 104 return nil 105} 106 107func (r *WebsocketReplicator) openWebsocket(ctx context.Context, view *streamplace.BroadcastDefs_BroadcastOriginView) error { 108 err := r.tryConnection(view.Author.Did) 109 if err != nil { 110 return err 111 } 112 defer r.removeConnection(view.Author.Did) 113 origin, ok := view.Record.Val.(*streamplace.BroadcastOrigin) 114 if !ok { 115 return fmt.Errorf("record is not a BroadcastOrigin") 116 } 117 if origin.WebsocketURL == nil { 118 return fmt.Errorf("origin has no websocket URL") 119 } 120 conn, _, err := websocket.DefaultDialer.Dial(*origin.WebsocketURL, nil) 121 if err != nil { 122 return fmt.Errorf("could not dial websocket: %w", err) 123 } 124 defer conn.Close() 125 for { 126 typ, msg, err := conn.ReadMessage() 127 if err != nil { 128 log.Error(ctx, "could not read message", "error", err) 129 return fmt.Errorf("could not read message: %w", err) 130 } 131 if typ != websocket.BinaryMessage { 132 log.Error(ctx, "expected binary message", "type", typ) 133 return fmt.Errorf("expected binary message") 134 } 135 log.Debug(ctx, "received message", "type", typ, "length", len(msg)) 136 err = r.mm.ValidateMP4(context.Background(), bytes.NewReader(msg), false) 137 if err != nil { 138 return fmt.Errorf("could not validate segment: %w", err) 139 } 140 } 141} 142 143func (r *WebsocketReplicator) hasConnection(origin string) bool { 144 r.connsMutex.RLock() 145 defer r.connsMutex.RUnlock() 146 return r.conns[origin] 147} 148 149func (r *WebsocketReplicator) tryConnection(origin string) error { 150 r.connsMutex.Lock() 151 defer r.connsMutex.Unlock() 152 if _, ok := r.conns[origin]; ok { 153 return fmt.Errorf("connection already exists") 154 } 155 r.conns[origin] = true 156 return nil 157} 158 159func (r *WebsocketReplicator) removeConnection(origin string) { 160 r.connsMutex.Lock() 161 defer r.connsMutex.Unlock() 162 delete(r.conns, origin) 163} 164 165// we're pull-based, nothing to do here 166func (r *WebsocketReplicator) SendSegment(ctx context.Context, seg *media.NewSegmentNotification) error { 167 return nil 168} 169 170func (r *WebsocketReplicator) BuildOriginRecord(origin *streamplace.BroadcastOrigin) error { 171 u := r.getMyWebsocketURL() 172 u.Path = "/xrpc/place.stream.live.subscribeSegments" 173 u.RawQuery = url.Values{ 174 "streamer": []string{origin.Streamer}, 175 }.Encode() 176 177 urlStr := u.String() 178 origin.WebsocketURL = &urlStr 179 return nil 180} 181 182func (r *WebsocketReplicator) getMyWebsocketURL() *url.URL { 183 if r.cli.WebsocketURL != "" { 184 u, err := url.Parse(r.cli.WebsocketURL) 185 // chill to panic, we're going to check this on boot 186 if err != nil { 187 panic("invalid websocket override URL: " + r.cli.WebsocketURL) 188 } 189 return u 190 } 191 u := url.URL{ 192 Scheme: "ws", 193 Host: r.cli.ServerHost, 194 } 195 if r.cli.HasHTTPS() { 196 u.Scheme = "wss" 197 } 198 return &u 199}