diff --git a/src/database.rs b/src/database.rs index be0af84..821d286 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,5 +1,6 @@ use crate::config::DatabaseConfig; use crate::entity; +use crate::util::tristate::TriState; #[derive(Clone, derive_more::AsRef)] pub struct Database { @@ -90,18 +91,28 @@ impl Database { pub async fn update_user_by_id( &self, user_id: entity::user::Id, - display_name: Option<&str>, - avatar_id: Option, + display_name: TriState<&str>, + avatar_id: TriState, ) -> Result { - let user = sqlx::query_as!( - entity::user::User, - r#"UPDATE "user" SET "display_name" = COALESCE($2, "display_name"), "avatar_id" = COALESCE($3, "avatar_id") WHERE "id" = $1 RETURNING "user".*"#, - user_id, - display_name, - avatar_id - ) - .fetch_one(&self.pool) - .await?; + let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#); + + let mut separated = query_builder.separated(", "); + + if !display_name.is_absent() { + separated.push(r#""display_name" = "#); + separated.push_bind_unseparated(display_name.into_option()); + } + + if !avatar_id.is_absent() { + separated.push(r#""avatar_id" = "#); + separated.push_bind_unseparated(avatar_id.into_option()); + } + + query_builder.push(" WHERE \"id\" = "); + query_builder.push_bind(user_id); + query_builder.push(" RETURNING \"user\".*"); + + let user = query_builder.build_query_as().fetch_one(&self.pool).await?; Ok(user) } @@ -551,6 +562,39 @@ impl Database { Ok(file) } + + pub async fn insert_message_attachment( + &self, + message_id: entity::message::Id, + file_id: entity::file::Id, + ) -> Result<()> { + sqlx::query!( + r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#, + message_id, + file_id + ) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn select_message_attachments( + &self, + message_id: entity::message::Id, + ) -> Result> { + let attachments = sqlx::query_as!( + entity::file::File, + r#"SELECT * FROM "file" WHERE "id" IN ( + SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1 + )"#, + message_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(attachments) + } pub async fn select_related_user_ids( &self, diff --git a/src/util.rs b/src/util/mod.rs similarity index 99% rename from src/util.rs rename to src/util/mod.rs index 9587904..3d09741 100644 --- a/src/util.rs +++ b/src/util/mod.rs @@ -1,3 +1,5 @@ +pub mod tristate; + use axum::extract::multipart::Field; use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; use serde::{Deserialize, Serialize}; diff --git a/src/util/tristate.rs b/src/util/tristate.rs new file mode 100644 index 0000000..2cdfb2a --- /dev/null +++ b/src/util/tristate.rs @@ -0,0 +1,139 @@ +use std::ops::Deref; + +use serde::{Deserialize, Deserializer, Serialize}; +use validator::ValidateLength; + +#[derive(Debug, Clone)] +pub enum TriState { + Value(T), + None, + Absent, +} + +impl TriState { + pub fn as_deref(&self) -> TriState<&D> + where + T: Deref, + { + match self { + TriState::Value(v) => TriState::Value(v.deref()), + TriState::None => TriState::None, + TriState::Absent => TriState::Absent, + } + } + + pub fn is_value(&self) -> bool { + matches!(self, TriState::Value(_)) + } + + pub fn is_none(&self) -> bool { + matches!(self, TriState::None) + } + + pub fn is_absent(&self) -> bool { + matches!(self, TriState::Absent) + } + + pub fn as_option(&self) -> Option<&T> { + match self { + TriState::Value(v) => Some(v), + TriState::None => None, + TriState::Absent => None, + } + } + + pub fn into_option(self) -> Option { + match self { + TriState::Value(v) => Some(v), + TriState::None => None, + TriState::Absent => None, + } + } +} + +impl Default for TriState { + fn default() -> Self { + TriState::Absent + } +} + +pub(crate) struct TriStateFieldVisitor { + marker: std::marker::PhantomData, +} + +impl Serialize for TriState +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + TriState::Value(v) => v.serialize(serializer), + TriState::None => serializer.serialize_none(), + TriState::Absent => serializer.serialize_unit(), + } + } +} + +impl<'de, T> Deserialize<'de> for TriState +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_option(TriStateFieldVisitor:: { + marker: std::marker::PhantomData, + }) + } +} +impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor +where + T: Deserialize<'de>, +{ + type Value = TriState; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("TriStateField") + } + + #[inline] + fn visit_none(self) -> Result, E> + where + E: serde::de::Error, + { + Ok(TriState::None) + } + + #[inline] + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + T::deserialize(deserializer).map(TriState::Value) + } + + #[inline] + fn visit_unit(self) -> Result, E> + where + E: serde::de::Error, + { + Ok(TriState::None) + } +} + +impl ValidateLength for TriState +where + T: ValidateLength, +{ + fn length(&self) -> Option { + match self { + TriState::Value(v) => v.length(), + TriState::None => None, + TriState::Absent => None, + } + } +} diff --git a/src/web/entity/file.rs b/src/web/entity/file.rs new file mode 100644 index 0000000..ee87000 --- /dev/null +++ b/src/web/entity/file.rs @@ -0,0 +1,26 @@ +use serde::Serialize; + +use crate::entity::file::Id; +use crate::util; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct File { + pub id: Id, + pub filename: String, + pub content_type: String, + pub size: i64, + pub url: String, +} + +impl From for File { + fn from(file: crate::entity::file::File) -> Self { + Self { + id: file.id, + filename: file.filename, + content_type: file.content_type, + size: file.size, + url: util::file_id_to_url(&file.id).unwrap_or_default(), + } + } +} diff --git a/src/web/entity/message.rs b/src/web/entity/message.rs index 0f932cc..bf57b70 100644 --- a/src/web/entity/message.rs +++ b/src/web/entity/message.rs @@ -11,10 +11,14 @@ pub struct Message { pub channel_id: channel::Id, pub content: String, pub created_at: chrono::DateTime, + pub attachments: Vec, } -impl From for Message { - fn from(message: crate::entity::message::Message) -> Self { +impl Message { + pub fn from_message_with_attachments( + message: crate::entity::message::Message, + attachments: Vec, + ) -> Self { Self { id: message.id, author_id: message.author_id, @@ -30,6 +34,7 @@ impl From for Message { }) .flatten() .unwrap_or_default(), + attachments, } } -} +} \ No newline at end of file diff --git a/src/web/entity/mod.rs b/src/web/entity/mod.rs index 6747680..906aea9 100644 --- a/src/web/entity/mod.rs +++ b/src/web/entity/mod.rs @@ -1,3 +1,4 @@ +pub mod file; pub mod message; +pub mod server; pub mod user; -pub mod server; \ No newline at end of file diff --git a/src/web/route/channel/message/create.rs b/src/web/route/channel/message/create.rs index aa75250..18767ed 100644 --- a/src/web/route/channel/message/create.rs +++ b/src/web/route/channel/message/create.rs @@ -5,15 +5,31 @@ use validator::Validate; use crate::state::AppState; use crate::web::context::UserContext; +use crate::web::entity::file::File; use crate::web::entity::message::Message; use crate::web::ws; use crate::{entity, web}; #[derive(Debug, serde::Deserialize, Validate)] #[serde(rename_all = "camelCase")] +#[validate(schema(function = "validate_create_payload"))] pub struct CreatePayload { - #[validate(length(min = 1, max = 2000))] + #[validate(length(min = 0, max = 2000))] pub content: String, + + #[serde(default)] + #[validate(length(min = 0, max = 10))] + pub attachments: Vec, +} + +fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { + if payload.content.is_empty() && payload.attachments.is_empty() { + return Err(validator::ValidationError::new( + "at_least_one_field_required", + )); + } + + Ok(()) } pub async fn create( @@ -32,8 +48,23 @@ pub async fn create( .database .insert_channel_message(context.user_id, channel_id, &payload.content) .await?; + + for attachment in payload.attachments.iter() { + state + .database + .insert_message_attachment(message.id, *attachment) + .await?; + } - let message = Message::from(message); + let attachments = state + .database + .select_message_attachments(message.id) + .await? + .into_iter() + .map(File::from) + .collect::>(); + + let message = Message::from_message_with_attachments(message, attachments); // TODO: check permissions ws::gateway::util::send_message_channel( diff --git a/src/web/route/channel/message/page.rs b/src/web/route/channel/message/page.rs index a96d0b8..1b4af42 100644 --- a/src/web/route/channel/message/page.rs +++ b/src/web/route/channel/message/page.rs @@ -2,6 +2,7 @@ use axum::Json; use axum::extract::{Path, Query, State}; use axum::response::IntoResponse; use serde::Deserialize; +use tracing::error; use validator::Validate; use crate::state::AppState; @@ -35,13 +36,24 @@ pub async fn page( Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), }; - let messages = state + let messages_futures = state .database .select_channel_messages_paginated(channel_id, params.before, params.limit as i64) .await? .into_iter() - .map(Message::from) - .collect::>(); + .map(|message| async { + let attachments = state + .database + .select_message_attachments(message.id) + .await? + .into_iter() + .map(web::entity::file::File::from) + .collect::>(); + + Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments)) + }); + + let messages = futures::future::try_join_all(messages_futures).await?; Ok(Json(messages)) } diff --git a/src/web/route/server/channel/create.rs b/src/web/route/server/channel/create.rs index 6a7b6af..7a804d9 100644 --- a/src/web/route/server/channel/create.rs +++ b/src/web/route/server/channel/create.rs @@ -41,6 +41,11 @@ pub async fn create( // TODO: check permissions let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + let channel = state .database .insert_server_channel(&payload.name, 0, payload.r#type, server_id, None) diff --git a/src/web/route/server/channel/delete.rs b/src/web/route/server/channel/delete.rs index a8832a1..6a79c9d 100644 --- a/src/web/route/server/channel/delete.rs +++ b/src/web/route/server/channel/delete.rs @@ -5,6 +5,7 @@ use axum::response::IntoResponse; use crate::state::AppState; use crate::web::context::UserContext; use crate::web::ws; +use crate::webrtc::WebRtcSignal; use crate::{entity, web}; pub async fn delete( @@ -15,6 +16,11 @@ pub async fn delete( // TODO: check permissions let channel = state.database.select_channel_by_id(channel_id).await?; + let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } if let Some(channel_server_id) = channel.server_id { if channel_server_id != server_id { @@ -25,9 +31,9 @@ pub async fn delete( } let channel = state.database.delete_channel_by_id(channel_id).await?; - + ws::gateway::util::send_message_server( - state, + state.clone(), server_id, ws::gateway::event::Event::RemoveServerChannel { server_id: server_id.clone(), @@ -35,5 +41,12 @@ pub async fn delete( }, ); + tokio::spawn(async move { + let voice_rooms = state.voice_rooms.read().await; + if let Some(voice_room) = voice_rooms.get(&channel_id) { + let _ = voice_room.send(WebRtcSignal::Close); + } + }); + Ok(Json(channel)) } diff --git a/src/web/route/user/patch.rs b/src/web/route/user/patch.rs index c23e187..7939bf2 100644 --- a/src/web/route/user/patch.rs +++ b/src/web/route/user/patch.rs @@ -9,17 +9,28 @@ use crate::state::AppState; use crate::web::context::UserContext; use crate::web::entity::user::{FullUser, PartialUser}; use crate::web::ws; -use crate::{entity, web}; +use crate::{entity, util, web}; #[derive(Debug, Validate, Deserialize)] #[serde(rename_all = "camelCase")] +#[validate(schema(function = "validate_create_payload"))] pub struct CreatePayload { #[validate(length(min = 1, max = 32))] #[serde(default)] - display_name: Option, + display_name: util::tristate::TriState, #[serde(default)] - avatar_id: Option, + avatar_id: util::tristate::TriState, +} + +fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { + if payload.display_name.is_absent() && payload.avatar_id.is_absent() { + return Err(validator::ValidationError::new( + "at_least_one_field_required", + )); + } + + Ok(()) } pub async fn patch( diff --git a/src/web/ws/gateway/connection.rs b/src/web/ws/gateway/connection.rs index 14fb04b..bad238f 100644 --- a/src/web/ws/gateway/connection.rs +++ b/src/web/ws/gateway/connection.rs @@ -86,7 +86,7 @@ where tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err); return Err(crate::web::ws::error::Error::WebSocketClosed); } - None => { // Stream closed by client + None => { tracing::debug!("WebSocket stream ended by client during Initialize state."); return Err(crate::web::ws::error::Error::WebSocketClosed); } @@ -123,7 +123,7 @@ where tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err); return Err(crate::web::ws::error::Error::WebSocketClosed); } - None => { // Stream closed by client + None => { tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state."); return Err(crate::web::ws::error::Error::WebSocketClosed); } @@ -139,14 +139,13 @@ where async fn handle_initial_message( context: &mut WsContext, message: AxumMessage, - sender: &mpsc::UnboundedSender>, // Changed to reference + sender: &mpsc::UnboundedSender>, app_state: &AppState, ) -> crate::web::ws::error::Result<(), WsError> { match deserialize_ws_message(message)? { WsClientMessage::Authenticate { token } => { match crate::web::middleware::get_context_from_token(&app_state, &token).await { Ok(auth_user_context) => { - // auth_user_context is `crate::web::context::UserContext` let user_id = auth_user_context.user_id; let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); @@ -210,7 +209,7 @@ async fn handle_initial_message( async fn handle_connected_message( context: &mut WsContext, message: AxumMessage, - sender: &mpsc::UnboundedSender>, // Changed to reference + sender: &mpsc::UnboundedSender>, app_state: &AppState, ) -> crate::web::ws::error::Result<(), WsError> { match deserialize_ws_message(message)? { @@ -218,6 +217,7 @@ async fn handle_connected_message( server_id, channel_id, } => { + let server_id = server_id.unwrap_or_else(|| channel_id.clone().into()); // TODO: check if can join this channel let claims = voice::claims::VoiceClaims { diff --git a/src/web/ws/gateway/protocol.rs b/src/web/ws/gateway/protocol.rs index 819101b..cfe2de6 100644 --- a/src/web/ws/gateway/protocol.rs +++ b/src/web/ws/gateway/protocol.rs @@ -33,7 +33,7 @@ pub enum WsClientMessage { #[serde(rename_all = "camelCase")] VoiceStateUpdate { - server_id: entity::server::Id, + server_id: Option, channel_id: entity::channel::Id, }, diff --git a/src/web/ws/voice/connection.rs b/src/web/ws/voice/connection.rs index 4fb2a82..bf77edd 100644 --- a/src/web/ws/voice/connection.rs +++ b/src/web/ws/voice/connection.rs @@ -42,10 +42,11 @@ impl WebSocketHandler for WsContext { server_id, channel_id, user_id, + .. } => { - ws::gateway::util::send_message_server( + ws::gateway::util::send_message_channel( app_state.clone(), - *server_id, + *channel_id, ws::gateway::event::Event::VoiceChannelDisconnected { server_id: *server_id, channel_id: *channel_id, @@ -128,7 +129,7 @@ async fn handle_initial_message( let signal_channel = ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await; - + context.connection_state = WsState::Connected { signal_channel, server_id: claims.server_id, @@ -136,9 +137,9 @@ async fn handle_initial_message( user_id: claims.user_id, }; - ws::gateway::util::send_message_server( + ws::gateway::util::send_message_channel( app_state.clone(), - claims.server_id, + claims.channel_id, ws::gateway::event::Event::VoiceChannelConnected { server_id: claims.server_id, channel_id: claims.channel_id, @@ -174,7 +175,7 @@ async fn handle_connected_message( ) -> ws::error::Result<(), error::Error> { match deserialize_ws_message(message)? { WsClientMessage::SdpOffer { sdp } => { - let (signal_channel, user_id) = match &context.connection_state { + let (signal_channel, user_id) = match &mut context.connection_state { WsState::Connected { signal_channel, user_id, diff --git a/src/web/ws/voice/state.rs b/src/web/ws/voice/state.rs index 5485edf..07fa916 100644 --- a/src/web/ws/voice/state.rs +++ b/src/web/ws/voice/state.rs @@ -1,10 +1,11 @@ +use tokio::sync::oneshot; use tokio::sync::mpsc; use crate::entity; use crate::state::AppState; use crate::webrtc::{OfferSignal, WebRtcSignal}; -#[derive(Debug, Clone)] +#[derive(Debug)] pub enum WsState { Initialize, Connected { diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index 9d92671..a72eaac 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -241,14 +241,20 @@ async fn handle_peer( }; { - if let Some(old_peer) = room_state + if let Some((_, old_peer)) = room_state .peers - .insert(offer_signal.offer.peer_id, peer_state) + .remove(&offer_signal.offer.peer_id) { let _ = old_peer.peer_connection.close().await; } } + { + room_state + .peers + .insert(offer_signal.offer.peer_id, peer_state); + } + let _ = offer_signal.response.send(AnswerSignal { sdp_answer: local_description, });