.
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
use crate::config::DatabaseConfig;
|
||||
use crate::entity;
|
||||
use crate::util::tristate::TriState;
|
||||
|
||||
#[derive(Clone, derive_more::AsRef)]
|
||||
pub struct Database {
|
||||
@@ -90,18 +91,28 @@ impl Database {
|
||||
pub async fn update_user_by_id(
|
||||
&self,
|
||||
user_id: entity::user::Id,
|
||||
display_name: Option<&str>,
|
||||
avatar_id: Option<entity::file::Id>,
|
||||
display_name: TriState<&str>,
|
||||
avatar_id: TriState<entity::file::Id>,
|
||||
) -> Result<entity::user::User> {
|
||||
let user = sqlx::query_as!(
|
||||
entity::user::User,
|
||||
r#"UPDATE "user" SET "display_name" = COALESCE($2, "display_name"), "avatar_id" = COALESCE($3, "avatar_id") WHERE "id" = $1 RETURNING "user".*"#,
|
||||
user_id,
|
||||
display_name,
|
||||
avatar_id
|
||||
)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#);
|
||||
|
||||
let mut separated = query_builder.separated(", ");
|
||||
|
||||
if !display_name.is_absent() {
|
||||
separated.push(r#""display_name" = "#);
|
||||
separated.push_bind_unseparated(display_name.into_option());
|
||||
}
|
||||
|
||||
if !avatar_id.is_absent() {
|
||||
separated.push(r#""avatar_id" = "#);
|
||||
separated.push_bind_unseparated(avatar_id.into_option());
|
||||
}
|
||||
|
||||
query_builder.push(" WHERE \"id\" = ");
|
||||
query_builder.push_bind(user_id);
|
||||
query_builder.push(" RETURNING \"user\".*");
|
||||
|
||||
let user = query_builder.build_query_as().fetch_one(&self.pool).await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
@@ -552,6 +563,39 @@ impl Database {
|
||||
Ok(file)
|
||||
}
|
||||
|
||||
pub async fn insert_message_attachment(
|
||||
&self,
|
||||
message_id: entity::message::Id,
|
||||
file_id: entity::file::Id,
|
||||
) -> Result<()> {
|
||||
sqlx::query!(
|
||||
r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#,
|
||||
message_id,
|
||||
file_id
|
||||
)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn select_message_attachments(
|
||||
&self,
|
||||
message_id: entity::message::Id,
|
||||
) -> Result<Vec<entity::file::File>> {
|
||||
let attachments = sqlx::query_as!(
|
||||
entity::file::File,
|
||||
r#"SELECT * FROM "file" WHERE "id" IN (
|
||||
SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1
|
||||
)"#,
|
||||
message_id
|
||||
)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(attachments)
|
||||
}
|
||||
|
||||
pub async fn select_related_user_ids(
|
||||
&self,
|
||||
user_id: entity::user::Id,
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
pub mod tristate;
|
||||
|
||||
use axum::extract::multipart::Field;
|
||||
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
|
||||
use serde::{Deserialize, Serialize};
|
||||
139
src/util/tristate.rs
Normal file
139
src/util/tristate.rs
Normal file
@@ -0,0 +1,139 @@
|
||||
use std::ops::Deref;
|
||||
|
||||
use serde::{Deserialize, Deserializer, Serialize};
|
||||
use validator::ValidateLength;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum TriState<T> {
|
||||
Value(T),
|
||||
None,
|
||||
Absent,
|
||||
}
|
||||
|
||||
impl<T> TriState<T> {
|
||||
pub fn as_deref<D: ?Sized>(&self) -> TriState<&D>
|
||||
where
|
||||
T: Deref<Target = D>,
|
||||
{
|
||||
match self {
|
||||
TriState::Value(v) => TriState::Value(v.deref()),
|
||||
TriState::None => TriState::None,
|
||||
TriState::Absent => TriState::Absent,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_value(&self) -> bool {
|
||||
matches!(self, TriState::Value(_))
|
||||
}
|
||||
|
||||
pub fn is_none(&self) -> bool {
|
||||
matches!(self, TriState::None)
|
||||
}
|
||||
|
||||
pub fn is_absent(&self) -> bool {
|
||||
matches!(self, TriState::Absent)
|
||||
}
|
||||
|
||||
pub fn as_option(&self) -> Option<&T> {
|
||||
match self {
|
||||
TriState::Value(v) => Some(v),
|
||||
TriState::None => None,
|
||||
TriState::Absent => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_option(self) -> Option<T> {
|
||||
match self {
|
||||
TriState::Value(v) => Some(v),
|
||||
TriState::None => None,
|
||||
TriState::Absent => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Default for TriState<T> {
|
||||
fn default() -> Self {
|
||||
TriState::Absent
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) struct TriStateFieldVisitor<T> {
|
||||
marker: std::marker::PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Serialize for TriState<T>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
match self {
|
||||
TriState::Value(v) => v.serialize(serializer),
|
||||
TriState::None => serializer.serialize_none(),
|
||||
TriState::Absent => serializer.serialize_unit(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for TriState<T>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
deserializer.deserialize_option(TriStateFieldVisitor::<T> {
|
||||
marker: std::marker::PhantomData,
|
||||
})
|
||||
}
|
||||
}
|
||||
impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor<T>
|
||||
where
|
||||
T: Deserialize<'de>,
|
||||
{
|
||||
type Value = TriState<T>;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("TriStateField<T>")
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_none<E>(self) -> Result<TriState<T>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(TriState::None)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
||||
where
|
||||
D: Deserializer<'de>,
|
||||
{
|
||||
T::deserialize(deserializer).map(TriState::Value)
|
||||
}
|
||||
|
||||
#[inline]
|
||||
fn visit_unit<E>(self) -> Result<TriState<T>, E>
|
||||
where
|
||||
E: serde::de::Error,
|
||||
{
|
||||
Ok(TriState::None)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ValidateLength<u64> for TriState<T>
|
||||
where
|
||||
T: ValidateLength<u64>,
|
||||
{
|
||||
fn length(&self) -> Option<u64> {
|
||||
match self {
|
||||
TriState::Value(v) => v.length(),
|
||||
TriState::None => None,
|
||||
TriState::Absent => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
26
src/web/entity/file.rs
Normal file
26
src/web/entity/file.rs
Normal file
@@ -0,0 +1,26 @@
|
||||
use serde::Serialize;
|
||||
|
||||
use crate::entity::file::Id;
|
||||
use crate::util;
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct File {
|
||||
pub id: Id,
|
||||
pub filename: String,
|
||||
pub content_type: String,
|
||||
pub size: i64,
|
||||
pub url: String,
|
||||
}
|
||||
|
||||
impl From<crate::entity::file::File> for File {
|
||||
fn from(file: crate::entity::file::File) -> Self {
|
||||
Self {
|
||||
id: file.id,
|
||||
filename: file.filename,
|
||||
content_type: file.content_type,
|
||||
size: file.size,
|
||||
url: util::file_id_to_url(&file.id).unwrap_or_default(),
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -11,10 +11,14 @@ pub struct Message {
|
||||
pub channel_id: channel::Id,
|
||||
pub content: String,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub attachments: Vec<super::file::File>,
|
||||
}
|
||||
|
||||
impl From<crate::entity::message::Message> for Message {
|
||||
fn from(message: crate::entity::message::Message) -> Self {
|
||||
impl Message {
|
||||
pub fn from_message_with_attachments(
|
||||
message: crate::entity::message::Message,
|
||||
attachments: Vec<super::file::File>,
|
||||
) -> Self {
|
||||
Self {
|
||||
id: message.id,
|
||||
author_id: message.author_id,
|
||||
@@ -30,6 +34,7 @@ impl From<crate::entity::message::Message> for Message {
|
||||
})
|
||||
.flatten()
|
||||
.unwrap_or_default(),
|
||||
attachments,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -1,3 +1,4 @@
|
||||
pub mod file;
|
||||
pub mod message;
|
||||
pub mod user;
|
||||
pub mod server;
|
||||
pub mod user;
|
||||
|
||||
@@ -5,15 +5,31 @@ use validator::Validate;
|
||||
|
||||
use crate::state::AppState;
|
||||
use crate::web::context::UserContext;
|
||||
use crate::web::entity::file::File;
|
||||
use crate::web::entity::message::Message;
|
||||
use crate::web::ws;
|
||||
use crate::{entity, web};
|
||||
|
||||
#[derive(Debug, serde::Deserialize, Validate)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[validate(schema(function = "validate_create_payload"))]
|
||||
pub struct CreatePayload {
|
||||
#[validate(length(min = 1, max = 2000))]
|
||||
#[validate(length(min = 0, max = 2000))]
|
||||
pub content: String,
|
||||
|
||||
#[serde(default)]
|
||||
#[validate(length(min = 0, max = 10))]
|
||||
pub attachments: Vec<entity::file::Id>,
|
||||
}
|
||||
|
||||
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
|
||||
if payload.content.is_empty() && payload.attachments.is_empty() {
|
||||
return Err(validator::ValidationError::new(
|
||||
"at_least_one_field_required",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create(
|
||||
@@ -33,7 +49,22 @@ pub async fn create(
|
||||
.insert_channel_message(context.user_id, channel_id, &payload.content)
|
||||
.await?;
|
||||
|
||||
let message = Message::from(message);
|
||||
for attachment in payload.attachments.iter() {
|
||||
state
|
||||
.database
|
||||
.insert_message_attachment(message.id, *attachment)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let attachments = state
|
||||
.database
|
||||
.select_message_attachments(message.id)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(File::from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
let message = Message::from_message_with_attachments(message, attachments);
|
||||
|
||||
// TODO: check permissions
|
||||
ws::gateway::util::send_message_channel(
|
||||
|
||||
@@ -2,6 +2,7 @@ use axum::Json;
|
||||
use axum::extract::{Path, Query, State};
|
||||
use axum::response::IntoResponse;
|
||||
use serde::Deserialize;
|
||||
use tracing::error;
|
||||
use validator::Validate;
|
||||
|
||||
use crate::state::AppState;
|
||||
@@ -35,13 +36,24 @@ pub async fn page(
|
||||
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
|
||||
};
|
||||
|
||||
let messages = state
|
||||
let messages_futures = state
|
||||
.database
|
||||
.select_channel_messages_paginated(channel_id, params.before, params.limit as i64)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(Message::from)
|
||||
.map(|message| async {
|
||||
let attachments = state
|
||||
.database
|
||||
.select_message_attachments(message.id)
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(web::entity::file::File::from)
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments))
|
||||
});
|
||||
|
||||
let messages = futures::future::try_join_all(messages_futures).await?;
|
||||
|
||||
Ok(Json(messages))
|
||||
}
|
||||
|
||||
@@ -41,6 +41,11 @@ pub async fn create(
|
||||
|
||||
// TODO: check permissions
|
||||
let server = state.database.select_server_by_id(server_id).await?;
|
||||
|
||||
if server.owner_id != context.user_id {
|
||||
return Err(web::error::ClientError::NotAllowed.into());
|
||||
}
|
||||
|
||||
let channel = state
|
||||
.database
|
||||
.insert_server_channel(&payload.name, 0, payload.r#type, server_id, None)
|
||||
|
||||
@@ -5,6 +5,7 @@ use axum::response::IntoResponse;
|
||||
use crate::state::AppState;
|
||||
use crate::web::context::UserContext;
|
||||
use crate::web::ws;
|
||||
use crate::webrtc::WebRtcSignal;
|
||||
use crate::{entity, web};
|
||||
|
||||
pub async fn delete(
|
||||
@@ -15,6 +16,11 @@ pub async fn delete(
|
||||
// TODO: check permissions
|
||||
|
||||
let channel = state.database.select_channel_by_id(channel_id).await?;
|
||||
let server = state.database.select_server_by_id(server_id).await?;
|
||||
|
||||
if server.owner_id != context.user_id {
|
||||
return Err(web::error::ClientError::NotAllowed.into());
|
||||
}
|
||||
|
||||
if let Some(channel_server_id) = channel.server_id {
|
||||
if channel_server_id != server_id {
|
||||
@@ -27,7 +33,7 @@ pub async fn delete(
|
||||
let channel = state.database.delete_channel_by_id(channel_id).await?;
|
||||
|
||||
ws::gateway::util::send_message_server(
|
||||
state,
|
||||
state.clone(),
|
||||
server_id,
|
||||
ws::gateway::event::Event::RemoveServerChannel {
|
||||
server_id: server_id.clone(),
|
||||
@@ -35,5 +41,12 @@ pub async fn delete(
|
||||
},
|
||||
);
|
||||
|
||||
tokio::spawn(async move {
|
||||
let voice_rooms = state.voice_rooms.read().await;
|
||||
if let Some(voice_room) = voice_rooms.get(&channel_id) {
|
||||
let _ = voice_room.send(WebRtcSignal::Close);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(Json(channel))
|
||||
}
|
||||
|
||||
@@ -9,17 +9,28 @@ use crate::state::AppState;
|
||||
use crate::web::context::UserContext;
|
||||
use crate::web::entity::user::{FullUser, PartialUser};
|
||||
use crate::web::ws;
|
||||
use crate::{entity, web};
|
||||
use crate::{entity, util, web};
|
||||
|
||||
#[derive(Debug, Validate, Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[validate(schema(function = "validate_create_payload"))]
|
||||
pub struct CreatePayload {
|
||||
#[validate(length(min = 1, max = 32))]
|
||||
#[serde(default)]
|
||||
display_name: Option<String>,
|
||||
display_name: util::tristate::TriState<String>,
|
||||
|
||||
#[serde(default)]
|
||||
avatar_id: Option<entity::file::Id>,
|
||||
avatar_id: util::tristate::TriState<entity::file::Id>,
|
||||
}
|
||||
|
||||
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
|
||||
if payload.display_name.is_absent() && payload.avatar_id.is_absent() {
|
||||
return Err(validator::ValidationError::new(
|
||||
"at_least_one_field_required",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn patch(
|
||||
|
||||
@@ -86,7 +86,7 @@ where
|
||||
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
|
||||
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
||||
}
|
||||
None => { // Stream closed by client
|
||||
None => {
|
||||
tracing::debug!("WebSocket stream ended by client during Initialize state.");
|
||||
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
||||
}
|
||||
@@ -123,7 +123,7 @@ where
|
||||
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
|
||||
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
||||
}
|
||||
None => { // Stream closed by client
|
||||
None => {
|
||||
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
|
||||
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
||||
}
|
||||
@@ -139,14 +139,13 @@ where
|
||||
async fn handle_initial_message(
|
||||
context: &mut WsContext,
|
||||
message: AxumMessage,
|
||||
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
|
||||
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
||||
app_state: &AppState,
|
||||
) -> crate::web::ws::error::Result<(), WsError> {
|
||||
match deserialize_ws_message(message)? {
|
||||
WsClientMessage::Authenticate { token } => {
|
||||
match crate::web::middleware::get_context_from_token(&app_state, &token).await {
|
||||
Ok(auth_user_context) => {
|
||||
// auth_user_context is `crate::web::context::UserContext`
|
||||
let user_id = auth_user_context.user_id;
|
||||
|
||||
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
|
||||
@@ -210,7 +209,7 @@ async fn handle_initial_message(
|
||||
async fn handle_connected_message(
|
||||
context: &mut WsContext,
|
||||
message: AxumMessage,
|
||||
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
|
||||
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
||||
app_state: &AppState,
|
||||
) -> crate::web::ws::error::Result<(), WsError> {
|
||||
match deserialize_ws_message(message)? {
|
||||
@@ -218,6 +217,7 @@ async fn handle_connected_message(
|
||||
server_id,
|
||||
channel_id,
|
||||
} => {
|
||||
let server_id = server_id.unwrap_or_else(|| channel_id.clone().into());
|
||||
// TODO: check if can join this channel
|
||||
|
||||
let claims = voice::claims::VoiceClaims {
|
||||
|
||||
@@ -33,7 +33,7 @@ pub enum WsClientMessage {
|
||||
|
||||
#[serde(rename_all = "camelCase")]
|
||||
VoiceStateUpdate {
|
||||
server_id: entity::server::Id,
|
||||
server_id: Option<entity::server::Id>,
|
||||
channel_id: entity::channel::Id,
|
||||
},
|
||||
|
||||
|
||||
@@ -42,10 +42,11 @@ impl WebSocketHandler for WsContext {
|
||||
server_id,
|
||||
channel_id,
|
||||
user_id,
|
||||
..
|
||||
} => {
|
||||
ws::gateway::util::send_message_server(
|
||||
ws::gateway::util::send_message_channel(
|
||||
app_state.clone(),
|
||||
*server_id,
|
||||
*channel_id,
|
||||
ws::gateway::event::Event::VoiceChannelDisconnected {
|
||||
server_id: *server_id,
|
||||
channel_id: *channel_id,
|
||||
@@ -136,9 +137,9 @@ async fn handle_initial_message(
|
||||
user_id: claims.user_id,
|
||||
};
|
||||
|
||||
ws::gateway::util::send_message_server(
|
||||
ws::gateway::util::send_message_channel(
|
||||
app_state.clone(),
|
||||
claims.server_id,
|
||||
claims.channel_id,
|
||||
ws::gateway::event::Event::VoiceChannelConnected {
|
||||
server_id: claims.server_id,
|
||||
channel_id: claims.channel_id,
|
||||
@@ -174,7 +175,7 @@ async fn handle_connected_message(
|
||||
) -> ws::error::Result<(), error::Error> {
|
||||
match deserialize_ws_message(message)? {
|
||||
WsClientMessage::SdpOffer { sdp } => {
|
||||
let (signal_channel, user_id) = match &context.connection_state {
|
||||
let (signal_channel, user_id) = match &mut context.connection_state {
|
||||
WsState::Connected {
|
||||
signal_channel,
|
||||
user_id,
|
||||
|
||||
@@ -1,10 +1,11 @@
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use crate::entity;
|
||||
use crate::state::AppState;
|
||||
use crate::webrtc::{OfferSignal, WebRtcSignal};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[derive(Debug)]
|
||||
pub enum WsState {
|
||||
Initialize,
|
||||
Connected {
|
||||
|
||||
@@ -241,14 +241,20 @@ async fn handle_peer(
|
||||
};
|
||||
|
||||
{
|
||||
if let Some(old_peer) = room_state
|
||||
if let Some((_, old_peer)) = room_state
|
||||
.peers
|
||||
.insert(offer_signal.offer.peer_id, peer_state)
|
||||
.remove(&offer_signal.offer.peer_id)
|
||||
{
|
||||
let _ = old_peer.peer_connection.close().await;
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
room_state
|
||||
.peers
|
||||
.insert(offer_signal.offer.peer_id, peer_state);
|
||||
}
|
||||
|
||||
let _ = offer_signal.response.send(AnswerSignal {
|
||||
sdp_answer: local_description,
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user