This commit is contained in:
2025-05-14 10:41:46 +03:00
parent f04f5e6e41
commit 02f45aeac6
32 changed files with 650 additions and 326 deletions

160
Cargo.lock generated
View File

@@ -23,7 +23,7 @@ version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0"
dependencies = [
"crypto-common",
"crypto-common 0.1.6",
"generic-array",
]
@@ -107,9 +107,9 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457"
[[package]]
name = "argon2"
version = "0.5.3"
version = "0.6.0-pre.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c3610892ee6e0cbce8ae2700349fcf8f98adb0dbfbee85aec3c9179d29cc072"
checksum = "8f95281c5706985b6c00f8a2270438f968d475672aa68a4a85cddcb57a68577b"
dependencies = [
"base64ct",
"blake2",
@@ -405,11 +405,11 @@ dependencies = [
[[package]]
name = "blake2"
version = "0.10.6"
version = "0.11.0-pre.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46502ad458c9a52b69d4d4d32775c788b7a1b85e8bc9d482d92250fc0e3f8efe"
checksum = "e6dbf347378982186052c47f25f33fc1a6eb439ee840d778eb3ec132e304379d"
dependencies = [
"digest",
"digest 0.11.0-pre.9",
]
[[package]]
@@ -421,6 +421,15 @@ dependencies = [
"generic-array",
]
[[package]]
name = "block-buffer"
version = "0.11.0-rc.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a229bfd78e4827c91b9b95784f69492c1b77c1ab75a45a8a037b139215086f94"
dependencies = [
"hybrid-array",
]
[[package]]
name = "block-padding"
version = "0.3.3"
@@ -556,7 +565,7 @@ version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad"
dependencies = [
"crypto-common",
"crypto-common 0.1.6",
"inout",
]
@@ -715,6 +724,15 @@ dependencies = [
"typenum",
]
[[package]]
name = "crypto-common"
version = "0.2.0-rc.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "170d71b5b14dec99db7739f6fc7d6ec2db80b78c3acb77db48392ccc3d8a9ea0"
dependencies = [
"hybrid-array",
]
[[package]]
name = "ctr"
version = "0.9.2"
@@ -867,9 +885,20 @@ version = "0.10.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
dependencies = [
"block-buffer",
"block-buffer 0.10.4",
"const-oid",
"crypto-common",
"crypto-common 0.1.6",
"subtle",
]
[[package]]
name = "digest"
version = "0.11.0-pre.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cf2e3d6615d99707295a9673e889bf363a04b2a466bd320c65a72536f7577379"
dependencies = [
"block-buffer 0.11.0-rc.4",
"crypto-common 0.2.0-rc.2",
"subtle",
]
@@ -889,12 +918,11 @@ dependencies = [
"dashmap",
"derive_more",
"futures",
"hex",
"jsonwebtoken",
"mime",
"rand 0.9.1",
"rand_core 0.6.4",
"regex",
"scc",
"serde",
"serde_json",
"sha2",
@@ -943,7 +971,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ee27f32b5c5292967d2d4a9d7f1e0b0aed2c15daded5a60300e4abb9d8020bca"
dependencies = [
"der",
"digest",
"digest 0.10.7",
"elliptic-curve",
"rfc6979",
"signature",
@@ -967,7 +995,7 @@ checksum = "b5e6043086bf7973472e0c7dff2142ea0b680d30e18d9cc40f267efbf222bd47"
dependencies = [
"base16ct",
"crypto-bigint",
"digest",
"digest 0.10.7",
"ff",
"generic-array",
"group",
@@ -1336,7 +1364,7 @@ version = "0.12.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
dependencies = [
"digest",
"digest 0.10.7",
]
[[package]]
@@ -1394,6 +1422,15 @@ version = "1.0.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df3b46402a9d5adb4c86a0cf463f42e19994e3ee891101b1841f30a545cb49a9"
[[package]]
name = "hybrid-array"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "891d15931895091dea5c47afa5b3c9a01ba634b311919fd4d41388fa0e3d76af"
dependencies = [
"typenum",
]
[[package]]
name = "hyper"
version = "1.6.0"
@@ -1588,9 +1625,9 @@ dependencies = [
[[package]]
name = "interceptor"
version = "0.13.0"
version = "0.14.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e5ab04c530fd82e414e40394cabe5f0ebfe30d119f10fe29d6e3561926af412e"
checksum = "1ac0781c825d602095113772e389ef0607afcb869ae0e68a590d8e0799cdcef8"
dependencies = [
"async-trait",
"bytes",
@@ -1735,7 +1772,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
dependencies = [
"cfg-if",
"digest",
"digest 0.10.7",
]
[[package]]
@@ -1999,9 +2036,9 @@ dependencies = [
[[package]]
name = "password-hash"
version = "0.5.0"
version = "0.6.0-rc.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "346f04948ba92c43e8469c1ee6736c7563d71012b17d40745260fe106aac2166"
checksum = "ec3b470a56963403c40f9dbb41eaee539759de9d026d3324da705a0ae0d269cd"
dependencies = [
"base64ct",
"rand_core 0.6.4",
@@ -2462,7 +2499,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "78928ac1ed176a5ca1d17e578a1825f3d81ca54cf41053a592584b020cfd691b"
dependencies = [
"const-oid",
"digest",
"digest 0.10.7",
"num-bigint-dig",
"num-integer",
"num-traits",
@@ -2477,9 +2514,9 @@ dependencies = [
[[package]]
name = "rtcp"
version = "0.12.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8306430fb118b7834bbee50e744dc34826eca1da2158657a3d6cbc70e24c2096"
checksum = "e9689528bf3a9eb311fd938d05516dd546412f9ce4fffc8acfc1db27cc3dbf72"
dependencies = [
"bytes",
"thiserror 1.0.69",
@@ -2488,9 +2525,9 @@ dependencies = [
[[package]]
name = "rtp"
version = "0.12.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e68baca5b6cb4980678713f0d06ef3a432aa642baefcbfd0f4dd2ef9eb5ab550"
checksum = "c54733451a67d76caf9caa07a7a2cec6871ea9dda92a7847f98063d459200f4b"
dependencies = [
"bytes",
"memchr",
@@ -2611,6 +2648,15 @@ version = "1.0.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "28d3b2b1366ec20994f1fd18c3c594f05c5dd4bc44d8bb0c1c632c8d6829481f"
[[package]]
name = "scc"
version = "2.3.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22b2d775fb28f245817589471dd49c5edf64237f4a19d10ce9a92ff4651a27f4"
dependencies = [
"sdd",
]
[[package]]
name = "scopeguard"
version = "1.2.0"
@@ -2618,10 +2664,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49"
[[package]]
name = "sdp"
version = "0.7.0"
name = "sdd"
version = "3.0.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "02a526161f474ae94b966ba622379d939a8fe46c930eebbadb73e339622599d5"
checksum = "584e070911c7017da6cb2eb0788d09f43d789029b5877d3e5ecc8acf86ceee21"
[[package]]
name = "sdp"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd277015eada44a0bb810a4b84d3bf6e810573fa62fb442f457edf6a1087a69"
dependencies = [
"rand 0.8.5",
"substring",
@@ -2726,7 +2778,7 @@ checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
"digest 0.10.7",
]
[[package]]
@@ -2737,7 +2789,7 @@ checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
"digest 0.10.7",
]
[[package]]
@@ -2770,7 +2822,7 @@ version = "2.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "77549399552de45a898a580c1b41d445bf730df867cc44e6c0233bbc4b8329de"
dependencies = [
"digest",
"digest 0.10.7",
"rand_core 0.6.4",
]
@@ -2949,7 +3001,7 @@ dependencies = [
"bytes",
"chrono",
"crc",
"digest",
"digest 0.10.7",
"dotenvy",
"either",
"futures-channel",
@@ -3070,9 +3122,9 @@ checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f"
[[package]]
name = "stun"
version = "0.7.0"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea256fb46a13f9204e9dee9982997b2c3097db175a9fddaa8350310d03c4d5a3"
checksum = "7dbc2bab375524093c143dc362a03fb6a1fb79e938391cdb21665688f88a088a"
dependencies = [
"base64 0.22.1",
"crc",
@@ -3519,9 +3571,9 @@ dependencies = [
[[package]]
name = "turn"
version = "0.9.0"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0044fdae001dd8a1e247ea6289abf12f4fcea1331a2364da512f9cd680bbd8cb"
checksum = "3f5aea1116456e1da71c45586b87c72e3b43164fbf435eb93ff6aa475416a9a4"
dependencies = [
"async-trait",
"base64 0.22.1",
@@ -3601,7 +3653,7 @@ version = "0.5.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea"
dependencies = [
"crypto-common",
"crypto-common 0.1.6",
"subtle",
]
@@ -3784,9 +3836,9 @@ dependencies = [
[[package]]
name = "webrtc"
version = "0.12.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "30367074d9f18231d28a74fab0120856b2b665da108d71a12beab7185a36f97b"
checksum = "24bab7195998d605c862772f90a452ba655b90a2f463c850ac032038890e367a"
dependencies = [
"arc-swap",
"async-trait",
@@ -3828,9 +3880,9 @@ dependencies = [
[[package]]
name = "webrtc-data"
version = "0.10.0"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dec93b991efcd01b73c5b3503fa8adba159d069abe5785c988ebe14fcf8f05d1"
checksum = "4e97b932854da633a767eff0cc805425a2222fc6481e96f463e57b015d949d1d"
dependencies = [
"bytes",
"log",
@@ -3843,9 +3895,9 @@ dependencies = [
[[package]]
name = "webrtc-dtls"
version = "0.11.0"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b7c9b89fc909f9da0499283b1112cd98f72fec28e55a54a9e352525ca65cd95c"
checksum = "5ccbe4d9049390ab52695c3646c1395c877e16c15fb05d3bda8eee0c7351711c"
dependencies = [
"aes",
"aes-gcm",
@@ -3880,9 +3932,9 @@ dependencies = [
[[package]]
name = "webrtc-ice"
version = "0.12.0"
version = "0.13.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0348b28b593f7709ac98d872beb58c0009523df652c78e01b950ab9c537ff17d"
checksum = "eb51bde0d790f109a15bfe4d04f1b56fb51d567da231643cb3f21bb74d678997"
dependencies = [
"arc-swap",
"async-trait",
@@ -3905,9 +3957,9 @@ dependencies = [
[[package]]
name = "webrtc-mdns"
version = "0.8.0"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e6dfe9686c6c9c51428da4de415cb6ca2dc0591ce2b63212e23fd9cccf0e316b"
checksum = "979cc85259c53b7b620803509d10d35e2546fa505d228850cbe3f08765ea6ea8"
dependencies = [
"log",
"socket2",
@@ -3918,9 +3970,9 @@ dependencies = [
[[package]]
name = "webrtc-media"
version = "0.9.0"
version = "0.10.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e153be16b8650021ad3e9e49ab6e5fa9fb7f6d1c23c213fd8bbd1a1135a4c704"
checksum = "80041211deccda758a3e19aa93d6b10bc1d37c9183b519054b40a83691d13810"
dependencies = [
"byteorder",
"bytes",
@@ -3931,9 +3983,9 @@ dependencies = [
[[package]]
name = "webrtc-sctp"
version = "0.11.0"
version = "0.12.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5faf3846ec4b7e64b56338d62cbafe084aa79806b0379dff5cc74a8b7a2b3063"
checksum = "07439c134425d51d2f10907aaf2f815fdfb587dce19fe94a4ae8b5faf2aae5ae"
dependencies = [
"arc-swap",
"async-trait",
@@ -3949,9 +4001,9 @@ dependencies = [
[[package]]
name = "webrtc-srtp"
version = "0.14.0"
version = "0.15.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "771db9993712a8fb3886d5be4613ebf27250ef422bd4071988bf55f1ed1a64fa"
checksum = "01e773f79b09b057ffbda6b03fe7b43403b012a240cf8d05d630674c3723b5bb"
dependencies = [
"aead",
"aes",
@@ -3972,9 +4024,9 @@ dependencies = [
[[package]]
name = "webrtc-util"
version = "0.10.0"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1438a8fd0d69c5775afb4a71470af92242dbd04059c61895163aa3c1ef933375"
checksum = "64bfb10dbe6d762f80169ae07cf252bafa1f764b9594d140008a0231c0cdce58"
dependencies = [
"async-trait",
"bitflags 1.3.2",

View File

@@ -5,14 +5,13 @@ edition = "2024"
[dependencies]
anyhow = "1.0"
argon2 = "0.5"
argon2 = "0.6.0-pre.1"
axum = { version = "0.8", features = ["ws", "multipart", "macros"] }
axum-extra = { version = "0.10", features = ["typed-header"] }
chrono = { version = "0.4", features = ["serde"] }
config = "0.15"
derive_more = { version = "2.0", features = ["full"] }
jsonwebtoken = "9.3"
rand_core = { version = "0.6", features = ["getrandom"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio", "uuid", "chrono"] }
@@ -23,15 +22,15 @@ tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["serde", "env-filter", "chrono"] }
uuid = { version = "1.16", features = ["fast-rng", "serde", "v7"] }
url = { version = "2.5", features = ["serde"] }
validator = { version = "0.20.0", features = ["derive"] }
regex = "1.11.1"
mime = "0.3.17"
axum_typed_multipart = "0.16.2"
async-trait = "0.1.88"
futures = "0.3.31"
webrtc = "0.12.0"
dashmap = "6.1.0"
rand = "0.9.1"
sha2 = "0.10.9"
hex = "0.4.3"
base64 = "0.22.1"
validator = { version = "0.20", features = ["derive"] }
regex = "1.11"
mime = "0.3"
axum_typed_multipart = "0.16"
async-trait = "0.1"
futures = "0.3"
webrtc = "0.13"
dashmap = "6.1"
rand = "0.9"
sha2 = "0.10"
base64 = "0.22"
scc = "2.3"

View File

@@ -18,11 +18,30 @@ pub fn config() -> &'static Config {
#[derive(Deserialize)]
pub struct Config {
pub port: u16,
pub jwt_secret: String,
pub server: ServerConfig,
pub security: SecurityConfig,
pub gateway: GatewayConfig,
pub database: DatabaseConfig,
}
#[derive(Deserialize)]
pub struct ServerConfig {
pub host: std::net::Ipv4Addr,
pub port: u16,
}
#[derive(Deserialize)]
pub struct SecurityConfig {
pub auth_secret: String,
pub voice_secret: String,
}
#[derive(Deserialize)]
pub struct GatewayConfig {
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds")]
pub voice_token_lifetime: std::time::Duration,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[serde(deny_unknown_fields)]

View File

@@ -10,20 +10,21 @@ use crate::config;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims<T> {
pub user_id: T,
#[serde(flatten)]
pub data: T,
pub iat: i64,
}
pub fn generate_jwt<T: Serialize>(user_id: T) -> Result<String> {
pub fn generate_jwt<T: Serialize>(data: T) -> Result<String> {
let claims = Claims {
user_id,
data,
iat: Utc::now().timestamp_millis(),
};
let token = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
&claims,
&jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()),
&jsonwebtoken::EncodingKey::from_secret(config::config().security.auth_secret.as_ref()),
)
.map_err(|_| Error::CouldNotEncodeToken)?;
@@ -34,16 +35,16 @@ pub fn verify_jwt<T: DeserializeOwned>(token: &str) -> Result<T> {
tracing::debug!("verifying token: {}", token);
let mut validation = jsonwebtoken::Validation::default();
validation.required_spec_claims = HashSet::new();
validation.set_required_spec_claims::<String>(&[]);
let token_data = jsonwebtoken::decode::<Claims<T>>(
token,
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
&jsonwebtoken::DecodingKey::from_secret(config::config().security.auth_secret.as_ref()),
&validation,
)
.map_err(|_| Error::CouldNotVerifyToken)?;
Ok(token_data.claims.user_id)
Ok(token_data.claims.data)
}
pub type Result<T> = std::result::Result<T, Error>;

View File

@@ -1,8 +1,6 @@
use std::collections::HashMap;
use std::sync::Arc;
use argon2::Argon2;
use tokio::sync::RwLock;
use crate::database::Database;
use crate::state::AppState;
@@ -25,8 +23,8 @@ async fn main() -> anyhow::Result<()> {
let state = AppState {
database,
hasher: Arc::new(Argon2::default()),
event_connected_users: Arc::new(RwLock::new(HashMap::new())),
voice_rooms: Arc::new(RwLock::new(HashMap::new())),
gateway_state: Default::default(),
voice_rooms: Default::default(),
};
web::run(state).await?;

View File

@@ -2,12 +2,11 @@ use std::collections::HashMap;
use std::sync::Arc;
use argon2::Argon2;
use dashmap::DashMap;
use tokio::sync::{RwLock, mpsc};
use uuid::Uuid;
use crate::database::Database;
use crate::web::ws::gateway::{EventWsState, message};
use crate::web::ws::gateway::{GatewayWsState, SessionKey, event};
use crate::webrtc::OfferSignal;
#[derive(Clone)]
@@ -15,37 +14,50 @@ pub struct AppState {
pub database: Database,
pub hasher: Arc<Argon2<'static>>,
pub event_connected_users: Arc<RwLock<HashMap<Uuid, EventWsState>>>,
pub gateway_state: Arc<GatewayState>,
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<OfferSignal>>>>,
}
impl AppState {
pub async fn register_event_connected_user(
&self,
user_id: Uuid,
session_id: String,
event_sender: mpsc::UnboundedSender<message::Event>,
) {
let mut connected_users = self.event_connected_users.write().await;
if let Some(state) = connected_users.get_mut(&user_id) {
state.connection_instance.insert(session_id, event_sender);
} else {
let state = EventWsState {
connection_instance: DashMap::new(),
};
state.connection_instance.insert(session_id, event_sender);
connected_users.insert(user_id, state);
}
#[derive(Debug, Default)]
pub struct GatewayState {
pub connected: scc::HashMap<Uuid, GatewayWsState>,
}
pub async fn unregister_event_connected_user(&self, user_id: Uuid, session_id: &str) {
let mut connected_users = self.event_connected_users.write().await;
if let Some(state) = connected_users.get_mut(&user_id) {
state.connection_instance.remove(session_id);
if state.connection_instance.is_empty() {
connected_users.remove(&user_id);
}
impl AppState {
pub async fn register_gateway_connected_user(
&self,
user_id: Uuid,
session_key: SessionKey,
event_sender: mpsc::UnboundedSender<event::Event>,
) {
self.gateway_state
.connected
.entry_async(user_id)
.await
.or_default()
.get_mut()
.instances
.insert(session_key, event_sender);
}
pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) {
let is_empty = {
let mut entry = self
.gateway_state
.connected
.entry_async(user_id)
.await
.or_default();
let entry = entry.get_mut();
entry.instances.remove(session_id);
entry.instances.is_empty()
};
if is_empty {
self.gateway_state.connected.remove_async(&user_id).await;
}
}
}

View File

@@ -1,6 +1,6 @@
use axum::extract::multipart::Field;
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
use serde::Serialize;
use serde::{Deserialize, Serialize};
#[derive(Debug, derive_more::Deref)]
pub struct SerdeFieldData<T>(pub FieldData<T>);
@@ -50,3 +50,13 @@ where
let seconds = duration.as_secs();
seconds.serialize(serializer)
}
pub fn deserialize_duration_seconds<'de, D>(
deserializer: D,
) -> Result<std::time::Duration, D::Error>
where
D: serde::Deserializer<'de>,
{
let seconds = u64::deserialize(deserializer)?;
Ok(std::time::Duration::from_secs(seconds))
}

View File

@@ -3,7 +3,7 @@ use axum::http::request::Parts;
use crate::entity;
#[derive(Debug, Copy, Clone)]
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct UserContext {
pub user_id: entity::user::Id,
}

View File

@@ -9,6 +9,7 @@ use axum_extra::typed_header::TypedHeaderRejectionReason;
use crate::jwt;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::{self, context};
pub async fn require_context(
@@ -51,13 +52,13 @@ async fn get_context(state: &AppState, request: &mut Request) -> context::UserCo
}
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
let context = jwt::verify_jwt::<UserContext>(token).map_err(|_| context::Error::BadToken)?;
let _ = state
.database
.select_user_by_id(user_id)
.select_user_by_id(context.user_id)
.await
.map_err(|_| context::Error::BadToken)?;
Ok(context::UserContext { user_id })
Ok(context)
}

View File

@@ -12,9 +12,9 @@ use crate::{config, state};
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
let config = config::config();
let addr: std::net::SocketAddr = ([127, 0, 0, 1], config.port).into();
tracing::info!("listening on {}", addr);
let addr: std::net::SocketAddr = (config.server.host, config.server.port).into();
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("listening on {}", addr);
axum::serve(listener, router(state))
.with_graceful_shutdown(shutdown_signal())
@@ -36,10 +36,12 @@ fn router(state: state::AppState) -> axum::Router {
Router::new()
// websocket
.route("/gateway/ws", get(ws::gateway::ws_handler))
.route("/voice/ws", get(ws::voice::ws_handler))
// api
.nest(
"/api/v1",
Router::new()
.route("/ws", get(ws::gateway::ws_handler))
.route("/auth/login", post(auth::login))
.route("/auth/register", post(auth::register))
.merge(protected_router())
@@ -48,6 +50,7 @@ fn router(state: state::AppState) -> axum::Router {
middleware::resolve_context,
)),
)
// middleware
.layer(axum::middleware::from_fn(middleware::response_map))
.layer(cors)
.layer(tower_http::trace::TraceLayer::new_for_http())
@@ -70,7 +73,6 @@ fn protected_router() -> axum::Router<state::AppState> {
.route("/servers", post(server::create))
.route("/servers/{server_id}", get(server::get))
.route("/servers/{server_id}/channels", get(server::channel::list))
.route("/voice/{channel_id}/connect", post(voice::connect))
// middleware
.route_layer(axum::middleware::from_fn(middleware::require_context))
}

View File

@@ -5,6 +5,7 @@ use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::route::user::FullUser;
use crate::{jwt, web};
@@ -38,7 +39,7 @@ pub async fn login(
.verify_password(payload.password.as_bytes(), &password_hash)
.map_err(|_| web::error::ClientError::WrongPassword)?;
let token = jwt::generate_jwt(user.id)?;
let token = jwt::generate_jwt(UserContext { user_id: user.id })?;
let response = LoginResponse {
user: user.into(),

View File

@@ -1,3 +1,3 @@
mod list;
pub use list::*;
pub use list::list;

View File

@@ -73,7 +73,7 @@ pub async fn create(
ws::gateway::util::send_message(
&state,
context.user_id,
ws::gateway::message::Event::AddServer {
ws::gateway::event::Event::AddServer {
server: server.clone(),
},
)

View File

@@ -3,6 +3,6 @@ mod create;
mod get;
mod list;
pub use create::*;
pub use get::*;
pub use list::*;
pub use create::create;
pub use get::get;
pub use list::list;

View File

@@ -1,3 +1,3 @@
mod list;
pub use list::*;
pub use list::list;

View File

@@ -2,8 +2,8 @@ pub mod channel;
mod get;
mod me;
pub use get::*;
pub use me::*;
pub use get::get_by_id;
pub use me::me;
use crate::entity::user;

View File

@@ -1,3 +1,3 @@
mod connect;
pub use connect::*;
pub use connect::connect;

14
src/web/ws/error.rs Normal file
View File

@@ -0,0 +1,14 @@
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error {
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
WrongMessageType,
WebSocketClosed,
}

View File

@@ -1,24 +1,22 @@
// src/web/ws/connection.rs
use std::ops::ControlFlow;
use std::time::Duration;
use axum::extract::ws::{Message as AxumMessage, WebSocket};
use base64::Engine as _; // Bring trait into scope
use base64::Engine as _;
use futures::stream::SplitStream;
use futures::{SinkExt, StreamExt};
use futures::{Sink, SinkExt, StreamExt};
use serde::Serialize;
use sha2::{Digest, Sha256};
use tokio::time::Instant;
// Use items from sibling modules within `ws`
use super::error::{self, Error as WsError};
// Assuming Event type is from ws::message
use super::message::Event as WsEvent;
use super::protocol::{
SendWsMessage, WsClientMessage, WsServerMessage, deserialize_ws_message, serialize_ws_message,
};
use super::event::Event as WsEvent;
use super::protocol::{WsClientMessage, WsServerMessage};
use super::state::{WsContext, WsState, WsUserContext};
use crate::jwt;
use crate::state::AppState;
use crate::web::ws::gateway::SessionKey;
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message, serialize_ws_message};
use crate::web::ws::{util, voice};
/// Main handler for an individual WebSocket connection's lifecycle.
/// Spawned by Axum upon successful WebSocket upgrade.
@@ -26,57 +24,20 @@ use crate::state::AppState;
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {
let (ws_sink, ws_stream) = websocket.split();
let (internal_send_tx, mut internal_send_rx) = tokio::sync::mpsc::unbounded_channel();
let (internal_send_tx, internal_send_rx) = tokio::sync::mpsc::unbounded_channel();
// Writer task: consumes messages from MPSC channel and sends them to the WebSocket sink.
let writer_task = tokio::spawn(async move {
let mut ws_sink_mut = ws_sink;
while let Some(SendWsMessage {
message,
response_ch,
}) = internal_send_rx.recv().await
{
let send_result = match serialize_ws_message(message) {
Ok(ws_msg) => {
if ws_sink_mut.send(ws_msg).await.is_err() {
Err(WsError::WebSocketClosed) // Send to client failed
} else {
Ok(())
}
},
Err(e) => Err(e), // Serialization error itself
};
if let Some(ch) = response_ch {
if ch.send(send_result).is_err() {
// Log if the receiver of the acknowledgement was dropped, though this is unlikely
// if send_with_response is awaiting it.
tracing::debug!("Failed to send acknowledgement; receiver dropped.");
}
} else if let Err(e) = send_result {
// For fire-and-forget, log critical errors (not just WebSocketClosed).
if !matches!(e, WsError::WebSocketClosed) {
tracing::warn!("Error in fire-and-forget WebSocket send: {:?}", e);
}
}
}
// MPSC channel closed, attempt to gracefully close WebSocket.
if ws_sink_mut.close().await.is_err() {
tracing::debug!("Error closing WebSocket sink; connection might be already dead.");
}
});
let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx);
let mut context = WsContext {
connection_state: WsState::Initialize,
user_context: None,
heartbeat_interval: std::time::Duration::from_secs(30), // Assuming config path
next_ping_deadline: Instant::now(), // Will be properly set before first use
event_channel: None,
};
let processing_result = process_websocket_messages(
&mut context,
ws_stream,
&internal_send_tx, // Pass as reference
&internal_send_tx,
&app_state,
)
.await;
@@ -84,7 +45,7 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState)
// --- Cleanup ---
if let Some(user_ctx_data) = &context.user_context {
app_state
.unregister_event_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key)
.unregister_gateway_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key)
.await;
tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user.");
}
@@ -97,13 +58,13 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState)
if let Err(err_to_report) = &processing_result {
if !matches!(
err_to_report,
WsError::WebSocketClosed | WsError::HeartbeatTimeout
WsError::WebSocketClosed
) {
tracing::warn!(
"WebSocket processing error, attempting to notify client: {:?}",
err_to_report
);
let client_err_code = err_to_report.into_client_error();
let client_err_code = err_to_report.as_client_error();
let error_ws_message = WsServerMessage::Error {
code: client_err_code,
};
@@ -127,32 +88,19 @@ pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState)
/// Main loop for processing incoming WebSocket messages and outgoing application events.
/// Manages state transitions (Initialize -> Connected) and heartbeating.
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))]
async fn process_websocket_messages(
context: &mut WsContext,
mut ws_stream: SplitStream<WebSocket>,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>,
app_state: &AppState,
) -> error::Result<()> {
// Send initial heartbeat interval and set first deadline.
SendWsMessage::send_with_response(
sender,
WsServerMessage::HeartbeatInterval {
interval: context.heartbeat_interval,
},
)
.await?;
context.reset_deadline();
loop {
match context.connection_state {
WsState::Initialize => {
tokio::select! {
biased; // Prefer timeout check if multiple branches are ready
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
tracing::warn!("Initial connection timeout (no Authenticate or Ping).");
return Err(WsError::HeartbeatTimeout);
}
biased;
maybe_message = ws_stream.next() => {
match maybe_message {
Some(Ok(message)) => {
@@ -191,10 +139,6 @@ async fn process_websocket_messages(
tokio::select! {
biased;
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
tracing::warn!(user_id = ?user_ctx.user_id, "Heartbeat timeout.");
return Err(WsError::HeartbeatTimeout);
}
// Listen for application events to send to the client
maybe_app_event = event_rx.recv() => {
if let Some(app_event_data) = maybe_app_event {
@@ -234,20 +178,13 @@ async fn process_websocket_messages(
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>, // Changed to reference
app_state: &AppState,
) -> error::Result<ControlFlow<WsState, ()>> {
// Break(NewState) or Continue(())
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => {
// IMPORTANT: Adjust the call below to your actual token validation logic.
// Assuming `get_context_from_token` returns `Result<crate::web::context::UserContext, YourAuthError>`
match crate::web::middleware::get_context_from_token(
&app_state, // Example: Pass necessary parts of AppState
&token,
)
.await
{
match crate::web::middleware::get_context_from_token(&app_state, &token).await {
Ok(auth_user_context) => {
// auth_user_context is `crate::web::context::UserContext`
let user_id = auth_user_context.user_id;
@@ -256,12 +193,14 @@ async fn handle_initial_message(
context.event_channel = Some((event_tx.clone(), event_rx));
let random_key_part = rand::random::<u64>();
let current_session_key = {
let current_session_key: SessionKey = {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hasher.update(user_id.to_string().as_bytes());
hasher.update(&random_key_part.to_be_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize())
base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(hasher.finalize())
.into()
};
context.user_context = Some(WsUserContext {
@@ -270,7 +209,7 @@ async fn handle_initial_message(
});
app_state
.register_event_connected_user(
.register_gateway_connected_user(
user_id,
current_session_key.clone(),
event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender<ws::message::Event>
@@ -281,7 +220,7 @@ async fn handle_initial_message(
sender,
WsServerMessage::AuthenticateAccepted {
user_id,
session_key: current_session_key.clone(),
session_key: current_session_key,
},
)
.await?;
@@ -301,11 +240,6 @@ async fn handle_initial_message(
},
}
},
WsClientMessage::Ping => {
context.reset_deadline(); // Reset deadline on successful ping
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
Ok(ControlFlow::Continue(()))
},
// Per original code, only Authenticate and Ping are expected in Initialize.
// If WsClientMessage has other variants, this might need adjustment.
#[allow(unreachable_patterns)]
@@ -318,17 +252,45 @@ async fn handle_initial_message(
/// Handles messages received when the connection is in the `Connected` state.
/// Primarily expects `Ping` messages to keep the connection alive.
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))]
async fn handle_connected_message(
context: &mut WsContext, // Although not heavily used here, good for consistency and tracing
context: &mut WsContext,
message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage<WsServerMessage>>,
) -> error::Result<()> {
match deserialize_ws_message(message)? {
WsClientMessage::Ping => {
tracing::debug!("Ping received.");
context.reset_deadline();
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
WsClientMessage::VoiceStateUpdate {
server_id,
channel_id,
} => {
// TODO: check if can join this channel
let claims = voice::claims::VoiceClaims {
user_id: context
.user_context
.as_ref()
.expect("user context should be present")
.user_id,
server_id,
channel_id,
iat: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime)
.timestamp(),
};
let token = jwt::generate_jwt(claims).map_err(|_| WsError::TokenGenerationFailed)?;
SendWsMessage::send_with_response(
sender,
WsServerMessage::Event {
event: WsEvent::VoiceServerUpdate {
server_id,
channel_id,
token,
},
},
)
.await?;
Ok(())
},

View File

@@ -17,8 +17,9 @@ pub enum Error {
WebSocketClosed,
HeartbeatTimeout,
AuthenticationFailed,
TokenGenerationFailed,
}
#[derive(Debug, Clone, serde::Serialize)]
@@ -27,15 +28,13 @@ pub enum ClientError {
DeserializationError,
NotAuthenticated,
AlreadyAuthenticated,
HeartbeatTimeout,
Unknown,
}
impl Error {
pub fn into_client_error(&self) -> ClientError {
pub fn as_client_error(&self) -> ClientError {
match self {
Error::HeartbeatTimeout => ClientError::HeartbeatTimeout,
Error::Json(_) => ClientError::DeserializationError,
Error::UnexpectedMessageType => ClientError::Unknown,
Error::WrongMessageType => ClientError::Unknown,
@@ -44,3 +43,14 @@ impl Error {
}
}
}
impl From<crate::web::ws::error::Error> for Error {
fn from(err: crate::web::ws::error::Error) -> Self {
match err {
crate::web::ws::error::Error::Json(e) => Error::Json(e),
crate::web::ws::error::Error::AcknowledgementError(e) => Error::AcknowledgementError(e),
crate::web::ws::error::Error::WrongMessageType => Error::WrongMessageType,
crate::web::ws::error::Error::WebSocketClosed => Error::WebSocketClosed,
}
}
}

View File

@@ -4,12 +4,25 @@ use crate::entity;
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Event {
#[serde(rename_all = "camelCase")]
AddServer { server: entity::server::Server },
#[serde(rename_all = "camelCase")]
RemoveServer { server_id: entity::server::Id },
#[serde(rename_all = "camelCase")]
AddDmChannel { channel: entity::channel::Channel },
#[serde(rename_all = "camelCase")]
RemoveDmChannel { channel_id: entity::channel::Id },
#[serde(rename_all = "camelCase")]
AddServerChannel { channel: entity::channel::Channel },
#[serde(rename_all = "camelCase")]
RemoveServerChannel { channel_id: entity::channel::Id },
#[serde(rename_all = "camelCase")]
VoiceServerUpdate {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
token: String,
},
}

View File

@@ -1,9 +1,6 @@
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use base64::Engine;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use sha2::Digest;
use crate::state::AppState;
use crate::web::ws::gateway::connection::handle_socket_connection;
@@ -11,14 +8,29 @@ use crate::web::ws::gateway::state::EventSender;
mod connection;
mod error;
pub mod message;
pub mod event;
mod protocol;
mod state;
pub mod util;
#[derive(
Debug,
Clone,
Eq,
PartialEq,
Hash,
Default,
derive_more::Display,
serde::Serialize,
serde::Deserialize,
derive_more::From,
)]
#[serde(transparent)]
pub struct SessionKey(String);
#[derive(Debug, Default)]
pub struct EventWsState {
pub connection_instance: DashMap<String, EventSender>,
pub struct GatewayWsState {
pub instances: DashMap<SessionKey, EventSender>,
}
pub async fn ws_handler(

View File

@@ -1,33 +1,29 @@
// src/web/ws/protocol.rs
use axum::extract::ws::Message as AxumMessage;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::error::{self, ClientError, Error as WsError};
use super::message as ws_local_message; // For ws::message::Event
use serde::{Deserialize, Serialize};
use super::error::ClientError;
use super::{SessionKey, event as ws_local_message};
use crate::entity;
use crate::util as crate_root_util; // For crate::util::serialize_duration_seconds
#[derive(Debug, Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage {
HeartbeatInterval {
#[serde(serialize_with = "crate_root_util::serialize_duration_seconds")]
interval: Duration,
},
AuthenticateDenied,
#[serde(rename_all = "camelCase")]
AuthenticateAccepted {
user_id: entity::user::Id,
session_key: String,
session_key: SessionKey,
},
#[serde(rename_all = "camelCase")]
Event {
event: ws_local_message::Event, // Assumes Event is defined in ws::message
event: ws_local_message::Event,
},
Pong,
#[serde(rename_all = "camelCase")]
Error {
code: ClientError,
},
@@ -41,59 +37,10 @@ pub enum WsClientMessage {
Authenticate {
token: String,
},
Ping,
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message(message: AxumMessage) -> error::Result<WsClientMessage> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from),
AxumMessage::Close(_) => Err(WsError::WebSocketClosed),
_ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message(message: WsServerMessage) -> error::Result<AxumMessage> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(WsError::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage {
pub message: WsServerMessage,
pub response_ch: Option<tokio::sync::oneshot::Sender<error::Result<()>>>,
}
impl SendWsMessage {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &tokio::sync::mpsc::UnboundedSender<Self>, // Changed to reference
message: WsServerMessage,
) -> error::Result<()> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: WsServerMessage) -> Self {
SendWsMessage {
message,
response_ch: None,
}
}
#[serde(rename_all = "camelCase")]
VoiceStateUpdate {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
},
}

View File

@@ -1,10 +1,8 @@
// src/web/ws/state.rs
use std::time::{Duration};
use std::time::Duration;
use tokio::sync::mpsc;
use super::message;
use super::{event, SessionKey};
use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver
/// Represents the current state of a single WebSocket connection.
@@ -15,35 +13,25 @@ pub enum WsState {
}
/// Contextual information for an authenticated WebSocket user session.
#[derive(Debug, Clone)] // Clone might be useful
#[derive(Debug)]
pub struct WsUserContext {
pub user_id: entity::user::Id,
pub session_key: String, // Unique key for this specific WebSocket session instance
pub session_key: SessionKey, // Unique key for this specific WebSocket session instance
}
/// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client.
pub type EventSender = mpsc::UnboundedSender<message::Event>;
pub type EventSender = mpsc::UnboundedSender<event::Event>;
/// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s.
pub type EventReceiver = mpsc::UnboundedReceiver<message::Event>;
pub type EventReceiver = mpsc::UnboundedReceiver<event::Event>;
/// Holds the full context for a single WebSocket connection's lifecycle.
/// This struct is managed per-connection.
pub struct WsContext {
pub connection_state: WsState,
pub user_context: Option<WsUserContext>,
pub heartbeat_interval: Duration,
pub next_ping_deadline: tokio::time::Instant,
/// Channel for receiving application-specific events to be sent to this client.
/// The `EventSender` (tx) part is given to `AppState` for broadcasting.
/// The `EventReceiver` (rx) part is polled by the connection task.
pub event_channel: Option<(EventSender, EventReceiver)>,
}
impl WsContext {
/// Resets the ping deadline based on the current time and heartbeat interval.
/// This should be called after successfully receiving a ping from the client
/// or after sending a message that implies activity (like Pong or initial auth).
pub fn reset_deadline(&mut self) {
self.next_ping_deadline = tokio::time::Instant::now() + self.heartbeat_interval;
}
}

View File

@@ -1,11 +1,11 @@
use crate::entity;
use crate::state::AppState;
use crate::web::ws::gateway::message;
use crate::web::ws::gateway::event;
pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: message::Event) {
let connected_users = state.event_connected_users.read().await;
if let Some(state) = connected_users.get(&user_id) {
for instance in state.connection_instance.iter() {
pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: event::Event) {
let connected_users = state.gateway_state.connected.get_async(&user_id).await;
if let Some(session) = connected_users {
for instance in session.instances.iter() {
if let Err(e) = instance.send(message.clone()) {
tracing::error!("failed to send message: {}", e);
}

View File

@@ -1 +1,4 @@
mod error;
pub mod gateway;
mod util;
pub mod voice;

95
src/web/ws/util.rs Normal file
View File

@@ -0,0 +1,95 @@
use axum::extract::ws::Message as AxumMessage;
use futures::{Sink, SinkExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::{mpsc, oneshot};
pub fn spawn_writer_task<S, T>(
mut ws_sink: S,
mut writer_rx: mpsc::UnboundedReceiver<SendWsMessage<T>>,
) -> tokio::task::JoinHandle<()>
where
S: Sink<axum::extract::ws::Message> + Unpin + Send + 'static,
T: Serialize + Send + 'static,
{
tokio::spawn(async move {
while let Some(SendWsMessage {
message,
response_ch,
}) = writer_rx.recv().await
{
let send_result = match serialize_ws_message(message) {
Ok(ws_msg) => {
if ws_sink.send(ws_msg).await.is_err() {
Err(super::error::Error::WebSocketClosed)
} else {
Ok(())
}
},
Err(e) => Err(e.into()),
};
if let Some(ch) = response_ch {
let _ = ch.send(send_result);
}
}
let _ = ws_sink.close().await;
})
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message<T: DeserializeOwned>(
message: AxumMessage,
) -> super::error::Result<T> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from),
AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed),
_ => Err(super::error::Error::WrongMessageType),
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message<T: Serialize>(message: T) -> super::error::Result<AxumMessage> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(super::error::Error::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage<T> {
pub message: T,
pub response_ch: Option<oneshot::Sender<super::error::Result<()>>>,
}
impl<T> SendWsMessage<T> {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &mpsc::UnboundedSender<Self>, // Changed to reference
message: T,
) -> super::error::Result<()> {
let (response_tx, response_rx) = oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: T) -> Self {
SendWsMessage {
message,
response_ch: None,
}
}
}

View File

@@ -0,0 +1,9 @@
use crate::entity;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VoiceClaims {
pub user_id: entity::user::Id,
pub server_id: entity::server::Id,
pub channel_id: entity::channel::Id,
pub iat: i64,
}

View File

@@ -0,0 +1,6 @@
use axum::extract::ws::WebSocket;
use crate::state::AppState;
#[tracing::instrument(skip_all, name = "ws_connection_handler")]
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {}

48
src/web/ws/voice/error.rs Normal file
View File

@@ -0,0 +1,48 @@
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error {
#[from]
Axum(axum::Error),
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
UnexpectedMessageType,
WrongMessageType,
WebSocketClosed,
HeartbeatTimeout,
AuthenticationFailed,
TokenGenerationFailed,
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
DeserializationError,
NotAuthenticated,
AlreadyAuthenticated,
HeartbeatTimeout,
Unknown,
}
impl Error {
pub fn into_client_error(&self) -> ClientError {
match self {
Error::HeartbeatTimeout => ClientError::HeartbeatTimeout,
Error::Json(_) => ClientError::DeserializationError,
Error::UnexpectedMessageType => ClientError::Unknown,
Error::WrongMessageType => ClientError::Unknown,
Error::WebSocketClosed => ClientError::Unknown,
_ => ClientError::Unknown,
}
}
}

17
src/web/ws/voice/mod.rs Normal file
View File

@@ -0,0 +1,17 @@
pub mod claims;
mod connection;
mod error;
mod protocol;
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::ws::voice::connection::handle_socket_connection;
pub async fn ws_handler(
State(app_state): State<AppState>,
ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state)))
}

View File

@@ -0,0 +1,105 @@
use std::time::Duration;
use axum::extract::ws::Message as AxumMessage;
use serde::{Deserialize, Serialize};
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use super::error::{self, ClientError, Error as WsError};
use crate::{entity, util as crate_root_util}; // For crate::util::serialize_duration_seconds
#[derive(Debug, Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage {
HeartbeatInterval {
#[serde(serialize_with = "crate_root_util::serialize_duration_seconds")]
interval: Duration,
},
AuthenticateDenied,
AuthenticateAccepted,
#[serde(rename_all = "camelCase")]
SdpAnswer {
sdp: RTCSessionDescription,
},
#[serde(rename_all = "camelCase")]
Error {
code: ClientError,
},
Pong,
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage {
#[serde(rename_all = "camelCase")]
Authenticate {
token: String,
},
#[serde(rename_all = "camelCase")]
SdpOffer {
sdp: RTCSessionDescription,
},
Ping,
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message(message: AxumMessage) -> error::Result<WsClientMessage> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from),
AxumMessage::Close(_) => Err(WsError::WebSocketClosed),
_ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message(message: WsServerMessage) -> error::Result<AxumMessage> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(WsError::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage {
pub message: WsServerMessage,
pub response_ch: Option<tokio::sync::oneshot::Sender<error::Result<()>>>,
}
impl SendWsMessage {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &tokio::sync::mpsc::UnboundedSender<Self>, // Changed to reference
message: WsServerMessage,
) -> error::Result<()> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: WsServerMessage) -> Self {
SendWsMessage {
message,
response_ch: None,
}
}
}