Live video on the AT Protocol

wrap context with arc mutex (this is bad actually)

+128 -95
+128 -95
rust/iroh-streamplace/src/lib.rs
··· 23 23 str::FromStr, 24 24 sync::Arc, 25 25 }; 26 + use tokio::sync::Mutex; 26 27 27 28 use bytes::Bytes; 28 - use iroh::{NodeId, PublicKey, RelayMode, SecretKey, discovery::static_provider::StaticProvider}; 29 + use iroh::{NodeId, RelayMode, SecretKey, discovery::static_provider::StaticProvider}; 29 30 use iroh_base::ticket::NodeTicket; 30 31 use iroh_gossip::{net::Gossip, proto::TopicId}; 31 32 use irpc::{WithChannels, rpc::RemoteService}; ··· 190 191 Receiver(Arc<dyn DataHandler>), 191 192 } 192 193 194 + impl Clone for HandlerMode { 195 + fn clone(&self) -> Self { 196 + match self { 197 + HandlerMode::Sender => HandlerMode::Sender, 198 + HandlerMode::Forwarder => HandlerMode::Forwarder, 199 + HandlerMode::Receiver(h) => HandlerMode::Receiver(h.clone()), 200 + } 201 + } 202 + } 203 + 193 204 impl HandlerMode { 194 205 pub fn mode_str(&self) -> &'static str { 195 206 match self { ··· 211 222 rpc_rx: tokio::sync::mpsc::Receiver<RpcMessage>, 212 223 /// Receiver for API messages from the user 213 224 api_rx: tokio::sync::mpsc::Receiver<ApiMessage>, 225 + /// Shared state wrapped for concurrent access 226 + state: Arc<Mutex<ActorState>>, 227 + } 228 + 229 + /// Shared mutable state for the actor 230 + struct ActorState { 214 231 /// nodes I need to send to for each stream 215 232 subscribers: BTreeMap<String, BTreeSet<NodeId>>, 216 233 /// nodes I am subscribed to ··· 311 328 let client = iroh_smol_kv::Client::local(topic, db_config); 312 329 let write = db::WriteScope::new(client.write(secret.clone())); 313 330 let client = db::Db::new(client); 314 - let actor = Self { 315 - rpc_rx, 316 - api_rx, 331 + let state = Arc::new(Mutex::new(ActorState { 317 332 subscribers: BTreeMap::new(), 318 333 subscriptions: BTreeMap::new(), 319 334 connections: ConnectionPool::new(router.endpoint().clone()), ··· 324 339 client: client.clone(), 325 340 tasks: FuturesUnordered::new(), 326 341 config: Arc::new(config), 342 + })); 343 + let actor = Self { 344 + rpc_rx, 345 + api_rx, 346 + state, 327 347 }; 328 348 let api = Node { 329 349 client: Arc::new(client), ··· 347 367 error!("rpc channel closed"); 348 368 break; 349 369 }; 350 - self.handle_rpc(msg).instrument(trace_span!("rpc")).await; 370 + let state = self.state.clone(); 371 + tokio::spawn(async move { 372 + Self::handle_rpc(state, msg).instrument(trace_span!("rpc")).await; 373 + }); 351 374 } 352 375 msg = self.api_rx.recv() => { 353 376 trace!("received remote rpc message"); 354 377 let Some(msg) = msg else { 355 378 break; 356 379 }; 357 - if let Some(shutdown) = self.handle_api(msg).instrument(trace_span!("api")).await { 380 + let state = self.state.clone(); 381 + let shutdown = tokio::spawn(async move { 382 + Self::handle_api(state, msg).instrument(trace_span!("api")).await 383 + }).await.ok().flatten(); 384 + if let Some(shutdown) = shutdown { 358 385 shutdown.send(()).await.ok(); 359 386 break; 360 387 } 361 388 } 362 - res = self.tasks.next(), if !self.tasks.is_empty() => { 389 + else => { 363 390 trace!("processing task"); 364 - let Some((remote_id, res)) = res else { 365 - error!("task finished but no result"); 366 - break; 367 - }; 368 - match res { 369 - Ok(()) => {} 370 - Err(RpcTaskError::Timeout { source }) => { 371 - warn!("call to {remote_id} timed out: {source}"); 372 - } 373 - Err(RpcTaskError::Task { source }) => { 374 - warn!("call to {remote_id} failed: {source}"); 391 + // poll tasks 392 + let mut state = self.state.lock().await; 393 + if !state.tasks.is_empty() { 394 + if let Some((remote_id, res)) = state.tasks.next().await { 395 + match res { 396 + Ok(()) => {} 397 + Err(RpcTaskError::Timeout { source }) => { 398 + warn!("call to {remote_id} timed out: {source}"); 399 + } 400 + Err(RpcTaskError::Task { source }) => { 401 + warn!("call to {remote_id} failed: {source}"); 402 + } 403 + } 404 + state.connections.remove(&remote_id); 375 405 } 376 406 } 377 - self.connections.remove(&remote_id); 378 407 } 379 408 } 380 409 } 381 410 warn!("RPC Actor loop has closed"); 382 411 } 383 412 384 - async fn update_subscriber_meta(&mut self, key: &str) { 385 - let n = self 413 + async fn update_subscriber_meta(state: &mut ActorState, key: &str) { 414 + let n = state 386 415 .subscribers 387 416 .get(key) 388 417 .map(|s| s.len()) 389 418 .unwrap_or_default(); 390 419 let v = n.to_string().into_bytes(); 391 - self.write 420 + state 421 + .write 392 422 .put_impl(Some(key.as_bytes().to_vec()), b"subscribers", v.into()) 393 423 .await 394 424 .ok(); 395 425 } 396 426 397 427 /// Requests from remote nodes 398 - async fn handle_rpc(&mut self, msg: RpcMessage) { 428 + async fn handle_rpc(state: Arc<Mutex<ActorState>>, msg: RpcMessage) { 399 429 trace!("RPC.handle_rpc"); 400 430 match msg { 401 431 RpcMessage::Subscribe(msg) => { ··· 405 435 inner: rpc::Subscribe { key, remote_id }, 406 436 .. 407 437 } = msg; 408 - self.subscribers 409 - .entry(key.clone()) 410 - .or_default() 411 - .insert(remote_id); 412 - self.update_subscriber_meta(&key).await; 438 + { 439 + let mut state = state.lock().await; 440 + state 441 + .subscribers 442 + .entry(key.clone()) 443 + .or_default() 444 + .insert(remote_id); 445 + Self::update_subscriber_meta(&mut state, &key).await; 446 + } 413 447 tx.send(()).await.ok(); 414 448 } 415 449 RpcMessage::Unsubscribe(msg) => { ··· 419 453 inner: rpc::Unsubscribe { key, remote_id }, 420 454 .. 421 455 } = msg; 422 - if let Some(e) = self.subscribers.get_mut(&key) 423 - && !e.remove(&remote_id) 424 456 { 425 - warn!( 426 - "unsubscribe: no subscription for {} from {}", 427 - key, remote_id 428 - ); 457 + let mut state = state.lock().await; 458 + if let Some(e) = state.subscribers.get_mut(&key) 459 + && !e.remove(&remote_id) 460 + { 461 + warn!( 462 + "unsubscribe: no subscription for {} from {}", 463 + key, remote_id 464 + ); 465 + } 466 + if let Some(subscriptions) = state.subscribers.get(&key) 467 + && subscriptions.is_empty() 468 + { 469 + state.subscribers.remove(&key); 470 + } 471 + Self::update_subscriber_meta(&mut state, &key).await; 429 472 } 430 - if let Some(subscriptions) = self.subscribers.get(&key) 431 - && subscriptions.is_empty() 432 - { 433 - self.subscribers.remove(&key); 434 - } 435 - self.update_subscriber_meta(&key).await; 436 473 tx.send(()).await.ok(); 437 474 } 438 475 RpcMessage::RecvSegment(msg) => { ··· 442 479 inner: rpc::RecvSegment { key, from, data }, 443 480 .. 444 481 } = msg; 445 - match &self.handler { 482 + let mut state = state.lock().await; 483 + match &state.handler { 446 484 HandlerMode::Sender => { 447 485 warn!("received segment but in sender mode"); 448 486 } 449 487 HandlerMode::Forwarder => { 450 488 trace!("forwarding segment"); 451 - if let Some(remotes) = self.subscribers.get(&key) { 452 - Self::handle_send( 453 - &mut self.tasks, 454 - &mut self.connections, 455 - &self.config, 456 - key, 457 - data, 458 - remotes, 459 - ); 489 + if let Some(remotes) = state.subscribers.get(&key).cloned() { 490 + Self::handle_send(&mut state, key, data, &remotes); 460 491 } else { 461 492 trace!("no subscribers for stream {}", key); 462 493 } 463 494 } 464 495 HandlerMode::Receiver(handler) => { 465 - if self.subscriptions.contains_key(&key) { 496 + if state.subscriptions.contains_key(&key) { 466 497 let from = Arc::new(from.into()); 498 + let handler = handler.clone(); 499 + drop(state); // release lock before async call 467 500 handler.handle_data(from, key, data.to_vec()).await; 468 501 } else { 469 502 warn!("received segment for unsubscribed key: {}", key); ··· 475 508 } 476 509 } 477 510 478 - async fn handle_api(&mut self, msg: ApiMessage) -> Option<irpc::channel::oneshot::Sender<()>> { 511 + async fn handle_api( 512 + state: Arc<Mutex<ActorState>>, 513 + msg: ApiMessage, 514 + ) -> Option<irpc::channel::oneshot::Sender<()>> { 479 515 trace!("RPC.handle_api"); 480 516 match msg { 481 517 ApiMessage::SendSegment(msg) => { ··· 485 521 inner: api::SendSegment { key, data }, 486 522 .. 487 523 } = msg; 488 - if let Some(remotes) = self.subscribers.get(&key) { 489 - Self::handle_send( 490 - &mut self.tasks, 491 - &mut self.connections, 492 - &self.config, 493 - key, 494 - data, 495 - remotes, 496 - ); 497 - } else { 498 - trace!("no subscribers for stream {}", key); 524 + { 525 + let mut state = state.lock().await; 526 + if let Some(remotes) = state.subscribers.get(&key).cloned() { 527 + Self::handle_send(&mut state, key, data, &remotes); 528 + } else { 529 + trace!("no subscribers for stream {}", key); 530 + } 499 531 } 500 532 tx.send(()).await.ok(); 501 533 } ··· 506 538 inner: api::Subscribe { key, remote_id }, 507 539 .. 508 540 } = msg; 509 - let conn = self.connections.get(&remote_id); 541 + let conn = { 542 + let mut state = state.lock().await; 543 + state.connections.get(&remote_id) 544 + }; 545 + let node_id = state.lock().await.router.endpoint().node_id(); 510 546 tx.send(()).await.ok(); 511 547 trace!(remote = %remote_id.fmt_short(), key = %key, "send rpc::Subscribe message"); 512 548 conn.rpc 513 549 .rpc(rpc::Subscribe { 514 550 key: key.clone(), 515 - remote_id: self.node_id(), 551 + remote_id: node_id, 516 552 }) 517 553 .await 518 554 .ok(); 519 - trace!(remote = %remote_id.fmt_short(), key = %key, "inserting subscription"); 520 555 self.subscriptions.insert(key, remote_id); 521 - trace!("finished inserting subscription"); 556 + tx.send(()).await.ok(); 522 557 } 523 558 ApiMessage::Unsubscribe(msg) => { 524 559 trace!(inner = ?msg.inner, "ApiMessage::Unsubscribe"); ··· 527 562 inner: api::Unsubscribe { key, remote_id }, 528 563 .. 529 564 } = msg; 530 - let conn = self.connections.get(&remote_id); 565 + let conn = { 566 + let mut state = state.lock().await; 567 + state.connections.get(&remote_id) 568 + }; 569 + let node_id = state.lock().await.router.endpoint().node_id(); 531 570 tx.send(()).await.ok(); 532 571 conn.rpc 533 572 .rpc(rpc::Unsubscribe { 534 573 key: key.clone(), 535 - remote_id: self.node_id(), 574 + remote_id: node_id, 536 575 }) 537 576 .await 538 577 .ok(); 539 - self.subscriptions.remove(&key); 578 + state.lock().await.subscriptions.remove(&key); 579 + tx.send(()).await.ok(); 540 580 } 541 581 ApiMessage::AddTickets(msg) => { 542 582 trace!(inner = ?msg.inner, "ApiMessage::AddTickets"); ··· 545 585 inner: api::AddTickets { peers }, 546 586 .. 547 587 } = msg; 588 + let state = state.lock().await; 548 589 for addr in &peers { 549 - self.sp.add_node_info(addr.clone()); 590 + state.sp.add_node_info(addr.clone()); 550 591 } 551 - // self.client.inner().join_peers(ids).await.ok(); 552 592 tx.send(()).await.ok(); 553 593 } 554 594 ApiMessage::JoinPeers(msg) => { ··· 558 598 inner: api::JoinPeers { peers }, 559 599 .. 560 600 } = msg; 601 + let state = state.lock().await; 602 + let node_id = state.router.endpoint().node_id(); 561 603 let ids = peers 562 604 .iter() 563 605 .map(|a| a.node_id) 564 - .filter(|id| *id != self.node_id()) 606 + .filter(|id| *id != node_id) 565 607 .collect::<HashSet<_>>(); 566 608 for addr in &peers { 567 - self.sp.add_node_info(addr.clone()); 609 + state.sp.add_node_info(addr.clone()); 568 610 } 569 - self.client.inner().join_peers(ids).await.ok(); 611 + let client = state.client.clone(); 612 + drop(state); 613 + client.inner().join_peers(ids).await.ok(); 570 614 tx.send(()).await.ok(); 571 615 } 572 616 ApiMessage::GetNodeAddr(msg) => { 573 617 trace!(inner = ?msg.inner, "ApiMessage::GetNodeAddr"); 574 618 let WithChannels { tx, .. } = msg; 575 - if !self.config.disable_relay { 576 - // don't await home relay if we have disabled relays, this will hang forever 577 - self.router.endpoint().online().await; 619 + let state = state.lock().await; 620 + if !state.config.disable_relay { 621 + state.router.endpoint().online().await; 578 622 } 579 - let addr = self.router.endpoint().node_addr(); 623 + let addr = state.router.endpoint().node_addr(); 580 624 tx.send(addr).await.ok(); 581 625 } 582 626 ApiMessage::Shutdown(msg) => { ··· 587 631 None 588 632 } 589 633 590 - fn handle_send( 591 - tasks: &mut Tasks, 592 - connections: &mut ConnectionPool, 593 - config: &Arc<Config>, 594 - key: String, 595 - data: Bytes, 596 - remotes: &BTreeSet<NodeId>, 597 - ) { 598 - let me = connections.endpoint.node_id(); 634 + fn handle_send(state: &mut ActorState, key: String, data: Bytes, remotes: &BTreeSet<NodeId>) { 635 + let me = state.connections.endpoint.node_id(); 599 636 let msg = rpc::RecvSegment { 600 637 key, 601 638 data, ··· 603 640 }; 604 641 for remote in remotes { 605 642 trace!(remote = %remote.fmt_short(), key = %msg.key, "handle_send to remote"); 606 - let conn = connections.get(remote); 607 - tasks.push(Box::pin(Self::forward_task( 608 - config.clone(), 643 + let conn = state.connections.get(remote); 644 + state.tasks.push(Box::pin(Self::forward_task( 645 + state.config.clone(), 609 646 conn, 610 647 msg.clone(), 611 648 ))); ··· 624 661 } 625 662 .await; 626 663 (id, res) 627 - } 628 - 629 - fn node_id(&self) -> PublicKey { 630 - self.router.endpoint().node_id() 631 664 } 632 665 } 633 666