5063 lines
140 KiB
Plaintext
5063 lines
140 KiB
Plaintext
|
|
// file: migrations/20250510180101_uuidv7.sql
|
|
CREATE EXTENSION IF NOT EXISTS pg_uuidv7;
|
|
// file: migrations/20250510182916_file.sql
|
|
CREATE TABLE IF NOT EXISTS "file"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"filename" VARCHAR NOT NULL,
|
|
"content_type" VARCHAR NOT NULL,
|
|
"size" INT8 NOT NULL
|
|
);
|
|
|
|
|
|
// file: migrations/20250510183102_user.sql
|
|
CREATE TABLE IF NOT EXISTS "user"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"avatar_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL,
|
|
"username" VARCHAR NOT NULL UNIQUE,
|
|
"display_name" VARCHAR,
|
|
"email" VARCHAR NOT NULL,
|
|
"password_hash" VARCHAR NOT NULL,
|
|
"bot" BOOLEAN NOT NULL DEFAULT FALSE,
|
|
"system" BOOLEAN NOT NULL DEFAULT FALSE,
|
|
"settings" JSONB NOT NULL DEFAULT '{}'::JSONB
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "user_relation"
|
|
(
|
|
"user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
|
|
"other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
|
|
"type" INT2 NOT NULL,
|
|
"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
|
|
PRIMARY KEY ("user_id", "other_id")
|
|
);
|
|
|
|
-- create system account
|
|
INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system")
|
|
VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE);
|
|
|
|
CREATE OR REPLACE FUNCTION check_avatar_is_image()
|
|
RETURNS TRIGGER AS
|
|
$$
|
|
DECLARE
|
|
file_content_type VARCHAR;
|
|
BEGIN
|
|
-- Skip check if icon_id is null
|
|
IF NEW.avatar_id IS NULL THEN
|
|
RETURN NEW;
|
|
END IF;
|
|
|
|
-- Retrieve content_type from file table
|
|
SELECT content_type
|
|
INTO file_content_type
|
|
FROM file
|
|
WHERE id = NEW.avatar_id;
|
|
|
|
-- Raise exception if content_type does not start with 'image/'
|
|
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
|
|
RAISE EXCEPTION 'avatar_id must reference a file with content_type starting with image/';
|
|
END IF;
|
|
|
|
RETURN NEW;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
|
|
CREATE TRIGGER trigger_check_icon_is_image
|
|
BEFORE INSERT OR UPDATE
|
|
ON "user"
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION check_avatar_is_image();
|
|
|
|
CREATE OR REPLACE FUNCTION fn_on_user_relation_update()
|
|
RETURNS TRIGGER
|
|
LANGUAGE plpgsql
|
|
AS
|
|
$$
|
|
BEGIN
|
|
IF (TG_OP = 'UPDATE') THEN
|
|
NEW.updated_at := now();
|
|
END IF;
|
|
|
|
RETURN NEW;
|
|
END;
|
|
$$;
|
|
|
|
CREATE TRIGGER trg_user_relation_update
|
|
BEFORE UPDATE
|
|
ON "user_relation"
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION fn_on_user_relation_update();
|
|
// file: migrations/20250510183125_server.sql
|
|
CREATE TABLE IF NOT EXISTS "server"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"owner_id" UUID NOT NULL REFERENCES "user" ("id"),
|
|
"name" VARCHAR NOT NULL,
|
|
"icon_id" UUID REFERENCES "file" ("id") ON DELETE SET NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "server_role"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
|
"name" VARCHAR NOT NULL,
|
|
"color" VARCHAR,
|
|
"display" BOOLEAN NOT NULL DEFAULT FALSE,
|
|
"permissions" JSONB NOT NULL DEFAULT '{}'::JSONB,
|
|
"position" SMALLINT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "server_member"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
|
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
|
|
"nickname" VARCHAR,
|
|
"avatar_url" VARCHAR,
|
|
UNIQUE ("server_id", "user_id")
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "server_member_role"
|
|
(
|
|
"member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE,
|
|
"role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE,
|
|
PRIMARY KEY ("member_id", "role_id")
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "server_invite"
|
|
(
|
|
"code" VARCHAR NOT NULL PRIMARY KEY,
|
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
|
"inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL,
|
|
"expires_at" TIMESTAMPTZ
|
|
);
|
|
|
|
CREATE OR REPLACE FUNCTION check_icon_is_image()
|
|
RETURNS TRIGGER AS
|
|
$$
|
|
DECLARE
|
|
file_content_type VARCHAR;
|
|
BEGIN
|
|
-- Skip check if icon_id is null
|
|
IF NEW.icon_id IS NULL THEN
|
|
RETURN NEW;
|
|
END IF;
|
|
|
|
-- Retrieve content_type from file table
|
|
SELECT content_type
|
|
INTO file_content_type
|
|
FROM file
|
|
WHERE id = NEW.icon_id;
|
|
|
|
-- Raise exception if content_type does not start with 'image/'
|
|
IF file_content_type IS NULL OR file_content_type NOT LIKE 'image/%' THEN
|
|
RAISE EXCEPTION 'icon_id must reference a file with content_type starting with image/';
|
|
END IF;
|
|
|
|
RETURN NEW;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
|
|
CREATE TRIGGER trigger_check_icon_is_image
|
|
BEFORE INSERT OR UPDATE
|
|
ON "server"
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION check_icon_is_image();
|
|
|
|
CREATE OR REPLACE FUNCTION check_server_user_role_server_id()
|
|
RETURNS TRIGGER AS
|
|
$$
|
|
DECLARE
|
|
member_server_id UUID;
|
|
role_server_id UUID;
|
|
BEGIN
|
|
-- Get server_id from server_user
|
|
SELECT server_id
|
|
INTO member_server_id
|
|
FROM server_member
|
|
WHERE id = NEW.member_id;
|
|
|
|
-- Get server_id from server_role
|
|
SELECT server_id
|
|
INTO role_server_id
|
|
FROM server_role
|
|
WHERE id = NEW.role_id;
|
|
|
|
-- Check if server_ids match
|
|
IF member_server_id != role_server_id THEN
|
|
RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id;
|
|
END IF;
|
|
|
|
RETURN NEW;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
|
|
CREATE TRIGGER enforce_server_user_role_server_id
|
|
BEFORE INSERT OR UPDATE
|
|
ON server_member_role
|
|
FOR EACH ROW
|
|
EXECUTE FUNCTION check_server_user_role_server_id();
|
|
// file: migrations/20250510184011_channel_message.sql
|
|
CREATE TABLE IF NOT EXISTS "channel"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"name" VARCHAR NOT NULL,
|
|
"type" INT2 NOT NULL,
|
|
"position" INT2 NOT NULL,
|
|
"server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels
|
|
"parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels
|
|
"owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels
|
|
);
|
|
|
|
-- only for dm or group channels
|
|
CREATE TABLE IF NOT EXISTS "channel_recipient"
|
|
(
|
|
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
|
|
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
|
|
PRIMARY KEY ("channel_id", "user_id")
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "message"
|
|
(
|
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
|
"author_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
|
|
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
|
|
"content" TEXT NOT NULL
|
|
);
|
|
|
|
CREATE TABLE IF NOT EXISTS "message_attachment"
|
|
(
|
|
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
|
|
"file_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
|
|
"order" INT2 NOT NULL,
|
|
PRIMARY KEY ("message_id", "file_id")
|
|
);
|
|
|
|
ALTER TABLE "channel"
|
|
ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL;
|
|
|
|
|
|
CREATE OR REPLACE FUNCTION create_dm_channel(
|
|
user1_id UUID,
|
|
user2_id UUID,
|
|
channel_type INT2
|
|
)
|
|
RETURNS UUID
|
|
LANGUAGE plpgsql
|
|
AS
|
|
$$
|
|
DECLARE
|
|
new_channel_id UUID;
|
|
channel_name VARCHAR;
|
|
BEGIN
|
|
-- Validate parameters
|
|
IF user1_id IS NULL OR user2_id IS NULL THEN
|
|
RAISE EXCEPTION 'Both user IDs must be provided';
|
|
END IF;
|
|
|
|
IF user1_id = user2_id THEN
|
|
RAISE EXCEPTION 'Cannot create a DM channel with the same user';
|
|
END IF;
|
|
|
|
-- Check if users exist
|
|
IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR
|
|
NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN
|
|
RAISE EXCEPTION 'One or both users do not exist';
|
|
END IF;
|
|
|
|
-- Check if DM already exists between these users
|
|
IF exists (SELECT 1
|
|
FROM channel c
|
|
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
|
|
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
|
|
WHERE c.type = channel_type
|
|
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN
|
|
-- Find and return the existing channel ID
|
|
SELECT c.id
|
|
INTO new_channel_id
|
|
FROM channel c
|
|
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
|
|
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
|
|
WHERE c.type = channel_type
|
|
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2
|
|
LIMIT 1;
|
|
|
|
RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id;
|
|
RETURN new_channel_id;
|
|
END IF;
|
|
|
|
-- Generate channel name (conventionally uses user IDs in DMs)
|
|
channel_name := concat(user1_id, '-', user2_id);
|
|
|
|
-- Create new channel
|
|
INSERT INTO "channel" ("name", "type", "position")
|
|
VALUES (channel_name, channel_type, 0)
|
|
RETURNING id INTO new_channel_id;
|
|
|
|
-- Add both users as recipients
|
|
INSERT INTO "channel_recipient" ("channel_id", "user_id")
|
|
VALUES (new_channel_id, user1_id),
|
|
(new_channel_id, user2_id);
|
|
|
|
RAISE NOTICE 'DM channel created with ID: %', new_channel_id;
|
|
RETURN new_channel_id;
|
|
END;
|
|
$$;
|
|
|
|
CREATE OR REPLACE FUNCTION create_group_channel(
|
|
p_creator_id UUID, -- The user initiating the group creation (will be owner)
|
|
p_recipient_user_ids UUID[], -- Array of all user IDs for the group (must include creator)
|
|
p_group_channel_type INT2 -- The channel type identifier for groups (e.g., 1)
|
|
)
|
|
RETURNS UUID
|
|
LANGUAGE plpgsql
|
|
AS
|
|
$$
|
|
DECLARE
|
|
new_channel_id UUID;
|
|
existing_channel_id UUID;
|
|
final_channel_name VARCHAR;
|
|
unique_sorted_recipient_ids UUID[];
|
|
num_recipients INT;
|
|
uid UUID;
|
|
-- Threshold for detailed name vs. summary name
|
|
MAX_MEMBERS_FOR_DETAILED_NAME CONSTANT INT := 3;
|
|
BEGIN
|
|
-- Validate and process recipient IDs
|
|
IF p_recipient_user_ids IS NULL OR array_length(p_recipient_user_ids, 1) IS NULL THEN
|
|
RAISE EXCEPTION 'Recipient user IDs array must be provided and not empty.';
|
|
END IF;
|
|
|
|
-- Get unique, sorted recipient IDs for consistent checking and to avoid duplicates.
|
|
SELECT array_agg(DISTINCT u ORDER BY u) INTO unique_sorted_recipient_ids FROM unnest(p_recipient_user_ids) u;
|
|
num_recipients := array_length(unique_sorted_recipient_ids, 1);
|
|
|
|
-- Validate minimum number of recipients for a group
|
|
IF num_recipients < 1 THEN -- Groups typically have at least 2 members
|
|
RAISE EXCEPTION 'Group channels (type %) must have at least 2 recipients. Found %.', p_group_channel_type, num_recipients;
|
|
END IF;
|
|
|
|
-- Create new group channel
|
|
INSERT INTO "channel" ("name", "type", "position", "owner_id", "server_id", "parent")
|
|
VALUES ('Group',
|
|
p_group_channel_type,
|
|
0, -- Default position
|
|
p_creator_id,
|
|
NULL, -- Not a server channel
|
|
NULL -- Not a nested server channel
|
|
)
|
|
RETURNING id INTO new_channel_id;
|
|
|
|
-- Add all recipients to the channel_recipient table
|
|
INSERT INTO "channel_recipient" ("channel_id", "user_id")
|
|
VALUES (new_channel_id, p_creator_id);
|
|
|
|
INSERT INTO "channel_recipient" ("channel_id", "user_id")
|
|
SELECT new_channel_id, r_id
|
|
FROM unnest(unique_sorted_recipient_ids) AS r_id;
|
|
|
|
RAISE NOTICE 'Group channel (type %) named "%" created with ID: % by owner % for recipients: %',
|
|
p_group_channel_type, final_channel_name, new_channel_id, p_creator_id, unique_sorted_recipient_ids;
|
|
RETURN new_channel_id;
|
|
END;
|
|
$$;
|
|
|
|
// file: migrations/20250517190855_util.sql
|
|
CREATE OR REPLACE FUNCTION get_users_that_can_see_user(target_user_id UUID)
|
|
RETURNS TABLE (user_id UUID) AS $$
|
|
BEGIN
|
|
RETURN QUERY
|
|
-- Users directly related to the target user
|
|
SELECT ur.user_id
|
|
FROM user_relation ur
|
|
WHERE ur.other_id = target_user_id
|
|
|
|
UNION
|
|
|
|
-- Users where target user is related to them
|
|
SELECT ur.other_id AS user_id
|
|
FROM user_relation ur
|
|
WHERE ur.user_id = target_user_id
|
|
|
|
UNION
|
|
|
|
-- Users who share a server with the target user
|
|
SELECT sm.user_id
|
|
FROM server_member sm
|
|
JOIN server_member sm2 ON sm.server_id = sm2.server_id
|
|
WHERE sm2.user_id = target_user_id
|
|
AND sm.user_id != target_user_id
|
|
|
|
UNION
|
|
|
|
-- Users who share a channel with the target user (DM or group)
|
|
SELECT cr.user_id
|
|
FROM channel_recipient cr
|
|
JOIN channel_recipient cr2 ON cr.channel_id = cr2.channel_id
|
|
WHERE cr2.user_id = target_user_id
|
|
AND cr.user_id != target_user_id;
|
|
END;
|
|
$$ LANGUAGE plpgsql;
|
|
// file: src/config.rs
|
|
use std::sync::OnceLock;
|
|
|
|
use serde::Deserialize;
|
|
|
|
pub fn config() -> &'static Config {
|
|
static INSTANCE: OnceLock<Config> = OnceLock::new();
|
|
|
|
INSTANCE.get_or_init(|| {
|
|
config::Config::builder()
|
|
.add_source(config::File::with_name("config"))
|
|
.add_source(config::Environment::with_prefix("DIPLOM_"))
|
|
.build()
|
|
.expect("config builder")
|
|
.try_deserialize()
|
|
.expect("config deserialize")
|
|
})
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct Config {
|
|
pub server: ServerConfig,
|
|
pub security: SecurityConfig,
|
|
pub gateway: GatewayConfig,
|
|
pub database: DatabaseConfig,
|
|
pub object_store: ObjectStoreConfig,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct ServerConfig {
|
|
pub hostname: url::Url,
|
|
pub host: std::net::Ipv4Addr,
|
|
pub port: u16,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct SecurityConfig {
|
|
pub auth_secret: String,
|
|
pub voice_secret: String,
|
|
}
|
|
|
|
#[derive(Deserialize)]
|
|
pub struct GatewayConfig {
|
|
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds")]
|
|
pub voice_token_lifetime: std::time::Duration,
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(untagged)]
|
|
#[serde(deny_unknown_fields)]
|
|
pub enum DatabaseConfig {
|
|
Url {
|
|
url: url::Url,
|
|
},
|
|
Full {
|
|
driver: String,
|
|
username: String,
|
|
password: String,
|
|
host: url::Host,
|
|
port: u16,
|
|
name: String,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, Deserialize)]
|
|
#[serde(deny_unknown_fields)]
|
|
pub struct ObjectStoreConfig {
|
|
pub endpoint: url::Url,
|
|
pub region: String,
|
|
pub bucket: String,
|
|
pub access_key: String,
|
|
pub secret_key: String,
|
|
}
|
|
|
|
impl DatabaseConfig {
|
|
pub fn url(&self) -> Option<url::Url> {
|
|
match self {
|
|
Self::Url { url } => Some(url.clone()),
|
|
Self::Full {
|
|
driver,
|
|
username,
|
|
password,
|
|
host,
|
|
port,
|
|
name,
|
|
} => {
|
|
let str = format!(
|
|
"{}://{}:{}@{}:{}/{}",
|
|
driver, username, password, host, port, name
|
|
);
|
|
let url = url::Url::parse(&str).ok()?;
|
|
Some(url)
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/database.rs
|
|
use crate::config::DatabaseConfig;
|
|
use crate::entity;
|
|
use crate::util::tristate::TriState;
|
|
|
|
#[derive(Clone, derive_more::AsRef)]
|
|
pub struct Database {
|
|
pool: sqlx::PgPool,
|
|
}
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
|
|
pub enum Error {
|
|
#[from]
|
|
Migrate(sqlx::migrate::MigrateError),
|
|
|
|
#[from]
|
|
Sqlx(sqlx::Error),
|
|
|
|
UserDoesNotExists,
|
|
UserAlreadyExists,
|
|
|
|
ServerDoesNotExists,
|
|
|
|
MemberAlreadyExists,
|
|
|
|
ChannelDoesNotExists,
|
|
|
|
InviteDoesNotExists,
|
|
|
|
MessageDoesNotExists,
|
|
|
|
FileDoesNotExists,
|
|
}
|
|
|
|
impl Database {
|
|
pub async fn connect(config: &DatabaseConfig) -> sqlx::Result<Self> {
|
|
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
|
|
|
|
let pool = sqlx::postgres::PgPoolOptions::new()
|
|
.connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str())
|
|
.await?;
|
|
|
|
MIGRATOR.run(&pool).await?;
|
|
|
|
Ok(Self { pool })
|
|
}
|
|
}
|
|
|
|
impl Database {
|
|
pub async fn insert_user(
|
|
&self,
|
|
username: &str,
|
|
display_name: Option<&str>,
|
|
email: &str,
|
|
password_hash: &str,
|
|
) -> Result<entity::user::User> {
|
|
let user = match sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#,
|
|
username,
|
|
display_name,
|
|
email,
|
|
password_hash
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await {
|
|
Ok(user) => user,
|
|
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
|
|
return Err(Error::UserAlreadyExists);
|
|
}
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result<entity::user::User> {
|
|
let user = sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"SELECT * FROM "user" WHERE "id" = $1"#,
|
|
user_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::UserDoesNotExists)?;
|
|
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn update_user_by_id(
|
|
&self,
|
|
user_id: entity::user::Id,
|
|
display_name: TriState<&str>,
|
|
avatar_id: TriState<entity::file::Id>,
|
|
) -> Result<entity::user::User> {
|
|
let mut query_builder = sqlx::QueryBuilder::new(r#"UPDATE "user" SET "#);
|
|
|
|
let mut separated = query_builder.separated(", ");
|
|
|
|
if !display_name.is_absent() {
|
|
separated.push(r#""display_name" = "#);
|
|
separated.push_bind_unseparated(display_name.into_option());
|
|
}
|
|
|
|
if !avatar_id.is_absent() {
|
|
separated.push(r#""avatar_id" = "#);
|
|
separated.push_bind_unseparated(avatar_id.into_option());
|
|
}
|
|
|
|
query_builder.push(" WHERE \"id\" = ");
|
|
query_builder.push_bind(user_id);
|
|
query_builder.push(" RETURNING \"user\".*");
|
|
|
|
let user = query_builder.build_query_as().fetch_one(&self.pool).await?;
|
|
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn select_users_by_ids(
|
|
&self,
|
|
user_ids: &[entity::user::Id],
|
|
) -> Result<Vec<entity::user::User>> {
|
|
let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#);
|
|
|
|
let mut separated = query_builder.separated(", ");
|
|
for id in user_ids.iter() {
|
|
separated.push_bind(id);
|
|
}
|
|
|
|
query_builder.push(")");
|
|
|
|
let users = query_builder.build_query_as().fetch_all(&self.pool).await?;
|
|
|
|
Ok(users)
|
|
}
|
|
|
|
pub async fn select_user_by_username(&self, username: &str) -> Result<entity::user::User> {
|
|
let user = sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"SELECT * FROM "user" WHERE "username" = $1"#,
|
|
username
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::UserDoesNotExists)?;
|
|
|
|
Ok(user)
|
|
}
|
|
|
|
pub async fn select_server_by_id(
|
|
&self,
|
|
server_id: entity::server::Id,
|
|
) -> Result<entity::server::Server> {
|
|
let server = sqlx::query_as!(
|
|
entity::server::Server,
|
|
r#"SELECT * FROM "server" WHERE "id" = $1"#,
|
|
server_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::ServerDoesNotExists)?;
|
|
|
|
Ok(server)
|
|
}
|
|
|
|
pub async fn select_user_servers(
|
|
&self,
|
|
user_id: entity::user::Id,
|
|
) -> Result<Vec<entity::server::Server>> {
|
|
let servers = sqlx::query_as!(
|
|
entity::server::Server,
|
|
r#"SELECT * FROM "server" WHERE "id" IN (
|
|
SELECT "server_id" FROM "server_member"
|
|
WHERE "user_id" = $1
|
|
)"#,
|
|
user_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(servers)
|
|
}
|
|
|
|
pub async fn select_server_members(
|
|
&self,
|
|
server_id: entity::server::Id,
|
|
) -> Result<Vec<entity::user::User>> {
|
|
let users = sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"SELECT * FROM "user" WHERE "id" IN (
|
|
SELECT "user_id" FROM "server_member" WHERE "server_id" = $1
|
|
)"#,
|
|
server_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(users)
|
|
}
|
|
|
|
pub async fn select_channel_members(
|
|
&self,
|
|
channel_id: entity::channel::Id,
|
|
) -> Result<Vec<entity::user::User>> {
|
|
let users = sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"SELECT * FROM "user" WHERE "id" IN (
|
|
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
|
|
UNION SELECT "user_id" FROM "server_member" WHERE "server_id" IN (
|
|
SELECT "server_id" FROM "channel" WHERE "id" = $1
|
|
)
|
|
)"#,
|
|
channel_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(users)
|
|
}
|
|
|
|
pub async fn select_user_channels(
|
|
&self,
|
|
user_id: entity::user::Id,
|
|
) -> Result<Vec<entity::channel::Channel>> {
|
|
let channels = sqlx::query_as!(
|
|
entity::channel::Channel,
|
|
r#"SELECT * FROM "channel" WHERE "id" IN (
|
|
SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1
|
|
)"#,
|
|
user_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(channels)
|
|
}
|
|
|
|
pub async fn insert_server(
|
|
&self,
|
|
name: &str,
|
|
icon_id: Option<entity::file::Id>,
|
|
owner_id: entity::user::Id,
|
|
) -> Result<entity::server::Server> {
|
|
let server = sqlx::query_as!(
|
|
entity::server::Server,
|
|
r#"INSERT INTO "server"("name", "icon_id", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#,
|
|
name,
|
|
icon_id,
|
|
owner_id
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(server)
|
|
}
|
|
|
|
pub async fn insert_server_role(
|
|
&self,
|
|
id: Option<entity::server::role::Id>,
|
|
server_id: entity::server::Id,
|
|
name: &str,
|
|
color: Option<&str>,
|
|
display: bool,
|
|
permissions: serde_json::Value,
|
|
position: u16,
|
|
) -> Result<entity::server::role::ServerRole> {
|
|
let role = sqlx::query_as!(
|
|
entity::server::role::ServerRole,
|
|
r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#,
|
|
id,
|
|
server_id,
|
|
name,
|
|
color,
|
|
display,
|
|
permissions,
|
|
position as i16
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(role)
|
|
}
|
|
|
|
pub async fn insert_server_member(
|
|
&self,
|
|
server_id: entity::server::Id,
|
|
user_id: entity::user::Id,
|
|
) -> Result<entity::server::member::ServerMember> {
|
|
let member = match sqlx::query_as!(
|
|
entity::server::member::ServerMember,
|
|
r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#,
|
|
server_id,
|
|
user_id
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await {
|
|
Ok(member) => member,
|
|
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
|
|
return Err(Error::MemberAlreadyExists);
|
|
}
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
Ok(member)
|
|
}
|
|
|
|
pub async fn insert_server_member_role(
|
|
&self,
|
|
server_member_id: entity::server::member::Id,
|
|
server_role_id: entity::server::role::Id,
|
|
) -> Result<()> {
|
|
sqlx::query!(
|
|
r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#,
|
|
server_member_id,
|
|
server_role_id
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn select_channel_by_id(
|
|
&self,
|
|
channel_id: entity::channel::Id,
|
|
) -> Result<entity::channel::Channel> {
|
|
let channel = sqlx::query_as!(
|
|
entity::channel::Channel,
|
|
r#"SELECT * FROM "channel" WHERE "id" = $1"#,
|
|
channel_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::ChannelDoesNotExists)?;
|
|
|
|
Ok(channel)
|
|
}
|
|
|
|
pub async fn select_channel_recipients(
|
|
&self,
|
|
channel_id: entity::channel::Id,
|
|
) -> Result<Option<Vec<entity::user::User>>> {
|
|
let recipients = sqlx::query_as!(
|
|
entity::user::User,
|
|
r#"SELECT * FROM "user" WHERE "id" IN (
|
|
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
|
|
)"#,
|
|
channel_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
if recipients.is_empty() {
|
|
return Ok(None);
|
|
}
|
|
|
|
Ok(Some(recipients))
|
|
}
|
|
|
|
pub async fn select_server_channels(
|
|
&self,
|
|
server_id: entity::server::Id,
|
|
) -> Result<Vec<entity::channel::Channel>> {
|
|
let channels = sqlx::query_as!(
|
|
entity::channel::Channel,
|
|
r#"SELECT * FROM "channel" WHERE "server_id" = $1"#,
|
|
server_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(channels)
|
|
}
|
|
|
|
pub async fn delete_server_by_id(
|
|
&self,
|
|
server_id: entity::server::Id,
|
|
) -> Result<entity::server::Server> {
|
|
let server = sqlx::query_as!(
|
|
entity::server::Server,
|
|
r#"DELETE FROM "server" WHERE "id" = $1 RETURNING "server".*"#,
|
|
server_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::ServerDoesNotExists)?;
|
|
|
|
Ok(server)
|
|
}
|
|
|
|
pub async fn delete_channel_by_id(
|
|
&self,
|
|
channel_id: entity::channel::Id,
|
|
) -> Result<entity::channel::Channel> {
|
|
let channel = sqlx::query_as!(
|
|
entity::channel::Channel,
|
|
r#"DELETE FROM "channel" WHERE "id" = $1 RETURNING "channel".*"#,
|
|
channel_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::ChannelDoesNotExists)?;
|
|
|
|
Ok(channel)
|
|
}
|
|
|
|
pub async fn insert_server_channel(
|
|
&self,
|
|
name: &str,
|
|
position: u16,
|
|
r#type: entity::channel::ChannelType,
|
|
server_id: entity::server::Id,
|
|
parent: Option<entity::channel::Id>,
|
|
) -> Result<entity::channel::Channel> {
|
|
let channel = sqlx::query_as!(
|
|
entity::channel::Channel,
|
|
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
|
|
name,
|
|
r#type as i16,
|
|
position as i16,
|
|
server_id,
|
|
parent
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(channel)
|
|
}
|
|
|
|
pub async fn insert_server_invite(
|
|
&self,
|
|
code: &str,
|
|
server_id: entity::server::Id,
|
|
inviter_id: Option<entity::user::Id>,
|
|
expires_at: Option<chrono::DateTime<chrono::Utc>>,
|
|
) -> Result<entity::server::invite::ServerInvite> {
|
|
let invite = sqlx::query_as!(
|
|
entity::server::invite::ServerInvite,
|
|
r#"INSERT INTO "server_invite"("code", "server_id", "inviter_id", "expires_at") VALUES ($1, $2, $3, $4) RETURNING "server_invite".*"#,
|
|
code,
|
|
server_id,
|
|
inviter_id,
|
|
expires_at
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(invite)
|
|
}
|
|
|
|
pub async fn select_server_invite_by_code(
|
|
&self,
|
|
code: &str,
|
|
) -> Result<entity::server::invite::ServerInvite> {
|
|
let invite = sqlx::query_as!(
|
|
entity::server::invite::ServerInvite,
|
|
r#"SELECT * FROM "server_invite" WHERE "code" = $1"#,
|
|
code
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::InviteDoesNotExists)?;
|
|
|
|
Ok(invite)
|
|
}
|
|
|
|
pub async fn delete_server_invite_by_code(
|
|
&self,
|
|
code: &str,
|
|
) -> Result<Option<entity::server::invite::ServerInvite>> {
|
|
let invite = sqlx::query_as!(
|
|
entity::server::invite::ServerInvite,
|
|
r#"DELETE FROM "server_invite" WHERE "code" = $1 RETURNING "server_invite".*"#,
|
|
code
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?;
|
|
|
|
Ok(invite)
|
|
}
|
|
|
|
pub async fn select_channel_messages_paginated(
|
|
&self,
|
|
channel_id: entity::channel::Id,
|
|
before: Option<entity::message::Id>,
|
|
limit: i64,
|
|
) -> Result<Vec<entity::message::Message>> {
|
|
let messages = sqlx::query_as!(
|
|
entity::message::Message,
|
|
r#"SELECT * FROM "message" WHERE "channel_id" = $1 AND ($2::uuid IS NULL OR "id" < $2::uuid) ORDER BY "id" DESC LIMIT $3"#,
|
|
channel_id,
|
|
before,
|
|
limit
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(messages)
|
|
}
|
|
|
|
pub async fn insert_channel_message(
|
|
&self,
|
|
user_id: entity::user::Id,
|
|
channel_id: entity::channel::Id,
|
|
content: &str,
|
|
) -> Result<entity::message::Message> {
|
|
let message = sqlx::query_as!(
|
|
entity::message::Message,
|
|
r#"INSERT INTO "message"("channel_id", "author_id", "content") VALUES ($1, $2, $3) RETURNING "message".*"#,
|
|
channel_id,
|
|
user_id,
|
|
content
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(message)
|
|
}
|
|
|
|
pub async fn select_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
|
|
let file = sqlx::query_as!(
|
|
entity::file::File,
|
|
r#"SELECT * FROM "file" WHERE "id" = $1"#,
|
|
file_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::FileDoesNotExists)?;
|
|
|
|
Ok(file)
|
|
}
|
|
|
|
pub async fn delete_file_by_id(&self, file_id: entity::file::Id) -> Result<entity::file::File> {
|
|
let file = sqlx::query_as!(
|
|
entity::file::File,
|
|
r#"DELETE FROM "file" WHERE "id" = $1 RETURNING "file".*"#,
|
|
file_id
|
|
)
|
|
.fetch_optional(&self.pool)
|
|
.await?
|
|
.ok_or(Error::FileDoesNotExists)?;
|
|
|
|
Ok(file)
|
|
}
|
|
|
|
pub async fn insert_file(
|
|
&self,
|
|
filename: &str,
|
|
content_type: &str,
|
|
size: usize,
|
|
) -> Result<entity::file::File> {
|
|
let file = sqlx::query_as!(
|
|
entity::file::File,
|
|
r#"INSERT INTO "file"("filename", "content_type", "size") VALUES ($1, $2, $3) RETURNING "file".*"#,
|
|
filename,
|
|
content_type,
|
|
size as i64
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?;
|
|
|
|
Ok(file)
|
|
}
|
|
|
|
pub async fn insert_message_attachment(
|
|
&self,
|
|
message_id: entity::message::Id,
|
|
file_id: entity::file::Id,
|
|
) -> Result<()> {
|
|
sqlx::query!(
|
|
r#"INSERT INTO "message_attachment"("message_id", "file_id", "order") VALUES ($1, $2, 0)"#,
|
|
message_id,
|
|
file_id
|
|
)
|
|
.execute(&self.pool)
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn select_message_attachments(
|
|
&self,
|
|
message_id: entity::message::Id,
|
|
) -> Result<Vec<entity::file::File>> {
|
|
let attachments = sqlx::query_as!(
|
|
entity::file::File,
|
|
r#"SELECT * FROM "file" WHERE "id" IN (
|
|
SELECT "file_id" FROM "message_attachment" WHERE "message_id" = $1
|
|
)"#,
|
|
message_id
|
|
)
|
|
.fetch_all(&self.pool)
|
|
.await?;
|
|
|
|
Ok(attachments)
|
|
}
|
|
|
|
pub async fn select_related_user_ids(
|
|
&self,
|
|
user_id: entity::user::Id,
|
|
) -> Result<Vec<entity::user::Id>> {
|
|
#[derive(sqlx::FromRow)]
|
|
struct UserId {
|
|
user_id: entity::user::Id,
|
|
}
|
|
|
|
let user_ids =
|
|
sqlx::query_as::<_, UserId>(r#"SELECT * FROM get_users_that_can_see_user($1)"#)
|
|
.bind(user_id)
|
|
.fetch_all(&self.pool)
|
|
.await?
|
|
.into_iter()
|
|
.map(|row| row.user_id)
|
|
.collect();
|
|
|
|
Ok(user_ids)
|
|
}
|
|
|
|
pub async fn procedure_create_dm_channel(
|
|
&self,
|
|
user1_id: entity::user::Id,
|
|
user2_id: entity::user::Id,
|
|
) -> Result<entity::channel::Id> {
|
|
let channel_id = sqlx::query_scalar!(
|
|
r#"SELECT create_dm_channel($1, $2, $3)"#,
|
|
user1_id,
|
|
user2_id,
|
|
entity::channel::ChannelType::DirectMessage as i16
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?
|
|
.expect("channel_id is null");
|
|
|
|
Ok(channel_id)
|
|
}
|
|
|
|
pub async fn procedure_create_group_channel(
|
|
&self,
|
|
creator_id: entity::user::Id,
|
|
users: &[entity::user::Id],
|
|
) -> Result<entity::channel::Id> {
|
|
let channel_id = sqlx::query_scalar!(
|
|
r#"SELECT create_group_channel($1, $2, $3)"#,
|
|
creator_id,
|
|
users,
|
|
entity::channel::ChannelType::DirectMessage as i16
|
|
)
|
|
.fetch_one(&self.pool)
|
|
.await?
|
|
.expect("channel_id is null");
|
|
|
|
Ok(channel_id)
|
|
}
|
|
}
|
|
|
|
// file: src/entity/channel.rs
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::entity::{message, server, user};
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct Channel {
|
|
pub id: Id,
|
|
pub name: String,
|
|
pub r#type: ChannelType,
|
|
pub position: i16,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub server_id: Option<server::Id>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub parent: Option<Id>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub owner_id: Option<user::Id>,
|
|
#[serde(skip_serializing_if = "Option::is_none")]
|
|
pub last_message_id: Option<message::Id>,
|
|
}
|
|
|
|
#[derive(Debug, Clone, sqlx::Type, Serialize, Deserialize)]
|
|
#[non_exhaustive]
|
|
#[serde(rename_all = "snake_case")]
|
|
#[repr(i16)]
|
|
pub enum ChannelType {
|
|
ServerText = 1,
|
|
ServerVoice = 2,
|
|
ServerCategory = 3,
|
|
|
|
DirectMessage = 4,
|
|
Group = 5,
|
|
}
|
|
|
|
impl From<i16> for ChannelType {
|
|
fn from(value: i16) -> Self {
|
|
match value {
|
|
1 => ChannelType::ServerText,
|
|
2 => ChannelType::ServerVoice,
|
|
3 => ChannelType::ServerCategory,
|
|
4 => ChannelType::DirectMessage,
|
|
5 => ChannelType::Group,
|
|
_ => ChannelType::ServerText,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/entity/file.rs
|
|
use serde::Serialize;
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
pub struct File {
|
|
pub id: Id,
|
|
pub filename: String,
|
|
pub content_type: String,
|
|
pub size: i64,
|
|
}
|
|
|
|
// file: src/entity/message.rs
|
|
use crate::entity::{channel, user};
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
|
pub struct Message {
|
|
pub id: Id,
|
|
pub author_id: user::Id,
|
|
pub channel_id: channel::Id,
|
|
pub content: String,
|
|
}
|
|
|
|
// file: src/entity/mod.rs
|
|
pub mod file;
|
|
pub mod channel;
|
|
pub mod message;
|
|
pub mod server;
|
|
pub mod user;
|
|
|
|
// file: src/entity/server/invite.rs
|
|
use chrono::{DateTime, Utc};
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::{server, user};
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct ServerInvite {
|
|
pub code: String,
|
|
pub server_id: server::Id,
|
|
pub inviter_id: Option<user::Id>,
|
|
pub expires_at: Option<DateTime<Utc>>,
|
|
}
|
|
|
|
// file: src/entity/server/member.rs
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::{server, user};
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct ServerMember {
|
|
pub id: Id,
|
|
pub server_id: server::Id,
|
|
pub user_id: user::Id,
|
|
pub nickname: Option<String>,
|
|
pub avatar_url: Option<String>,
|
|
}
|
|
|
|
// file: src/entity/server/role.rs
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::server;
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct ServerRole {
|
|
pub id: Id,
|
|
pub server_id: server::Id,
|
|
pub name: String,
|
|
pub color: Option<String>,
|
|
pub display: bool,
|
|
pub permissions: serde_json::Value,
|
|
pub position: i16,
|
|
}
|
|
|
|
// file: src/entity/server.rs
|
|
pub mod invite;
|
|
pub mod member;
|
|
pub mod role;
|
|
|
|
use crate::entity::{file, user};
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
|
pub struct Server {
|
|
pub id: Id,
|
|
pub owner_id: user::Id,
|
|
pub name: String,
|
|
pub icon_id: Option<file::Id>,
|
|
}
|
|
|
|
// file: src/entity/user.rs
|
|
use std::sync::LazyLock;
|
|
|
|
use regex::Regex;
|
|
use crate::entity::file;
|
|
|
|
pub static USERNAME_REGEX: LazyLock<Regex> =
|
|
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap());
|
|
|
|
pub type Id = uuid::Uuid;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
|
pub struct User {
|
|
pub id: Id,
|
|
pub avatar_id: Option<file::Id>,
|
|
pub username: String,
|
|
pub display_name: Option<String>,
|
|
pub email: String,
|
|
pub password_hash: String,
|
|
pub bot: bool,
|
|
pub system: bool,
|
|
pub settings: serde_json::Value,
|
|
}
|
|
|
|
// file: src/jwt.rs
|
|
#![allow(unused)]
|
|
|
|
use std::collections::HashSet;
|
|
|
|
use chrono::{Duration, Local, Utc};
|
|
use serde::de::DeserializeOwned;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::config;
|
|
|
|
#[derive(Debug, Serialize, Deserialize)]
|
|
pub struct Claims<T> {
|
|
#[serde(flatten)]
|
|
pub data: T,
|
|
pub iat: i64,
|
|
}
|
|
|
|
pub fn generate_jwt<T: Serialize>(data: T, secret: &[u8]) -> Result<String> {
|
|
let claims = Claims {
|
|
data,
|
|
iat: Utc::now().timestamp_millis(),
|
|
};
|
|
|
|
let token = jsonwebtoken::encode(
|
|
&jsonwebtoken::Header::default(),
|
|
&claims,
|
|
&jsonwebtoken::EncodingKey::from_secret(secret),
|
|
)
|
|
.map_err(|_| Error::CouldNotEncodeToken)?;
|
|
|
|
Ok(token)
|
|
}
|
|
|
|
pub fn verify_jwt<T: DeserializeOwned>(token: &str, secret: &[u8]) -> Result<T> {
|
|
tracing::debug!("verifying token: {}", token);
|
|
|
|
let mut validation = jsonwebtoken::Validation::default();
|
|
validation.set_required_spec_claims::<String>(&[]);
|
|
|
|
let token_data = jsonwebtoken::decode::<Claims<T>>(
|
|
token,
|
|
&jsonwebtoken::DecodingKey::from_secret(secret),
|
|
&validation,
|
|
)
|
|
.inspect_err(|err| {
|
|
tracing::error!("Failed to decode JWT: {:?}", err);
|
|
})
|
|
.map_err(|_| Error::CouldNotVerifyToken)?;
|
|
|
|
Ok(token_data.claims.data)
|
|
}
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::Error, derive_more::Display)]
|
|
pub enum Error {
|
|
CouldNotEncodeToken,
|
|
CouldNotVerifyToken,
|
|
}
|
|
|
|
// file: src/log.rs
|
|
use std::io;
|
|
|
|
use tracing::level_filters::LevelFilter;
|
|
use tracing_appender::non_blocking::WorkerGuard;
|
|
use tracing_subscriber::EnvFilter;
|
|
use tracing_subscriber::fmt::time::ChronoLocal;
|
|
use tracing_subscriber::layer::SubscriberExt;
|
|
use tracing_subscriber::util::SubscriberInitExt;
|
|
|
|
pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> {
|
|
let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string());
|
|
|
|
let file = tracing_appender::rolling::daily("./logs", "log");
|
|
|
|
let (file, file_guard) = tracing_appender::non_blocking(file);
|
|
|
|
let file_logger = tracing_subscriber::fmt::layer()
|
|
.with_ansi(false)
|
|
.with_timer(logger_timer.clone())
|
|
.with_writer(file);
|
|
|
|
let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout());
|
|
|
|
let stdout_logger = tracing_subscriber::fmt::layer()
|
|
.with_timer(logger_timer)
|
|
.with_writer(stdout);
|
|
#[cfg(debug_assertions)]
|
|
let stdout_logger = stdout_logger.pretty();
|
|
|
|
tracing_subscriber::registry()
|
|
.with(
|
|
EnvFilter::builder()
|
|
.with_default_directive(LevelFilter::INFO.into())
|
|
.from_env_lossy(),
|
|
)
|
|
.with(file_logger)
|
|
.with(stdout_logger)
|
|
.init();
|
|
|
|
tracing::info!("initialized logging");
|
|
|
|
Ok((file_guard, stdout_guard))
|
|
}
|
|
|
|
// file: src/main.rs
|
|
use std::sync::Arc;
|
|
|
|
use argon2::Argon2;
|
|
|
|
use crate::database::Database;
|
|
use crate::state::AppState;
|
|
|
|
mod config;
|
|
mod database;
|
|
mod entity;
|
|
mod jwt;
|
|
mod log;
|
|
mod object_store;
|
|
mod state;
|
|
mod util;
|
|
mod web;
|
|
mod webrtc;
|
|
|
|
#[tokio::main]
|
|
async fn main() -> anyhow::Result<()> {
|
|
let _guard = log::init_logging()?;
|
|
|
|
let database = Database::connect(&config::config().database).await?;
|
|
let object_store = object_store::ObjectStore::connect(&config::config().object_store).await?;
|
|
let state = AppState {
|
|
database,
|
|
object_store,
|
|
hasher: Arc::new(Argon2::default()),
|
|
gateway_state: Default::default(),
|
|
voice_rooms: Default::default(),
|
|
};
|
|
|
|
web::run(state).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
// file: src/object_store.rs
|
|
use crate::config::ObjectStoreConfig;
|
|
|
|
#[derive(Clone, derive_more::AsRef, derive_more::Deref)]
|
|
pub struct ObjectStore {
|
|
inner: Box<s3::Bucket>,
|
|
}
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
|
|
pub enum Error {
|
|
#[from]
|
|
Credentials(s3::creds::error::CredentialsError),
|
|
|
|
#[from]
|
|
S3(s3::error::S3Error),
|
|
}
|
|
|
|
impl ObjectStore {
|
|
pub async fn connect(config: &ObjectStoreConfig) -> Result<Self> {
|
|
let region = s3::region::Region::Custom {
|
|
region: config.region.clone(),
|
|
endpoint: config.endpoint.origin().ascii_serialization(),
|
|
};
|
|
|
|
let credentials = s3::creds::Credentials::new(
|
|
Some(&config.access_key),
|
|
Some(&config.secret_key),
|
|
None,
|
|
None,
|
|
None,
|
|
)?;
|
|
|
|
let mut bucket =
|
|
s3::bucket::Bucket::new(&config.bucket, region.clone(), credentials.clone())?
|
|
.with_path_style();
|
|
|
|
if !bucket.exists().await? {
|
|
bucket = s3::bucket::Bucket::create_with_path_style(
|
|
&config.bucket,
|
|
region,
|
|
credentials,
|
|
s3::BucketConfiguration::default(),
|
|
)
|
|
.await?
|
|
.bucket;
|
|
}
|
|
|
|
Ok(Self { inner: bucket })
|
|
}
|
|
}
|
|
|
|
// file: src/state.rs
|
|
use std::collections::HashMap;
|
|
use std::sync::Arc;
|
|
|
|
use argon2::Argon2;
|
|
use tokio::sync::{RwLock, mpsc};
|
|
use uuid::Uuid;
|
|
|
|
use crate::database::Database;
|
|
use crate::object_store::ObjectStore;
|
|
use crate::web::ws::gateway::{GatewayWsState, SessionKey, event};
|
|
use crate::webrtc::WebRtcSignal;
|
|
|
|
#[derive(Clone)]
|
|
pub struct AppState {
|
|
pub database: Database,
|
|
pub object_store: ObjectStore,
|
|
pub hasher: Arc<Argon2<'static>>,
|
|
|
|
pub gateway_state: Arc<GatewayState>,
|
|
|
|
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<WebRtcSignal>>>>,
|
|
}
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct GatewayState {
|
|
pub connected: scc::HashMap<Uuid, GatewayWsState>,
|
|
}
|
|
|
|
impl AppState {
|
|
pub async fn register_gateway_connected_user(
|
|
&self,
|
|
user_id: Uuid,
|
|
session_key: SessionKey,
|
|
event_sender: mpsc::UnboundedSender<event::Event>,
|
|
) {
|
|
self.gateway_state
|
|
.connected
|
|
.entry_async(user_id)
|
|
.await
|
|
.or_default()
|
|
.get_mut()
|
|
.instances
|
|
.insert(session_key, event_sender);
|
|
}
|
|
|
|
pub async fn unregister_gateway_connected_user(&self, user_id: Uuid, session_id: &SessionKey) {
|
|
let is_empty = {
|
|
let mut entry = self
|
|
.gateway_state
|
|
.connected
|
|
.entry_async(user_id)
|
|
.await
|
|
.or_default();
|
|
|
|
let entry = entry.get_mut();
|
|
|
|
entry.instances.remove(session_id);
|
|
entry.instances.is_empty()
|
|
};
|
|
|
|
if is_empty {
|
|
self.gateway_state.connected.remove_async(&user_id).await;
|
|
}
|
|
}
|
|
|
|
pub async fn register_voice_room(
|
|
&self,
|
|
room_id: Uuid,
|
|
sender: mpsc::UnboundedSender<WebRtcSignal>,
|
|
) {
|
|
self.voice_rooms.write().await.insert(room_id, sender);
|
|
}
|
|
|
|
pub async fn unregister_voice_room(&self, room_id: Uuid) {
|
|
self.voice_rooms.write().await.remove(&room_id);
|
|
}
|
|
}
|
|
|
|
// file: src/util/mod.rs
|
|
pub mod tristate;
|
|
|
|
use axum::extract::multipart::Field;
|
|
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::entity;
|
|
|
|
pub fn file_id_to_url(file_id: &entity::file::Id) -> Option<String> {
|
|
Some(
|
|
crate::config::config()
|
|
.server
|
|
.hostname
|
|
.join("files/")
|
|
.ok()?
|
|
.join(&file_id.to_string())
|
|
.ok()?
|
|
.to_string(),
|
|
)
|
|
}
|
|
|
|
#[derive(Debug, derive_more::Deref)]
|
|
pub struct SerdeFieldData<T>(pub FieldData<T>);
|
|
|
|
#[async_trait::async_trait]
|
|
impl<T: TryFromField> TryFromField for SerdeFieldData<T> {
|
|
async fn try_from_field(
|
|
field: Field<'_>,
|
|
limit_bytes: Option<usize>,
|
|
) -> Result<Self, TypedMultipartError> {
|
|
let field = FieldData::try_from_field(field, limit_bytes).await?;
|
|
|
|
Ok(Self(field))
|
|
}
|
|
}
|
|
|
|
impl<T> Serialize for SerdeFieldData<T> {
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
#[derive(Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct Metadata {
|
|
name: Option<String>,
|
|
file_name: Option<String>,
|
|
content_type: Option<String>,
|
|
}
|
|
|
|
let metadata = Metadata {
|
|
name: self.0.metadata.name.clone(),
|
|
file_name: self.0.metadata.file_name.clone(),
|
|
content_type: self.0.metadata.content_type.clone(),
|
|
};
|
|
|
|
metadata.serialize(serializer)
|
|
}
|
|
}
|
|
|
|
pub fn serialize_duration_seconds<S>(
|
|
duration: &std::time::Duration,
|
|
serializer: S,
|
|
) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
let seconds = duration.as_secs();
|
|
seconds.serialize(serializer)
|
|
}
|
|
|
|
pub fn deserialize_duration_seconds<'de, D>(
|
|
deserializer: D,
|
|
) -> Result<std::time::Duration, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
let seconds = u64::deserialize(deserializer)?;
|
|
Ok(std::time::Duration::from_secs(seconds))
|
|
}
|
|
|
|
pub fn serialize_duration_seconds_option<S>(
|
|
duration: &Option<std::time::Duration>,
|
|
serializer: S,
|
|
) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
match duration {
|
|
Some(duration) => serialize_duration_seconds(duration, serializer),
|
|
None => serializer.serialize_none(),
|
|
}
|
|
}
|
|
|
|
pub fn deserialize_duration_seconds_option<'de, D>(
|
|
deserializer: D,
|
|
) -> Result<Option<std::time::Duration>, D::Error>
|
|
where
|
|
D: serde::Deserializer<'de>,
|
|
{
|
|
Ok(deserialize_duration_seconds(deserializer).ok())
|
|
}
|
|
|
|
// file: src/util/tristate.rs
|
|
use std::ops::Deref;
|
|
|
|
use serde::{Deserialize, Deserializer, Serialize};
|
|
use validator::ValidateLength;
|
|
|
|
#[derive(Debug, Clone)]
|
|
pub enum TriState<T> {
|
|
Value(T),
|
|
None,
|
|
Absent,
|
|
}
|
|
|
|
impl<T> TriState<T> {
|
|
pub fn as_deref<D: ?Sized>(&self) -> TriState<&D>
|
|
where
|
|
T: Deref<Target = D>,
|
|
{
|
|
match self {
|
|
TriState::Value(v) => TriState::Value(v.deref()),
|
|
TriState::None => TriState::None,
|
|
TriState::Absent => TriState::Absent,
|
|
}
|
|
}
|
|
|
|
pub fn is_value(&self) -> bool {
|
|
matches!(self, TriState::Value(_))
|
|
}
|
|
|
|
pub fn is_none(&self) -> bool {
|
|
matches!(self, TriState::None)
|
|
}
|
|
|
|
pub fn is_absent(&self) -> bool {
|
|
matches!(self, TriState::Absent)
|
|
}
|
|
|
|
pub fn as_option(&self) -> Option<&T> {
|
|
match self {
|
|
TriState::Value(v) => Some(v),
|
|
TriState::None => None,
|
|
TriState::Absent => None,
|
|
}
|
|
}
|
|
|
|
pub fn into_option(self) -> Option<T> {
|
|
match self {
|
|
TriState::Value(v) => Some(v),
|
|
TriState::None => None,
|
|
TriState::Absent => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<T> Default for TriState<T> {
|
|
fn default() -> Self {
|
|
TriState::Absent
|
|
}
|
|
}
|
|
|
|
pub(crate) struct TriStateFieldVisitor<T> {
|
|
marker: std::marker::PhantomData<T>,
|
|
}
|
|
|
|
impl<T> Serialize for TriState<T>
|
|
where
|
|
T: Serialize,
|
|
{
|
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
|
where
|
|
S: serde::Serializer,
|
|
{
|
|
match self {
|
|
TriState::Value(v) => v.serialize(serializer),
|
|
TriState::None => serializer.serialize_none(),
|
|
TriState::Absent => serializer.serialize_unit(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<'de, T> Deserialize<'de> for TriState<T>
|
|
where
|
|
T: Deserialize<'de>,
|
|
{
|
|
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
deserializer.deserialize_option(TriStateFieldVisitor::<T> {
|
|
marker: std::marker::PhantomData,
|
|
})
|
|
}
|
|
}
|
|
impl<'de, T> serde::de::Visitor<'de> for TriStateFieldVisitor<T>
|
|
where
|
|
T: Deserialize<'de>,
|
|
{
|
|
type Value = TriState<T>;
|
|
|
|
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
|
formatter.write_str("TriStateField<T>")
|
|
}
|
|
|
|
#[inline]
|
|
fn visit_none<E>(self) -> Result<TriState<T>, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(TriState::None)
|
|
}
|
|
|
|
#[inline]
|
|
fn visit_some<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
|
|
where
|
|
D: Deserializer<'de>,
|
|
{
|
|
T::deserialize(deserializer).map(TriState::Value)
|
|
}
|
|
|
|
#[inline]
|
|
fn visit_unit<E>(self) -> Result<TriState<T>, E>
|
|
where
|
|
E: serde::de::Error,
|
|
{
|
|
Ok(TriState::None)
|
|
}
|
|
}
|
|
|
|
impl<T> ValidateLength<u64> for TriState<T>
|
|
where
|
|
T: ValidateLength<u64>,
|
|
{
|
|
fn length(&self) -> Option<u64> {
|
|
match self {
|
|
TriState::Value(v) => v.length(),
|
|
TriState::None => None,
|
|
TriState::Absent => None,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/context.rs
|
|
use axum::extract::FromRequestParts;
|
|
use axum::http::request::Parts;
|
|
|
|
use crate::entity;
|
|
|
|
#[derive(Debug, Copy, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct UserContext {
|
|
pub user_id: entity::user::Id,
|
|
}
|
|
|
|
pub type UserContextResult = Result<UserContext, Error>;
|
|
|
|
#[derive(Debug, Clone, Copy, derive_more::Display)]
|
|
pub enum Error {
|
|
NotInRequest,
|
|
NotInHeader,
|
|
BadCharacters,
|
|
WrongTokenType,
|
|
BadToken,
|
|
Model,
|
|
}
|
|
|
|
impl<S: Send + Sync> FromRequestParts<S> for UserContext {
|
|
type Rejection = super::error::Error;
|
|
|
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
|
let context = parts
|
|
.extensions
|
|
.get::<UserContextResult>()
|
|
.cloned()
|
|
.ok_or(Error::NotInRequest)??;
|
|
|
|
Ok(context)
|
|
}
|
|
}
|
|
|
|
// file: src/web/entity/file.rs
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::file::Id;
|
|
use crate::util;
|
|
|
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct File {
|
|
pub id: Id,
|
|
pub filename: String,
|
|
pub content_type: String,
|
|
pub size: i64,
|
|
pub url: String,
|
|
}
|
|
|
|
impl From<crate::entity::file::File> for File {
|
|
fn from(file: crate::entity::file::File) -> Self {
|
|
Self {
|
|
id: file.id,
|
|
filename: file.filename,
|
|
content_type: file.content_type,
|
|
size: file.size,
|
|
url: util::file_id_to_url(&file.id).unwrap_or_default(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/entity/message.rs
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::message::Id;
|
|
use crate::entity::{channel, user};
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct Message {
|
|
pub id: Id,
|
|
pub author_id: user::Id,
|
|
pub channel_id: channel::Id,
|
|
pub content: String,
|
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
|
pub attachments: Vec<super::file::File>,
|
|
}
|
|
|
|
impl Message {
|
|
pub fn from_message_with_attachments(
|
|
message: crate::entity::message::Message,
|
|
attachments: Vec<super::file::File>,
|
|
) -> Self {
|
|
Self {
|
|
id: message.id,
|
|
author_id: message.author_id,
|
|
channel_id: message.channel_id,
|
|
content: message.content,
|
|
created_at: message
|
|
.id
|
|
.get_timestamp()
|
|
.as_ref()
|
|
.map(uuid::Timestamp::to_unix)
|
|
.map(|(secs, nsecs)| {
|
|
chrono::DateTime::<chrono::Utc>::from_timestamp(secs as i64, nsecs)
|
|
})
|
|
.flatten()
|
|
.unwrap_or_default(),
|
|
attachments,
|
|
}
|
|
}
|
|
}
|
|
// file: src/web/entity/mod.rs
|
|
pub mod file;
|
|
pub mod message;
|
|
pub mod server;
|
|
pub mod user;
|
|
|
|
// file: src/web/entity/server.rs
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::server::Id;
|
|
use crate::entity::user;
|
|
use crate::util;
|
|
|
|
#[derive(Debug, Clone, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct Server {
|
|
pub id: Id,
|
|
pub owner_id: user::Id,
|
|
pub name: String,
|
|
pub icon_url: Option<String>,
|
|
}
|
|
|
|
impl From<crate::entity::server::Server> for Server {
|
|
fn from(server: crate::entity::server::Server) -> Self {
|
|
Self {
|
|
id: server.id,
|
|
owner_id: server.owner_id,
|
|
name: server.name,
|
|
icon_url: server.icon_id.as_ref().map(util::file_id_to_url).flatten(),
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/entity/user.rs
|
|
use crate::entity::user;
|
|
use crate::util;
|
|
|
|
#[derive(serde::Serialize, Debug)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct FullUser {
|
|
pub id: user::Id,
|
|
pub avatar_url: Option<String>,
|
|
pub username: String,
|
|
pub display_name: Option<String>,
|
|
pub email: String,
|
|
pub bot: bool,
|
|
pub system: bool,
|
|
pub settings: serde_json::Value,
|
|
}
|
|
|
|
#[derive(Debug, Clone, serde::Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct PartialUser {
|
|
pub id: user::Id,
|
|
pub avatar_url: Option<String>,
|
|
pub username: String,
|
|
pub display_name: Option<String>,
|
|
pub bot: bool,
|
|
pub system: bool,
|
|
}
|
|
|
|
impl From<user::User> for FullUser {
|
|
fn from(user: user::User) -> Self {
|
|
Self {
|
|
id: user.id,
|
|
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
|
|
username: user.username,
|
|
display_name: user.display_name,
|
|
email: user.email,
|
|
bot: user.bot,
|
|
system: user.system,
|
|
settings: user.settings,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl From<user::User> for PartialUser {
|
|
fn from(user: user::User) -> Self {
|
|
Self {
|
|
id: user.id,
|
|
avatar_url: user.avatar_id.as_ref().map(util::file_id_to_url).flatten(),
|
|
username: user.username,
|
|
display_name: user.display_name,
|
|
bot: user.bot,
|
|
system: user.system,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/error.rs
|
|
use std::sync::Arc;
|
|
|
|
use axum::http::StatusCode;
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::web::context;
|
|
use crate::{database, jwt, object_store};
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::From)]
|
|
pub enum Error {
|
|
#[from]
|
|
Context(context::Error),
|
|
|
|
#[from]
|
|
Jwt(jwt::Error),
|
|
|
|
#[from]
|
|
Hash(argon2::password_hash::Error),
|
|
|
|
#[from]
|
|
Database(database::Error),
|
|
|
|
#[from]
|
|
ObjectStore(object_store::Error),
|
|
|
|
#[from]
|
|
Json(serde_json::error::Error),
|
|
|
|
#[from]
|
|
JsonRejection(axum::extract::rejection::JsonRejection),
|
|
|
|
#[from]
|
|
Client(ClientError),
|
|
}
|
|
|
|
#[derive(Debug, Clone, derive_more::From, serde::Serialize)]
|
|
#[serde(tag = "message", content = "details")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum ClientError {
|
|
UserAlreadyExists,
|
|
UserDoesNotExist,
|
|
|
|
ServerDoesNotExist,
|
|
|
|
ChannelDoesNotExist,
|
|
|
|
MessageDoesNotExist,
|
|
|
|
NotAuthorized,
|
|
WrongPassword,
|
|
NotAllowed,
|
|
|
|
InvalidJson(JsonRejection),
|
|
|
|
#[from]
|
|
ValidationFailed(validator::ValidationErrors),
|
|
|
|
InternalServerError,
|
|
|
|
Unknown,
|
|
}
|
|
|
|
#[derive(derive_more::Debug, Clone, serde::Serialize)]
|
|
pub struct JsonRejection {
|
|
#[serde(skip)]
|
|
status: StatusCode,
|
|
|
|
reason: String,
|
|
}
|
|
|
|
impl From<&axum::extract::rejection::JsonRejection> for JsonRejection {
|
|
fn from(value: &axum::extract::rejection::JsonRejection) -> Self {
|
|
use std::error::Error;
|
|
|
|
use axum::extract::rejection::JsonRejection::*;
|
|
|
|
let reason = match value {
|
|
JsonDataError(e) => e.source().map(ToString::to_string),
|
|
JsonSyntaxError(e) => e.source().map(ToString::to_string),
|
|
MissingJsonContentType(e) => e.source().map(ToString::to_string),
|
|
BytesRejection(e) => e.source().map(ToString::to_string),
|
|
_ => None,
|
|
}
|
|
.unwrap_or_else(|| value.body_text());
|
|
|
|
Self {
|
|
status: value.status(),
|
|
reason,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Error {
|
|
pub fn as_client_error(&self) -> ClientError {
|
|
match self {
|
|
Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized,
|
|
Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists,
|
|
Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist,
|
|
|
|
Error::Database(database::Error::ServerDoesNotExists) => {
|
|
ClientError::ServerDoesNotExist
|
|
},
|
|
|
|
Error::Database(database::Error::ChannelDoesNotExists) => {
|
|
ClientError::ChannelDoesNotExist
|
|
},
|
|
|
|
Error::Database(database::Error::MessageDoesNotExists) => {
|
|
ClientError::MessageDoesNotExist
|
|
},
|
|
// Error::WrongPassword => ClientError::WrongPassword,
|
|
// Error::NotAllowed => ClientError::NotAllowed,
|
|
Error::JsonRejection(e) => ClientError::InvalidJson(e.into()),
|
|
Error::Client(client_error) => client_error.clone(),
|
|
_ => ClientError::InternalServerError,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl IntoResponse for Error {
|
|
fn into_response(self) -> axum::response::Response {
|
|
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
|
|
|
response.extensions_mut().insert(Arc::new(self));
|
|
|
|
response
|
|
}
|
|
}
|
|
|
|
impl ClientError {
|
|
pub fn status_code(&self) -> StatusCode {
|
|
match self {
|
|
ClientError::UserAlreadyExists => StatusCode::CONFLICT,
|
|
ClientError::UserDoesNotExist
|
|
| ClientError::ServerDoesNotExist
|
|
| ClientError::ChannelDoesNotExist
|
|
| ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND,
|
|
|
|
ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED,
|
|
|
|
ClientError::NotAllowed => StatusCode::FORBIDDEN,
|
|
|
|
ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
|
|
|
ClientError::InvalidJson(e) => e.status,
|
|
|
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/middleware/auth.rs
|
|
use axum::extract::{Request, State};
|
|
use axum::middleware::Next;
|
|
use axum::response::{IntoResponse, Response};
|
|
use axum::{Extension, RequestExt};
|
|
use axum_extra::TypedHeader;
|
|
use axum_extra::headers::Authorization;
|
|
use axum_extra::headers::authorization::Bearer;
|
|
use axum_extra::typed_header::TypedHeaderRejectionReason;
|
|
|
|
use crate::jwt;
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::{self, context};
|
|
|
|
pub async fn require_context(
|
|
Extension(context): Extension<context::UserContextResult>,
|
|
request: Request,
|
|
next: Next,
|
|
) -> web::error::Result<Response> {
|
|
context?;
|
|
|
|
Ok(next.run(request).await)
|
|
}
|
|
|
|
pub async fn resolve_context(
|
|
State(state): State<AppState>,
|
|
mut request: Request,
|
|
next: Next,
|
|
) -> impl IntoResponse {
|
|
let context = get_context(&state, &mut request).await;
|
|
|
|
request.extensions_mut().insert(context);
|
|
|
|
next.run(request).await
|
|
}
|
|
|
|
async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult {
|
|
let bearer = request
|
|
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
|
|
.await
|
|
.map_err(|error| match error.reason() {
|
|
TypedHeaderRejectionReason::Missing => context::Error::NotInHeader,
|
|
TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType,
|
|
_ => context::Error::NotInHeader,
|
|
})?;
|
|
|
|
let token = bearer.token();
|
|
|
|
let context = get_context_from_token(state, token).await?;
|
|
|
|
Ok(context)
|
|
}
|
|
|
|
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
|
|
let context = jwt::verify_jwt::<UserContext>(
|
|
token,
|
|
crate::config::config().security.auth_secret.as_ref(),
|
|
)
|
|
.map_err(|_| context::Error::BadToken)?;
|
|
|
|
let _ = state
|
|
.database
|
|
.select_user_by_id(context.user_id)
|
|
.await
|
|
.map_err(|_| context::Error::BadToken)?;
|
|
|
|
Ok(context)
|
|
}
|
|
|
|
// file: src/web/middleware/mod.rs
|
|
mod auth;
|
|
mod response_map;
|
|
|
|
pub use auth::*;
|
|
pub use response_map::*;
|
|
|
|
// file: src/web/middleware/response_map.rs
|
|
use std::sync::Arc;
|
|
|
|
use axum::Json;
|
|
use axum::extract::Request;
|
|
use axum::middleware::Next;
|
|
use axum::response::IntoResponse;
|
|
use serde_json::json;
|
|
|
|
use crate::web;
|
|
|
|
pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
|
|
let response = next.run(request).await;
|
|
|
|
let error = response.extensions().get::<Arc<web::Error>>();
|
|
|
|
if error.is_some() {
|
|
tracing::error!("{:?}", error);
|
|
}
|
|
|
|
let error_response = error.map(|error| {
|
|
let client_error = error.as_client_error();
|
|
|
|
(
|
|
client_error.status_code(),
|
|
Json(json!({"error": client_error})),
|
|
)
|
|
.into_response()
|
|
});
|
|
|
|
error_response.unwrap_or(response)
|
|
}
|
|
|
|
// file: src/web/mod.rs
|
|
mod context;
|
|
mod entity;
|
|
mod error;
|
|
mod middleware;
|
|
mod route;
|
|
pub mod ws;
|
|
|
|
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
|
|
|
|
pub use self::error::{Error, Result};
|
|
use crate::{config, state};
|
|
|
|
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
|
let config = config::config();
|
|
|
|
let addr: std::net::SocketAddr = (config.server.host, config.server.port).into();
|
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
|
tracing::info!("listening on {}", addr);
|
|
|
|
axum::serve(listener, router(state))
|
|
.with_graceful_shutdown(shutdown_signal())
|
|
.await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
fn router(state: state::AppState) -> axum::Router {
|
|
use axum::Router;
|
|
use axum::routing::*;
|
|
|
|
use self::route::*;
|
|
|
|
let cors = tower_http::cors::CorsLayer::new()
|
|
.allow_origin(AllowOrigin::any())
|
|
.allow_methods(AllowMethods::any())
|
|
.allow_headers(AllowHeaders::any());
|
|
|
|
Router::new()
|
|
// websocket
|
|
.route("/gateway/ws", get(ws::gateway::ws_handler))
|
|
.route("/voice/ws", get(ws::voice::ws_handler))
|
|
// file
|
|
.route("/files/{file_id}", get(file::get))
|
|
// api
|
|
.nest(
|
|
"/api/v1",
|
|
Router::new()
|
|
.route("/auth/login", post(auth::login))
|
|
.route("/auth/register", post(auth::register))
|
|
.merge(protected_router())
|
|
.route_layer(axum::middleware::from_fn_with_state(
|
|
state.clone(),
|
|
middleware::resolve_context,
|
|
)),
|
|
)
|
|
// middleware
|
|
.layer(axum::middleware::from_fn(middleware::response_map))
|
|
.layer(cors)
|
|
.layer(tower_http::trace::TraceLayer::new_for_http())
|
|
.with_state(state.clone())
|
|
}
|
|
|
|
fn protected_router() -> axum::Router<state::AppState> {
|
|
use axum::Router;
|
|
use axum::routing::*;
|
|
|
|
use self::route::*;
|
|
|
|
Router::new()
|
|
// user
|
|
.route("/users/@me", get(user::me))
|
|
.route("/users/@me", patch(user::patch))
|
|
.route("/users/@me/channels", get(user::channel::list))
|
|
.route("/users/@me/channels", post(user::channel::create))
|
|
.route("/users/{id}", get(user::get_by_id))
|
|
// channel
|
|
.route(
|
|
"/channels/{channel_id}/messages",
|
|
get(channel::message::page),
|
|
)
|
|
.route(
|
|
"/channels/{channel_id}/messages",
|
|
post(channel::message::create),
|
|
)
|
|
// server
|
|
.route("/servers", get(server::list))
|
|
.route("/servers", post(server::create))
|
|
.route("/servers/{server_id}", get(server::get))
|
|
.route("/servers/{server_id}", delete(server::delete))
|
|
.route("/servers/{server_id}/channels", get(server::channel::list))
|
|
.route(
|
|
"/servers/{server_id}/channels",
|
|
post(server::channel::create),
|
|
)
|
|
.route(
|
|
"/servers/{server_id}/channels/{channel_id}",
|
|
get(server::channel::get),
|
|
)
|
|
.route(
|
|
"/servers/{server_id}/channels/{channel_id}",
|
|
delete(server::channel::delete),
|
|
)
|
|
// invite
|
|
.route("/servers/{server_id}/invites", post(server::invite::create))
|
|
.route("/invites/{code}", get(server::invite::get))
|
|
// file
|
|
.route("/files", post(file::upload))
|
|
// middleware
|
|
.route_layer(axum::middleware::from_fn(middleware::require_context))
|
|
}
|
|
|
|
async fn shutdown_signal() {
|
|
_ = tokio::signal::ctrl_c().await;
|
|
}
|
|
|
|
// file: src/web/route/auth/login.rs
|
|
use argon2::{PasswordHash, PasswordVerifier};
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use serde::{Deserialize, Serialize};
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::{jwt, web};
|
|
use crate::web::entity::user::FullUser;
|
|
|
|
#[derive(Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct LoginPayload {
|
|
username: String,
|
|
password: String,
|
|
}
|
|
|
|
#[derive(Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct LoginResponse {
|
|
user: FullUser,
|
|
token: String,
|
|
}
|
|
|
|
pub async fn login(
|
|
State(state): State<AppState>,
|
|
Json(payload): Json<LoginPayload>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let user = state
|
|
.database
|
|
.select_user_by_username(&payload.username)
|
|
.await?;
|
|
|
|
let password_hash = PasswordHash::new(&user.password_hash)?;
|
|
|
|
state
|
|
.hasher
|
|
.verify_password(payload.password.as_bytes(), &password_hash)
|
|
.map_err(|_| web::error::ClientError::WrongPassword)?;
|
|
|
|
let token = jwt::generate_jwt(UserContext { user_id: user.id }, crate::config::config().security.auth_secret.as_ref())?;
|
|
|
|
let response = LoginResponse {
|
|
user: user.into(),
|
|
token,
|
|
};
|
|
|
|
Ok(Json(response))
|
|
}
|
|
|
|
// file: src/web/route/auth/mod.rs
|
|
pub mod login;
|
|
pub mod register;
|
|
|
|
pub use login::*;
|
|
pub use register::*;
|
|
|
|
// file: src/web/route/auth/register.rs
|
|
use argon2::password_hash::rand_core::OsRng;
|
|
use argon2::password_hash::{PasswordHasher, SaltString};
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use axum_extra::extract::WithRejection;
|
|
use serde::Deserialize;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web;
|
|
use crate::web::entity::user::FullUser;
|
|
use crate::web::error::ClientError;
|
|
|
|
#[derive(Validate, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct RegisterPayload {
|
|
#[validate(regex(
|
|
path = "crate::entity::user::USERNAME_REGEX",
|
|
code = "invalid_username"
|
|
))]
|
|
#[validate(length(min = 3, max = 64))]
|
|
username: String,
|
|
|
|
#[validate(length(min = 1, max = 64))]
|
|
display_name: Option<String>,
|
|
|
|
#[validate(email)]
|
|
#[validate(length(min = 3, max = 128))]
|
|
email: String,
|
|
|
|
#[validate(length(min = 3, max = 128))]
|
|
password: String,
|
|
}
|
|
|
|
pub async fn register(
|
|
State(state): State<AppState>,
|
|
WithRejection(Json(payload), _): WithRejection<Json<RegisterPayload>, web::Error>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
payload.validate().map_err(ClientError::ValidationFailed)?;
|
|
|
|
let salt = SaltString::generate(&mut OsRng);
|
|
|
|
let password_hash = state
|
|
.hasher
|
|
.hash_password(payload.password.as_bytes(), &salt)?
|
|
.to_string();
|
|
|
|
let user = state
|
|
.database
|
|
.insert_user(
|
|
&payload.username,
|
|
payload.display_name.as_ref().map(String::as_str),
|
|
&payload.email,
|
|
&password_hash,
|
|
)
|
|
.await?;
|
|
|
|
let system_user = state.database.select_user_by_username("system").await?;
|
|
if system_user.system {
|
|
state
|
|
.database
|
|
.procedure_create_dm_channel(user.id, system_user.id)
|
|
.await?;
|
|
}
|
|
|
|
Ok(Json(FullUser::from(user)))
|
|
}
|
|
|
|
// file: src/web/route/channel/message/create.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::file::File;
|
|
use crate::web::entity::message::Message;
|
|
use crate::web::ws;
|
|
use crate::{entity, web};
|
|
|
|
#[derive(Debug, serde::Deserialize, Validate)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[validate(schema(function = "validate_create_payload"))]
|
|
pub struct CreatePayload {
|
|
#[validate(length(min = 0, max = 2000))]
|
|
pub content: String,
|
|
|
|
#[serde(default)]
|
|
#[validate(length(min = 0, max = 10))]
|
|
pub attachments: Vec<entity::file::Id>,
|
|
}
|
|
|
|
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
|
|
if payload.content.is_empty() && payload.attachments.is_empty() {
|
|
return Err(validator::ValidationError::new(
|
|
"at_least_one_field_required",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn create(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(channel_id): Path<entity::channel::Id>,
|
|
Json(payload): Json<CreatePayload>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
match payload.validate() {
|
|
Ok(_) => {},
|
|
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
|
|
};
|
|
|
|
let message = state
|
|
.database
|
|
.insert_channel_message(context.user_id, channel_id, &payload.content)
|
|
.await?;
|
|
|
|
for attachment in payload.attachments.iter() {
|
|
state
|
|
.database
|
|
.insert_message_attachment(message.id, *attachment)
|
|
.await?;
|
|
}
|
|
|
|
let attachments = state
|
|
.database
|
|
.select_message_attachments(message.id)
|
|
.await?
|
|
.into_iter()
|
|
.map(File::from)
|
|
.collect::<Vec<_>>();
|
|
|
|
let message = Message::from_message_with_attachments(message, attachments);
|
|
|
|
// TODO: check permissions
|
|
ws::gateway::util::send_message_channel(
|
|
state,
|
|
message.channel_id,
|
|
ws::gateway::event::Event::AddMessage {
|
|
channel_id,
|
|
message: message.clone(),
|
|
},
|
|
);
|
|
|
|
Ok(Json(message))
|
|
}
|
|
|
|
// file: src/web/route/channel/message/mod.rs
|
|
mod page;
|
|
mod create;
|
|
|
|
pub use page::page;
|
|
pub use create::create;
|
|
// file: src/web/route/channel/message/page.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, Query, State};
|
|
use axum::response::IntoResponse;
|
|
use serde::Deserialize;
|
|
use tracing::error;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::message::Message;
|
|
use crate::{entity, web};
|
|
|
|
#[derive(Debug, Deserialize, Validate)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct PageParams {
|
|
#[serde(default = "limit_default")]
|
|
#[validate(range(min = 1, max = 100))]
|
|
pub limit: u32,
|
|
#[serde(default)]
|
|
pub before: Option<entity::message::Id>,
|
|
}
|
|
|
|
fn limit_default() -> u32 {
|
|
50
|
|
}
|
|
|
|
pub async fn page(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(channel_id): Path<entity::channel::Id>,
|
|
Query(params): Query<PageParams>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
match params.validate() {
|
|
Ok(_) => {},
|
|
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
|
|
};
|
|
|
|
let messages_futures = state
|
|
.database
|
|
.select_channel_messages_paginated(channel_id, params.before, params.limit as i64)
|
|
.await?
|
|
.into_iter()
|
|
.map(|message| async {
|
|
let attachments = state
|
|
.database
|
|
.select_message_attachments(message.id)
|
|
.await?
|
|
.into_iter()
|
|
.map(web::entity::file::File::from)
|
|
.collect::<Vec<_>>();
|
|
|
|
Ok::<_, web::Error>(Message::from_message_with_attachments(message, attachments))
|
|
});
|
|
|
|
let messages = futures::future::try_join_all(messages_futures).await?;
|
|
|
|
Ok(Json(messages))
|
|
}
|
|
|
|
// file: src/web/route/channel/mod.rs
|
|
pub mod message;
|
|
|
|
// file: src/web/route/file/get.rs
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::{entity, object_store, web};
|
|
|
|
pub async fn get(
|
|
State(state): State<AppState>,
|
|
Path(file_id): Path<entity::file::Id>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let file = match state.database.select_file_by_id(file_id).await {
|
|
Ok(file) => file,
|
|
Err(e) => {
|
|
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
|
|
},
|
|
};
|
|
|
|
let data = match state
|
|
.object_store
|
|
.get_object_stream(&file.id.to_string())
|
|
.await
|
|
{
|
|
Ok(data) => data,
|
|
Err(s3::error::S3Error::HttpFailWithBody(403 | 404, _)) => {
|
|
let _ = state.database.delete_file_by_id(file.id).await?;
|
|
|
|
return Ok(axum::http::StatusCode::NOT_FOUND.into_response());
|
|
},
|
|
Err(e) => {
|
|
return Err(object_store::Error::from(e).into());
|
|
},
|
|
};
|
|
|
|
let headers = axum::response::AppendHeaders([
|
|
(axum::http::header::CONTENT_TYPE, file.content_type.clone()),
|
|
(
|
|
axum::http::header::CONTENT_DISPOSITION,
|
|
format!("filename=\"{}\"", file.filename),
|
|
),
|
|
]);
|
|
|
|
Ok((headers, axum::body::Body::from_stream(data.bytes)).into_response())
|
|
}
|
|
|
|
// file: src/web/route/file/mod.rs
|
|
mod get;
|
|
mod upload;
|
|
|
|
pub use get::get;
|
|
pub use upload::upload;
|
|
|
|
// file: src/web/route/file/upload.rs
|
|
use axum::Json;
|
|
use axum::body::Bytes;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::util::SerdeFieldData;
|
|
use crate::{object_store, web};
|
|
|
|
#[derive(Debug, Validate, TryFromMultipart)]
|
|
#[try_from_multipart(rename_all = "camelCase")]
|
|
pub struct UploadPayload {
|
|
#[form_data(limit = "50MB")]
|
|
#[validate(length(min = 1, max = 16))]
|
|
files: Vec<SerdeFieldData<Bytes>>,
|
|
}
|
|
|
|
pub async fn upload(
|
|
State(state): State<AppState>,
|
|
TypedMultipart(payload): TypedMultipart<UploadPayload>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
match payload.validate() {
|
|
Ok(_) => {},
|
|
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
|
|
};
|
|
|
|
let mut file_ids = Vec::new();
|
|
|
|
for file in payload.files {
|
|
let db_file = state
|
|
.database
|
|
.insert_file(
|
|
file.metadata
|
|
.file_name
|
|
.as_deref()
|
|
.unwrap_or_else(|| "unknown"),
|
|
file.metadata.content_type.as_deref().unwrap_or_default(),
|
|
file.contents.len(),
|
|
)
|
|
.await?;
|
|
|
|
state
|
|
.object_store
|
|
.put_object(&db_file.id.to_string(), &file.contents)
|
|
.await
|
|
.map_err(object_store::Error::from)?;
|
|
|
|
file_ids.push(db_file.id);
|
|
}
|
|
|
|
Ok(Json(file_ids))
|
|
}
|
|
|
|
// file: src/web/route/mod.rs
|
|
pub mod auth;
|
|
pub mod channel;
|
|
pub mod file;
|
|
pub mod server;
|
|
pub mod user;
|
|
|
|
// file: src/web/route/server/channel/create.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
use axum_extra::extract::WithRejection;
|
|
use serde::Deserialize;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::error::ClientError;
|
|
use crate::web::ws;
|
|
use crate::{entity, web};
|
|
|
|
#[derive(Debug, Validate, Deserialize)]
|
|
pub struct CreatePayload {
|
|
#[validate(length(min = 1, max = 32))]
|
|
name: String,
|
|
|
|
#[validate(custom(function = "validate_server_channel_type"))]
|
|
r#type: entity::channel::ChannelType,
|
|
}
|
|
|
|
fn validate_server_channel_type(
|
|
r#type: &entity::channel::ChannelType,
|
|
) -> Result<(), validator::ValidationError> {
|
|
match r#type {
|
|
entity::channel::ChannelType::ServerText => Ok(()),
|
|
entity::channel::ChannelType::ServerVoice => Ok(()),
|
|
entity::channel::ChannelType::ServerCategory => Ok(()),
|
|
_ => Err(validator::ValidationError::new("invalid_channel_type")),
|
|
}
|
|
}
|
|
|
|
pub async fn create(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(server_id): Path<entity::server::Id>,
|
|
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
payload.validate().map_err(ClientError::ValidationFailed)?;
|
|
|
|
// TODO: check permissions
|
|
let server = state.database.select_server_by_id(server_id).await?;
|
|
|
|
if server.owner_id != context.user_id {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
|
|
let channel = state
|
|
.database
|
|
.insert_server_channel(&payload.name, 0, payload.r#type, server_id, None)
|
|
.await?;
|
|
|
|
ws::gateway::util::send_message_server(
|
|
state,
|
|
server_id,
|
|
ws::gateway::event::Event::AddServerChannel {
|
|
channel: channel.clone(),
|
|
},
|
|
);
|
|
|
|
Ok(Json(channel))
|
|
}
|
|
|
|
// file: src/web/route/server/channel/delete.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::ws;
|
|
use crate::webrtc::WebRtcSignal;
|
|
use crate::{entity, web};
|
|
|
|
pub async fn delete(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
|
|
let channel = state.database.select_channel_by_id(channel_id).await?;
|
|
let server = state.database.select_server_by_id(server_id).await?;
|
|
|
|
if server.owner_id != context.user_id {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
|
|
if let Some(channel_server_id) = channel.server_id {
|
|
if channel_server_id != server_id {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
} else {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
|
|
let channel = state.database.delete_channel_by_id(channel_id).await?;
|
|
|
|
ws::gateway::util::send_message_server(
|
|
state.clone(),
|
|
server_id,
|
|
ws::gateway::event::Event::RemoveServerChannel {
|
|
server_id: server_id.clone(),
|
|
channel_id: channel.id.clone(),
|
|
},
|
|
);
|
|
|
|
tokio::spawn(async move {
|
|
let voice_rooms = state.voice_rooms.read().await;
|
|
if let Some(voice_room) = voice_rooms.get(&channel_id) {
|
|
let _ = voice_room.send(WebRtcSignal::Close);
|
|
}
|
|
});
|
|
|
|
Ok(Json(channel))
|
|
}
|
|
|
|
// file: src/web/route/server/channel/get.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::{entity, web};
|
|
|
|
pub async fn get(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path((server_id, channel_id)): Path<(entity::server::Id, entity::channel::Id)>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
|
|
let channel = state.database.select_channel_by_id(channel_id).await?;
|
|
|
|
if let Some(channel_server_id) = channel.server_id {
|
|
if channel_server_id != server_id {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
} else {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
|
|
Ok(Json(channel))
|
|
}
|
|
|
|
// file: src/web/route/server/channel/list.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::{entity, web};
|
|
|
|
pub async fn list(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(server_id): Path<entity::server::Id>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let channels = state.database.select_server_channels(server_id).await?;
|
|
|
|
Ok(Json(channels))
|
|
}
|
|
|
|
// file: src/web/route/server/channel/mod.rs
|
|
mod create;
|
|
mod delete;
|
|
mod get;
|
|
mod list;
|
|
|
|
pub use create::create;
|
|
pub use delete::delete;
|
|
pub use get::get;
|
|
pub use list::list;
|
|
|
|
// file: src/web/route/server/create.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use axum_extra::extract::WithRejection;
|
|
use axum_typed_multipart::TryFromMultipart;
|
|
use serde::Deserialize;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::error::ClientError;
|
|
use crate::web::ws;
|
|
use crate::{entity, web};
|
|
use crate::web::entity::server::Server;
|
|
|
|
#[derive(Debug, Validate, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
pub struct CreatePayload {
|
|
#[validate(length(min = 1, max = 32))]
|
|
name: String,
|
|
|
|
icon_id: Option<entity::file::Id>,
|
|
}
|
|
|
|
pub async fn create(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
payload.validate().map_err(ClientError::ValidationFailed)?;
|
|
|
|
let server = state
|
|
.database
|
|
.insert_server(&payload.name, payload.icon_id, context.user_id)
|
|
.await?;
|
|
|
|
let role = state
|
|
.database
|
|
.insert_server_role(
|
|
Some(server.id.into()),
|
|
server.id,
|
|
"@everyone",
|
|
None,
|
|
false,
|
|
serde_json::json!({}),
|
|
0,
|
|
)
|
|
.await?;
|
|
|
|
let member = state
|
|
.database
|
|
.insert_server_member(server.id, context.user_id)
|
|
.await?;
|
|
|
|
state
|
|
.database
|
|
.insert_server_member_role(member.id, role.id)
|
|
.await?;
|
|
|
|
let server = Server::from(server);
|
|
|
|
ws::gateway::util::send_message(
|
|
state,
|
|
context.user_id,
|
|
ws::gateway::event::Event::AddServer {
|
|
server: server.clone(),
|
|
},
|
|
);
|
|
|
|
Ok(Json(server))
|
|
}
|
|
|
|
// file: src/web/route/server/delete.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::ws;
|
|
use crate::webrtc::WebRtcSignal;
|
|
use crate::{entity, web};
|
|
use crate::web::entity::server::Server;
|
|
|
|
pub async fn delete(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(server_id): Path<entity::server::Id>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let server = state.database.select_server_by_id(server_id).await?;
|
|
|
|
if server.owner_id != context.user_id {
|
|
return Err(web::error::ClientError::NotAllowed.into());
|
|
}
|
|
|
|
let members = state
|
|
.database
|
|
.select_server_members(server_id)
|
|
.await?
|
|
.iter()
|
|
.map(|u| u.id)
|
|
.collect::<Vec<_>>();
|
|
|
|
let channels = state
|
|
.database
|
|
.select_server_channels(server_id)
|
|
.await?
|
|
.iter()
|
|
.map(|c| c.id)
|
|
.collect::<Vec<_>>();
|
|
|
|
let state_clone = state.clone();
|
|
tokio::spawn(async move {
|
|
let voice_rooms = state_clone.voice_rooms.read().await;
|
|
for channel_id in channels {
|
|
if let Some(voice_room) = voice_rooms.get(&channel_id) {
|
|
let _ = voice_room.send(WebRtcSignal::Close);
|
|
}
|
|
}
|
|
});
|
|
|
|
let server = state.database.delete_server_by_id(server_id).await?;
|
|
|
|
ws::gateway::util::send_message_many(
|
|
state.clone(),
|
|
&members,
|
|
ws::gateway::event::Event::RemoveServer {
|
|
server_id: server.id,
|
|
},
|
|
);
|
|
|
|
Ok(Json(Server::from(server)))
|
|
}
|
|
|
|
// file: src/web/route/server/get.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::{entity, web};
|
|
use crate::web::entity::server::Server;
|
|
|
|
pub async fn get(
|
|
State(state): State<AppState>,
|
|
Path(server_id): Path<entity::server::Id>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
|
|
let server = state.database.select_server_by_id(server_id).await?;
|
|
|
|
Ok(Json(Server::from(server)))
|
|
}
|
|
|
|
// file: src/web/route/server/invite/create.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, Query, State};
|
|
use axum::response::IntoResponse;
|
|
use base64::Engine;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::{entity, web};
|
|
|
|
#[derive(serde::Deserialize, Debug)]
|
|
pub struct CreateParams {
|
|
#[serde(deserialize_with = "crate::util::deserialize_duration_seconds_option")]
|
|
#[serde(default)]
|
|
pub expires_in: Option<std::time::Duration>,
|
|
}
|
|
|
|
pub async fn create(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(server_id): Path<entity::server::Id>,
|
|
Query(params): Query<CreateParams>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
// TODO: check permissions
|
|
|
|
let code = {
|
|
use rand::Rng;
|
|
|
|
let mut rng = rand::rng();
|
|
let mut code = [0u8; 32];
|
|
rng.fill(&mut code);
|
|
base64::prelude::BASE64_URL_SAFE_NO_PAD.encode(&code)
|
|
};
|
|
|
|
let expires_at = params.expires_in.map(|d| {
|
|
let now = chrono::Utc::now();
|
|
now + chrono::Duration::from_std(d).expect("valid duration")
|
|
});
|
|
|
|
let invite = state
|
|
.database
|
|
.insert_server_invite(&code, server_id, Some(context.user_id), expires_at)
|
|
.await?;
|
|
|
|
Ok(Json(invite))
|
|
}
|
|
|
|
// file: src/web/route/server/invite/get.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::ws;
|
|
use crate::{database, web};
|
|
use crate::web::entity::server::Server;
|
|
|
|
pub async fn get(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
Path(code): Path<String>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let invite = state.database.select_server_invite_by_code(&code).await?;
|
|
let server = state.database.select_server_by_id(invite.server_id).await?;
|
|
|
|
let member = match state
|
|
.database
|
|
.insert_server_member(invite.server_id, context.user_id)
|
|
.await
|
|
{
|
|
Ok(member) => member,
|
|
Err(database::Error::MemberAlreadyExists) => return Ok(Json(Server::from(server))),
|
|
Err(e) => return Err(e.into()),
|
|
};
|
|
|
|
state
|
|
.database
|
|
.insert_server_member_role(member.id, invite.server_id)
|
|
.await?;
|
|
|
|
let user = state.database.select_user_by_id(context.user_id).await?;
|
|
|
|
ws::gateway::util::send_message_server(
|
|
state,
|
|
invite.server_id,
|
|
ws::gateway::event::Event::AddServerMember {
|
|
server_id: server.id,
|
|
member: user.into(),
|
|
},
|
|
);
|
|
|
|
Ok(Json(Server::from(server)))
|
|
}
|
|
|
|
// file: src/web/route/server/invite/mod.rs
|
|
mod create;
|
|
mod get;
|
|
|
|
pub use create::create;
|
|
pub use get::get;
|
|
|
|
// file: src/web/route/server/list.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::server::Server;
|
|
|
|
pub async fn list(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let servers = state
|
|
.database
|
|
.select_user_servers(context.user_id)
|
|
.await?
|
|
.into_iter()
|
|
.map(Server::from)
|
|
.collect::<Vec<_>>();
|
|
|
|
Ok(Json(servers))
|
|
}
|
|
|
|
// file: src/web/route/server/mod.rs
|
|
pub mod channel;
|
|
mod create;
|
|
mod delete;
|
|
mod get;
|
|
pub mod invite;
|
|
mod list;
|
|
|
|
pub use create::create;
|
|
pub use delete::delete;
|
|
pub use get::get;
|
|
pub use list::list;
|
|
|
|
// file: src/web/route/user/channel/create.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use axum_extra::extract::WithRejection;
|
|
use serde::Deserialize;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::user::PartialUser;
|
|
use crate::web::route::user::channel::RecipientChannel;
|
|
use crate::web::ws;
|
|
use crate::{entity, web};
|
|
|
|
#[derive(Debug, Validate, Deserialize)]
|
|
pub struct CreatePayload {
|
|
#[validate(length(min = 1, max = 32))]
|
|
recipients: Vec<entity::user::Id>,
|
|
}
|
|
|
|
pub async fn create(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
match payload.validate() {
|
|
Ok(_) => {},
|
|
Err(err) => {
|
|
return Err(web::error::ClientError::ValidationFailed(err).into());
|
|
},
|
|
}
|
|
|
|
let channel_id = match payload.recipients.len() {
|
|
1 => {
|
|
let recipient = payload.recipients[0];
|
|
state
|
|
.database
|
|
.procedure_create_dm_channel(context.user_id, recipient)
|
|
.await?
|
|
},
|
|
_ => {
|
|
state
|
|
.database
|
|
.procedure_create_group_channel(context.user_id, &payload.recipients)
|
|
.await?
|
|
},
|
|
};
|
|
|
|
let channel = state.database.select_channel_by_id(channel_id).await?;
|
|
|
|
let recipients = state
|
|
.database
|
|
.select_channel_recipients(channel_id)
|
|
.await?
|
|
.unwrap_or_default()
|
|
.into_iter()
|
|
.map(|user| user.id)
|
|
.collect::<Vec<_>>();
|
|
|
|
let recipient_channels = RecipientChannel {
|
|
channel: channel.clone(),
|
|
recipients: recipients.clone(),
|
|
};
|
|
|
|
ws::gateway::util::send_message_channel(
|
|
state,
|
|
channel_id,
|
|
ws::gateway::event::Event::AddDmChannel {
|
|
channel,
|
|
recipients,
|
|
},
|
|
);
|
|
|
|
Ok(Json(recipient_channels))
|
|
}
|
|
|
|
// file: src/web/route/user/channel/list.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::route::user::channel::RecipientChannel;
|
|
|
|
pub async fn list(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let channels = state.database.select_user_channels(context.user_id).await?;
|
|
|
|
let mut recipient_channels = Vec::new();
|
|
for channel in channels {
|
|
let recipients = state.database.select_channel_recipients(channel.id).await?;
|
|
|
|
let recipients = match recipients {
|
|
Some(recipients) => recipients
|
|
.into_iter()
|
|
.map(|user| user.id)
|
|
.collect(),
|
|
None => {
|
|
continue;
|
|
},
|
|
};
|
|
|
|
recipient_channels.push(RecipientChannel {
|
|
channel,
|
|
recipients,
|
|
});
|
|
}
|
|
|
|
Ok(Json(recipient_channels))
|
|
}
|
|
|
|
// file: src/web/route/user/channel/mod.rs
|
|
mod create;
|
|
mod list;
|
|
|
|
pub use create::create;
|
|
pub use list::list;
|
|
use serde::Serialize;
|
|
|
|
use crate::entity::{channel, user};
|
|
|
|
#[derive(Debug, Serialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
struct RecipientChannel {
|
|
#[serde(flatten)]
|
|
pub channel: channel::Channel,
|
|
pub recipients: Vec<user::Id>,
|
|
}
|
|
|
|
// file: src/web/route/user/get.rs
|
|
use axum::Json;
|
|
use axum::extract::{Path, State};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web;
|
|
use crate::web::entity::user::PartialUser;
|
|
|
|
pub async fn get_by_id(
|
|
State(state): State<AppState>,
|
|
Path(user_id): Path<uuid::Uuid>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let user = state.database.select_user_by_id(user_id).await?;
|
|
|
|
Ok(Json(PartialUser::from(user)))
|
|
}
|
|
|
|
// file: src/web/route/user/me.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::user::FullUser;
|
|
|
|
pub async fn me(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
) -> web::Result<impl IntoResponse> {
|
|
let user = state.database.select_user_by_id(context.user_id).await?;
|
|
|
|
Ok(Json(FullUser::from(user)))
|
|
}
|
|
|
|
// file: src/web/route/user/mod.rs
|
|
pub mod channel;
|
|
mod get;
|
|
mod me;
|
|
mod patch;
|
|
|
|
pub use get::get_by_id;
|
|
pub use me::me;
|
|
pub use patch::patch;
|
|
|
|
|
|
// file: src/web/route/user/patch.rs
|
|
use axum::Json;
|
|
use axum::extract::State;
|
|
use axum::response::IntoResponse;
|
|
use axum_extra::extract::WithRejection;
|
|
use serde::Deserialize;
|
|
use validator::Validate;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::context::UserContext;
|
|
use crate::web::entity::user::{FullUser, PartialUser};
|
|
use crate::web::ws;
|
|
use crate::{entity, util, web};
|
|
|
|
#[derive(Debug, Validate, Deserialize)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[validate(schema(function = "validate_create_payload"))]
|
|
pub struct CreatePayload {
|
|
#[validate(length(min = 1, max = 32))]
|
|
#[serde(default)]
|
|
display_name: util::tristate::TriState<String>,
|
|
|
|
#[serde(default)]
|
|
avatar_id: util::tristate::TriState<entity::file::Id>,
|
|
}
|
|
|
|
fn validate_create_payload(payload: &CreatePayload) -> Result<(), validator::ValidationError> {
|
|
if payload.display_name.is_absent() && payload.avatar_id.is_absent() {
|
|
return Err(validator::ValidationError::new(
|
|
"at_least_one_field_required",
|
|
));
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
pub async fn patch(
|
|
State(state): State<AppState>,
|
|
context: UserContext,
|
|
WithRejection(Json(payload), _): WithRejection<Json<CreatePayload>, web::Error>,
|
|
) -> web::Result<impl IntoResponse> {
|
|
match payload.validate() {
|
|
Ok(_) => {},
|
|
Err(e) => return Err(web::error::ClientError::ValidationFailed(e).into()),
|
|
};
|
|
|
|
let user = state
|
|
.database
|
|
.update_user_by_id(
|
|
context.user_id,
|
|
payload.display_name.as_deref(),
|
|
payload.avatar_id,
|
|
)
|
|
.await?;
|
|
|
|
ws::gateway::util::send_message(
|
|
state.clone(),
|
|
context.user_id,
|
|
ws::gateway::event::Event::AddUser {
|
|
user: PartialUser::from(user.clone()),
|
|
},
|
|
);
|
|
|
|
ws::gateway::util::send_message_related(
|
|
state.clone(),
|
|
context.user_id,
|
|
ws::gateway::event::Event::AddUser {
|
|
user: PartialUser::from(user.clone()),
|
|
},
|
|
);
|
|
|
|
Ok(Json(FullUser::from(user)))
|
|
}
|
|
|
|
// file: src/web/ws/error.rs
|
|
pub type Result<T, E> = std::result::Result<T, Error<E>>;
|
|
|
|
#[derive(Debug, derive_more::From, derive_more::Display)]
|
|
pub enum Error<T: CustomError> {
|
|
#[from]
|
|
Custom(T),
|
|
|
|
#[from]
|
|
Json(serde_json::Error),
|
|
|
|
#[from]
|
|
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
|
|
|
|
WrongMessageType,
|
|
|
|
WebSocketClosed,
|
|
|
|
UnknownError,
|
|
}
|
|
|
|
pub trait CustomError {}
|
|
|
|
// file: src/web/ws/gateway/connection.rs
|
|
use axum::extract::ws::Message as AxumMessage;
|
|
use base64::Engine as _;
|
|
use futures::{Stream, StreamExt};
|
|
use sha2::{Digest, Sha256};
|
|
use tokio::sync::mpsc;
|
|
|
|
use super::error::Error as WsError;
|
|
use super::event::Event as WsEvent;
|
|
use super::protocol::{WsClientMessage, WsServerMessage};
|
|
use super::state::{WsContext, WsState, WsUserContext};
|
|
use crate::jwt;
|
|
use crate::state::AppState;
|
|
use crate::web::ws::gateway::SessionKey;
|
|
use crate::web::ws::general::WebSocketHandler;
|
|
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
|
|
use crate::web::ws::voice;
|
|
use crate::webrtc::WebRtcSignal;
|
|
|
|
impl WebSocketHandler for WsContext {
|
|
type ServerMessage = WsServerMessage;
|
|
type ClientMessage = WsClientMessage;
|
|
type Error = WsError;
|
|
|
|
async fn handle_stream<S>(
|
|
&mut self,
|
|
stream: S,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
app_state: &AppState,
|
|
) -> crate::web::ws::error::Result<(), Self::Error>
|
|
where
|
|
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
|
|
{
|
|
process_websocket_messages(self, stream, sender, app_state).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn cleanup(&mut self, app_state: &AppState) {
|
|
if let Some(user_ctx_data) = &self.user_context {
|
|
app_state
|
|
.unregister_gateway_connected_user(
|
|
user_ctx_data.user_id,
|
|
&user_ctx_data.session_key,
|
|
)
|
|
.await;
|
|
}
|
|
|
|
drop(self.event_channel.take());
|
|
}
|
|
|
|
async fn handle_result_error(
|
|
&mut self,
|
|
error: Self::Error,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
) {
|
|
let error_ws_message = WsServerMessage::Error { code: error };
|
|
|
|
let _ = sender.send(SendWsMessage::new_no_response(error_ws_message));
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
|
|
))]
|
|
async fn process_websocket_messages<S>(
|
|
context: &mut WsContext,
|
|
mut ws_stream: S,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
|
app_state: &AppState,
|
|
) -> crate::web::ws::error::Result<(), WsError>
|
|
where
|
|
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
|
|
{
|
|
loop {
|
|
match context.connection_state {
|
|
WsState::Initialize => {
|
|
tokio::select! {
|
|
biased;
|
|
maybe_message = ws_stream.next() => {
|
|
match maybe_message {
|
|
Some(Ok(message)) => {
|
|
handle_initial_message(context, message, sender, app_state).await?;
|
|
context.connection_state = WsState::Connected;
|
|
tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected.");
|
|
}
|
|
Some(Err(axum_ws_err)) => {
|
|
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
|
|
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
|
}
|
|
None => {
|
|
tracing::debug!("WebSocket stream ended by client during Initialize state.");
|
|
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
},
|
|
WsState::Connected => {
|
|
let user_ctx = context
|
|
.user_context
|
|
.as_ref()
|
|
.expect("User context must be set in Connected state");
|
|
let (_event_tx_ref, event_rx) = context
|
|
.event_channel
|
|
.as_mut()
|
|
.expect("Event channel must be set in Connected state");
|
|
|
|
tokio::select! {
|
|
biased;
|
|
maybe_app_event = event_rx.recv() => {
|
|
if let Some(app_event_data) = maybe_app_event {
|
|
SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?;
|
|
} else {
|
|
tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket.");
|
|
return Ok(());
|
|
}
|
|
}
|
|
maybe_ws_message = ws_stream.next() => {
|
|
match maybe_ws_message {
|
|
Some(Ok(message)) => {
|
|
handle_connected_message(context, message, sender, &app_state).await?;
|
|
}
|
|
Some(Err(axum_ws_err)) => {
|
|
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
|
|
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
|
}
|
|
None => {
|
|
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
|
|
return Err(crate::web::ws::error::Error::WebSocketClosed);
|
|
}
|
|
}
|
|
}
|
|
}
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all, fields(state = ?context.connection_state))]
|
|
async fn handle_initial_message(
|
|
context: &mut WsContext,
|
|
message: AxumMessage,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
|
app_state: &AppState,
|
|
) -> crate::web::ws::error::Result<(), WsError> {
|
|
match deserialize_ws_message(message)? {
|
|
WsClientMessage::Authenticate { token } => {
|
|
match crate::web::middleware::get_context_from_token(&app_state, &token).await {
|
|
Ok(auth_user_context) => {
|
|
let user_id = auth_user_context.user_id;
|
|
|
|
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
|
|
context.event_channel = Some((event_tx.clone(), event_rx));
|
|
|
|
let random_key_part = rand::random::<u64>();
|
|
let current_session_key: SessionKey = {
|
|
let mut hasher = Sha256::new();
|
|
hasher.update(token.as_bytes());
|
|
hasher.update(user_id.to_string().as_bytes());
|
|
hasher.update(&random_key_part.to_be_bytes());
|
|
base64::engine::general_purpose::URL_SAFE_NO_PAD
|
|
.encode(hasher.finalize())
|
|
.into()
|
|
};
|
|
|
|
context.user_context = Some(WsUserContext {
|
|
user_id,
|
|
session_key: current_session_key.clone(),
|
|
});
|
|
|
|
app_state
|
|
.register_gateway_connected_user(
|
|
user_id,
|
|
current_session_key.clone(),
|
|
event_tx,
|
|
)
|
|
.await;
|
|
|
|
SendWsMessage::send_with_response(
|
|
sender,
|
|
WsServerMessage::AuthenticateAccepted {
|
|
user_id,
|
|
session_key: current_session_key,
|
|
},
|
|
)
|
|
.await?;
|
|
Ok(())
|
|
},
|
|
Err(_auth_err) => {
|
|
tracing::warn!(token = %token, "Authentication failed for token.");
|
|
let _ = SendWsMessage::send_with_response(
|
|
sender,
|
|
WsServerMessage::AuthenticateDenied,
|
|
)
|
|
.await;
|
|
Err(WsError::AuthenticationFailed.into())
|
|
},
|
|
}
|
|
},
|
|
#[allow(unreachable_patterns)]
|
|
_ => {
|
|
tracing::warn!("Unexpected message type received during Initialize state.");
|
|
Err(crate::web::ws::error::Error::WrongMessageType)
|
|
},
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)
|
|
))]
|
|
async fn handle_connected_message(
|
|
context: &mut WsContext,
|
|
message: AxumMessage,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
|
app_state: &AppState,
|
|
) -> crate::web::ws::error::Result<(), WsError> {
|
|
match deserialize_ws_message(message)? {
|
|
WsClientMessage::VoiceStateUpdate {
|
|
server_id,
|
|
channel_id,
|
|
} => {
|
|
let server_id = server_id.unwrap_or_else(|| channel_id.clone().into());
|
|
// TODO: check if can join this channel
|
|
|
|
let claims = voice::claims::VoiceClaims {
|
|
user_id: context
|
|
.user_context
|
|
.as_ref()
|
|
.expect("user context should be present")
|
|
.user_id,
|
|
server_id,
|
|
channel_id,
|
|
exp: (chrono::Utc::now() + crate::config::config().gateway.voice_token_lifetime)
|
|
.timestamp(),
|
|
};
|
|
|
|
let token = jwt::generate_jwt(
|
|
claims,
|
|
crate::config::config().security.voice_secret.as_ref(),
|
|
)
|
|
.map_err(|_| WsError::TokenGenerationFailed)?;
|
|
|
|
SendWsMessage::send_with_response(
|
|
sender,
|
|
WsServerMessage::Event {
|
|
event: WsEvent::VoiceServerUpdate {
|
|
server_id,
|
|
channel_id,
|
|
token,
|
|
},
|
|
},
|
|
)
|
|
.await?;
|
|
|
|
Ok(())
|
|
},
|
|
WsClientMessage::RequestVoiceStates { server_id } => {
|
|
let channels = app_state
|
|
.database
|
|
.select_server_channels(server_id)
|
|
.await
|
|
.map_err(|_| crate::web::ws::error::Error::UnknownError)?;
|
|
|
|
for channel in channels {
|
|
let (tx, rx) = tokio::sync::oneshot::channel();
|
|
|
|
let webrtc_sender =
|
|
{ app_state.voice_rooms.read().await.get(&channel.id).cloned() };
|
|
|
|
if let Some(voice_room) = webrtc_sender {
|
|
let _ = voice_room.send(WebRtcSignal::RequestPeers { response: tx });
|
|
|
|
let peers = match rx.await {
|
|
Ok(peers) => peers,
|
|
Err(_) => {
|
|
continue;
|
|
},
|
|
};
|
|
|
|
for peer in peers {
|
|
let _ =
|
|
sender.send(SendWsMessage::new_no_response(WsServerMessage::Event {
|
|
event: WsEvent::VoiceChannelConnected {
|
|
server_id,
|
|
channel_id: channel.id,
|
|
user_id: peer,
|
|
},
|
|
}));
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
},
|
|
other_message => {
|
|
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
|
|
Err(crate::web::ws::error::Error::WrongMessageType)
|
|
},
|
|
}
|
|
}
|
|
|
|
// file: src/web/ws/gateway/error.rs
|
|
use crate::web::ws::error::CustomError;
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::From, derive_more::Display, serde::Serialize)]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum Error {
|
|
AuthenticationFailed,
|
|
TokenGenerationFailed,
|
|
}
|
|
|
|
impl CustomError for Error {}
|
|
|
|
// file: src/web/ws/gateway/event.rs
|
|
use crate::{entity, web};
|
|
|
|
#[derive(Debug, Clone, serde::Serialize)]
|
|
#[serde(tag = "type", content = "data")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum Event {
|
|
#[serde(rename_all = "camelCase")]
|
|
AddServer { server: web::entity::server::Server },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveServer { server_id: entity::server::Id },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AddDmChannel {
|
|
channel: entity::channel::Channel,
|
|
recipients: Vec<entity::user::Id>,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveDmChannel { channel_id: entity::channel::Id },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AddServerChannel { channel: entity::channel::Channel },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveServerChannel {
|
|
server_id: entity::server::Id,
|
|
channel_id: entity::channel::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AddUser {
|
|
user: web::entity::user::PartialUser,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveUser { user_id: entity::user::Id },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AddServerMember {
|
|
server_id: entity::server::Id,
|
|
member: web::entity::user::PartialUser,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveServerMember {
|
|
server_id: entity::server::Id,
|
|
member_id: entity::user::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AddMessage {
|
|
channel_id: entity::channel::Id,
|
|
message: web::entity::message::Message,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RemoveMessage {
|
|
channel_id: entity::channel::Id,
|
|
message_id: entity::message::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
VoiceChannelConnected {
|
|
server_id: entity::server::Id,
|
|
channel_id: entity::channel::Id,
|
|
user_id: entity::user::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
VoiceChannelDisconnected {
|
|
server_id: entity::server::Id,
|
|
channel_id: entity::channel::Id,
|
|
user_id: entity::user::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
VoiceServerUpdate {
|
|
server_id: entity::server::Id,
|
|
channel_id: entity::channel::Id,
|
|
token: String,
|
|
},
|
|
}
|
|
|
|
// file: src/web/ws/gateway/mod.rs
|
|
use axum::extract::{State, WebSocketUpgrade};
|
|
use axum::response::IntoResponse;
|
|
use dashmap::DashMap;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::ws::gateway::state::{EventSender, WsContext};
|
|
use crate::web::ws::general;
|
|
|
|
mod connection;
|
|
mod error;
|
|
pub mod event;
|
|
mod protocol;
|
|
mod state;
|
|
pub mod util;
|
|
|
|
#[derive(
|
|
Debug,
|
|
Clone,
|
|
Eq,
|
|
PartialEq,
|
|
Hash,
|
|
Default,
|
|
derive_more::Display,
|
|
serde::Serialize,
|
|
serde::Deserialize,
|
|
derive_more::From,
|
|
)]
|
|
#[serde(transparent)]
|
|
pub struct SessionKey(String);
|
|
|
|
#[derive(Debug, Default)]
|
|
pub struct GatewayWsState {
|
|
pub instances: DashMap<SessionKey, EventSender>,
|
|
}
|
|
|
|
pub async fn ws_handler(
|
|
State(app_state): State<AppState>,
|
|
ws: WebSocketUpgrade,
|
|
) -> crate::web::error::Result<impl IntoResponse> {
|
|
Ok(ws.on_upgrade(|socket| {
|
|
general::handle_websocket_connection(socket, app_state, WsContext::default())
|
|
}))
|
|
}
|
|
|
|
// file: src/web/ws/gateway/protocol.rs
|
|
use super::{SessionKey, error, event};
|
|
use crate::entity;
|
|
|
|
#[derive(Debug, serde::Serialize)]
|
|
#[serde(tag = "type", content = "data")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum WsServerMessage {
|
|
AuthenticateDenied,
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
AuthenticateAccepted {
|
|
user_id: entity::user::Id,
|
|
session_key: SessionKey,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
Event {
|
|
event: event::Event,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
Error {
|
|
code: error::Error,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
#[serde(tag = "type", content = "data")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum WsClientMessage {
|
|
#[serde(rename_all = "camelCase")]
|
|
Authenticate { token: String },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
VoiceStateUpdate {
|
|
server_id: Option<entity::server::Id>,
|
|
channel_id: entity::channel::Id,
|
|
},
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
RequestVoiceStates { server_id: entity::server::Id },
|
|
}
|
|
|
|
// file: src/web/ws/gateway/state.rs
|
|
use tokio::sync::mpsc;
|
|
|
|
use super::{SessionKey, event};
|
|
use crate::entity;
|
|
|
|
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
|
|
pub enum WsState {
|
|
Initialize,
|
|
Connected,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct WsUserContext {
|
|
pub user_id: entity::user::Id,
|
|
pub session_key: SessionKey, // Unique key for this specific WebSocket session instance
|
|
}
|
|
|
|
pub type EventSender = mpsc::UnboundedSender<event::Event>;
|
|
pub type EventReceiver = mpsc::UnboundedReceiver<event::Event>;
|
|
|
|
pub struct WsContext {
|
|
pub connection_state: WsState,
|
|
pub user_context: Option<WsUserContext>,
|
|
pub event_channel: Option<(EventSender, EventReceiver)>,
|
|
}
|
|
|
|
impl Default for WsContext {
|
|
fn default() -> Self {
|
|
Self {
|
|
connection_state: WsState::Initialize,
|
|
user_context: None,
|
|
event_channel: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/ws/gateway/util.rs
|
|
use crate::entity;
|
|
use crate::state::AppState;
|
|
use crate::web::ws::gateway::event;
|
|
|
|
pub fn send_message(state: AppState, user_id: entity::user::Id, message: event::Event) {
|
|
tokio::spawn(async move {
|
|
let connected_users = state.gateway_state.connected.get_async(&user_id).await;
|
|
if let Some(session) = connected_users {
|
|
for instance in session.instances.iter() {
|
|
if let Err(e) = instance.send(message.clone()) {
|
|
tracing::error!("failed to send message: {}", e);
|
|
}
|
|
}
|
|
}
|
|
});
|
|
}
|
|
|
|
pub fn send_message_many(state: AppState, user_ids: &[entity::user::Id], message: event::Event) {
|
|
for id in user_ids.iter() {
|
|
send_message(state.clone(), *id, message.clone());
|
|
}
|
|
}
|
|
|
|
pub fn send_message_server(state: AppState, server_id: entity::server::Id, message: event::Event) {
|
|
tokio::spawn(async move {
|
|
let users = state
|
|
.database
|
|
.select_server_members(server_id)
|
|
.await
|
|
.unwrap_or_else(|_| vec![])
|
|
.iter()
|
|
.map(|u| u.id)
|
|
.collect::<Vec<_>>();
|
|
|
|
send_message_many(state, &users, message);
|
|
});
|
|
}
|
|
|
|
pub fn send_message_channel(
|
|
state: AppState,
|
|
channel_id: entity::channel::Id,
|
|
message: event::Event,
|
|
) {
|
|
tokio::spawn(async move {
|
|
let users = state
|
|
.database
|
|
.select_channel_members(channel_id)
|
|
.await
|
|
.unwrap_or_else(|_| vec![])
|
|
.iter()
|
|
.map(|u| u.id)
|
|
.collect::<Vec<_>>();
|
|
|
|
send_message_many(state, &users, message);
|
|
});
|
|
}
|
|
|
|
pub fn send_message_related(state: AppState, user_id: entity::user::Id, message: event::Event) {
|
|
tokio::spawn(async move {
|
|
let users = state
|
|
.database
|
|
.select_related_user_ids(user_id)
|
|
.await
|
|
.unwrap_or_else(|_| vec![]);
|
|
|
|
send_message_many(state, &users, message);
|
|
});
|
|
}
|
|
|
|
// file: src/web/ws/general.rs
|
|
use std::fmt::Debug;
|
|
|
|
use axum::extract::ws::WebSocket;
|
|
use futures::{Stream, StreamExt};
|
|
use serde::Serialize;
|
|
use serde::de::DeserializeOwned;
|
|
use tokio::sync::mpsc;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::ws::error::CustomError;
|
|
use crate::web::ws::util;
|
|
use crate::web::ws::util::SendWsMessage;
|
|
|
|
pub trait WebSocketHandler {
|
|
type ServerMessage: Serialize + Send;
|
|
type ClientMessage: DeserializeOwned;
|
|
type Error: CustomError + Send + Debug;
|
|
|
|
async fn handle_stream<S>(
|
|
&mut self,
|
|
stream: S,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
app_state: &AppState,
|
|
) -> crate::web::ws::error::Result<(), Self::Error>
|
|
where
|
|
S: Stream<Item = Result<axum::extract::ws::Message, axum::Error>> + Unpin;
|
|
|
|
async fn cleanup(&mut self, app_state: &AppState);
|
|
|
|
async fn handle_result_error(
|
|
&mut self,
|
|
error: Self::Error,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
);
|
|
}
|
|
|
|
#[tracing::instrument(skip_all)]
|
|
pub async fn handle_websocket_connection(
|
|
websocket: WebSocket,
|
|
app_state: AppState,
|
|
mut handler: impl WebSocketHandler + 'static,
|
|
) {
|
|
let (ws_sink, ws_stream) = websocket.split();
|
|
|
|
let (internal_send_tx, internal_send_rx) = mpsc::unbounded_channel();
|
|
|
|
let writer_task = util::spawn_writer_task(ws_sink, internal_send_rx);
|
|
|
|
let processing_result = handler
|
|
.handle_stream(ws_stream, &internal_send_tx, &app_state)
|
|
.await;
|
|
|
|
handler.cleanup(&app_state).await;
|
|
|
|
match processing_result {
|
|
Ok(_) => {},
|
|
Err(crate::web::ws::error::Error::Custom(err_to_report)) => {
|
|
handler
|
|
.handle_result_error(err_to_report, &internal_send_tx)
|
|
.await;
|
|
},
|
|
Err(e) => {
|
|
tracing::info!("WebSocket connection closed: {:?}", e);
|
|
},
|
|
}
|
|
|
|
drop(internal_send_tx);
|
|
if let Err(e) = writer_task.await {
|
|
tracing::error!(
|
|
"WebSocket writer task panicked or encountered an error: {:?}",
|
|
e
|
|
);
|
|
}
|
|
}
|
|
|
|
// file: src/web/ws/mod.rs
|
|
mod error;
|
|
pub mod gateway;
|
|
mod util;
|
|
pub mod voice;
|
|
mod general;
|
|
|
|
// file: src/web/ws/util.rs
|
|
use axum::extract::ws::Message as AxumMessage;
|
|
use futures::{Sink, SinkExt};
|
|
use serde::Serialize;
|
|
use serde::de::DeserializeOwned;
|
|
use tokio::sync::{mpsc, oneshot};
|
|
|
|
use crate::web::ws::error::CustomError;
|
|
|
|
pub fn spawn_writer_task<S, T, E>(
|
|
mut ws_sink: S,
|
|
mut writer_rx: mpsc::UnboundedReceiver<SendWsMessage<T, E>>,
|
|
) -> tokio::task::JoinHandle<()>
|
|
where
|
|
S: Sink<axum::extract::ws::Message> + Unpin + Send + 'static,
|
|
T: Serialize + Send + 'static,
|
|
E: CustomError + Send + 'static,
|
|
{
|
|
tokio::spawn(async move {
|
|
while let Some(SendWsMessage {
|
|
message,
|
|
response_ch,
|
|
}) = writer_rx.recv().await
|
|
{
|
|
let send_result = match serialize_ws_message(message) {
|
|
Ok(ws_msg) => {
|
|
if ws_sink.send(ws_msg).await.is_err() {
|
|
Err(super::error::Error::WebSocketClosed)
|
|
} else {
|
|
Ok(())
|
|
}
|
|
},
|
|
Err(e) => Err(e.into()),
|
|
};
|
|
|
|
if let Some(ch) = response_ch {
|
|
let _ = ch.send(send_result);
|
|
}
|
|
}
|
|
|
|
let _ = ws_sink.close().await;
|
|
})
|
|
}
|
|
|
|
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
|
|
pub fn deserialize_ws_message<T: DeserializeOwned, E: CustomError>(
|
|
message: AxumMessage,
|
|
) -> super::error::Result<T, E> {
|
|
match message {
|
|
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(super::error::Error::from),
|
|
AxumMessage::Close(_) => Err(super::error::Error::WebSocketClosed),
|
|
_ => Err(super::error::Error::WrongMessageType),
|
|
}
|
|
}
|
|
|
|
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
|
|
pub fn serialize_ws_message<T: Serialize, E: CustomError>(
|
|
message: T,
|
|
) -> super::error::Result<AxumMessage, E> {
|
|
serde_json::to_string(&message)
|
|
.map(Into::into)
|
|
.map(AxumMessage::Text)
|
|
.map_err(super::error::Error::from)
|
|
}
|
|
|
|
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
|
|
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
|
|
pub struct SendWsMessage<T, E: CustomError> {
|
|
pub message: T,
|
|
pub response_ch: Option<oneshot::Sender<super::error::Result<(), E>>>,
|
|
}
|
|
|
|
impl<T, E: CustomError> SendWsMessage<T, E> {
|
|
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
|
|
pub async fn send_with_response(
|
|
tx: &mpsc::UnboundedSender<Self>, // Changed to reference
|
|
message: T,
|
|
) -> super::error::Result<(), E> {
|
|
let (response_tx, response_rx) = oneshot::channel();
|
|
let send_message = SendWsMessage {
|
|
message,
|
|
response_ch: Some(response_tx),
|
|
};
|
|
|
|
if tx.send(send_message).is_err() {
|
|
Err(super::error::Error::WebSocketClosed) // MPSC channel closed, writer task likely dead
|
|
} else {
|
|
// Wait for the writer task to acknowledge the send attempt.
|
|
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
|
|
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
|
|
}
|
|
}
|
|
|
|
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
|
|
pub fn new_no_response(message: T) -> Self {
|
|
Self {
|
|
message,
|
|
response_ch: None,
|
|
}
|
|
}
|
|
}
|
|
|
|
// file: src/web/ws/voice/claims.rs
|
|
use crate::entity;
|
|
|
|
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
|
|
pub struct VoiceClaims {
|
|
pub user_id: entity::user::Id,
|
|
pub server_id: entity::server::Id,
|
|
pub channel_id: entity::channel::Id,
|
|
pub exp: i64,
|
|
}
|
|
|
|
// file: src/web/ws/voice/connection.rs
|
|
use axum::extract::ws::Message as AxumMessage;
|
|
use futures::{Stream, StreamExt};
|
|
use tokio::sync::{mpsc, oneshot};
|
|
|
|
use super::error::{self, Error as WsError};
|
|
use super::protocol::{WsClientMessage, WsServerMessage};
|
|
use crate::jwt;
|
|
use crate::state::AppState;
|
|
use crate::web::ws;
|
|
use crate::web::ws::general::WebSocketHandler;
|
|
use crate::web::ws::util::{SendWsMessage, deserialize_ws_message};
|
|
use crate::web::ws::voice::claims::VoiceClaims;
|
|
use crate::web::ws::voice::protocol::WsServerMessage::SdpAnswer;
|
|
use crate::web::ws::voice::state::{WsContext, WsState};
|
|
use crate::webrtc::{Offer, OfferSignal, WebRtcSignal};
|
|
|
|
impl WebSocketHandler for WsContext {
|
|
type ServerMessage = WsServerMessage;
|
|
type ClientMessage = WsClientMessage;
|
|
type Error = WsError;
|
|
|
|
async fn handle_stream<S>(
|
|
&mut self,
|
|
stream: S,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
app_state: &AppState,
|
|
) -> ws::error::Result<(), Self::Error>
|
|
where
|
|
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
|
|
{
|
|
process_websocket_messages(self, stream, sender, app_state).await?;
|
|
|
|
Ok(())
|
|
}
|
|
|
|
async fn cleanup(&mut self, app_state: &AppState) {
|
|
tracing::debug!("Cleaning up WebSocket connection.");
|
|
|
|
match &self.connection_state {
|
|
WsState::Connected {
|
|
signal_channel,
|
|
server_id,
|
|
channel_id,
|
|
user_id,
|
|
..
|
|
} => {
|
|
ws::gateway::util::send_message_channel(
|
|
app_state.clone(),
|
|
*channel_id,
|
|
ws::gateway::event::Event::VoiceChannelDisconnected {
|
|
server_id: *server_id,
|
|
channel_id: *channel_id,
|
|
user_id: *user_id,
|
|
},
|
|
);
|
|
|
|
let _ = signal_channel.send(WebRtcSignal::Disconnect(user_id.clone()));
|
|
},
|
|
WsState::Initialize => {},
|
|
}
|
|
}
|
|
|
|
async fn handle_result_error(
|
|
&mut self,
|
|
error: Self::Error,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<Self::ServerMessage, Self::Error>>,
|
|
) {
|
|
tracing::error!("WebSocket error: {:?}", error);
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all)]
|
|
async fn process_websocket_messages<S>(
|
|
context: &mut WsContext,
|
|
mut ws_stream: S,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, WsError>>,
|
|
app_state: &AppState,
|
|
) -> ws::error::Result<(), WsError>
|
|
where
|
|
S: Stream<Item = Result<AxumMessage, axum::Error>> + Unpin,
|
|
{
|
|
loop {
|
|
match &context.connection_state {
|
|
WsState::Initialize => {
|
|
while let Some(Ok(message)) = ws_stream.next().await {
|
|
handle_initial_message(context, message, sender, &app_state).await?;
|
|
break;
|
|
}
|
|
},
|
|
WsState::Connected { signal_channel, .. } => {
|
|
let signal_channel = signal_channel.clone();
|
|
loop {
|
|
tokio::select! {
|
|
biased;
|
|
_ = signal_channel.closed() => {
|
|
tracing::debug!("Signal channel closed.");
|
|
break;
|
|
}
|
|
Some(Ok(message)) = ws_stream.next() => {
|
|
handle_connected_message(context, message, sender, &app_state).await?;
|
|
}
|
|
else => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
|
|
return Err(ws::error::Error::WebSocketClosed);
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all)]
|
|
async fn handle_initial_message(
|
|
context: &mut WsContext,
|
|
message: AxumMessage,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
|
|
app_state: &AppState,
|
|
) -> ws::error::Result<(), error::Error> {
|
|
match deserialize_ws_message(message)? {
|
|
WsClientMessage::Authenticate { token } => match jwt::verify_jwt::<VoiceClaims>(
|
|
&token,
|
|
crate::config::config().security.voice_secret.as_ref(),
|
|
) {
|
|
Ok(claims) => {
|
|
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateAccepted)
|
|
.await?;
|
|
|
|
let signal_channel =
|
|
ws::voice::state::get_signaling_channel(app_state, claims.channel_id).await;
|
|
|
|
context.connection_state = WsState::Connected {
|
|
signal_channel,
|
|
server_id: claims.server_id,
|
|
channel_id: claims.channel_id,
|
|
user_id: claims.user_id,
|
|
};
|
|
|
|
ws::gateway::util::send_message_channel(
|
|
app_state.clone(),
|
|
claims.channel_id,
|
|
ws::gateway::event::Event::VoiceChannelConnected {
|
|
server_id: claims.server_id,
|
|
channel_id: claims.channel_id,
|
|
user_id: claims.user_id,
|
|
},
|
|
);
|
|
|
|
Ok(())
|
|
},
|
|
Err(auth_err) => {
|
|
tracing::warn!("Authentication failed: {:?}", auth_err);
|
|
|
|
let _ =
|
|
SendWsMessage::send_with_response(sender, WsServerMessage::AuthenticateDenied)
|
|
.await;
|
|
Err(error::Error::AuthenticationFailed.into())
|
|
},
|
|
},
|
|
#[allow(unreachable_patterns)]
|
|
_ => {
|
|
tracing::warn!("Unexpected message type received during Initialize state.");
|
|
Err(ws::error::Error::WrongMessageType)
|
|
},
|
|
}
|
|
}
|
|
|
|
#[tracing::instrument(skip_all)]
|
|
async fn handle_connected_message(
|
|
context: &mut WsContext,
|
|
message: AxumMessage,
|
|
sender: &mpsc::UnboundedSender<SendWsMessage<WsServerMessage, error::Error>>,
|
|
app_state: &AppState,
|
|
) -> ws::error::Result<(), error::Error> {
|
|
match deserialize_ws_message(message)? {
|
|
WsClientMessage::SdpOffer { sdp } => {
|
|
let (signal_channel, user_id) = match &mut context.connection_state {
|
|
WsState::Connected {
|
|
signal_channel,
|
|
user_id,
|
|
..
|
|
} => (signal_channel.clone(), *user_id),
|
|
_ => return Err(ws::error::Error::WrongMessageType),
|
|
};
|
|
|
|
let (tx, rx) = oneshot::channel();
|
|
|
|
let _ = signal_channel.send(WebRtcSignal::Offer(OfferSignal {
|
|
offer: Offer {
|
|
peer_id: user_id,
|
|
sdp_offer: sdp,
|
|
},
|
|
response: tx,
|
|
}));
|
|
|
|
let answer_signal = rx.await?;
|
|
|
|
sender
|
|
.send(SendWsMessage::new_no_response(SdpAnswer {
|
|
sdp: answer_signal.sdp_answer,
|
|
}))
|
|
.map_err(|_| ws::error::Error::WebSocketClosed)?;
|
|
|
|
Ok(())
|
|
},
|
|
other_message => {
|
|
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
|
|
Err(ws::error::Error::WrongMessageType)
|
|
},
|
|
}
|
|
}
|
|
|
|
// file: src/web/ws/voice/error.rs
|
|
use crate::web::ws::error::CustomError;
|
|
|
|
pub type Result<T> = std::result::Result<T, Error>;
|
|
|
|
#[derive(Debug, derive_more::From, derive_more::Display)]
|
|
pub enum Error {
|
|
AuthenticationFailed,
|
|
}
|
|
|
|
impl CustomError for Error {}
|
|
|
|
// file: src/web/ws/voice/mod.rs
|
|
pub mod claims;
|
|
mod connection;
|
|
mod error;
|
|
mod protocol;
|
|
mod state;
|
|
|
|
use axum::extract::{State, WebSocketUpgrade};
|
|
use axum::response::IntoResponse;
|
|
|
|
use crate::state::AppState;
|
|
use crate::web::ws::general;
|
|
use crate::web::ws::voice::state::WsContext;
|
|
|
|
pub async fn ws_handler(
|
|
State(app_state): State<AppState>,
|
|
ws: WebSocketUpgrade,
|
|
) -> crate::web::error::Result<impl IntoResponse> {
|
|
Ok(ws.on_upgrade(|socket| {
|
|
general::handle_websocket_connection(socket, app_state, WsContext::default())
|
|
}))
|
|
}
|
|
|
|
// file: src/web/ws/voice/protocol.rs
|
|
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
|
|
|
#[derive(Debug, serde::Serialize)]
|
|
#[serde(tag = "type", content = "data")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum WsServerMessage {
|
|
AuthenticateDenied,
|
|
|
|
AuthenticateAccepted,
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
SdpAnswer {
|
|
sdp: RTCSessionDescription,
|
|
},
|
|
}
|
|
|
|
#[derive(Debug, serde::Deserialize)]
|
|
#[serde(tag = "type", content = "data")]
|
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
|
pub enum WsClientMessage {
|
|
#[serde(rename_all = "camelCase")]
|
|
Authenticate { token: String },
|
|
|
|
#[serde(rename_all = "camelCase")]
|
|
SdpOffer { sdp: RTCSessionDescription },
|
|
}
|
|
|
|
// file: src/web/ws/voice/state.rs
|
|
use tokio::sync::oneshot;
|
|
use tokio::sync::mpsc;
|
|
|
|
use crate::entity;
|
|
use crate::state::AppState;
|
|
use crate::webrtc::{OfferSignal, WebRtcSignal};
|
|
|
|
#[derive(Debug)]
|
|
pub enum WsState {
|
|
Initialize,
|
|
Connected {
|
|
signal_channel: mpsc::UnboundedSender<WebRtcSignal>,
|
|
server_id: entity::server::Id,
|
|
channel_id: entity::channel::Id,
|
|
user_id: entity::user::Id,
|
|
},
|
|
}
|
|
|
|
pub struct WsContext {
|
|
pub connection_state: WsState,
|
|
}
|
|
|
|
impl Default for WsContext {
|
|
fn default() -> Self {
|
|
Self {
|
|
connection_state: WsState::Initialize,
|
|
}
|
|
}
|
|
}
|
|
|
|
pub async fn get_signaling_channel(
|
|
app_state: &AppState,
|
|
channel_id: entity::channel::Id,
|
|
) -> mpsc::UnboundedSender<WebRtcSignal> {
|
|
let room_sender = {
|
|
app_state
|
|
.voice_rooms
|
|
.read()
|
|
.await
|
|
.get(&channel_id)
|
|
.map(|room| room.clone())
|
|
};
|
|
|
|
match room_sender {
|
|
Some(room) => room,
|
|
None => {
|
|
let (tx, rx) = mpsc::unbounded_channel();
|
|
|
|
let app_state_ = app_state.clone();
|
|
tokio::spawn(async move {
|
|
crate::webrtc::webrtc_task(channel_id, rx)
|
|
.await
|
|
.unwrap_or_else(|err| {
|
|
tracing::error!("webrtc task error: {:?}", err);
|
|
});
|
|
|
|
{
|
|
app_state_.unregister_voice_room(channel_id).await;
|
|
}
|
|
});
|
|
|
|
{
|
|
app_state.register_voice_room(channel_id, tx.clone()).await;
|
|
}
|
|
|
|
tx
|
|
},
|
|
}
|
|
}
|
|
|
|
// file: src/webrtc/mod.rs
|
|
use std::sync::Arc;
|
|
|
|
use dashmap::DashMap;
|
|
use tracing::Instrument;
|
|
use webrtc::api::interceptor_registry::register_default_interceptors;
|
|
use webrtc::api::media_engine::MIME_TYPE_OPUS;
|
|
use webrtc::api::{API, APIBuilder};
|
|
use webrtc::interceptor::registry::Registry;
|
|
use webrtc::peer_connection::configuration::RTCConfiguration;
|
|
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
|
|
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
|
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
|
|
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
|
|
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
|
|
use webrtc::track::track_remote::TrackRemote;
|
|
|
|
use crate::entity;
|
|
|
|
type PeerId = entity::user::Id;
|
|
type RoomId = entity::channel::Id;
|
|
|
|
struct PeerState {
|
|
peer_id: PeerId,
|
|
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
|
|
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
|
|
ssrc: u32,
|
|
}
|
|
|
|
struct RoomState {
|
|
room_id: RoomId,
|
|
peers: DashMap<PeerId, PeerState>,
|
|
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct Offer {
|
|
pub peer_id: PeerId,
|
|
pub sdp_offer: RTCSessionDescription,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct AnswerSignal {
|
|
pub sdp_answer: RTCSessionDescription,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub enum WebRtcSignal {
|
|
Offer(OfferSignal),
|
|
Disconnect(PeerId),
|
|
RequestPeers {
|
|
response: tokio::sync::oneshot::Sender<Vec<PeerId>>,
|
|
},
|
|
Close,
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
pub struct OfferSignal {
|
|
pub offer: Offer,
|
|
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
|
|
}
|
|
|
|
#[tracing::instrument(skip(signal))]
|
|
pub async fn webrtc_task(
|
|
room_id: RoomId,
|
|
signal: tokio::sync::mpsc::UnboundedReceiver<WebRtcSignal>,
|
|
) -> anyhow::Result<()> {
|
|
tracing::info!("Starting WebRTC task");
|
|
|
|
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
|
|
|
|
let mut skip_timeout = false;
|
|
|
|
let state = Arc::new(RoomState {
|
|
room_id,
|
|
peers: DashMap::new(),
|
|
close_signal,
|
|
});
|
|
|
|
let mut signal = signal;
|
|
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
|
|
media_engine.register_default_codecs()?;
|
|
let mut registry = Registry::new();
|
|
registry = register_default_interceptors(registry, &mut media_engine)?;
|
|
|
|
let api = APIBuilder::new()
|
|
.with_media_engine(media_engine)
|
|
.with_interceptor_registry(registry)
|
|
.build();
|
|
|
|
let api = Arc::new(api);
|
|
|
|
loop {
|
|
tokio::select! {
|
|
biased;
|
|
_ = tokio::time::sleep(std::time::Duration::from_secs(10)), if !skip_timeout => {
|
|
tracing::debug!("initial timeout reached");
|
|
break;
|
|
}
|
|
_ = close_receiver.recv() => {
|
|
tracing::debug!("WebRTC task stopped");
|
|
break;
|
|
}
|
|
Some(signal) = signal.recv() => {
|
|
skip_timeout = true;
|
|
match signal {
|
|
WebRtcSignal::Offer(offer_signal) => {
|
|
let room_state = state.clone();
|
|
let api = api.clone();
|
|
|
|
tokio::spawn(async move {
|
|
if let Err(e) = handle_peer(api, room_state, offer_signal).await {
|
|
tracing::error!("error handling peer: {}", e);
|
|
}
|
|
}.instrument(tracing::Span::current()));
|
|
}
|
|
WebRtcSignal::RequestPeers { response } => {
|
|
let peers = state
|
|
.peers
|
|
.iter()
|
|
.map(|pair| pair.key().clone())
|
|
.collect::<Vec<_>>();
|
|
|
|
let _ = response.send(peers);
|
|
}
|
|
WebRtcSignal::Disconnect(peer_id) => {
|
|
tracing::debug!("received disconnect signal for peer {}", peer_id);
|
|
cleanup_peer(state.clone(), peer_id).await;
|
|
}
|
|
WebRtcSignal::Close => {
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id
|
|
))]
|
|
async fn handle_peer(
|
|
api: Arc<API>,
|
|
room_state: Arc<RoomState>,
|
|
offer_signal: OfferSignal,
|
|
) -> anyhow::Result<()> {
|
|
tracing::debug!("handling peer");
|
|
|
|
let config = RTCConfiguration {
|
|
..Default::default()
|
|
};
|
|
|
|
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
|
|
RTCRtpCodecCapability {
|
|
mime_type: MIME_TYPE_OPUS.to_string(),
|
|
..Default::default()
|
|
},
|
|
format!("audio-{}", offer_signal.offer.peer_id),
|
|
format!("room-{}", room_state.room_id),
|
|
));
|
|
|
|
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
|
|
|
|
let rtp_sender = peer_connection
|
|
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
|
|
.await?;
|
|
|
|
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
|
|
|
|
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
|
|
let outgoing_track_ = Arc::clone(&outgoing_track);
|
|
tokio::spawn(
|
|
async move {
|
|
let mut rtcp_buf = vec![0u8; 1500];
|
|
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
|
|
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
|
|
}
|
|
tracing::debug!(
|
|
"RTCP reader loop for outgoing track {} ended",
|
|
outgoing_track_.id()
|
|
);
|
|
}
|
|
.instrument(tracing::Span::current()),
|
|
);
|
|
|
|
let room_state_ = room_state.clone();
|
|
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
|
|
let room_state_ = room_state_.clone();
|
|
Box::pin(async move {
|
|
tracing::debug!("peer connection state changed: {:?}", state);
|
|
if state == RTCPeerConnectionState::Disconnected
|
|
|| state == RTCPeerConnectionState::Failed
|
|
{
|
|
tracing::debug!("peer connection closed");
|
|
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
|
|
}
|
|
})
|
|
}));
|
|
|
|
let room_state_ = room_state.clone();
|
|
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
|
|
tracing::debug!("track received: {:?}", track);
|
|
|
|
let room_state_ = room_state_.clone();
|
|
Box::pin(async move {
|
|
if track.kind() == RTPCodecType::Audio {
|
|
tracing::debug!("audio track received: {:?}", track);
|
|
tokio::spawn(async move {
|
|
if let Err(e) =
|
|
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
|
|
{
|
|
tracing::error!("error forwarding audio track: {}", e);
|
|
}
|
|
});
|
|
} else {
|
|
tracing::warn!("received non-audio track: {:?}", track);
|
|
}
|
|
})
|
|
}));
|
|
|
|
peer_connection
|
|
.set_remote_description(offer_signal.offer.sdp_offer)
|
|
.await?;
|
|
let answer = peer_connection.create_answer(None).await?;
|
|
|
|
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
|
|
|
|
peer_connection.set_local_description(answer).await?;
|
|
|
|
gathering_complete.recv().await;
|
|
|
|
tracing::debug!("ICE gathering complete");
|
|
|
|
let local_description = peer_connection
|
|
.local_description()
|
|
.await
|
|
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
|
|
|
|
let peer_state = PeerState {
|
|
peer_id: offer_signal.offer.peer_id,
|
|
peer_connection: Arc::clone(&peer_connection),
|
|
outgoing_audio_track: Arc::clone(&outgoing_track),
|
|
ssrc: 0,
|
|
};
|
|
|
|
{
|
|
if let Some((_, old_peer)) = room_state.peers.remove(&offer_signal.offer.peer_id) {
|
|
let _ = old_peer.peer_connection.close().await;
|
|
}
|
|
}
|
|
|
|
{
|
|
room_state
|
|
.peers
|
|
.insert(offer_signal.offer.peer_id, peer_state);
|
|
}
|
|
|
|
let _ = offer_signal.response.send(AnswerSignal {
|
|
sdp_answer: local_description,
|
|
});
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id
|
|
))]
|
|
async fn forward_audio_track(
|
|
room_state: Arc<RoomState>,
|
|
peer_id: PeerId,
|
|
track: Arc<TrackRemote>,
|
|
) -> anyhow::Result<()> {
|
|
let mut rtp_buf = vec![0u8; 1500];
|
|
|
|
{
|
|
let mut peer_state = room_state
|
|
.peers
|
|
.get_mut(&peer_id)
|
|
.expect("peer state not found");
|
|
|
|
peer_state.ssrc = track.ssrc();
|
|
}
|
|
|
|
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
|
|
let other_peer_tracks = room_state
|
|
.peers
|
|
.iter()
|
|
.filter_map(|pair| {
|
|
let peer_state = pair.value();
|
|
if peer_state.peer_id != peer_id {
|
|
let mut rtp_packet = rtp_packet.clone();
|
|
rtp_packet.header.ssrc = peer_state.ssrc;
|
|
|
|
Some((peer_state.outgoing_audio_track.clone(), rtp_packet))
|
|
} else {
|
|
None
|
|
}
|
|
})
|
|
.collect::<Vec<_>>();
|
|
|
|
if other_peer_tracks.is_empty() {
|
|
// tracing::warn!("no other peers to forward audio track to");
|
|
continue;
|
|
}
|
|
|
|
let write_futures = other_peer_tracks
|
|
.iter()
|
|
.map(|(outgoing_track, packet)| outgoing_track.write_rtp(&packet));
|
|
|
|
let results = futures::future::join_all(write_futures).await;
|
|
for result in results {
|
|
if let Err(e) = result {
|
|
tracing::error!("error writing RTP packet: {}", e);
|
|
}
|
|
}
|
|
}
|
|
tracing::debug!(
|
|
"RTP Read loop ended for track {} from peer {}",
|
|
track.id(),
|
|
peer_id
|
|
);
|
|
|
|
Ok(())
|
|
}
|
|
|
|
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
|
|
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
|
|
tracing::debug!("cleaning up peer");
|
|
|
|
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
|
|
tracing::debug!("removed peer");
|
|
let pc = Arc::clone(&peer_state.peer_connection);
|
|
tokio::spawn(async move {
|
|
if let Err(e) = pc.close().await {
|
|
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
|
|
tracing::warn!("error closing peer connection: {}", e);
|
|
}
|
|
}
|
|
|
|
tracing::debug!("peer connection closed");
|
|
});
|
|
}
|
|
|
|
if room_state.peers.is_empty() {
|
|
tracing::debug!("no more peers in room, closing room");
|
|
let _ = room_state.close_signal.send(());
|
|
}
|
|
}
|