328 lines
10 KiB
Rust
328 lines
10 KiB
Rust
use std::sync::Arc;
|
|
|
|
use dashmap::DashMap;
|
|
use tracing::Instrument;
|
|
use webrtc::api::interceptor_registry::register_default_interceptors;
|
|
use webrtc::api::media_engine::MIME_TYPE_OPUS;
|
|
use webrtc::api::{API, APIBuilder};
|
|
use webrtc::interceptor::registry::Registry;
|
|
use webrtc::peer_connection::configuration::RTCConfiguration;
|
|
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
|
|
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
|
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
|
|
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
|
|
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
|
|
use webrtc::track::track_remote::TrackRemote;
|
|
|
|
use crate::entity;
|
|
|
|
type PeerId = entity::user::Id;
|
|
type RoomId = entity::channel::Id;
|
|
|
|
struct PeerState {
|
|
peer_id: PeerId,
|
|
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
|
|
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
|
|
}
|
|
|
|
struct RoomState {
|
|
room_id: RoomId,
|
|
peers: DashMap<PeerId, PeerState>,
|
|
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Offer {
|
|
pub peer_id: PeerId,
|
|
pub sdp_offer: RTCSessionDescription,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AnswerSignal {
|
|
pub sdp_answer: RTCSessionDescription,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum WebRtcSignal {
|
|
Offer(OfferSignal),
|
|
Disconnect(PeerId),
|
|
RequestPeers {
|
|
response: tokio::sync::oneshot::Sender<Vec<PeerId>>,
|
|
},
|
|
Close,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct OfferSignal {
|
|
pub offer: Offer,
|
|
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
|
|
}
|
|
|
|
#[tracing::instrument(skip(signal))]
|
|
pub async fn webrtc_task(
|
|
room_id: RoomId,
|
|
signal: tokio::sync::mpsc::UnboundedReceiver<WebRtcSignal>,
|
|
) -> anyhow::Result<()> {
|
|
tracing::info!("Starting WebRTC task");
|
|
|
|
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
|
|
|
|
let mut skip_timeout = false;
|
|
|
|
let state = Arc::new(RoomState {
|
|
room_id,
|
|
peers: DashMap::new(),
|
|
close_signal,
|
|
});
|
|
|
|
let mut signal = signal;
|
|
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
|
|
media_engine.register_default_codecs()?;
|
|
let mut registry = Registry::new();
|
|
registry = register_default_interceptors(registry, &mut media_engine)?;
|
|
|
|
let api = APIBuilder::new()
|
|
.with_media_engine(media_engine)
|
|
.with_interceptor_registry(registry)
|
|
.build();
|
|
|
|
let api = Arc::new(api);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
biased;
|
|
_ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => {
|
|
tracing::debug!("initial timeout reached");
|
|
break;
|
|
}
|
|
_ = close_receiver.recv() => {
|
|
tracing::debug!("WebRTC task stopped");
|
|
break;
|
|
}
|
|
Some(signal) = signal.recv() => {
|
|
skip_timeout = true;
|
|
match signal {
|
|
WebRtcSignal::Offer(offer_signal) => {
|
|
let room_state = state.clone();
|
|
let api = api.clone();
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_peer(api, room_state, offer_signal).await {
|
|
tracing::error!("error handling peer: {}", e);
|
|
}
|
|
}.instrument(tracing::Span::current()));
|
|
}
|
|
WebRtcSignal::RequestPeers { response } => {
|
|
let peers = state
|
|
.peers
|
|
.iter()
|
|
.map(|pair| pair.key().clone())
|
|
.collect::<Vec<_>>();
|
|
|
|
let _ = response.send(peers);
|
|
}
|
|
WebRtcSignal::Disconnect(peer_id) => {
|
|
tracing::debug!("received disconnect signal for peer {}", peer_id);
|
|
cleanup_peer(state.clone(), peer_id).await;
|
|
}
|
|
WebRtcSignal::Close => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id))]
|
|
async fn handle_peer(
|
|
api: Arc<API>,
|
|
room_state: Arc<RoomState>,
|
|
offer_signal: OfferSignal,
|
|
) -> anyhow::Result<()> {
|
|
tracing::debug!("handling peer");
|
|
|
|
let config = RTCConfiguration {
|
|
..Default::default()
|
|
};
|
|
|
|
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
|
|
RTCRtpCodecCapability {
|
|
mime_type: MIME_TYPE_OPUS.to_string(),
|
|
..Default::default()
|
|
},
|
|
format!("audio-{}", offer_signal.offer.peer_id),
|
|
format!("room-{}", room_state.room_id),
|
|
));
|
|
|
|
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
|
|
|
|
let rtp_sender = peer_connection
|
|
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
|
|
.await?;
|
|
|
|
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
|
|
|
|
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
|
|
let outgoing_track_ = Arc::clone(&outgoing_track);
|
|
tokio::spawn(
|
|
async move {
|
|
let mut rtcp_buf = vec![0u8; 1500];
|
|
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
|
|
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
|
|
}
|
|
tracing::debug!(
|
|
"RTCP reader loop for outgoing track {} ended",
|
|
outgoing_track_.id()
|
|
);
|
|
}
|
|
.instrument(tracing::Span::current()),
|
|
);
|
|
|
|
let room_state_ = room_state.clone();
|
|
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
|
|
let room_state_ = room_state_.clone();
|
|
Box::pin(async move {
|
|
tracing::debug!("peer connection state changed: {:?}", state);
|
|
if state == RTCPeerConnectionState::Disconnected
|
|
|| state == RTCPeerConnectionState::Failed
|
|
{
|
|
tracing::debug!("peer connection closed");
|
|
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
|
|
}
|
|
})
|
|
}));
|
|
|
|
let room_state_ = room_state.clone();
|
|
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
|
|
tracing::debug!("track received: {:?}", track);
|
|
|
|
let room_state_ = room_state_.clone();
|
|
Box::pin(async move {
|
|
if track.kind() == RTPCodecType::Audio {
|
|
tracing::debug!("audio track received: {:?}", track);
|
|
tokio::spawn(async move {
|
|
if let Err(e) =
|
|
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
|
|
{
|
|
tracing::error!("error forwarding audio track: {}", e);
|
|
}
|
|
});
|
|
} else {
|
|
tracing::warn!("received non-audio track: {:?}", track);
|
|
}
|
|
})
|
|
}));
|
|
|
|
peer_connection
|
|
.set_remote_description(offer_signal.offer.sdp_offer)
|
|
.await?;
|
|
let answer = peer_connection.create_answer(None).await?;
|
|
|
|
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
|
|
|
|
peer_connection.set_local_description(answer).await?;
|
|
|
|
gathering_complete.recv().await;
|
|
|
|
tracing::debug!("ICE gathering complete");
|
|
|
|
let local_description = peer_connection
|
|
.local_description()
|
|
.await
|
|
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
|
|
|
|
let peer_state = PeerState {
|
|
peer_id: offer_signal.offer.peer_id,
|
|
peer_connection: Arc::clone(&peer_connection),
|
|
outgoing_audio_track: Arc::clone(&outgoing_track),
|
|
};
|
|
|
|
{
|
|
if let Some(old_peer) = room_state
|
|
.peers
|
|
.insert(offer_signal.offer.peer_id, peer_state)
|
|
{
|
|
let _ = old_peer.peer_connection.close().await;
|
|
}
|
|
}
|
|
|
|
let _ = offer_signal.response.send(AnswerSignal {
|
|
sdp_answer: local_description,
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id))]
|
|
async fn forward_audio_track(
|
|
room_state: Arc<RoomState>,
|
|
peer_id: PeerId,
|
|
track: Arc<TrackRemote>,
|
|
) -> anyhow::Result<()> {
|
|
let mut rtp_buf = vec![0u8; 1500];
|
|
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
|
|
let other_peer_tracks = room_state
|
|
.peers
|
|
.iter()
|
|
.filter_map(|pair| {
|
|
let peer_state = pair.value();
|
|
if peer_state.peer_id != peer_id {
|
|
Some(peer_state.outgoing_audio_track.clone())
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
if other_peer_tracks.is_empty() {
|
|
// tracing::warn!("no other peers to forward audio track to");
|
|
continue;
|
|
}
|
|
|
|
let write_futures = other_peer_tracks
|
|
.iter()
|
|
.map(|outgoing_track| outgoing_track.write_rtp(&rtp_packet));
|
|
|
|
let results = futures::future::join_all(write_futures).await;
|
|
for result in results {
|
|
if let Err(e) = result {
|
|
tracing::error!("error writing RTP packet: {}", e);
|
|
}
|
|
}
|
|
}
|
|
tracing::debug!(
|
|
"RTP Read loop ended for track {} from peer {}",
|
|
track.id(),
|
|
peer_id
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
|
|
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
|
|
tracing::debug!("cleaning up peer");
|
|
|
|
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
|
|
tracing::debug!("removed peer");
|
|
let pc = Arc::clone(&peer_state.peer_connection);
|
|
tokio::spawn(async move {
|
|
if let Err(e) = pc.close().await {
|
|
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
|
|
tracing::warn!("error closing peer connection: {}", e);
|
|
}
|
|
}
|
|
|
|
tracing::debug!("peer connection closed");
|
|
});
|
|
}
|
|
|
|
if room_state.peers.is_empty() {
|
|
tracing::debug!("no more peers in room, closing room");
|
|
let _ = room_state.close_signal.send(());
|
|
}
|
|
}
|