diff --git a/Cargo.lock b/Cargo.lock index fa501b6..e70671b 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,7 +23,7 @@ version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "generic-array", ] @@ -107,9 +107,9 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] name = "argon2" -version = "0.5.3" +version = "0.6.0-pre.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072" +checksum = "8f95281c5706985b6c00f8a2270438f968d475672aa68a4a85cddcb57a68577b" dependencies = [ "base64ct", "blake2", @@ -405,11 +405,11 @@ dependencies = [ [[package]] name = "blake2" -version = "0.10.6" +version = "0.11.0-pre.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe" +checksum = "e6dbf347378982186052c47f25f33fc1a6eb439ee840d778eb3ec132e304379d" dependencies = [ - "digest", + "digest 0.11.0-pre.9", ] [[package]] @@ -421,6 +421,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "block-buffer" +version = "0.11.0-rc.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a229bfd78e4827c91b9b95784f69492c1b77c1ab75a45a8a037b139215086f94" +dependencies = [ + "hybrid-array", +] + [[package]] name = "block-padding" version = "0.3.3" @@ -556,7 +565,7 @@ version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "inout", ] @@ -715,6 +724,15 @@ dependencies = [ "typenum", ] +[[package]] +name = "crypto-common" +version = "0.2.0-rc.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "170d71b5b14dec99db7739f6fc7d6ec2db80b78c3acb77db48392ccc3d8a9ea0" +dependencies = [ + "hybrid-array", +] + [[package]] name = "ctr" version = "0.9.2" @@ -867,9 +885,20 @@ version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ - "block-buffer", + "block-buffer 0.10.4", "const-oid", - "crypto-common", + "crypto-common 0.1.6", + "subtle", +] + +[[package]] +name = "digest" +version = "0.11.0-pre.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf2e3d6615d99707295a9673e889bf363a04b2a466bd320c65a72536f7577379" +dependencies = [ + "block-buffer 0.11.0-rc.4", + "crypto-common 0.2.0-rc.2", "subtle", ] @@ -889,12 +918,11 @@ dependencies = [ "dashmap", "derive_more", "futures", - "hex", "jsonwebtoken", "mime", "rand 0.9.1", - "rand_core 0.6.4", "regex", + "scc", "serde", "serde_json", "sha2", @@ -943,7 +971,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca" dependencies = [ "der", - "digest", + "digest 0.10.7", "elliptic-curve", "rfc6979", "signature", @@ -967,7 +995,7 @@ checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47" dependencies = [ "base16ct", "crypto-bigint", - "digest", + "digest 0.10.7", "ff", "generic-array", "group", @@ -1336,7 +1364,7 @@ version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "digest", + "digest 0.10.7", ] [[package]] @@ -1394,6 +1422,15 @@ version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9" +[[package]] +name = "hybrid-array" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "891d15931895091dea5c47afa5b3c9a01ba634b311919fd4d41388fa0e3d76af" +dependencies = [ + "typenum", +] + [[package]] name = "hyper" version = "1.6.0" @@ -1588,9 +1625,9 @@ dependencies = [ [[package]] name = "interceptor" -version = "0.13.0" +version = "0.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e5ab04c530fd82e414e40394cabe5f0ebfe30d119f10fe29d6e3561926af412e" +checksum = "1ac0781c825d602095113772e389ef0607afcb869ae0e68a590d8e0799cdcef8" dependencies = [ "async-trait", "bytes", @@ -1735,7 +1772,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" dependencies = [ "cfg-if", - "digest", + "digest 0.10.7", ] [[package]] @@ -1999,9 +2036,9 @@ dependencies = [ [[package]] name = "password-hash" -version = "0.5.0" +version = "0.6.0-rc.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166" +checksum = "ec3b470a56963403c40f9dbb41eaee539759de9d026d3324da705a0ae0d269cd" dependencies = [ "base64ct", "rand_core 0.6.4", @@ -2462,7 +2499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78928ac1ed176a5ca1d17e578a1825f3d81ca54cf41053a592584b020cfd691b" dependencies = [ "const-oid", - "digest", + "digest 0.10.7", "num-bigint-dig", "num-integer", "num-traits", @@ -2477,9 +2514,9 @@ dependencies = [ [[package]] name = "rtcp" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8306430fb118b7834bbee50e744dc34826eca1da2158657a3d6cbc70e24c2096" +checksum = "e9689528bf3a9eb311fd938d05516dd546412f9ce4fffc8acfc1db27cc3dbf72" dependencies = [ "bytes", "thiserror 1.0.69", @@ -2488,9 +2525,9 @@ dependencies = [ [[package]] name = "rtp" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e68baca5b6cb4980678713f0d06ef3a432aa642baefcbfd0f4dd2ef9eb5ab550" +checksum = "c54733451a67d76caf9caa07a7a2cec6871ea9dda92a7847f98063d459200f4b" dependencies = [ "bytes", "memchr", @@ -2611,6 +2648,15 @@ version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f" +[[package]] +name = "scc" +version = "2.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4" +dependencies = [ + "sdd", +] + [[package]] name = "scopeguard" version = "1.2.0" @@ -2618,10 +2664,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] -name = "sdp" -version = "0.7.0" +name = "sdd" +version = "3.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "02a526161f474ae94b966ba622379d939a8fe46c930eebbadb73e339622599d5" +checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21" + +[[package]] +name = "sdp" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4cd277015eada44a0bb810a4b84d3bf6e810573fa62fb442f457edf6a1087a69" dependencies = [ "rand 0.8.5", "substring", @@ -2726,7 +2778,7 @@ checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.7", ] [[package]] @@ -2737,7 +2789,7 @@ checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" dependencies = [ "cfg-if", "cpufeatures", - "digest", + "digest 0.10.7", ] [[package]] @@ -2770,7 +2822,7 @@ version = "2.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de" dependencies = [ - "digest", + "digest 0.10.7", "rand_core 0.6.4", ] @@ -2949,7 +3001,7 @@ dependencies = [ "bytes", "chrono", "crc", - "digest", + "digest 0.10.7", "dotenvy", "either", "futures-channel", @@ -3070,9 +3122,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" [[package]] name = "stun" -version = "0.7.0" +version = "0.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ea256fb46a13f9204e9dee9982997b2c3097db175a9fddaa8350310d03c4d5a3" +checksum = "7dbc2bab375524093c143dc362a03fb6a1fb79e938391cdb21665688f88a088a" dependencies = [ "base64 0.22.1", "crc", @@ -3519,9 +3571,9 @@ dependencies = [ [[package]] name = "turn" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0044fdae001dd8a1e247ea6289abf12f4fcea1331a2364da512f9cd680bbd8cb" +checksum = "3f5aea1116456e1da71c45586b87c72e3b43164fbf435eb93ff6aa475416a9a4" dependencies = [ "async-trait", "base64 0.22.1", @@ -3601,7 +3653,7 @@ version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" dependencies = [ - "crypto-common", + "crypto-common 0.1.6", "subtle", ] @@ -3784,9 +3836,9 @@ dependencies = [ [[package]] name = "webrtc" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "30367074d9f18231d28a74fab0120856b2b665da108d71a12beab7185a36f97b" +checksum = "24bab7195998d605c862772f90a452ba655b90a2f463c850ac032038890e367a" dependencies = [ "arc-swap", "async-trait", @@ -3828,9 +3880,9 @@ dependencies = [ [[package]] name = "webrtc-data" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dec93b991efcd01b73c5b3503fa8adba159d069abe5785c988ebe14fcf8f05d1" +checksum = "4e97b932854da633a767eff0cc805425a2222fc6481e96f463e57b015d949d1d" dependencies = [ "bytes", "log", @@ -3843,9 +3895,9 @@ dependencies = [ [[package]] name = "webrtc-dtls" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b7c9b89fc909f9da0499283b1112cd98f72fec28e55a54a9e352525ca65cd95c" +checksum = "5ccbe4d9049390ab52695c3646c1395c877e16c15fb05d3bda8eee0c7351711c" dependencies = [ "aes", "aes-gcm", @@ -3880,9 +3932,9 @@ dependencies = [ [[package]] name = "webrtc-ice" -version = "0.12.0" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0348b28b593f7709ac98d872beb58c0009523df652c78e01b950ab9c537ff17d" +checksum = "eb51bde0d790f109a15bfe4d04f1b56fb51d567da231643cb3f21bb74d678997" dependencies = [ "arc-swap", "async-trait", @@ -3905,9 +3957,9 @@ dependencies = [ [[package]] name = "webrtc-mdns" -version = "0.8.0" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e6dfe9686c6c9c51428da4de415cb6ca2dc0591ce2b63212e23fd9cccf0e316b" +checksum = "979cc85259c53b7b620803509d10d35e2546fa505d228850cbe3f08765ea6ea8" dependencies = [ "log", "socket2", @@ -3918,9 +3970,9 @@ dependencies = [ [[package]] name = "webrtc-media" -version = "0.9.0" +version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e153be16b8650021ad3e9e49ab6e5fa9fb7f6d1c23c213fd8bbd1a1135a4c704" +checksum = "80041211deccda758a3e19aa93d6b10bc1d37c9183b519054b40a83691d13810" dependencies = [ "byteorder", "bytes", @@ -3931,9 +3983,9 @@ dependencies = [ [[package]] name = "webrtc-sctp" -version = "0.11.0" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5faf3846ec4b7e64b56338d62cbafe084aa79806b0379dff5cc74a8b7a2b3063" +checksum = "07439c134425d51d2f10907aaf2f815fdfb587dce19fe94a4ae8b5faf2aae5ae" dependencies = [ "arc-swap", "async-trait", @@ -3949,9 +4001,9 @@ dependencies = [ [[package]] name = "webrtc-srtp" -version = "0.14.0" +version = "0.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "771db9993712a8fb3886d5be4613ebf27250ef422bd4071988bf55f1ed1a64fa" +checksum = "01e773f79b09b057ffbda6b03fe7b43403b012a240cf8d05d630674c3723b5bb" dependencies = [ "aead", "aes", @@ -3972,9 +4024,9 @@ dependencies = [ [[package]] name = "webrtc-util" -version = "0.10.0" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1438a8fd0d69c5775afb4a71470af92242dbd04059c61895163aa3c1ef933375" +checksum = "64bfb10dbe6d762f80169ae07cf252bafa1f764b9594d140008a0231c0cdce58" dependencies = [ "async-trait", "bitflags 1.3.2", diff --git a/Cargo.toml b/Cargo.toml index b5ddacc..c6f8d42 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -5,14 +5,13 @@ edition = "2024" [dependencies] anyhow = "1.0" -argon2 = "0.5" +argon2 = "0.6.0-pre.1" axum = { version = "0.8", features = ["ws", "multipart", "macros"] } axum-extra = { version = "0.10", features = ["typed-header"] } chrono = { version = "0.4", features = ["serde"] } config = "0.15" derive_more = { version = "2.0", features = ["full"] } jsonwebtoken = "9.3" -rand_core = { version = "0.6", features = ["getrandom"] } serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sqlx = { version = "0.8", features = ["postgres", "runtime-tokio", "uuid", "chrono"] } @@ -23,15 +22,15 @@ tracing-appender = "0.2" tracing-subscriber = { version = "0.3", features = ["serde", "env-filter", "chrono"] } uuid = { version = "1.16", features = ["fast-rng", "serde", "v7"] } url = { version = "2.5", features = ["serde"] } -validator = { version = "0.20.0", features = ["derive"] } -regex = "1.11.1" -mime = "0.3.17" -axum_typed_multipart = "0.16.2" -async-trait = "0.1.88" -futures = "0.3.31" -webrtc = "0.12.0" -dashmap = "6.1.0" -rand = "0.9.1" -sha2 = "0.10.9" -hex = "0.4.3" -base64 = "0.22.1" +validator = { version = "0.20", features = ["derive"] } +regex = "1.11" +mime = "0.3" +axum_typed_multipart = "0.16" +async-trait = "0.1" +futures = "0.3" +webrtc = "0.13" +dashmap = "6.1" +rand = "0.9" +sha2 = "0.10" +base64 = "0.22" +scc = "2.3" diff --git a/src/config.rs b/src/config.rs index dd63430..26e0240 100644 --- a/src/config.rs +++ b/src/config.rs @@ -18,11 +18,30 @@ pub fn config() -> &'static Config { #[derive(Deserialize)] pub struct Config { - pub port: u16, - pub jwt_secret: String, + pub server: ServerConfig, + pub security: SecurityConfig, + pub gateway: GatewayConfig, pub database: DatabaseConfig, } +#[derive(Deserialize)] +pub struct ServerConfig { + pub host: std::net::Ipv4Addr, + pub port: u16, +} + +#[derive(Deserialize)] +pub struct SecurityConfig { + pub auth_secret: String, + pub voice_secret: String, +} + +#[derive(Deserialize)] +pub struct GatewayConfig { + #[serde(deserialize_with = "crate::util::deserialize_duration_seconds")] + pub voice_token_lifetime: std::time::Duration, +} + #[derive(Debug, Deserialize)] #[serde(untagged)] #[serde(deny_unknown_fields)] diff --git a/src/jwt.rs b/src/jwt.rs index ca373ee..d569f6b 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -10,20 +10,21 @@ use crate::config; #[derive(Debug, Serialize, Deserialize)] pub struct Claims { - pub user_id: T, + #[serde(flatten)] + pub data: T, pub iat: i64, } -pub fn generate_jwt(user_id: T) -> Result { +pub fn generate_jwt(data: T) -> Result { let claims = Claims { - user_id, + data, iat: Utc::now().timestamp_millis(), }; let token = jsonwebtoken::encode( &jsonwebtoken::Header::default(), &claims, - &jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()), + &jsonwebtoken::EncodingKey::from_secret(config::config().security.auth_secret.as_ref()), ) .map_err(|_| Error::CouldNotEncodeToken)?; @@ -34,16 +35,16 @@ pub fn verify_jwt(token: &str) -> Result { tracing::debug!("verifying token: {}", token); let mut validation = jsonwebtoken::Validation::default(); - validation.required_spec_claims = HashSet::new(); + validation.set_required_spec_claims::(&[]); let token_data = jsonwebtoken::decode::>( token, - &jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()), + &jsonwebtoken::DecodingKey::from_secret(config::config().security.auth_secret.as_ref()), &validation, ) .map_err(|_| Error::CouldNotVerifyToken)?; - Ok(token_data.claims.user_id) + Ok(token_data.claims.data) } pub type Result = std::result::Result; diff --git a/src/main.rs b/src/main.rs index 189bf60..03d472c 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,8 +1,6 @@ -use std::collections::HashMap; use std::sync::Arc; use argon2::Argon2; -use tokio::sync::RwLock; use crate::database::Database; use crate::state::AppState; @@ -25,8 +23,8 @@ async fn main() -> anyhow::Result<()> { let state = AppState { database, hasher: Arc::new(Argon2::default()), - event_connected_users: Arc::new(RwLock::new(HashMap::new())), - voice_rooms: Arc::new(RwLock::new(HashMap::new())), + gateway_state: Default::default(), + voice_rooms: Default::default(), }; web::run(state).await?; diff --git a/src/state.rs b/src/state.rs index 982e369..9074860 100644 --- a/src/state.rs +++ b/src/state.rs @@ -2,12 +2,11 @@ use std::collections::HashMap; use std::sync::Arc; use argon2::Argon2; -use dashmap::DashMap; use tokio::sync::{RwLock, mpsc}; use uuid::Uuid; use crate::database::Database; -use crate::web::ws::gateway::{EventWsState, message}; +use crate::web::ws::gateway::{GatewayWsState, SessionKey, event}; use crate::webrtc::OfferSignal; #[derive(Clone)] @@ -15,37 +14,50 @@ pub struct AppState { pub database: Database, pub hasher: Arc>, - pub event_connected_users: Arc>>, + pub gateway_state: Arc, pub voice_rooms: Arc>>>, } +#[derive(Debug, Default)] +pub struct GatewayState { + pub connected: scc::HashMap, +} + impl AppState { - pub async fn register_event_connected_user( + pub async fn register_gateway_connected_user( &self, user_id: Uuid, - session_id: String, - event_sender: mpsc::UnboundedSender, + session_key: SessionKey, + event_sender: mpsc::UnboundedSender, ) { - let mut connected_users = self.event_connected_users.write().await; - if let Some(state) = connected_users.get_mut(&user_id) { - state.connection_instance.insert(session_id, event_sender); - } else { - let state = EventWsState { - connection_instance: DashMap::new(), - }; - state.connection_instance.insert(session_id, event_sender); - connected_users.insert(user_id, state); - } + self.gateway_state + .connected + .entry_async(user_id) + .await + .or_default() + .get_mut() + .instances + .insert(session_key, event_sender); } - pub async fn unregister_event_connected_user(&self, user_id: Uuid, session_id: &str) { - let mut connected_users = self.event_connected_users.write().await; - if let Some(state) = connected_users.get_mut(&user_id) { - state.connection_instance.remove(session_id); - if state.connection_instance.is_empty() { - connected_users.remove(&user_id); - } + pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) { + let is_empty = { + let mut entry = self + .gateway_state + .connected + .entry_async(user_id) + .await + .or_default(); + + let entry = entry.get_mut(); + + entry.instances.remove(session_id); + entry.instances.is_empty() + }; + + if is_empty { + self.gateway_state.connected.remove_async(&user_id).await; } } } diff --git a/src/util.rs b/src/util.rs index 00bb8f0..3feeade 100644 --- a/src/util.rs +++ b/src/util.rs @@ -1,6 +1,6 @@ use axum::extract::multipart::Field; use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; -use serde::Serialize; +use serde::{Deserialize, Serialize}; #[derive(Debug, derive_more::Deref)] pub struct SerdeFieldData(pub FieldData); @@ -50,3 +50,13 @@ where let seconds = duration.as_secs(); seconds.serialize(serializer) } + +pub fn deserialize_duration_seconds<'de, D>( + deserializer: D, +) -> Result +where + D: serde::Deserializer<'de>, +{ + let seconds = u64::deserialize(deserializer)?; + Ok(std::time::Duration::from_secs(seconds)) +} diff --git a/src/web/context.rs b/src/web/context.rs index b256ba5..19288a5 100644 --- a/src/web/context.rs +++ b/src/web/context.rs @@ -3,7 +3,7 @@ use axum::http::request::Parts; use crate::entity; -#[derive(Debug, Copy, Clone)] +#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] pub struct UserContext { pub user_id: entity::user::Id, } diff --git a/src/web/middleware/auth.rs b/src/web/middleware/auth.rs index 0373760..e8604c0 100644 --- a/src/web/middleware/auth.rs +++ b/src/web/middleware/auth.rs @@ -9,6 +9,7 @@ use axum_extra::typed_header::TypedHeaderRejectionReason; use crate::jwt; use crate::state::AppState; +use crate::web::context::UserContext; use crate::web::{self, context}; pub async fn require_context( @@ -51,13 +52,13 @@ async fn get_context(state: &AppState, request: &mut Request) -> context::UserCo } pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult { - let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?; + let context = jwt::verify_jwt::(token).map_err(|_| context::Error::BadToken)?; let _ = state .database - .select_user_by_id(user_id) + .select_user_by_id(context.user_id) .await .map_err(|_| context::Error::BadToken)?; - Ok(context::UserContext { user_id }) + Ok(context) } diff --git a/src/web/mod.rs b/src/web/mod.rs index 5f2a826..0044361 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -12,9 +12,9 @@ use crate::{config, state}; pub async fn run(state: state::AppState) -> anyhow::Result<()> { let config = config::config(); - let addr: std::net::SocketAddr = ([127, 0, 0, 1], config.port).into(); - tracing::info!("listening on {}", addr); + let addr: std::net::SocketAddr = (config.server.host, config.server.port).into(); let listener = tokio::net::TcpListener::bind(addr).await?; + tracing::info!("listening on {}", addr); axum::serve(listener, router(state)) .with_graceful_shutdown(shutdown_signal()) @@ -36,10 +36,12 @@ fn router(state: state::AppState) -> axum::Router { Router::new() // websocket + .route("/gateway/ws", get(ws::gateway::ws_handler)) + .route("/voice/ws", get(ws::voice::ws_handler)) + // api .nest( "/api/v1", Router::new() - .route("/ws", get(ws::gateway::ws_handler)) .route("/auth/login", post(auth::login)) .route("/auth/register", post(auth::register)) .merge(protected_router()) @@ -48,6 +50,7 @@ fn router(state: state::AppState) -> axum::Router { middleware::resolve_context, )), ) + // middleware .layer(axum::middleware::from_fn(middleware::response_map)) .layer(cors) .layer(tower_http::trace::TraceLayer::new_for_http()) @@ -70,7 +73,6 @@ fn protected_router() -> axum::Router { .route("/servers", post(server::create)) .route("/servers/{server_id}", get(server::get)) .route("/servers/{server_id}/channels", get(server::channel::list)) - .route("/voice/{channel_id}/connect", post(voice::connect)) // middleware .route_layer(axum::middleware::from_fn(middleware::require_context)) } diff --git a/src/web/route/auth/login.rs b/src/web/route/auth/login.rs index 10a0b8a..88ef85b 100644 --- a/src/web/route/auth/login.rs +++ b/src/web/route/auth/login.rs @@ -5,6 +5,7 @@ use axum::response::IntoResponse; use serde::{Deserialize, Serialize}; use crate::state::AppState; +use crate::web::context::UserContext; use crate::web::route::user::FullUser; use crate::{jwt, web}; @@ -38,7 +39,7 @@ pub async fn login( .verify_password(payload.password.as_bytes(), &password_hash) .map_err(|_| web::error::ClientError::WrongPassword)?; - let token = jwt::generate_jwt(user.id)?; + let token = jwt::generate_jwt(UserContext { user_id: user.id })?; let response = LoginResponse { user: user.into(), diff --git a/src/web/route/server/channel/mod.rs b/src/web/route/server/channel/mod.rs index 4ce14ea..19f2172 100644 --- a/src/web/route/server/channel/mod.rs +++ b/src/web/route/server/channel/mod.rs @@ -1,3 +1,3 @@ mod list; -pub use list::*; +pub use list::list; diff --git a/src/web/route/server/create.rs b/src/web/route/server/create.rs index cfed669..7e2f869 100644 --- a/src/web/route/server/create.rs +++ b/src/web/route/server/create.rs @@ -73,7 +73,7 @@ pub async fn create( ws::gateway::util::send_message( &state, context.user_id, - ws::gateway::message::Event::AddServer { + ws::gateway::event::Event::AddServer { server: server.clone(), }, ) diff --git a/src/web/route/server/mod.rs b/src/web/route/server/mod.rs index c65146e..4f7e49a 100644 --- a/src/web/route/server/mod.rs +++ b/src/web/route/server/mod.rs @@ -3,6 +3,6 @@ mod create; mod get; mod list; -pub use create::*; -pub use get::*; -pub use list::*; +pub use create::create; +pub use get::get; +pub use list::list; diff --git a/src/web/route/user/channel/mod.rs b/src/web/route/user/channel/mod.rs index 4ce14ea..19f2172 100644 --- a/src/web/route/user/channel/mod.rs +++ b/src/web/route/user/channel/mod.rs @@ -1,3 +1,3 @@ mod list; -pub use list::*; +pub use list::list; diff --git a/src/web/route/user/mod.rs b/src/web/route/user/mod.rs index 91132b1..08dd7d6 100644 --- a/src/web/route/user/mod.rs +++ b/src/web/route/user/mod.rs @@ -2,8 +2,8 @@ pub mod channel; mod get; mod me; -pub use get::*; -pub use me::*; +pub use get::get_by_id; +pub use me::me; use crate::entity::user; diff --git a/src/web/route/voice/mod.rs b/src/web/route/voice/mod.rs index 08cbc12..fc9caa8 100644 --- a/src/web/route/voice/mod.rs +++ b/src/web/route/voice/mod.rs @@ -1,3 +1,3 @@ mod connect; -pub use connect::*; +pub use connect::connect; diff --git a/src/web/ws/error.rs b/src/web/ws/error.rs new file mode 100644 index 0000000..55eb7d5 --- /dev/null +++ b/src/web/ws/error.rs @@ -0,0 +1,14 @@ +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Display)] +pub enum Error { + #[from] + Json(serde_json::Error), + + #[from] + AcknowledgementError(tokio::sync::oneshot::error::RecvError), + + WrongMessageType, + + WebSocketClosed, +} diff --git a/src/web/ws/gateway/connection.rs b/src/web/ws/gateway/connection.rs index 2672931..bb678b2 100644 --- a/src/web/ws/gateway/connection.rs +++ b/src/web/ws/gateway/connection.rs @@ -1,24 +1,22 @@ -// src/web/ws/connection.rs - use std::ops::ControlFlow; -use std::time::Duration; use axum::extract::ws::{Message as AxumMessage, WebSocket}; -use base64::Engine as _; // Bring trait into scope +use base64::Engine as _; use futures::stream::SplitStream; -use futures::{SinkExt, StreamExt}; +use futures::{Sink, SinkExt, StreamExt}; +use serde::Serialize; use sha2::{Digest, Sha256}; use tokio::time::Instant; -// Use items from sibling modules within `ws` use super::error::{self, Error as WsError}; -// Assuming Event type is from ws::message -use super::message::Event as WsEvent; -use super::protocol::{ - SendWsMessage, WsClientMessage, WsServerMessage, deserialize_ws_message, serialize_ws_message, -}; +use super::event::Event as WsEvent; +use super::protocol::{WsClientMessage, WsServerMessage}; use super::state::{WsContext, WsState, WsUserContext}; +use crate::jwt; use crate::state::AppState; +use crate::web::ws::gateway::SessionKey; +use crate::web::ws::util::{SendWsMessage, deserialize_ws_message, serialize_ws_message}; +use crate::web::ws::{util, voice}; /// Main handler for an individual WebSocket connection's lifecycle. /// Spawned by Axum upon successful WebSocket upgrade. @@ -26,57 +24,20 @@ use crate::state::AppState; pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) { let (ws_sink, ws_stream) = websocket.split(); - let (internal_send_tx, mut internal_send_rx) = tokio::sync::mpsc::unbounded_channel(); + let (internal_send_tx, internal_send_rx) = tokio::sync::mpsc::unbounded_channel(); - // Writer task: consumes messages from MPSC channel and sends them to the WebSocket sink. - let writer_task = tokio::spawn(async move { - let mut ws_sink_mut = ws_sink; - while let Some(SendWsMessage { - message, - response_ch, - }) = internal_send_rx.recv().await - { - let send_result = match serialize_ws_message(message) { - Ok(ws_msg) => { - if ws_sink_mut.send(ws_msg).await.is_err() { - Err(WsError::WebSocketClosed) // Send to client failed - } else { - Ok(()) - } - }, - Err(e) => Err(e), // Serialization error itself - }; - - if let Some(ch) = response_ch { - if ch.send(send_result).is_err() { - // Log if the receiver of the acknowledgement was dropped, though this is unlikely - // if send_with_response is awaiting it. - tracing::debug!("Failed to send acknowledgement; receiver dropped."); - } - } else if let Err(e) = send_result { - // For fire-and-forget, log critical errors (not just WebSocketClosed). - if !matches!(e, WsError::WebSocketClosed) { - tracing::warn!("Error in fire-and-forget WebSocket send: {:?}", e); - } - } - } - // MPSC channel closed, attempt to gracefully close WebSocket. - if ws_sink_mut.close().await.is_err() { - tracing::debug!("Error closing WebSocket sink; connection might be already dead."); - } - }); + let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); let mut context = WsContext { connection_state: WsState::Initialize, user_context: None, - heartbeat_interval: std::time::Duration::from_secs(30), // Assuming config path - next_ping_deadline: Instant::now(), // Will be properly set before first use event_channel: None, }; + let processing_result = process_websocket_messages( &mut context, ws_stream, - &internal_send_tx, // Pass as reference + &internal_send_tx, &app_state, ) .await; @@ -84,7 +45,7 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) // --- Cleanup --- if let Some(user_ctx_data) = &context.user_context { app_state - .unregister_event_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key) + .unregister_gateway_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key) .await; tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user."); } @@ -97,13 +58,13 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) if let Err(err_to_report) = &processing_result { if !matches!( err_to_report, - WsError::WebSocketClosed | WsError::HeartbeatTimeout + WsError::WebSocketClosed ) { tracing::warn!( "WebSocket processing error, attempting to notify client: {:?}", err_to_report ); - let client_err_code = err_to_report.into_client_error(); + let client_err_code = err_to_report.as_client_error(); let error_ws_message = WsServerMessage::Error { code: client_err_code, }; @@ -127,32 +88,19 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) /// Main loop for processing incoming WebSocket messages and outgoing application events. /// Manages state transitions (Initialize -> Connected) and heartbeating. -#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))] +#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) +))] async fn process_websocket_messages( context: &mut WsContext, mut ws_stream: SplitStream, - sender: &tokio::sync::mpsc::UnboundedSender, // Changed to reference + sender: &tokio::sync::mpsc::UnboundedSender>, app_state: &AppState, ) -> error::Result<()> { - // Send initial heartbeat interval and set first deadline. - SendWsMessage::send_with_response( - sender, - WsServerMessage::HeartbeatInterval { - interval: context.heartbeat_interval, - }, - ) - .await?; - context.reset_deadline(); - loop { match context.connection_state { WsState::Initialize => { tokio::select! { - biased; // Prefer timeout check if multiple branches are ready - _ = tokio::time::sleep_until(context.next_ping_deadline) => { - tracing::warn!("Initial connection timeout (no Authenticate or Ping)."); - return Err(WsError::HeartbeatTimeout); - } + biased; maybe_message = ws_stream.next() => { match maybe_message { Some(Ok(message)) => { @@ -191,10 +139,6 @@ async fn process_websocket_messages( tokio::select! { biased; - _ = tokio::time::sleep_until(context.next_ping_deadline) => { - tracing::warn!(user_id = ?user_ctx.user_id, "Heartbeat timeout."); - return Err(WsError::HeartbeatTimeout); - } // Listen for application events to send to the client maybe_app_event = event_rx.recv() => { if let Some(app_event_data) = maybe_app_event { @@ -234,20 +178,13 @@ async fn process_websocket_messages( async fn handle_initial_message( context: &mut WsContext, message: AxumMessage, - sender: &tokio::sync::mpsc::UnboundedSender, // Changed to reference + sender: &tokio::sync::mpsc::UnboundedSender>, // Changed to reference app_state: &AppState, ) -> error::Result> { // Break(NewState) or Continue(()) match deserialize_ws_message(message)? { WsClientMessage::Authenticate { token } => { - // IMPORTANT: Adjust the call below to your actual token validation logic. - // Assuming `get_context_from_token` returns `Result` - match crate::web::middleware::get_context_from_token( - &app_state, // Example: Pass necessary parts of AppState - &token, - ) - .await - { + match crate::web::middleware::get_context_from_token(&app_state, &token).await { Ok(auth_user_context) => { // auth_user_context is `crate::web::context::UserContext` let user_id = auth_user_context.user_id; @@ -256,12 +193,14 @@ async fn handle_initial_message( context.event_channel = Some((event_tx.clone(), event_rx)); let random_key_part = rand::random::(); - let current_session_key = { + let current_session_key: SessionKey = { let mut hasher = Sha256::new(); hasher.update(token.as_bytes()); hasher.update(user_id.to_string().as_bytes()); hasher.update(&random_key_part.to_be_bytes()); - base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize()) + base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(hasher.finalize()) + .into() }; context.user_context = Some(WsUserContext { @@ -270,7 +209,7 @@ async fn handle_initial_message( }); app_state - .register_event_connected_user( + .register_gateway_connected_user( user_id, current_session_key.clone(), event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender @@ -281,7 +220,7 @@ async fn handle_initial_message( sender, WsServerMessage::AuthenticateAccepted { user_id, - session_key: current_session_key.clone(), + session_key: current_session_key, }, ) .await?; @@ -301,11 +240,6 @@ async fn handle_initial_message( }, } }, - WsClientMessage::Ping => { - context.reset_deadline(); // Reset deadline on successful ping - SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?; - Ok(ControlFlow::Continue(())) - }, // Per original code, only Authenticate and Ping are expected in Initialize. // If WsClientMessage has other variants, this might need adjustment. #[allow(unreachable_patterns)] @@ -318,17 +252,45 @@ async fn handle_initial_message( /// Handles messages received when the connection is in the `Connected` state. /// Primarily expects `Ping` messages to keep the connection alive. -#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))] +#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) +))] async fn handle_connected_message( - context: &mut WsContext, // Although not heavily used here, good for consistency and tracing + context: &mut WsContext, message: AxumMessage, - sender: &tokio::sync::mpsc::UnboundedSender, // Changed to reference + sender: &tokio::sync::mpsc::UnboundedSender>, ) -> error::Result<()> { match deserialize_ws_message(message)? { - WsClientMessage::Ping => { - tracing::debug!("Ping received."); - context.reset_deadline(); - SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?; + WsClientMessage::VoiceStateUpdate { + server_id, + channel_id, + } => { + // TODO: check if can join this channel + + let claims = voice::claims::VoiceClaims { + user_id: context + .user_context + .as_ref() + .expect("user context should be present") + .user_id, + server_id, + channel_id, + iat: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) + .timestamp(), + }; + + let token = jwt::generate_jwt(claims).map_err(|_| WsError::TokenGenerationFailed)?; + + SendWsMessage::send_with_response( + sender, + WsServerMessage::Event { + event: WsEvent::VoiceServerUpdate { + server_id, + channel_id, + token, + }, + }, + ) + .await?; Ok(()) }, diff --git a/src/web/ws/gateway/error.rs b/src/web/ws/gateway/error.rs index 03ebc5d..1d1eb94 100644 --- a/src/web/ws/gateway/error.rs +++ b/src/web/ws/gateway/error.rs @@ -17,8 +17,9 @@ pub enum Error { WebSocketClosed, - HeartbeatTimeout, AuthenticationFailed, + + TokenGenerationFailed, } #[derive(Debug, Clone, serde::Serialize)] @@ -27,15 +28,13 @@ pub enum ClientError { DeserializationError, NotAuthenticated, AlreadyAuthenticated, - HeartbeatTimeout, Unknown, } impl Error { - pub fn into_client_error(&self) -> ClientError { + pub fn as_client_error(&self) -> ClientError { match self { - Error::HeartbeatTimeout => ClientError::HeartbeatTimeout, Error::Json(_) => ClientError::DeserializationError, Error::UnexpectedMessageType => ClientError::Unknown, Error::WrongMessageType => ClientError::Unknown, @@ -44,3 +43,14 @@ impl Error { } } } + +impl From for Error { + fn from(err: crate::web::ws::error::Error) -> Self { + match err { + crate::web::ws::error::Error::Json(e) => Error::Json(e), + crate::web::ws::error::Error::AcknowledgementError(e) => Error::AcknowledgementError(e), + crate::web::ws::error::Error::WrongMessageType => Error::WrongMessageType, + crate::web::ws::error::Error::WebSocketClosed => Error::WebSocketClosed, + } + } +} diff --git a/src/web/ws/gateway/message.rs b/src/web/ws/gateway/event.rs similarity index 55% rename from src/web/ws/gateway/message.rs rename to src/web/ws/gateway/event.rs index 688d3f5..624b20c 100644 --- a/src/web/ws/gateway/message.rs +++ b/src/web/ws/gateway/event.rs @@ -4,12 +4,25 @@ use crate::entity; #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum Event { + #[serde(rename_all = "camelCase")] AddServer { server: entity::server::Server }, + #[serde(rename_all = "camelCase")] RemoveServer { server_id: entity::server::Id }, + #[serde(rename_all = "camelCase")] AddDmChannel { channel: entity::channel::Channel }, + #[serde(rename_all = "camelCase")] RemoveDmChannel { channel_id: entity::channel::Id }, + #[serde(rename_all = "camelCase")] AddServerChannel { channel: entity::channel::Channel }, + #[serde(rename_all = "camelCase")] RemoveServerChannel { channel_id: entity::channel::Id }, + + #[serde(rename_all = "camelCase")] + VoiceServerUpdate { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + token: String, + }, } diff --git a/src/web/ws/gateway/mod.rs b/src/web/ws/gateway/mod.rs index 116cdad..49137f6 100644 --- a/src/web/ws/gateway/mod.rs +++ b/src/web/ws/gateway/mod.rs @@ -1,9 +1,6 @@ use axum::extract::{State, WebSocketUpgrade}; use axum::response::IntoResponse; -use base64::Engine; use dashmap::DashMap; -use futures::{SinkExt, StreamExt}; -use sha2::Digest; use crate::state::AppState; use crate::web::ws::gateway::connection::handle_socket_connection; @@ -11,14 +8,29 @@ use crate::web::ws::gateway::state::EventSender; mod connection; mod error; -pub mod message; +pub mod event; mod protocol; mod state; pub mod util; +#[derive( + Debug, + Clone, + Eq, + PartialEq, + Hash, + Default, + derive_more::Display, + serde::Serialize, + serde::Deserialize, + derive_more::From, +)] +#[serde(transparent)] +pub struct SessionKey(String); + #[derive(Debug, Default)] -pub struct EventWsState { - pub connection_instance: DashMap, +pub struct GatewayWsState { + pub instances: DashMap, } pub async fn ws_handler( diff --git a/src/web/ws/gateway/protocol.rs b/src/web/ws/gateway/protocol.rs index 95f857a..39ffdb5 100644 --- a/src/web/ws/gateway/protocol.rs +++ b/src/web/ws/gateway/protocol.rs @@ -1,33 +1,29 @@ -// src/web/ws/protocol.rs - -use axum::extract::ws::Message as AxumMessage; -use serde::{Deserialize, Serialize}; use std::time::Duration; -use super::error::{self, ClientError, Error as WsError}; -use super::message as ws_local_message; // For ws::message::Event +use serde::{Deserialize, Serialize}; + +use super::error::ClientError; +use super::{SessionKey, event as ws_local_message}; use crate::entity; -use crate::util as crate_root_util; // For crate::util::serialize_duration_seconds #[derive(Debug, Serialize)] #[serde(tag = "type", content = "data")] #[serde(rename_all = "SCREAMING_SNAKE_CASE")] pub enum WsServerMessage { - HeartbeatInterval { - #[serde(serialize_with = "crate_root_util::serialize_duration_seconds")] - interval: Duration, - }, AuthenticateDenied, + #[serde(rename_all = "camelCase")] AuthenticateAccepted { user_id: entity::user::Id, - session_key: String, + session_key: SessionKey, }, + #[serde(rename_all = "camelCase")] Event { - event: ws_local_message::Event, // Assumes Event is defined in ws::message + event: ws_local_message::Event, }, - Pong, + + #[serde(rename_all = "camelCase")] Error { code: ClientError, }, @@ -41,59 +37,10 @@ pub enum WsClientMessage { Authenticate { token: String, }, - Ping, + + #[serde(rename_all = "camelCase")] + VoiceStateUpdate { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + }, } - -/// Deserializes an Axum WebSocket message into a `WsClientMessage`. -pub fn deserialize_ws_message(message: AxumMessage) -> error::Result { - match message { - AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from), - AxumMessage::Close(_) => Err(WsError::WebSocketClosed), - _ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message - } -} - -/// Serializes a `WsServerMessage` into an Axum WebSocket message. -pub fn serialize_ws_message(message: WsServerMessage) -> error::Result { - serde_json::to_string(&message) - .map(Into::into) - .map(AxumMessage::Text) - .map_err(WsError::from) -} - -/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. -/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. -pub struct SendWsMessage { - pub message: WsServerMessage, - pub response_ch: Option>>, -} - -impl SendWsMessage { - /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. - pub async fn send_with_response( - tx: &tokio::sync::mpsc::UnboundedSender, // Changed to reference - message: WsServerMessage, - ) -> error::Result<()> { - let (response_tx, response_rx) = tokio::sync::oneshot::channel(); - let send_message = SendWsMessage { - message, - response_ch: Some(response_tx), - }; - - if tx.send(send_message).is_err() { - Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead - } else { - // Wait for the writer task to acknowledge the send attempt. - // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. - response_rx.await? // Propagates RecvError into WsError::AcknowledgementError - } - } - - /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). - pub fn new_no_response(message: WsServerMessage) -> Self { - SendWsMessage { - message, - response_ch: None, - } - } -} \ No newline at end of file diff --git a/src/web/ws/gateway/state.rs b/src/web/ws/gateway/state.rs index 60e5b99..273f427 100644 --- a/src/web/ws/gateway/state.rs +++ b/src/web/ws/gateway/state.rs @@ -1,10 +1,8 @@ -// src/web/ws/state.rs - -use std::time::{Duration}; +use std::time::Duration; use tokio::sync::mpsc; -use super::message; +use super::{event, SessionKey}; use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver /// Represents the current state of a single WebSocket connection. @@ -15,35 +13,25 @@ pub enum WsState { } /// Contextual information for an authenticated WebSocket user session. -#[derive(Debug, Clone)] // Clone might be useful +#[derive(Debug)] pub struct WsUserContext { pub user_id: entity::user::Id, - pub session_key: String, // Unique key for this specific WebSocket session instance + pub session_key: SessionKey, // Unique key for this specific WebSocket session instance } /// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client. -pub type EventSender = mpsc::UnboundedSender; +pub type EventSender = mpsc::UnboundedSender; /// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s. -pub type EventReceiver = mpsc::UnboundedReceiver; +pub type EventReceiver = mpsc::UnboundedReceiver; /// Holds the full context for a single WebSocket connection's lifecycle. /// This struct is managed per-connection. pub struct WsContext { pub connection_state: WsState, pub user_context: Option, - pub heartbeat_interval: Duration, - pub next_ping_deadline: tokio::time::Instant, /// Channel for receiving application-specific events to be sent to this client. /// The `EventSender` (tx) part is given to `AppState` for broadcasting. /// The `EventReceiver` (rx) part is polled by the connection task. pub event_channel: Option<(EventSender, EventReceiver)>, } -impl WsContext { - /// Resets the ping deadline based on the current time and heartbeat interval. - /// This should be called after successfully receiving a ping from the client - /// or after sending a message that implies activity (like Pong or initial auth). - pub fn reset_deadline(&mut self) { - self.next_ping_deadline = tokio::time::Instant::now() + self.heartbeat_interval; - } -} diff --git a/src/web/ws/gateway/util.rs b/src/web/ws/gateway/util.rs index 4a23798..d8d798a 100644 --- a/src/web/ws/gateway/util.rs +++ b/src/web/ws/gateway/util.rs @@ -1,11 +1,11 @@ use crate::entity; use crate::state::AppState; -use crate::web::ws::gateway::message; +use crate::web::ws::gateway::event; -pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: message::Event) { - let connected_users = state.event_connected_users.read().await; - if let Some(state) = connected_users.get(&user_id) { - for instance in state.connection_instance.iter() { +pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: event::Event) { + let connected_users = state.gateway_state.connected.get_async(&user_id).await; + if let Some(session) = connected_users { + for instance in session.instances.iter() { if let Err(e) = instance.send(message.clone()) { tracing::error!("failed to send message: {}", e); } diff --git a/src/web/ws/mod.rs b/src/web/ws/mod.rs index 4f27526..658f016 100644 --- a/src/web/ws/mod.rs +++ b/src/web/ws/mod.rs @@ -1 +1,4 @@ +mod error; pub mod gateway; +mod util; +pub mod voice; diff --git a/src/web/ws/util.rs b/src/web/ws/util.rs new file mode 100644 index 0000000..8817c77 --- /dev/null +++ b/src/web/ws/util.rs @@ -0,0 +1,95 @@ +use axum::extract::ws::Message as AxumMessage; +use futures::{Sink, SinkExt}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::sync::{mpsc, oneshot}; + +pub fn spawn_writer_task( + mut ws_sink: S, + mut writer_rx: mpsc::UnboundedReceiver>, +) -> tokio::task::JoinHandle<()> +where + S: Sink + Unpin + Send + 'static, + T: Serialize + Send + 'static, +{ + tokio::spawn(async move { + while let Some(SendWsMessage { + message, + response_ch, + }) = writer_rx.recv().await + { + let send_result = match serialize_ws_message(message) { + Ok(ws_msg) => { + if ws_sink.send(ws_msg).await.is_err() { + Err(super::error::Error::WebSocketClosed) + } else { + Ok(()) + } + }, + Err(e) => Err(e.into()), + }; + + if let Some(ch) = response_ch { + let _ = ch.send(send_result); + } + } + + let _ = ws_sink.close().await; + }) +} + +/// Deserializes an Axum WebSocket message into a `WsClientMessage`. +pub fn deserialize_ws_message( + message: AxumMessage, +) -> super::error::Result { + match message { + AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from), + AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed), + _ => Err(super::error::Error::WrongMessageType), + } +} + +/// Serializes a `WsServerMessage` into an Axum WebSocket message. +pub fn serialize_ws_message(message: T) -> super::error::Result { + serde_json::to_string(&message) + .map(Into::into) + .map(AxumMessage::Text) + .map_err(super::error::Error::from) +} + +/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. +/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. +pub struct SendWsMessage { + pub message: T, + pub response_ch: Option>>, +} + +impl SendWsMessage { + /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. + pub async fn send_with_response( + tx: &mpsc::UnboundedSender, // Changed to reference + message: T, + ) -> super::error::Result<()> { + let (response_tx, response_rx) = oneshot::channel(); + let send_message = SendWsMessage { + message, + response_ch: Some(response_tx), + }; + + if tx.send(send_message).is_err() { + Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead + } else { + // Wait for the writer task to acknowledge the send attempt. + // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. + response_rx.await? // Propagates RecvError into WsError::AcknowledgementError + } + } + + /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). + pub fn new_no_response(message: T) -> Self { + SendWsMessage { + message, + response_ch: None, + } + } +} diff --git a/src/web/ws/voice/claims.rs b/src/web/ws/voice/claims.rs new file mode 100644 index 0000000..e1d3ea7 --- /dev/null +++ b/src/web/ws/voice/claims.rs @@ -0,0 +1,9 @@ +use crate::entity; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct VoiceClaims { + pub user_id: entity::user::Id, + pub server_id: entity::server::Id, + pub channel_id: entity::channel::Id, + pub iat: i64, +} diff --git a/src/web/ws/voice/connection.rs b/src/web/ws/voice/connection.rs new file mode 100644 index 0000000..c1a9b07 --- /dev/null +++ b/src/web/ws/voice/connection.rs @@ -0,0 +1,6 @@ +use axum::extract::ws::WebSocket; + +use crate::state::AppState; + +#[tracing::instrument(skip_all, name = "ws_connection_handler")] +pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {} diff --git a/src/web/ws/voice/error.rs b/src/web/ws/voice/error.rs new file mode 100644 index 0000000..5766e1d --- /dev/null +++ b/src/web/ws/voice/error.rs @@ -0,0 +1,48 @@ +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Display)] +pub enum Error { + #[from] + Axum(axum::Error), + + #[from] + Json(serde_json::Error), + + #[from] + AcknowledgementError(tokio::sync::oneshot::error::RecvError), + + UnexpectedMessageType, + + WrongMessageType, + + WebSocketClosed, + + HeartbeatTimeout, + AuthenticationFailed, + + TokenGenerationFailed, +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum ClientError { + DeserializationError, + NotAuthenticated, + AlreadyAuthenticated, + HeartbeatTimeout, + + Unknown, +} + +impl Error { + pub fn into_client_error(&self) -> ClientError { + match self { + Error::HeartbeatTimeout => ClientError::HeartbeatTimeout, + Error::Json(_) => ClientError::DeserializationError, + Error::UnexpectedMessageType => ClientError::Unknown, + Error::WrongMessageType => ClientError::Unknown, + Error::WebSocketClosed => ClientError::Unknown, + _ => ClientError::Unknown, + } + } +} diff --git a/src/web/ws/voice/mod.rs b/src/web/ws/voice/mod.rs new file mode 100644 index 0000000..bddb0d6 --- /dev/null +++ b/src/web/ws/voice/mod.rs @@ -0,0 +1,17 @@ +pub mod claims; +mod connection; +mod error; +mod protocol; + +use axum::extract::{State, WebSocketUpgrade}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::ws::voice::connection::handle_socket_connection; + +pub async fn ws_handler( + State(app_state): State, + ws: WebSocketUpgrade, +) -> crate::web::error::Result { + Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state))) +} diff --git a/src/web/ws/voice/protocol.rs b/src/web/ws/voice/protocol.rs new file mode 100644 index 0000000..118bdb6 --- /dev/null +++ b/src/web/ws/voice/protocol.rs @@ -0,0 +1,105 @@ +use std::time::Duration; + +use axum::extract::ws::Message as AxumMessage; +use serde::{Deserialize, Serialize}; +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; + +use super::error::{self, ClientError, Error as WsError}; +use crate::{entity, util as crate_root_util}; // For crate::util::serialize_duration_seconds + +#[derive(Debug, Serialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsServerMessage { + HeartbeatInterval { + #[serde(serialize_with = "crate_root_util::serialize_duration_seconds")] + interval: Duration, + }, + + AuthenticateDenied, + + AuthenticateAccepted, + + #[serde(rename_all = "camelCase")] + SdpAnswer { + sdp: RTCSessionDescription, + }, + + #[serde(rename_all = "camelCase")] + Error { + code: ClientError, + }, + + Pong, +} + +#[derive(Debug, Deserialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsClientMessage { + #[serde(rename_all = "camelCase")] + Authenticate { + token: String, + }, + + #[serde(rename_all = "camelCase")] + SdpOffer { + sdp: RTCSessionDescription, + }, + + Ping, +} + +/// Deserializes an Axum WebSocket message into a `WsClientMessage`. +pub fn deserialize_ws_message(message: AxumMessage) -> error::Result { + match message { + AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from), + AxumMessage::Close(_) => Err(WsError::WebSocketClosed), + _ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message + } +} + +/// Serializes a `WsServerMessage` into an Axum WebSocket message. +pub fn serialize_ws_message(message: WsServerMessage) -> error::Result { + serde_json::to_string(&message) + .map(Into::into) + .map(AxumMessage::Text) + .map_err(WsError::from) +} + +/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. +/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. +pub struct SendWsMessage { + pub message: WsServerMessage, + pub response_ch: Option>>, +} + +impl SendWsMessage { + /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. + pub async fn send_with_response( + tx: &tokio::sync::mpsc::UnboundedSender, // Changed to reference + message: WsServerMessage, + ) -> error::Result<()> { + let (response_tx, response_rx) = tokio::sync::oneshot::channel(); + let send_message = SendWsMessage { + message, + response_ch: Some(response_tx), + }; + + if tx.send(send_message).is_err() { + Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead + } else { + // Wait for the writer task to acknowledge the send attempt. + // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. + response_rx.await? // Propagates RecvError into WsError::AcknowledgementError + } + } + + /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). + pub fn new_no_response(message: WsServerMessage) -> Self { + SendWsMessage { + message, + response_ch: None, + } + } +}