Live video on the AT Protocol
1package cmd
2
3import (
4 "context"
5 "flag"
6 "fmt"
7 "io"
8 "net/http"
9 "strings"
10 "time"
11
12 atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
13 "github.com/go-gst/go-gst/gst"
14 "github.com/go-gst/go-gst/gst/app"
15 "github.com/pion/webrtc/v4"
16 pionmedia "github.com/pion/webrtc/v4/pkg/media"
17 "golang.org/x/sync/errgroup"
18 "stream.place/streamplace/pkg/gstinit"
19 "stream.place/streamplace/pkg/log"
20 "stream.place/streamplace/pkg/media"
21)
22
23func WHIP(args []string) error {
24 fs := flag.NewFlagSet("whip", flag.ExitOnError)
25 streamKey := fs.String("stream-key", "", "stream key")
26 count := fs.Int("count", 1, "number of concurrent streams (for load testing)")
27 viewers := fs.Int("viewers", 0, "number of viewers to simulate per stream")
28 duration := fs.Duration("duration", 0, "duration of the stream")
29 file := fs.String("file", "", "file to stream (needs to be an MP4 containing H264 video and Opus audio)")
30 endpoint := fs.String("endpoint", "http://127.0.0.1:38080", "endpoint to send the WHIP request to")
31 freezeAfter := fs.Duration("freeze-after", 0, "freeze the stream after the given duration")
32 err := fs.Parse(args)
33 if *file == "" {
34 return fmt.Errorf("file is required")
35 }
36 if err != nil {
37 return err
38 }
39 gstinit.InitGST()
40
41 ctx := context.Background()
42 if *duration > 0 {
43 var cancel context.CancelFunc
44 ctx, cancel = context.WithTimeout(ctx, *duration)
45 defer cancel()
46 }
47
48 w := &WHIPClient{
49 StreamKey: *streamKey,
50 File: *file,
51 Endpoint: *endpoint,
52 Count: *count,
53 FreezeAfter: *freezeAfter,
54 Viewers: *viewers,
55 }
56
57 return w.WHIP(ctx)
58}
59
60type WHIPClient struct {
61 StreamKey string
62 File string
63 Endpoint string
64 Count int
65 FreezeAfter time.Duration
66 Viewers int
67}
68
69var failureStates = []webrtc.ICEConnectionState{
70 webrtc.ICEConnectionStateFailed,
71 webrtc.ICEConnectionStateDisconnected,
72 webrtc.ICEConnectionStateClosed,
73 webrtc.ICEConnectionStateCompleted,
74}
75
76type WHIPConnection struct {
77 peerConnection *webrtc.PeerConnection
78 audioTrack *webrtc.TrackLocalStaticSample
79 videoTrack *webrtc.TrackLocalStaticSample
80 did string
81}
82
83func (w *WHIPClient) WHIP(ctx context.Context) error {
84 ctx, cancel := context.WithCancel(ctx)
85 defer cancel()
86
87 pipelineSlice := []string{
88 "filesrc name=filesrc ! qtdemux name=demux",
89 "demux.video_0 ! tee name=video_tee",
90 "demux.audio_0 ! tee name=audio_tee",
91 "video_tee. ! queue ! h264parse config-interval=-1 ! video/x-h264,stream-format=byte-stream ! appsink name=videoappsink",
92 "audio_tee. ! queue ! opusparse ! appsink name=audioappsink",
93 // "matroskamux name=mux ! fakesink name=fakesink sync=true",
94 // "video_tee. ! mux.video_0",
95 // "audio_tee. ! mux.audio_0",
96 }
97
98 pipeline, err := gst.NewPipelineFromString(strings.Join(pipelineSlice, "\n"))
99 if err != nil {
100 return err
101 }
102
103 fileSrc, err := pipeline.GetElementByName("filesrc")
104 if err != nil {
105 return err
106 }
107
108 fileSrc.Set("location", w.File)
109
110 videoSink, err := pipeline.GetElementByName("videoappsink")
111 if err != nil {
112 return err
113 }
114
115 audioSink, err := pipeline.GetElementByName("audioappsink")
116 if err != nil {
117 return err
118 }
119
120 startTime := time.Now()
121 sinks := []*app.Sink{
122 app.SinkFromElement(videoSink),
123 app.SinkFromElement(audioSink),
124 }
125 // Create accumulators for tracking elapsed duration
126 accumulators := make([]time.Duration, len(sinks))
127
128 conns := make([]*WHIPConnection, w.Count)
129 g := &errgroup.Group{}
130 for i := 0; i < w.Count; i++ {
131 ctx := ctx
132 // var streamKey string
133 var did string
134 var streamKey string
135 if w.StreamKey != "" {
136 streamKey = w.StreamKey
137 } else {
138 priv, err := atcrypto.GeneratePrivateKeyK256()
139 if err != nil {
140 return err
141 }
142 pub, err := priv.PublicKey()
143 if err != nil {
144 return err
145 }
146
147 did = pub.DIDKey()
148 ctx = log.WithLogValues(ctx, "did", did)
149 streamKey = priv.Multibase()
150 }
151
152 g.Go(func() error {
153 conn, err := w.StartWHIPConnection(ctx, streamKey, did)
154 if err != nil {
155 return err
156 }
157 conns[i] = conn
158 ctx := log.WithLogValues(ctx, "did", did)
159 conn.peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) {
160 log.Log(ctx, "WHIP connection State has changed", "state", connectionState.String())
161 for _, state := range failureStates {
162 if connectionState == state {
163 log.Log(ctx, "connection failed, cancelling")
164 cancel()
165 }
166 }
167 })
168 go func() {
169 <-ctx.Done()
170 if conn.peerConnection != nil {
171 conn.peerConnection.Close()
172 }
173 }()
174 return nil
175 })
176 }
177
178 if err := g.Wait(); err != nil {
179 return err
180 }
181
182 // Start a ticker to print elapsed duration every second
183 go func() {
184 ticker := time.NewTicker(time.Second)
185 defer ticker.Stop()
186
187 for {
188 select {
189 case <-ctx.Done():
190 return
191 case <-ticker.C:
192 for i, duration := range accumulators {
193 trackType := "video"
194 if i == 1 {
195 trackType = "audio"
196 }
197 target := startTime.Add(time.Duration(accumulators[i]))
198 diff := time.Since(target)
199 log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff)
200 }
201 }
202 }
203 }()
204
205 errCh := make(chan error, 1)
206
207 for i, _ := range sinks {
208 func(i int) {
209 sink := sinks[i]
210 trackType := "video"
211 if i == 1 {
212 trackType = "audio"
213 }
214
215 sink.SetCallbacks(&app.SinkCallbacks{
216 NewSampleFunc: func(sink *app.Sink) gst.FlowReturn {
217
218 sample := sink.PullSample()
219 if sample == nil {
220 return gst.FlowEOS
221 }
222
223 buffer := sample.GetBuffer()
224 if buffer == nil {
225 return gst.FlowError
226 }
227
228 samples := buffer.Map(gst.MapRead).Bytes()
229 defer buffer.Unmap()
230
231 durationPtr := buffer.Duration().AsDuration()
232 var duration time.Duration
233 if durationPtr == nil {
234 errCh <- fmt.Errorf("%v duration: nil", trackType)
235 return gst.FlowError
236 } else {
237 // fmt.Printf("%v duration: %v\n", trackType, *durationPtr)
238 duration = *durationPtr
239 }
240
241 accumulators[i] += duration
242
243 if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter {
244 for _, conn := range conns {
245 if trackType == "video" {
246 if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
247 log.Log(ctx, "error writing video sample", "error", err)
248 errCh <- err
249 return gst.FlowError
250 }
251 } else {
252 if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil {
253 log.Log(ctx, "error writing video sample", "error", err)
254 errCh <- err
255 return gst.FlowError
256 }
257 }
258 }
259 }
260
261 return gst.FlowOK
262 },
263 })
264 }(i)
265 }
266
267 go func() {
268 media.HandleBusMessages(ctx, pipeline)
269 cancel()
270 }()
271
272 if err = pipeline.SetState(gst.StatePlaying); err != nil {
273 return err
274 }
275 if w.Viewers > 0 {
276 whepG, ctx := errgroup.WithContext(ctx)
277 for i := 0; i < w.Count; i++ {
278 did := conns[i].did
279 w := &WHEPClient{
280 Endpoint: fmt.Sprintf("%s/api/playback/%s/webrtc", w.Endpoint, did),
281 Count: w.Viewers,
282 }
283 whepG.Go(func() error {
284 return w.WHEP(ctx)
285 })
286 }
287 if err := whepG.Wait(); err != nil {
288 return err
289 }
290 }
291
292 <-ctx.Done()
293 err = pipeline.BlockSetState(gst.StateNull)
294 if err != nil {
295 return err
296 }
297
298 select {
299 case err := <-errCh:
300 return err
301 case <-ctx.Done():
302 return ctx.Err()
303 }
304}
305
306func (w *WHIPClient) StartWHIPConnection(ctx context.Context, streamKey string, did string) (*WHIPConnection, error) {
307
308 // Prepare the configuration
309 config := webrtc.Configuration{}
310
311 // Create a new RTCPeerConnection
312 peerConnection, err := webrtc.NewPeerConnection(config)
313 if err != nil {
314 return nil, err
315 }
316
317 // Create a audio track
318 audioTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "audio/opus"}, "audio", "pion1")
319 if err != nil {
320 return nil, err
321 }
322 _, err = peerConnection.AddTrack(audioTrack)
323 if err != nil {
324 return nil, err
325 }
326
327 // Create a video track
328 videoTrack, err := webrtc.NewTrackLocalStaticSample(webrtc.RTPCodecCapability{MimeType: "video/h264"}, "video", "pion2")
329 if err != nil {
330 return nil, err
331 }
332 _, err = peerConnection.AddTrack(videoTrack)
333 if err != nil {
334 return nil, err
335 }
336
337 // Create an offer
338 offer, err := peerConnection.CreateOffer(nil)
339 if err != nil {
340 return nil, err
341 }
342
343 // Set the generated offer as our LocalDescription
344 err = peerConnection.SetLocalDescription(offer)
345 if err != nil {
346 return nil, err
347 }
348
349 // Wait for ICE gathering to complete
350 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
351 // <-gatherComplete
352
353 // Create HTTP client and prepare the request
354 client := &http.Client{}
355
356 // Send the WHIP request to the server
357 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP))
358 if err != nil {
359 return nil, err
360 }
361 req.Header.Set("Authorization", "Bearer "+streamKey)
362 req.Header.Set("Content-Type", "application/sdp")
363
364 // Execute the request
365 resp, err := client.Do(req)
366 if err != nil {
367 return nil, err
368 }
369 defer resp.Body.Close()
370
371 // Read and process the answer
372 answerBytes, err := io.ReadAll(resp.Body)
373 if err != nil {
374 return nil, err
375 }
376
377 // Parse the SDP answer
378 var answer webrtc.SessionDescription
379 answer.Type = webrtc.SDPTypeAnswer
380 answer.SDP = string(answerBytes)
381
382 // Apply the answer as remote description
383 err = peerConnection.SetRemoteDescription(answer)
384 if err != nil {
385 return nil, err
386 }
387
388 gatherComplete := webrtc.GatheringCompletePromise(peerConnection)
389 <-gatherComplete
390
391 conn := &WHIPConnection{
392 peerConnection: peerConnection,
393 audioTrack: audioTrack,
394 videoTrack: videoTrack,
395 did: did,
396 }
397
398 return conn, nil
399}