Live video on the AT Protocol
at eli/docker-linting 371 lines 9.0 kB view raw
1package cmd 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/pion/webrtc/v4" 14 "golang.org/x/sync/errgroup" 15 "stream.place/streamplace/pkg/log" 16) 17 18func WHEP(args []string) error { 19 fs := flag.NewFlagSet("whep", flag.ExitOnError) 20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)") 21 duration := fs.Duration("duration", 0, "stop after this long") 22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to") 23 err := fs.Parse(args) 24 25 if err != nil { 26 return err 27 } 28 29 ctx := context.Background() 30 if *duration > 0 { 31 var cancel context.CancelFunc 32 ctx, cancel = context.WithTimeout(ctx, *duration) 33 defer cancel() 34 } 35 36 w := &WHEPClient{ 37 Endpoint: *endpoint, 38 Count: *count, 39 } 40 41 return w.WHEP(ctx) 42} 43 44type WHEPClient struct { 45 StreamKey string 46 File string 47 Endpoint string 48 Count int 49 FreezeAfter time.Duration 50} 51 52type WHEPConnection struct { 53 peerConnection *webrtc.PeerConnection 54 audioTrack *webrtc.TrackLocalStaticSample 55 videoTrack *webrtc.TrackLocalStaticSample 56 did string 57 Done func() <-chan struct{} 58} 59 60func (w *WHEPClient) StartWHEPConnection(ctx context.Context) (*WHEPConnection, error) { 61 62 ctx, cancel := context.WithCancel(ctx) 63 64 // Prepare the configuration 65 config := webrtc.Configuration{} 66 67 // Create a new RTCPeerConnection 68 peerConnection, err := webrtc.NewPeerConnection(config) 69 if err != nil { 70 return nil, err 71 } 72 73 // Track statistics 74 type trackStats struct { 75 total int 76 lastTotal int 77 lastUpdate time.Time 78 mu sync.Mutex 79 } 80 81 stats := map[string]*trackStats{ 82 "video": {lastUpdate: time.Now()}, 83 "audio": {lastUpdate: time.Now()}, 84 } 85 86 // Create a ticker to print combined bitrate every 5 seconds 87 ticker := time.NewTicker(5 * time.Second) 88 89 // Start a goroutine to print combined bitrate 90 go func() { 91 for { 92 select { 93 case <-ticker.C: 94 currentTime := time.Now() 95 96 // Lock both stats to get a consistent snapshot 97 for _, s := range stats { 98 s.mu.Lock() 99 } 100 101 videoStats := stats["video"] 102 audioStats := stats["audio"] 103 104 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds() 105 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds() 106 107 videoBytes := videoStats.total - videoStats.lastTotal 108 audioBytes := audioStats.total - audioStats.lastTotal 109 110 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps 111 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps 112 113 log.Log(ctx, "bitrate stats", 114 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000), 115 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000), 116 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate)) 117 118 // Update last values 119 videoStats.lastTotal = videoStats.total 120 videoStats.lastUpdate = currentTime 121 audioStats.lastTotal = audioStats.total 122 audioStats.lastUpdate = currentTime 123 124 // Unlock stats 125 for _, s := range stats { 126 s.mu.Unlock() 127 } 128 129 case <-ctx.Done(): 130 ticker.Stop() 131 return 132 } 133 } 134 }() 135 136 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { 137 log.Log(ctx, "track received", "track", track.ID()) 138 139 // Determine track type 140 trackType := "video" 141 if track.Kind() == webrtc.RTPCodecTypeAudio { 142 trackType = "audio" 143 } 144 145 trackStat := stats[trackType] 146 147 for { 148 if ctx.Err() != nil { 149 return 150 } 151 rtp, _, err := track.ReadRTP() 152 if err != nil { 153 log.Log(ctx, "error reading RTP", "error", err) 154 cancel() 155 return 156 } 157 158 trackStat.mu.Lock() 159 trackStat.total += len(rtp.Payload) 160 trackStat.mu.Unlock() 161 } 162 }) 163 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 164 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String()) 165 for _, state := range failureStates { 166 if connectionState == state { 167 log.Log(ctx, "connection failed, cancelling") 168 cancel() 169 } 170 } 171 }) 172 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { 173 log.Log(ctx, "ICE candidate", "candidate", candidate) 174 }) 175 peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ 176 Direction: webrtc.RTPTransceiverDirectionRecvonly, 177 }) 178 peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ 179 Direction: webrtc.RTPTransceiverDirectionRecvonly, 180 }) 181 182 // Create an offer 183 offer, err := peerConnection.CreateOffer(nil) 184 if err != nil { 185 return nil, err 186 } 187 188 // Set the generated offer as our LocalDescription 189 err = peerConnection.SetLocalDescription(offer) 190 if err != nil { 191 return nil, err 192 } 193 194 // Wait for ICE gathering to complete 195 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 196 // <-gatherComplete 197 198 // Create HTTP client and prepare the request 199 client := &http.Client{} 200 201 // Send the WHIP request to the server 202 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 203 if err != nil { 204 return nil, err 205 } 206 req.Header.Set("Content-Type", "application/sdp") 207 208 // Execute the request 209 resp, err := client.Do(req) 210 if err != nil { 211 return nil, err 212 } 213 defer resp.Body.Close() 214 if resp.StatusCode != 201 { 215 return nil, fmt.Errorf("status code: %d", resp.StatusCode) 216 } 217 218 // Read and process the answer 219 answerBytes, err := io.ReadAll(resp.Body) 220 if err != nil { 221 return nil, err 222 } 223 224 // Parse the SDP answer 225 var answer webrtc.SessionDescription 226 answer.Type = webrtc.SDPTypeAnswer 227 answer.SDP = string(answerBytes) 228 229 // Apply the answer as remote description 230 err = peerConnection.SetRemoteDescription(answer) 231 if err != nil { 232 return nil, err 233 } 234 235 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 236 <-gatherComplete 237 238 conn := &WHEPConnection{ 239 peerConnection: peerConnection, 240 Done: ctx.Done, 241 } 242 243 return conn, nil 244} 245 246func (w *WHEPClient) WHEP(ctx context.Context) error { 247 ctx, cancel := context.WithCancel(ctx) 248 defer cancel() 249 250 conns := make([]*WHEPConnection, w.Count) 251 g := &errgroup.Group{} 252 for i := 0; i < w.Count; i++ { 253 g.Go(func() error { 254 conn, err := w.StartWHEPConnection(ctx) 255 if err != nil { 256 return err 257 } 258 conns[i] = conn 259 260 <-conn.Done() 261 262 return nil 263 }) 264 } 265 266 err := g.Wait() 267 if err != nil { 268 return err 269 } 270 // if err := g.Wait(); err != nil { 271 // if err := g.Wait(); err != nil { 272 // return err 273 // } 274 // // Start a ticker to print elapsed duration every second 275 // go func() { 276 // ticker := time.NewTicker(time.Second) 277 // defer ticker.Stop() 278 279 // for { 280 // <-ticker.C 281 // for i, duration := range accumulators { 282 // trackType := "video" 283 // if i == 1 { 284 // trackType = "audio" 285 // } 286 // target := startTime.Add(time.Duration(accumulators[i])) 287 // diff := time.Since(target) 288 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 289 // } 290 // } 291 // }() 292 293 // errCh := make(chan error, 1) 294 295 // for i, _ := range sinks { 296 // func(i int) { 297 // sink := sinks[i] 298 // trackType := "video" 299 // if i == 1 { 300 // trackType = "audio" 301 // } 302 303 // sink.SetCallbacks(&app.SinkCallbacks{ 304 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 305 306 // sample := sink.PullSample() 307 // if sample == nil { 308 // return gst.FlowEOS 309 // } 310 311 // buffer := sample.GetBuffer() 312 // if buffer == nil { 313 // return gst.FlowError 314 // } 315 316 // samples := buffer.Map(gst.MapRead).Bytes() 317 // defer buffer.Unmap() 318 319 // durationPtr := buffer.Duration().AsDuration() 320 // var duration time.Duration 321 // if durationPtr == nil { 322 // errCh <- fmt.Errorf("%v duration: nil", trackType) 323 // return gst.FlowError 324 // } else { 325 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 326 // duration = *durationPtr 327 // } 328 329 // accumulators[i] += duration 330 331 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 332 // for _, conn := range conns { 333 // if trackType == "video" { 334 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 335 // log.Log(ctx, "error writing video sample", "error", err) 336 // errCh <- err 337 // return gst.FlowError 338 // } 339 // } else { 340 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 341 // log.Log(ctx, "error writing video sample", "error", err) 342 // errCh <- err 343 // return gst.FlowError 344 // } 345 // } 346 // } 347 // } 348 349 // return gst.FlowOK 350 // }, 351 // }) 352 // }(i) 353 // } 354 355 // go func() { 356 // media.HandleBusMessages(ctx, pipeline) 357 // cancel() 358 // }() 359 360 // if err = pipeline.SetState(gst.StatePlaying); err != nil { 361 // return err 362 // } 363 // select { 364 // case err := <-errCh: 365 // return err 366 // case <-ctx.Done(): 367 // return ctx.Err() 368 // } 369 370 return nil 371}