Files
diplom/src/webrtc/mod.rs
2025-05-17 23:52:20 +03:00

328 lines
10 KiB
Rust

use std::sync::Arc;
use dashmap::DashMap;
use tracing::Instrument;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MIME_TYPE_OPUS;
use webrtc::api::{API, APIBuilder};
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
use webrtc::track::track_remote::TrackRemote;
use crate::entity;
type PeerId = entity::user::Id;
type RoomId = entity::channel::Id;
struct PeerState {
peer_id: PeerId,
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
}
struct RoomState {
room_id: RoomId,
peers: DashMap<PeerId, PeerState>,
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
}
#[derive(Debug)]
pub struct Offer {
pub peer_id: PeerId,
pub sdp_offer: RTCSessionDescription,
}
#[derive(Debug)]
pub struct AnswerSignal {
pub sdp_answer: RTCSessionDescription,
}
#[derive(Debug)]
pub enum WebRtcSignal {
Offer(OfferSignal),
Disconnect(PeerId),
RequestPeers {
response: tokio::sync::oneshot::Sender<Vec<PeerId>>,
},
Close,
}
#[derive(Debug)]
pub struct OfferSignal {
pub offer: Offer,
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
}
#[tracing::instrument(skip(signal))]
pub async fn webrtc_task(
room_id: RoomId,
signal: tokio::sync::mpsc::UnboundedReceiver<WebRtcSignal>,
) -> anyhow::Result<()> {
tracing::info!("Starting WebRTC task");
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
let mut skip_timeout = false;
let state = Arc::new(RoomState {
room_id,
peers: DashMap::new(),
close_signal,
});
let mut signal = signal;
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
media_engine.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
let api = Arc::new(api);
loop {
tokio::select! {
biased;
_ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => {
tracing::debug!("initial timeout reached");
break;
}
_ = close_receiver.recv() => {
tracing::debug!("WebRTC task stopped");
break;
}
Some(signal) = signal.recv() => {
skip_timeout = true;
match signal {
WebRtcSignal::Offer(offer_signal) => {
let room_state = state.clone();
let api = api.clone();
tokio::spawn(async move {
if let Err(e) = handle_peer(api, room_state, offer_signal).await {
tracing::error!("error handling peer: {}", e);
}
}.instrument(tracing::Span::current()));
}
WebRtcSignal::RequestPeers { response } => {
let peers = state
.peers
.iter()
.map(|pair| pair.key().clone())
.collect::<Vec<_>>();
let _ = response.send(peers);
}
WebRtcSignal::Disconnect(peer_id) => {
tracing::debug!("received disconnect signal for peer {}", peer_id);
cleanup_peer(state.clone(), peer_id).await;
}
WebRtcSignal::Close => {
break;
}
}
}
}
}
Ok(())
}
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id))]
async fn handle_peer(
api: Arc<API>,
room_state: Arc<RoomState>,
offer_signal: OfferSignal,
) -> anyhow::Result<()> {
tracing::debug!("handling peer");
let config = RTCConfiguration {
..Default::default()
};
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: MIME_TYPE_OPUS.to_string(),
..Default::default()
},
format!("audio-{}", offer_signal.offer.peer_id),
format!("room-{}", room_state.room_id),
));
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
let rtp_sender = peer_connection
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
.await?;
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
let outgoing_track_ = Arc::clone(&outgoing_track);
tokio::spawn(
async move {
let mut rtcp_buf = vec![0u8; 1500];
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
}
tracing::debug!(
"RTCP reader loop for outgoing track {} ended",
outgoing_track_.id()
);
}
.instrument(tracing::Span::current()),
);
let room_state_ = room_state.clone();
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
let room_state_ = room_state_.clone();
Box::pin(async move {
tracing::debug!("peer connection state changed: {:?}", state);
if state == RTCPeerConnectionState::Disconnected
|| state == RTCPeerConnectionState::Failed
{
tracing::debug!("peer connection closed");
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
}
})
}));
let room_state_ = room_state.clone();
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
tracing::debug!("track received: {:?}", track);
let room_state_ = room_state_.clone();
Box::pin(async move {
if track.kind() == RTPCodecType::Audio {
tracing::debug!("audio track received: {:?}", track);
tokio::spawn(async move {
if let Err(e) =
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
{
tracing::error!("error forwarding audio track: {}", e);
}
});
} else {
tracing::warn!("received non-audio track: {:?}", track);
}
})
}));
peer_connection
.set_remote_description(offer_signal.offer.sdp_offer)
.await?;
let answer = peer_connection.create_answer(None).await?;
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
peer_connection.set_local_description(answer).await?;
gathering_complete.recv().await;
tracing::debug!("ICE gathering complete");
let local_description = peer_connection
.local_description()
.await
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
let peer_state = PeerState {
peer_id: offer_signal.offer.peer_id,
peer_connection: Arc::clone(&peer_connection),
outgoing_audio_track: Arc::clone(&outgoing_track),
};
{
if let Some(old_peer) = room_state
.peers
.insert(offer_signal.offer.peer_id, peer_state)
{
let _ = old_peer.peer_connection.close().await;
}
}
let _ = offer_signal.response.send(AnswerSignal {
sdp_answer: local_description,
});
Ok(())
}
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id))]
async fn forward_audio_track(
room_state: Arc<RoomState>,
peer_id: PeerId,
track: Arc<TrackRemote>,
) -> anyhow::Result<()> {
let mut rtp_buf = vec![0u8; 1500];
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
let other_peer_tracks = room_state
.peers
.iter()
.filter_map(|pair| {
let peer_state = pair.value();
if peer_state.peer_id != peer_id {
Some(peer_state.outgoing_audio_track.clone())
} else {
None
}
})
.collect::<Vec<_>>();
if other_peer_tracks.is_empty() {
// tracing::warn!("no other peers to forward audio track to");
continue;
}
let write_futures = other_peer_tracks
.iter()
.map(|outgoing_track| outgoing_track.write_rtp(&rtp_packet));
let results = futures::future::join_all(write_futures).await;
for result in results {
if let Err(e) = result {
tracing::error!("error writing RTP packet: {}", e);
}
}
}
tracing::debug!(
"RTP Read loop ended for track {} from peer {}",
track.id(),
peer_id
);
Ok(())
}
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
tracing::debug!("cleaning up peer");
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
tracing::debug!("removed peer");
let pc = Arc::clone(&peer_state.peer_connection);
tokio::spawn(async move {
if let Err(e) = pc.close().await {
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
tracing::warn!("error closing peer connection: {}", e);
}
}
tracing::debug!("peer connection closed");
});
}
if room_state.peers.is_empty() {
tracing::debug!("no more peers in room, closing room");
let _ = room_state.close_signal.send(());
}
}