This commit is contained in:
2025-05-20 11:13:56 +03:00
parent de11b5a4c3
commit 384e6aede2
16 changed files with 337 additions and 40 deletions

View File

@@ -1,5 +1,6 @@
use crate::config::DatabaseConfig;
use crate::entity;
use crate::util::tristate::TriState;
#[derive(Clone, derive_more::AsRef)]
pub struct Database {
@@ -90,18 +91,28 @@ impl Database {
pub async fn update_user_by_id(
&self,
user_id: entity::user::Id,
display_name: Option<&str>,
avatar_id: Option<entity::file::Id>,
display_name: TriState<&str>,
avatar_id: TriState<entity::file::Id>,
) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"UPDATE "user" SET "display_name" = COALESCE($2, "display_name"), "avatar_id" = COALESCE($3, "avatar_id") WHERE "id" = $1 RETURNING "user".*"#,
user_id,
display_name,
avatar_id
)
.fetch_one(&self.pool)
.await?;
let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#);
let mut separated = query_builder.separated(", ");
if !display_name.is_absent() {
separated.push(r#""display_name" = "#);
separated.push_bind_unseparated(display_name.into_option());
}
if !avatar_id.is_absent() {
separated.push(r#""avatar_id" = "#);
separated.push_bind_unseparated(avatar_id.into_option());
}
query_builder.push(" WHERE \"id\" = ");
query_builder.push_bind(user_id);
query_builder.push(" RETURNING \"user\".*");
let user = query_builder.build_query_as().fetch_one(&self.pool).await?;
Ok(user)
}
@@ -552,6 +563,39 @@ impl Database {
Ok(file)
}
pub async fn insert_message_attachment(
&self,
message_id: entity::message::Id,
file_id: entity::file::Id,
) -> Result<()> {
sqlx::query!(
r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#,
message_id,
file_id
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn select_message_attachments(
&self,
message_id: entity::message::Id,
) -> Result<Vec<entity::file::File>> {
let attachments = sqlx::query_as!(
entity::file::File,
r#"SELECT * FROM "file" WHERE "id" IN (
SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1
)"#,
message_id
)
.fetch_all(&self.pool)
.await?;
Ok(attachments)
}
pub async fn select_related_user_ids(
&self,
user_id: entity::user::Id,

View File

@@ -1,3 +1,5 @@
pub mod tristate;
use axum::extract::multipart::Field;
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
use serde::{Deserialize, Serialize};

139
src/util/tristate.rs Normal file
View File

@@ -0,0 +1,139 @@
use std::ops::Deref;
use serde::{Deserialize, Deserializer, Serialize};
use validator::ValidateLength;
#[derive(Debug, Clone)]
pub enum TriState<T> {
Value(T),
None,
Absent,
}
impl<T> TriState<T> {
pub fn as_deref<D: ?Sized>(&self) -> TriState<&D>
where
T: Deref<Target = D>,
{
match self {
TriState::Value(v) => TriState::Value(v.deref()),
TriState::None => TriState::None,
TriState::Absent => TriState::Absent,
}
}
pub fn is_value(&self) -> bool {
matches!(self, TriState::Value(_))
}
pub fn is_none(&self) -> bool {
matches!(self, TriState::None)
}
pub fn is_absent(&self) -> bool {
matches!(self, TriState::Absent)
}
pub fn as_option(&self) -> Option<&T> {
match self {
TriState::Value(v) => Some(v),
TriState::None => None,
TriState::Absent => None,
}
}
pub fn into_option(self) -> Option<T> {
match self {
TriState::Value(v) => Some(v),
TriState::None => None,
TriState::Absent => None,
}
}
}
impl<T> Default for TriState<T> {
fn default() -> Self {
TriState::Absent
}
}
pub(crate) struct TriStateFieldVisitor<T> {
marker: std::marker::PhantomData<T>,
}
impl<T> Serialize for TriState<T>
where
T: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
TriState::Value(v) => v.serialize(serializer),
TriState::None => serializer.serialize_none(),
TriState::Absent => serializer.serialize_unit(),
}
}
}
impl<'de, T> Deserialize<'de> for TriState<T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_option(TriStateFieldVisitor::<T> {
marker: std::marker::PhantomData,
})
}
}
impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor<T>
where
T: Deserialize<'de>,
{
type Value = TriState<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("TriStateField<T>")
}
#[inline]
fn visit_none<E>(self) -> Result<TriState<T>, E>
where
E: serde::de::Error,
{
Ok(TriState::None)
}
#[inline]
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
T::deserialize(deserializer).map(TriState::Value)
}
#[inline]
fn visit_unit<E>(self) -> Result<TriState<T>, E>
where
E: serde::de::Error,
{
Ok(TriState::None)
}
}
impl<T> ValidateLength<u64> for TriState<T>
where
T: ValidateLength<u64>,
{
fn length(&self) -> Option<u64> {
match self {
TriState::Value(v) => v.length(),
TriState::None => None,
TriState::Absent => None,
}
}
}

26
src/web/entity/file.rs Normal file
View File

@@ -0,0 +1,26 @@
use serde::Serialize;
use crate::entity::file::Id;
use crate::util;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct File {
pub id: Id,
pub filename: String,
pub content_type: String,
pub size: i64,
pub url: String,
}
impl From<crate::entity::file::File> for File {
fn from(file: crate::entity::file::File) -> Self {
Self {
id: file.id,
filename: file.filename,
content_type: file.content_type,
size: file.size,
url: util::file_id_to_url(&file.id).unwrap_or_default(),
}
}
}

View File

@@ -11,10 +11,14 @@ pub struct Message {
pub channel_id: channel::Id,
pub content: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub attachments: Vec<super::file::File>,
}
impl From<crate::entity::message::Message> for Message {
fn from(message: crate::entity::message::Message) -> Self {
impl Message {
pub fn from_message_with_attachments(
message: crate::entity::message::Message,
attachments: Vec<super::file::File>,
) -> Self {
Self {
id: message.id,
author_id: message.author_id,
@@ -30,6 +34,7 @@ impl From<crate::entity::message::Message> for Message {
})
.flatten()
.unwrap_or_default(),
attachments,
}
}
}

View File

@@ -1,3 +1,4 @@
pub mod file;
pub mod message;
pub mod user;
pub mod server;
pub mod user;

View File

@@ -5,15 +5,31 @@ use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::file::File;
use crate::web::entity::message::Message;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, serde::Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
#[validate(schema(function = "validate_create_payload"))]
pub struct CreatePayload {
#[validate(length(min = 1, max = 2000))]
#[validate(length(min = 0, max = 2000))]
pub content: String,
#[serde(default)]
#[validate(length(min = 0, max = 10))]
pub attachments: Vec<entity::file::Id>,
}
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
if payload.content.is_empty() && payload.attachments.is_empty() {
return Err(validator::ValidationError::new(
"at_least_one_field_required",
));
}
Ok(())
}
pub async fn create(
@@ -33,7 +49,22 @@ pub async fn create(
.insert_channel_message(context.user_id, channel_id, &payload.content)
.await?;
let message = Message::from(message);
for attachment in payload.attachments.iter() {
state
.database
.insert_message_attachment(message.id, *attachment)
.await?;
}
let attachments = state
.database
.select_message_attachments(message.id)
.await?
.into_iter()
.map(File::from)
.collect::<Vec<_>>();
let message = Message::from_message_with_attachments(message, attachments);
// TODO: check permissions
ws::gateway::util::send_message_channel(

View File

@@ -2,6 +2,7 @@ use axum::Json;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use serde::Deserialize;
use tracing::error;
use validator::Validate;
use crate::state::AppState;
@@ -35,13 +36,24 @@ pub async fn page(
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let messages = state
let messages_futures = state
.database
.select_channel_messages_paginated(channel_id, params.before, params.limit as i64)
.await?
.into_iter()
.map(Message::from)
.map(|message| async {
let attachments = state
.database
.select_message_attachments(message.id)
.await?
.into_iter()
.map(web::entity::file::File::from)
.collect::<Vec<_>>();
Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments))
});
let messages = futures::future::try_join_all(messages_futures).await?;
Ok(Json(messages))
}

View File

@@ -41,6 +41,11 @@ pub async fn create(
// TODO: check permissions
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
let channel = state
.database
.insert_server_channel(&payload.name, 0, payload.r#type, server_id, None)

View File

@@ -5,6 +5,7 @@ use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::webrtc::WebRtcSignal;
use crate::{entity, web};
pub async fn delete(
@@ -15,6 +16,11 @@ pub async fn delete(
// TODO: check permissions
let channel = state.database.select_channel_by_id(channel_id).await?;
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
if let Some(channel_server_id) = channel.server_id {
if channel_server_id != server_id {
@@ -27,7 +33,7 @@ pub async fn delete(
let channel = state.database.delete_channel_by_id(channel_id).await?;
ws::gateway::util::send_message_server(
state,
state.clone(),
server_id,
ws::gateway::event::Event::RemoveServerChannel {
server_id: server_id.clone(),
@@ -35,5 +41,12 @@ pub async fn delete(
},
);
tokio::spawn(async move {
let voice_rooms = state.voice_rooms.read().await;
if let Some(voice_room) = voice_rooms.get(&channel_id) {
let _ = voice_room.send(WebRtcSignal::Close);
}
});
Ok(Json(channel))
}

View File

@@ -9,17 +9,28 @@ use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::user::{FullUser, PartialUser};
use crate::web::ws;
use crate::{entity, web};
use crate::{entity, util, web};
#[derive(Debug, Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
#[validate(schema(function = "validate_create_payload"))]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
#[serde(default)]
display_name: Option<String>,
display_name: util::tristate::TriState<String>,
#[serde(default)]
avatar_id: Option<entity::file::Id>,
avatar_id: util::tristate::TriState<entity::file::Id>,
}
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
if payload.display_name.is_absent() && payload.avatar_id.is_absent() {
return Err(validator::ValidationError::new(
"at_least_one_field_required",
));
}
Ok(())
}
pub async fn patch(

View File

@@ -86,7 +86,7 @@ where
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
None => { // Stream closed by client
None => {
tracing::debug!("WebSocket stream ended by client during Initialize state.");
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
@@ -123,7 +123,7 @@ where
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
None => { // Stream closed by client
None => {
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
@@ -139,14 +139,13 @@ where
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError> {
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => {
match crate::web::middleware::get_context_from_token(&app_state, &token).await {
Ok(auth_user_context) => {
// auth_user_context is `crate::web::context::UserContext`
let user_id = auth_user_context.user_id;
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
@@ -210,7 +209,7 @@ async fn handle_initial_message(
async fn handle_connected_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>, // Changed to reference
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError> {
match deserialize_ws_message(message)? {
@@ -218,6 +217,7 @@ async fn handle_connected_message(
server_id,
channel_id,
} => {
let server_id = server_id.unwrap_or_else(|| channel_id.clone().into());
// TODO: check if can join this channel
let claims = voice::claims::VoiceClaims {

View File

@@ -33,7 +33,7 @@ pub enum WsClientMessage {
#[serde(rename_all = "camelCase")]
VoiceStateUpdate {
server_id: entity::server::Id,
server_id: Option<entity::server::Id>,
channel_id: entity::channel::Id,
},

View File

@@ -42,10 +42,11 @@ impl WebSocketHandler for WsContext {
server_id,
channel_id,
user_id,
..
} => {
ws::gateway::util::send_message_server(
ws::gateway::util::send_message_channel(
app_state.clone(),
*server_id,
*channel_id,
ws::gateway::event::Event::VoiceChannelDisconnected {
server_id: *server_id,
channel_id: *channel_id,
@@ -136,9 +137,9 @@ async fn handle_initial_message(
user_id: claims.user_id,
};
ws::gateway::util::send_message_server(
ws::gateway::util::send_message_channel(
app_state.clone(),
claims.server_id,
claims.channel_id,
ws::gateway::event::Event::VoiceChannelConnected {
server_id: claims.server_id,
channel_id: claims.channel_id,
@@ -174,7 +175,7 @@ async fn handle_connected_message(
) -> ws::error::Result<(), error::Error> {
match deserialize_ws_message(message)? {
WsClientMessage::SdpOffer { sdp } => {
let (signal_channel, user_id) = match &context.connection_state {
let (signal_channel, user_id) = match &mut context.connection_state {
WsState::Connected {
signal_channel,
user_id,

View File

@@ -1,10 +1,11 @@
use tokio::sync::oneshot;
use tokio::sync::mpsc;
use crate::entity;
use crate::state::AppState;
use crate::webrtc::{OfferSignal, WebRtcSignal};
#[derive(Debug, Clone)]
#[derive(Debug)]
pub enum WsState {
Initialize,
Connected {

View File

@@ -241,14 +241,20 @@ async fn handle_peer(
};
{
if let Some(old_peer) = room_state
if let Some((_, old_peer)) = room_state
.peers
.insert(offer_signal.offer.peer_id, peer_state)
.remove(&offer_signal.offer.peer_id)
{
let _ = old_peer.peer_connection.close().await;
}
}
{
room_state
.peers
.insert(offer_signal.offer.peer_id, peer_state);
}
let _ = offer_signal.response.send(AnswerSignal {
sdp_answer: local_description,
});