Live video on the AT Protocol
at eli/docs-url-fix 379 lines 9.3 kB view raw
1package cmd 2 3import ( 4 "context" 5 "flag" 6 "fmt" 7 "io" 8 "net/http" 9 "strings" 10 "sync" 11 "time" 12 13 "github.com/pion/webrtc/v4" 14 "golang.org/x/sync/errgroup" 15 "stream.place/streamplace/pkg/log" 16) 17 18func WHEP(args []string) error { 19 fs := flag.NewFlagSet("whep", flag.ExitOnError) 20 count := fs.Int("count", 1, "number of concurrent streams (for load testing)") 21 duration := fs.Duration("duration", 0, "stop after this long") 22 endpoint := fs.String("endpoint", "", "endpoint to send the WHEP request to") 23 err := fs.Parse(args) 24 25 if err != nil { 26 return err 27 } 28 29 ctx := context.Background() 30 if *duration > 0 { 31 var cancel context.CancelFunc 32 ctx, cancel = context.WithTimeout(ctx, *duration) 33 defer cancel() 34 } 35 36 w := &WHEPClient{ 37 Endpoint: *endpoint, 38 Count: *count, 39 } 40 41 return w.WHEP(ctx) 42} 43 44type WHEPClient struct { 45 StreamKey string 46 File string 47 Endpoint string 48 Count int 49 FreezeAfter time.Duration 50} 51 52type WHEPConnection struct { 53 peerConnection *webrtc.PeerConnection 54 audioTrack *webrtc.TrackLocalStaticSample 55 videoTrack *webrtc.TrackLocalStaticSample 56 did string 57 Done func() <-chan struct{} 58} 59 60func (w *WHEPClient) StartWHEPConnection(ctx context.Context) (*WHEPConnection, error) { 61 62 // Prepare the configuration 63 config := webrtc.Configuration{} 64 65 // Create a new RTCPeerConnection 66 peerConnection, err := webrtc.NewPeerConnection(config) 67 if err != nil { 68 return nil, err 69 } 70 71 // Track statistics 72 type trackStats struct { 73 total int 74 lastTotal int 75 lastUpdate time.Time 76 mu sync.Mutex 77 } 78 79 stats := map[string]*trackStats{ 80 "video": {lastUpdate: time.Now()}, 81 "audio": {lastUpdate: time.Now()}, 82 } 83 84 // Create a ticker to print combined bitrate every 5 seconds 85 ticker := time.NewTicker(5 * time.Second) 86 87 // Start a goroutine to print combined bitrate 88 go func() { 89 for { 90 select { 91 case <-ticker.C: 92 currentTime := time.Now() 93 94 // Lock both stats to get a consistent snapshot 95 for _, s := range stats { 96 s.mu.Lock() 97 } 98 99 videoStats := stats["video"] 100 audioStats := stats["audio"] 101 102 videoElapsed := currentTime.Sub(videoStats.lastUpdate).Seconds() 103 audioElapsed := currentTime.Sub(audioStats.lastUpdate).Seconds() 104 105 videoBytes := videoStats.total - videoStats.lastTotal 106 audioBytes := audioStats.total - audioStats.lastTotal 107 108 videoBitrate := float64(videoBytes) * 8 / videoElapsed / 1000 // kbps 109 audioBitrate := float64(audioBytes) * 8 / audioElapsed / 1000 // kbps 110 111 log.Log(ctx, "bitrate stats", 112 "video", fmt.Sprintf("%.2f kbps (%.2f KB)", videoBitrate, float64(videoBytes)/1000), 113 "audio", fmt.Sprintf("%.2f kbps (%.2f KB)", audioBitrate, float64(audioBytes)/1000), 114 "total", fmt.Sprintf("%.2f kbps", videoBitrate+audioBitrate)) 115 116 // Update last values 117 videoStats.lastTotal = videoStats.total 118 videoStats.lastUpdate = currentTime 119 audioStats.lastTotal = audioStats.total 120 audioStats.lastUpdate = currentTime 121 122 // Unlock stats 123 for _, s := range stats { 124 s.mu.Unlock() 125 } 126 127 case <-ctx.Done(): 128 ticker.Stop() 129 return 130 } 131 } 132 }() 133 134 go func() { 135 ctx, cancel := context.WithCancel(ctx) 136 peerConnection.OnTrack(func(track *webrtc.TrackRemote, receiver *webrtc.RTPReceiver) { 137 log.Log(ctx, "track received", "track", track.ID()) 138 139 // Determine track type 140 trackType := "video" 141 if track.Kind() == webrtc.RTPCodecTypeAudio { 142 trackType = "audio" 143 } 144 145 trackStat := stats[trackType] 146 147 for { 148 if ctx.Err() != nil { 149 return 150 } 151 rtp, _, err := track.ReadRTP() 152 if err != nil { 153 log.Log(ctx, "error reading RTP", "error", err) 154 cancel() 155 return 156 } 157 158 trackStat.mu.Lock() 159 trackStat.total += len(rtp.Payload) 160 trackStat.mu.Unlock() 161 } 162 }) 163 peerConnection.OnICEConnectionStateChange(func(connectionState webrtc.ICEConnectionState) { 164 log.Log(ctx, "WHEP connection State has changed", "state", connectionState.String()) 165 for _, state := range failureStates { 166 if connectionState == state { 167 log.Log(ctx, "connection failed, cancelling") 168 cancel() 169 } 170 } 171 }) 172 173 <-ctx.Done() 174 peerConnection.Close() 175 }() 176 peerConnection.OnICECandidate(func(candidate *webrtc.ICECandidate) { 177 log.Log(ctx, "ICE candidate", "candidate", candidate) 178 }) 179 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeVideo, webrtc.RTPTransceiverInit{ 180 Direction: webrtc.RTPTransceiverDirectionRecvonly, 181 }); err != nil { 182 return nil, fmt.Errorf("failed to add video transceiver: %w", err) 183 } 184 if _, err := peerConnection.AddTransceiverFromKind(webrtc.RTPCodecTypeAudio, webrtc.RTPTransceiverInit{ 185 Direction: webrtc.RTPTransceiverDirectionRecvonly, 186 }); err != nil { 187 return nil, fmt.Errorf("failed to add audio transceiver: %w", err) 188 } 189 190 // Create an offer 191 offer, err := peerConnection.CreateOffer(nil) 192 if err != nil { 193 return nil, err 194 } 195 196 // Set the generated offer as our LocalDescription 197 err = peerConnection.SetLocalDescription(offer) 198 if err != nil { 199 return nil, err 200 } 201 202 // Wait for ICE gathering to complete 203 // gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 204 // <-gatherComplete 205 206 // Create HTTP client and prepare the request 207 client := &http.Client{} 208 209 // Send the WHIP request to the server 210 req, err := http.NewRequest("POST", w.Endpoint, strings.NewReader(offer.SDP)) 211 if err != nil { 212 return nil, err 213 } 214 req.Header.Set("Content-Type", "application/sdp") 215 216 // Execute the request 217 resp, err := client.Do(req) 218 if err != nil { 219 return nil, err 220 } 221 defer resp.Body.Close() 222 if resp.StatusCode != 201 { 223 return nil, fmt.Errorf("status code: %d", resp.StatusCode) 224 } 225 226 // Read and process the answer 227 answerBytes, err := io.ReadAll(resp.Body) 228 if err != nil { 229 return nil, err 230 } 231 232 // Parse the SDP answer 233 var answer webrtc.SessionDescription 234 answer.Type = webrtc.SDPTypeAnswer 235 answer.SDP = string(answerBytes) 236 237 // Apply the answer as remote description 238 err = peerConnection.SetRemoteDescription(answer) 239 if err != nil { 240 return nil, err 241 } 242 243 gatherComplete := webrtc.GatheringCompletePromise(peerConnection) 244 <-gatherComplete 245 246 conn := &WHEPConnection{ 247 peerConnection: peerConnection, 248 Done: ctx.Done, 249 } 250 251 return conn, nil 252} 253 254func (w *WHEPClient) WHEP(ctx context.Context) error { 255 ctx, cancel := context.WithCancel(ctx) 256 defer cancel() 257 258 conns := make([]*WHEPConnection, w.Count) 259 g := &errgroup.Group{} 260 for i := 0; i < w.Count; i++ { 261 g.Go(func() error { 262 conn, err := w.StartWHEPConnection(ctx) 263 if err != nil { 264 return err 265 } 266 conns[i] = conn 267 268 <-conn.Done() 269 270 return nil 271 }) 272 } 273 274 err := g.Wait() 275 if err != nil { 276 return err 277 } 278 // if err := g.Wait(); err != nil { 279 // if err := g.Wait(); err != nil { 280 // return err 281 // } 282 // // Start a ticker to print elapsed duration every second 283 // go func() { 284 // ticker := time.NewTicker(time.Second) 285 // defer ticker.Stop() 286 287 // for { 288 // <-ticker.C 289 // for i, duration := range accumulators { 290 // trackType := "video" 291 // if i == 1 { 292 // trackType = "audio" 293 // } 294 // target := startTime.Add(time.Duration(accumulators[i])) 295 // diff := time.Since(target) 296 // log.Debug(ctx, "elapsed duration", "track", trackType, "duration", duration, "diff", diff) 297 // } 298 // } 299 // }() 300 301 // errCh := make(chan error, 1) 302 303 // for i, _ := range sinks { 304 // func(i int) { 305 // sink := sinks[i] 306 // trackType := "video" 307 // if i == 1 { 308 // trackType = "audio" 309 // } 310 311 // sink.SetCallbacks(&app.SinkCallbacks{ 312 // NewSampleFunc: func(sink *app.Sink) gst.FlowReturn { 313 314 // sample := sink.PullSample() 315 // if sample == nil { 316 // return gst.FlowEOS 317 // } 318 319 // buffer := sample.GetBuffer() 320 // if buffer == nil { 321 // return gst.FlowError 322 // } 323 324 // samples := buffer.Map(gst.MapRead).Bytes() 325 // defer buffer.Unmap() 326 327 // durationPtr := buffer.Duration().AsDuration() 328 // var duration time.Duration 329 // if durationPtr == nil { 330 // errCh <- fmt.Errorf("%v duration: nil", trackType) 331 // return gst.FlowError 332 // } else { 333 // // fmt.Printf("%v duration: %v\n", trackType, *durationPtr) 334 // duration = *durationPtr 335 // } 336 337 // accumulators[i] += duration 338 339 // if w.FreezeAfter == 0 || time.Since(startTime) < w.FreezeAfter { 340 // for _, conn := range conns { 341 // if trackType == "video" { 342 // if err := conn.videoTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 343 // log.Log(ctx, "error writing video sample", "error", err) 344 // errCh <- err 345 // return gst.FlowError 346 // } 347 // } else { 348 // if err := conn.audioTrack.WriteSample(pionmedia.Sample{Data: samples, Duration: duration}); err != nil { 349 // log.Log(ctx, "error writing video sample", "error", err) 350 // errCh <- err 351 // return gst.FlowError 352 // } 353 // } 354 // } 355 // } 356 357 // return gst.FlowOK 358 // }, 359 // }) 360 // }(i) 361 // } 362 363 // go func() { 364 // media.HandleBusMessages(ctx, pipeline) 365 // cancel() 366 // }() 367 368 // if err = pipeline.SetState(gst.StatePlaying); err != nil { 369 // return err 370 // } 371 // select { 372 // case err := <-errCh: 373 // return err 374 // case <-ctx.Done(): 375 // return ctx.Err() 376 // } 377 378 return nil 379}