diff --git a/backend.txt b/backend.txt deleted file mode 100644 index 8742ed2..0000000 --- a/backend.txt +++ /dev/null @@ -1,5062 +0,0 @@ - -// file: migrations/20250510180101_uuidv7.sql -CREATE EXTENSION IF NOT EXISTS pg_uuidv7; -// file: migrations/20250510182916_file.sql -CREATE TABLE IF NOT EXISTS "file" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "filename" VARCHAR NOT NULL, - "content_type" VARCHAR NOT NULL, - "size" INT8 NOT NULL -); - - -// file: migrations/20250510183102_user.sql -CREATE TABLE IF NOT EXISTS "user" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL, - "username" VARCHAR NOT NULL UNIQUE, - "display_name" VARCHAR, - "email" VARCHAR NOT NULL, - "password_hash" VARCHAR NOT NULL, - "bot" BOOLEAN NOT NULL DEFAULT FALSE, - "system" BOOLEAN NOT NULL DEFAULT FALSE, - "settings" JSONB NOT NULL DEFAULT '{}'::JSONB -); - -CREATE TABLE IF NOT EXISTS "user_relation" -( - "user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE, - "other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE, - "type" INT2 NOT NULL, - "created_at" TIMESTAMPTZ NOT NULL DEFAULT now(), - "updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(), - PRIMARY KEY ("user_id", "other_id") -); - --- create system account -INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system") -VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE); - -CREATE OR REPLACE FUNCTION check_avatar_is_image() - RETURNS TRIGGER AS -$$ -DECLARE - file_content_type VARCHAR; -BEGIN - -- Skip check if icon_id is null - IF NEW.avatar_id IS NULL THEN - RETURN NEW; - END IF; - - -- Retrieve content_type from file table - SELECT content_type - INTO file_content_type - FROM file - WHERE id = NEW.avatar_id; - - -- Raise exception if content_type does not start with 'image/' - IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN - RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/'; - END IF; - - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -CREATE TRIGGER trigger_check_icon_is_image - BEFORE INSERT OR UPDATE - ON "user" - FOR EACH ROW -EXECUTE FUNCTION check_avatar_is_image(); - -CREATE OR REPLACE FUNCTION fn_on_user_relation_update() - RETURNS TRIGGER - LANGUAGE plpgsql -AS -$$ -BEGIN - IF (TG_OP = 'UPDATE') THEN - NEW.updated_at := now(); - END IF; - - RETURN NEW; -END; -$$; - -CREATE TRIGGER trg_user_relation_update - BEFORE UPDATE - ON "user_relation" - FOR EACH ROW -EXECUTE FUNCTION fn_on_user_relation_update(); -// file: migrations/20250510183125_server.sql -CREATE TABLE IF NOT EXISTS "server" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "owner_id" UUID NOT NULL REFERENCES "user" ("id"), - "name" VARCHAR NOT NULL, - "icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL -); - -CREATE TABLE IF NOT EXISTS "server_role" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, - "name" VARCHAR NOT NULL, - "color" VARCHAR, - "display" BOOLEAN NOT NULL DEFAULT FALSE, - "permissions" JSONB NOT NULL DEFAULT '{}'::JSONB, - "position" SMALLINT NOT NULL -); - -CREATE TABLE IF NOT EXISTS "server_member" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, - "user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, - "nickname" VARCHAR, - "avatar_url" VARCHAR, - UNIQUE ("server_id", "user_id") -); - -CREATE TABLE IF NOT EXISTS "server_member_role" -( - "member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE, - "role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE, - PRIMARY KEY ("member_id", "role_id") -); - -CREATE TABLE IF NOT EXISTS "server_invite" -( - "code" VARCHAR NOT NULL PRIMARY KEY, - "server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE, - "inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL, - "expires_at" TIMESTAMPTZ -); - -CREATE OR REPLACE FUNCTION check_icon_is_image() - RETURNS TRIGGER AS -$$ -DECLARE - file_content_type VARCHAR; -BEGIN - -- Skip check if icon_id is null - IF NEW.icon_id IS NULL THEN - RETURN NEW; - END IF; - - -- Retrieve content_type from file table - SELECT content_type - INTO file_content_type - FROM file - WHERE id = NEW.icon_id; - - -- Raise exception if content_type does not start with 'image/' - IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN - RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/'; - END IF; - - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -CREATE TRIGGER trigger_check_icon_is_image - BEFORE INSERT OR UPDATE - ON "server" - FOR EACH ROW -EXECUTE FUNCTION check_icon_is_image(); - -CREATE OR REPLACE FUNCTION check_server_user_role_server_id() - RETURNS TRIGGER AS -$$ -DECLARE - member_server_id UUID; - role_server_id UUID; -BEGIN - -- Get server_id from server_user - SELECT server_id - INTO member_server_id - FROM server_member - WHERE id = NEW.member_id; - - -- Get server_id from server_role - SELECT server_id - INTO role_server_id - FROM server_role - WHERE id = NEW.role_id; - - -- Check if server_ids match - IF member_server_id != role_server_id THEN - RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id; - END IF; - - RETURN NEW; -END; -$$ LANGUAGE plpgsql; - -CREATE TRIGGER enforce_server_user_role_server_id - BEFORE INSERT OR UPDATE - ON server_member_role - FOR EACH ROW -EXECUTE FUNCTION check_server_user_role_server_id(); -// file: migrations/20250510184011_channel_message.sql -CREATE TABLE IF NOT EXISTS "channel" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "name" VARCHAR NOT NULL, - "type" INT2 NOT NULL, - "position" INT2 NOT NULL, - "server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels - "parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels - "owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels -); - --- only for dm or group channels -CREATE TABLE IF NOT EXISTS "channel_recipient" -( - "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE, - "user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, - PRIMARY KEY ("channel_id", "user_id") -); - -CREATE TABLE IF NOT EXISTS "message" -( - "id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(), - "author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE, - "channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE, - "content" TEXT NOT NULL -); - -CREATE TABLE IF NOT EXISTS "message_attachment" -( - "message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE, - "file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE, - "order" INT2 NOT NULL, - PRIMARY KEY ("message_id", "file_id") -); - -ALTER TABLE "channel" - ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL; - - -CREATE OR REPLACE FUNCTION create_dm_channel( - user1_id UUID, - user2_id UUID, - channel_type INT2 -) - RETURNS UUID - LANGUAGE plpgsql -AS -$$ -DECLARE - new_channel_id UUID; - channel_name VARCHAR; -BEGIN - -- Validate parameters - IF user1_id IS NULL OR user2_id IS NULL THEN - RAISE EXCEPTION 'Both user IDs must be provided'; - END IF; - - IF user1_id = user2_id THEN - RAISE EXCEPTION 'Cannot create a DM channel with the same user'; - END IF; - --- Check if users exist - IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR - NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN - RAISE EXCEPTION 'One or both users do not exist'; - END IF; - --- Check if DM already exists between these users - IF exists (SELECT 1 - FROM channel c - JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id - JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id - WHERE c.type = channel_type - AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN - -- Find and return the existing channel ID - SELECT c.id - INTO new_channel_id - FROM channel c - JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id - JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id - WHERE c.type = channel_type - AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2 - LIMIT 1; - - RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id; - RETURN new_channel_id; - END IF; - --- Generate channel name (conventionally uses user IDs in DMs) - channel_name := concat(user1_id, '-', user2_id); - - -- Create new channel - INSERT INTO "channel" ("name", "type", "position") - VALUES (channel_name, channel_type, 0) - RETURNING id INTO new_channel_id; - --- Add both users as recipients - INSERT INTO "channel_recipient" ("channel_id", "user_id") - VALUES (new_channel_id, user1_id), - (new_channel_id, user2_id); - - RAISE NOTICE 'DM channel created with ID: %', new_channel_id; - RETURN new_channel_id; -END; -$$; - -CREATE OR REPLACE FUNCTION create_group_channel( - p_creator_id UUID, -- The user initiating the group creation (will be owner) - p_recipient_user_ids UUID[], -- Array of all user IDs for the group (must include creator) - p_group_channel_type INT2 -- The channel type identifier for groups (e.g., 1) -) - RETURNS UUID - LANGUAGE plpgsql -AS -$$ -DECLARE - new_channel_id UUID; - existing_channel_id UUID; - final_channel_name VARCHAR; - unique_sorted_recipient_ids UUID[]; - num_recipients INT; - uid UUID; - -- Threshold for detailed name vs. summary name - MAX_MEMBERS_FOR_DETAILED_NAME CONSTANT INT := 3; -BEGIN - -- Validate and process recipient IDs - IF p_recipient_user_ids IS NULL OR array_length(p_recipient_user_ids, 1) IS NULL THEN - RAISE EXCEPTION 'Recipient user IDs array must be provided and not empty.'; - END IF; - - -- Get unique, sorted recipient IDs for consistent checking and to avoid duplicates. - SELECT array_agg(DISTINCT u ORDER BY u) INTO unique_sorted_recipient_ids FROM unnest(p_recipient_user_ids) u; - num_recipients := array_length(unique_sorted_recipient_ids, 1); - - -- Validate minimum number of recipients for a group - IF num_recipients < 1 THEN -- Groups typically have at least 2 members - RAISE EXCEPTION 'Group channels (type %) must have at least 2 recipients. Found %.', p_group_channel_type, num_recipients; - END IF; - - -- Create new group channel - INSERT INTO "channel" ("name", "type", "position", "owner_id", "server_id", "parent") - VALUES ('Group', - p_group_channel_type, - 0, -- Default position - p_creator_id, - NULL, -- Not a server channel - NULL -- Not a nested server channel - ) - RETURNING id INTO new_channel_id; - - -- Add all recipients to the channel_recipient table - INSERT INTO "channel_recipient" ("channel_id", "user_id") - VALUES (new_channel_id, p_creator_id); - - INSERT INTO "channel_recipient" ("channel_id", "user_id") - SELECT new_channel_id, r_id - FROM unnest(unique_sorted_recipient_ids) AS r_id; - - RAISE NOTICE 'Group channel (type %) named "%" created with ID: % by owner % for recipients: %', - p_group_channel_type, final_channel_name, new_channel_id, p_creator_id, unique_sorted_recipient_ids; - RETURN new_channel_id; -END; -$$; - -// file: migrations/20250517190855_util.sql -CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID) - RETURNS TABLE (user_id UUID) AS $$ -BEGIN - RETURN QUERY - -- Users directly related to the target user - SELECT ur.user_id - FROM user_relation ur - WHERE ur.other_id = target_user_id - - UNION - - -- Users where target user is related to them - SELECT ur.other_id AS user_id - FROM user_relation ur - WHERE ur.user_id = target_user_id - - UNION - - -- Users who share a server with the target user - SELECT sm.user_id - FROM server_member sm - JOIN server_member sm2 ON sm.server_id = sm2.server_id - WHERE sm2.user_id = target_user_id - AND sm.user_id != target_user_id - - UNION - - -- Users who share a channel with the target user (DM or group) - SELECT cr.user_id - FROM channel_recipient cr - JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id - WHERE cr2.user_id = target_user_id - AND cr.user_id != target_user_id; -END; -$$ LANGUAGE plpgsql; -// file: src/config.rs -use std::sync::OnceLock; - -use serde::Deserialize; - -pub fn config() -> &'static Config { - static INSTANCE: OnceLock = OnceLock::new(); - - INSTANCE.get_or_init(|| { - config::Config::builder() - .add_source(config::File::with_name("config")) - .add_source(config::Environment::with_prefix("DIPLOM_")) - .build() - .expect("config builder") - .try_deserialize() - .expect("config deserialize") - }) -} - -#[derive(Deserialize)] -pub struct Config { - pub server: ServerConfig, - pub security: SecurityConfig, - pub gateway: GatewayConfig, - pub database: DatabaseConfig, - pub object_store: ObjectStoreConfig, -} - -#[derive(Deserialize)] -pub struct ServerConfig { - pub hostname: url::Url, - pub host: std::net::Ipv4Addr, - pub port: u16, -} - -#[derive(Deserialize)] -pub struct SecurityConfig { - pub auth_secret: String, - pub voice_secret: String, -} - -#[derive(Deserialize)] -pub struct GatewayConfig { - #[serde(deserialize_with = "crate::util::deserialize_duration_seconds")] - pub voice_token_lifetime: std::time::Duration, -} - -#[derive(Debug, Deserialize)] -#[serde(untagged)] -#[serde(deny_unknown_fields)] -pub enum DatabaseConfig { - Url { - url: url::Url, - }, - Full { - driver: String, - username: String, - password: String, - host: url::Host, - port: u16, - name: String, - }, -} - -#[derive(Debug, Deserialize)] -#[serde(deny_unknown_fields)] -pub struct ObjectStoreConfig { - pub endpoint: url::Url, - pub region: String, - pub bucket: String, - pub access_key: String, - pub secret_key: String, -} - -impl DatabaseConfig { - pub fn url(&self) -> Option { - match self { - Self::Url { url } => Some(url.clone()), - Self::Full { - driver, - username, - password, - host, - port, - name, - } => { - let str = format!( - "{}://{}:{}@{}:{}/{}", - driver, username, password, host, port, name - ); - let url = url::Url::parse(&str).ok()?; - Some(url) - }, - } - } -} - -// file: src/database.rs -use crate::config::DatabaseConfig; -use crate::entity; -use crate::util::tristate::TriState; - -#[derive(Clone, derive_more::AsRef)] -pub struct Database { - pool: sqlx::PgPool, -} - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)] -pub enum Error { - #[from] - Migrate(sqlx::migrate::MigrateError), - - #[from] - Sqlx(sqlx::Error), - - UserDoesNotExists, - UserAlreadyExists, - - ServerDoesNotExists, - - MemberAlreadyExists, - - ChannelDoesNotExists, - - InviteDoesNotExists, - - MessageDoesNotExists, - - FileDoesNotExists, -} - -impl Database { - pub async fn connect(config: &DatabaseConfig) -> sqlx::Result { - static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!(); - - let pool = sqlx::postgres::PgPoolOptions::new() - .connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str()) - .await?; - - MIGRATOR.run(&pool).await?; - - Ok(Self { pool }) - } -} - -impl Database { - pub async fn insert_user( - &self, - username: &str, - display_name: Option<&str>, - email: &str, - password_hash: &str, - ) -> Result { - let user = match sqlx::query_as!( - entity::user::User, - r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#, - username, - display_name, - email, - password_hash - ) - .fetch_one(&self.pool) - .await { - Ok(user) => user, - Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => { - return Err(Error::UserAlreadyExists); - } - Err(e) => return Err(e.into()), - }; - - Ok(user) - } - - pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result { - let user = sqlx::query_as!( - entity::user::User, - r#"SELECT * FROM "user" WHERE "id" = $1"#, - user_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::UserDoesNotExists)?; - - Ok(user) - } - - pub async fn update_user_by_id( - &self, - user_id: entity::user::Id, - display_name: TriState<&str>, - avatar_id: TriState, - ) -> Result { - let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#); - - let mut separated = query_builder.separated(", "); - - if !display_name.is_absent() { - separated.push(r#""display_name" = "#); - separated.push_bind_unseparated(display_name.into_option()); - } - - if !avatar_id.is_absent() { - separated.push(r#""avatar_id" = "#); - separated.push_bind_unseparated(avatar_id.into_option()); - } - - query_builder.push(" WHERE \"id\" = "); - query_builder.push_bind(user_id); - query_builder.push(" RETURNING \"user\".*"); - - let user = query_builder.build_query_as().fetch_one(&self.pool).await?; - - Ok(user) - } - - pub async fn select_users_by_ids( - &self, - user_ids: &[entity::user::Id], - ) -> Result> { - let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#); - - let mut separated = query_builder.separated(", "); - for id in user_ids.iter() { - separated.push_bind(id); - } - - query_builder.push(")"); - - let users = query_builder.build_query_as().fetch_all(&self.pool).await?; - - Ok(users) - } - - pub async fn select_user_by_username(&self, username: &str) -> Result { - let user = sqlx::query_as!( - entity::user::User, - r#"SELECT * FROM "user" WHERE "username" = $1"#, - username - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::UserDoesNotExists)?; - - Ok(user) - } - - pub async fn select_server_by_id( - &self, - server_id: entity::server::Id, - ) -> Result { - let server = sqlx::query_as!( - entity::server::Server, - r#"SELECT * FROM "server" WHERE "id" = $1"#, - server_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::ServerDoesNotExists)?; - - Ok(server) - } - - pub async fn select_user_servers( - &self, - user_id: entity::user::Id, - ) -> Result> { - let servers = sqlx::query_as!( - entity::server::Server, - r#"SELECT * FROM "server" WHERE "id" IN ( - SELECT "server_id" FROM "server_member" - WHERE "user_id" = $1 - )"#, - user_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(servers) - } - - pub async fn select_server_members( - &self, - server_id: entity::server::Id, - ) -> Result> { - let users = sqlx::query_as!( - entity::user::User, - r#"SELECT * FROM "user" WHERE "id" IN ( - SELECT "user_id" FROM "server_member" WHERE "server_id" = $1 - )"#, - server_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(users) - } - - pub async fn select_channel_members( - &self, - channel_id: entity::channel::Id, - ) -> Result> { - let users = sqlx::query_as!( - entity::user::User, - r#"SELECT * FROM "user" WHERE "id" IN ( - SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1 - UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN ( - SELECT "server_id" FROM "channel" WHERE "id" = $1 - ) - )"#, - channel_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(users) - } - - pub async fn select_user_channels( - &self, - user_id: entity::user::Id, - ) -> Result> { - let channels = sqlx::query_as!( - entity::channel::Channel, - r#"SELECT * FROM "channel" WHERE "id" IN ( - SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1 - )"#, - user_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(channels) - } - - pub async fn insert_server( - &self, - name: &str, - icon_id: Option, - owner_id: entity::user::Id, - ) -> Result { - let server = sqlx::query_as!( - entity::server::Server, - r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#, - name, - icon_id, - owner_id - ) - .fetch_one(&self.pool) - .await?; - - Ok(server) - } - - pub async fn insert_server_role( - &self, - id: Option, - server_id: entity::server::Id, - name: &str, - color: Option<&str>, - display: bool, - permissions: serde_json::Value, - position: u16, - ) -> Result { - let role = sqlx::query_as!( - entity::server::role::ServerRole, - r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#, - id, - server_id, - name, - color, - display, - permissions, - position as i16 - ) - .fetch_one(&self.pool) - .await?; - - Ok(role) - } - - pub async fn insert_server_member( - &self, - server_id: entity::server::Id, - user_id: entity::user::Id, - ) -> Result { - let member = match sqlx::query_as!( - entity::server::member::ServerMember, - r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#, - server_id, - user_id - ) - .fetch_one(&self.pool) - .await { - Ok(member) => member, - Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => { - return Err(Error::MemberAlreadyExists); - } - Err(e) => return Err(e.into()), - }; - - Ok(member) - } - - pub async fn insert_server_member_role( - &self, - server_member_id: entity::server::member::Id, - server_role_id: entity::server::role::Id, - ) -> Result<()> { - sqlx::query!( - r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#, - server_member_id, - server_role_id - ) - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn select_channel_by_id( - &self, - channel_id: entity::channel::Id, - ) -> Result { - let channel = sqlx::query_as!( - entity::channel::Channel, - r#"SELECT * FROM "channel" WHERE "id" = $1"#, - channel_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::ChannelDoesNotExists)?; - - Ok(channel) - } - - pub async fn select_channel_recipients( - &self, - channel_id: entity::channel::Id, - ) -> Result>> { - let recipients = sqlx::query_as!( - entity::user::User, - r#"SELECT * FROM "user" WHERE "id" IN ( - SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1 - )"#, - channel_id - ) - .fetch_all(&self.pool) - .await?; - - if recipients.is_empty() { - return Ok(None); - } - - Ok(Some(recipients)) - } - - pub async fn select_server_channels( - &self, - server_id: entity::server::Id, - ) -> Result> { - let channels = sqlx::query_as!( - entity::channel::Channel, - r#"SELECT * FROM "channel" WHERE "server_id" = $1"#, - server_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(channels) - } - - pub async fn delete_server_by_id( - &self, - server_id: entity::server::Id, - ) -> Result { - let server = sqlx::query_as!( - entity::server::Server, - r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#, - server_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::ServerDoesNotExists)?; - - Ok(server) - } - - pub async fn delete_channel_by_id( - &self, - channel_id: entity::channel::Id, - ) -> Result { - let channel = sqlx::query_as!( - entity::channel::Channel, - r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#, - channel_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::ChannelDoesNotExists)?; - - Ok(channel) - } - - pub async fn insert_server_channel( - &self, - name: &str, - position: u16, - r#type: entity::channel::ChannelType, - server_id: entity::server::Id, - parent: Option, - ) -> Result { - let channel = sqlx::query_as!( - entity::channel::Channel, - r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#, - name, - r#type as i16, - position as i16, - server_id, - parent - ) - .fetch_one(&self.pool) - .await?; - - Ok(channel) - } - - pub async fn insert_server_invite( - &self, - code: &str, - server_id: entity::server::Id, - inviter_id: Option, - expires_at: Option>, - ) -> Result { - let invite = sqlx::query_as!( - entity::server::invite::ServerInvite, - r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#, - code, - server_id, - inviter_id, - expires_at - ) - .fetch_one(&self.pool) - .await?; - - Ok(invite) - } - - pub async fn select_server_invite_by_code( - &self, - code: &str, - ) -> Result { - let invite = sqlx::query_as!( - entity::server::invite::ServerInvite, - r#"SELECT * FROM "server_invite" WHERE "code" = $1"#, - code - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::InviteDoesNotExists)?; - - Ok(invite) - } - - pub async fn delete_server_invite_by_code( - &self, - code: &str, - ) -> Result> { - let invite = sqlx::query_as!( - entity::server::invite::ServerInvite, - r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#, - code - ) - .fetch_optional(&self.pool) - .await?; - - Ok(invite) - } - - pub async fn select_channel_messages_paginated( - &self, - channel_id: entity::channel::Id, - before: Option, - limit: i64, - ) -> Result> { - let messages = sqlx::query_as!( - entity::message::Message, - r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#, - channel_id, - before, - limit - ) - .fetch_all(&self.pool) - .await?; - - Ok(messages) - } - - pub async fn insert_channel_message( - &self, - user_id: entity::user::Id, - channel_id: entity::channel::Id, - content: &str, - ) -> Result { - let message = sqlx::query_as!( - entity::message::Message, - r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#, - channel_id, - user_id, - content - ) - .fetch_one(&self.pool) - .await?; - - Ok(message) - } - - pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result { - let file = sqlx::query_as!( - entity::file::File, - r#"SELECT * FROM "file" WHERE "id" = $1"#, - file_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::FileDoesNotExists)?; - - Ok(file) - } - - pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result { - let file = sqlx::query_as!( - entity::file::File, - r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#, - file_id - ) - .fetch_optional(&self.pool) - .await? - .ok_or(Error::FileDoesNotExists)?; - - Ok(file) - } - - pub async fn insert_file( - &self, - filename: &str, - content_type: &str, - size: usize, - ) -> Result { - let file = sqlx::query_as!( - entity::file::File, - r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#, - filename, - content_type, - size as i64 - ) - .fetch_one(&self.pool) - .await?; - - Ok(file) - } - - pub async fn insert_message_attachment( - &self, - message_id: entity::message::Id, - file_id: entity::file::Id, - ) -> Result<()> { - sqlx::query!( - r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#, - message_id, - file_id - ) - .execute(&self.pool) - .await?; - - Ok(()) - } - - pub async fn select_message_attachments( - &self, - message_id: entity::message::Id, - ) -> Result> { - let attachments = sqlx::query_as!( - entity::file::File, - r#"SELECT * FROM "file" WHERE "id" IN ( - SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1 - )"#, - message_id - ) - .fetch_all(&self.pool) - .await?; - - Ok(attachments) - } - - pub async fn select_related_user_ids( - &self, - user_id: entity::user::Id, - ) -> Result> { - #[derive(sqlx::FromRow)] - struct UserId { - user_id: entity::user::Id, - } - - let user_ids = - sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#) - .bind(user_id) - .fetch_all(&self.pool) - .await? - .into_iter() - .map(|row| row.user_id) - .collect(); - - Ok(user_ids) - } - - pub async fn procedure_create_dm_channel( - &self, - user1_id: entity::user::Id, - user2_id: entity::user::Id, - ) -> Result { - let channel_id = sqlx::query_scalar!( - r#"SELECT create_dm_channel($1, $2, $3)"#, - user1_id, - user2_id, - entity::channel::ChannelType::DirectMessage as i16 - ) - .fetch_one(&self.pool) - .await? - .expect("channel_id is null"); - - Ok(channel_id) - } - - pub async fn procedure_create_group_channel( - &self, - creator_id: entity::user::Id, - users: &[entity::user::Id], - ) -> Result { - let channel_id = sqlx::query_scalar!( - r#"SELECT create_group_channel($1, $2, $3)"#, - creator_id, - users, - entity::channel::ChannelType::DirectMessage as i16 - ) - .fetch_one(&self.pool) - .await? - .expect("channel_id is null"); - - Ok(channel_id) - } -} - -// file: src/entity/channel.rs -use serde::{Deserialize, Serialize}; - -use crate::entity::{message, server, user}; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Channel { - pub id: Id, - pub name: String, - pub r#type: ChannelType, - pub position: i16, - #[serde(skip_serializing_if = "Option::is_none")] - pub server_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub parent: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub owner_id: Option, - #[serde(skip_serializing_if = "Option::is_none")] - pub last_message_id: Option, -} - -#[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)] -#[non_exhaustive] -#[serde(rename_all = "snake_case")] -#[repr(i16)] -pub enum ChannelType { - ServerText = 1, - ServerVoice = 2, - ServerCategory = 3, - - DirectMessage = 4, - Group = 5, -} - -impl From for ChannelType { - fn from(value: i16) -> Self { - match value { - 1 => ChannelType::ServerText, - 2 => ChannelType::ServerVoice, - 3 => ChannelType::ServerCategory, - 4 => ChannelType::DirectMessage, - 5 => ChannelType::Group, - _ => ChannelType::ServerText, - } - } -} - -// file: src/entity/file.rs -use serde::Serialize; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -pub struct File { - pub id: Id, - pub filename: String, - pub content_type: String, - pub size: i64, -} - -// file: src/entity/message.rs -use crate::entity::{channel, user}; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow)] -pub struct Message { - pub id: Id, - pub author_id: user::Id, - pub channel_id: channel::Id, - pub content: String, -} - -// file: src/entity/mod.rs -pub mod file; -pub mod channel; -pub mod message; -pub mod server; -pub mod user; - -// file: src/entity/server/invite.rs -use chrono::{DateTime, Utc}; -use serde::Serialize; - -use crate::entity::{server, user}; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ServerInvite { - pub code: String, - pub server_id: server::Id, - pub inviter_id: Option, - pub expires_at: Option>, -} - -// file: src/entity/server/member.rs -use serde::Serialize; - -use crate::entity::{server, user}; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ServerMember { - pub id: Id, - pub server_id: server::Id, - pub user_id: user::Id, - pub nickname: Option, - pub avatar_url: Option, -} - -// file: src/entity/server/role.rs -use serde::Serialize; - -use crate::entity::server; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct ServerRole { - pub id: Id, - pub server_id: server::Id, - pub name: String, - pub color: Option, - pub display: bool, - pub permissions: serde_json::Value, - pub position: i16, -} - -// file: src/entity/server.rs -pub mod invite; -pub mod member; -pub mod role; - -use crate::entity::{file, user}; - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow)] -pub struct Server { - pub id: Id, - pub owner_id: user::Id, - pub name: String, - pub icon_id: Option, -} - -// file: src/entity/user.rs -use std::sync::LazyLock; - -use regex::Regex; -use crate::entity::file; - -pub static USERNAME_REGEX: LazyLock = - LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap()); - -pub type Id = uuid::Uuid; - -#[derive(Debug, Clone, sqlx::FromRow)] -pub struct User { - pub id: Id, - pub avatar_id: Option, - pub username: String, - pub display_name: Option, - pub email: String, - pub password_hash: String, - pub bot: bool, - pub system: bool, - pub settings: serde_json::Value, -} - -// file: src/jwt.rs -#![allow(unused)] - -use std::collections::HashSet; - -use chrono::{Duration, Local, Utc}; -use serde::de::DeserializeOwned; -use serde::{Deserialize, Serialize}; - -use crate::config; - -#[derive(Debug, Serialize, Deserialize)] -pub struct Claims { - #[serde(flatten)] - pub data: T, - pub iat: i64, -} - -pub fn generate_jwt(data: T, secret: &[u8]) -> Result { - let claims = Claims { - data, - iat: Utc::now().timestamp_millis(), - }; - - let token = jsonwebtoken::encode( - &jsonwebtoken::Header::default(), - &claims, - &jsonwebtoken::EncodingKey::from_secret(secret), - ) - .map_err(|_| Error::CouldNotEncodeToken)?; - - Ok(token) -} - -pub fn verify_jwt(token: &str, secret: &[u8]) -> Result { - tracing::debug!("verifying token: {}", token); - - let mut validation = jsonwebtoken::Validation::default(); - validation.set_required_spec_claims::(&[]); - - let token_data = jsonwebtoken::decode::>( - token, - &jsonwebtoken::DecodingKey::from_secret(secret), - &validation, - ) - .inspect_err(|err| { - tracing::error!("Failed to decode JWT: {:?}", err); - }) - .map_err(|_| Error::CouldNotVerifyToken)?; - - Ok(token_data.claims.data) -} - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::Error, derive_more::Display)] -pub enum Error { - CouldNotEncodeToken, - CouldNotVerifyToken, -} - -// file: src/log.rs -use std::io; - -use tracing::level_filters::LevelFilter; -use tracing_appender::non_blocking::WorkerGuard; -use tracing_subscriber::EnvFilter; -use tracing_subscriber::fmt::time::ChronoLocal; -use tracing_subscriber::layer::SubscriberExt; -use tracing_subscriber::util::SubscriberInitExt; - -pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> { - let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string()); - - let file = tracing_appender::rolling::daily("./logs", "log"); - - let (file, file_guard) = tracing_appender::non_blocking(file); - - let file_logger = tracing_subscriber::fmt::layer() - .with_ansi(false) - .with_timer(logger_timer.clone()) - .with_writer(file); - - let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout()); - - let stdout_logger = tracing_subscriber::fmt::layer() - .with_timer(logger_timer) - .with_writer(stdout); - #[cfg(debug_assertions)] - let stdout_logger = stdout_logger.pretty(); - - tracing_subscriber::registry() - .with( - EnvFilter::builder() - .with_default_directive(LevelFilter::INFO.into()) - .from_env_lossy(), - ) - .with(file_logger) - .with(stdout_logger) - .init(); - - tracing::info!("initialized logging"); - - Ok((file_guard, stdout_guard)) -} - -// file: src/main.rs -use std::sync::Arc; - -use argon2::Argon2; - -use crate::database::Database; -use crate::state::AppState; - -mod config; -mod database; -mod entity; -mod jwt; -mod log; -mod object_store; -mod state; -mod util; -mod web; -mod webrtc; - -#[tokio::main] -async fn main() -> anyhow::Result<()> { - let _guard = log::init_logging()?; - - let database = Database::connect(&config::config().database).await?; - let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?; - let state = AppState { - database, - object_store, - hasher: Arc::new(Argon2::default()), - gateway_state: Default::default(), - voice_rooms: Default::default(), - }; - - web::run(state).await?; - - Ok(()) -} - -// file: src/object_store.rs -use crate::config::ObjectStoreConfig; - -#[derive(Clone, derive_more::AsRef, derive_more::Deref)] -pub struct ObjectStore { - inner: Box, -} - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)] -pub enum Error { - #[from] - Credentials(s3::creds::error::CredentialsError), - - #[from] - S3(s3::error::S3Error), -} - -impl ObjectStore { - pub async fn connect(config: &ObjectStoreConfig) -> Result { - let region = s3::region::Region::Custom { - region: config.region.clone(), - endpoint: config.endpoint.origin().ascii_serialization(), - }; - - let credentials = s3::creds::Credentials::new( - Some(&config.access_key), - Some(&config.secret_key), - None, - None, - None, - )?; - - let mut bucket = - s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())? - .with_path_style(); - - if !bucket.exists().await? { - bucket = s3::bucket::Bucket::create_with_path_style( - &config.bucket, - region, - credentials, - s3::BucketConfiguration::default(), - ) - .await? - .bucket; - } - - Ok(Self { inner: bucket }) - } -} - -// file: src/state.rs -use std::collections::HashMap; -use std::sync::Arc; - -use argon2::Argon2; -use tokio::sync::{RwLock, mpsc}; -use uuid::Uuid; - -use crate::database::Database; -use crate::object_store::ObjectStore; -use crate::web::ws::gateway::{GatewayWsState, SessionKey, event}; -use crate::webrtc::WebRtcSignal; - -#[derive(Clone)] -pub struct AppState { - pub database: Database, - pub object_store: ObjectStore, - pub hasher: Arc>, - - pub gateway_state: Arc, - - pub voice_rooms: Arc>>>, -} - -#[derive(Debug, Default)] -pub struct GatewayState { - pub connected: scc::HashMap, -} - -impl AppState { - pub async fn register_gateway_connected_user( - &self, - user_id: Uuid, - session_key: SessionKey, - event_sender: mpsc::UnboundedSender, - ) { - self.gateway_state - .connected - .entry_async(user_id) - .await - .or_default() - .get_mut() - .instances - .insert(session_key, event_sender); - } - - pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) { - let is_empty = { - let mut entry = self - .gateway_state - .connected - .entry_async(user_id) - .await - .or_default(); - - let entry = entry.get_mut(); - - entry.instances.remove(session_id); - entry.instances.is_empty() - }; - - if is_empty { - self.gateway_state.connected.remove_async(&user_id).await; - } - } - - pub async fn register_voice_room( - &self, - room_id: Uuid, - sender: mpsc::UnboundedSender, - ) { - self.voice_rooms.write().await.insert(room_id, sender); - } - - pub async fn unregister_voice_room(&self, room_id: Uuid) { - self.voice_rooms.write().await.remove(&room_id); - } -} - -// file: src/util/mod.rs -pub mod tristate; - -use axum::extract::multipart::Field; -use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError}; -use serde::{Deserialize, Serialize}; - -use crate::entity; - -pub fn file_id_to_url(file_id: &entity::file::Id) -> Option { - Some( - crate::config::config() - .server - .hostname - .join("files/") - .ok()? - .join(&file_id.to_string()) - .ok()? - .to_string(), - ) -} - -#[derive(Debug, derive_more::Deref)] -pub struct SerdeFieldData(pub FieldData); - -#[async_trait::async_trait] -impl TryFromField for SerdeFieldData { - async fn try_from_field( - field: Field<'_>, - limit_bytes: Option, - ) -> Result { - let field = FieldData::try_from_field(field, limit_bytes).await?; - - Ok(Self(field)) - } -} - -impl Serialize for SerdeFieldData { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - #[derive(Serialize)] - #[serde(rename_all = "camelCase")] - struct Metadata { - name: Option, - file_name: Option, - content_type: Option, - } - - let metadata = Metadata { - name: self.0.metadata.name.clone(), - file_name: self.0.metadata.file_name.clone(), - content_type: self.0.metadata.content_type.clone(), - }; - - metadata.serialize(serializer) - } -} - -pub fn serialize_duration_seconds( - duration: &std::time::Duration, - serializer: S, -) -> Result -where - S: serde::Serializer, -{ - let seconds = duration.as_secs(); - seconds.serialize(serializer) -} - -pub fn deserialize_duration_seconds<'de, D>( - deserializer: D, -) -> Result -where - D: serde::Deserializer<'de>, -{ - let seconds = u64::deserialize(deserializer)?; - Ok(std::time::Duration::from_secs(seconds)) -} - -pub fn serialize_duration_seconds_option( - duration: &Option, - serializer: S, -) -> Result -where - S: serde::Serializer, -{ - match duration { - Some(duration) => serialize_duration_seconds(duration, serializer), - None => serializer.serialize_none(), - } -} - -pub fn deserialize_duration_seconds_option<'de, D>( - deserializer: D, -) -> Result, D::Error> -where - D: serde::Deserializer<'de>, -{ - Ok(deserialize_duration_seconds(deserializer).ok()) -} - -// file: src/util/tristate.rs -use std::ops::Deref; - -use serde::{Deserialize, Deserializer, Serialize}; -use validator::ValidateLength; - -#[derive(Debug, Clone)] -pub enum TriState { - Value(T), - None, - Absent, -} - -impl TriState { - pub fn as_deref(&self) -> TriState<&D> - where - T: Deref, - { - match self { - TriState::Value(v) => TriState::Value(v.deref()), - TriState::None => TriState::None, - TriState::Absent => TriState::Absent, - } - } - - pub fn is_value(&self) -> bool { - matches!(self, TriState::Value(_)) - } - - pub fn is_none(&self) -> bool { - matches!(self, TriState::None) - } - - pub fn is_absent(&self) -> bool { - matches!(self, TriState::Absent) - } - - pub fn as_option(&self) -> Option<&T> { - match self { - TriState::Value(v) => Some(v), - TriState::None => None, - TriState::Absent => None, - } - } - - pub fn into_option(self) -> Option { - match self { - TriState::Value(v) => Some(v), - TriState::None => None, - TriState::Absent => None, - } - } -} - -impl Default for TriState { - fn default() -> Self { - TriState::Absent - } -} - -pub(crate) struct TriStateFieldVisitor { - marker: std::marker::PhantomData, -} - -impl Serialize for TriState -where - T: Serialize, -{ - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - match self { - TriState::Value(v) => v.serialize(serializer), - TriState::None => serializer.serialize_none(), - TriState::Absent => serializer.serialize_unit(), - } - } -} - -impl<'de, T> Deserialize<'de> for TriState -where - T: Deserialize<'de>, -{ - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - deserializer.deserialize_option(TriStateFieldVisitor:: { - marker: std::marker::PhantomData, - }) - } -} -impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor -where - T: Deserialize<'de>, -{ - type Value = TriState; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("TriStateField") - } - - #[inline] - fn visit_none(self) -> Result, E> - where - E: serde::de::Error, - { - Ok(TriState::None) - } - - #[inline] - fn visit_some(self, deserializer: D) -> Result - where - D: Deserializer<'de>, - { - T::deserialize(deserializer).map(TriState::Value) - } - - #[inline] - fn visit_unit(self) -> Result, E> - where - E: serde::de::Error, - { - Ok(TriState::None) - } -} - -impl ValidateLength for TriState -where - T: ValidateLength, -{ - fn length(&self) -> Option { - match self { - TriState::Value(v) => v.length(), - TriState::None => None, - TriState::Absent => None, - } - } -} - -// file: src/web/context.rs -use axum::extract::FromRequestParts; -use axum::http::request::Parts; - -use crate::entity; - -#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)] -pub struct UserContext { - pub user_id: entity::user::Id, -} - -pub type UserContextResult = Result; - -#[derive(Debug, Clone, Copy, derive_more::Display)] -pub enum Error { - NotInRequest, - NotInHeader, - BadCharacters, - WrongTokenType, - BadToken, - Model, -} - -impl FromRequestParts for UserContext { - type Rejection = super::error::Error; - - async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result { - let context = parts - .extensions - .get::() - .cloned() - .ok_or(Error::NotInRequest)??; - - Ok(context) - } -} - -// file: src/web/entity/file.rs -use serde::Serialize; - -use crate::entity::file::Id; -use crate::util; - -#[derive(Debug, Clone, sqlx::FromRow, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct File { - pub id: Id, - pub filename: String, - pub content_type: String, - pub size: i64, - pub url: String, -} - -impl From for File { - fn from(file: crate::entity::file::File) -> Self { - Self { - id: file.id, - filename: file.filename, - content_type: file.content_type, - size: file.size, - url: util::file_id_to_url(&file.id).unwrap_or_default(), - } - } -} - -// file: src/web/entity/message.rs -use serde::Serialize; - -use crate::entity::message::Id; -use crate::entity::{channel, user}; - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Message { - pub id: Id, - pub author_id: user::Id, - pub channel_id: channel::Id, - pub content: String, - pub created_at: chrono::DateTime, - pub attachments: Vec, -} - -impl Message { - pub fn from_message_with_attachments( - message: crate::entity::message::Message, - attachments: Vec, - ) -> Self { - Self { - id: message.id, - author_id: message.author_id, - channel_id: message.channel_id, - content: message.content, - created_at: message - .id - .get_timestamp() - .as_ref() - .map(uuid::Timestamp::to_unix) - .map(|(secs, nsecs)| { - chrono::DateTime::::from_timestamp(secs as i64, nsecs) - }) - .flatten() - .unwrap_or_default(), - attachments, - } - } -} -// file: src/web/entity/mod.rs -pub mod file; -pub mod message; -pub mod server; -pub mod user; - -// file: src/web/entity/server.rs -use serde::Serialize; - -use crate::entity::server::Id; -use crate::entity::user; -use crate::util; - -#[derive(Debug, Clone, Serialize)] -#[serde(rename_all = "camelCase")] -pub struct Server { - pub id: Id, - pub owner_id: user::Id, - pub name: String, - pub icon_url: Option, -} - -impl From for Server { - fn from(server: crate::entity::server::Server) -> Self { - Self { - id: server.id, - owner_id: server.owner_id, - name: server.name, - icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(), - } - } -} - -// file: src/web/entity/user.rs -use crate::entity::user; -use crate::util; - -#[derive(serde::Serialize, Debug)] -#[serde(rename_all = "camelCase")] -pub struct FullUser { - pub id: user::Id, - pub avatar_url: Option, - pub username: String, - pub display_name: Option, - pub email: String, - pub bot: bool, - pub system: bool, - pub settings: serde_json::Value, -} - -#[derive(Debug, Clone, serde::Serialize)] -#[serde(rename_all = "camelCase")] -pub struct PartialUser { - pub id: user::Id, - pub avatar_url: Option, - pub username: String, - pub display_name: Option, - pub bot: bool, - pub system: bool, -} - -impl From for FullUser { - fn from(user: user::User) -> Self { - Self { - id: user.id, - avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), - username: user.username, - display_name: user.display_name, - email: user.email, - bot: user.bot, - system: user.system, - settings: user.settings, - } - } -} - -impl From for PartialUser { - fn from(user: user::User) -> Self { - Self { - id: user.id, - avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(), - username: user.username, - display_name: user.display_name, - bot: user.bot, - system: user.system, - } - } -} - -// file: src/web/error.rs -use std::sync::Arc; - -use axum::http::StatusCode; -use axum::response::IntoResponse; - -use crate::web::context; -use crate::{database, jwt, object_store}; - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::From)] -pub enum Error { - #[from] - Context(context::Error), - - #[from] - Jwt(jwt::Error), - - #[from] - Hash(argon2::password_hash::Error), - - #[from] - Database(database::Error), - - #[from] - ObjectStore(object_store::Error), - - #[from] - Json(serde_json::error::Error), - - #[from] - JsonRejection(axum::extract::rejection::JsonRejection), - - #[from] - Client(ClientError), -} - -#[derive(Debug, Clone, derive_more::From, serde::Serialize)] -#[serde(tag = "message", content = "details")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum ClientError { - UserAlreadyExists, - UserDoesNotExist, - - ServerDoesNotExist, - - ChannelDoesNotExist, - - MessageDoesNotExist, - - NotAuthorized, - WrongPassword, - NotAllowed, - - InvalidJson(JsonRejection), - - #[from] - ValidationFailed(validator::ValidationErrors), - - InternalServerError, - - Unknown, -} - -#[derive(derive_more::Debug, Clone, serde::Serialize)] -pub struct JsonRejection { - #[serde(skip)] - status: StatusCode, - - reason: String, -} - -impl From<&axum::extract::rejection::JsonRejection> for JsonRejection { - fn from(value: &axum::extract::rejection::JsonRejection) -> Self { - use std::error::Error; - - use axum::extract::rejection::JsonRejection::*; - - let reason = match value { - JsonDataError(e) => e.source().map(ToString::to_string), - JsonSyntaxError(e) => e.source().map(ToString::to_string), - MissingJsonContentType(e) => e.source().map(ToString::to_string), - BytesRejection(e) => e.source().map(ToString::to_string), - _ => None, - } - .unwrap_or_else(|| value.body_text()); - - Self { - status: value.status(), - reason, - } - } -} - -impl Error { - pub fn as_client_error(&self) -> ClientError { - match self { - Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized, - Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists, - Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist, - - Error::Database(database::Error::ServerDoesNotExists) => { - ClientError::ServerDoesNotExist - }, - - Error::Database(database::Error::ChannelDoesNotExists) => { - ClientError::ChannelDoesNotExist - }, - - Error::Database(database::Error::MessageDoesNotExists) => { - ClientError::MessageDoesNotExist - }, - // Error::WrongPassword => ClientError::WrongPassword, - // Error::NotAllowed => ClientError::NotAllowed, - Error::JsonRejection(e) => ClientError::InvalidJson(e.into()), - Error::Client(client_error) => client_error.clone(), - _ => ClientError::InternalServerError, - } - } -} - -impl IntoResponse for Error { - fn into_response(self) -> axum::response::Response { - let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response(); - - response.extensions_mut().insert(Arc::new(self)); - - response - } -} - -impl ClientError { - pub fn status_code(&self) -> StatusCode { - match self { - ClientError::UserAlreadyExists => StatusCode::CONFLICT, - ClientError::UserDoesNotExist - | ClientError::ServerDoesNotExist - | ClientError::ChannelDoesNotExist - | ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND, - - ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED, - - ClientError::NotAllowed => StatusCode::FORBIDDEN, - - ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY, - - ClientError::InvalidJson(e) => e.status, - - _ => StatusCode::INTERNAL_SERVER_ERROR, - } - } -} - -// file: src/web/middleware/auth.rs -use axum::extract::{Request, State}; -use axum::middleware::Next; -use axum::response::{IntoResponse, Response}; -use axum::{Extension, RequestExt}; -use axum_extra::TypedHeader; -use axum_extra::headers::Authorization; -use axum_extra::headers::authorization::Bearer; -use axum_extra::typed_header::TypedHeaderRejectionReason; - -use crate::jwt; -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::{self, context}; - -pub async fn require_context( - Extension(context): Extension, - request: Request, - next: Next, -) -> web::error::Result { - context?; - - Ok(next.run(request).await) -} - -pub async fn resolve_context( - State(state): State, - mut request: Request, - next: Next, -) -> impl IntoResponse { - let context = get_context(&state, &mut request).await; - - request.extensions_mut().insert(context); - - next.run(request).await -} - -async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult { - let bearer = request - .extract_parts::>>() - .await - .map_err(|error| match error.reason() { - TypedHeaderRejectionReason::Missing => context::Error::NotInHeader, - TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType, - _ => context::Error::NotInHeader, - })?; - - let token = bearer.token(); - - let context = get_context_from_token(state, token).await?; - - Ok(context) -} - -pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult { - let context = jwt::verify_jwt::( - token, - crate::config::config().security.auth_secret.as_ref(), - ) - .map_err(|_| context::Error::BadToken)?; - - let _ = state - .database - .select_user_by_id(context.user_id) - .await - .map_err(|_| context::Error::BadToken)?; - - Ok(context) -} - -// file: src/web/middleware/mod.rs -mod auth; -mod response_map; - -pub use auth::*; -pub use response_map::*; - -// file: src/web/middleware/response_map.rs -use std::sync::Arc; - -use axum::Json; -use axum::extract::Request; -use axum::middleware::Next; -use axum::response::IntoResponse; -use serde_json::json; - -use crate::web; - -pub async fn response_map(request: Request, next: Next) -> impl IntoResponse { - let response = next.run(request).await; - - let error = response.extensions().get::>(); - - if error.is_some() { - tracing::error!("{:?}", error); - } - - let error_response = error.map(|error| { - let client_error = error.as_client_error(); - - ( - client_error.status_code(), - Json(json!({"error": client_error})), - ) - .into_response() - }); - - error_response.unwrap_or(response) -} - -// file: src/web/mod.rs -mod context; -mod entity; -mod error; -mod middleware; -mod route; -pub mod ws; - -use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin}; - -pub use self::error::{Error, Result}; -use crate::{config, state}; - -pub async fn run(state: state::AppState) -> anyhow::Result<()> { - let config = config::config(); - - let addr: std::net::SocketAddr = (config.server.host, config.server.port).into(); - let listener = tokio::net::TcpListener::bind(addr).await?; - tracing::info!("listening on {}", addr); - - axum::serve(listener, router(state)) - .with_graceful_shutdown(shutdown_signal()) - .await?; - - Ok(()) -} - -fn router(state: state::AppState) -> axum::Router { - use axum::Router; - use axum::routing::*; - - use self::route::*; - - let cors = tower_http::cors::CorsLayer::new() - .allow_origin(AllowOrigin::any()) - .allow_methods(AllowMethods::any()) - .allow_headers(AllowHeaders::any()); - - Router::new() - // websocket - .route("/gateway/ws", get(ws::gateway::ws_handler)) - .route("/voice/ws", get(ws::voice::ws_handler)) - // file - .route("/files/{file_id}", get(file::get)) - // api - .nest( - "/api/v1", - Router::new() - .route("/auth/login", post(auth::login)) - .route("/auth/register", post(auth::register)) - .merge(protected_router()) - .route_layer(axum::middleware::from_fn_with_state( - state.clone(), - middleware::resolve_context, - )), - ) - // middleware - .layer(axum::middleware::from_fn(middleware::response_map)) - .layer(cors) - .layer(tower_http::trace::TraceLayer::new_for_http()) - .with_state(state.clone()) -} - -fn protected_router() -> axum::Router { - use axum::Router; - use axum::routing::*; - - use self::route::*; - - Router::new() - // user - .route("/users/@me", get(user::me)) - .route("/users/@me", patch(user::patch)) - .route("/users/@me/channels", get(user::channel::list)) - .route("/users/@me/channels", post(user::channel::create)) - .route("/users/{id}", get(user::get_by_id)) - // channel - .route( - "/channels/{channel_id}/messages", - get(channel::message::page), - ) - .route( - "/channels/{channel_id}/messages", - post(channel::message::create), - ) - // server - .route("/servers", get(server::list)) - .route("/servers", post(server::create)) - .route("/servers/{server_id}", get(server::get)) - .route("/servers/{server_id}", delete(server::delete)) - .route("/servers/{server_id}/channels", get(server::channel::list)) - .route( - "/servers/{server_id}/channels", - post(server::channel::create), - ) - .route( - "/servers/{server_id}/channels/{channel_id}", - get(server::channel::get), - ) - .route( - "/servers/{server_id}/channels/{channel_id}", - delete(server::channel::delete), - ) - // invite - .route("/servers/{server_id}/invites", post(server::invite::create)) - .route("/invites/{code}", get(server::invite::get)) - // file - .route("/files", post(file::upload)) - // middleware - .route_layer(axum::middleware::from_fn(middleware::require_context)) -} - -async fn shutdown_signal() { - _ = tokio::signal::ctrl_c().await; -} - -// file: src/web/route/auth/login.rs -use argon2::{PasswordHash, PasswordVerifier}; -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; -use serde::{Deserialize, Serialize}; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::{jwt, web}; -use crate::web::entity::user::FullUser; - -#[derive(Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct LoginPayload { - username: String, - password: String, -} - -#[derive(Serialize)] -#[serde(rename_all = "camelCase")] -pub struct LoginResponse { - user: FullUser, - token: String, -} - -pub async fn login( - State(state): State, - Json(payload): Json, -) -> web::Result { - let user = state - .database - .select_user_by_username(&payload.username) - .await?; - - let password_hash = PasswordHash::new(&user.password_hash)?; - - state - .hasher - .verify_password(payload.password.as_bytes(), &password_hash) - .map_err(|_| web::error::ClientError::WrongPassword)?; - - let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?; - - let response = LoginResponse { - user: user.into(), - token, - }; - - Ok(Json(response)) -} - -// file: src/web/route/auth/mod.rs -pub mod login; -pub mod register; - -pub use login::*; -pub use register::*; - -// file: src/web/route/auth/register.rs -use argon2::password_hash::rand_core::OsRng; -use argon2::password_hash::{PasswordHasher, SaltString}; -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use serde::Deserialize; -use validator::Validate; - -use crate::state::AppState; -use crate::web; -use crate::web::entity::user::FullUser; -use crate::web::error::ClientError; - -#[derive(Validate, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct RegisterPayload { - #[validate(regex( - path = "crate::entity::user::USERNAME_REGEX", - code = "invalid_username" - ))] - #[validate(length(min = 3, max = 64))] - username: String, - - #[validate(length(min = 1, max = 64))] - display_name: Option, - - #[validate(email)] - #[validate(length(min = 3, max = 128))] - email: String, - - #[validate(length(min = 3, max = 128))] - password: String, -} - -pub async fn register( - State(state): State, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - payload.validate().map_err(ClientError::ValidationFailed)?; - - let salt = SaltString::generate(&mut OsRng); - - let password_hash = state - .hasher - .hash_password(payload.password.as_bytes(), &salt)? - .to_string(); - - let user = state - .database - .insert_user( - &payload.username, - payload.display_name.as_ref().map(String::as_str), - &payload.email, - &password_hash, - ) - .await?; - - let system_user = state.database.select_user_by_username("system").await?; - if system_user.system { - state - .database - .procedure_create_dm_channel(user.id, system_user.id) - .await?; - } - - Ok(Json(FullUser::from(user))) -} - -// file: src/web/route/channel/message/create.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::entity::file::File; -use crate::web::entity::message::Message; -use crate::web::ws; -use crate::{entity, web}; - -#[derive(Debug, serde::Deserialize, Validate)] -#[serde(rename_all = "camelCase")] -#[validate(schema(function = "validate_create_payload"))] -pub struct CreatePayload { - #[validate(length(min = 0, max = 2000))] - pub content: String, - - #[serde(default)] - #[validate(length(min = 0, max = 10))] - pub attachments: Vec, -} - -fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { - if payload.content.is_empty() && payload.attachments.is_empty() { - return Err(validator::ValidationError::new( - "at_least_one_field_required", - )); - } - - Ok(()) -} - -pub async fn create( - State(state): State, - context: UserContext, - Path(channel_id): Path, - Json(payload): Json, -) -> web::Result { - // TODO: check permissions - match payload.validate() { - Ok(_) => {}, - Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), - }; - - let message = state - .database - .insert_channel_message(context.user_id, channel_id, &payload.content) - .await?; - - for attachment in payload.attachments.iter() { - state - .database - .insert_message_attachment(message.id, *attachment) - .await?; - } - - let attachments = state - .database - .select_message_attachments(message.id) - .await? - .into_iter() - .map(File::from) - .collect::>(); - - let message = Message::from_message_with_attachments(message, attachments); - - // TODO: check permissions - ws::gateway::util::send_message_channel( - state, - message.channel_id, - ws::gateway::event::Event::AddMessage { - channel_id, - message: message.clone(), - }, - ); - - Ok(Json(message)) -} - -// file: src/web/route/channel/message/mod.rs -mod page; -mod create; - -pub use page::page; -pub use create::create; -// file: src/web/route/channel/message/page.rs -use axum::Json; -use axum::extract::{Path, Query, State}; -use axum::response::IntoResponse; -use serde::Deserialize; -use tracing::error; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::entity::message::Message; -use crate::{entity, web}; - -#[derive(Debug, Deserialize, Validate)] -#[serde(rename_all = "camelCase")] -pub struct PageParams { - #[serde(default = "limit_default")] - #[validate(range(min = 1, max = 100))] - pub limit: u32, - #[serde(default)] - pub before: Option, -} - -fn limit_default() -> u32 { - 50 -} - -pub async fn page( - State(state): State, - context: UserContext, - Path(channel_id): Path, - Query(params): Query, -) -> web::Result { - // TODO: check permissions - match params.validate() { - Ok(_) => {}, - Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), - }; - - let messages_futures = state - .database - .select_channel_messages_paginated(channel_id, params.before, params.limit as i64) - .await? - .into_iter() - .map(|message| async { - let attachments = state - .database - .select_message_attachments(message.id) - .await? - .into_iter() - .map(web::entity::file::File::from) - .collect::>(); - - Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments)) - }); - - let messages = futures::future::try_join_all(messages_futures).await?; - - Ok(Json(messages)) -} - -// file: src/web/route/channel/mod.rs -pub mod message; - -// file: src/web/route/file/get.rs -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::{entity, object_store, web}; - -pub async fn get( - State(state): State, - Path(file_id): Path, -) -> web::Result { - let file = match state.database.select_file_by_id(file_id).await { - Ok(file) => file, - Err(e) => { - return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); - }, - }; - - let data = match state - .object_store - .get_object_stream(&file.id.to_string()) - .await - { - Ok(data) => data, - Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => { - let _ = state.database.delete_file_by_id(file.id).await?; - - return Ok(axum::http::StatusCode::NOT_FOUND.into_response()); - }, - Err(e) => { - return Err(object_store::Error::from(e).into()); - }, - }; - - let headers = axum::response::AppendHeaders([ - (axum::http::header::CONTENT_TYPE, file.content_type.clone()), - ( - axum::http::header::CONTENT_DISPOSITION, - format!("filename=\"{}\"", file.filename), - ), - ]); - - Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response()) -} - -// file: src/web/route/file/mod.rs -mod get; -mod upload; - -pub use get::get; -pub use upload::upload; - -// file: src/web/route/file/upload.rs -use axum::Json; -use axum::body::Bytes; -use axum::extract::State; -use axum::response::IntoResponse; -use axum_typed_multipart::{TryFromMultipart, TypedMultipart}; -use validator::Validate; - -use crate::state::AppState; -use crate::util::SerdeFieldData; -use crate::{object_store, web}; - -#[derive(Debug, Validate, TryFromMultipart)] -#[try_from_multipart(rename_all = "camelCase")] -pub struct UploadPayload { - #[form_data(limit = "50MB")] - #[validate(length(min = 1, max = 16))] - files: Vec>, -} - -pub async fn upload( - State(state): State, - TypedMultipart(payload): TypedMultipart, -) -> web::Result { - match payload.validate() { - Ok(_) => {}, - Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), - }; - - let mut file_ids = Vec::new(); - - for file in payload.files { - let db_file = state - .database - .insert_file( - file.metadata - .file_name - .as_deref() - .unwrap_or_else(|| "unknown"), - file.metadata.content_type.as_deref().unwrap_or_default(), - file.contents.len(), - ) - .await?; - - state - .object_store - .put_object(&db_file.id.to_string(), &file.contents) - .await - .map_err(object_store::Error::from)?; - - file_ids.push(db_file.id); - } - - Ok(Json(file_ids)) -} - -// file: src/web/route/mod.rs -pub mod auth; -pub mod channel; -pub mod file; -pub mod server; -pub mod user; - -// file: src/web/route/server/channel/create.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use serde::Deserialize; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::error::ClientError; -use crate::web::ws; -use crate::{entity, web}; - -#[derive(Debug, Validate, Deserialize)] -pub struct CreatePayload { - #[validate(length(min = 1, max = 32))] - name: String, - - #[validate(custom(function = "validate_server_channel_type"))] - r#type: entity::channel::ChannelType, -} - -fn validate_server_channel_type( - r#type: &entity::channel::ChannelType, -) -> Result<(), validator::ValidationError> { - match r#type { - entity::channel::ChannelType::ServerText => Ok(()), - entity::channel::ChannelType::ServerVoice => Ok(()), - entity::channel::ChannelType::ServerCategory => Ok(()), - _ => Err(validator::ValidationError::new("invalid_channel_type")), - } -} - -pub async fn create( - State(state): State, - context: UserContext, - Path(server_id): Path, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - payload.validate().map_err(ClientError::ValidationFailed)?; - - // TODO: check permissions - let server = state.database.select_server_by_id(server_id).await?; - - if server.owner_id != context.user_id { - return Err(web::error::ClientError::NotAllowed.into()); - } - - let channel = state - .database - .insert_server_channel(&payload.name, 0, payload.r#type, server_id, None) - .await?; - - ws::gateway::util::send_message_server( - state, - server_id, - ws::gateway::event::Event::AddServerChannel { - channel: channel.clone(), - }, - ); - - Ok(Json(channel)) -} - -// file: src/web/route/server/channel/delete.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::ws; -use crate::webrtc::WebRtcSignal; -use crate::{entity, web}; - -pub async fn delete( - State(state): State, - context: UserContext, - Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, -) -> web::Result { - // TODO: check permissions - - let channel = state.database.select_channel_by_id(channel_id).await?; - let server = state.database.select_server_by_id(server_id).await?; - - if server.owner_id != context.user_id { - return Err(web::error::ClientError::NotAllowed.into()); - } - - if let Some(channel_server_id) = channel.server_id { - if channel_server_id != server_id { - return Err(web::error::ClientError::NotAllowed.into()); - } - } else { - return Err(web::error::ClientError::NotAllowed.into()); - } - - let channel = state.database.delete_channel_by_id(channel_id).await?; - - ws::gateway::util::send_message_server( - state.clone(), - server_id, - ws::gateway::event::Event::RemoveServerChannel { - server_id: server_id.clone(), - channel_id: channel.id.clone(), - }, - ); - - tokio::spawn(async move { - let voice_rooms = state.voice_rooms.read().await; - if let Some(voice_room) = voice_rooms.get(&channel_id) { - let _ = voice_room.send(WebRtcSignal::Close); - } - }); - - Ok(Json(channel)) -} - -// file: src/web/route/server/channel/get.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::{entity, web}; - -pub async fn get( - State(state): State, - context: UserContext, - Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>, -) -> web::Result { - // TODO: check permissions - - let channel = state.database.select_channel_by_id(channel_id).await?; - - if let Some(channel_server_id) = channel.server_id { - if channel_server_id != server_id { - return Err(web::error::ClientError::NotAllowed.into()); - } - } else { - return Err(web::error::ClientError::NotAllowed.into()); - } - - Ok(Json(channel)) -} - -// file: src/web/route/server/channel/list.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::{entity, web}; - -pub async fn list( - State(state): State, - context: UserContext, - Path(server_id): Path, -) -> web::Result { - let channels = state.database.select_server_channels(server_id).await?; - - Ok(Json(channels)) -} - -// file: src/web/route/server/channel/mod.rs -mod create; -mod delete; -mod get; -mod list; - -pub use create::create; -pub use delete::delete; -pub use get::get; -pub use list::list; - -// file: src/web/route/server/create.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use axum_typed_multipart::TryFromMultipart; -use serde::Deserialize; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::error::ClientError; -use crate::web::ws; -use crate::{entity, web}; -use crate::web::entity::server::Server; - -#[derive(Debug, Validate, Deserialize)] -#[serde(rename_all = "camelCase")] -pub struct CreatePayload { - #[validate(length(min = 1, max = 32))] - name: String, - - icon_id: Option, -} - -pub async fn create( - State(state): State, - context: UserContext, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - payload.validate().map_err(ClientError::ValidationFailed)?; - - let server = state - .database - .insert_server(&payload.name, payload.icon_id, context.user_id) - .await?; - - let role = state - .database - .insert_server_role( - Some(server.id.into()), - server.id, - "@everyone", - None, - false, - serde_json::json!({}), - 0, - ) - .await?; - - let member = state - .database - .insert_server_member(server.id, context.user_id) - .await?; - - state - .database - .insert_server_member_role(member.id, role.id) - .await?; - - let server = Server::from(server); - - ws::gateway::util::send_message( - state, - context.user_id, - ws::gateway::event::Event::AddServer { - server: server.clone(), - }, - ); - - Ok(Json(server)) -} - -// file: src/web/route/server/delete.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::ws; -use crate::webrtc::WebRtcSignal; -use crate::{entity, web}; -use crate::web::entity::server::Server; - -pub async fn delete( - State(state): State, - context: UserContext, - Path(server_id): Path, -) -> web::Result { - let server = state.database.select_server_by_id(server_id).await?; - - if server.owner_id != context.user_id { - return Err(web::error::ClientError::NotAllowed.into()); - } - - let members = state - .database - .select_server_members(server_id) - .await? - .iter() - .map(|u| u.id) - .collect::>(); - - let channels = state - .database - .select_server_channels(server_id) - .await? - .iter() - .map(|c| c.id) - .collect::>(); - - let state_clone = state.clone(); - tokio::spawn(async move { - let voice_rooms = state_clone.voice_rooms.read().await; - for channel_id in channels { - if let Some(voice_room) = voice_rooms.get(&channel_id) { - let _ = voice_room.send(WebRtcSignal::Close); - } - } - }); - - let server = state.database.delete_server_by_id(server_id).await?; - - ws::gateway::util::send_message_many( - state.clone(), - &members, - ws::gateway::event::Event::RemoveServer { - server_id: server.id, - }, - ); - - Ok(Json(Server::from(server))) -} - -// file: src/web/route/server/get.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::{entity, web}; -use crate::web::entity::server::Server; - -pub async fn get( - State(state): State, - Path(server_id): Path, -) -> web::Result { - // TODO: check permissions - - let server = state.database.select_server_by_id(server_id).await?; - - Ok(Json(Server::from(server))) -} - -// file: src/web/route/server/invite/create.rs -use axum::Json; -use axum::extract::{Path, Query, State}; -use axum::response::IntoResponse; -use base64::Engine; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::{entity, web}; - -#[derive(serde::Deserialize, Debug)] -pub struct CreateParams { - #[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")] - #[serde(default)] - pub expires_in: Option, -} - -pub async fn create( - State(state): State, - context: UserContext, - Path(server_id): Path, - Query(params): Query, -) -> web::Result { - // TODO: check permissions - - let code = { - use rand::Rng; - - let mut rng = rand::rng(); - let mut code = [0u8; 32]; - rng.fill(&mut code); - base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code) - }; - - let expires_at = params.expires_in.map(|d| { - let now = chrono::Utc::now(); - now + chrono::Duration::from_std(d).expect("valid duration") - }); - - let invite = state - .database - .insert_server_invite(&code, server_id, Some(context.user_id), expires_at) - .await?; - - Ok(Json(invite)) -} - -// file: src/web/route/server/invite/get.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::ws; -use crate::{database, web}; -use crate::web::entity::server::Server; - -pub async fn get( - State(state): State, - context: UserContext, - Path(code): Path, -) -> web::Result { - let invite = state.database.select_server_invite_by_code(&code).await?; - let server = state.database.select_server_by_id(invite.server_id).await?; - - let member = match state - .database - .insert_server_member(invite.server_id, context.user_id) - .await - { - Ok(member) => member, - Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))), - Err(e) => return Err(e.into()), - }; - - state - .database - .insert_server_member_role(member.id, invite.server_id) - .await?; - - let user = state.database.select_user_by_id(context.user_id).await?; - - ws::gateway::util::send_message_server( - state, - invite.server_id, - ws::gateway::event::Event::AddServerMember { - server_id: server.id, - member: user.into(), - }, - ); - - Ok(Json(Server::from(server))) -} - -// file: src/web/route/server/invite/mod.rs -mod create; -mod get; - -pub use create::create; -pub use get::get; - -// file: src/web/route/server/list.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web; -use crate::web::context::UserContext; -use crate::web::entity::server::Server; - -pub async fn list( - State(state): State, - context: UserContext, -) -> web::Result { - let servers = state - .database - .select_user_servers(context.user_id) - .await? - .into_iter() - .map(Server::from) - .collect::>(); - - Ok(Json(servers)) -} - -// file: src/web/route/server/mod.rs -pub mod channel; -mod create; -mod delete; -mod get; -pub mod invite; -mod list; - -pub use create::create; -pub use delete::delete; -pub use get::get; -pub use list::list; - -// file: src/web/route/user/channel/create.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use serde::Deserialize; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::entity::user::PartialUser; -use crate::web::route::user::channel::RecipientChannel; -use crate::web::ws; -use crate::{entity, web}; - -#[derive(Debug, Validate, Deserialize)] -pub struct CreatePayload { - #[validate(length(min = 1, max = 32))] - recipients: Vec, -} - -pub async fn create( - State(state): State, - context: UserContext, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - match payload.validate() { - Ok(_) => {}, - Err(err) => { - return Err(web::error::ClientError::ValidationFailed(err).into()); - }, - } - - let channel_id = match payload.recipients.len() { - 1 => { - let recipient = payload.recipients[0]; - state - .database - .procedure_create_dm_channel(context.user_id, recipient) - .await? - }, - _ => { - state - .database - .procedure_create_group_channel(context.user_id, &payload.recipients) - .await? - }, - }; - - let channel = state.database.select_channel_by_id(channel_id).await?; - - let recipients = state - .database - .select_channel_recipients(channel_id) - .await? - .unwrap_or_default() - .into_iter() - .map(|user| user.id) - .collect::>(); - - let recipient_channels = RecipientChannel { - channel: channel.clone(), - recipients: recipients.clone(), - }; - - ws::gateway::util::send_message_channel( - state, - channel_id, - ws::gateway::event::Event::AddDmChannel { - channel, - recipients, - }, - ); - - Ok(Json(recipient_channels)) -} - -// file: src/web/route/user/channel/list.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web; -use crate::web::context::UserContext; -use crate::web::route::user::channel::RecipientChannel; - -pub async fn list( - State(state): State, - context: UserContext, -) -> web::Result { - let channels = state.database.select_user_channels(context.user_id).await?; - - let mut recipient_channels = Vec::new(); - for channel in channels { - let recipients = state.database.select_channel_recipients(channel.id).await?; - - let recipients = match recipients { - Some(recipients) => recipients - .into_iter() - .map(|user| user.id) - .collect(), - None => { - continue; - }, - }; - - recipient_channels.push(RecipientChannel { - channel, - recipients, - }); - } - - Ok(Json(recipient_channels)) -} - -// file: src/web/route/user/channel/mod.rs -mod create; -mod list; - -pub use create::create; -pub use list::list; -use serde::Serialize; - -use crate::entity::{channel, user}; - -#[derive(Debug, Serialize)] -#[serde(rename_all = "camelCase")] -struct RecipientChannel { - #[serde(flatten)] - pub channel: channel::Channel, - pub recipients: Vec, -} - -// file: src/web/route/user/get.rs -use axum::Json; -use axum::extract::{Path, State}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web; -use crate::web::entity::user::PartialUser; - -pub async fn get_by_id( - State(state): State, - Path(user_id): Path, -) -> web::Result { - let user = state.database.select_user_by_id(user_id).await?; - - Ok(Json(PartialUser::from(user))) -} - -// file: src/web/route/user/me.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web; -use crate::web::context::UserContext; -use crate::web::entity::user::FullUser; - -pub async fn me( - State(state): State, - context: UserContext, -) -> web::Result { - let user = state.database.select_user_by_id(context.user_id).await?; - - Ok(Json(FullUser::from(user))) -} - -// file: src/web/route/user/mod.rs -pub mod channel; -mod get; -mod me; -mod patch; - -pub use get::get_by_id; -pub use me::me; -pub use patch::patch; - - -// file: src/web/route/user/patch.rs -use axum::Json; -use axum::extract::State; -use axum::response::IntoResponse; -use axum_extra::extract::WithRejection; -use serde::Deserialize; -use validator::Validate; - -use crate::state::AppState; -use crate::web::context::UserContext; -use crate::web::entity::user::{FullUser, PartialUser}; -use crate::web::ws; -use crate::{entity, util, web}; - -#[derive(Debug, Validate, Deserialize)] -#[serde(rename_all = "camelCase")] -#[validate(schema(function = "validate_create_payload"))] -pub struct CreatePayload { - #[validate(length(min = 1, max = 32))] - #[serde(default)] - display_name: util::tristate::TriState, - - #[serde(default)] - avatar_id: util::tristate::TriState, -} - -fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> { - if payload.display_name.is_absent() && payload.avatar_id.is_absent() { - return Err(validator::ValidationError::new( - "at_least_one_field_required", - )); - } - - Ok(()) -} - -pub async fn patch( - State(state): State, - context: UserContext, - WithRejection(Json(payload), _): WithRejection, web::Error>, -) -> web::Result { - match payload.validate() { - Ok(_) => {}, - Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()), - }; - - let user = state - .database - .update_user_by_id( - context.user_id, - payload.display_name.as_deref(), - payload.avatar_id, - ) - .await?; - - ws::gateway::util::send_message( - state.clone(), - context.user_id, - ws::gateway::event::Event::AddUser { - user: PartialUser::from(user.clone()), - }, - ); - - ws::gateway::util::send_message_related( - state.clone(), - context.user_id, - ws::gateway::event::Event::AddUser { - user: PartialUser::from(user.clone()), - }, - ); - - Ok(Json(FullUser::from(user))) -} - -// file: src/web/ws/error.rs -pub type Result = std::result::Result>; - -#[derive(Debug, derive_more::From, derive_more::Display)] -pub enum Error { - #[from] - Custom(T), - - #[from] - Json(serde_json::Error), - - #[from] - AcknowledgementError(tokio::sync::oneshot::error::RecvError), - - WrongMessageType, - - WebSocketClosed, - - UnknownError, -} - -pub trait CustomError {} - -// file: src/web/ws/gateway/connection.rs -use axum::extract::ws::Message as AxumMessage; -use base64::Engine as _; -use futures::{Stream, StreamExt}; -use sha2::{Digest, Sha256}; -use tokio::sync::mpsc; - -use super::error::Error as WsError; -use super::event::Event as WsEvent; -use super::protocol::{WsClientMessage, WsServerMessage}; -use super::state::{WsContext, WsState, WsUserContext}; -use crate::jwt; -use crate::state::AppState; -use crate::web::ws::gateway::SessionKey; -use crate::web::ws::general::WebSocketHandler; -use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; -use crate::web::ws::voice; -use crate::webrtc::WebRtcSignal; - -impl WebSocketHandler for WsContext { - type ServerMessage = WsServerMessage; - type ClientMessage = WsClientMessage; - type Error = WsError; - - async fn handle_stream( - &mut self, - stream: S, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, - ) -> crate::web::ws::error::Result<(), Self::Error> - where - S: Stream> + Unpin, - { - process_websocket_messages(self, stream, sender, app_state).await?; - - Ok(()) - } - - async fn cleanup(&mut self, app_state: &AppState) { - if let Some(user_ctx_data) = &self.user_context { - app_state - .unregister_gateway_connected_user( - user_ctx_data.user_id, - &user_ctx_data.session_key, - ) - .await; - } - - drop(self.event_channel.take()); - } - - async fn handle_result_error( - &mut self, - error: Self::Error, - sender: &mpsc::UnboundedSender>, - ) { - let error_ws_message = WsServerMessage::Error { code: error }; - - let _ = sender.send(SendWsMessage::new_no_response(error_ws_message)); - } -} - -#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) -))] -async fn process_websocket_messages( - context: &mut WsContext, - mut ws_stream: S, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> crate::web::ws::error::Result<(), WsError> -where - S: Stream> + Unpin, -{ - loop { - match context.connection_state { - WsState::Initialize => { - tokio::select! { - biased; - maybe_message = ws_stream.next() => { - match maybe_message { - Some(Ok(message)) => { - handle_initial_message(context, message, sender, app_state).await?; - context.connection_state = WsState::Connected; - tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected."); - } - Some(Err(axum_ws_err)) => { - tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err); - return Err(crate::web::ws::error::Error::WebSocketClosed); - } - None => { - tracing::debug!("WebSocket stream ended by client during Initialize state."); - return Err(crate::web::ws::error::Error::WebSocketClosed); - } - } - } - } - }, - WsState::Connected => { - let user_ctx = context - .user_context - .as_ref() - .expect("User context must be set in Connected state"); - let (_event_tx_ref, event_rx) = context - .event_channel - .as_mut() - .expect("Event channel must be set in Connected state"); - - tokio::select! { - biased; - maybe_app_event = event_rx.recv() => { - if let Some(app_event_data) = maybe_app_event { - SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?; - } else { - tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket."); - return Ok(()); - } - } - maybe_ws_message = ws_stream.next() => { - match maybe_ws_message { - Some(Ok(message)) => { - handle_connected_message(context, message, sender, &app_state).await?; - } - Some(Err(axum_ws_err)) => { - tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err); - return Err(crate::web::ws::error::Error::WebSocketClosed); - } - None => { - tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state."); - return Err(crate::web::ws::error::Error::WebSocketClosed); - } - } - } - } - }, - } - } -} - -#[tracing::instrument(skip_all, fields(state = ?context.connection_state))] -async fn handle_initial_message( - context: &mut WsContext, - message: AxumMessage, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> crate::web::ws::error::Result<(), WsError> { - match deserialize_ws_message(message)? { - WsClientMessage::Authenticate { token } => { - match crate::web::middleware::get_context_from_token(&app_state, &token).await { - Ok(auth_user_context) => { - let user_id = auth_user_context.user_id; - - let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::(); - context.event_channel = Some((event_tx.clone(), event_rx)); - - let random_key_part = rand::random::(); - let current_session_key: SessionKey = { - let mut hasher = Sha256::new(); - hasher.update(token.as_bytes()); - hasher.update(user_id.to_string().as_bytes()); - hasher.update(&random_key_part.to_be_bytes()); - base64::engine::general_purpose::URL_SAFE_NO_PAD - .encode(hasher.finalize()) - .into() - }; - - context.user_context = Some(WsUserContext { - user_id, - session_key: current_session_key.clone(), - }); - - app_state - .register_gateway_connected_user( - user_id, - current_session_key.clone(), - event_tx, - ) - .await; - - SendWsMessage::send_with_response( - sender, - WsServerMessage::AuthenticateAccepted { - user_id, - session_key: current_session_key, - }, - ) - .await?; - Ok(()) - }, - Err(_auth_err) => { - tracing::warn!(token = %token, "Authentication failed for token."); - let _ = SendWsMessage::send_with_response( - sender, - WsServerMessage::AuthenticateDenied, - ) - .await; - Err(WsError::AuthenticationFailed.into()) - }, - } - }, - #[allow(unreachable_patterns)] - _ => { - tracing::warn!("Unexpected message type received during Initialize state."); - Err(crate::web::ws::error::Error::WrongMessageType) - }, - } -} - -#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id) -))] -async fn handle_connected_message( - context: &mut WsContext, - message: AxumMessage, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> crate::web::ws::error::Result<(), WsError> { - match deserialize_ws_message(message)? { - WsClientMessage::VoiceStateUpdate { - server_id, - channel_id, - } => { - let server_id = server_id.unwrap_or_else(|| channel_id.clone().into()); - // TODO: check if can join this channel - - let claims = voice::claims::VoiceClaims { - user_id: context - .user_context - .as_ref() - .expect("user context should be present") - .user_id, - server_id, - channel_id, - exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime) - .timestamp(), - }; - - let token = jwt::generate_jwt( - claims, - crate::config::config().security.voice_secret.as_ref(), - ) - .map_err(|_| WsError::TokenGenerationFailed)?; - - SendWsMessage::send_with_response( - sender, - WsServerMessage::Event { - event: WsEvent::VoiceServerUpdate { - server_id, - channel_id, - token, - }, - }, - ) - .await?; - - Ok(()) - }, - WsClientMessage::RequestVoiceStates { server_id } => { - let channels = app_state - .database - .select_server_channels(server_id) - .await - .map_err(|_| crate::web::ws::error::Error::UnknownError)?; - - for channel in channels { - let (tx, rx) = tokio::sync::oneshot::channel(); - - let webrtc_sender = - { app_state.voice_rooms.read().await.get(&channel.id).cloned() }; - - if let Some(voice_room) = webrtc_sender { - let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx }); - - let peers = match rx.await { - Ok(peers) => peers, - Err(_) => { - continue; - }, - }; - - for peer in peers { - let _ = - sender.send(SendWsMessage::new_no_response(WsServerMessage::Event { - event: WsEvent::VoiceChannelConnected { - server_id, - channel_id: channel.id, - user_id: peer, - }, - })); - } - } - } - - Ok(()) - }, - other_message => { - tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); - Err(crate::web::ws::error::Error::WrongMessageType) - }, - } -} - -// file: src/web/ws/gateway/error.rs -use crate::web::ws::error::CustomError; - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum Error { - AuthenticationFailed, - TokenGenerationFailed, -} - -impl CustomError for Error {} - -// file: src/web/ws/gateway/event.rs -use crate::{entity, web}; - -#[derive(Debug, Clone, serde::Serialize)] -#[serde(tag = "type", content = "data")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum Event { - #[serde(rename_all = "camelCase")] - AddServer { server: web::entity::server::Server }, - - #[serde(rename_all = "camelCase")] - RemoveServer { server_id: entity::server::Id }, - - #[serde(rename_all = "camelCase")] - AddDmChannel { - channel: entity::channel::Channel, - recipients: Vec, - }, - - #[serde(rename_all = "camelCase")] - RemoveDmChannel { channel_id: entity::channel::Id }, - - #[serde(rename_all = "camelCase")] - AddServerChannel { channel: entity::channel::Channel }, - - #[serde(rename_all = "camelCase")] - RemoveServerChannel { - server_id: entity::server::Id, - channel_id: entity::channel::Id, - }, - - #[serde(rename_all = "camelCase")] - AddUser { - user: web::entity::user::PartialUser, - }, - - #[serde(rename_all = "camelCase")] - RemoveUser { user_id: entity::user::Id }, - - #[serde(rename_all = "camelCase")] - AddServerMember { - server_id: entity::server::Id, - member: web::entity::user::PartialUser, - }, - - #[serde(rename_all = "camelCase")] - RemoveServerMember { - server_id: entity::server::Id, - member_id: entity::user::Id, - }, - - #[serde(rename_all = "camelCase")] - AddMessage { - channel_id: entity::channel::Id, - message: web::entity::message::Message, - }, - - #[serde(rename_all = "camelCase")] - RemoveMessage { - channel_id: entity::channel::Id, - message_id: entity::message::Id, - }, - - #[serde(rename_all = "camelCase")] - VoiceChannelConnected { - server_id: entity::server::Id, - channel_id: entity::channel::Id, - user_id: entity::user::Id, - }, - - #[serde(rename_all = "camelCase")] - VoiceChannelDisconnected { - server_id: entity::server::Id, - channel_id: entity::channel::Id, - user_id: entity::user::Id, - }, - - #[serde(rename_all = "camelCase")] - VoiceServerUpdate { - server_id: entity::server::Id, - channel_id: entity::channel::Id, - token: String, - }, -} - -// file: src/web/ws/gateway/mod.rs -use axum::extract::{State, WebSocketUpgrade}; -use axum::response::IntoResponse; -use dashmap::DashMap; - -use crate::state::AppState; -use crate::web::ws::gateway::state::{EventSender, WsContext}; -use crate::web::ws::general; - -mod connection; -mod error; -pub mod event; -mod protocol; -mod state; -pub mod util; - -#[derive( - Debug, - Clone, - Eq, - PartialEq, - Hash, - Default, - derive_more::Display, - serde::Serialize, - serde::Deserialize, - derive_more::From, -)] -#[serde(transparent)] -pub struct SessionKey(String); - -#[derive(Debug, Default)] -pub struct GatewayWsState { - pub instances: DashMap, -} - -pub async fn ws_handler( - State(app_state): State, - ws: WebSocketUpgrade, -) -> crate::web::error::Result { - Ok(ws.on_upgrade(|socket| { - general::handle_websocket_connection(socket, app_state, WsContext::default()) - })) -} - -// file: src/web/ws/gateway/protocol.rs -use super::{SessionKey, error, event}; -use crate::entity; - -#[derive(Debug, serde::Serialize)] -#[serde(tag = "type", content = "data")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum WsServerMessage { - AuthenticateDenied, - - #[serde(rename_all = "camelCase")] - AuthenticateAccepted { - user_id: entity::user::Id, - session_key: SessionKey, - }, - - #[serde(rename_all = "camelCase")] - Event { - event: event::Event, - }, - - #[serde(rename_all = "camelCase")] - Error { - code: error::Error, - }, -} - -#[derive(Debug, serde::Deserialize)] -#[serde(tag = "type", content = "data")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum WsClientMessage { - #[serde(rename_all = "camelCase")] - Authenticate { token: String }, - - #[serde(rename_all = "camelCase")] - VoiceStateUpdate { - server_id: Option, - channel_id: entity::channel::Id, - }, - - #[serde(rename_all = "camelCase")] - RequestVoiceStates { server_id: entity::server::Id }, -} - -// file: src/web/ws/gateway/state.rs -use tokio::sync::mpsc; - -use super::{SessionKey, event}; -use crate::entity; - -#[derive(Debug, Eq, PartialEq, Clone, Copy)] -pub enum WsState { - Initialize, - Connected, -} - -#[derive(Debug)] -pub struct WsUserContext { - pub user_id: entity::user::Id, - pub session_key: SessionKey, // Unique key for this specific WebSocket session instance -} - -pub type EventSender = mpsc::UnboundedSender; -pub type EventReceiver = mpsc::UnboundedReceiver; - -pub struct WsContext { - pub connection_state: WsState, - pub user_context: Option, - pub event_channel: Option<(EventSender, EventReceiver)>, -} - -impl Default for WsContext { - fn default() -> Self { - Self { - connection_state: WsState::Initialize, - user_context: None, - event_channel: None, - } - } -} - -// file: src/web/ws/gateway/util.rs -use crate::entity; -use crate::state::AppState; -use crate::web::ws::gateway::event; - -pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) { - tokio::spawn(async move { - let connected_users = state.gateway_state.connected.get_async(&user_id).await; - if let Some(session) = connected_users { - for instance in session.instances.iter() { - if let Err(e) = instance.send(message.clone()) { - tracing::error!("failed to send message: {}", e); - } - } - } - }); -} - -pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) { - for id in user_ids.iter() { - send_message(state.clone(), *id, message.clone()); - } -} - -pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) { - tokio::spawn(async move { - let users = state - .database - .select_server_members(server_id) - .await - .unwrap_or_else(|_| vec![]) - .iter() - .map(|u| u.id) - .collect::>(); - - send_message_many(state, &users, message); - }); -} - -pub fn send_message_channel( - state: AppState, - channel_id: entity::channel::Id, - message: event::Event, -) { - tokio::spawn(async move { - let users = state - .database - .select_channel_members(channel_id) - .await - .unwrap_or_else(|_| vec![]) - .iter() - .map(|u| u.id) - .collect::>(); - - send_message_many(state, &users, message); - }); -} - -pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) { - tokio::spawn(async move { - let users = state - .database - .select_related_user_ids(user_id) - .await - .unwrap_or_else(|_| vec![]); - - send_message_many(state, &users, message); - }); -} - -// file: src/web/ws/general.rs -use std::fmt::Debug; - -use axum::extract::ws::WebSocket; -use futures::{Stream, StreamExt}; -use serde::Serialize; -use serde::de::DeserializeOwned; -use tokio::sync::mpsc; - -use crate::state::AppState; -use crate::web::ws::error::CustomError; -use crate::web::ws::util; -use crate::web::ws::util::SendWsMessage; - -pub trait WebSocketHandler { - type ServerMessage: Serialize + Send; - type ClientMessage: DeserializeOwned; - type Error: CustomError + Send + Debug; - - async fn handle_stream( - &mut self, - stream: S, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, - ) -> crate::web::ws::error::Result<(), Self::Error> - where - S: Stream> + Unpin; - - async fn cleanup(&mut self, app_state: &AppState); - - async fn handle_result_error( - &mut self, - error: Self::Error, - sender: &mpsc::UnboundedSender>, - ); -} - -#[tracing::instrument(skip_all)] -pub async fn handle_websocket_connection( - websocket: WebSocket, - app_state: AppState, - mut handler: impl WebSocketHandler + 'static, -) { - let (ws_sink, ws_stream) = websocket.split(); - - let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel(); - - let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx); - - let processing_result = handler - .handle_stream(ws_stream, &internal_send_tx, &app_state) - .await; - - handler.cleanup(&app_state).await; - - match processing_result { - Ok(_) => {}, - Err(crate::web::ws::error::Error::Custom(err_to_report)) => { - handler - .handle_result_error(err_to_report, &internal_send_tx) - .await; - }, - Err(e) => { - tracing::info!("WebSocket connection closed: {:?}", e); - }, - } - - drop(internal_send_tx); - if let Err(e) = writer_task.await { - tracing::error!( - "WebSocket writer task panicked or encountered an error: {:?}", - e - ); - } -} - -// file: src/web/ws/mod.rs -mod error; -pub mod gateway; -mod util; -pub mod voice; -mod general; - -// file: src/web/ws/util.rs -use axum::extract::ws::Message as AxumMessage; -use futures::{Sink, SinkExt}; -use serde::Serialize; -use serde::de::DeserializeOwned; -use tokio::sync::{mpsc, oneshot}; - -use crate::web::ws::error::CustomError; - -pub fn spawn_writer_task( - mut ws_sink: S, - mut writer_rx: mpsc::UnboundedReceiver>, -) -> tokio::task::JoinHandle<()> -where - S: Sink + Unpin + Send + 'static, - T: Serialize + Send + 'static, - E: CustomError + Send + 'static, -{ - tokio::spawn(async move { - while let Some(SendWsMessage { - message, - response_ch, - }) = writer_rx.recv().await - { - let send_result = match serialize_ws_message(message) { - Ok(ws_msg) => { - if ws_sink.send(ws_msg).await.is_err() { - Err(super::error::Error::WebSocketClosed) - } else { - Ok(()) - } - }, - Err(e) => Err(e.into()), - }; - - if let Some(ch) = response_ch { - let _ = ch.send(send_result); - } - } - - let _ = ws_sink.close().await; - }) -} - -/// Deserializes an Axum WebSocket message into a `WsClientMessage`. -pub fn deserialize_ws_message( - message: AxumMessage, -) -> super::error::Result { - match message { - AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from), - AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed), - _ => Err(super::error::Error::WrongMessageType), - } -} - -/// Serializes a `WsServerMessage` into an Axum WebSocket message. -pub fn serialize_ws_message( - message: T, -) -> super::error::Result { - serde_json::to_string(&message) - .map(Into::into) - .map(AxumMessage::Text) - .map_err(super::error::Error::from) -} - -/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task. -/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer. -pub struct SendWsMessage { - pub message: T, - pub response_ch: Option>>, -} - -impl SendWsMessage { - /// Sends a message over the MPSC channel and awaits a response via a oneshot channel. - pub async fn send_with_response( - tx: &mpsc::UnboundedSender, // Changed to reference - message: T, - ) -> super::error::Result<(), E> { - let (response_tx, response_rx) = oneshot::channel(); - let send_message = SendWsMessage { - message, - response_ch: Some(response_tx), - }; - - if tx.send(send_message).is_err() { - Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead - } else { - // Wait for the writer task to acknowledge the send attempt. - // This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error. - response_rx.await? // Propagates RecvError into WsError::AcknowledgementError - } - } - - /// Creates a new message for fire-and-forget sending (no response/acknowledgement expected). - pub fn new_no_response(message: T) -> Self { - Self { - message, - response_ch: None, - } - } -} - -// file: src/web/ws/voice/claims.rs -use crate::entity; - -#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)] -pub struct VoiceClaims { - pub user_id: entity::user::Id, - pub server_id: entity::server::Id, - pub channel_id: entity::channel::Id, - pub exp: i64, -} - -// file: src/web/ws/voice/connection.rs -use axum::extract::ws::Message as AxumMessage; -use futures::{Stream, StreamExt}; -use tokio::sync::{mpsc, oneshot}; - -use super::error::{self, Error as WsError}; -use super::protocol::{WsClientMessage, WsServerMessage}; -use crate::jwt; -use crate::state::AppState; -use crate::web::ws; -use crate::web::ws::general::WebSocketHandler; -use crate::web::ws::util::{SendWsMessage, deserialize_ws_message}; -use crate::web::ws::voice::claims::VoiceClaims; -use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer; -use crate::web::ws::voice::state::{WsContext, WsState}; -use crate::webrtc::{Offer, OfferSignal, WebRtcSignal}; - -impl WebSocketHandler for WsContext { - type ServerMessage = WsServerMessage; - type ClientMessage = WsClientMessage; - type Error = WsError; - - async fn handle_stream( - &mut self, - stream: S, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, - ) -> ws::error::Result<(), Self::Error> - where - S: Stream> + Unpin, - { - process_websocket_messages(self, stream, sender, app_state).await?; - - Ok(()) - } - - async fn cleanup(&mut self, app_state: &AppState) { - tracing::debug!("Cleaning up WebSocket connection."); - - match &self.connection_state { - WsState::Connected { - signal_channel, - server_id, - channel_id, - user_id, - .. - } => { - ws::gateway::util::send_message_channel( - app_state.clone(), - *channel_id, - ws::gateway::event::Event::VoiceChannelDisconnected { - server_id: *server_id, - channel_id: *channel_id, - user_id: *user_id, - }, - ); - - let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone())); - }, - WsState::Initialize => {}, - } - } - - async fn handle_result_error( - &mut self, - error: Self::Error, - sender: &mpsc::UnboundedSender>, - ) { - tracing::error!("WebSocket error: {:?}", error); - } -} - -#[tracing::instrument(skip_all)] -async fn process_websocket_messages( - context: &mut WsContext, - mut ws_stream: S, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> ws::error::Result<(), WsError> -where - S: Stream> + Unpin, -{ - loop { - match &context.connection_state { - WsState::Initialize => { - while let Some(Ok(message)) = ws_stream.next().await { - handle_initial_message(context, message, sender, &app_state).await?; - break; - } - }, - WsState::Connected { signal_channel, .. } => { - let signal_channel = signal_channel.clone(); - loop { - tokio::select! { - biased; - _ = signal_channel.closed() => { - tracing::debug!("Signal channel closed."); - break; - } - Some(Ok(message)) = ws_stream.next() => { - handle_connected_message(context, message, sender, &app_state).await?; - } - else => { - break; - } - } - } - - return Err(ws::error::Error::WebSocketClosed); - }, - } - } -} - -#[tracing::instrument(skip_all)] -async fn handle_initial_message( - context: &mut WsContext, - message: AxumMessage, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> ws::error::Result<(), error::Error> { - match deserialize_ws_message(message)? { - WsClientMessage::Authenticate { token } => match jwt::verify_jwt::( - &token, - crate::config::config().security.voice_secret.as_ref(), - ) { - Ok(claims) => { - SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted) - .await?; - - let signal_channel = - ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await; - - context.connection_state = WsState::Connected { - signal_channel, - server_id: claims.server_id, - channel_id: claims.channel_id, - user_id: claims.user_id, - }; - - ws::gateway::util::send_message_channel( - app_state.clone(), - claims.channel_id, - ws::gateway::event::Event::VoiceChannelConnected { - server_id: claims.server_id, - channel_id: claims.channel_id, - user_id: claims.user_id, - }, - ); - - Ok(()) - }, - Err(auth_err) => { - tracing::warn!("Authentication failed: {:?}", auth_err); - - let _ = - SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied) - .await; - Err(error::Error::AuthenticationFailed.into()) - }, - }, - #[allow(unreachable_patterns)] - _ => { - tracing::warn!("Unexpected message type received during Initialize state."); - Err(ws::error::Error::WrongMessageType) - }, - } -} - -#[tracing::instrument(skip_all)] -async fn handle_connected_message( - context: &mut WsContext, - message: AxumMessage, - sender: &mpsc::UnboundedSender>, - app_state: &AppState, -) -> ws::error::Result<(), error::Error> { - match deserialize_ws_message(message)? { - WsClientMessage::SdpOffer { sdp } => { - let (signal_channel, user_id) = match &mut context.connection_state { - WsState::Connected { - signal_channel, - user_id, - .. - } => (signal_channel.clone(), *user_id), - _ => return Err(ws::error::Error::WrongMessageType), - }; - - let (tx, rx) = oneshot::channel(); - - let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal { - offer: Offer { - peer_id: user_id, - sdp_offer: sdp, - }, - response: tx, - })); - - let answer_signal = rx.await?; - - sender - .send(SendWsMessage::new_no_response(SdpAnswer { - sdp: answer_signal.sdp_answer, - })) - .map_err(|_| ws::error::Error::WebSocketClosed)?; - - Ok(()) - }, - other_message => { - tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state."); - Err(ws::error::Error::WrongMessageType) - }, - } -} - -// file: src/web/ws/voice/error.rs -use crate::web::ws::error::CustomError; - -pub type Result = std::result::Result; - -#[derive(Debug, derive_more::From, derive_more::Display)] -pub enum Error { - AuthenticationFailed, -} - -impl CustomError for Error {} - -// file: src/web/ws/voice/mod.rs -pub mod claims; -mod connection; -mod error; -mod protocol; -mod state; - -use axum::extract::{State, WebSocketUpgrade}; -use axum::response::IntoResponse; - -use crate::state::AppState; -use crate::web::ws::general; -use crate::web::ws::voice::state::WsContext; - -pub async fn ws_handler( - State(app_state): State, - ws: WebSocketUpgrade, -) -> crate::web::error::Result { - Ok(ws.on_upgrade(|socket| { - general::handle_websocket_connection(socket, app_state, WsContext::default()) - })) -} - -// file: src/web/ws/voice/protocol.rs -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; - -#[derive(Debug, serde::Serialize)] -#[serde(tag = "type", content = "data")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum WsServerMessage { - AuthenticateDenied, - - AuthenticateAccepted, - - #[serde(rename_all = "camelCase")] - SdpAnswer { - sdp: RTCSessionDescription, - }, -} - -#[derive(Debug, serde::Deserialize)] -#[serde(tag = "type", content = "data")] -#[serde(rename_all = "SCREAMING_SNAKE_CASE")] -pub enum WsClientMessage { - #[serde(rename_all = "camelCase")] - Authenticate { token: String }, - - #[serde(rename_all = "camelCase")] - SdpOffer { sdp: RTCSessionDescription }, -} - -// file: src/web/ws/voice/state.rs -use tokio::sync::oneshot; -use tokio::sync::mpsc; - -use crate::entity; -use crate::state::AppState; -use crate::webrtc::{OfferSignal, WebRtcSignal}; - -#[derive(Debug)] -pub enum WsState { - Initialize, - Connected { - signal_channel: mpsc::UnboundedSender, - server_id: entity::server::Id, - channel_id: entity::channel::Id, - user_id: entity::user::Id, - }, -} - -pub struct WsContext { - pub connection_state: WsState, -} - -impl Default for WsContext { - fn default() -> Self { - Self { - connection_state: WsState::Initialize, - } - } -} - -pub async fn get_signaling_channel( - app_state: &AppState, - channel_id: entity::channel::Id, -) -> mpsc::UnboundedSender { - let room_sender = { - app_state - .voice_rooms - .read() - .await - .get(&channel_id) - .map(|room| room.clone()) - }; - - match room_sender { - Some(room) => room, - None => { - let (tx, rx) = mpsc::unbounded_channel(); - - let app_state_ = app_state.clone(); - tokio::spawn(async move { - crate::webrtc::webrtc_task(channel_id, rx) - .await - .unwrap_or_else(|err| { - tracing::error!("webrtc task error: {:?}", err); - }); - - { - app_state_.unregister_voice_room(channel_id).await; - } - }); - - { - app_state.register_voice_room(channel_id, tx.clone()).await; - } - - tx - }, - } -} - -// file: src/webrtc/mod.rs -use std::sync::Arc; - -use dashmap::DashMap; -use tracing::Instrument; -use webrtc::api::interceptor_registry::register_default_interceptors; -use webrtc::api::media_engine::MIME_TYPE_OPUS; -use webrtc::api::{API, APIBuilder}; -use webrtc::interceptor::registry::Registry; -use webrtc::peer_connection::configuration::RTCConfiguration; -use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState; -use webrtc::peer_connection::sdp::session_description::RTCSessionDescription; -use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType}; -use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP; -use webrtc::track::track_local::{TrackLocal, TrackLocalWriter}; -use webrtc::track::track_remote::TrackRemote; - -use crate::entity; - -type PeerId = entity::user::Id; -type RoomId = entity::channel::Id; - -struct PeerState { - peer_id: PeerId, - peer_connection: Arc, - outgoing_audio_track: Arc, - ssrc: u32, -} - -struct RoomState { - room_id: RoomId, - peers: DashMap, - close_signal: tokio::sync::mpsc::UnboundedSender<()>, -} - -#[derive(Debug)] -pub struct Offer { - pub peer_id: PeerId, - pub sdp_offer: RTCSessionDescription, -} - -#[derive(Debug)] -pub struct AnswerSignal { - pub sdp_answer: RTCSessionDescription, -} - -#[derive(Debug)] -pub enum WebRtcSignal { - Offer(OfferSignal), - Disconnect(PeerId), - RequestPeers { - response: tokio::sync::oneshot::Sender>, - }, - Close, -} - -#[derive(Debug)] -pub struct OfferSignal { - pub offer: Offer, - pub response: tokio::sync::oneshot::Sender, -} - -#[tracing::instrument(skip(signal))] -pub async fn webrtc_task( - room_id: RoomId, - signal: tokio::sync::mpsc::UnboundedReceiver, -) -> anyhow::Result<()> { - tracing::info!("Starting WebRTC task"); - - let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel(); - - let mut skip_timeout = false; - - let state = Arc::new(RoomState { - room_id, - peers: DashMap::new(), - close_signal, - }); - - let mut signal = signal; - let mut media_engine = webrtc::api::media_engine::MediaEngine::default(); - media_engine.register_default_codecs()?; - let mut registry = Registry::new(); - registry = register_default_interceptors(registry, &mut media_engine)?; - - let api = APIBuilder::new() - .with_media_engine(media_engine) - .with_interceptor_registry(registry) - .build(); - - let api = Arc::new(api); - - loop { - tokio::select! { - biased; - _ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => { - tracing::debug!("initial timeout reached"); - break; - } - _ = close_receiver.recv() => { - tracing::debug!("WebRTC task stopped"); - break; - } - Some(signal) = signal.recv() => { - skip_timeout = true; - match signal { - WebRtcSignal::Offer(offer_signal) => { - let room_state = state.clone(); - let api = api.clone(); - - tokio::spawn(async move { - if let Err(e) = handle_peer(api, room_state, offer_signal).await { - tracing::error!("error handling peer: {}", e); - } - }.instrument(tracing::Span::current())); - } - WebRtcSignal::RequestPeers { response } => { - let peers = state - .peers - .iter() - .map(|pair| pair.key().clone()) - .collect::>(); - - let _ = response.send(peers); - } - WebRtcSignal::Disconnect(peer_id) => { - tracing::debug!("received disconnect signal for peer {}", peer_id); - cleanup_peer(state.clone(), peer_id).await; - } - WebRtcSignal::Close => { - break; - } - } - } - } - } - - Ok(()) -} - -#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id -))] -async fn handle_peer( - api: Arc, - room_state: Arc, - offer_signal: OfferSignal, -) -> anyhow::Result<()> { - tracing::debug!("handling peer"); - - let config = RTCConfiguration { - ..Default::default() - }; - - let outgoing_track = Arc::new(TrackLocalStaticRTP::new( - RTCRtpCodecCapability { - mime_type: MIME_TYPE_OPUS.to_string(), - ..Default::default() - }, - format!("audio-{}", offer_signal.offer.peer_id), - format!("room-{}", room_state.room_id), - )); - - let peer_connection = Arc::new(api.new_peer_connection(config).await?); - - let rtp_sender = peer_connection - .add_track(Arc::clone(&outgoing_track) as Arc) - .await?; - - tracing::debug!("added track to peer connection: {:?}", rtp_sender); - - // Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI) - let outgoing_track_ = Arc::clone(&outgoing_track); - tokio::spawn( - async move { - let mut rtcp_buf = vec![0u8; 1500]; - while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await { - // Process RTCP if needed (e.g., bandwidth estimation, custom feedback) - } - tracing::debug!( - "RTCP reader loop for outgoing track {} ended", - outgoing_track_.id() - ); - } - .instrument(tracing::Span::current()), - ); - - let room_state_ = room_state.clone(); - peer_connection.on_peer_connection_state_change(Box::new(move |state| { - let room_state_ = room_state_.clone(); - Box::pin(async move { - tracing::debug!("peer connection state changed: {:?}", state); - if state == RTCPeerConnectionState::Disconnected - || state == RTCPeerConnectionState::Failed - { - tracing::debug!("peer connection closed"); - cleanup_peer(room_state_, offer_signal.offer.peer_id).await; - } - }) - })); - - let room_state_ = room_state.clone(); - peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| { - tracing::debug!("track received: {:?}", track); - - let room_state_ = room_state_.clone(); - Box::pin(async move { - if track.kind() == RTPCodecType::Audio { - tracing::debug!("audio track received: {:?}", track); - tokio::spawn(async move { - if let Err(e) = - forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await - { - tracing::error!("error forwarding audio track: {}", e); - } - }); - } else { - tracing::warn!("received non-audio track: {:?}", track); - } - }) - })); - - peer_connection - .set_remote_description(offer_signal.offer.sdp_offer) - .await?; - let answer = peer_connection.create_answer(None).await?; - - let mut gathering_complete = peer_connection.gathering_complete_promise().await; - - peer_connection.set_local_description(answer).await?; - - gathering_complete.recv().await; - - tracing::debug!("ICE gathering complete"); - - let local_description = peer_connection - .local_description() - .await - .ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?; - - let peer_state = PeerState { - peer_id: offer_signal.offer.peer_id, - peer_connection: Arc::clone(&peer_connection), - outgoing_audio_track: Arc::clone(&outgoing_track), - ssrc: 0, - }; - - { - if let Some((_, old_peer)) = room_state.peers.remove(&offer_signal.offer.peer_id) { - let _ = old_peer.peer_connection.close().await; - } - } - - { - room_state - .peers - .insert(offer_signal.offer.peer_id, peer_state); - } - - let _ = offer_signal.response.send(AnswerSignal { - sdp_answer: local_description, - }); - - Ok(()) -} - -#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id -))] -async fn forward_audio_track( - room_state: Arc, - peer_id: PeerId, - track: Arc, -) -> anyhow::Result<()> { - let mut rtp_buf = vec![0u8; 1500]; - - { - let mut peer_state = room_state - .peers - .get_mut(&peer_id) - .expect("peer state not found"); - - peer_state.ssrc = track.ssrc(); - } - - while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await { - let other_peer_tracks = room_state - .peers - .iter() - .filter_map(|pair| { - let peer_state = pair.value(); - if peer_state.peer_id != peer_id { - let mut rtp_packet = rtp_packet.clone(); - rtp_packet.header.ssrc = peer_state.ssrc; - - Some((peer_state.outgoing_audio_track.clone(), rtp_packet)) - } else { - None - } - }) - .collect::>(); - - if other_peer_tracks.is_empty() { - // tracing::warn!("no other peers to forward audio track to"); - continue; - } - - let write_futures = other_peer_tracks - .iter() - .map(|(outgoing_track, packet)| outgoing_track.write_rtp(&packet)); - - let results = futures::future::join_all(write_futures).await; - for result in results { - if let Err(e) = result { - tracing::error!("error writing RTP packet: {}", e); - } - } - } - tracing::debug!( - "RTP Read loop ended for track {} from peer {}", - track.id(), - peer_id - ); - - Ok(()) -} - -#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))] -async fn cleanup_peer(room_state: Arc, peer_id: PeerId) { - tracing::debug!("cleaning up peer"); - - if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) { - tracing::debug!("removed peer"); - let pc = Arc::clone(&peer_state.peer_connection); - tokio::spawn(async move { - if let Err(e) = pc.close().await { - if !matches!(e, webrtc::Error::ErrConnectionClosed) { - tracing::warn!("error closing peer connection: {}", e); - } - } - - tracing::debug!("peer connection closed"); - }); - } - - if room_state.peers.is_empty() { - tracing::debug!("no more peers in room, closing room"); - let _ = room_state.close_signal.send(()); - } -}