Live video on the AT Protocol
at eli/github-skip-darwin 403 lines 9.8 kB view raw
1package cmd 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "time" 11 12 atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 13 "github.com/go-gst/go-gst/gst" 14 "github.com/go-gst/go-gst/gst/app" 15 "github.com/pion/webrtc/v4" 16 pionmedia "github.com/pion/webrtc/v4/pkg/media" 17 "golang.org/x/sync/errgroup" 18 "stream.place/streamplace/pkg/gstinit" 19 "stream.place/streamplace/pkg/log" 20 "stream.place/streamplace/pkg/media" 21) 22 23func WHIP(args []string) error { 24 fs := flag.NewFlagSet("whip", flag.ExitOnError) 25 streamKey := fs.String("stream-key", "", "stream key") 26 count := fs.Int("count", 1, "number of concurrent streams (for load testing)") 27 viewers := fs.Int("viewers", 0, "number of viewers to simulate per stream") 28 duration := fs.Duration("duration", 0, "duration of the stream") 29 file := fs.String("file", "", "file to stream (needs to be an MP4 containing H264 video and Opus audio)") 30 endpoint := fs.String("endpoint", "http://127.0.0.1:38080", "endpoint to send the WHIP request to") 31 freezeAfter := fs.Duration("freeze-after", 0, "freeze the stream after the given duration") 32 err := fs.Parse(args) 33 if *file == "" { 34 return fmt.Errorf("file is required") 35 } 36 if err != nil { 37 return err 38 } 39 gstinit.InitGST() 40 41 ctx := context.Background() 42 if *duration > 0 { 43 var cancel context.CancelFunc 44 ctx, cancel = context.WithTimeout(ctx, *duration) 45 defer cancel() 46 } 47 48 w := &WHIPClient{ 49 StreamKey: *streamKey, 50 File: *file, 51 Endpoint: *endpoint, 52 Count: *count, 53 FreezeAfter: *freezeAfter, 54 Viewers: *viewers, 55 } 56 57 return w.WHIP(ctx) 58} 59 60type WHIPClient struct { 61 StreamKey string 62 File string 63 Endpoint string 64 Count int 65 FreezeAfter time.Duration 66 Viewers int 67} 68 69var failureStates = []webrtc.ICEConnectionState{ 70 webrtc.ICEConnectionStateFailed, 71 webrtc.ICEConnectionStateDisconnected, 72 webrtc.ICEConnectionStateClosed, 73 webrtc.ICEConnectionStateCompleted, 74} 75 76type WHIPConnection struct { 77 peerConnection *webrtc.PeerConnection 78 audioTrack *webrtc.TrackLocalStaticSample 79 videoTrack *webrtc.TrackLocalStaticSample 80 did string 81} 82 83func (w *WHIPClient) WHIP(ctx context.Context) error { 84 ctx, cancel := context.WithCancel(ctx) 85 defer cancel() 86 87 pipelineSlice := []string{ 88 "filesrc name=filesrc ! qtdemux name=demux", 89 "demux.video_0 ! tee name=video_tee", 90 "demux.audio_0 ! tee name=audio_tee", 91 "video_tee. ! queue ! h264parse config-interval=-1 ! video/x-h264,stream-format=byte-stream ! appsink name=videoappsink", 92 "audio_tee. ! queue ! opusparse ! appsink name=audioappsink", 93 // "matroskamux name=mux ! fakesink name=fakesink sync=true", 94 // "video_tee. ! mux.video_0", 95 // "audio_tee. ! mux.audio_0", 96 } 97 98 pipeline, err := gst.NewPipelineFromString(strings.Join(pipelineSlice, "\n")) 99 if err != nil { 100 return err 101 } 102 103 fileSrc, err := pipeline.GetElementByName("filesrc") 104 if err != nil { 105 return err 106 } 107 108 if err := fileSrc.Set("location", w.File); err != nil { 109 return err 110 } 111 112 videoSink, err := pipeline.GetElementByName("videoappsink") 113 if err != nil { 114 return err 115 } 116 117 audioSink, err := pipeline.GetElementByName("audioappsink") 118 if err != nil { 119 return err 120 } 121 122 startTime := time.Now() 123 sinks := []*app.Sink{ 124 app.SinkFromElement(videoSink), 125 app.SinkFromElement(audioSink), 126 } 127 // Create accumulators for tracking elapsed duration 128 accumulators := make([]time.Duration, len(sinks)) 129 130 conns := make([]*WHIPConnection, w.Count) 131 g := &errgroup.Group{} 132 for i := 0; i < w.Count; i++ { 133 ctx := ctx 134 // var streamKey string 135 var did string 136 var streamKey string 137 if w.StreamKey != "" { 138 streamKey = w.StreamKey 139 } else { 140 priv, err := atcrypto.GeneratePrivateKeyK256() 141 if err != nil { 142 return err 143 } 144 pub, err := priv.PublicKey() 145 if err != nil { 146 return err 147 } 148 149 did = pub.DIDKey() 150 ctx = log.WithLogValues(ctx, "did", did) 151 streamKey = priv.Multibase() 152 } 153 154 g.Go(func() error { 155 conn, err := w.StartWHIPConnection(ctx, streamKey, did) 156 if err != nil { 157 return err 158 } 159 conns[i] = conn 160 ctx := log.WithLogValues(ctx, "did", did) 161 conn.peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 162 log.Log(ctx, "WHIP connection State has changed", "state", connectionState.String()) 163 for _, state := range failureStates { 164 if connectionState == state { 165 log.Log(ctx, "connection failed, cancelling") 166 cancel() 167 } 168 } 169 }) 170 go func() { 171 <-ctx.Done() 172 if conn.peerConnection != nil { 173 conn.peerConnection.Close() 174 } 175 }() 176 return nil 177 }) 178 } 179 180 if err := g.Wait(); err != nil { 181 return err 182 } 183 184 // Start a ticker to print elapsed duration every second 185 go func() { 186 ticker := time.NewTicker(time.Second) 187 defer ticker.Stop() 188 189 for { 190 select { 191 case <-ctx.Done(): 192 return 193 case <-ticker.C: 194 for i, duration := range accumulators { 195 trackType := "video" 196 if i == 1 { 197 trackType = "audio" 198 } 199 target := startTime.Add(time.Duration(accumulators[i])) 200 diff := time.Since(target) 201 log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 202 } 203 } 204 } 205 }() 206 207 errCh := make(chan error, 1) 208 209 for i := range sinks { 210 func(i int) { 211 sink := sinks[i] 212 trackType := "video" 213 if i == 1 { 214 trackType = "audio" 215 } 216 217 sink.SetCallbacks(&app.SinkCallbacks{ 218 NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 219 220 sample := sink.PullSample() 221 if sample == nil { 222 return gst.FlowEOS 223 } 224 225 buffer := sample.GetBuffer() 226 if buffer == nil { 227 return gst.FlowError 228 } 229 230 samples := buffer.Map(gst.MapRead).Bytes() 231 defer buffer.Unmap() 232 233 durationPtr := buffer.Duration().AsDuration() 234 var duration time.Duration 235 if durationPtr == nil { 236 errCh <- fmt.Errorf("%v duration: nil", trackType) 237 return gst.FlowError 238 } else { 239 // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 240 duration = *durationPtr 241 } 242 243 accumulators[i] += duration 244 245 if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 246 for _, conn := range conns { 247 if trackType == "video" { 248 if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 249 log.Log(ctx, "error writing video sample", "error", err) 250 errCh <- err 251 return gst.FlowError 252 } 253 } else { 254 if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 255 log.Log(ctx, "error writing video sample", "error", err) 256 errCh <- err 257 return gst.FlowError 258 } 259 } 260 } 261 } 262 263 return gst.FlowOK 264 }, 265 }) 266 }(i) 267 } 268 269 go func() { 270 if err := media.HandleBusMessages(ctx, pipeline); err != nil { 271 log.Log(ctx, "pipeline error", "error", err) 272 } 273 cancel() 274 }() 275 276 if err = pipeline.SetState(gst.StatePlaying); err != nil { 277 return err 278 } 279 if w.Viewers > 0 { 280 whepG, ctx := errgroup.WithContext(ctx) 281 for i := 0; i < w.Count; i++ { 282 did := conns[i].did 283 w := &WHEPClient{ 284 Endpoint: fmt.Sprintf("%s/api/playback/%s/webrtc", w.Endpoint, did), 285 Count: w.Viewers, 286 } 287 whepG.Go(func() error { 288 return w.WHEP(ctx) 289 }) 290 } 291 if err := whepG.Wait(); err != nil { 292 return err 293 } 294 } 295 296 <-ctx.Done() 297 err = pipeline.BlockSetState(gst.StateNull) 298 if err != nil { 299 return err 300 } 301 302 select { 303 case err := <-errCh: 304 return err 305 case <-ctx.Done(): 306 return ctx.Err() 307 } 308} 309 310func (w *WHIPClient) StartWHIPConnection(ctx context.Context, streamKey string, did string) (*WHIPConnection, error) { 311 312 // Prepare the configuration 313 config := webrtc.Configuration{} 314 315 // Create a new RTCPeerConnection 316 peerConnection, err := webrtc.NewPeerConnection(config) 317 if err != nil { 318 return nil, err 319 } 320 321 // Create a audio track 322 audioTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "audio/opus"}, "audio", "pion1") 323 if err != nil { 324 return nil, err 325 } 326 _, err = peerConnection.AddTrack(audioTrack) 327 if err != nil { 328 return nil, err 329 } 330 331 // Create a video track 332 videoTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "video/h264"}, "video", "pion2") 333 if err != nil { 334 return nil, err 335 } 336 _, err = peerConnection.AddTrack(videoTrack) 337 if err != nil { 338 return nil, err 339 } 340 341 // Create an offer 342 offer, err := peerConnection.CreateOffer(nil) 343 if err != nil { 344 return nil, err 345 } 346 347 // Set the generated offer as our LocalDescription 348 err = peerConnection.SetLocalDescription(offer) 349 if err != nil { 350 return nil, err 351 } 352 353 // Wait for ICE gathering to complete 354 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 355 // <-gatherComplete 356 357 // Create HTTP client and prepare the request 358 client := &http.Client{} 359 360 // Send the WHIP request to the server 361 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 362 if err != nil { 363 return nil, err 364 } 365 req.Header.Set("Authorization", "Bearer "+streamKey) 366 req.Header.Set("Content-Type", "application/sdp") 367 368 // Execute the request 369 resp, err := client.Do(req) 370 if err != nil { 371 return nil, err 372 } 373 defer resp.Body.Close() 374 375 // Read and process the answer 376 answerBytes, err := io.ReadAll(resp.Body) 377 if err != nil { 378 return nil, err 379 } 380 381 // Parse the SDP answer 382 var answer webrtc.SessionDescription 383 answer.Type = webrtc.SDPTypeAnswer 384 answer.SDP = string(answerBytes) 385 386 // Apply the answer as remote description 387 err = peerConnection.SetRemoteDescription(answer) 388 if err != nil { 389 return nil, err 390 } 391 392 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 393 <-gatherComplete 394 395 conn := &WHIPConnection{ 396 peerConnection: peerConnection, 397 audioTrack: audioTrack, 398 videoTrack: videoTrack, 399 did: did, 400 } 401 402 return conn, nil 403}