Files
diplom/backend.txt
2025-06-03 11:42:51 +03:00

5063 lines
140 KiB
Plaintext

// file: migrations/20250510180101_uuidv7.sql
CREATE EXTENSION IF NOT EXISTS pg_uuidv7;
// file: migrations/20250510182916_file.sql
CREATE TABLE IF NOT EXISTS "file"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"filename" VARCHAR NOT NULL,
"content_type" VARCHAR NOT NULL,
"size" INT8 NOT NULL
);
// file: migrations/20250510183102_user.sql
CREATE TABLE IF NOT EXISTS "user"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL,
"username" VARCHAR NOT NULL UNIQUE,
"display_name" VARCHAR,
"email" VARCHAR NOT NULL,
"password_hash" VARCHAR NOT NULL,
"bot" BOOLEAN NOT NULL DEFAULT FALSE,
"system" BOOLEAN NOT NULL DEFAULT FALSE,
"settings" JSONB NOT NULL DEFAULT '{}'::JSONB
);
CREATE TABLE IF NOT EXISTS "user_relation"
(
"user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
"other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
"type" INT2 NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY ("user_id", "other_id")
);
-- create system account
INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system")
VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE);
CREATE OR REPLACE FUNCTION check_avatar_is_image()
RETURNS TRIGGER AS
$$
DECLARE
file_content_type VARCHAR;
BEGIN
-- Skip check if icon_id is null
IF NEW.avatar_id IS NULL THEN
RETURN NEW;
END IF;
-- Retrieve content_type from file table
SELECT content_type
INTO file_content_type
FROM file
WHERE id = NEW.avatar_id;
-- Raise exception if content_type does not start with 'image/'
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_check_icon_is_image
BEFORE INSERT OR UPDATE
ON "user"
FOR EACH ROW
EXECUTE FUNCTION check_avatar_is_image();
CREATE OR REPLACE FUNCTION fn_on_user_relation_update()
RETURNS TRIGGER
LANGUAGE plpgsql
AS
$$
BEGIN
IF (TG_OP = 'UPDATE') THEN
NEW.updated_at := now();
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER trg_user_relation_update
BEFORE UPDATE
ON "user_relation"
FOR EACH ROW
EXECUTE FUNCTION fn_on_user_relation_update();
// file: migrations/20250510183125_server.sql
CREATE TABLE IF NOT EXISTS "server"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"owner_id" UUID NOT NULL REFERENCES "user" ("id"),
"name" VARCHAR NOT NULL,
"icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL
);
CREATE TABLE IF NOT EXISTS "server_role"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"name" VARCHAR NOT NULL,
"color" VARCHAR,
"display" BOOLEAN NOT NULL DEFAULT FALSE,
"permissions" JSONB NOT NULL DEFAULT '{}'::JSONB,
"position" SMALLINT NOT NULL
);
CREATE TABLE IF NOT EXISTS "server_member"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
"nickname" VARCHAR,
"avatar_url" VARCHAR,
UNIQUE ("server_id", "user_id")
);
CREATE TABLE IF NOT EXISTS "server_member_role"
(
"member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE,
"role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE,
PRIMARY KEY ("member_id", "role_id")
);
CREATE TABLE IF NOT EXISTS "server_invite"
(
"code" VARCHAR NOT NULL PRIMARY KEY,
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL,
"expires_at" TIMESTAMPTZ
);
CREATE OR REPLACE FUNCTION check_icon_is_image()
RETURNS TRIGGER AS
$$
DECLARE
file_content_type VARCHAR;
BEGIN
-- Skip check if icon_id is null
IF NEW.icon_id IS NULL THEN
RETURN NEW;
END IF;
-- Retrieve content_type from file table
SELECT content_type
INTO file_content_type
FROM file
WHERE id = NEW.icon_id;
-- Raise exception if content_type does not start with 'image/'
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_check_icon_is_image
BEFORE INSERT OR UPDATE
ON "server"
FOR EACH ROW
EXECUTE FUNCTION check_icon_is_image();
CREATE OR REPLACE FUNCTION check_server_user_role_server_id()
RETURNS TRIGGER AS
$$
DECLARE
member_server_id UUID;
role_server_id UUID;
BEGIN
-- Get server_id from server_user
SELECT server_id
INTO member_server_id
FROM server_member
WHERE id = NEW.member_id;
-- Get server_id from server_role
SELECT server_id
INTO role_server_id
FROM server_role
WHERE id = NEW.role_id;
-- Check if server_ids match
IF member_server_id != role_server_id THEN
RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER enforce_server_user_role_server_id
BEFORE INSERT OR UPDATE
ON server_member_role
FOR EACH ROW
EXECUTE FUNCTION check_server_user_role_server_id();
// file: migrations/20250510184011_channel_message.sql
CREATE TABLE IF NOT EXISTS "channel"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"name" VARCHAR NOT NULL,
"type" INT2 NOT NULL,
"position" INT2 NOT NULL,
"server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels
"parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels
"owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels
);
-- only for dm or group channels
CREATE TABLE IF NOT EXISTS "channel_recipient"
(
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
PRIMARY KEY ("channel_id", "user_id")
);
CREATE TABLE IF NOT EXISTS "message"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
"content" TEXT NOT NULL
);
CREATE TABLE IF NOT EXISTS "message_attachment"
(
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
"file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
"order" INT2 NOT NULL,
PRIMARY KEY ("message_id", "file_id")
);
ALTER TABLE "channel"
ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL;
CREATE OR REPLACE FUNCTION create_dm_channel(
user1_id UUID,
user2_id UUID,
channel_type INT2
)
RETURNS UUID
LANGUAGE plpgsql
AS
$$
DECLARE
new_channel_id UUID;
channel_name VARCHAR;
BEGIN
-- Validate parameters
IF user1_id IS NULL OR user2_id IS NULL THEN
RAISE EXCEPTION 'Both user IDs must be provided';
END IF;
IF user1_id = user2_id THEN
RAISE EXCEPTION 'Cannot create a DM channel with the same user';
END IF;
-- Check if users exist
IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR
NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN
RAISE EXCEPTION 'One or both users do not exist';
END IF;
-- Check if DM already exists between these users
IF exists (SELECT 1
FROM channel c
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
WHERE c.type = channel_type
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN
-- Find and return the existing channel ID
SELECT c.id
INTO new_channel_id
FROM channel c
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
WHERE c.type = channel_type
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2
LIMIT 1;
RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id;
RETURN new_channel_id;
END IF;
-- Generate channel name (conventionally uses user IDs in DMs)
channel_name := concat(user1_id, '-', user2_id);
-- Create new channel
INSERT INTO "channel" ("name", "type", "position")
VALUES (channel_name, channel_type, 0)
RETURNING id INTO new_channel_id;
-- Add both users as recipients
INSERT INTO "channel_recipient" ("channel_id", "user_id")
VALUES (new_channel_id, user1_id),
(new_channel_id, user2_id);
RAISE NOTICE 'DM channel created with ID: %', new_channel_id;
RETURN new_channel_id;
END;
$$;
CREATE OR REPLACE FUNCTION create_group_channel(
p_creator_id UUID, -- The user initiating the group creation (will be owner)
p_recipient_user_ids UUID[], -- Array of all user IDs for the group (must include creator)
p_group_channel_type INT2 -- The channel type identifier for groups (e.g., 1)
)
RETURNS UUID
LANGUAGE plpgsql
AS
$$
DECLARE
new_channel_id UUID;
existing_channel_id UUID;
final_channel_name VARCHAR;
unique_sorted_recipient_ids UUID[];
num_recipients INT;
uid UUID;
-- Threshold for detailed name vs. summary name
MAX_MEMBERS_FOR_DETAILED_NAME CONSTANT INT := 3;
BEGIN
-- Validate and process recipient IDs
IF p_recipient_user_ids IS NULL OR array_length(p_recipient_user_ids, 1) IS NULL THEN
RAISE EXCEPTION 'Recipient user IDs array must be provided and not empty.';
END IF;
-- Get unique, sorted recipient IDs for consistent checking and to avoid duplicates.
SELECT array_agg(DISTINCT u ORDER BY u) INTO unique_sorted_recipient_ids FROM unnest(p_recipient_user_ids) u;
num_recipients := array_length(unique_sorted_recipient_ids, 1);
-- Validate minimum number of recipients for a group
IF num_recipients < 1 THEN -- Groups typically have at least 2 members
RAISE EXCEPTION 'Group channels (type %) must have at least 2 recipients. Found %.', p_group_channel_type, num_recipients;
END IF;
-- Create new group channel
INSERT INTO "channel" ("name", "type", "position", "owner_id", "server_id", "parent")
VALUES ('Group',
p_group_channel_type,
0, -- Default position
p_creator_id,
NULL, -- Not a server channel
NULL -- Not a nested server channel
)
RETURNING id INTO new_channel_id;
-- Add all recipients to the channel_recipient table
INSERT INTO "channel_recipient" ("channel_id", "user_id")
VALUES (new_channel_id, p_creator_id);
INSERT INTO "channel_recipient" ("channel_id", "user_id")
SELECT new_channel_id, r_id
FROM unnest(unique_sorted_recipient_ids) AS r_id;
RAISE NOTICE 'Group channel (type %) named "%" created with ID: % by owner % for recipients: %',
p_group_channel_type, final_channel_name, new_channel_id, p_creator_id, unique_sorted_recipient_ids;
RETURN new_channel_id;
END;
$$;
// file: migrations/20250517190855_util.sql
CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID)
RETURNS TABLE (user_id UUID) AS $$
BEGIN
RETURN QUERY
-- Users directly related to the target user
SELECT ur.user_id
FROM user_relation ur
WHERE ur.other_id = target_user_id
UNION
-- Users where target user is related to them
SELECT ur.other_id AS user_id
FROM user_relation ur
WHERE ur.user_id = target_user_id
UNION
-- Users who share a server with the target user
SELECT sm.user_id
FROM server_member sm
JOIN server_member sm2 ON sm.server_id = sm2.server_id
WHERE sm2.user_id = target_user_id
AND sm.user_id != target_user_id
UNION
-- Users who share a channel with the target user (DM or group)
SELECT cr.user_id
FROM channel_recipient cr
JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id
WHERE cr2.user_id = target_user_id
AND cr.user_id != target_user_id;
END;
$$ LANGUAGE plpgsql;
// file: src/config.rs
use std::sync::OnceLock;
use serde::Deserialize;
pub fn config() -> &'static Config {
static INSTANCE: OnceLock<Config> = OnceLock::new();
INSTANCE.get_or_init(|| {
config::Config::builder()
.add_source(config::File::with_name("config"))
.add_source(config::Environment::with_prefix("DIPLOM_"))
.build()
.expect("config builder")
.try_deserialize()
.expect("config deserialize")
})
}
#[derive(Deserialize)]
pub struct Config {
pub server: ServerConfig,
pub security: SecurityConfig,
pub gateway: GatewayConfig,
pub database: DatabaseConfig,
pub object_store: ObjectStoreConfig,
}
#[derive(Deserialize)]
pub struct ServerConfig {
pub hostname: url::Url,
pub host: std::net::Ipv4Addr,
pub port: u16,
}
#[derive(Deserialize)]
pub struct SecurityConfig {
pub auth_secret: String,
pub voice_secret: String,
}
#[derive(Deserialize)]
pub struct GatewayConfig {
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds")]
pub voice_token_lifetime: std::time::Duration,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[serde(deny_unknown_fields)]
pub enum DatabaseConfig {
Url {
url: url::Url,
},
Full {
driver: String,
username: String,
password: String,
host: url::Host,
port: u16,
name: String,
},
}
#[derive(Debug, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ObjectStoreConfig {
pub endpoint: url::Url,
pub region: String,
pub bucket: String,
pub access_key: String,
pub secret_key: String,
}
impl DatabaseConfig {
pub fn url(&self) -> Option<url::Url> {
match self {
Self::Url { url } => Some(url.clone()),
Self::Full {
driver,
username,
password,
host,
port,
name,
} => {
let str = format!(
"{}://{}:{}@{}:{}/{}",
driver, username, password, host, port, name
);
let url = url::Url::parse(&str).ok()?;
Some(url)
},
}
}
}
// file: src/database.rs
use crate::config::DatabaseConfig;
use crate::entity;
use crate::util::tristate::TriState;
#[derive(Clone, derive_more::AsRef)]
pub struct Database {
pool: sqlx::PgPool,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
pub enum Error {
#[from]
Migrate(sqlx::migrate::MigrateError),
#[from]
Sqlx(sqlx::Error),
UserDoesNotExists,
UserAlreadyExists,
ServerDoesNotExists,
MemberAlreadyExists,
ChannelDoesNotExists,
InviteDoesNotExists,
MessageDoesNotExists,
FileDoesNotExists,
}
impl Database {
pub async fn connect(config: &DatabaseConfig) -> sqlx::Result<Self> {
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str())
.await?;
MIGRATOR.run(&pool).await?;
Ok(Self { pool })
}
}
impl Database {
pub async fn insert_user(
&self,
username: &str,
display_name: Option<&str>,
email: &str,
password_hash: &str,
) -> Result<entity::user::User> {
let user = match sqlx::query_as!(
entity::user::User,
r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#,
username,
display_name,
email,
password_hash
)
.fetch_one(&self.pool)
.await {
Ok(user) => user,
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
return Err(Error::UserAlreadyExists);
}
Err(e) => return Err(e.into()),
};
Ok(user)
}
pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" = $1"#,
user_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user)
}
pub async fn update_user_by_id(
&self,
user_id: entity::user::Id,
display_name: TriState<&str>,
avatar_id: TriState<entity::file::Id>,
) -> Result<entity::user::User> {
let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#);
let mut separated = query_builder.separated(", ");
if !display_name.is_absent() {
separated.push(r#""display_name" = "#);
separated.push_bind_unseparated(display_name.into_option());
}
if !avatar_id.is_absent() {
separated.push(r#""avatar_id" = "#);
separated.push_bind_unseparated(avatar_id.into_option());
}
query_builder.push(" WHERE \"id\" = ");
query_builder.push_bind(user_id);
query_builder.push(" RETURNING \"user\".*");
let user = query_builder.build_query_as().fetch_one(&self.pool).await?;
Ok(user)
}
pub async fn select_users_by_ids(
&self,
user_ids: &[entity::user::Id],
) -> Result<Vec<entity::user::User>> {
let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#);
let mut separated = query_builder.separated(", ");
for id in user_ids.iter() {
separated.push_bind(id);
}
query_builder.push(")");
let users = query_builder.build_query_as().fetch_all(&self.pool).await?;
Ok(users)
}
pub async fn select_user_by_username(&self, username: &str) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "username" = $1"#,
username
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user)
}
pub async fn select_server_by_id(
&self,
server_id: entity::server::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"SELECT * FROM "server" WHERE "id" = $1"#,
server_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ServerDoesNotExists)?;
Ok(server)
}
pub async fn select_user_servers(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::server::Server>> {
let servers = sqlx::query_as!(
entity::server::Server,
r#"SELECT * FROM "server" WHERE "id" IN (
SELECT "server_id" FROM "server_member"
WHERE "user_id" = $1
)"#,
user_id
)
.fetch_all(&self.pool)
.await?;
Ok(servers)
}
pub async fn select_server_members(
&self,
server_id: entity::server::Id,
) -> Result<Vec<entity::user::User>> {
let users = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "server_member" WHERE "server_id" = $1
)"#,
server_id
)
.fetch_all(&self.pool)
.await?;
Ok(users)
}
pub async fn select_channel_members(
&self,
channel_id: entity::channel::Id,
) -> Result<Vec<entity::user::User>> {
let users = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN (
SELECT "server_id" FROM "channel" WHERE "id" = $1
)
)"#,
channel_id
)
.fetch_all(&self.pool)
.await?;
Ok(users)
}
pub async fn select_user_channels(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::channel::Channel>> {
let channels = sqlx::query_as!(
entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "id" IN (
SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1
)"#,
user_id
)
.fetch_all(&self.pool)
.await?;
Ok(channels)
}
pub async fn insert_server(
&self,
name: &str,
icon_id: Option<entity::file::Id>,
owner_id: entity::user::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#,
name,
icon_id,
owner_id
)
.fetch_one(&self.pool)
.await?;
Ok(server)
}
pub async fn insert_server_role(
&self,
id: Option<entity::server::role::Id>,
server_id: entity::server::Id,
name: &str,
color: Option<&str>,
display: bool,
permissions: serde_json::Value,
position: u16,
) -> Result<entity::server::role::ServerRole> {
let role = sqlx::query_as!(
entity::server::role::ServerRole,
r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#,
id,
server_id,
name,
color,
display,
permissions,
position as i16
)
.fetch_one(&self.pool)
.await?;
Ok(role)
}
pub async fn insert_server_member(
&self,
server_id: entity::server::Id,
user_id: entity::user::Id,
) -> Result<entity::server::member::ServerMember> {
let member = match sqlx::query_as!(
entity::server::member::ServerMember,
r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#,
server_id,
user_id
)
.fetch_one(&self.pool)
.await {
Ok(member) => member,
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
return Err(Error::MemberAlreadyExists);
}
Err(e) => return Err(e.into()),
};
Ok(member)
}
pub async fn insert_server_member_role(
&self,
server_member_id: entity::server::member::Id,
server_role_id: entity::server::role::Id,
) -> Result<()> {
sqlx::query!(
r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#,
server_member_id,
server_role_id
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn select_channel_by_id(
&self,
channel_id: entity::channel::Id,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "id" = $1"#,
channel_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ChannelDoesNotExists)?;
Ok(channel)
}
pub async fn select_channel_recipients(
&self,
channel_id: entity::channel::Id,
) -> Result<Option<Vec<entity::user::User>>> {
let recipients = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
)"#,
channel_id
)
.fetch_all(&self.pool)
.await?;
if recipients.is_empty() {
return Ok(None);
}
Ok(Some(recipients))
}
pub async fn select_server_channels(
&self,
server_id: entity::server::Id,
) -> Result<Vec<entity::channel::Channel>> {
let channels = sqlx::query_as!(
entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "server_id" = $1"#,
server_id
)
.fetch_all(&self.pool)
.await?;
Ok(channels)
}
pub async fn delete_server_by_id(
&self,
server_id: entity::server::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#,
server_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ServerDoesNotExists)?;
Ok(server)
}
pub async fn delete_channel_by_id(
&self,
channel_id: entity::channel::Id,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#,
channel_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ChannelDoesNotExists)?;
Ok(channel)
}
pub async fn insert_server_channel(
&self,
name: &str,
position: u16,
r#type: entity::channel::ChannelType,
server_id: entity::server::Id,
parent: Option<entity::channel::Id>,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
name,
r#type as i16,
position as i16,
server_id,
parent
)
.fetch_one(&self.pool)
.await?;
Ok(channel)
}
pub async fn insert_server_invite(
&self,
code: &str,
server_id: entity::server::Id,
inviter_id: Option<entity::user::Id>,
expires_at: Option<chrono::DateTime<chrono::Utc>>,
) -> Result<entity::server::invite::ServerInvite> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#,
code,
server_id,
inviter_id,
expires_at
)
.fetch_one(&self.pool)
.await?;
Ok(invite)
}
pub async fn select_server_invite_by_code(
&self,
code: &str,
) -> Result<entity::server::invite::ServerInvite> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"SELECT * FROM "server_invite" WHERE "code" = $1"#,
code
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::InviteDoesNotExists)?;
Ok(invite)
}
pub async fn delete_server_invite_by_code(
&self,
code: &str,
) -> Result<Option<entity::server::invite::ServerInvite>> {
let invite = sqlx::query_as!(
entity::server::invite::ServerInvite,
r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#,
code
)
.fetch_optional(&self.pool)
.await?;
Ok(invite)
}
pub async fn select_channel_messages_paginated(
&self,
channel_id: entity::channel::Id,
before: Option<entity::message::Id>,
limit: i64,
) -> Result<Vec<entity::message::Message>> {
let messages = sqlx::query_as!(
entity::message::Message,
r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#,
channel_id,
before,
limit
)
.fetch_all(&self.pool)
.await?;
Ok(messages)
}
pub async fn insert_channel_message(
&self,
user_id: entity::user::Id,
channel_id: entity::channel::Id,
content: &str,
) -> Result<entity::message::Message> {
let message = sqlx::query_as!(
entity::message::Message,
r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#,
channel_id,
user_id,
content
)
.fetch_one(&self.pool)
.await?;
Ok(message)
}
pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"SELECT * FROM "file" WHERE "id" = $1"#,
file_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::FileDoesNotExists)?;
Ok(file)
}
pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#,
file_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::FileDoesNotExists)?;
Ok(file)
}
pub async fn insert_file(
&self,
filename: &str,
content_type: &str,
size: usize,
) -> Result<entity::file::File> {
let file = sqlx::query_as!(
entity::file::File,
r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#,
filename,
content_type,
size as i64
)
.fetch_one(&self.pool)
.await?;
Ok(file)
}
pub async fn insert_message_attachment(
&self,
message_id: entity::message::Id,
file_id: entity::file::Id,
) -> Result<()> {
sqlx::query!(
r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#,
message_id,
file_id
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn select_message_attachments(
&self,
message_id: entity::message::Id,
) -> Result<Vec<entity::file::File>> {
let attachments = sqlx::query_as!(
entity::file::File,
r#"SELECT * FROM "file" WHERE "id" IN (
SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1
)"#,
message_id
)
.fetch_all(&self.pool)
.await?;
Ok(attachments)
}
pub async fn select_related_user_ids(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::user::Id>> {
#[derive(sqlx::FromRow)]
struct UserId {
user_id: entity::user::Id,
}
let user_ids =
sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#)
.bind(user_id)
.fetch_all(&self.pool)
.await?
.into_iter()
.map(|row| row.user_id)
.collect();
Ok(user_ids)
}
pub async fn procedure_create_dm_channel(
&self,
user1_id: entity::user::Id,
user2_id: entity::user::Id,
) -> Result<entity::channel::Id> {
let channel_id = sqlx::query_scalar!(
r#"SELECT create_dm_channel($1, $2, $3)"#,
user1_id,
user2_id,
entity::channel::ChannelType::DirectMessage as i16
)
.fetch_one(&self.pool)
.await?
.expect("channel_id is null");
Ok(channel_id)
}
pub async fn procedure_create_group_channel(
&self,
creator_id: entity::user::Id,
users: &[entity::user::Id],
) -> Result<entity::channel::Id> {
let channel_id = sqlx::query_scalar!(
r#"SELECT create_group_channel($1, $2, $3)"#,
creator_id,
users,
entity::channel::ChannelType::DirectMessage as i16
)
.fetch_one(&self.pool)
.await?
.expect("channel_id is null");
Ok(channel_id)
}
}
// file: src/entity/channel.rs
use serde::{Deserialize, Serialize};
use crate::entity::{message, server, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Channel {
pub id: Id,
pub name: String,
pub r#type: ChannelType,
pub position: i16,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_id: Option<server::Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent: Option<Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub owner_id: Option<user::Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_message_id: Option<message::Id>,
}
#[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
#[repr(i16)]
pub enum ChannelType {
ServerText = 1,
ServerVoice = 2,
ServerCategory = 3,
DirectMessage = 4,
Group = 5,
}
impl From<i16> for ChannelType {
fn from(value: i16) -> Self {
match value {
1 => ChannelType::ServerText,
2 => ChannelType::ServerVoice,
3 => ChannelType::ServerCategory,
4 => ChannelType::DirectMessage,
5 => ChannelType::Group,
_ => ChannelType::ServerText,
}
}
}
// file: src/entity/file.rs
use serde::Serialize;
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
pub struct File {
pub id: Id,
pub filename: String,
pub content_type: String,
pub size: i64,
}
// file: src/entity/message.rs
use crate::entity::{channel, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct Message {
pub id: Id,
pub author_id: user::Id,
pub channel_id: channel::Id,
pub content: String,
}
// file: src/entity/mod.rs
pub mod file;
pub mod channel;
pub mod message;
pub mod server;
pub mod user;
// file: src/entity/server/invite.rs
use chrono::{DateTime, Utc};
use serde::Serialize;
use crate::entity::{server, user};
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerInvite {
pub code: String,
pub server_id: server::Id,
pub inviter_id: Option<user::Id>,
pub expires_at: Option<DateTime<Utc>>,
}
// file: src/entity/server/member.rs
use serde::Serialize;
use crate::entity::{server, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerMember {
pub id: Id,
pub server_id: server::Id,
pub user_id: user::Id,
pub nickname: Option<String>,
pub avatar_url: Option<String>,
}
// file: src/entity/server/role.rs
use serde::Serialize;
use crate::entity::server;
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerRole {
pub id: Id,
pub server_id: server::Id,
pub name: String,
pub color: Option<String>,
pub display: bool,
pub permissions: serde_json::Value,
pub position: i16,
}
// file: src/entity/server.rs
pub mod invite;
pub mod member;
pub mod role;
use crate::entity::{file, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct Server {
pub id: Id,
pub owner_id: user::Id,
pub name: String,
pub icon_id: Option<file::Id>,
}
// file: src/entity/user.rs
use std::sync::LazyLock;
use regex::Regex;
use crate::entity::file;
pub static USERNAME_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap());
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct User {
pub id: Id,
pub avatar_id: Option<file::Id>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub password_hash: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}
// file: src/jwt.rs
#![allow(unused)]
use std::collections::HashSet;
use chrono::{Duration, Local, Utc};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::config;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims<T> {
#[serde(flatten)]
pub data: T,
pub iat: i64,
}
pub fn generate_jwt<T: Serialize>(data: T, secret: &[u8]) -> Result<String> {
let claims = Claims {
data,
iat: Utc::now().timestamp_millis(),
};
let token = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
&claims,
&jsonwebtoken::EncodingKey::from_secret(secret),
)
.map_err(|_| Error::CouldNotEncodeToken)?;
Ok(token)
}
pub fn verify_jwt<T: DeserializeOwned>(token: &str, secret: &[u8]) -> Result<T> {
tracing::debug!("verifying token: {}", token);
let mut validation = jsonwebtoken::Validation::default();
validation.set_required_spec_claims::<String>(&[]);
let token_data = jsonwebtoken::decode::<Claims<T>>(
token,
&jsonwebtoken::DecodingKey::from_secret(secret),
&validation,
)
.inspect_err(|err| {
tracing::error!("Failed to decode JWT: {:?}", err);
})
.map_err(|_| Error::CouldNotVerifyToken)?;
Ok(token_data.claims.data)
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::Error, derive_more::Display)]
pub enum Error {
CouldNotEncodeToken,
CouldNotVerifyToken,
}
// file: src/log.rs
use std::io;
use tracing::level_filters::LevelFilter;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::time::ChronoLocal;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> {
let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string());
let file = tracing_appender::rolling::daily("./logs", "log");
let (file, file_guard) = tracing_appender::non_blocking(file);
let file_logger = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_timer(logger_timer.clone())
.with_writer(file);
let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout());
let stdout_logger = tracing_subscriber::fmt::layer()
.with_timer(logger_timer)
.with_writer(stdout);
#[cfg(debug_assertions)]
let stdout_logger = stdout_logger.pretty();
tracing_subscriber::registry()
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
)
.with(file_logger)
.with(stdout_logger)
.init();
tracing::info!("initialized logging");
Ok((file_guard, stdout_guard))
}
// file: src/main.rs
use std::sync::Arc;
use argon2::Argon2;
use crate::database::Database;
use crate::state::AppState;
mod config;
mod database;
mod entity;
mod jwt;
mod log;
mod object_store;
mod state;
mod util;
mod web;
mod webrtc;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _guard = log::init_logging()?;
let database = Database::connect(&config::config().database).await?;
let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?;
let state = AppState {
database,
object_store,
hasher: Arc::new(Argon2::default()),
gateway_state: Default::default(),
voice_rooms: Default::default(),
};
web::run(state).await?;
Ok(())
}
// file: src/object_store.rs
use crate::config::ObjectStoreConfig;
#[derive(Clone, derive_more::AsRef, derive_more::Deref)]
pub struct ObjectStore {
inner: Box<s3::Bucket>,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
pub enum Error {
#[from]
Credentials(s3::creds::error::CredentialsError),
#[from]
S3(s3::error::S3Error),
}
impl ObjectStore {
pub async fn connect(config: &ObjectStoreConfig) -> Result<Self> {
let region = s3::region::Region::Custom {
region: config.region.clone(),
endpoint: config.endpoint.origin().ascii_serialization(),
};
let credentials = s3::creds::Credentials::new(
Some(&config.access_key),
Some(&config.secret_key),
None,
None,
None,
)?;
let mut bucket =
s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())?
.with_path_style();
if !bucket.exists().await? {
bucket = s3::bucket::Bucket::create_with_path_style(
&config.bucket,
region,
credentials,
s3::BucketConfiguration::default(),
)
.await?
.bucket;
}
Ok(Self { inner: bucket })
}
}
// file: src/state.rs
use std::collections::HashMap;
use std::sync::Arc;
use argon2::Argon2;
use tokio::sync::{RwLock, mpsc};
use uuid::Uuid;
use crate::database::Database;
use crate::object_store::ObjectStore;
use crate::web::ws::gateway::{GatewayWsState, SessionKey, event};
use crate::webrtc::WebRtcSignal;
#[derive(Clone)]
pub struct AppState {
pub database: Database,
pub object_store: ObjectStore,
pub hasher: Arc<Argon2<'static>>,
pub gateway_state: Arc<GatewayState>,
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<WebRtcSignal>>>>,
}
#[derive(Debug, Default)]
pub struct GatewayState {
pub connected: scc::HashMap<Uuid, GatewayWsState>,
}
impl AppState {
pub async fn register_gateway_connected_user(
&self,
user_id: Uuid,
session_key: SessionKey,
event_sender: mpsc::UnboundedSender<event::Event>,
) {
self.gateway_state
.connected
.entry_async(user_id)
.await
.or_default()
.get_mut()
.instances
.insert(session_key, event_sender);
}
pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) {
let is_empty = {
let mut entry = self
.gateway_state
.connected
.entry_async(user_id)
.await
.or_default();
let entry = entry.get_mut();
entry.instances.remove(session_id);
entry.instances.is_empty()
};
if is_empty {
self.gateway_state.connected.remove_async(&user_id).await;
}
}
pub async fn register_voice_room(
&self,
room_id: Uuid,
sender: mpsc::UnboundedSender<WebRtcSignal>,
) {
self.voice_rooms.write().await.insert(room_id, sender);
}
pub async fn unregister_voice_room(&self, room_id: Uuid) {
self.voice_rooms.write().await.remove(&room_id);
}
}
// file: src/util/mod.rs
pub mod tristate;
use axum::extract::multipart::Field;
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
use serde::{Deserialize, Serialize};
use crate::entity;
pub fn file_id_to_url(file_id: &entity::file::Id) -> Option<String> {
Some(
crate::config::config()
.server
.hostname
.join("files/")
.ok()?
.join(&file_id.to_string())
.ok()?
.to_string(),
)
}
#[derive(Debug, derive_more::Deref)]
pub struct SerdeFieldData<T>(pub FieldData<T>);
#[async_trait::async_trait]
impl<T: TryFromField> TryFromField for SerdeFieldData<T> {
async fn try_from_field(
field: Field<'_>,
limit_bytes: Option<usize>,
) -> Result<Self, TypedMultipartError> {
let field = FieldData::try_from_field(field, limit_bytes).await?;
Ok(Self(field))
}
}
impl<T> Serialize for SerdeFieldData<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Metadata {
name: Option<String>,
file_name: Option<String>,
content_type: Option<String>,
}
let metadata = Metadata {
name: self.0.metadata.name.clone(),
file_name: self.0.metadata.file_name.clone(),
content_type: self.0.metadata.content_type.clone(),
};
metadata.serialize(serializer)
}
}
pub fn serialize_duration_seconds<S>(
duration: &std::time::Duration,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let seconds = duration.as_secs();
seconds.serialize(serializer)
}
pub fn deserialize_duration_seconds<'de, D>(
deserializer: D,
) -> Result<std::time::Duration, D::Error>
where
D: serde::Deserializer<'de>,
{
let seconds = u64::deserialize(deserializer)?;
Ok(std::time::Duration::from_secs(seconds))
}
pub fn serialize_duration_seconds_option<S>(
duration: &Option<std::time::Duration>,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match duration {
Some(duration) => serialize_duration_seconds(duration, serializer),
None => serializer.serialize_none(),
}
}
pub fn deserialize_duration_seconds_option<'de, D>(
deserializer: D,
) -> Result<Option<std::time::Duration>, D::Error>
where
D: serde::Deserializer<'de>,
{
Ok(deserialize_duration_seconds(deserializer).ok())
}
// file: src/util/tristate.rs
use std::ops::Deref;
use serde::{Deserialize, Deserializer, Serialize};
use validator::ValidateLength;
#[derive(Debug, Clone)]
pub enum TriState<T> {
Value(T),
None,
Absent,
}
impl<T> TriState<T> {
pub fn as_deref<D: ?Sized>(&self) -> TriState<&D>
where
T: Deref<Target = D>,
{
match self {
TriState::Value(v) => TriState::Value(v.deref()),
TriState::None => TriState::None,
TriState::Absent => TriState::Absent,
}
}
pub fn is_value(&self) -> bool {
matches!(self, TriState::Value(_))
}
pub fn is_none(&self) -> bool {
matches!(self, TriState::None)
}
pub fn is_absent(&self) -> bool {
matches!(self, TriState::Absent)
}
pub fn as_option(&self) -> Option<&T> {
match self {
TriState::Value(v) => Some(v),
TriState::None => None,
TriState::Absent => None,
}
}
pub fn into_option(self) -> Option<T> {
match self {
TriState::Value(v) => Some(v),
TriState::None => None,
TriState::Absent => None,
}
}
}
impl<T> Default for TriState<T> {
fn default() -> Self {
TriState::Absent
}
}
pub(crate) struct TriStateFieldVisitor<T> {
marker: std::marker::PhantomData<T>,
}
impl<T> Serialize for TriState<T>
where
T: Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
TriState::Value(v) => v.serialize(serializer),
TriState::None => serializer.serialize_none(),
TriState::Absent => serializer.serialize_unit(),
}
}
}
impl<'de, T> Deserialize<'de> for TriState<T>
where
T: Deserialize<'de>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_option(TriStateFieldVisitor::<T> {
marker: std::marker::PhantomData,
})
}
}
impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor<T>
where
T: Deserialize<'de>,
{
type Value = TriState<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("TriStateField<T>")
}
#[inline]
fn visit_none<E>(self) -> Result<TriState<T>, E>
where
E: serde::de::Error,
{
Ok(TriState::None)
}
#[inline]
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
T::deserialize(deserializer).map(TriState::Value)
}
#[inline]
fn visit_unit<E>(self) -> Result<TriState<T>, E>
where
E: serde::de::Error,
{
Ok(TriState::None)
}
}
impl<T> ValidateLength<u64> for TriState<T>
where
T: ValidateLength<u64>,
{
fn length(&self) -> Option<u64> {
match self {
TriState::Value(v) => v.length(),
TriState::None => None,
TriState::Absent => None,
}
}
}
// file: src/web/context.rs
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use crate::entity;
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
pub struct UserContext {
pub user_id: entity::user::Id,
}
pub type UserContextResult = Result<UserContext, Error>;
#[derive(Debug, Clone, Copy, derive_more::Display)]
pub enum Error {
NotInRequest,
NotInHeader,
BadCharacters,
WrongTokenType,
BadToken,
Model,
}
impl<S: Send + Sync> FromRequestParts<S> for UserContext {
type Rejection = super::error::Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let context = parts
.extensions
.get::<UserContextResult>()
.cloned()
.ok_or(Error::NotInRequest)??;
Ok(context)
}
}
// file: src/web/entity/file.rs
use serde::Serialize;
use crate::entity::file::Id;
use crate::util;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct File {
pub id: Id,
pub filename: String,
pub content_type: String,
pub size: i64,
pub url: String,
}
impl From<crate::entity::file::File> for File {
fn from(file: crate::entity::file::File) -> Self {
Self {
id: file.id,
filename: file.filename,
content_type: file.content_type,
size: file.size,
url: util::file_id_to_url(&file.id).unwrap_or_default(),
}
}
}
// file: src/web/entity/message.rs
use serde::Serialize;
use crate::entity::message::Id;
use crate::entity::{channel, user};
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
pub id: Id,
pub author_id: user::Id,
pub channel_id: channel::Id,
pub content: String,
pub created_at: chrono::DateTime<chrono::Utc>,
pub attachments: Vec<super::file::File>,
}
impl Message {
pub fn from_message_with_attachments(
message: crate::entity::message::Message,
attachments: Vec<super::file::File>,
) -> Self {
Self {
id: message.id,
author_id: message.author_id,
channel_id: message.channel_id,
content: message.content,
created_at: message
.id
.get_timestamp()
.as_ref()
.map(uuid::Timestamp::to_unix)
.map(|(secs, nsecs)| {
chrono::DateTime::<chrono::Utc>::from_timestamp(secs as i64, nsecs)
})
.flatten()
.unwrap_or_default(),
attachments,
}
}
}
// file: src/web/entity/mod.rs
pub mod file;
pub mod message;
pub mod server;
pub mod user;
// file: src/web/entity/server.rs
use serde::Serialize;
use crate::entity::server::Id;
use crate::entity::user;
use crate::util;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Server {
pub id: Id,
pub owner_id: user::Id,
pub name: String,
pub icon_url: Option<String>,
}
impl From<crate::entity::server::Server> for Server {
fn from(server: crate::entity::server::Server) -> Self {
Self {
id: server.id,
owner_id: server.owner_id,
name: server.name,
icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(),
}
}
}
// file: src/web/entity/user.rs
use crate::entity::user;
use crate::util;
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct FullUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct PartialUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub bot: bool,
pub system: bool,
}
impl From<user::User> for FullUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
username: user.username,
display_name: user.display_name,
email: user.email,
bot: user.bot,
system: user.system,
settings: user.settings,
}
}
}
impl From<user::User> for PartialUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
username: user.username,
display_name: user.display_name,
bot: user.bot,
system: user.system,
}
}
}
// file: src/web/error.rs
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use crate::web::context;
use crate::{database, jwt, object_store};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From)]
pub enum Error {
#[from]
Context(context::Error),
#[from]
Jwt(jwt::Error),
#[from]
Hash(argon2::password_hash::Error),
#[from]
Database(database::Error),
#[from]
ObjectStore(object_store::Error),
#[from]
Json(serde_json::error::Error),
#[from]
JsonRejection(axum::extract::rejection::JsonRejection),
#[from]
Client(ClientError),
}
#[derive(Debug, Clone, derive_more::From, serde::Serialize)]
#[serde(tag = "message", content = "details")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
UserAlreadyExists,
UserDoesNotExist,
ServerDoesNotExist,
ChannelDoesNotExist,
MessageDoesNotExist,
NotAuthorized,
WrongPassword,
NotAllowed,
InvalidJson(JsonRejection),
#[from]
ValidationFailed(validator::ValidationErrors),
InternalServerError,
Unknown,
}
#[derive(derive_more::Debug, Clone, serde::Serialize)]
pub struct JsonRejection {
#[serde(skip)]
status: StatusCode,
reason: String,
}
impl From<&axum::extract::rejection::JsonRejection> for JsonRejection {
fn from(value: &axum::extract::rejection::JsonRejection) -> Self {
use std::error::Error;
use axum::extract::rejection::JsonRejection::*;
let reason = match value {
JsonDataError(e) => e.source().map(ToString::to_string),
JsonSyntaxError(e) => e.source().map(ToString::to_string),
MissingJsonContentType(e) => e.source().map(ToString::to_string),
BytesRejection(e) => e.source().map(ToString::to_string),
_ => None,
}
.unwrap_or_else(|| value.body_text());
Self {
status: value.status(),
reason,
}
}
}
impl Error {
pub fn as_client_error(&self) -> ClientError {
match self {
Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized,
Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists,
Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist,
Error::Database(database::Error::ServerDoesNotExists) => {
ClientError::ServerDoesNotExist
},
Error::Database(database::Error::ChannelDoesNotExists) => {
ClientError::ChannelDoesNotExist
},
Error::Database(database::Error::MessageDoesNotExists) => {
ClientError::MessageDoesNotExist
},
// Error::WrongPassword => ClientError::WrongPassword,
// Error::NotAllowed => ClientError::NotAllowed,
Error::JsonRejection(e) => ClientError::InvalidJson(e.into()),
Error::Client(client_error) => client_error.clone(),
_ => ClientError::InternalServerError,
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
response.extensions_mut().insert(Arc::new(self));
response
}
}
impl ClientError {
pub fn status_code(&self) -> StatusCode {
match self {
ClientError::UserAlreadyExists => StatusCode::CONFLICT,
ClientError::UserDoesNotExist
| ClientError::ServerDoesNotExist
| ClientError::ChannelDoesNotExist
| ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND,
ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED,
ClientError::NotAllowed => StatusCode::FORBIDDEN,
ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY,
ClientError::InvalidJson(e) => e.status,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}
// file: src/web/middleware/auth.rs
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::{Extension, RequestExt};
use axum_extra::TypedHeader;
use axum_extra::headers::Authorization;
use axum_extra::headers::authorization::Bearer;
use axum_extra::typed_header::TypedHeaderRejectionReason;
use crate::jwt;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::{self, context};
pub async fn require_context(
Extension(context): Extension<context::UserContextResult>,
request: Request,
next: Next,
) -> web::error::Result<Response> {
context?;
Ok(next.run(request).await)
}
pub async fn resolve_context(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> impl IntoResponse {
let context = get_context(&state, &mut request).await;
request.extensions_mut().insert(context);
next.run(request).await
}
async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult {
let bearer = request
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|error| match error.reason() {
TypedHeaderRejectionReason::Missing => context::Error::NotInHeader,
TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType,
_ => context::Error::NotInHeader,
})?;
let token = bearer.token();
let context = get_context_from_token(state, token).await?;
Ok(context)
}
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
let context = jwt::verify_jwt::<UserContext>(
token,
crate::config::config().security.auth_secret.as_ref(),
)
.map_err(|_| context::Error::BadToken)?;
let _ = state
.database
.select_user_by_id(context.user_id)
.await
.map_err(|_| context::Error::BadToken)?;
Ok(context)
}
// file: src/web/middleware/mod.rs
mod auth;
mod response_map;
pub use auth::*;
pub use response_map::*;
// file: src/web/middleware/response_map.rs
use std::sync::Arc;
use axum::Json;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::IntoResponse;
use serde_json::json;
use crate::web;
pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
let response = next.run(request).await;
let error = response.extensions().get::<Arc<web::Error>>();
if error.is_some() {
tracing::error!("{:?}", error);
}
let error_response = error.map(|error| {
let client_error = error.as_client_error();
(
client_error.status_code(),
Json(json!({"error": client_error})),
)
.into_response()
});
error_response.unwrap_or(response)
}
// file: src/web/mod.rs
mod context;
mod entity;
mod error;
mod middleware;
mod route;
pub mod ws;
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
pub use self::error::{Error, Result};
use crate::{config, state};
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
let config = config::config();
let addr: std::net::SocketAddr = (config.server.host, config.server.port).into();
let listener = tokio::net::TcpListener::bind(addr).await?;
tracing::info!("listening on {}", addr);
axum::serve(listener, router(state))
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
fn router(state: state::AppState) -> axum::Router {
use axum::Router;
use axum::routing::*;
use self::route::*;
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::any())
.allow_methods(AllowMethods::any())
.allow_headers(AllowHeaders::any());
Router::new()
// websocket
.route("/gateway/ws", get(ws::gateway::ws_handler))
.route("/voice/ws", get(ws::voice::ws_handler))
// file
.route("/files/{file_id}", get(file::get))
// api
.nest(
"/api/v1",
Router::new()
.route("/auth/login", post(auth::login))
.route("/auth/register", post(auth::register))
.merge(protected_router())
.route_layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::resolve_context,
)),
)
// middleware
.layer(axum::middleware::from_fn(middleware::response_map))
.layer(cors)
.layer(tower_http::trace::TraceLayer::new_for_http())
.with_state(state.clone())
}
fn protected_router() -> axum::Router<state::AppState> {
use axum::Router;
use axum::routing::*;
use self::route::*;
Router::new()
// user
.route("/users/@me", get(user::me))
.route("/users/@me", patch(user::patch))
.route("/users/@me/channels", get(user::channel::list))
.route("/users/@me/channels", post(user::channel::create))
.route("/users/{id}", get(user::get_by_id))
// channel
.route(
"/channels/{channel_id}/messages",
get(channel::message::page),
)
.route(
"/channels/{channel_id}/messages",
post(channel::message::create),
)
// server
.route("/servers", get(server::list))
.route("/servers", post(server::create))
.route("/servers/{server_id}", get(server::get))
.route("/servers/{server_id}", delete(server::delete))
.route("/servers/{server_id}/channels", get(server::channel::list))
.route(
"/servers/{server_id}/channels",
post(server::channel::create),
)
.route(
"/servers/{server_id}/channels/{channel_id}",
get(server::channel::get),
)
.route(
"/servers/{server_id}/channels/{channel_id}",
delete(server::channel::delete),
)
// invite
.route("/servers/{server_id}/invites", post(server::invite::create))
.route("/invites/{code}", get(server::invite::get))
// file
.route("/files", post(file::upload))
// middleware
.route_layer(axum::middleware::from_fn(middleware::require_context))
}
async fn shutdown_signal() {
_ = tokio::signal::ctrl_c().await;
}
// file: src/web/route/auth/login.rs
use argon2::{PasswordHash, PasswordVerifier};
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{jwt, web};
use crate::web::entity::user::FullUser;
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LoginPayload {
username: String,
password: String,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LoginResponse {
user: FullUser,
token: String,
}
pub async fn login(
State(state): State<AppState>,
Json(payload): Json<LoginPayload>,
) -> web::Result<impl IntoResponse> {
let user = state
.database
.select_user_by_username(&payload.username)
.await?;
let password_hash = PasswordHash::new(&user.password_hash)?;
state
.hasher
.verify_password(payload.password.as_bytes(), &password_hash)
.map_err(|_| web::error::ClientError::WrongPassword)?;
let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?;
let response = LoginResponse {
user: user.into(),
token,
};
Ok(Json(response))
}
// file: src/web/route/auth/mod.rs
pub mod login;
pub mod register;
pub use login::*;
pub use register::*;
// file: src/web/route/auth/register.rs
use argon2::password_hash::rand_core::OsRng;
use argon2::password_hash::{PasswordHasher, SaltString};
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web;
use crate::web::entity::user::FullUser;
use crate::web::error::ClientError;
#[derive(Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPayload {
#[validate(regex(
path = "crate::entity::user::USERNAME_REGEX",
code = "invalid_username"
))]
#[validate(length(min = 3, max = 64))]
username: String,
#[validate(length(min = 1, max = 64))]
display_name: Option<String>,
#[validate(email)]
#[validate(length(min = 3, max = 128))]
email: String,
#[validate(length(min = 3, max = 128))]
password: String,
}
pub async fn register(
State(state): State<AppState>,
WithRejection(Json(payload), _): WithRejection<Json<RegisterPayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
let salt = SaltString::generate(&mut OsRng);
let password_hash = state
.hasher
.hash_password(payload.password.as_bytes(), &salt)?
.to_string();
let user = state
.database
.insert_user(
&payload.username,
payload.display_name.as_ref().map(String::as_str),
&payload.email,
&password_hash,
)
.await?;
let system_user = state.database.select_user_by_username("system").await?;
if system_user.system {
state
.database
.procedure_create_dm_channel(user.id, system_user.id)
.await?;
}
Ok(Json(FullUser::from(user)))
}
// file: src/web/route/channel/message/create.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::file::File;
use crate::web::entity::message::Message;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, serde::Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
#[validate(schema(function = "validate_create_payload"))]
pub struct CreatePayload {
#[validate(length(min = 0, max = 2000))]
pub content: String,
#[serde(default)]
#[validate(length(min = 0, max = 10))]
pub attachments: Vec<entity::file::Id>,
}
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
if payload.content.is_empty() && payload.attachments.is_empty() {
return Err(validator::ValidationError::new(
"at_least_one_field_required",
));
}
Ok(())
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
Json(payload): Json<CreatePayload>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let message = state
.database
.insert_channel_message(context.user_id, channel_id, &payload.content)
.await?;
for attachment in payload.attachments.iter() {
state
.database
.insert_message_attachment(message.id, *attachment)
.await?;
}
let attachments = state
.database
.select_message_attachments(message.id)
.await?
.into_iter()
.map(File::from)
.collect::<Vec<_>>();
let message = Message::from_message_with_attachments(message, attachments);
// TODO: check permissions
ws::gateway::util::send_message_channel(
state,
message.channel_id,
ws::gateway::event::Event::AddMessage {
channel_id,
message: message.clone(),
},
);
Ok(Json(message))
}
// file: src/web/route/channel/message/mod.rs
mod page;
mod create;
pub use page::page;
pub use create::create;
// file: src/web/route/channel/message/page.rs
use axum::Json;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use serde::Deserialize;
use tracing::error;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::message::Message;
use crate::{entity, web};
#[derive(Debug, Deserialize, Validate)]
#[serde(rename_all = "camelCase")]
pub struct PageParams {
#[serde(default = "limit_default")]
#[validate(range(min = 1, max = 100))]
pub limit: u32,
#[serde(default)]
pub before: Option<entity::message::Id>,
}
fn limit_default() -> u32 {
50
}
pub async fn page(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
Query(params): Query<PageParams>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
match params.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let messages_futures = state
.database
.select_channel_messages_paginated(channel_id, params.before, params.limit as i64)
.await?
.into_iter()
.map(|message| async {
let attachments = state
.database
.select_message_attachments(message.id)
.await?
.into_iter()
.map(web::entity::file::File::from)
.collect::<Vec<_>>();
Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments))
});
let messages = futures::future::try_join_all(messages_futures).await?;
Ok(Json(messages))
}
// file: src/web/route/channel/mod.rs
pub mod message;
// file: src/web/route/file/get.rs
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::{entity, object_store, web};
pub async fn get(
State(state): State<AppState>,
Path(file_id): Path<entity::file::Id>,
) -> web::Result<impl IntoResponse> {
let file = match state.database.select_file_by_id(file_id).await {
Ok(file) => file,
Err(e) => {
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
},
};
let data = match state
.object_store
.get_object_stream(&file.id.to_string())
.await
{
Ok(data) => data,
Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => {
let _ = state.database.delete_file_by_id(file.id).await?;
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
},
Err(e) => {
return Err(object_store::Error::from(e).into());
},
};
let headers = axum::response::AppendHeaders([
(axum::http::header::CONTENT_TYPE, file.content_type.clone()),
(
axum::http::header::CONTENT_DISPOSITION,
format!("filename=\"{}\"", file.filename),
),
]);
Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response())
}
// file: src/web/route/file/mod.rs
mod get;
mod upload;
pub use get::get;
pub use upload::upload;
// file: src/web/route/file/upload.rs
use axum::Json;
use axum::body::Bytes;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
use validator::Validate;
use crate::state::AppState;
use crate::util::SerdeFieldData;
use crate::{object_store, web};
#[derive(Debug, Validate, TryFromMultipart)]
#[try_from_multipart(rename_all = "camelCase")]
pub struct UploadPayload {
#[form_data(limit = "50MB")]
#[validate(length(min = 1, max = 16))]
files: Vec<SerdeFieldData<Bytes>>,
}
pub async fn upload(
State(state): State<AppState>,
TypedMultipart(payload): TypedMultipart<UploadPayload>,
) -> web::Result<impl IntoResponse> {
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let mut file_ids = Vec::new();
for file in payload.files {
let db_file = state
.database
.insert_file(
file.metadata
.file_name
.as_deref()
.unwrap_or_else(|| "unknown"),
file.metadata.content_type.as_deref().unwrap_or_default(),
file.contents.len(),
)
.await?;
state
.object_store
.put_object(&db_file.id.to_string(), &file.contents)
.await
.map_err(object_store::Error::from)?;
file_ids.push(db_file.id);
}
Ok(Json(file_ids))
}
// file: src/web/route/mod.rs
pub mod auth;
pub mod channel;
pub mod file;
pub mod server;
pub mod user;
// file: src/web/route/server/channel/create.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::error::ClientError;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, Validate, Deserialize)]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
name: String,
#[validate(custom(function = "validate_server_channel_type"))]
r#type: entity::channel::ChannelType,
}
fn validate_server_channel_type(
r#type: &entity::channel::ChannelType,
) -> Result<(), validator::ValidationError> {
match r#type {
entity::channel::ChannelType::ServerText => Ok(()),
entity::channel::ChannelType::ServerVoice => Ok(()),
entity::channel::ChannelType::ServerCategory => Ok(()),
_ => Err(validator::ValidationError::new("invalid_channel_type")),
}
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
// TODO: check permissions
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
let channel = state
.database
.insert_server_channel(&payload.name, 0, payload.r#type, server_id, None)
.await?;
ws::gateway::util::send_message_server(
state,
server_id,
ws::gateway::event::Event::AddServerChannel {
channel: channel.clone(),
},
);
Ok(Json(channel))
}
// file: src/web/route/server/channel/delete.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::webrtc::WebRtcSignal;
use crate::{entity, web};
pub async fn delete(
State(state): State<AppState>,
context: UserContext,
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let channel = state.database.select_channel_by_id(channel_id).await?;
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
if let Some(channel_server_id) = channel.server_id {
if channel_server_id != server_id {
return Err(web::error::ClientError::NotAllowed.into());
}
} else {
return Err(web::error::ClientError::NotAllowed.into());
}
let channel = state.database.delete_channel_by_id(channel_id).await?;
ws::gateway::util::send_message_server(
state.clone(),
server_id,
ws::gateway::event::Event::RemoveServerChannel {
server_id: server_id.clone(),
channel_id: channel.id.clone(),
},
);
tokio::spawn(async move {
let voice_rooms = state.voice_rooms.read().await;
if let Some(voice_room) = voice_rooms.get(&channel_id) {
let _ = voice_room.send(WebRtcSignal::Close);
}
});
Ok(Json(channel))
}
// file: src/web/route/server/channel/get.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
pub async fn get(
State(state): State<AppState>,
context: UserContext,
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let channel = state.database.select_channel_by_id(channel_id).await?;
if let Some(channel_server_id) = channel.server_id {
if channel_server_id != server_id {
return Err(web::error::ClientError::NotAllowed.into());
}
} else {
return Err(web::error::ClientError::NotAllowed.into());
}
Ok(Json(channel))
}
// file: src/web/route/server/channel/list.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
pub async fn list(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
let channels = state.database.select_server_channels(server_id).await?;
Ok(Json(channels))
}
// file: src/web/route/server/channel/mod.rs
mod create;
mod delete;
mod get;
mod list;
pub use create::create;
pub use delete::delete;
pub use get::get;
pub use list::list;
// file: src/web/route/server/create.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use axum_typed_multipart::TryFromMultipart;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::error::ClientError;
use crate::web::ws;
use crate::{entity, web};
use crate::web::entity::server::Server;
#[derive(Debug, Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
name: String,
icon_id: Option<entity::file::Id>,
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
let server = state
.database
.insert_server(&payload.name, payload.icon_id, context.user_id)
.await?;
let role = state
.database
.insert_server_role(
Some(server.id.into()),
server.id,
"@everyone",
None,
false,
serde_json::json!({}),
0,
)
.await?;
let member = state
.database
.insert_server_member(server.id, context.user_id)
.await?;
state
.database
.insert_server_member_role(member.id, role.id)
.await?;
let server = Server::from(server);
ws::gateway::util::send_message(
state,
context.user_id,
ws::gateway::event::Event::AddServer {
server: server.clone(),
},
);
Ok(Json(server))
}
// file: src/web/route/server/delete.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::webrtc::WebRtcSignal;
use crate::{entity, web};
use crate::web::entity::server::Server;
pub async fn delete(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
let server = state.database.select_server_by_id(server_id).await?;
if server.owner_id != context.user_id {
return Err(web::error::ClientError::NotAllowed.into());
}
let members = state
.database
.select_server_members(server_id)
.await?
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
let channels = state
.database
.select_server_channels(server_id)
.await?
.iter()
.map(|c| c.id)
.collect::<Vec<_>>();
let state_clone = state.clone();
tokio::spawn(async move {
let voice_rooms = state_clone.voice_rooms.read().await;
for channel_id in channels {
if let Some(voice_room) = voice_rooms.get(&channel_id) {
let _ = voice_room.send(WebRtcSignal::Close);
}
}
});
let server = state.database.delete_server_by_id(server_id).await?;
ws::gateway::util::send_message_many(
state.clone(),
&members,
ws::gateway::event::Event::RemoveServer {
server_id: server.id,
},
);
Ok(Json(Server::from(server)))
}
// file: src/web/route/server/get.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::{entity, web};
use crate::web::entity::server::Server;
pub async fn get(
State(state): State<AppState>,
Path(server_id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let server = state.database.select_server_by_id(server_id).await?;
Ok(Json(Server::from(server)))
}
// file: src/web/route/server/invite/create.rs
use axum::Json;
use axum::extract::{Path, Query, State};
use axum::response::IntoResponse;
use base64::Engine;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
#[derive(serde::Deserialize, Debug)]
pub struct CreateParams {
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")]
#[serde(default)]
pub expires_in: Option<std::time::Duration>,
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
Path(server_id): Path<entity::server::Id>,
Query(params): Query<CreateParams>,
) -> web::Result<impl IntoResponse> {
// TODO: check permissions
let code = {
use rand::Rng;
let mut rng = rand::rng();
let mut code = [0u8; 32];
rng.fill(&mut code);
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code)
};
let expires_at = params.expires_in.map(|d| {
let now = chrono::Utc::now();
now + chrono::Duration::from_std(d).expect("valid duration")
});
let invite = state
.database
.insert_server_invite(&code, server_id, Some(context.user_id), expires_at)
.await?;
Ok(Json(invite))
}
// file: src/web/route/server/invite/get.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::ws;
use crate::{database, web};
use crate::web::entity::server::Server;
pub async fn get(
State(state): State<AppState>,
context: UserContext,
Path(code): Path<String>,
) -> web::Result<impl IntoResponse> {
let invite = state.database.select_server_invite_by_code(&code).await?;
let server = state.database.select_server_by_id(invite.server_id).await?;
let member = match state
.database
.insert_server_member(invite.server_id, context.user_id)
.await
{
Ok(member) => member,
Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))),
Err(e) => return Err(e.into()),
};
state
.database
.insert_server_member_role(member.id, invite.server_id)
.await?;
let user = state.database.select_user_by_id(context.user_id).await?;
ws::gateway::util::send_message_server(
state,
invite.server_id,
ws::gateway::event::Event::AddServerMember {
server_id: server.id,
member: user.into(),
},
);
Ok(Json(Server::from(server)))
}
// file: src/web/route/server/invite/mod.rs
mod create;
mod get;
pub use create::create;
pub use get::get;
// file: src/web/route/server/list.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
use crate::web::entity::server::Server;
pub async fn list(
State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> {
let servers = state
.database
.select_user_servers(context.user_id)
.await?
.into_iter()
.map(Server::from)
.collect::<Vec<_>>();
Ok(Json(servers))
}
// file: src/web/route/server/mod.rs
pub mod channel;
mod create;
mod delete;
mod get;
pub mod invite;
mod list;
pub use create::create;
pub use delete::delete;
pub use get::get;
pub use list::list;
// file: src/web/route/user/channel/create.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::user::PartialUser;
use crate::web::route::user::channel::RecipientChannel;
use crate::web::ws;
use crate::{entity, web};
#[derive(Debug, Validate, Deserialize)]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
recipients: Vec<entity::user::Id>,
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
match payload.validate() {
Ok(_) => {},
Err(err) => {
return Err(web::error::ClientError::ValidationFailed(err).into());
},
}
let channel_id = match payload.recipients.len() {
1 => {
let recipient = payload.recipients[0];
state
.database
.procedure_create_dm_channel(context.user_id, recipient)
.await?
},
_ => {
state
.database
.procedure_create_group_channel(context.user_id, &payload.recipients)
.await?
},
};
let channel = state.database.select_channel_by_id(channel_id).await?;
let recipients = state
.database
.select_channel_recipients(channel_id)
.await?
.unwrap_or_default()
.into_iter()
.map(|user| user.id)
.collect::<Vec<_>>();
let recipient_channels = RecipientChannel {
channel: channel.clone(),
recipients: recipients.clone(),
};
ws::gateway::util::send_message_channel(
state,
channel_id,
ws::gateway::event::Event::AddDmChannel {
channel,
recipients,
},
);
Ok(Json(recipient_channels))
}
// file: src/web/route/user/channel/list.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
use crate::web::route::user::channel::RecipientChannel;
pub async fn list(
State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> {
let channels = state.database.select_user_channels(context.user_id).await?;
let mut recipient_channels = Vec::new();
for channel in channels {
let recipients = state.database.select_channel_recipients(channel.id).await?;
let recipients = match recipients {
Some(recipients) => recipients
.into_iter()
.map(|user| user.id)
.collect(),
None => {
continue;
},
};
recipient_channels.push(RecipientChannel {
channel,
recipients,
});
}
Ok(Json(recipient_channels))
}
// file: src/web/route/user/channel/mod.rs
mod create;
mod list;
pub use create::create;
pub use list::list;
use serde::Serialize;
use crate::entity::{channel, user};
#[derive(Debug, Serialize)]
#[serde(rename_all = "camelCase")]
struct RecipientChannel {
#[serde(flatten)]
pub channel: channel::Channel,
pub recipients: Vec<user::Id>,
}
// file: src/web/route/user/get.rs
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::entity::user::PartialUser;
pub async fn get_by_id(
State(state): State<AppState>,
Path(user_id): Path<uuid::Uuid>,
) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(user_id).await?;
Ok(Json(PartialUser::from(user)))
}
// file: src/web/route/user/me.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
use crate::web::entity::user::FullUser;
pub async fn me(
State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(context.user_id).await?;
Ok(Json(FullUser::from(user)))
}
// file: src/web/route/user/mod.rs
pub mod channel;
mod get;
mod me;
mod patch;
pub use get::get_by_id;
pub use me::me;
pub use patch::patch;
// file: src/web/route/user/patch.rs
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::web::entity::user::{FullUser, PartialUser};
use crate::web::ws;
use crate::{entity, util, web};
#[derive(Debug, Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
#[validate(schema(function = "validate_create_payload"))]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
#[serde(default)]
display_name: util::tristate::TriState<String>,
#[serde(default)]
avatar_id: util::tristate::TriState<entity::file::Id>,
}
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
if payload.display_name.is_absent() && payload.avatar_id.is_absent() {
return Err(validator::ValidationError::new(
"at_least_one_field_required",
));
}
Ok(())
}
pub async fn patch(
State(state): State<AppState>,
context: UserContext,
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
match payload.validate() {
Ok(_) => {},
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
};
let user = state
.database
.update_user_by_id(
context.user_id,
payload.display_name.as_deref(),
payload.avatar_id,
)
.await?;
ws::gateway::util::send_message(
state.clone(),
context.user_id,
ws::gateway::event::Event::AddUser {
user: PartialUser::from(user.clone()),
},
);
ws::gateway::util::send_message_related(
state.clone(),
context.user_id,
ws::gateway::event::Event::AddUser {
user: PartialUser::from(user.clone()),
},
);
Ok(Json(FullUser::from(user)))
}
// file: src/web/ws/error.rs
pub type Result<T, E> = std::result::Result<T, Error<E>>;
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error<T: CustomError> {
#[from]
Custom(T),
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
WrongMessageType,
WebSocketClosed,
UnknownError,
}
pub trait CustomError {}
// file: src/web/ws/gateway/connection.rs
use axum::extract::ws::Message as AxumMessage;
use base64::Engine as _;
use futures::{Stream, StreamExt};
use sha2::{Digest, Sha256};
use tokio::sync::mpsc;
use super::error::Error as WsError;
use super::event::Event as WsEvent;
use super::protocol::{WsClientMessage, WsServerMessage};
use super::state::{WsContext, WsState, WsUserContext};
use crate::jwt;
use crate::state::AppState;
use crate::web::ws::gateway::SessionKey;
use crate::web::ws::general::WebSocketHandler;
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
use crate::web::ws::voice;
use crate::webrtc::WebRtcSignal;
impl WebSocketHandler for WsContext {
type ServerMessage = WsServerMessage;
type ClientMessage = WsClientMessage;
type Error = WsError;
async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
process_websocket_messages(self, stream, sender, app_state).await?;
Ok(())
}
async fn cleanup(&mut self, app_state: &AppState) {
if let Some(user_ctx_data) = &self.user_context {
app_state
.unregister_gateway_connected_user(
user_ctx_data.user_id,
&user_ctx_data.session_key,
)
.await;
}
drop(self.event_channel.take());
}
async fn handle_result_error(
&mut self,
error: Self::Error,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
) {
let error_ws_message = WsServerMessage::Error { code: error };
let _ = sender.send(SendWsMessage::new_no_response(error_ws_message));
}
}
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))]
async fn process_websocket_messages<S>(
context: &mut WsContext,
mut ws_stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
loop {
match context.connection_state {
WsState::Initialize => {
tokio::select! {
biased;
maybe_message = ws_stream.next() => {
match maybe_message {
Some(Ok(message)) => {
handle_initial_message(context, message, sender, app_state).await?;
context.connection_state = WsState::Connected;
tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected.");
}
Some(Err(axum_ws_err)) => {
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
None => {
tracing::debug!("WebSocket stream ended by client during Initialize state.");
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
}
}
}
},
WsState::Connected => {
let user_ctx = context
.user_context
.as_ref()
.expect("User context must be set in Connected state");
let (_event_tx_ref, event_rx) = context
.event_channel
.as_mut()
.expect("Event channel must be set in Connected state");
tokio::select! {
biased;
maybe_app_event = event_rx.recv() => {
if let Some(app_event_data) = maybe_app_event {
SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?;
} else {
tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket.");
return Ok(());
}
}
maybe_ws_message = ws_stream.next() => {
match maybe_ws_message {
Some(Ok(message)) => {
handle_connected_message(context, message, sender, &app_state).await?;
}
Some(Err(axum_ws_err)) => {
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
None => {
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
return Err(crate::web::ws::error::Error::WebSocketClosed);
}
}
}
}
},
}
}
}
#[tracing::instrument(skip_all, fields(state = ?context.connection_state))]
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError> {
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => {
match crate::web::middleware::get_context_from_token(&app_state, &token).await {
Ok(auth_user_context) => {
let user_id = auth_user_context.user_id;
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
context.event_channel = Some((event_tx.clone(), event_rx));
let random_key_part = rand::random::<u64>();
let current_session_key: SessionKey = {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hasher.update(user_id.to_string().as_bytes());
hasher.update(&random_key_part.to_be_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD
.encode(hasher.finalize())
.into()
};
context.user_context = Some(WsUserContext {
user_id,
session_key: current_session_key.clone(),
});
app_state
.register_gateway_connected_user(
user_id,
current_session_key.clone(),
event_tx,
)
.await;
SendWsMessage::send_with_response(
sender,
WsServerMessage::AuthenticateAccepted {
user_id,
session_key: current_session_key,
},
)
.await?;
Ok(())
},
Err(_auth_err) => {
tracing::warn!(token = %token, "Authentication failed for token.");
let _ = SendWsMessage::send_with_response(
sender,
WsServerMessage::AuthenticateDenied,
)
.await;
Err(WsError::AuthenticationFailed.into())
},
}
},
#[allow(unreachable_patterns)]
_ => {
tracing::warn!("Unexpected message type received during Initialize state.");
Err(crate::web::ws::error::Error::WrongMessageType)
},
}
}
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
))]
async fn handle_connected_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), WsError> {
match deserialize_ws_message(message)? {
WsClientMessage::VoiceStateUpdate {
server_id,
channel_id,
} => {
let server_id = server_id.unwrap_or_else(|| channel_id.clone().into());
// TODO: check if can join this channel
let claims = voice::claims::VoiceClaims {
user_id: context
.user_context
.as_ref()
.expect("user context should be present")
.user_id,
server_id,
channel_id,
exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime)
.timestamp(),
};
let token = jwt::generate_jwt(
claims,
crate::config::config().security.voice_secret.as_ref(),
)
.map_err(|_| WsError::TokenGenerationFailed)?;
SendWsMessage::send_with_response(
sender,
WsServerMessage::Event {
event: WsEvent::VoiceServerUpdate {
server_id,
channel_id,
token,
},
},
)
.await?;
Ok(())
},
WsClientMessage::RequestVoiceStates { server_id } => {
let channels = app_state
.database
.select_server_channels(server_id)
.await
.map_err(|_| crate::web::ws::error::Error::UnknownError)?;
for channel in channels {
let (tx, rx) = tokio::sync::oneshot::channel();
let webrtc_sender =
{ app_state.voice_rooms.read().await.get(&channel.id).cloned() };
if let Some(voice_room) = webrtc_sender {
let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx });
let peers = match rx.await {
Ok(peers) => peers,
Err(_) => {
continue;
},
};
for peer in peers {
let _ =
sender.send(SendWsMessage::new_no_response(WsServerMessage::Event {
event: WsEvent::VoiceChannelConnected {
server_id,
channel_id: channel.id,
user_id: peer,
},
}));
}
}
}
Ok(())
},
other_message => {
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
Err(crate::web::ws::error::Error::WrongMessageType)
},
}
}
// file: src/web/ws/gateway/error.rs
use crate::web::ws::error::CustomError;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Error {
AuthenticationFailed,
TokenGenerationFailed,
}
impl CustomError for Error {}
// file: src/web/ws/gateway/event.rs
use crate::{entity, web};
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Event {
#[serde(rename_all = "camelCase")]
AddServer { server: web::entity::server::Server },
#[serde(rename_all = "camelCase")]
RemoveServer { server_id: entity::server::Id },
#[serde(rename_all = "camelCase")]
AddDmChannel {
channel: entity::channel::Channel,
recipients: Vec<entity::user::Id>,
},
#[serde(rename_all = "camelCase")]
RemoveDmChannel { channel_id: entity::channel::Id },
#[serde(rename_all = "camelCase")]
AddServerChannel { channel: entity::channel::Channel },
#[serde(rename_all = "camelCase")]
RemoveServerChannel {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
},
#[serde(rename_all = "camelCase")]
AddUser {
user: web::entity::user::PartialUser,
},
#[serde(rename_all = "camelCase")]
RemoveUser { user_id: entity::user::Id },
#[serde(rename_all = "camelCase")]
AddServerMember {
server_id: entity::server::Id,
member: web::entity::user::PartialUser,
},
#[serde(rename_all = "camelCase")]
RemoveServerMember {
server_id: entity::server::Id,
member_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
AddMessage {
channel_id: entity::channel::Id,
message: web::entity::message::Message,
},
#[serde(rename_all = "camelCase")]
RemoveMessage {
channel_id: entity::channel::Id,
message_id: entity::message::Id,
},
#[serde(rename_all = "camelCase")]
VoiceChannelConnected {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
VoiceChannelDisconnected {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
#[serde(rename_all = "camelCase")]
VoiceServerUpdate {
server_id: entity::server::Id,
channel_id: entity::channel::Id,
token: String,
},
}
// file: src/web/ws/gateway/mod.rs
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use dashmap::DashMap;
use crate::state::AppState;
use crate::web::ws::gateway::state::{EventSender, WsContext};
use crate::web::ws::general;
mod connection;
mod error;
pub mod event;
mod protocol;
mod state;
pub mod util;
#[derive(
Debug,
Clone,
Eq,
PartialEq,
Hash,
Default,
derive_more::Display,
serde::Serialize,
serde::Deserialize,
derive_more::From,
)]
#[serde(transparent)]
pub struct SessionKey(String);
#[derive(Debug, Default)]
pub struct GatewayWsState {
pub instances: DashMap<SessionKey, EventSender>,
}
pub async fn ws_handler(
State(app_state): State<AppState>,
ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| {
general::handle_websocket_connection(socket, app_state, WsContext::default())
}))
}
// file: src/web/ws/gateway/protocol.rs
use super::{SessionKey, error, event};
use crate::entity;
#[derive(Debug, serde::Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage {
AuthenticateDenied,
#[serde(rename_all = "camelCase")]
AuthenticateAccepted {
user_id: entity::user::Id,
session_key: SessionKey,
},
#[serde(rename_all = "camelCase")]
Event {
event: event::Event,
},
#[serde(rename_all = "camelCase")]
Error {
code: error::Error,
},
}
#[derive(Debug, serde::Deserialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage {
#[serde(rename_all = "camelCase")]
Authenticate { token: String },
#[serde(rename_all = "camelCase")]
VoiceStateUpdate {
server_id: Option<entity::server::Id>,
channel_id: entity::channel::Id,
},
#[serde(rename_all = "camelCase")]
RequestVoiceStates { server_id: entity::server::Id },
}
// file: src/web/ws/gateway/state.rs
use tokio::sync::mpsc;
use super::{SessionKey, event};
use crate::entity;
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum WsState {
Initialize,
Connected,
}
#[derive(Debug)]
pub struct WsUserContext {
pub user_id: entity::user::Id,
pub session_key: SessionKey, // Unique key for this specific WebSocket session instance
}
pub type EventSender = mpsc::UnboundedSender<event::Event>;
pub type EventReceiver = mpsc::UnboundedReceiver<event::Event>;
pub struct WsContext {
pub connection_state: WsState,
pub user_context: Option<WsUserContext>,
pub event_channel: Option<(EventSender, EventReceiver)>,
}
impl Default for WsContext {
fn default() -> Self {
Self {
connection_state: WsState::Initialize,
user_context: None,
event_channel: None,
}
}
}
// file: src/web/ws/gateway/util.rs
use crate::entity;
use crate::state::AppState;
use crate::web::ws::gateway::event;
pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) {
tokio::spawn(async move {
let connected_users = state.gateway_state.connected.get_async(&user_id).await;
if let Some(session) = connected_users {
for instance in session.instances.iter() {
if let Err(e) = instance.send(message.clone()) {
tracing::error!("failed to send message: {}", e);
}
}
}
});
}
pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) {
for id in user_ids.iter() {
send_message(state.clone(), *id, message.clone());
}
}
pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) {
tokio::spawn(async move {
let users = state
.database
.select_server_members(server_id)
.await
.unwrap_or_else(|_| vec![])
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
send_message_many(state, &users, message);
});
}
pub fn send_message_channel(
state: AppState,
channel_id: entity::channel::Id,
message: event::Event,
) {
tokio::spawn(async move {
let users = state
.database
.select_channel_members(channel_id)
.await
.unwrap_or_else(|_| vec![])
.iter()
.map(|u| u.id)
.collect::<Vec<_>>();
send_message_many(state, &users, message);
});
}
pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) {
tokio::spawn(async move {
let users = state
.database
.select_related_user_ids(user_id)
.await
.unwrap_or_else(|_| vec![]);
send_message_many(state, &users, message);
});
}
// file: src/web/ws/general.rs
use std::fmt::Debug;
use axum::extract::ws::WebSocket;
use futures::{Stream, StreamExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::mpsc;
use crate::state::AppState;
use crate::web::ws::error::CustomError;
use crate::web::ws::util;
use crate::web::ws::util::SendWsMessage;
pub trait WebSocketHandler {
type ServerMessage: Serialize + Send;
type ClientMessage: DeserializeOwned;
type Error: CustomError + Send + Debug;
async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> crate::web::ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>> + Unpin;
async fn cleanup(&mut self, app_state: &AppState);
async fn handle_result_error(
&mut self,
error: Self::Error,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
);
}
#[tracing::instrument(skip_all)]
pub async fn handle_websocket_connection(
websocket: WebSocket,
app_state: AppState,
mut handler: impl WebSocketHandler + 'static,
) {
let (ws_sink, ws_stream) = websocket.split();
let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel();
let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx);
let processing_result = handler
.handle_stream(ws_stream, &internal_send_tx, &app_state)
.await;
handler.cleanup(&app_state).await;
match processing_result {
Ok(_) => {},
Err(crate::web::ws::error::Error::Custom(err_to_report)) => {
handler
.handle_result_error(err_to_report, &internal_send_tx)
.await;
},
Err(e) => {
tracing::info!("WebSocket connection closed: {:?}", e);
},
}
drop(internal_send_tx);
if let Err(e) = writer_task.await {
tracing::error!(
"WebSocket writer task panicked or encountered an error: {:?}",
e
);
}
}
// file: src/web/ws/mod.rs
mod error;
pub mod gateway;
mod util;
pub mod voice;
mod general;
// file: src/web/ws/util.rs
use axum::extract::ws::Message as AxumMessage;
use futures::{Sink, SinkExt};
use serde::Serialize;
use serde::de::DeserializeOwned;
use tokio::sync::{mpsc, oneshot};
use crate::web::ws::error::CustomError;
pub fn spawn_writer_task<S, T, E>(
mut ws_sink: S,
mut writer_rx: mpsc::UnboundedReceiver<SendWsMessage<T, E>>,
) -> tokio::task::JoinHandle<()>
where
S: Sink<axum::extract::ws::Message> + Unpin + Send + 'static,
T: Serialize + Send + 'static,
E: CustomError + Send + 'static,
{
tokio::spawn(async move {
while let Some(SendWsMessage {
message,
response_ch,
}) = writer_rx.recv().await
{
let send_result = match serialize_ws_message(message) {
Ok(ws_msg) => {
if ws_sink.send(ws_msg).await.is_err() {
Err(super::error::Error::WebSocketClosed)
} else {
Ok(())
}
},
Err(e) => Err(e.into()),
};
if let Some(ch) = response_ch {
let _ = ch.send(send_result);
}
}
let _ = ws_sink.close().await;
})
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message<T: DeserializeOwned, E: CustomError>(
message: AxumMessage,
) -> super::error::Result<T, E> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from),
AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed),
_ => Err(super::error::Error::WrongMessageType),
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message<T: Serialize, E: CustomError>(
message: T,
) -> super::error::Result<AxumMessage, E> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(super::error::Error::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage<T, E: CustomError> {
pub message: T,
pub response_ch: Option<oneshot::Sender<super::error::Result<(), E>>>,
}
impl<T, E: CustomError> SendWsMessage<T, E> {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &mpsc::UnboundedSender<Self>, // Changed to reference
message: T,
) -> super::error::Result<(), E> {
let (response_tx, response_rx) = oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: T) -> Self {
Self {
message,
response_ch: None,
}
}
}
// file: src/web/ws/voice/claims.rs
use crate::entity;
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct VoiceClaims {
pub user_id: entity::user::Id,
pub server_id: entity::server::Id,
pub channel_id: entity::channel::Id,
pub exp: i64,
}
// file: src/web/ws/voice/connection.rs
use axum::extract::ws::Message as AxumMessage;
use futures::{Stream, StreamExt};
use tokio::sync::{mpsc, oneshot};
use super::error::{self, Error as WsError};
use super::protocol::{WsClientMessage, WsServerMessage};
use crate::jwt;
use crate::state::AppState;
use crate::web::ws;
use crate::web::ws::general::WebSocketHandler;
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
use crate::web::ws::voice::claims::VoiceClaims;
use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer;
use crate::web::ws::voice::state::{WsContext, WsState};
use crate::webrtc::{Offer, OfferSignal, WebRtcSignal};
impl WebSocketHandler for WsContext {
type ServerMessage = WsServerMessage;
type ClientMessage = WsClientMessage;
type Error = WsError;
async fn handle_stream<S>(
&mut self,
stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), Self::Error>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
process_websocket_messages(self, stream, sender, app_state).await?;
Ok(())
}
async fn cleanup(&mut self, app_state: &AppState) {
tracing::debug!("Cleaning up WebSocket connection.");
match &self.connection_state {
WsState::Connected {
signal_channel,
server_id,
channel_id,
user_id,
..
} => {
ws::gateway::util::send_message_channel(
app_state.clone(),
*channel_id,
ws::gateway::event::Event::VoiceChannelDisconnected {
server_id: *server_id,
channel_id: *channel_id,
user_id: *user_id,
},
);
let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone()));
},
WsState::Initialize => {},
}
}
async fn handle_result_error(
&mut self,
error: Self::Error,
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
) {
tracing::error!("WebSocket error: {:?}", error);
}
}
#[tracing::instrument(skip_all)]
async fn process_websocket_messages<S>(
context: &mut WsContext,
mut ws_stream: S,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
app_state: &AppState,
) -> ws::error::Result<(), WsError>
where
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
{
loop {
match &context.connection_state {
WsState::Initialize => {
while let Some(Ok(message)) = ws_stream.next().await {
handle_initial_message(context, message, sender, &app_state).await?;
break;
}
},
WsState::Connected { signal_channel, .. } => {
let signal_channel = signal_channel.clone();
loop {
tokio::select! {
biased;
_ = signal_channel.closed() => {
tracing::debug!("Signal channel closed.");
break;
}
Some(Ok(message)) = ws_stream.next() => {
handle_connected_message(context, message, sender, &app_state).await?;
}
else => {
break;
}
}
}
return Err(ws::error::Error::WebSocketClosed);
},
}
}
}
#[tracing::instrument(skip_all)]
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), error::Error> {
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => match jwt::verify_jwt::<VoiceClaims>(
&token,
crate::config::config().security.voice_secret.as_ref(),
) {
Ok(claims) => {
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted)
.await?;
let signal_channel =
ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await;
context.connection_state = WsState::Connected {
signal_channel,
server_id: claims.server_id,
channel_id: claims.channel_id,
user_id: claims.user_id,
};
ws::gateway::util::send_message_channel(
app_state.clone(),
claims.channel_id,
ws::gateway::event::Event::VoiceChannelConnected {
server_id: claims.server_id,
channel_id: claims.channel_id,
user_id: claims.user_id,
},
);
Ok(())
},
Err(auth_err) => {
tracing::warn!("Authentication failed: {:?}", auth_err);
let _ =
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied)
.await;
Err(error::Error::AuthenticationFailed.into())
},
},
#[allow(unreachable_patterns)]
_ => {
tracing::warn!("Unexpected message type received during Initialize state.");
Err(ws::error::Error::WrongMessageType)
},
}
}
#[tracing::instrument(skip_all)]
async fn handle_connected_message(
context: &mut WsContext,
message: AxumMessage,
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
app_state: &AppState,
) -> ws::error::Result<(), error::Error> {
match deserialize_ws_message(message)? {
WsClientMessage::SdpOffer { sdp } => {
let (signal_channel, user_id) = match &mut context.connection_state {
WsState::Connected {
signal_channel,
user_id,
..
} => (signal_channel.clone(), *user_id),
_ => return Err(ws::error::Error::WrongMessageType),
};
let (tx, rx) = oneshot::channel();
let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal {
offer: Offer {
peer_id: user_id,
sdp_offer: sdp,
},
response: tx,
}));
let answer_signal = rx.await?;
sender
.send(SendWsMessage::new_no_response(SdpAnswer {
sdp: answer_signal.sdp_answer,
}))
.map_err(|_| ws::error::Error::WebSocketClosed)?;
Ok(())
},
other_message => {
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
Err(ws::error::Error::WrongMessageType)
},
}
}
// file: src/web/ws/voice/error.rs
use crate::web::ws::error::CustomError;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error {
AuthenticationFailed,
}
impl CustomError for Error {}
// file: src/web/ws/voice/mod.rs
pub mod claims;
mod connection;
mod error;
mod protocol;
mod state;
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::ws::general;
use crate::web::ws::voice::state::WsContext;
pub async fn ws_handler(
State(app_state): State<AppState>,
ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| {
general::handle_websocket_connection(socket, app_state, WsContext::default())
}))
}
// file: src/web/ws/voice/protocol.rs
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
#[derive(Debug, serde::Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage {
AuthenticateDenied,
AuthenticateAccepted,
#[serde(rename_all = "camelCase")]
SdpAnswer {
sdp: RTCSessionDescription,
},
}
#[derive(Debug, serde::Deserialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage {
#[serde(rename_all = "camelCase")]
Authenticate { token: String },
#[serde(rename_all = "camelCase")]
SdpOffer { sdp: RTCSessionDescription },
}
// file: src/web/ws/voice/state.rs
use tokio::sync::oneshot;
use tokio::sync::mpsc;
use crate::entity;
use crate::state::AppState;
use crate::webrtc::{OfferSignal, WebRtcSignal};
#[derive(Debug)]
pub enum WsState {
Initialize,
Connected {
signal_channel: mpsc::UnboundedSender<WebRtcSignal>,
server_id: entity::server::Id,
channel_id: entity::channel::Id,
user_id: entity::user::Id,
},
}
pub struct WsContext {
pub connection_state: WsState,
}
impl Default for WsContext {
fn default() -> Self {
Self {
connection_state: WsState::Initialize,
}
}
}
pub async fn get_signaling_channel(
app_state: &AppState,
channel_id: entity::channel::Id,
) -> mpsc::UnboundedSender<WebRtcSignal> {
let room_sender = {
app_state
.voice_rooms
.read()
.await
.get(&channel_id)
.map(|room| room.clone())
};
match room_sender {
Some(room) => room,
None => {
let (tx, rx) = mpsc::unbounded_channel();
let app_state_ = app_state.clone();
tokio::spawn(async move {
crate::webrtc::webrtc_task(channel_id, rx)
.await
.unwrap_or_else(|err| {
tracing::error!("webrtc task error: {:?}", err);
});
{
app_state_.unregister_voice_room(channel_id).await;
}
});
{
app_state.register_voice_room(channel_id, tx.clone()).await;
}
tx
},
}
}
// file: src/webrtc/mod.rs
use std::sync::Arc;
use dashmap::DashMap;
use tracing::Instrument;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MIME_TYPE_OPUS;
use webrtc::api::{API, APIBuilder};
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
use webrtc::track::track_remote::TrackRemote;
use crate::entity;
type PeerId = entity::user::Id;
type RoomId = entity::channel::Id;
struct PeerState {
peer_id: PeerId,
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
ssrc: u32,
}
struct RoomState {
room_id: RoomId,
peers: DashMap<PeerId, PeerState>,
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
}
#[derive(Debug)]
pub struct Offer {
pub peer_id: PeerId,
pub sdp_offer: RTCSessionDescription,
}
#[derive(Debug)]
pub struct AnswerSignal {
pub sdp_answer: RTCSessionDescription,
}
#[derive(Debug)]
pub enum WebRtcSignal {
Offer(OfferSignal),
Disconnect(PeerId),
RequestPeers {
response: tokio::sync::oneshot::Sender<Vec<PeerId>>,
},
Close,
}
#[derive(Debug)]
pub struct OfferSignal {
pub offer: Offer,
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
}
#[tracing::instrument(skip(signal))]
pub async fn webrtc_task(
room_id: RoomId,
signal: tokio::sync::mpsc::UnboundedReceiver<WebRtcSignal>,
) -> anyhow::Result<()> {
tracing::info!("Starting WebRTC task");
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
let mut skip_timeout = false;
let state = Arc::new(RoomState {
room_id,
peers: DashMap::new(),
close_signal,
});
let mut signal = signal;
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
media_engine.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
let api = Arc::new(api);
loop {
tokio::select! {
biased;
_ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => {
tracing::debug!("initial timeout reached");
break;
}
_ = close_receiver.recv() => {
tracing::debug!("WebRTC task stopped");
break;
}
Some(signal) = signal.recv() => {
skip_timeout = true;
match signal {
WebRtcSignal::Offer(offer_signal) => {
let room_state = state.clone();
let api = api.clone();
tokio::spawn(async move {
if let Err(e) = handle_peer(api, room_state, offer_signal).await {
tracing::error!("error handling peer: {}", e);
}
}.instrument(tracing::Span::current()));
}
WebRtcSignal::RequestPeers { response } => {
let peers = state
.peers
.iter()
.map(|pair| pair.key().clone())
.collect::<Vec<_>>();
let _ = response.send(peers);
}
WebRtcSignal::Disconnect(peer_id) => {
tracing::debug!("received disconnect signal for peer {}", peer_id);
cleanup_peer(state.clone(), peer_id).await;
}
WebRtcSignal::Close => {
break;
}
}
}
}
}
Ok(())
}
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id
))]
async fn handle_peer(
api: Arc<API>,
room_state: Arc<RoomState>,
offer_signal: OfferSignal,
) -> anyhow::Result<()> {
tracing::debug!("handling peer");
let config = RTCConfiguration {
..Default::default()
};
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: MIME_TYPE_OPUS.to_string(),
..Default::default()
},
format!("audio-{}", offer_signal.offer.peer_id),
format!("room-{}", room_state.room_id),
));
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
let rtp_sender = peer_connection
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
.await?;
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
let outgoing_track_ = Arc::clone(&outgoing_track);
tokio::spawn(
async move {
let mut rtcp_buf = vec![0u8; 1500];
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
}
tracing::debug!(
"RTCP reader loop for outgoing track {} ended",
outgoing_track_.id()
);
}
.instrument(tracing::Span::current()),
);
let room_state_ = room_state.clone();
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
let room_state_ = room_state_.clone();
Box::pin(async move {
tracing::debug!("peer connection state changed: {:?}", state);
if state == RTCPeerConnectionState::Disconnected
|| state == RTCPeerConnectionState::Failed
{
tracing::debug!("peer connection closed");
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
}
})
}));
let room_state_ = room_state.clone();
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
tracing::debug!("track received: {:?}", track);
let room_state_ = room_state_.clone();
Box::pin(async move {
if track.kind() == RTPCodecType::Audio {
tracing::debug!("audio track received: {:?}", track);
tokio::spawn(async move {
if let Err(e) =
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
{
tracing::error!("error forwarding audio track: {}", e);
}
});
} else {
tracing::warn!("received non-audio track: {:?}", track);
}
})
}));
peer_connection
.set_remote_description(offer_signal.offer.sdp_offer)
.await?;
let answer = peer_connection.create_answer(None).await?;
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
peer_connection.set_local_description(answer).await?;
gathering_complete.recv().await;
tracing::debug!("ICE gathering complete");
let local_description = peer_connection
.local_description()
.await
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
let peer_state = PeerState {
peer_id: offer_signal.offer.peer_id,
peer_connection: Arc::clone(&peer_connection),
outgoing_audio_track: Arc::clone(&outgoing_track),
ssrc: 0,
};
{
if let Some((_, old_peer)) = room_state.peers.remove(&offer_signal.offer.peer_id) {
let _ = old_peer.peer_connection.close().await;
}
}
{
room_state
.peers
.insert(offer_signal.offer.peer_id, peer_state);
}
let _ = offer_signal.response.send(AnswerSignal {
sdp_answer: local_description,
});
Ok(())
}
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id
))]
async fn forward_audio_track(
room_state: Arc<RoomState>,
peer_id: PeerId,
track: Arc<TrackRemote>,
) -> anyhow::Result<()> {
let mut rtp_buf = vec![0u8; 1500];
{
let mut peer_state = room_state
.peers
.get_mut(&peer_id)
.expect("peer state not found");
peer_state.ssrc = track.ssrc();
}
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
let other_peer_tracks = room_state
.peers
.iter()
.filter_map(|pair| {
let peer_state = pair.value();
if peer_state.peer_id != peer_id {
let mut rtp_packet = rtp_packet.clone();
rtp_packet.header.ssrc = peer_state.ssrc;
Some((peer_state.outgoing_audio_track.clone(), rtp_packet))
} else {
None
}
})
.collect::<Vec<_>>();
if other_peer_tracks.is_empty() {
// tracing::warn!("no other peers to forward audio track to");
continue;
}
let write_futures = other_peer_tracks
.iter()
.map(|(outgoing_track, packet)| outgoing_track.write_rtp(&packet));
let results = futures::future::join_all(write_futures).await;
for result in results {
if let Err(e) = result {
tracing::error!("error writing RTP packet: {}", e);
}
}
}
tracing::debug!(
"RTP Read loop ended for track {} from peer {}",
track.id(),
peer_id
);
Ok(())
}
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
tracing::debug!("cleaning up peer");
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
tracing::debug!("removed peer");
let pc = Arc::clone(&peer_state.peer_connection);
tokio::spawn(async move {
if let Err(e) = pc.close().await {
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
tracing::warn!("error closing peer connection: {}", e);
}
}
tracing::debug!("peer connection closed");
});
}
if room_state.peers.is_empty() {
tracing::debug!("no more peers in room, closing room");
let _ = room_state.close_signal.send(());
}
}