From e0a45216b0e29e95f6febb26fd4b8676f0d08153 Mon Sep 17 00:00:00 2001 From: Lionarius Date: Tue, 3 Jun 2025 11:42:51 +0300 Subject: [PATCH] . --- backend.txt | 5062 +++++++++++++++++ migrations/20250510184011_channel_message.sql | 60 +- src/database.rs | 27 +- src/web/mod.rs | 1 + src/web/route/user/channel/create.rs | 75 + src/web/route/user/channel/list.rs | 15 +- src/web/route/user/channel/mod.rs | 13 + src/web/ws/gateway/event.rs | 9 +- src/webrtc/mod.rs | 30 +- 9 files changed, 5262 insertions(+), 30 deletions(-) create mode 100644 backend.txt create mode 100644 src/web/route/user/channel/create.rs diff --git a/backend.txt b/backend.txt new file mode 100644 index 0000000..8742ed2 --- /dev/null +++ b/backend.txt @@ -0,0 +1,5062 @@ + +// file: migrations/20250510180101_uuidv7.sql +CREATE EXTENSION IF NOT EXISTS pg_uuidv7; +// file: migrations/20250510182916_file.sql +CREATE TABLE IF NOT EXISTS "file" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "filename" VARCHAR NOT NULL, + "content_type" VARCHAR NOT NULL, + "size" INT8 NOT NULL +); + + +// file: migrations/20250510183102_user.sql +CREATE TABLE IF NOT EXISTS "user" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL, + "username" VARCHAR NOT NULL UNIQUE, + "display_name" VARCHAR, + "email" VARCHAR NOT NULL, + "password_hash" VARCHAR NOT NULL, + "bot" BOOLEAN NOT NULL DEFAULT FALSE, + "system" BOOLEAN NOT NULL DEFAULT FALSE, + "settings" JSONB NOT NULL DEFAULT '{}'::JSONB +); + +CREATE TABLE IF NOT EXISTS "user_relation" +( + "user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE, + "other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE, + "type" INT2 NOT NULL, + "created_at" TIMESTAMPTZ NOT NULL DEFAULT now(), + "updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(), + PRIMARY KEY ("user_id", "other_id") +); + +-- create system account +INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system") +VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE); + +CREATE OR REPLACE FUNCTION check_avatar_is_image() + RETURNS TRIGGER AS +$$ +DECLARE + file_content_type VARCHAR; +BEGIN + -- Skip check if icon_id is null + IF NEW.avatar_id IS NULL THEN + RETURN NEW; + END IF; + + -- Retrieve content_type from file table + SELECT content_type + INTO file_content_type + FROM file + WHERE id = NEW.avatar_id; + + -- Raise exception if content_type does not start with 'image/' + IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN + RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/'; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_check_icon_is_image + BEFORE INSERT OR UPDATE + ON "user" + FOR EACH ROW +EXECUTE FUNCTION check_avatar_is_image(); + +CREATE OR REPLACE FUNCTION fn_on_user_relation_update() + RETURNS TRIGGER + LANGUAGE plpgsql +AS +$$ +BEGIN + IF (TG_OP = 'UPDATE') THEN + NEW.updated_at := now(); + END IF; + + RETURN NEW; +END; +$$; + +CREATE TRIGGER trg_user_relation_update + BEFORE UPDATE + ON "user_relation" + FOR EACH ROW +EXECUTE FUNCTION fn_on_user_relation_update(); +// file: migrations/20250510183125_server.sql +CREATE TABLE IF NOT EXISTS "server" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "owner_id" UUID NOT NULL REFERENCES "user" ("id"), + "name" VARCHAR NOT NULL, + "icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL +); + +CREATE TABLE IF NOT EXISTS "server_role" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, + "name" VARCHAR NOT NULL, + "color" VARCHAR, + "display" BOOLEAN NOT NULL DEFAULT FALSE, + "permissions" JSONB NOT NULL DEFAULT '{}'::JSONB, + "position" SMALLINT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "server_member" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, + "user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, + "nickname" VARCHAR, + "avatar_url" VARCHAR, + UNIQUE ("server_id", "user_id") +); + +CREATE TABLE IF NOT EXISTS "server_member_role" +( + "member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE, + "role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE, + PRIMARY KEY ("member_id", "role_id") +); + +CREATE TABLE IF NOT EXISTS "server_invite" +( + "code" VARCHAR NOT NULL PRIMARY KEY, + "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, + "inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL, + "expires_at" TIMESTAMPTZ +); + +CREATE OR REPLACE FUNCTION check_icon_is_image() + RETURNS TRIGGER AS +$$ +DECLARE + file_content_type VARCHAR; +BEGIN + -- Skip check if icon_id is null + IF NEW.icon_id IS NULL THEN + RETURN NEW; + END IF; + + -- Retrieve content_type from file table + SELECT content_type + INTO file_content_type + FROM file + WHERE id = NEW.icon_id; + + -- Raise exception if content_type does not start with 'image/' + IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN + RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/'; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_check_icon_is_image + BEFORE INSERT OR UPDATE + ON "server" + FOR EACH ROW +EXECUTE FUNCTION check_icon_is_image(); + +CREATE OR REPLACE FUNCTION check_server_user_role_server_id() + RETURNS TRIGGER AS +$$ +DECLARE + member_server_id UUID; + role_server_id UUID; +BEGIN + -- Get server_id from server_user + SELECT server_id + INTO member_server_id + FROM server_member + WHERE id = NEW.member_id; + + -- Get server_id from server_role + SELECT server_id + INTO role_server_id + FROM server_role + WHERE id = NEW.role_id; + + -- Check if server_ids match + IF member_server_id != role_server_id THEN + RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER enforce_server_user_role_server_id + BEFORE INSERT OR UPDATE + ON server_member_role + FOR EACH ROW +EXECUTE FUNCTION check_server_user_role_server_id(); +// file: migrations/20250510184011_channel_message.sql +CREATE TABLE IF NOT EXISTS "channel" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "name" VARCHAR NOT NULL, + "type" INT2 NOT NULL, + "position" INT2 NOT NULL, + "server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels + "parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels + "owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels +); + +-- only for dm or group channels +CREATE TABLE IF NOT EXISTS "channel_recipient" +( + "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE, + "user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, + PRIMARY KEY ("channel_id", "user_id") +); + +CREATE TABLE IF NOT EXISTS "message" +( + "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), + "author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, + "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE, + "content" TEXT NOT NULL +); + +CREATE TABLE IF NOT EXISTS "message_attachment" +( + "message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE, + "file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE, + "order" INT2 NOT NULL, + PRIMARY KEY ("message_id", "file_id") +); + +ALTER TABLE "channel" + ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL; + + +CREATE OR REPLACE FUNCTION create_dm_channel( + user1_id UUID, + user2_id UUID, + channel_type INT2 +) + RETURNS UUID + LANGUAGE plpgsql +AS +$$ +DECLARE + new_channel_id UUID; + channel_name VARCHAR; +BEGIN + -- Validate parameters + IF user1_id IS NULL OR user2_id IS NULL THEN + RAISE EXCEPTION 'Both user IDs must be provided'; + END IF; + + IF user1_id = user2_id THEN + RAISE EXCEPTION 'Cannot create a DM channel with the same user'; + END IF; + +-- Check if users exist + IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR + NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN + RAISE EXCEPTION 'One or both users do not exist'; + END IF; + +-- Check if DM already exists between these users + IF exists (SELECT 1 + FROM channel c + JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id + JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id + WHERE c.type = channel_type + AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN + -- Find and return the existing channel ID + SELECT c.id + INTO new_channel_id + FROM channel c + JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id + JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id + WHERE c.type = channel_type + AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2 + LIMIT 1; + + RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id; + RETURN new_channel_id; + END IF; + +-- Generate channel name (conventionally uses user IDs in DMs) + channel_name := concat(user1_id, '-', user2_id); + + -- Create new channel + INSERT INTO "channel" ("name", "type", "position") + VALUES (channel_name, channel_type, 0) + RETURNING id INTO new_channel_id; + +-- Add both users as recipients + INSERT INTO "channel_recipient" ("channel_id", "user_id") + VALUES (new_channel_id, user1_id), + (new_channel_id, user2_id); + + RAISE NOTICE 'DM channel created with ID: %', new_channel_id; + RETURN new_channel_id; +END; +$$; + +CREATE OR REPLACE FUNCTION create_group_channel( + p_creator_id UUID, -- The user initiating the group creation (will be owner) + p_recipient_user_ids UUID[], -- Array of all user IDs for the group (must include creator) + p_group_channel_type INT2 -- The channel type identifier for groups (e.g., 1) +) + RETURNS UUID + LANGUAGE plpgsql +AS +$$ +DECLARE + new_channel_id UUID; + existing_channel_id UUID; + final_channel_name VARCHAR; + unique_sorted_recipient_ids UUID[]; + num_recipients INT; + uid UUID; + -- Threshold for detailed name vs. summary name + MAX_MEMBERS_FOR_DETAILED_NAME CONSTANT INT := 3; +BEGIN + -- Validate and process recipient IDs + IF p_recipient_user_ids IS NULL OR array_length(p_recipient_user_ids, 1) IS NULL THEN + RAISE EXCEPTION 'Recipient user IDs array must be provided and not empty.'; + END IF; + + -- Get unique, sorted recipient IDs for consistent checking and to avoid duplicates. + SELECT array_agg(DISTINCT u ORDER BY u) INTO unique_sorted_recipient_ids FROM unnest(p_recipient_user_ids) u; + num_recipients := array_length(unique_sorted_recipient_ids, 1); + + -- Validate minimum number of recipients for a group + IF num_recipients < 1 THEN -- Groups typically have at least 2 members + RAISE EXCEPTION 'Group channels (type %) must have at least 2 recipients. Found %.', p_group_channel_type, num_recipients; + END IF; + + -- Create new group channel + INSERT INTO "channel" ("name", "type", "position", "owner_id", "server_id", "parent") + VALUES ('Group', + p_group_channel_type, + 0, -- Default position + p_creator_id, + NULL, -- Not a server channel + NULL -- Not a nested server channel + ) + RETURNING id INTO new_channel_id; + + -- Add all recipients to the channel_recipient table + INSERT INTO "channel_recipient" ("channel_id", "user_id") + VALUES (new_channel_id, p_creator_id); + + INSERT INTO "channel_recipient" ("channel_id", "user_id") + SELECT new_channel_id, r_id + FROM unnest(unique_sorted_recipient_ids) AS r_id; + + RAISE NOTICE 'Group channel (type %) named "%" created with ID: % by owner % for recipients: %', + p_group_channel_type, final_channel_name, new_channel_id, p_creator_id, unique_sorted_recipient_ids; + RETURN new_channel_id; +END; +$$; + +// file: migrations/20250517190855_util.sql +CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID) + RETURNS TABLE (user_id UUID) AS $$ +BEGIN + RETURN QUERY + -- Users directly related to the target user + SELECT ur.user_id + FROM user_relation ur + WHERE ur.other_id = target_user_id + + UNION + + -- Users where target user is related to them + SELECT ur.other_id AS user_id + FROM user_relation ur + WHERE ur.user_id = target_user_id + + UNION + + -- Users who share a server with the target user + SELECT sm.user_id + FROM server_member sm + JOIN server_member sm2 ON sm.server_id = sm2.server_id + WHERE sm2.user_id = target_user_id + AND sm.user_id != target_user_id + + UNION + + -- Users who share a channel with the target user (DM or group) + SELECT cr.user_id + FROM channel_recipient cr + JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id + WHERE cr2.user_id = target_user_id + AND cr.user_id != target_user_id; +END; +$$ LANGUAGE plpgsql; +// file: src/config.rs +use std::sync::OnceLock; + +use serde::Deserialize; + +pub fn config() -> &'static Config { + static INSTANCE: OnceLock = OnceLock::new(); + + INSTANCE.get_or_init(|| { + config::Config::builder() + .add_source(config::File::with_name("config")) + .add_source(config::Environment::with_prefix("DIPLOM_")) + .build() + .expect("config builder") + .try_deserialize() + .expect("config deserialize") + }) +} + +#[derive(Deserialize)] +pub struct Config { + pub server: ServerConfig, + pub security: SecurityConfig, + pub gateway: GatewayConfig, + pub database: DatabaseConfig, + pub object_store: ObjectStoreConfig, +} + +#[derive(Deserialize)] +pub struct ServerConfig { + pub hostname: url::Url, + pub host: std::net::Ipv4Addr, + pub port: u16, +} + +#[derive(Deserialize)] +pub struct SecurityConfig { + pub auth_secret: String, + pub voice_secret: String, +} + +#[derive(Deserialize)] +pub struct GatewayConfig { + #[serde(deserialize_with = "crate::util::deserialize_duration_seconds")] + pub voice_token_lifetime: std::time::Duration, +} + +#[derive(Debug, Deserialize)] +#[serde(untagged)] +#[serde(deny_unknown_fields)] +pub enum DatabaseConfig { + Url { + url: url::Url, + }, + Full { + driver: String, + username: String, + password: String, + host: url::Host, + port: u16, + name: String, + }, +} + +#[derive(Debug, Deserialize)] +#[serde(deny_unknown_fields)] +pub struct ObjectStoreConfig { + pub endpoint: url::Url, + pub region: String, + pub bucket: String, + pub access_key: String, + pub secret_key: String, +} + +impl DatabaseConfig { + pub fn url(&self) -> Option { + match self { + Self::Url { url } => Some(url.clone()), + Self::Full { + driver, + username, + password, + host, + port, + name, + } => { + let str = format!( + "{}://{}:{}@{}:{}/{}", + driver, username, password, host, port, name + ); + let url = url::Url::parse(&str).ok()?; + Some(url) + }, + } + } +} + +// file: src/database.rs +use crate::config::DatabaseConfig; +use crate::entity; +use crate::util::tristate::TriState; + +#[derive(Clone, derive_more::AsRef)] +pub struct Database { + pool: sqlx::PgPool, +} + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)] +pub enum Error { + #[from] + Migrate(sqlx::migrate::MigrateError), + + #[from] + Sqlx(sqlx::Error), + + UserDoesNotExists, + UserAlreadyExists, + + ServerDoesNotExists, + + MemberAlreadyExists, + + ChannelDoesNotExists, + + InviteDoesNotExists, + + MessageDoesNotExists, + + FileDoesNotExists, +} + +impl Database { + pub async fn connect(config: &DatabaseConfig) -> sqlx::Result { + static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); + + let pool = sqlx::postgres::PgPoolOptions::new() + .connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str()) + .await?; + + MIGRATOR.run(&pool).await?; + + Ok(Self { pool }) + } +} + +impl Database { + pub async fn insert_user( + &self, + username: &str, + display_name: Option<&str>, + email: &str, + password_hash: &str, + ) -> Result { + let user = match sqlx::query_as!( + entity::user::User, + r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#, + username, + display_name, + email, + password_hash + ) + .fetch_one(&self.pool) + .await { + Ok(user) => user, + Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => { + return Err(Error::UserAlreadyExists); + } + Err(e) => return Err(e.into()), + }; + + Ok(user) + } + + pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result { + let user = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" = $1"#, + user_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::UserDoesNotExists)?; + + Ok(user) + } + + pub async fn update_user_by_id( + &self, + user_id: entity::user::Id, + display_name: TriState<&str>, + avatar_id: TriState, + ) -> Result { + let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#); + + let mut separated = query_builder.separated(", "); + + if !display_name.is_absent() { + separated.push(r#""display_name" = "#); + separated.push_bind_unseparated(display_name.into_option()); + } + + if !avatar_id.is_absent() { + separated.push(r#""avatar_id" = "#); + separated.push_bind_unseparated(avatar_id.into_option()); + } + + query_builder.push(" WHERE \"id\" = "); + query_builder.push_bind(user_id); + query_builder.push(" RETURNING \"user\".*"); + + let user = query_builder.build_query_as().fetch_one(&self.pool).await?; + + Ok(user) + } + + pub async fn select_users_by_ids( + &self, + user_ids: &[entity::user::Id], + ) -> Result> { + let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#); + + let mut separated = query_builder.separated(", "); + for id in user_ids.iter() { + separated.push_bind(id); + } + + query_builder.push(")"); + + let users = query_builder.build_query_as().fetch_all(&self.pool).await?; + + Ok(users) + } + + pub async fn select_user_by_username(&self, username: &str) -> Result { + let user = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "username" = $1"#, + username + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::UserDoesNotExists)?; + + Ok(user) + } + + pub async fn select_server_by_id( + &self, + server_id: entity::server::Id, + ) -> Result { + let server = sqlx::query_as!( + entity::server::Server, + r#"SELECT * FROM "server" WHERE "id" = $1"#, + server_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ServerDoesNotExists)?; + + Ok(server) + } + + pub async fn select_user_servers( + &self, + user_id: entity::user::Id, + ) -> Result> { + let servers = sqlx::query_as!( + entity::server::Server, + r#"SELECT * FROM "server" WHERE "id" IN ( + SELECT "server_id" FROM "server_member" + WHERE "user_id" = $1 + )"#, + user_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(servers) + } + + pub async fn select_server_members( + &self, + server_id: entity::server::Id, + ) -> Result> { + let users = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" IN ( + SELECT "user_id" FROM "server_member" WHERE "server_id" = $1 + )"#, + server_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn select_channel_members( + &self, + channel_id: entity::channel::Id, + ) -> Result> { + let users = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" IN ( + SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1 + UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN ( + SELECT "server_id" FROM "channel" WHERE "id" = $1 + ) + )"#, + channel_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn select_user_channels( + &self, + user_id: entity::user::Id, + ) -> Result> { + let channels = sqlx::query_as!( + entity::channel::Channel, + r#"SELECT * FROM "channel" WHERE "id" IN ( + SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1 + )"#, + user_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(channels) + } + + pub async fn insert_server( + &self, + name: &str, + icon_id: Option, + owner_id: entity::user::Id, + ) -> Result { + let server = sqlx::query_as!( + entity::server::Server, + r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#, + name, + icon_id, + owner_id + ) + .fetch_one(&self.pool) + .await?; + + Ok(server) + } + + pub async fn insert_server_role( + &self, + id: Option, + server_id: entity::server::Id, + name: &str, + color: Option<&str>, + display: bool, + permissions: serde_json::Value, + position: u16, + ) -> Result { + let role = sqlx::query_as!( + entity::server::role::ServerRole, + r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#, + id, + server_id, + name, + color, + display, + permissions, + position as i16 + ) + .fetch_one(&self.pool) + .await?; + + Ok(role) + } + + pub async fn insert_server_member( + &self, + server_id: entity::server::Id, + user_id: entity::user::Id, + ) -> Result { + let member = match sqlx::query_as!( + entity::server::member::ServerMember, + r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#, + server_id, + user_id + ) + .fetch_one(&self.pool) + .await { + Ok(member) => member, + Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => { + return Err(Error::MemberAlreadyExists); + } + Err(e) => return Err(e.into()), + }; + + Ok(member) + } + + pub async fn insert_server_member_role( + &self, + server_member_id: entity::server::member::Id, + server_role_id: entity::server::role::Id, + ) -> Result<()> { + sqlx::query!( + r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#, + server_member_id, + server_role_id + ) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn select_channel_by_id( + &self, + channel_id: entity::channel::Id, + ) -> Result { + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"SELECT * FROM "channel" WHERE "id" = $1"#, + channel_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ChannelDoesNotExists)?; + + Ok(channel) + } + + pub async fn select_channel_recipients( + &self, + channel_id: entity::channel::Id, + ) -> Result>> { + let recipients = sqlx::query_as!( + entity::user::User, + r#"SELECT * FROM "user" WHERE "id" IN ( + SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1 + )"#, + channel_id + ) + .fetch_all(&self.pool) + .await?; + + if recipients.is_empty() { + return Ok(None); + } + + Ok(Some(recipients)) + } + + pub async fn select_server_channels( + &self, + server_id: entity::server::Id, + ) -> Result> { + let channels = sqlx::query_as!( + entity::channel::Channel, + r#"SELECT * FROM "channel" WHERE "server_id" = $1"#, + server_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(channels) + } + + pub async fn delete_server_by_id( + &self, + server_id: entity::server::Id, + ) -> Result { + let server = sqlx::query_as!( + entity::server::Server, + r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#, + server_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ServerDoesNotExists)?; + + Ok(server) + } + + pub async fn delete_channel_by_id( + &self, + channel_id: entity::channel::Id, + ) -> Result { + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#, + channel_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::ChannelDoesNotExists)?; + + Ok(channel) + } + + pub async fn insert_server_channel( + &self, + name: &str, + position: u16, + r#type: entity::channel::ChannelType, + server_id: entity::server::Id, + parent: Option, + ) -> Result { + let channel = sqlx::query_as!( + entity::channel::Channel, + r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#, + name, + r#type as i16, + position as i16, + server_id, + parent + ) + .fetch_one(&self.pool) + .await?; + + Ok(channel) + } + + pub async fn insert_server_invite( + &self, + code: &str, + server_id: entity::server::Id, + inviter_id: Option, + expires_at: Option>, + ) -> Result { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#, + code, + server_id, + inviter_id, + expires_at + ) + .fetch_one(&self.pool) + .await?; + + Ok(invite) + } + + pub async fn select_server_invite_by_code( + &self, + code: &str, + ) -> Result { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"SELECT * FROM "server_invite" WHERE "code" = $1"#, + code + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::InviteDoesNotExists)?; + + Ok(invite) + } + + pub async fn delete_server_invite_by_code( + &self, + code: &str, + ) -> Result> { + let invite = sqlx::query_as!( + entity::server::invite::ServerInvite, + r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#, + code + ) + .fetch_optional(&self.pool) + .await?; + + Ok(invite) + } + + pub async fn select_channel_messages_paginated( + &self, + channel_id: entity::channel::Id, + before: Option, + limit: i64, + ) -> Result> { + let messages = sqlx::query_as!( + entity::message::Message, + r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#, + channel_id, + before, + limit + ) + .fetch_all(&self.pool) + .await?; + + Ok(messages) + } + + pub async fn insert_channel_message( + &self, + user_id: entity::user::Id, + channel_id: entity::channel::Id, + content: &str, + ) -> Result { + let message = sqlx::query_as!( + entity::message::Message, + r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#, + channel_id, + user_id, + content + ) + .fetch_one(&self.pool) + .await?; + + Ok(message) + } + + pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"SELECT * FROM "file" WHERE "id" = $1"#, + file_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::FileDoesNotExists)?; + + Ok(file) + } + + pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#, + file_id + ) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::FileDoesNotExists)?; + + Ok(file) + } + + pub async fn insert_file( + &self, + filename: &str, + content_type: &str, + size: usize, + ) -> Result { + let file = sqlx::query_as!( + entity::file::File, + r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#, + filename, + content_type, + size as i64 + ) + .fetch_one(&self.pool) + .await?; + + Ok(file) + } + + pub async fn insert_message_attachment( + &self, + message_id: entity::message::Id, + file_id: entity::file::Id, + ) -> Result<()> { + sqlx::query!( + r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#, + message_id, + file_id + ) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn select_message_attachments( + &self, + message_id: entity::message::Id, + ) -> Result> { + let attachments = sqlx::query_as!( + entity::file::File, + r#"SELECT * FROM "file" WHERE "id" IN ( + SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1 + )"#, + message_id + ) + .fetch_all(&self.pool) + .await?; + + Ok(attachments) + } + + pub async fn select_related_user_ids( + &self, + user_id: entity::user::Id, + ) -> Result> { + #[derive(sqlx::FromRow)] + struct UserId { + user_id: entity::user::Id, + } + + let user_ids = + sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#) + .bind(user_id) + .fetch_all(&self.pool) + .await? + .into_iter() + .map(|row| row.user_id) + .collect(); + + Ok(user_ids) + } + + pub async fn procedure_create_dm_channel( + &self, + user1_id: entity::user::Id, + user2_id: entity::user::Id, + ) -> Result { + let channel_id = sqlx::query_scalar!( + r#"SELECT create_dm_channel($1, $2, $3)"#, + user1_id, + user2_id, + entity::channel::ChannelType::DirectMessage as i16 + ) + .fetch_one(&self.pool) + .await? + .expect("channel_id is null"); + + Ok(channel_id) + } + + pub async fn procedure_create_group_channel( + &self, + creator_id: entity::user::Id, + users: &[entity::user::Id], + ) -> Result { + let channel_id = sqlx::query_scalar!( + r#"SELECT create_group_channel($1, $2, $3)"#, + creator_id, + users, + entity::channel::ChannelType::DirectMessage as i16 + ) + .fetch_one(&self.pool) + .await? + .expect("channel_id is null"); + + Ok(channel_id) + } +} + +// file: src/entity/channel.rs +use serde::{Deserialize, Serialize}; + +use crate::entity::{message, server, user}; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Channel { + pub id: Id, + pub name: String, + pub r#type: ChannelType, + pub position: i16, + #[serde(skip_serializing_if = "Option::is_none")] + pub server_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub parent: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub owner_id: Option, + #[serde(skip_serializing_if = "Option::is_none")] + pub last_message_id: Option, +} + +#[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)] +#[non_exhaustive] +#[serde(rename_all = "snake_case")] +#[repr(i16)] +pub enum ChannelType { + ServerText = 1, + ServerVoice = 2, + ServerCategory = 3, + + DirectMessage = 4, + Group = 5, +} + +impl From for ChannelType { + fn from(value: i16) -> Self { + match value { + 1 => ChannelType::ServerText, + 2 => ChannelType::ServerVoice, + 3 => ChannelType::ServerCategory, + 4 => ChannelType::DirectMessage, + 5 => ChannelType::Group, + _ => ChannelType::ServerText, + } + } +} + +// file: src/entity/file.rs +use serde::Serialize; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +pub struct File { + pub id: Id, + pub filename: String, + pub content_type: String, + pub size: i64, +} + +// file: src/entity/message.rs +use crate::entity::{channel, user}; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Message { + pub id: Id, + pub author_id: user::Id, + pub channel_id: channel::Id, + pub content: String, +} + +// file: src/entity/mod.rs +pub mod file; +pub mod channel; +pub mod message; +pub mod server; +pub mod user; + +// file: src/entity/server/invite.rs +use chrono::{DateTime, Utc}; +use serde::Serialize; + +use crate::entity::{server, user}; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerInvite { + pub code: String, + pub server_id: server::Id, + pub inviter_id: Option, + pub expires_at: Option>, +} + +// file: src/entity/server/member.rs +use serde::Serialize; + +use crate::entity::{server, user}; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerMember { + pub id: Id, + pub server_id: server::Id, + pub user_id: user::Id, + pub nickname: Option, + pub avatar_url: Option, +} + +// file: src/entity/server/role.rs +use serde::Serialize; + +use crate::entity::server; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ServerRole { + pub id: Id, + pub server_id: server::Id, + pub name: String, + pub color: Option, + pub display: bool, + pub permissions: serde_json::Value, + pub position: i16, +} + +// file: src/entity/server.rs +pub mod invite; +pub mod member; +pub mod role; + +use crate::entity::{file, user}; + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct Server { + pub id: Id, + pub owner_id: user::Id, + pub name: String, + pub icon_id: Option, +} + +// file: src/entity/user.rs +use std::sync::LazyLock; + +use regex::Regex; +use crate::entity::file; + +pub static USERNAME_REGEX: LazyLock = + LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap()); + +pub type Id = uuid::Uuid; + +#[derive(Debug, Clone, sqlx::FromRow)] +pub struct User { + pub id: Id, + pub avatar_id: Option, + pub username: String, + pub display_name: Option, + pub email: String, + pub password_hash: String, + pub bot: bool, + pub system: bool, + pub settings: serde_json::Value, +} + +// file: src/jwt.rs +#![allow(unused)] + +use std::collections::HashSet; + +use chrono::{Duration, Local, Utc}; +use serde::de::DeserializeOwned; +use serde::{Deserialize, Serialize}; + +use crate::config; + +#[derive(Debug, Serialize, Deserialize)] +pub struct Claims { + #[serde(flatten)] + pub data: T, + pub iat: i64, +} + +pub fn generate_jwt(data: T, secret: &[u8]) -> Result { + let claims = Claims { + data, + iat: Utc::now().timestamp_millis(), + }; + + let token = jsonwebtoken::encode( + &jsonwebtoken::Header::default(), + &claims, + &jsonwebtoken::EncodingKey::from_secret(secret), + ) + .map_err(|_| Error::CouldNotEncodeToken)?; + + Ok(token) +} + +pub fn verify_jwt(token: &str, secret: &[u8]) -> Result { + tracing::debug!("verifying token: {}", token); + + let mut validation = jsonwebtoken::Validation::default(); + validation.set_required_spec_claims::(&[]); + + let token_data = jsonwebtoken::decode::>( + token, + &jsonwebtoken::DecodingKey::from_secret(secret), + &validation, + ) + .inspect_err(|err| { + tracing::error!("Failed to decode JWT: {:?}", err); + }) + .map_err(|_| Error::CouldNotVerifyToken)?; + + Ok(token_data.claims.data) +} + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::Error, derive_more::Display)] +pub enum Error { + CouldNotEncodeToken, + CouldNotVerifyToken, +} + +// file: src/log.rs +use std::io; + +use tracing::level_filters::LevelFilter; +use tracing_appender::non_blocking::WorkerGuard; +use tracing_subscriber::EnvFilter; +use tracing_subscriber::fmt::time::ChronoLocal; +use tracing_subscriber::layer::SubscriberExt; +use tracing_subscriber::util::SubscriberInitExt; + +pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> { + let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string()); + + let file = tracing_appender::rolling::daily("./logs", "log"); + + let (file, file_guard) = tracing_appender::non_blocking(file); + + let file_logger = tracing_subscriber::fmt::layer() + .with_ansi(false) + .with_timer(logger_timer.clone()) + .with_writer(file); + + let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout()); + + let stdout_logger = tracing_subscriber::fmt::layer() + .with_timer(logger_timer) + .with_writer(stdout); + #[cfg(debug_assertions)] + let stdout_logger = stdout_logger.pretty(); + + tracing_subscriber::registry() + .with( + EnvFilter::builder() + .with_default_directive(LevelFilter::INFO.into()) + .from_env_lossy(), + ) + .with(file_logger) + .with(stdout_logger) + .init(); + + tracing::info!("initialized logging"); + + Ok((file_guard, stdout_guard)) +} + +// file: src/main.rs +use std::sync::Arc; + +use argon2::Argon2; + +use crate::database::Database; +use crate::state::AppState; + +mod config; +mod database; +mod entity; +mod jwt; +mod log; +mod object_store; +mod state; +mod util; +mod web; +mod webrtc; + +#[tokio::main] +async fn main() -> anyhow::Result<()> { + let _guard = log::init_logging()?; + + let database = Database::connect(&config::config().database).await?; + let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?; + let state = AppState { + database, + object_store, + hasher: Arc::new(Argon2::default()), + gateway_state: Default::default(), + voice_rooms: Default::default(), + }; + + web::run(state).await?; + + Ok(()) +} + +// file: src/object_store.rs +use crate::config::ObjectStoreConfig; + +#[derive(Clone, derive_more::AsRef, derive_more::Deref)] +pub struct ObjectStore { + inner: Box, +} + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)] +pub enum Error { + #[from] + Credentials(s3::creds::error::CredentialsError), + + #[from] + S3(s3::error::S3Error), +} + +impl ObjectStore { + pub async fn connect(config: &ObjectStoreConfig) -> Result { + let region = s3::region::Region::Custom { + region: config.region.clone(), + endpoint: config.endpoint.origin().ascii_serialization(), + }; + + let credentials = s3::creds::Credentials::new( + Some(&config.access_key), + Some(&config.secret_key), + None, + None, + None, + )?; + + let mut bucket = + s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())? + .with_path_style(); + + if !bucket.exists().await? { + bucket = s3::bucket::Bucket::create_with_path_style( + &config.bucket, + region, + credentials, + s3::BucketConfiguration::default(), + ) + .await? + .bucket; + } + + Ok(Self { inner: bucket }) + } +} + +// file: src/state.rs +use std::collections::HashMap; +use std::sync::Arc; + +use argon2::Argon2; +use tokio::sync::{RwLock, mpsc}; +use uuid::Uuid; + +use crate::database::Database; +use crate::object_store::ObjectStore; +use crate::web::ws::gateway::{GatewayWsState, SessionKey, event}; +use crate::webrtc::WebRtcSignal; + +#[derive(Clone)] +pub struct AppState { + pub database: Database, + pub object_store: ObjectStore, + pub hasher: Arc>, + + pub gateway_state: Arc, + + pub voice_rooms: Arc>>>, +} + +#[derive(Debug, Default)] +pub struct GatewayState { + pub connected: scc::HashMap, +} + +impl AppState { + pub async fn register_gateway_connected_user( + &self, + user_id: Uuid, + session_key: SessionKey, + event_sender: mpsc::UnboundedSender, + ) { + self.gateway_state + .connected + .entry_async(user_id) + .await + .or_default() + .get_mut() + .instances + .insert(session_key, event_sender); + } + + pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) { + let is_empty = { + let mut entry = self + .gateway_state + .connected + .entry_async(user_id) + .await + .or_default(); + + let entry = entry.get_mut(); + + entry.instances.remove(session_id); + entry.instances.is_empty() + }; + + if is_empty { + self.gateway_state.connected.remove_async(&user_id).await; + } + } + + pub async fn register_voice_room( + &self, + room_id: Uuid, + sender: mpsc::UnboundedSender, + ) { + self.voice_rooms.write().await.insert(room_id, sender); + } + + pub async fn unregister_voice_room(&self, room_id: Uuid) { + self.voice_rooms.write().await.remove(&room_id); + } +} + +// file: src/util/mod.rs +pub mod tristate; + +use axum::extract::multipart::Field; +use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; +use serde::{Deserialize, Serialize}; + +use crate::entity; + +pub fn file_id_to_url(file_id: &entity::file::Id) -> Option { + Some( + crate::config::config() + .server + .hostname + .join("files/") + .ok()? + .join(&file_id.to_string()) + .ok()? + .to_string(), + ) +} + +#[derive(Debug, derive_more::Deref)] +pub struct SerdeFieldData(pub FieldData); + +#[async_trait::async_trait] +impl TryFromField for SerdeFieldData { + async fn try_from_field( + field: Field<'_>, + limit_bytes: Option, + ) -> Result { + let field = FieldData::try_from_field(field, limit_bytes).await?; + + Ok(Self(field)) + } +} + +impl Serialize for SerdeFieldData { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + #[derive(Serialize)] + #[serde(rename_all = "camelCase")] + struct Metadata { + name: Option, + file_name: Option, + content_type: Option, + } + + let metadata = Metadata { + name: self.0.metadata.name.clone(), + file_name: self.0.metadata.file_name.clone(), + content_type: self.0.metadata.content_type.clone(), + }; + + metadata.serialize(serializer) + } +} + +pub fn serialize_duration_seconds( + duration: &std::time::Duration, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + let seconds = duration.as_secs(); + seconds.serialize(serializer) +} + +pub fn deserialize_duration_seconds<'de, D>( + deserializer: D, +) -> Result +where + D: serde::Deserializer<'de>, +{ + let seconds = u64::deserialize(deserializer)?; + Ok(std::time::Duration::from_secs(seconds)) +} + +pub fn serialize_duration_seconds_option( + duration: &Option, + serializer: S, +) -> Result +where + S: serde::Serializer, +{ + match duration { + Some(duration) => serialize_duration_seconds(duration, serializer), + None => serializer.serialize_none(), + } +} + +pub fn deserialize_duration_seconds_option<'de, D>( + deserializer: D, +) -> Result, D::Error> +where + D: serde::Deserializer<'de>, +{ + Ok(deserialize_duration_seconds(deserializer).ok()) +} + +// file: src/util/tristate.rs +use std::ops::Deref; + +use serde::{Deserialize, Deserializer, Serialize}; +use validator::ValidateLength; + +#[derive(Debug, Clone)] +pub enum TriState { + Value(T), + None, + Absent, +} + +impl TriState { + pub fn as_deref(&self) -> TriState<&D> + where + T: Deref, + { + match self { + TriState::Value(v) => TriState::Value(v.deref()), + TriState::None => TriState::None, + TriState::Absent => TriState::Absent, + } + } + + pub fn is_value(&self) -> bool { + matches!(self, TriState::Value(_)) + } + + pub fn is_none(&self) -> bool { + matches!(self, TriState::None) + } + + pub fn is_absent(&self) -> bool { + matches!(self, TriState::Absent) + } + + pub fn as_option(&self) -> Option<&T> { + match self { + TriState::Value(v) => Some(v), + TriState::None => None, + TriState::Absent => None, + } + } + + pub fn into_option(self) -> Option { + match self { + TriState::Value(v) => Some(v), + TriState::None => None, + TriState::Absent => None, + } + } +} + +impl Default for TriState { + fn default() -> Self { + TriState::Absent + } +} + +pub(crate) struct TriStateFieldVisitor { + marker: std::marker::PhantomData, +} + +impl Serialize for TriState +where + T: Serialize, +{ + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + match self { + TriState::Value(v) => v.serialize(serializer), + TriState::None => serializer.serialize_none(), + TriState::Absent => serializer.serialize_unit(), + } + } +} + +impl<'de, T> Deserialize<'de> for TriState +where + T: Deserialize<'de>, +{ + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + deserializer.deserialize_option(TriStateFieldVisitor:: { + marker: std::marker::PhantomData, + }) + } +} +impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor +where + T: Deserialize<'de>, +{ + type Value = TriState; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("TriStateField") + } + + #[inline] + fn visit_none(self) -> Result, E> + where + E: serde::de::Error, + { + Ok(TriState::None) + } + + #[inline] + fn visit_some(self, deserializer: D) -> Result + where + D: Deserializer<'de>, + { + T::deserialize(deserializer).map(TriState::Value) + } + + #[inline] + fn visit_unit(self) -> Result, E> + where + E: serde::de::Error, + { + Ok(TriState::None) + } +} + +impl ValidateLength for TriState +where + T: ValidateLength, +{ + fn length(&self) -> Option { + match self { + TriState::Value(v) => v.length(), + TriState::None => None, + TriState::Absent => None, + } + } +} + +// file: src/web/context.rs +use axum::extract::FromRequestParts; +use axum::http::request::Parts; + +use crate::entity; + +#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] +pub struct UserContext { + pub user_id: entity::user::Id, +} + +pub type UserContextResult = Result; + +#[derive(Debug, Clone, Copy, derive_more::Display)] +pub enum Error { + NotInRequest, + NotInHeader, + BadCharacters, + WrongTokenType, + BadToken, + Model, +} + +impl FromRequestParts for UserContext { + type Rejection = super::error::Error; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { + let context = parts + .extensions + .get::() + .cloned() + .ok_or(Error::NotInRequest)??; + + Ok(context) + } +} + +// file: src/web/entity/file.rs +use serde::Serialize; + +use crate::entity::file::Id; +use crate::util; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct File { + pub id: Id, + pub filename: String, + pub content_type: String, + pub size: i64, + pub url: String, +} + +impl From for File { + fn from(file: crate::entity::file::File) -> Self { + Self { + id: file.id, + filename: file.filename, + content_type: file.content_type, + size: file.size, + url: util::file_id_to_url(&file.id).unwrap_or_default(), + } + } +} + +// file: src/web/entity/message.rs +use serde::Serialize; + +use crate::entity::message::Id; +use crate::entity::{channel, user}; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Message { + pub id: Id, + pub author_id: user::Id, + pub channel_id: channel::Id, + pub content: String, + pub created_at: chrono::DateTime, + pub attachments: Vec, +} + +impl Message { + pub fn from_message_with_attachments( + message: crate::entity::message::Message, + attachments: Vec, + ) -> Self { + Self { + id: message.id, + author_id: message.author_id, + channel_id: message.channel_id, + content: message.content, + created_at: message + .id + .get_timestamp() + .as_ref() + .map(uuid::Timestamp::to_unix) + .map(|(secs, nsecs)| { + chrono::DateTime::::from_timestamp(secs as i64, nsecs) + }) + .flatten() + .unwrap_or_default(), + attachments, + } + } +} +// file: src/web/entity/mod.rs +pub mod file; +pub mod message; +pub mod server; +pub mod user; + +// file: src/web/entity/server.rs +use serde::Serialize; + +use crate::entity::server::Id; +use crate::entity::user; +use crate::util; + +#[derive(Debug, Clone, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Server { + pub id: Id, + pub owner_id: user::Id, + pub name: String, + pub icon_url: Option, +} + +impl From for Server { + fn from(server: crate::entity::server::Server) -> Self { + Self { + id: server.id, + owner_id: server.owner_id, + name: server.name, + icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(), + } + } +} + +// file: src/web/entity/user.rs +use crate::entity::user; +use crate::util; + +#[derive(serde::Serialize, Debug)] +#[serde(rename_all = "camelCase")] +pub struct FullUser { + pub id: user::Id, + pub avatar_url: Option, + pub username: String, + pub display_name: Option, + pub email: String, + pub bot: bool, + pub system: bool, + pub settings: serde_json::Value, +} + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(rename_all = "camelCase")] +pub struct PartialUser { + pub id: user::Id, + pub avatar_url: Option, + pub username: String, + pub display_name: Option, + pub bot: bool, + pub system: bool, +} + +impl From for FullUser { + fn from(user: user::User) -> Self { + Self { + id: user.id, + avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), + username: user.username, + display_name: user.display_name, + email: user.email, + bot: user.bot, + system: user.system, + settings: user.settings, + } + } +} + +impl From for PartialUser { + fn from(user: user::User) -> Self { + Self { + id: user.id, + avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), + username: user.username, + display_name: user.display_name, + bot: user.bot, + system: user.system, + } + } +} + +// file: src/web/error.rs +use std::sync::Arc; + +use axum::http::StatusCode; +use axum::response::IntoResponse; + +use crate::web::context; +use crate::{database, jwt, object_store}; + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From)] +pub enum Error { + #[from] + Context(context::Error), + + #[from] + Jwt(jwt::Error), + + #[from] + Hash(argon2::password_hash::Error), + + #[from] + Database(database::Error), + + #[from] + ObjectStore(object_store::Error), + + #[from] + Json(serde_json::error::Error), + + #[from] + JsonRejection(axum::extract::rejection::JsonRejection), + + #[from] + Client(ClientError), +} + +#[derive(Debug, Clone, derive_more::From, serde::Serialize)] +#[serde(tag = "message", content = "details")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum ClientError { + UserAlreadyExists, + UserDoesNotExist, + + ServerDoesNotExist, + + ChannelDoesNotExist, + + MessageDoesNotExist, + + NotAuthorized, + WrongPassword, + NotAllowed, + + InvalidJson(JsonRejection), + + #[from] + ValidationFailed(validator::ValidationErrors), + + InternalServerError, + + Unknown, +} + +#[derive(derive_more::Debug, Clone, serde::Serialize)] +pub struct JsonRejection { + #[serde(skip)] + status: StatusCode, + + reason: String, +} + +impl From<&axum::extract::rejection::JsonRejection> for JsonRejection { + fn from(value: &axum::extract::rejection::JsonRejection) -> Self { + use std::error::Error; + + use axum::extract::rejection::JsonRejection::*; + + let reason = match value { + JsonDataError(e) => e.source().map(ToString::to_string), + JsonSyntaxError(e) => e.source().map(ToString::to_string), + MissingJsonContentType(e) => e.source().map(ToString::to_string), + BytesRejection(e) => e.source().map(ToString::to_string), + _ => None, + } + .unwrap_or_else(|| value.body_text()); + + Self { + status: value.status(), + reason, + } + } +} + +impl Error { + pub fn as_client_error(&self) -> ClientError { + match self { + Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized, + Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists, + Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist, + + Error::Database(database::Error::ServerDoesNotExists) => { + ClientError::ServerDoesNotExist + }, + + Error::Database(database::Error::ChannelDoesNotExists) => { + ClientError::ChannelDoesNotExist + }, + + Error::Database(database::Error::MessageDoesNotExists) => { + ClientError::MessageDoesNotExist + }, + // Error::WrongPassword => ClientError::WrongPassword, + // Error::NotAllowed => ClientError::NotAllowed, + Error::JsonRejection(e) => ClientError::InvalidJson(e.into()), + Error::Client(client_error) => client_error.clone(), + _ => ClientError::InternalServerError, + } + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response(); + + response.extensions_mut().insert(Arc::new(self)); + + response + } +} + +impl ClientError { + pub fn status_code(&self) -> StatusCode { + match self { + ClientError::UserAlreadyExists => StatusCode::CONFLICT, + ClientError::UserDoesNotExist + | ClientError::ServerDoesNotExist + | ClientError::ChannelDoesNotExist + | ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND, + + ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED, + + ClientError::NotAllowed => StatusCode::FORBIDDEN, + + ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY, + + ClientError::InvalidJson(e) => e.status, + + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} + +// file: src/web/middleware/auth.rs +use axum::extract::{Request, State}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use axum::{Extension, RequestExt}; +use axum_extra::TypedHeader; +use axum_extra::headers::Authorization; +use axum_extra::headers::authorization::Bearer; +use axum_extra::typed_header::TypedHeaderRejectionReason; + +use crate::jwt; +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::{self, context}; + +pub async fn require_context( + Extension(context): Extension, + request: Request, + next: Next, +) -> web::error::Result { + context?; + + Ok(next.run(request).await) +} + +pub async fn resolve_context( + State(state): State, + mut request: Request, + next: Next, +) -> impl IntoResponse { + let context = get_context(&state, &mut request).await; + + request.extensions_mut().insert(context); + + next.run(request).await +} + +async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult { + let bearer = request + .extract_parts::>>() + .await + .map_err(|error| match error.reason() { + TypedHeaderRejectionReason::Missing => context::Error::NotInHeader, + TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType, + _ => context::Error::NotInHeader, + })?; + + let token = bearer.token(); + + let context = get_context_from_token(state, token).await?; + + Ok(context) +} + +pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult { + let context = jwt::verify_jwt::( + token, + crate::config::config().security.auth_secret.as_ref(), + ) + .map_err(|_| context::Error::BadToken)?; + + let _ = state + .database + .select_user_by_id(context.user_id) + .await + .map_err(|_| context::Error::BadToken)?; + + Ok(context) +} + +// file: src/web/middleware/mod.rs +mod auth; +mod response_map; + +pub use auth::*; +pub use response_map::*; + +// file: src/web/middleware/response_map.rs +use std::sync::Arc; + +use axum::Json; +use axum::extract::Request; +use axum::middleware::Next; +use axum::response::IntoResponse; +use serde_json::json; + +use crate::web; + +pub async fn response_map(request: Request, next: Next) -> impl IntoResponse { + let response = next.run(request).await; + + let error = response.extensions().get::>(); + + if error.is_some() { + tracing::error!("{:?}", error); + } + + let error_response = error.map(|error| { + let client_error = error.as_client_error(); + + ( + client_error.status_code(), + Json(json!({"error": client_error})), + ) + .into_response() + }); + + error_response.unwrap_or(response) +} + +// file: src/web/mod.rs +mod context; +mod entity; +mod error; +mod middleware; +mod route; +pub mod ws; + +use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin}; + +pub use self::error::{Error, Result}; +use crate::{config, state}; + +pub async fn run(state: state::AppState) -> anyhow::Result<()> { + let config = config::config(); + + let addr: std::net::SocketAddr = (config.server.host, config.server.port).into(); + let listener = tokio::net::TcpListener::bind(addr).await?; + tracing::info!("listening on {}", addr); + + axum::serve(listener, router(state)) + .with_graceful_shutdown(shutdown_signal()) + .await?; + + Ok(()) +} + +fn router(state: state::AppState) -> axum::Router { + use axum::Router; + use axum::routing::*; + + use self::route::*; + + let cors = tower_http::cors::CorsLayer::new() + .allow_origin(AllowOrigin::any()) + .allow_methods(AllowMethods::any()) + .allow_headers(AllowHeaders::any()); + + Router::new() + // websocket + .route("/gateway/ws", get(ws::gateway::ws_handler)) + .route("/voice/ws", get(ws::voice::ws_handler)) + // file + .route("/files/{file_id}", get(file::get)) + // api + .nest( + "/api/v1", + Router::new() + .route("/auth/login", post(auth::login)) + .route("/auth/register", post(auth::register)) + .merge(protected_router()) + .route_layer(axum::middleware::from_fn_with_state( + state.clone(), + middleware::resolve_context, + )), + ) + // middleware + .layer(axum::middleware::from_fn(middleware::response_map)) + .layer(cors) + .layer(tower_http::trace::TraceLayer::new_for_http()) + .with_state(state.clone()) +} + +fn protected_router() -> axum::Router { + use axum::Router; + use axum::routing::*; + + use self::route::*; + + Router::new() + // user + .route("/users/@me", get(user::me)) + .route("/users/@me", patch(user::patch)) + .route("/users/@me/channels", get(user::channel::list)) + .route("/users/@me/channels", post(user::channel::create)) + .route("/users/{id}", get(user::get_by_id)) + // channel + .route( + "/channels/{channel_id}/messages", + get(channel::message::page), + ) + .route( + "/channels/{channel_id}/messages", + post(channel::message::create), + ) + // server + .route("/servers", get(server::list)) + .route("/servers", post(server::create)) + .route("/servers/{server_id}", get(server::get)) + .route("/servers/{server_id}", delete(server::delete)) + .route("/servers/{server_id}/channels", get(server::channel::list)) + .route( + "/servers/{server_id}/channels", + post(server::channel::create), + ) + .route( + "/servers/{server_id}/channels/{channel_id}", + get(server::channel::get), + ) + .route( + "/servers/{server_id}/channels/{channel_id}", + delete(server::channel::delete), + ) + // invite + .route("/servers/{server_id}/invites", post(server::invite::create)) + .route("/invites/{code}", get(server::invite::get)) + // file + .route("/files", post(file::upload)) + // middleware + .route_layer(axum::middleware::from_fn(middleware::require_context)) +} + +async fn shutdown_signal() { + _ = tokio::signal::ctrl_c().await; +} + +// file: src/web/route/auth/login.rs +use argon2::{PasswordHash, PasswordVerifier}; +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use serde::{Deserialize, Serialize}; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{jwt, web}; +use crate::web::entity::user::FullUser; + +#[derive(Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct LoginPayload { + username: String, + password: String, +} + +#[derive(Serialize)] +#[serde(rename_all = "camelCase")] +pub struct LoginResponse { + user: FullUser, + token: String, +} + +pub async fn login( + State(state): State, + Json(payload): Json, +) -> web::Result { + let user = state + .database + .select_user_by_username(&payload.username) + .await?; + + let password_hash = PasswordHash::new(&user.password_hash)?; + + state + .hasher + .verify_password(payload.password.as_bytes(), &password_hash) + .map_err(|_| web::error::ClientError::WrongPassword)?; + + let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?; + + let response = LoginResponse { + user: user.into(), + token, + }; + + Ok(Json(response)) +} + +// file: src/web/route/auth/mod.rs +pub mod login; +pub mod register; + +pub use login::*; +pub use register::*; + +// file: src/web/route/auth/register.rs +use argon2::password_hash::rand_core::OsRng; +use argon2::password_hash::{PasswordHasher, SaltString}; +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web; +use crate::web::entity::user::FullUser; +use crate::web::error::ClientError; + +#[derive(Validate, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct RegisterPayload { + #[validate(regex( + path = "crate::entity::user::USERNAME_REGEX", + code = "invalid_username" + ))] + #[validate(length(min = 3, max = 64))] + username: String, + + #[validate(length(min = 1, max = 64))] + display_name: Option, + + #[validate(email)] + #[validate(length(min = 3, max = 128))] + email: String, + + #[validate(length(min = 3, max = 128))] + password: String, +} + +pub async fn register( + State(state): State, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + payload.validate().map_err(ClientError::ValidationFailed)?; + + let salt = SaltString::generate(&mut OsRng); + + let password_hash = state + .hasher + .hash_password(payload.password.as_bytes(), &salt)? + .to_string(); + + let user = state + .database + .insert_user( + &payload.username, + payload.display_name.as_ref().map(String::as_str), + &payload.email, + &password_hash, + ) + .await?; + + let system_user = state.database.select_user_by_username("system").await?; + if system_user.system { + state + .database + .procedure_create_dm_channel(user.id, system_user.id) + .await?; + } + + Ok(Json(FullUser::from(user))) +} + +// file: src/web/route/channel/message/create.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::file::File; +use crate::web::entity::message::Message; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, serde::Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +#[validate(schema(function = "validate_create_payload"))] +pub struct CreatePayload { + #[validate(length(min = 0, max = 2000))] + pub content: String, + + #[serde(default)] + #[validate(length(min = 0, max = 10))] + pub attachments: Vec, +} + +fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { + if payload.content.is_empty() && payload.attachments.is_empty() { + return Err(validator::ValidationError::new( + "at_least_one_field_required", + )); + } + + Ok(()) +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(channel_id): Path, + Json(payload): Json, +) -> web::Result { + // TODO: check permissions + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let message = state + .database + .insert_channel_message(context.user_id, channel_id, &payload.content) + .await?; + + for attachment in payload.attachments.iter() { + state + .database + .insert_message_attachment(message.id, *attachment) + .await?; + } + + let attachments = state + .database + .select_message_attachments(message.id) + .await? + .into_iter() + .map(File::from) + .collect::>(); + + let message = Message::from_message_with_attachments(message, attachments); + + // TODO: check permissions + ws::gateway::util::send_message_channel( + state, + message.channel_id, + ws::gateway::event::Event::AddMessage { + channel_id, + message: message.clone(), + }, + ); + + Ok(Json(message)) +} + +// file: src/web/route/channel/message/mod.rs +mod page; +mod create; + +pub use page::page; +pub use create::create; +// file: src/web/route/channel/message/page.rs +use axum::Json; +use axum::extract::{Path, Query, State}; +use axum::response::IntoResponse; +use serde::Deserialize; +use tracing::error; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::message::Message; +use crate::{entity, web}; + +#[derive(Debug, Deserialize, Validate)] +#[serde(rename_all = "camelCase")] +pub struct PageParams { + #[serde(default = "limit_default")] + #[validate(range(min = 1, max = 100))] + pub limit: u32, + #[serde(default)] + pub before: Option, +} + +fn limit_default() -> u32 { + 50 +} + +pub async fn page( + State(state): State, + context: UserContext, + Path(channel_id): Path, + Query(params): Query, +) -> web::Result { + // TODO: check permissions + match params.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let messages_futures = state + .database + .select_channel_messages_paginated(channel_id, params.before, params.limit as i64) + .await? + .into_iter() + .map(|message| async { + let attachments = state + .database + .select_message_attachments(message.id) + .await? + .into_iter() + .map(web::entity::file::File::from) + .collect::>(); + + Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments)) + }); + + let messages = futures::future::try_join_all(messages_futures).await?; + + Ok(Json(messages)) +} + +// file: src/web/route/channel/mod.rs +pub mod message; + +// file: src/web/route/file/get.rs +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::{entity, object_store, web}; + +pub async fn get( + State(state): State, + Path(file_id): Path, +) -> web::Result { + let file = match state.database.select_file_by_id(file_id).await { + Ok(file) => file, + Err(e) => { + return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); + }, + }; + + let data = match state + .object_store + .get_object_stream(&file.id.to_string()) + .await + { + Ok(data) => data, + Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => { + let _ = state.database.delete_file_by_id(file.id).await?; + + return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); + }, + Err(e) => { + return Err(object_store::Error::from(e).into()); + }, + }; + + let headers = axum::response::AppendHeaders([ + (axum::http::header::CONTENT_TYPE, file.content_type.clone()), + ( + axum::http::header::CONTENT_DISPOSITION, + format!("filename=\"{}\"", file.filename), + ), + ]); + + Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response()) +} + +// file: src/web/route/file/mod.rs +mod get; +mod upload; + +pub use get::get; +pub use upload::upload; + +// file: src/web/route/file/upload.rs +use axum::Json; +use axum::body::Bytes; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; +use validator::Validate; + +use crate::state::AppState; +use crate::util::SerdeFieldData; +use crate::{object_store, web}; + +#[derive(Debug, Validate, TryFromMultipart)] +#[try_from_multipart(rename_all = "camelCase")] +pub struct UploadPayload { + #[form_data(limit = "50MB")] + #[validate(length(min = 1, max = 16))] + files: Vec>, +} + +pub async fn upload( + State(state): State, + TypedMultipart(payload): TypedMultipart, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let mut file_ids = Vec::new(); + + for file in payload.files { + let db_file = state + .database + .insert_file( + file.metadata + .file_name + .as_deref() + .unwrap_or_else(|| "unknown"), + file.metadata.content_type.as_deref().unwrap_or_default(), + file.contents.len(), + ) + .await?; + + state + .object_store + .put_object(&db_file.id.to_string(), &file.contents) + .await + .map_err(object_store::Error::from)?; + + file_ids.push(db_file.id); + } + + Ok(Json(file_ids)) +} + +// file: src/web/route/mod.rs +pub mod auth; +pub mod channel; +pub mod file; +pub mod server; +pub mod user; + +// file: src/web/route/server/channel/create.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::error::ClientError; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, Validate, Deserialize)] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + name: String, + + #[validate(custom(function = "validate_server_channel_type"))] + r#type: entity::channel::ChannelType, +} + +fn validate_server_channel_type( + r#type: &entity::channel::ChannelType, +) -> Result<(), validator::ValidationError> { + match r#type { + entity::channel::ChannelType::ServerText => Ok(()), + entity::channel::ChannelType::ServerVoice => Ok(()), + entity::channel::ChannelType::ServerCategory => Ok(()), + _ => Err(validator::ValidationError::new("invalid_channel_type")), + } +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(server_id): Path, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + payload.validate().map_err(ClientError::ValidationFailed)?; + + // TODO: check permissions + let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + + let channel = state + .database + .insert_server_channel(&payload.name, 0, payload.r#type, server_id, None) + .await?; + + ws::gateway::util::send_message_server( + state, + server_id, + ws::gateway::event::Event::AddServerChannel { + channel: channel.clone(), + }, + ); + + Ok(Json(channel)) +} + +// file: src/web/route/server/channel/delete.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::webrtc::WebRtcSignal; +use crate::{entity, web}; + +pub async fn delete( + State(state): State, + context: UserContext, + Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, +) -> web::Result { + // TODO: check permissions + + let channel = state.database.select_channel_by_id(channel_id).await?; + let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + + if let Some(channel_server_id) = channel.server_id { + if channel_server_id != server_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + } else { + return Err(web::error::ClientError::NotAllowed.into()); + } + + let channel = state.database.delete_channel_by_id(channel_id).await?; + + ws::gateway::util::send_message_server( + state.clone(), + server_id, + ws::gateway::event::Event::RemoveServerChannel { + server_id: server_id.clone(), + channel_id: channel.id.clone(), + }, + ); + + tokio::spawn(async move { + let voice_rooms = state.voice_rooms.read().await; + if let Some(voice_room) = voice_rooms.get(&channel_id) { + let _ = voice_room.send(WebRtcSignal::Close); + } + }); + + Ok(Json(channel)) +} + +// file: src/web/route/server/channel/get.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{entity, web}; + +pub async fn get( + State(state): State, + context: UserContext, + Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, +) -> web::Result { + // TODO: check permissions + + let channel = state.database.select_channel_by_id(channel_id).await?; + + if let Some(channel_server_id) = channel.server_id { + if channel_server_id != server_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + } else { + return Err(web::error::ClientError::NotAllowed.into()); + } + + Ok(Json(channel)) +} + +// file: src/web/route/server/channel/list.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{entity, web}; + +pub async fn list( + State(state): State, + context: UserContext, + Path(server_id): Path, +) -> web::Result { + let channels = state.database.select_server_channels(server_id).await?; + + Ok(Json(channels)) +} + +// file: src/web/route/server/channel/mod.rs +mod create; +mod delete; +mod get; +mod list; + +pub use create::create; +pub use delete::delete; +pub use get::get; +pub use list::list; + +// file: src/web/route/server/create.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use axum_typed_multipart::TryFromMultipart; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::error::ClientError; +use crate::web::ws; +use crate::{entity, web}; +use crate::web::entity::server::Server; + +#[derive(Debug, Validate, Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + name: String, + + icon_id: Option, +} + +pub async fn create( + State(state): State, + context: UserContext, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + payload.validate().map_err(ClientError::ValidationFailed)?; + + let server = state + .database + .insert_server(&payload.name, payload.icon_id, context.user_id) + .await?; + + let role = state + .database + .insert_server_role( + Some(server.id.into()), + server.id, + "@everyone", + None, + false, + serde_json::json!({}), + 0, + ) + .await?; + + let member = state + .database + .insert_server_member(server.id, context.user_id) + .await?; + + state + .database + .insert_server_member_role(member.id, role.id) + .await?; + + let server = Server::from(server); + + ws::gateway::util::send_message( + state, + context.user_id, + ws::gateway::event::Event::AddServer { + server: server.clone(), + }, + ); + + Ok(Json(server)) +} + +// file: src/web/route/server/delete.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::webrtc::WebRtcSignal; +use crate::{entity, web}; +use crate::web::entity::server::Server; + +pub async fn delete( + State(state): State, + context: UserContext, + Path(server_id): Path, +) -> web::Result { + let server = state.database.select_server_by_id(server_id).await?; + + if server.owner_id != context.user_id { + return Err(web::error::ClientError::NotAllowed.into()); + } + + let members = state + .database + .select_server_members(server_id) + .await? + .iter() + .map(|u| u.id) + .collect::>(); + + let channels = state + .database + .select_server_channels(server_id) + .await? + .iter() + .map(|c| c.id) + .collect::>(); + + let state_clone = state.clone(); + tokio::spawn(async move { + let voice_rooms = state_clone.voice_rooms.read().await; + for channel_id in channels { + if let Some(voice_room) = voice_rooms.get(&channel_id) { + let _ = voice_room.send(WebRtcSignal::Close); + } + } + }); + + let server = state.database.delete_server_by_id(server_id).await?; + + ws::gateway::util::send_message_many( + state.clone(), + &members, + ws::gateway::event::Event::RemoveServer { + server_id: server.id, + }, + ); + + Ok(Json(Server::from(server))) +} + +// file: src/web/route/server/get.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::{entity, web}; +use crate::web::entity::server::Server; + +pub async fn get( + State(state): State, + Path(server_id): Path, +) -> web::Result { + // TODO: check permissions + + let server = state.database.select_server_by_id(server_id).await?; + + Ok(Json(Server::from(server))) +} + +// file: src/web/route/server/invite/create.rs +use axum::Json; +use axum::extract::{Path, Query, State}; +use axum::response::IntoResponse; +use base64::Engine; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::{entity, web}; + +#[derive(serde::Deserialize, Debug)] +pub struct CreateParams { + #[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")] + #[serde(default)] + pub expires_in: Option, +} + +pub async fn create( + State(state): State, + context: UserContext, + Path(server_id): Path, + Query(params): Query, +) -> web::Result { + // TODO: check permissions + + let code = { + use rand::Rng; + + let mut rng = rand::rng(); + let mut code = [0u8; 32]; + rng.fill(&mut code); + base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code) + }; + + let expires_at = params.expires_in.map(|d| { + let now = chrono::Utc::now(); + now + chrono::Duration::from_std(d).expect("valid duration") + }); + + let invite = state + .database + .insert_server_invite(&code, server_id, Some(context.user_id), expires_at) + .await?; + + Ok(Json(invite)) +} + +// file: src/web/route/server/invite/get.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::ws; +use crate::{database, web}; +use crate::web::entity::server::Server; + +pub async fn get( + State(state): State, + context: UserContext, + Path(code): Path, +) -> web::Result { + let invite = state.database.select_server_invite_by_code(&code).await?; + let server = state.database.select_server_by_id(invite.server_id).await?; + + let member = match state + .database + .insert_server_member(invite.server_id, context.user_id) + .await + { + Ok(member) => member, + Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))), + Err(e) => return Err(e.into()), + }; + + state + .database + .insert_server_member_role(member.id, invite.server_id) + .await?; + + let user = state.database.select_user_by_id(context.user_id).await?; + + ws::gateway::util::send_message_server( + state, + invite.server_id, + ws::gateway::event::Event::AddServerMember { + server_id: server.id, + member: user.into(), + }, + ); + + Ok(Json(Server::from(server))) +} + +// file: src/web/route/server/invite/mod.rs +mod create; +mod get; + +pub use create::create; +pub use get::get; + +// file: src/web/route/server/list.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web; +use crate::web::context::UserContext; +use crate::web::entity::server::Server; + +pub async fn list( + State(state): State, + context: UserContext, +) -> web::Result { + let servers = state + .database + .select_user_servers(context.user_id) + .await? + .into_iter() + .map(Server::from) + .collect::>(); + + Ok(Json(servers)) +} + +// file: src/web/route/server/mod.rs +pub mod channel; +mod create; +mod delete; +mod get; +pub mod invite; +mod list; + +pub use create::create; +pub use delete::delete; +pub use get::get; +pub use list::list; + +// file: src/web/route/user/channel/create.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::user::PartialUser; +use crate::web::route::user::channel::RecipientChannel; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, Validate, Deserialize)] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + recipients: Vec, +} + +pub async fn create( + State(state): State, + context: UserContext, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(err) => { + return Err(web::error::ClientError::ValidationFailed(err).into()); + }, + } + + let channel_id = match payload.recipients.len() { + 1 => { + let recipient = payload.recipients[0]; + state + .database + .procedure_create_dm_channel(context.user_id, recipient) + .await? + }, + _ => { + state + .database + .procedure_create_group_channel(context.user_id, &payload.recipients) + .await? + }, + }; + + let channel = state.database.select_channel_by_id(channel_id).await?; + + let recipients = state + .database + .select_channel_recipients(channel_id) + .await? + .unwrap_or_default() + .into_iter() + .map(|user| user.id) + .collect::>(); + + let recipient_channels = RecipientChannel { + channel: channel.clone(), + recipients: recipients.clone(), + }; + + ws::gateway::util::send_message_channel( + state, + channel_id, + ws::gateway::event::Event::AddDmChannel { + channel, + recipients, + }, + ); + + Ok(Json(recipient_channels)) +} + +// file: src/web/route/user/channel/list.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web; +use crate::web::context::UserContext; +use crate::web::route::user::channel::RecipientChannel; + +pub async fn list( + State(state): State, + context: UserContext, +) -> web::Result { + let channels = state.database.select_user_channels(context.user_id).await?; + + let mut recipient_channels = Vec::new(); + for channel in channels { + let recipients = state.database.select_channel_recipients(channel.id).await?; + + let recipients = match recipients { + Some(recipients) => recipients + .into_iter() + .map(|user| user.id) + .collect(), + None => { + continue; + }, + }; + + recipient_channels.push(RecipientChannel { + channel, + recipients, + }); + } + + Ok(Json(recipient_channels)) +} + +// file: src/web/route/user/channel/mod.rs +mod create; +mod list; + +pub use create::create; +pub use list::list; +use serde::Serialize; + +use crate::entity::{channel, user}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct RecipientChannel { + #[serde(flatten)] + pub channel: channel::Channel, + pub recipients: Vec, +} + +// file: src/web/route/user/get.rs +use axum::Json; +use axum::extract::{Path, State}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web; +use crate::web::entity::user::PartialUser; + +pub async fn get_by_id( + State(state): State, + Path(user_id): Path, +) -> web::Result { + let user = state.database.select_user_by_id(user_id).await?; + + Ok(Json(PartialUser::from(user))) +} + +// file: src/web/route/user/me.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web; +use crate::web::context::UserContext; +use crate::web::entity::user::FullUser; + +pub async fn me( + State(state): State, + context: UserContext, +) -> web::Result { + let user = state.database.select_user_by_id(context.user_id).await?; + + Ok(Json(FullUser::from(user))) +} + +// file: src/web/route/user/mod.rs +pub mod channel; +mod get; +mod me; +mod patch; + +pub use get::get_by_id; +pub use me::me; +pub use patch::patch; + + +// file: src/web/route/user/patch.rs +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::user::{FullUser, PartialUser}; +use crate::web::ws; +use crate::{entity, util, web}; + +#[derive(Debug, Validate, Deserialize)] +#[serde(rename_all = "camelCase")] +#[validate(schema(function = "validate_create_payload"))] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + #[serde(default)] + display_name: util::tristate::TriState, + + #[serde(default)] + avatar_id: util::tristate::TriState, +} + +fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { + if payload.display_name.is_absent() && payload.avatar_id.is_absent() { + return Err(validator::ValidationError::new( + "at_least_one_field_required", + )); + } + + Ok(()) +} + +pub async fn patch( + State(state): State, + context: UserContext, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), + }; + + let user = state + .database + .update_user_by_id( + context.user_id, + payload.display_name.as_deref(), + payload.avatar_id, + ) + .await?; + + ws::gateway::util::send_message( + state.clone(), + context.user_id, + ws::gateway::event::Event::AddUser { + user: PartialUser::from(user.clone()), + }, + ); + + ws::gateway::util::send_message_related( + state.clone(), + context.user_id, + ws::gateway::event::Event::AddUser { + user: PartialUser::from(user.clone()), + }, + ); + + Ok(Json(FullUser::from(user))) +} + +// file: src/web/ws/error.rs +pub type Result = std::result::Result>; + +#[derive(Debug, derive_more::From, derive_more::Display)] +pub enum Error { + #[from] + Custom(T), + + #[from] + Json(serde_json::Error), + + #[from] + AcknowledgementError(tokio::sync::oneshot::error::RecvError), + + WrongMessageType, + + WebSocketClosed, + + UnknownError, +} + +pub trait CustomError {} + +// file: src/web/ws/gateway/connection.rs +use axum::extract::ws::Message as AxumMessage; +use base64::Engine as _; +use futures::{Stream, StreamExt}; +use sha2::{Digest, Sha256}; +use tokio::sync::mpsc; + +use super::error::Error as WsError; +use super::event::Event as WsEvent; +use super::protocol::{WsClientMessage, WsServerMessage}; +use super::state::{WsContext, WsState, WsUserContext}; +use crate::jwt; +use crate::state::AppState; +use crate::web::ws::gateway::SessionKey; +use crate::web::ws::general::WebSocketHandler; +use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; +use crate::web::ws::voice; +use crate::webrtc::WebRtcSignal; + +impl WebSocketHandler for WsContext { + type ServerMessage = WsServerMessage; + type ClientMessage = WsClientMessage; + type Error = WsError; + + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> crate::web::ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin, + { + process_websocket_messages(self, stream, sender, app_state).await?; + + Ok(()) + } + + async fn cleanup(&mut self, app_state: &AppState) { + if let Some(user_ctx_data) = &self.user_context { + app_state + .unregister_gateway_connected_user( + user_ctx_data.user_id, + &user_ctx_data.session_key, + ) + .await; + } + + drop(self.event_channel.take()); + } + + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ) { + let error_ws_message = WsServerMessage::Error { code: error }; + + let _ = sender.send(SendWsMessage::new_no_response(error_ws_message)); + } +} + +#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) +))] +async fn process_websocket_messages( + context: &mut WsContext, + mut ws_stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> crate::web::ws::error::Result<(), WsError> +where + S: Stream> + Unpin, +{ + loop { + match context.connection_state { + WsState::Initialize => { + tokio::select! { + biased; + maybe_message = ws_stream.next() => { + match maybe_message { + Some(Ok(message)) => { + handle_initial_message(context, message, sender, app_state).await?; + context.connection_state = WsState::Connected; + tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected."); + } + Some(Err(axum_ws_err)) => { + tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err); + return Err(crate::web::ws::error::Error::WebSocketClosed); + } + None => { + tracing::debug!("WebSocket stream ended by client during Initialize state."); + return Err(crate::web::ws::error::Error::WebSocketClosed); + } + } + } + } + }, + WsState::Connected => { + let user_ctx = context + .user_context + .as_ref() + .expect("User context must be set in Connected state"); + let (_event_tx_ref, event_rx) = context + .event_channel + .as_mut() + .expect("Event channel must be set in Connected state"); + + tokio::select! { + biased; + maybe_app_event = event_rx.recv() => { + if let Some(app_event_data) = maybe_app_event { + SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?; + } else { + tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket."); + return Ok(()); + } + } + maybe_ws_message = ws_stream.next() => { + match maybe_ws_message { + Some(Ok(message)) => { + handle_connected_message(context, message, sender, &app_state).await?; + } + Some(Err(axum_ws_err)) => { + tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err); + return Err(crate::web::ws::error::Error::WebSocketClosed); + } + None => { + tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state."); + return Err(crate::web::ws::error::Error::WebSocketClosed); + } + } + } + } + }, + } + } +} + +#[tracing::instrument(skip_all, fields(state = ?context.connection_state))] +async fn handle_initial_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> crate::web::ws::error::Result<(), WsError> { + match deserialize_ws_message(message)? { + WsClientMessage::Authenticate { token } => { + match crate::web::middleware::get_context_from_token(&app_state, &token).await { + Ok(auth_user_context) => { + let user_id = auth_user_context.user_id; + + let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); + context.event_channel = Some((event_tx.clone(), event_rx)); + + let random_key_part = rand::random::(); + let current_session_key: SessionKey = { + let mut hasher = Sha256::new(); + hasher.update(token.as_bytes()); + hasher.update(user_id.to_string().as_bytes()); + hasher.update(&random_key_part.to_be_bytes()); + base64::engine::general_purpose::URL_SAFE_NO_PAD + .encode(hasher.finalize()) + .into() + }; + + context.user_context = Some(WsUserContext { + user_id, + session_key: current_session_key.clone(), + }); + + app_state + .register_gateway_connected_user( + user_id, + current_session_key.clone(), + event_tx, + ) + .await; + + SendWsMessage::send_with_response( + sender, + WsServerMessage::AuthenticateAccepted { + user_id, + session_key: current_session_key, + }, + ) + .await?; + Ok(()) + }, + Err(_auth_err) => { + tracing::warn!(token = %token, "Authentication failed for token."); + let _ = SendWsMessage::send_with_response( + sender, + WsServerMessage::AuthenticateDenied, + ) + .await; + Err(WsError::AuthenticationFailed.into()) + }, + } + }, + #[allow(unreachable_patterns)] + _ => { + tracing::warn!("Unexpected message type received during Initialize state."); + Err(crate::web::ws::error::Error::WrongMessageType) + }, + } +} + +#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) +))] +async fn handle_connected_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> crate::web::ws::error::Result<(), WsError> { + match deserialize_ws_message(message)? { + WsClientMessage::VoiceStateUpdate { + server_id, + channel_id, + } => { + let server_id = server_id.unwrap_or_else(|| channel_id.clone().into()); + // TODO: check if can join this channel + + let claims = voice::claims::VoiceClaims { + user_id: context + .user_context + .as_ref() + .expect("user context should be present") + .user_id, + server_id, + channel_id, + exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) + .timestamp(), + }; + + let token = jwt::generate_jwt( + claims, + crate::config::config().security.voice_secret.as_ref(), + ) + .map_err(|_| WsError::TokenGenerationFailed)?; + + SendWsMessage::send_with_response( + sender, + WsServerMessage::Event { + event: WsEvent::VoiceServerUpdate { + server_id, + channel_id, + token, + }, + }, + ) + .await?; + + Ok(()) + }, + WsClientMessage::RequestVoiceStates { server_id } => { + let channels = app_state + .database + .select_server_channels(server_id) + .await + .map_err(|_| crate::web::ws::error::Error::UnknownError)?; + + for channel in channels { + let (tx, rx) = tokio::sync::oneshot::channel(); + + let webrtc_sender = + { app_state.voice_rooms.read().await.get(&channel.id).cloned() }; + + if let Some(voice_room) = webrtc_sender { + let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx }); + + let peers = match rx.await { + Ok(peers) => peers, + Err(_) => { + continue; + }, + }; + + for peer in peers { + let _ = + sender.send(SendWsMessage::new_no_response(WsServerMessage::Event { + event: WsEvent::VoiceChannelConnected { + server_id, + channel_id: channel.id, + user_id: peer, + }, + })); + } + } + } + + Ok(()) + }, + other_message => { + tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); + Err(crate::web::ws::error::Error::WrongMessageType) + }, + } +} + +// file: src/web/ws/gateway/error.rs +use crate::web::ws::error::CustomError; + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Error { + AuthenticationFailed, + TokenGenerationFailed, +} + +impl CustomError for Error {} + +// file: src/web/ws/gateway/event.rs +use crate::{entity, web}; + +#[derive(Debug, Clone, serde::Serialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum Event { + #[serde(rename_all = "camelCase")] + AddServer { server: web::entity::server::Server }, + + #[serde(rename_all = "camelCase")] + RemoveServer { server_id: entity::server::Id }, + + #[serde(rename_all = "camelCase")] + AddDmChannel { + channel: entity::channel::Channel, + recipients: Vec, + }, + + #[serde(rename_all = "camelCase")] + RemoveDmChannel { channel_id: entity::channel::Id }, + + #[serde(rename_all = "camelCase")] + AddServerChannel { channel: entity::channel::Channel }, + + #[serde(rename_all = "camelCase")] + RemoveServerChannel { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + }, + + #[serde(rename_all = "camelCase")] + AddUser { + user: web::entity::user::PartialUser, + }, + + #[serde(rename_all = "camelCase")] + RemoveUser { user_id: entity::user::Id }, + + #[serde(rename_all = "camelCase")] + AddServerMember { + server_id: entity::server::Id, + member: web::entity::user::PartialUser, + }, + + #[serde(rename_all = "camelCase")] + RemoveServerMember { + server_id: entity::server::Id, + member_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + AddMessage { + channel_id: entity::channel::Id, + message: web::entity::message::Message, + }, + + #[serde(rename_all = "camelCase")] + RemoveMessage { + channel_id: entity::channel::Id, + message_id: entity::message::Id, + }, + + #[serde(rename_all = "camelCase")] + VoiceChannelConnected { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + VoiceChannelDisconnected { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, + + #[serde(rename_all = "camelCase")] + VoiceServerUpdate { + server_id: entity::server::Id, + channel_id: entity::channel::Id, + token: String, + }, +} + +// file: src/web/ws/gateway/mod.rs +use axum::extract::{State, WebSocketUpgrade}; +use axum::response::IntoResponse; +use dashmap::DashMap; + +use crate::state::AppState; +use crate::web::ws::gateway::state::{EventSender, WsContext}; +use crate::web::ws::general; + +mod connection; +mod error; +pub mod event; +mod protocol; +mod state; +pub mod util; + +#[derive( + Debug, + Clone, + Eq, + PartialEq, + Hash, + Default, + derive_more::Display, + serde::Serialize, + serde::Deserialize, + derive_more::From, +)] +#[serde(transparent)] +pub struct SessionKey(String); + +#[derive(Debug, Default)] +pub struct GatewayWsState { + pub instances: DashMap, +} + +pub async fn ws_handler( + State(app_state): State, + ws: WebSocketUpgrade, +) -> crate::web::error::Result { + Ok(ws.on_upgrade(|socket| { + general::handle_websocket_connection(socket, app_state, WsContext::default()) + })) +} + +// file: src/web/ws/gateway/protocol.rs +use super::{SessionKey, error, event}; +use crate::entity; + +#[derive(Debug, serde::Serialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsServerMessage { + AuthenticateDenied, + + #[serde(rename_all = "camelCase")] + AuthenticateAccepted { + user_id: entity::user::Id, + session_key: SessionKey, + }, + + #[serde(rename_all = "camelCase")] + Event { + event: event::Event, + }, + + #[serde(rename_all = "camelCase")] + Error { + code: error::Error, + }, +} + +#[derive(Debug, serde::Deserialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsClientMessage { + #[serde(rename_all = "camelCase")] + Authenticate { token: String }, + + #[serde(rename_all = "camelCase")] + VoiceStateUpdate { + server_id: Option, + channel_id: entity::channel::Id, + }, + + #[serde(rename_all = "camelCase")] + RequestVoiceStates { server_id: entity::server::Id }, +} + +// file: src/web/ws/gateway/state.rs +use tokio::sync::mpsc; + +use super::{SessionKey, event}; +use crate::entity; + +#[derive(Debug, Eq, PartialEq, Clone, Copy)] +pub enum WsState { + Initialize, + Connected, +} + +#[derive(Debug)] +pub struct WsUserContext { + pub user_id: entity::user::Id, + pub session_key: SessionKey, // Unique key for this specific WebSocket session instance +} + +pub type EventSender = mpsc::UnboundedSender; +pub type EventReceiver = mpsc::UnboundedReceiver; + +pub struct WsContext { + pub connection_state: WsState, + pub user_context: Option, + pub event_channel: Option<(EventSender, EventReceiver)>, +} + +impl Default for WsContext { + fn default() -> Self { + Self { + connection_state: WsState::Initialize, + user_context: None, + event_channel: None, + } + } +} + +// file: src/web/ws/gateway/util.rs +use crate::entity; +use crate::state::AppState; +use crate::web::ws::gateway::event; + +pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) { + tokio::spawn(async move { + let connected_users = state.gateway_state.connected.get_async(&user_id).await; + if let Some(session) = connected_users { + for instance in session.instances.iter() { + if let Err(e) = instance.send(message.clone()) { + tracing::error!("failed to send message: {}", e); + } + } + } + }); +} + +pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) { + for id in user_ids.iter() { + send_message(state.clone(), *id, message.clone()); + } +} + +pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) { + tokio::spawn(async move { + let users = state + .database + .select_server_members(server_id) + .await + .unwrap_or_else(|_| vec![]) + .iter() + .map(|u| u.id) + .collect::>(); + + send_message_many(state, &users, message); + }); +} + +pub fn send_message_channel( + state: AppState, + channel_id: entity::channel::Id, + message: event::Event, +) { + tokio::spawn(async move { + let users = state + .database + .select_channel_members(channel_id) + .await + .unwrap_or_else(|_| vec![]) + .iter() + .map(|u| u.id) + .collect::>(); + + send_message_many(state, &users, message); + }); +} + +pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) { + tokio::spawn(async move { + let users = state + .database + .select_related_user_ids(user_id) + .await + .unwrap_or_else(|_| vec![]); + + send_message_many(state, &users, message); + }); +} + +// file: src/web/ws/general.rs +use std::fmt::Debug; + +use axum::extract::ws::WebSocket; +use futures::{Stream, StreamExt}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::sync::mpsc; + +use crate::state::AppState; +use crate::web::ws::error::CustomError; +use crate::web::ws::util; +use crate::web::ws::util::SendWsMessage; + +pub trait WebSocketHandler { + type ServerMessage: Serialize + Send; + type ClientMessage: DeserializeOwned; + type Error: CustomError + Send + Debug; + + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> crate::web::ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin; + + async fn cleanup(&mut self, app_state: &AppState); + + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ); +} + +#[tracing::instrument(skip_all)] +pub async fn handle_websocket_connection( + websocket: WebSocket, + app_state: AppState, + mut handler: impl WebSocketHandler + 'static, +) { + let (ws_sink, ws_stream) = websocket.split(); + + let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel(); + + let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); + + let processing_result = handler + .handle_stream(ws_stream, &internal_send_tx, &app_state) + .await; + + handler.cleanup(&app_state).await; + + match processing_result { + Ok(_) => {}, + Err(crate::web::ws::error::Error::Custom(err_to_report)) => { + handler + .handle_result_error(err_to_report, &internal_send_tx) + .await; + }, + Err(e) => { + tracing::info!("WebSocket connection closed: {:?}", e); + }, + } + + drop(internal_send_tx); + if let Err(e) = writer_task.await { + tracing::error!( + "WebSocket writer task panicked or encountered an error: {:?}", + e + ); + } +} + +// file: src/web/ws/mod.rs +mod error; +pub mod gateway; +mod util; +pub mod voice; +mod general; + +// file: src/web/ws/util.rs +use axum::extract::ws::Message as AxumMessage; +use futures::{Sink, SinkExt}; +use serde::Serialize; +use serde::de::DeserializeOwned; +use tokio::sync::{mpsc, oneshot}; + +use crate::web::ws::error::CustomError; + +pub fn spawn_writer_task( + mut ws_sink: S, + mut writer_rx: mpsc::UnboundedReceiver>, +) -> tokio::task::JoinHandle<()> +where + S: Sink + Unpin + Send + 'static, + T: Serialize + Send + 'static, + E: CustomError + Send + 'static, +{ + tokio::spawn(async move { + while let Some(SendWsMessage { + message, + response_ch, + }) = writer_rx.recv().await + { + let send_result = match serialize_ws_message(message) { + Ok(ws_msg) => { + if ws_sink.send(ws_msg).await.is_err() { + Err(super::error::Error::WebSocketClosed) + } else { + Ok(()) + } + }, + Err(e) => Err(e.into()), + }; + + if let Some(ch) = response_ch { + let _ = ch.send(send_result); + } + } + + let _ = ws_sink.close().await; + }) +} + +/// Deserializes an Axum WebSocket message into a `WsClientMessage`. +pub fn deserialize_ws_message( + message: AxumMessage, +) -> super::error::Result { + match message { + AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from), + AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed), + _ => Err(super::error::Error::WrongMessageType), + } +} + +/// Serializes a `WsServerMessage` into an Axum WebSocket message. +pub fn serialize_ws_message( + message: T, +) -> super::error::Result { + serde_json::to_string(&message) + .map(Into::into) + .map(AxumMessage::Text) + .map_err(super::error::Error::from) +} + +/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. +/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. +pub struct SendWsMessage { + pub message: T, + pub response_ch: Option>>, +} + +impl SendWsMessage { + /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. + pub async fn send_with_response( + tx: &mpsc::UnboundedSender, // Changed to reference + message: T, + ) -> super::error::Result<(), E> { + let (response_tx, response_rx) = oneshot::channel(); + let send_message = SendWsMessage { + message, + response_ch: Some(response_tx), + }; + + if tx.send(send_message).is_err() { + Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead + } else { + // Wait for the writer task to acknowledge the send attempt. + // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. + response_rx.await? // Propagates RecvError into WsError::AcknowledgementError + } + } + + /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). + pub fn new_no_response(message: T) -> Self { + Self { + message, + response_ch: None, + } + } +} + +// file: src/web/ws/voice/claims.rs +use crate::entity; + +#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] +pub struct VoiceClaims { + pub user_id: entity::user::Id, + pub server_id: entity::server::Id, + pub channel_id: entity::channel::Id, + pub exp: i64, +} + +// file: src/web/ws/voice/connection.rs +use axum::extract::ws::Message as AxumMessage; +use futures::{Stream, StreamExt}; +use tokio::sync::{mpsc, oneshot}; + +use super::error::{self, Error as WsError}; +use super::protocol::{WsClientMessage, WsServerMessage}; +use crate::jwt; +use crate::state::AppState; +use crate::web::ws; +use crate::web::ws::general::WebSocketHandler; +use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; +use crate::web::ws::voice::claims::VoiceClaims; +use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer; +use crate::web::ws::voice::state::{WsContext, WsState}; +use crate::webrtc::{Offer, OfferSignal, WebRtcSignal}; + +impl WebSocketHandler for WsContext { + type ServerMessage = WsServerMessage; + type ClientMessage = WsClientMessage; + type Error = WsError; + + async fn handle_stream( + &mut self, + stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, + ) -> ws::error::Result<(), Self::Error> + where + S: Stream> + Unpin, + { + process_websocket_messages(self, stream, sender, app_state).await?; + + Ok(()) + } + + async fn cleanup(&mut self, app_state: &AppState) { + tracing::debug!("Cleaning up WebSocket connection."); + + match &self.connection_state { + WsState::Connected { + signal_channel, + server_id, + channel_id, + user_id, + .. + } => { + ws::gateway::util::send_message_channel( + app_state.clone(), + *channel_id, + ws::gateway::event::Event::VoiceChannelDisconnected { + server_id: *server_id, + channel_id: *channel_id, + user_id: *user_id, + }, + ); + + let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone())); + }, + WsState::Initialize => {}, + } + } + + async fn handle_result_error( + &mut self, + error: Self::Error, + sender: &mpsc::UnboundedSender>, + ) { + tracing::error!("WebSocket error: {:?}", error); + } +} + +#[tracing::instrument(skip_all)] +async fn process_websocket_messages( + context: &mut WsContext, + mut ws_stream: S, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), WsError> +where + S: Stream> + Unpin, +{ + loop { + match &context.connection_state { + WsState::Initialize => { + while let Some(Ok(message)) = ws_stream.next().await { + handle_initial_message(context, message, sender, &app_state).await?; + break; + } + }, + WsState::Connected { signal_channel, .. } => { + let signal_channel = signal_channel.clone(); + loop { + tokio::select! { + biased; + _ = signal_channel.closed() => { + tracing::debug!("Signal channel closed."); + break; + } + Some(Ok(message)) = ws_stream.next() => { + handle_connected_message(context, message, sender, &app_state).await?; + } + else => { + break; + } + } + } + + return Err(ws::error::Error::WebSocketClosed); + }, + } + } +} + +#[tracing::instrument(skip_all)] +async fn handle_initial_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), error::Error> { + match deserialize_ws_message(message)? { + WsClientMessage::Authenticate { token } => match jwt::verify_jwt::( + &token, + crate::config::config().security.voice_secret.as_ref(), + ) { + Ok(claims) => { + SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted) + .await?; + + let signal_channel = + ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await; + + context.connection_state = WsState::Connected { + signal_channel, + server_id: claims.server_id, + channel_id: claims.channel_id, + user_id: claims.user_id, + }; + + ws::gateway::util::send_message_channel( + app_state.clone(), + claims.channel_id, + ws::gateway::event::Event::VoiceChannelConnected { + server_id: claims.server_id, + channel_id: claims.channel_id, + user_id: claims.user_id, + }, + ); + + Ok(()) + }, + Err(auth_err) => { + tracing::warn!("Authentication failed: {:?}", auth_err); + + let _ = + SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied) + .await; + Err(error::Error::AuthenticationFailed.into()) + }, + }, + #[allow(unreachable_patterns)] + _ => { + tracing::warn!("Unexpected message type received during Initialize state."); + Err(ws::error::Error::WrongMessageType) + }, + } +} + +#[tracing::instrument(skip_all)] +async fn handle_connected_message( + context: &mut WsContext, + message: AxumMessage, + sender: &mpsc::UnboundedSender>, + app_state: &AppState, +) -> ws::error::Result<(), error::Error> { + match deserialize_ws_message(message)? { + WsClientMessage::SdpOffer { sdp } => { + let (signal_channel, user_id) = match &mut context.connection_state { + WsState::Connected { + signal_channel, + user_id, + .. + } => (signal_channel.clone(), *user_id), + _ => return Err(ws::error::Error::WrongMessageType), + }; + + let (tx, rx) = oneshot::channel(); + + let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal { + offer: Offer { + peer_id: user_id, + sdp_offer: sdp, + }, + response: tx, + })); + + let answer_signal = rx.await?; + + sender + .send(SendWsMessage::new_no_response(SdpAnswer { + sdp: answer_signal.sdp_answer, + })) + .map_err(|_| ws::error::Error::WebSocketClosed)?; + + Ok(()) + }, + other_message => { + tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); + Err(ws::error::Error::WrongMessageType) + }, + } +} + +// file: src/web/ws/voice/error.rs +use crate::web::ws::error::CustomError; + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::From, derive_more::Display)] +pub enum Error { + AuthenticationFailed, +} + +impl CustomError for Error {} + +// file: src/web/ws/voice/mod.rs +pub mod claims; +mod connection; +mod error; +mod protocol; +mod state; + +use axum::extract::{State, WebSocketUpgrade}; +use axum::response::IntoResponse; + +use crate::state::AppState; +use crate::web::ws::general; +use crate::web::ws::voice::state::WsContext; + +pub async fn ws_handler( + State(app_state): State, + ws: WebSocketUpgrade, +) -> crate::web::error::Result { + Ok(ws.on_upgrade(|socket| { + general::handle_websocket_connection(socket, app_state, WsContext::default()) + })) +} + +// file: src/web/ws/voice/protocol.rs +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; + +#[derive(Debug, serde::Serialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsServerMessage { + AuthenticateDenied, + + AuthenticateAccepted, + + #[serde(rename_all = "camelCase")] + SdpAnswer { + sdp: RTCSessionDescription, + }, +} + +#[derive(Debug, serde::Deserialize)] +#[serde(tag = "type", content = "data")] +#[serde(rename_all = "SCREAMING_SNAKE_CASE")] +pub enum WsClientMessage { + #[serde(rename_all = "camelCase")] + Authenticate { token: String }, + + #[serde(rename_all = "camelCase")] + SdpOffer { sdp: RTCSessionDescription }, +} + +// file: src/web/ws/voice/state.rs +use tokio::sync::oneshot; +use tokio::sync::mpsc; + +use crate::entity; +use crate::state::AppState; +use crate::webrtc::{OfferSignal, WebRtcSignal}; + +#[derive(Debug)] +pub enum WsState { + Initialize, + Connected { + signal_channel: mpsc::UnboundedSender, + server_id: entity::server::Id, + channel_id: entity::channel::Id, + user_id: entity::user::Id, + }, +} + +pub struct WsContext { + pub connection_state: WsState, +} + +impl Default for WsContext { + fn default() -> Self { + Self { + connection_state: WsState::Initialize, + } + } +} + +pub async fn get_signaling_channel( + app_state: &AppState, + channel_id: entity::channel::Id, +) -> mpsc::UnboundedSender { + let room_sender = { + app_state + .voice_rooms + .read() + .await + .get(&channel_id) + .map(|room| room.clone()) + }; + + match room_sender { + Some(room) => room, + None => { + let (tx, rx) = mpsc::unbounded_channel(); + + let app_state_ = app_state.clone(); + tokio::spawn(async move { + crate::webrtc::webrtc_task(channel_id, rx) + .await + .unwrap_or_else(|err| { + tracing::error!("webrtc task error: {:?}", err); + }); + + { + app_state_.unregister_voice_room(channel_id).await; + } + }); + + { + app_state.register_voice_room(channel_id, tx.clone()).await; + } + + tx + }, + } +} + +// file: src/webrtc/mod.rs +use std::sync::Arc; + +use dashmap::DashMap; +use tracing::Instrument; +use webrtc::api::interceptor_registry::register_default_interceptors; +use webrtc::api::media_engine::MIME_TYPE_OPUS; +use webrtc::api::{API, APIBuilder}; +use webrtc::interceptor::registry::Registry; +use webrtc::peer_connection::configuration::RTCConfiguration; +use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; +use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; +use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType}; +use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; +use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; +use webrtc::track::track_remote::TrackRemote; + +use crate::entity; + +type PeerId = entity::user::Id; +type RoomId = entity::channel::Id; + +struct PeerState { + peer_id: PeerId, + peer_connection: Arc, + outgoing_audio_track: Arc, + ssrc: u32, +} + +struct RoomState { + room_id: RoomId, + peers: DashMap, + close_signal: tokio::sync::mpsc::UnboundedSender<()>, +} + +#[derive(Debug)] +pub struct Offer { + pub peer_id: PeerId, + pub sdp_offer: RTCSessionDescription, +} + +#[derive(Debug)] +pub struct AnswerSignal { + pub sdp_answer: RTCSessionDescription, +} + +#[derive(Debug)] +pub enum WebRtcSignal { + Offer(OfferSignal), + Disconnect(PeerId), + RequestPeers { + response: tokio::sync::oneshot::Sender>, + }, + Close, +} + +#[derive(Debug)] +pub struct OfferSignal { + pub offer: Offer, + pub response: tokio::sync::oneshot::Sender, +} + +#[tracing::instrument(skip(signal))] +pub async fn webrtc_task( + room_id: RoomId, + signal: tokio::sync::mpsc::UnboundedReceiver, +) -> anyhow::Result<()> { + tracing::info!("Starting WebRTC task"); + + let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel(); + + let mut skip_timeout = false; + + let state = Arc::new(RoomState { + room_id, + peers: DashMap::new(), + close_signal, + }); + + let mut signal = signal; + let mut media_engine = webrtc::api::media_engine::MediaEngine::default(); + media_engine.register_default_codecs()?; + let mut registry = Registry::new(); + registry = register_default_interceptors(registry, &mut media_engine)?; + + let api = APIBuilder::new() + .with_media_engine(media_engine) + .with_interceptor_registry(registry) + .build(); + + let api = Arc::new(api); + + loop { + tokio::select! { + biased; + _ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => { + tracing::debug!("initial timeout reached"); + break; + } + _ = close_receiver.recv() => { + tracing::debug!("WebRTC task stopped"); + break; + } + Some(signal) = signal.recv() => { + skip_timeout = true; + match signal { + WebRtcSignal::Offer(offer_signal) => { + let room_state = state.clone(); + let api = api.clone(); + + tokio::spawn(async move { + if let Err(e) = handle_peer(api, room_state, offer_signal).await { + tracing::error!("error handling peer: {}", e); + } + }.instrument(tracing::Span::current())); + } + WebRtcSignal::RequestPeers { response } => { + let peers = state + .peers + .iter() + .map(|pair| pair.key().clone()) + .collect::>(); + + let _ = response.send(peers); + } + WebRtcSignal::Disconnect(peer_id) => { + tracing::debug!("received disconnect signal for peer {}", peer_id); + cleanup_peer(state.clone(), peer_id).await; + } + WebRtcSignal::Close => { + break; + } + } + } + } + } + + Ok(()) +} + +#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id +))] +async fn handle_peer( + api: Arc, + room_state: Arc, + offer_signal: OfferSignal, +) -> anyhow::Result<()> { + tracing::debug!("handling peer"); + + let config = RTCConfiguration { + ..Default::default() + }; + + let outgoing_track = Arc::new(TrackLocalStaticRTP::new( + RTCRtpCodecCapability { + mime_type: MIME_TYPE_OPUS.to_string(), + ..Default::default() + }, + format!("audio-{}", offer_signal.offer.peer_id), + format!("room-{}", room_state.room_id), + )); + + let peer_connection = Arc::new(api.new_peer_connection(config).await?); + + let rtp_sender = peer_connection + .add_track(Arc::clone(&outgoing_track) as Arc) + .await?; + + tracing::debug!("added track to peer connection: {:?}", rtp_sender); + + // Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI) + let outgoing_track_ = Arc::clone(&outgoing_track); + tokio::spawn( + async move { + let mut rtcp_buf = vec![0u8; 1500]; + while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await { + // Process RTCP if needed (e.g., bandwidth estimation, custom feedback) + } + tracing::debug!( + "RTCP reader loop for outgoing track {} ended", + outgoing_track_.id() + ); + } + .instrument(tracing::Span::current()), + ); + + let room_state_ = room_state.clone(); + peer_connection.on_peer_connection_state_change(Box::new(move |state| { + let room_state_ = room_state_.clone(); + Box::pin(async move { + tracing::debug!("peer connection state changed: {:?}", state); + if state == RTCPeerConnectionState::Disconnected + || state == RTCPeerConnectionState::Failed + { + tracing::debug!("peer connection closed"); + cleanup_peer(room_state_, offer_signal.offer.peer_id).await; + } + }) + })); + + let room_state_ = room_state.clone(); + peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| { + tracing::debug!("track received: {:?}", track); + + let room_state_ = room_state_.clone(); + Box::pin(async move { + if track.kind() == RTPCodecType::Audio { + tracing::debug!("audio track received: {:?}", track); + tokio::spawn(async move { + if let Err(e) = + forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await + { + tracing::error!("error forwarding audio track: {}", e); + } + }); + } else { + tracing::warn!("received non-audio track: {:?}", track); + } + }) + })); + + peer_connection + .set_remote_description(offer_signal.offer.sdp_offer) + .await?; + let answer = peer_connection.create_answer(None).await?; + + let mut gathering_complete = peer_connection.gathering_complete_promise().await; + + peer_connection.set_local_description(answer).await?; + + gathering_complete.recv().await; + + tracing::debug!("ICE gathering complete"); + + let local_description = peer_connection + .local_description() + .await + .ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?; + + let peer_state = PeerState { + peer_id: offer_signal.offer.peer_id, + peer_connection: Arc::clone(&peer_connection), + outgoing_audio_track: Arc::clone(&outgoing_track), + ssrc: 0, + }; + + { + if let Some((_, old_peer)) = room_state.peers.remove(&offer_signal.offer.peer_id) { + let _ = old_peer.peer_connection.close().await; + } + } + + { + room_state + .peers + .insert(offer_signal.offer.peer_id, peer_state); + } + + let _ = offer_signal.response.send(AnswerSignal { + sdp_answer: local_description, + }); + + Ok(()) +} + +#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id +))] +async fn forward_audio_track( + room_state: Arc, + peer_id: PeerId, + track: Arc, +) -> anyhow::Result<()> { + let mut rtp_buf = vec![0u8; 1500]; + + { + let mut peer_state = room_state + .peers + .get_mut(&peer_id) + .expect("peer state not found"); + + peer_state.ssrc = track.ssrc(); + } + + while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await { + let other_peer_tracks = room_state + .peers + .iter() + .filter_map(|pair| { + let peer_state = pair.value(); + if peer_state.peer_id != peer_id { + let mut rtp_packet = rtp_packet.clone(); + rtp_packet.header.ssrc = peer_state.ssrc; + + Some((peer_state.outgoing_audio_track.clone(), rtp_packet)) + } else { + None + } + }) + .collect::>(); + + if other_peer_tracks.is_empty() { + // tracing::warn!("no other peers to forward audio track to"); + continue; + } + + let write_futures = other_peer_tracks + .iter() + .map(|(outgoing_track, packet)| outgoing_track.write_rtp(&packet)); + + let results = futures::future::join_all(write_futures).await; + for result in results { + if let Err(e) = result { + tracing::error!("error writing RTP packet: {}", e); + } + } + } + tracing::debug!( + "RTP Read loop ended for track {} from peer {}", + track.id(), + peer_id + ); + + Ok(()) +} + +#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))] +async fn cleanup_peer(room_state: Arc, peer_id: PeerId) { + tracing::debug!("cleaning up peer"); + + if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) { + tracing::debug!("removed peer"); + let pc = Arc::clone(&peer_state.peer_connection); + tokio::spawn(async move { + if let Err(e) = pc.close().await { + if !matches!(e, webrtc::Error::ErrConnectionClosed) { + tracing::warn!("error closing peer connection: {}", e); + } + } + + tracing::debug!("peer connection closed"); + }); + } + + if room_state.peers.is_empty() { + tracing::debug!("no more peers in room, closing room"); + let _ = room_state.close_signal.send(()); + } +} diff --git a/migrations/20250510184011_channel_message.sql b/migrations/20250510184011_channel_message.sql index 93a0cbf..24bbd3a 100644 --- a/migrations/20250510184011_channel_message.sql +++ b/migrations/20250510184011_channel_message.sql @@ -102,4 +102,62 @@ BEGIN RAISE NOTICE 'DM channel created with ID: %', new_channel_id; RETURN new_channel_id; END; -$$; \ No newline at end of file +$$; + +CREATE OR REPLACE FUNCTION create_group_channel( + p_creator_id UUID, -- The user initiating the group creation (will be owner) + p_recipient_user_ids UUID[], -- Array of all user IDs for the group (must include creator) + p_group_channel_type INT2 -- The channel type identifier for groups (e.g., 1) +) + RETURNS UUID + LANGUAGE plpgsql +AS +$$ +DECLARE + new_channel_id UUID; + existing_channel_id UUID; + final_channel_name VARCHAR; + unique_sorted_recipient_ids UUID[]; + num_recipients INT; + uid UUID; + -- Threshold for detailed name vs. summary name + MAX_MEMBERS_FOR_DETAILED_NAME CONSTANT INT := 3; +BEGIN + -- Validate and process recipient IDs + IF p_recipient_user_ids IS NULL OR array_length(p_recipient_user_ids, 1) IS NULL THEN + RAISE EXCEPTION 'Recipient user IDs array must be provided and not empty.'; + END IF; + + -- Get unique, sorted recipient IDs for consistent checking and to avoid duplicates. + SELECT array_agg(DISTINCT u ORDER BY u) INTO unique_sorted_recipient_ids FROM unnest(p_recipient_user_ids) u; + num_recipients := array_length(unique_sorted_recipient_ids, 1); + + -- Validate minimum number of recipients for a group + IF num_recipients < 1 THEN -- Groups typically have at least 2 members + RAISE EXCEPTION 'Group channels (type %) must have at least 2 recipients. Found %.', p_group_channel_type, num_recipients; + END IF; + + -- Create new group channel + INSERT INTO "channel" ("name", "type", "position", "owner_id", "server_id", "parent") + VALUES ('Group', + p_group_channel_type, + 0, -- Default position + p_creator_id, + NULL, -- Not a server channel + NULL -- Not a nested server channel + ) + RETURNING id INTO new_channel_id; + + -- Add all recipients to the channel_recipient table + INSERT INTO "channel_recipient" ("channel_id", "user_id") + VALUES (new_channel_id, p_creator_id); + + INSERT INTO "channel_recipient" ("channel_id", "user_id") + SELECT new_channel_id, r_id + FROM unnest(unique_sorted_recipient_ids) AS r_id; + + RAISE NOTICE 'Group channel (type %) named "%" created with ID: % by owner % for recipients: %', + p_group_channel_type, final_channel_name, new_channel_id, p_creator_id, unique_sorted_recipient_ids; + RETURN new_channel_id; +END; +$$; diff --git a/src/database.rs b/src/database.rs index 821d286..18682b7 100644 --- a/src/database.rs +++ b/src/database.rs @@ -562,7 +562,7 @@ impl Database { Ok(file) } - + pub async fn insert_message_attachment( &self, message_id: entity::message::Id, @@ -578,7 +578,7 @@ impl Database { Ok(()) } - + pub async fn select_message_attachments( &self, message_id: entity::message::Id, @@ -621,7 +621,7 @@ impl Database { &self, user1_id: entity::user::Id, user2_id: entity::user::Id, - ) -> Result> { + ) -> Result { let channel_id = sqlx::query_scalar!( r#"SELECT create_dm_channel($1, $2, $3)"#, user1_id, @@ -629,7 +629,26 @@ impl Database { entity::channel::ChannelType::DirectMessage as i16 ) .fetch_one(&self.pool) - .await?; + .await? + .expect("channel_id is null"); + + Ok(channel_id) + } + + pub async fn procedure_create_group_channel( + &self, + creator_id: entity::user::Id, + users: &[entity::user::Id], + ) -> Result { + let channel_id = sqlx::query_scalar!( + r#"SELECT create_group_channel($1, $2, $3)"#, + creator_id, + users, + entity::channel::ChannelType::DirectMessage as i16 + ) + .fetch_one(&self.pool) + .await? + .expect("channel_id is null"); Ok(channel_id) } diff --git a/src/web/mod.rs b/src/web/mod.rs index 452f33d..475f3b8 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -71,6 +71,7 @@ fn protected_router() -> axum::Router { .route("/users/@me", get(user::me)) .route("/users/@me", patch(user::patch)) .route("/users/@me/channels", get(user::channel::list)) + .route("/users/@me/channels", post(user::channel::create)) .route("/users/{id}", get(user::get_by_id)) // channel .route( diff --git a/src/web/route/user/channel/create.rs b/src/web/route/user/channel/create.rs new file mode 100644 index 0000000..0acdc6c --- /dev/null +++ b/src/web/route/user/channel/create.rs @@ -0,0 +1,75 @@ +use axum::Json; +use axum::extract::State; +use axum::response::IntoResponse; +use axum_extra::extract::WithRejection; +use serde::Deserialize; +use validator::Validate; + +use crate::state::AppState; +use crate::web::context::UserContext; +use crate::web::entity::user::PartialUser; +use crate::web::route::user::channel::RecipientChannel; +use crate::web::ws; +use crate::{entity, web}; + +#[derive(Debug, Validate, Deserialize)] +pub struct CreatePayload { + #[validate(length(min = 1, max = 32))] + recipients: Vec, +} + +pub async fn create( + State(state): State, + context: UserContext, + WithRejection(Json(payload), _): WithRejection, web::Error>, +) -> web::Result { + match payload.validate() { + Ok(_) => {}, + Err(err) => { + return Err(web::error::ClientError::ValidationFailed(err).into()); + }, + } + + let channel_id = match payload.recipients.len() { + 1 => { + let recipient = payload.recipients[0]; + state + .database + .procedure_create_dm_channel(context.user_id, recipient) + .await? + }, + _ => { + state + .database + .procedure_create_group_channel(context.user_id, &payload.recipients) + .await? + }, + }; + + let channel = state.database.select_channel_by_id(channel_id).await?; + + let recipients = state + .database + .select_channel_recipients(channel_id) + .await? + .unwrap_or_default() + .into_iter() + .map(|user| user.id) + .collect::>(); + + let recipient_channels = RecipientChannel { + channel: channel.clone(), + recipients: recipients.clone(), + }; + + ws::gateway::util::send_message_channel( + state, + channel_id, + ws::gateway::event::Event::AddDmChannel { + channel, + recipients, + }, + ); + + Ok(Json(recipient_channels)) +} diff --git a/src/web/route/user/channel/list.rs b/src/web/route/user/channel/list.rs index 79e9450..68f410c 100644 --- a/src/web/route/user/channel/list.rs +++ b/src/web/route/user/channel/list.rs @@ -1,21 +1,11 @@ use axum::Json; use axum::extract::State; use axum::response::IntoResponse; -use serde::Serialize; -use crate::entity::channel; use crate::state::AppState; use crate::web; use crate::web::context::UserContext; -use crate::web::entity::user::PartialUser; - -#[derive(Debug, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct RecipientChannel { - #[serde(flatten)] - pub channel: channel::Channel, - pub recipients: Vec, -} +use crate::web::route::user::channel::RecipientChannel; pub async fn list( State(state): State, @@ -30,8 +20,7 @@ pub async fn list( let recipients = match recipients { Some(recipients) => recipients .into_iter() - .filter(|user| user.id != context.user_id) - .map(PartialUser::from) + .map(|user| user.id) .collect(), None => { continue; diff --git a/src/web/route/user/channel/mod.rs b/src/web/route/user/channel/mod.rs index 19f2172..9b93cbd 100644 --- a/src/web/route/user/channel/mod.rs +++ b/src/web/route/user/channel/mod.rs @@ -1,3 +1,16 @@ +mod create; mod list; +pub use create::create; pub use list::list; +use serde::Serialize; + +use crate::entity::{channel, user}; + +#[derive(Debug, Serialize)] +#[serde(rename_all = "camelCase")] +struct RecipientChannel { + #[serde(flatten)] + pub channel: channel::Channel, + pub recipients: Vec, +} diff --git a/src/web/ws/gateway/event.rs b/src/web/ws/gateway/event.rs index 6734621..ec02281 100644 --- a/src/web/ws/gateway/event.rs +++ b/src/web/ws/gateway/event.rs @@ -11,7 +11,10 @@ pub enum Event { RemoveServer { server_id: entity::server::Id }, #[serde(rename_all = "camelCase")] - AddDmChannel { channel: entity::channel::Channel }, + AddDmChannel { + channel: entity::channel::Channel, + recipients: Vec, + }, #[serde(rename_all = "camelCase")] RemoveDmChannel { channel_id: entity::channel::Id }, @@ -31,9 +34,7 @@ pub enum Event { }, #[serde(rename_all = "camelCase")] - RemoveUser { - user_id: entity::user::Id, - }, + RemoveUser { user_id: entity::user::Id }, #[serde(rename_all = "camelCase")] AddServerMember { diff --git a/src/webrtc/mod.rs b/src/webrtc/mod.rs index a72eaac..cb754e2 100644 --- a/src/webrtc/mod.rs +++ b/src/webrtc/mod.rs @@ -23,6 +23,7 @@ struct PeerState { peer_id: PeerId, peer_connection: Arc, outgoing_audio_track: Arc, + ssrc: u32, } struct RoomState { @@ -136,7 +137,8 @@ pub async fn webrtc_task( Ok(()) } -#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id))] +#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id +))] async fn handle_peer( api: Arc, room_state: Arc, @@ -238,13 +240,11 @@ async fn handle_peer( peer_id: offer_signal.offer.peer_id, peer_connection: Arc::clone(&peer_connection), outgoing_audio_track: Arc::clone(&outgoing_track), + ssrc: 0, }; { - if let Some((_, old_peer)) = room_state - .peers - .remove(&offer_signal.offer.peer_id) - { + if let Some((_, old_peer)) = room_state.peers.remove(&offer_signal.offer.peer_id) { let _ = old_peer.peer_connection.close().await; } } @@ -262,13 +262,24 @@ async fn handle_peer( Ok(()) } -#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id))] +#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id +))] async fn forward_audio_track( room_state: Arc, peer_id: PeerId, track: Arc, ) -> anyhow::Result<()> { let mut rtp_buf = vec![0u8; 1500]; + + { + let mut peer_state = room_state + .peers + .get_mut(&peer_id) + .expect("peer state not found"); + + peer_state.ssrc = track.ssrc(); + } + while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await { let other_peer_tracks = room_state .peers @@ -276,7 +287,10 @@ async fn forward_audio_track( .filter_map(|pair| { let peer_state = pair.value(); if peer_state.peer_id != peer_id { - Some(peer_state.outgoing_audio_track.clone()) + let mut rtp_packet = rtp_packet.clone(); + rtp_packet.header.ssrc = peer_state.ssrc; + + Some((peer_state.outgoing_audio_track.clone(), rtp_packet)) } else { None } @@ -290,7 +304,7 @@ async fn forward_audio_track( let write_futures = other_peer_tracks .iter() - .map(|outgoing_track| outgoing_track.write_rtp(&rtp_packet)); + .map(|(outgoing_track, packet)| outgoing_track.write_rtp(&packet)); let results = futures::future::join_all(write_futures).await; for result in results {