Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11
12 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
13 "github.com/go-gst/go-gst/gst"
14 "github.com/go-gst/go-gst/gst/app"
15 "github.com/pion/webrtc/v4"
16 pionmedia "github.com/pion/webrtc/v4/pkg/media"
17 "golang.org/x/sync/errgroup"
18 "stream.place/streamplace/pkg/gstinit"
19 "stream.place/streamplace/pkg/log"
20 "stream.place/streamplace/pkg/media"
21)
22
23func WHIP(args []string) error {
24 fs := flag.NewFlagSet("whip", flag.ExitOnError)
25 streamKey := fs.String("stream-key", "", "stream key")
26 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
27 viewers := fs.Int("viewers", 0, "number of viewers to simulate per stream")
28 duration := fs.Duration("duration", 0, "duration of the stream")
29 file := fs.String("file", "", "file to stream (needs to be an MP4 containing H264 video and Opus audio)")
30 endpoint := fs.String("endpoint", "http://127.0.0.1:38080", "endpoint to send the WHIP request to")
31 freezeAfter := fs.Duration("freeze-after", 0, "freeze the stream after the given duration")
32 err := fs.Parse(args)
33 if *file == "" {
34 return fmt.Errorf("file is required")
35 }
36 if err != nil {
37 return err
38 }
39 gstinit.InitGST()
40
41 ctx := context.Background()
42 if *duration > 0 {
43 var cancel context.CancelFunc
44 ctx, cancel = context.WithTimeout(ctx, *duration)
45 defer cancel()
46 }
47
48 w := &WHIPClient{
49 StreamKey: *streamKey,
50 File: *file,
51 Endpoint: *endpoint,
52 Count: *count,
53 FreezeAfter: *freezeAfter,
54 Viewers: *viewers,
55 }
56
57 return w.WHIP(ctx)
58}
59
60type WHIPClient struct {
61 StreamKey string
62 File string
63 Endpoint string
64 Count int
65 FreezeAfter time.Duration
66 Viewers int
67}
68
69var failureStates = []webrtc.ICEConnectionState{
70 webrtc.ICEConnectionStateFailed,
71 webrtc.ICEConnectionStateDisconnected,
72 webrtc.ICEConnectionStateClosed,
73 webrtc.ICEConnectionStateCompleted,
74}
75
76type WHIPConnection struct {
77 peerConnection *webrtc.PeerConnection
78 audioTrack *webrtc.TrackLocalStaticSample
79 videoTrack *webrtc.TrackLocalStaticSample
80 did string
81}
82
83func (w *WHIPClient) WHIP(ctx context.Context) error {
84 ctx, cancel := context.WithCancel(ctx)
85 defer cancel()
86
87 pipelineSlice := []string{
88 "filesrc name=filesrc ! qtdemux name=demux",
89 "demux.video_0 ! tee name=video_tee",
90 "demux.audio_0 ! tee name=audio_tee",
91 "video_tee. ! queue ! h264parse config-interval=-1 ! video/x-h264,stream-format=byte-stream ! appsink name=videoappsink",
92 "audio_tee. ! queue ! opusparse ! appsink name=audioappsink",
93 // "matroskamux name=mux ! fakesink name=fakesink sync=true",
94 // "video_tee. ! mux.video_0",
95 // "audio_tee. ! mux.audio_0",
96 }
97
98 pipeline, err := gst.NewPipelineFromString(strings.Join(pipelineSlice, "\n"))
99 if err != nil {
100 return err
101 }
102
103 fileSrc, err := pipeline.GetElementByName("filesrc")
104 if err != nil {
105 return err
106 }
107
108 if err := fileSrc.Set("location", w.File); err != nil {
109 return err
110 }
111
112 videoSink, err := pipeline.GetElementByName("videoappsink")
113 if err != nil {
114 return err
115 }
116
117 audioSink, err := pipeline.GetElementByName("audioappsink")
118 if err != nil {
119 return err
120 }
121
122 startTime := time.Now()
123 sinks := []*app.Sink{
124 app.SinkFromElement(videoSink),
125 app.SinkFromElement(audioSink),
126 }
127 // Create accumulators for tracking elapsed duration
128 accumulators := make([]time.Duration, len(sinks))
129
130 conns := make([]*WHIPConnection, w.Count)
131 g := &errgroup.Group{}
132 for i := 0; i < w.Count; i++ {
133 ctx := ctx
134 // var streamKey string
135 var did string
136 var streamKey string
137 if w.StreamKey != "" {
138 streamKey = w.StreamKey
139 } else {
140 priv, err := atcrypto.GeneratePrivateKeyK256()
141 if err != nil {
142 return err
143 }
144 pub, err := priv.PublicKey()
145 if err != nil {
146 return err
147 }
148
149 did = pub.DIDKey()
150 ctx = log.WithLogValues(ctx, "did", did)
151 streamKey = priv.Multibase()
152 }
153
154 g.Go(func() error {
155 conn, err := w.StartWHIPConnection(ctx, streamKey, did)
156 if err != nil {
157 return err
158 }
159 conns[i] = conn
160 ctx := log.WithLogValues(ctx, "did", did)
161 conn.peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
162 log.Log(ctx, "WHIP connection State has changed", "state", connectionState.String())
163 for _, state := range failureStates {
164 if connectionState == state {
165 log.Log(ctx, "connection failed, cancelling")
166 cancel()
167 }
168 }
169 })
170 go func() {
171 <-ctx.Done()
172 if conn.peerConnection != nil {
173 conn.peerConnection.Close()
174 }
175 }()
176 return nil
177 })
178 }
179
180 if err := g.Wait(); err != nil {
181 return err
182 }
183
184 // Start a ticker to print elapsed duration every second
185 go func() {
186 ticker := time.NewTicker(time.Second)
187 defer ticker.Stop()
188
189 for {
190 select {
191 case <-ctx.Done():
192 return
193 case <-ticker.C:
194 for i, duration := range accumulators {
195 trackType := "video"
196 if i == 1 {
197 trackType = "audio"
198 }
199 target := startTime.Add(time.Duration(accumulators[i]))
200 diff := time.Since(target)
201 log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
202 }
203 }
204 }
205 }()
206
207 errCh := make(chan error, 1)
208
209 for i := range sinks {
210 func(i int) {
211 sink := sinks[i]
212 trackType := "video"
213 if i == 1 {
214 trackType = "audio"
215 }
216
217 sink.SetCallbacks(&app.SinkCallbacks{
218 NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
219
220 sample := sink.PullSample()
221 if sample == nil {
222 return gst.FlowEOS
223 }
224
225 buffer := sample.GetBuffer()
226 if buffer == nil {
227 return gst.FlowError
228 }
229
230 samples := buffer.Map(gst.MapRead).Bytes()
231 defer buffer.Unmap()
232
233 durationPtr := buffer.Duration().AsDuration()
234 var duration time.Duration
235 if durationPtr == nil {
236 errCh <- fmt.Errorf("%v duration: nil", trackType)
237 return gst.FlowError
238 } else {
239 // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
240 duration = *durationPtr
241 }
242
243 accumulators[i] += duration
244
245 if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
246 for _, conn := range conns {
247 if trackType == "video" {
248 if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
249 log.Log(ctx, "error writing video sample", "error", err)
250 errCh <- err
251 return gst.FlowError
252 }
253 } else {
254 if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
255 log.Log(ctx, "error writing video sample", "error", err)
256 errCh <- err
257 return gst.FlowError
258 }
259 }
260 }
261 }
262
263 return gst.FlowOK
264 },
265 })
266 }(i)
267 }
268
269 go func() {
270 if err := media.HandleBusMessages(ctx, pipeline); err != nil {
271 log.Log(ctx, "pipeline error", "error", err)
272 }
273 cancel()
274 }()
275
276 if err = pipeline.SetState(gst.StatePlaying); err != nil {
277 return err
278 }
279 if w.Viewers > 0 {
280 whepG, ctx := errgroup.WithContext(ctx)
281 for i := 0; i < w.Count; i++ {
282 did := conns[i].did
283 w := &WHEPClient{
284 Endpoint: fmt.Sprintf("%s/api/playback/%s/webrtc", w.Endpoint, did),
285 Count: w.Viewers,
286 }
287 whepG.Go(func() error {
288 return w.WHEP(ctx)
289 })
290 }
291 if err := whepG.Wait(); err != nil {
292 return err
293 }
294 }
295
296 <-ctx.Done()
297 err = pipeline.BlockSetState(gst.StateNull)
298 if err != nil {
299 return err
300 }
301
302 select {
303 case err := <-errCh:
304 return err
305 case <-ctx.Done():
306 return ctx.Err()
307 }
308}
309
310func (w *WHIPClient) StartWHIPConnection(ctx context.Context, streamKey string, did string) (*WHIPConnection, error) {
311
312 // Prepare the configuration
313 config := webrtc.Configuration{}
314
315 // Create a new RTCPeerConnection
316 peerConnection, err := webrtc.NewPeerConnection(config)
317 if err != nil {
318 return nil, err
319 }
320
321 // Create a audio track
322 audioTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "audio/opus"}, "audio", "pion1")
323 if err != nil {
324 return nil, err
325 }
326 _, err = peerConnection.AddTrack(audioTrack)
327 if err != nil {
328 return nil, err
329 }
330
331 // Create a video track
332 videoTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "video/h264"}, "video", "pion2")
333 if err != nil {
334 return nil, err
335 }
336 _, err = peerConnection.AddTrack(videoTrack)
337 if err != nil {
338 return nil, err
339 }
340
341 // Create an offer
342 offer, err := peerConnection.CreateOffer(nil)
343 if err != nil {
344 return nil, err
345 }
346
347 // Set the generated offer as our LocalDescription
348 err = peerConnection.SetLocalDescription(offer)
349 if err != nil {
350 return nil, err
351 }
352
353 // Wait for ICE gathering to complete
354 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
355 // <-gatherComplete
356
357 // Create HTTP client and prepare the request
358 client := &http.Client{}
359
360 // Send the WHIP request to the server
361 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
362 if err != nil {
363 return nil, err
364 }
365 req.Header.Set("Authorization", "Bearer "+streamKey)
366 req.Header.Set("Content-Type", "application/sdp")
367
368 // Execute the request
369 resp, err := client.Do(req)
370 if err != nil {
371 return nil, err
372 }
373 defer resp.Body.Close()
374
375 // Read and process the answer
376 answerBytes, err := io.ReadAll(resp.Body)
377 if err != nil {
378 return nil, err
379 }
380
381 // Parse the SDP answer
382 var answer webrtc.SessionDescription
383 answer.Type = webrtc.SDPTypeAnswer
384 answer.SDP = string(answerBytes)
385
386 // Apply the answer as remote description
387 err = peerConnection.SetRemoteDescription(answer)
388 if err != nil {
389 return nil, err
390 }
391
392 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
393 <-gatherComplete
394
395 conn := &WHIPConnection{
396 peerConnection: peerConnection,
397 audioTrack: audioTrack,
398 videoTrack: videoTrack,
399 did: did,
400 }
401
402 return conn, nil
403}