This commit is contained in:
2025-05-14 06:47:10 +03:00
commit f04f5e6e41
64 changed files with 7281 additions and 0 deletions

3
.cargo/config.toml Normal file
View File

@@ -0,0 +1,3 @@
[target.x86_64-unknown-linux-gnu]
linker = "/usr/bin/clang"
rustflags = ["-C", "link-arg=--ld-path=/usr/bin/mold"]

7
.gitignore vendored Normal file
View File

@@ -0,0 +1,7 @@
/target
/.idea
/config.toml
/db
/logs
.env

4432
Cargo.lock generated Normal file

File diff suppressed because it is too large Load Diff

37
Cargo.toml Normal file
View File

@@ -0,0 +1,37 @@
[package]
name = "diplom"
version = "0.1.0"
edition = "2024"
[dependencies]
anyhow = "1.0"
argon2 = "0.5"
axum = { version = "0.8", features = ["ws", "multipart", "macros"] }
axum-extra = { version = "0.10", features = ["typed-header"] }
chrono = { version = "0.4", features = ["serde"] }
config = "0.15"
derive_more = { version = "2.0", features = ["full"] }
jsonwebtoken = "9.3"
rand_core = { version = "0.6", features = ["getrandom"] }
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio", "uuid", "chrono"] }
tokio = { version = "1.45", features = ["full"] }
tower-http = { version = "0.6", features = ["cors", "trace"] }
tracing = "0.1"
tracing-appender = "0.2"
tracing-subscriber = { version = "0.3", features = ["serde", "env-filter", "chrono"] }
uuid = { version = "1.16", features = ["fast-rng", "serde", "v7"] }
url = { version = "2.5", features = ["serde"] }
validator = { version = "0.20.0", features = ["derive"] }
regex = "1.11.1"
mime = "0.3.17"
axum_typed_multipart = "0.16.2"
async-trait = "0.1.88"
futures = "0.3.31"
webrtc = "0.12.0"
dashmap = "6.1.0"
rand = "0.9.1"
sha2 = "0.10.9"
hex = "0.4.3"
base64 = "0.22.1"

10
docker-compose.yml Normal file
View File

@@ -0,0 +1,10 @@
services:
database:
image: 'ghcr.io/craigpastro/pg_uuidv7:main'
ports:
- "15432:5432"
env_file:
- .env
volumes:
- ${PWD}/db/:/var/lib/postgresql/data/
user: "1000:1000"

View File

@@ -0,0 +1 @@
DROP EXTENSION pg_uuidv7;

View File

@@ -0,0 +1 @@
CREATE EXTENSION IF NOT EXISTS pg_uuidv7;

View File

@@ -0,0 +1,4 @@
DROP TRIGGER trg_user_relation_update ON "user_relation";
DROP FUNCTION fn_on_user_relation_update();
DROP TABLE "user_relation";
DROP TABLE "user";

View File

@@ -0,0 +1,46 @@
CREATE TABLE IF NOT EXISTS "user"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"avatar_url" VARCHAR,
"username" VARCHAR NOT NULL UNIQUE,
"display_name" VARCHAR,
"email" VARCHAR NOT NULL,
"password_hash" VARCHAR NOT NULL,
"bot" BOOLEAN NOT NULL DEFAULT FALSE,
"system" BOOLEAN NOT NULL DEFAULT FALSE,
"settings" JSONB NOT NULL DEFAULT '{}'::JSONB
);
CREATE TABLE IF NOT EXISTS "user_relation"
(
"user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
"other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
"type" INT2 NOT NULL,
"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
PRIMARY KEY ("user_id", "other_id")
);
-- create system account
INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system")
VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE);
CREATE OR REPLACE FUNCTION fn_on_user_relation_update()
RETURNS TRIGGER
LANGUAGE plpgsql
AS
$$
BEGIN
IF (TG_OP = 'UPDATE') THEN
NEW.updated_at := now();
END IF;
RETURN NEW;
END;
$$;
CREATE TRIGGER trg_user_relation_update
BEFORE UPDATE
ON "user_relation"
FOR EACH ROW
EXECUTE FUNCTION fn_on_user_relation_update();

View File

@@ -0,0 +1 @@
DROP TABLE "server";

View File

@@ -0,0 +1,77 @@
CREATE TABLE IF NOT EXISTS "server"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"owner_id" UUID NOT NULL REFERENCES "user" ("id"),
"name" VARCHAR NOT NULL,
"icon_url" VARCHAR
);
CREATE TABLE IF NOT EXISTS "server_role"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"name" VARCHAR NOT NULL,
"color" VARCHAR,
"display" BOOLEAN NOT NULL DEFAULT FALSE,
"permissions" JSONB NOT NULL DEFAULT '{}'::JSONB,
"position" SMALLINT NOT NULL
);
CREATE TABLE IF NOT EXISTS "server_member"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
"nickname" VARCHAR,
"avatar_url" VARCHAR,
UNIQUE ("server_id", "user_id")
);
CREATE TABLE IF NOT EXISTS "server_member_role"
(
"member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE,
"role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE,
PRIMARY KEY ("member_id", "role_id")
);
CREATE TABLE IF NOT EXISTS "server_invite"
(
"code" VARCHAR NOT NULL PRIMARY KEY,
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
"inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL,
"expires_at" TIMESTAMPTZ
);
CREATE OR REPLACE FUNCTION check_server_user_role_server_id()
RETURNS TRIGGER AS
$$
DECLARE
member_server_id UUID;
role_server_id UUID;
BEGIN
-- Get server_id from server_user
SELECT server_id
INTO member_server_id
FROM server_member
WHERE id = NEW.member_id;
-- Get server_id from server_role
SELECT server_id
INTO role_server_id
FROM server_role
WHERE id = NEW.role_id;
-- Check if server_ids match
IF member_server_id != role_server_id THEN
RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id;
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER enforce_server_user_role_server_id
BEFORE INSERT OR UPDATE
ON server_member_role
FOR EACH ROW
EXECUTE FUNCTION check_server_user_role_server_id();

View File

@@ -0,0 +1,2 @@
DROP TABLE "message";
DROP TABLE "channel";

View File

@@ -0,0 +1,97 @@
CREATE TABLE IF NOT EXISTS "channel"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"name" VARCHAR NOT NULL,
"type" INT2 NOT NULL,
"position" INT2 NOT NULL,
"server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels
"parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels
"owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels
);
-- only for dm or group channels
CREATE TABLE IF NOT EXISTS "channel_recipient"
(
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
PRIMARY KEY ("channel_id", "user_id")
);
CREATE TABLE IF NOT EXISTS "message"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"author_id" UUID NOT NULL REFERENCES "user" ("id"),
"channel_id" UUID NOT NULL REFERENCES "channel" ("id"),
"content" TEXT NOT NULL
);
ALTER TABLE "channel"
ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL;
CREATE OR REPLACE FUNCTION create_dm_channel(
user1_id UUID,
user2_id UUID,
channel_type INT2
)
RETURNS UUID
LANGUAGE plpgsql
AS
$$
DECLARE
new_channel_id UUID;
channel_name VARCHAR;
BEGIN
-- Validate parameters
IF user1_id IS NULL OR user2_id IS NULL THEN
RAISE EXCEPTION 'Both user IDs must be provided';
END IF;
IF user1_id = user2_id THEN
RAISE EXCEPTION 'Cannot create a DM channel with the same user';
END IF;
-- Check if users exist
IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR
NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN
RAISE EXCEPTION 'One or both users do not exist';
END IF;
-- Check if DM already exists between these users
IF exists (SELECT 1
FROM channel c
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
WHERE c.type = channel_type
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN
-- Find and return the existing channel ID
SELECT c.id
INTO new_channel_id
FROM channel c
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
WHERE c.type = channel_type
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2
LIMIT 1;
RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id;
RETURN new_channel_id;
END IF;
-- Generate channel name (conventionally uses user IDs in DMs)
channel_name := concat(user1_id, '-', user2_id);
-- Create new channel
INSERT INTO "channel" ("name", "type", "position")
VALUES (channel_name, channel_type, 0)
RETURNING id INTO new_channel_id;
-- Add both users as recipients
INSERT INTO "channel_recipient" ("channel_id", "user_id")
VALUES (new_channel_id, user1_id),
(new_channel_id, user2_id);
RAISE NOTICE 'DM channel created with ID: %', new_channel_id;
RETURN new_channel_id;
END;
$$;

View File

@@ -0,0 +1,2 @@
DROP TABLE "message_attachment";
DROP TABLE "file";

View File

@@ -0,0 +1,16 @@
CREATE TABLE IF NOT EXISTS "file"
(
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
"filename" VARCHAR NOT NULL,
"content_type" VARCHAR NOT NULL,
"url" VARCHAR NOT NULL,
"size" INT8 NOT NULL
);
CREATE TABLE IF NOT EXISTS "message_attachment"
(
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
"attachment_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
"order" INT2 NOT NULL,
PRIMARY KEY ("message_id", "attachment_id")
);

7
rustfmt.toml Normal file
View File

@@ -0,0 +1,7 @@
match_block_trailing_comma = true
newline_style = "Unix"
style_edition = "2024"
group_imports = "StdExternalCrate"
imports_granularity = "Module"
unstable_features = true

64
src/config.rs Normal file
View File

@@ -0,0 +1,64 @@
use std::sync::OnceLock;
use serde::Deserialize;
pub fn config() -> &'static Config {
static INSTANCE: OnceLock<Config> = OnceLock::new();
INSTANCE.get_or_init(|| {
config::Config::builder()
.add_source(config::File::with_name("config"))
.add_source(config::Environment::with_prefix("DIPLOM_"))
.build()
.expect("config builder")
.try_deserialize()
.expect("config deserialize")
})
}
#[derive(Deserialize)]
pub struct Config {
pub port: u16,
pub jwt_secret: String,
pub database: DatabaseConfig,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
#[serde(deny_unknown_fields)]
pub enum DatabaseConfig {
Url {
url: url::Url,
},
Full {
driver: String,
username: String,
password: String,
host: url::Host,
port: u16,
name: String,
},
}
impl DatabaseConfig {
pub fn url(&self) -> Option<url::Url> {
match self {
Self::Url { url } => Some(url.clone()),
Self::Full {
driver,
username,
password,
host,
port,
name,
} => {
let str = format!(
"{}://{}:{}@{}:{}/{}",
driver, username, password, host, port, name
);
let url = url::Url::parse(&str).ok()?;
Some(url)
},
}
}
}

333
src/database.rs Normal file
View File

@@ -0,0 +1,333 @@
use crate::config::DatabaseConfig;
use crate::entity;
#[derive(Clone, derive_more::AsRef)]
pub struct Database {
pool: sqlx::PgPool,
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
pub enum Error {
#[from]
Migrate(sqlx::migrate::MigrateError),
#[from]
Sqlx(sqlx::Error),
UserDoesNotExists,
UserAlreadyExists,
ServerDoesNotExists,
ChannelDoesNotExists,
MessageDoesNotExists,
}
impl Database {
pub async fn connect(config: &DatabaseConfig) -> sqlx::Result<Self> {
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
let pool = sqlx::postgres::PgPoolOptions::new()
.connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str())
.await?;
MIGRATOR.run(&pool).await?;
Ok(Self { pool })
}
}
impl Database {
pub async fn insert_user(
&self,
username: &str,
display_name: Option<&str>,
email: &str,
password_hash: &str,
) -> Result<entity::user::User> {
let user = match sqlx::query_as!(
entity::user::User,
r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#,
username,
display_name,
email,
password_hash
)
.fetch_one(&self.pool)
.await {
Ok(user) => user,
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
return Err(Error::UserAlreadyExists);
}
Err(e) => return Err(e.into()),
};
Ok(user)
}
pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" = $1"#,
user_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user)
}
pub async fn select_users_by_ids(
&self,
user_ids: &[entity::user::Id],
) -> Result<Vec<entity::user::User>> {
let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#);
let mut separated = query_builder.separated(", ");
for id in user_ids.iter() {
separated.push_bind(id);
}
query_builder.push(")");
let users = query_builder.build_query_as().fetch_all(&self.pool).await?;
Ok(users)
}
pub async fn select_user_by_username(&self, username: &str) -> Result<entity::user::User> {
let user = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "username" = $1"#,
username
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user)
}
pub async fn select_server_by_id(
&self,
server_id: entity::server::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"SELECT * FROM "server" WHERE "id" = $1"#,
server_id
)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ServerDoesNotExists)?;
Ok(server)
}
pub async fn select_user_servers(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::server::Server>> {
let servers = sqlx::query_as!(
entity::server::Server,
r#"SELECT * FROM "server" WHERE "id" IN (
SELECT "server_id" FROM "server_member"
WHERE "user_id" = $1
)"#,
user_id
)
.fetch_all(&self.pool)
.await?;
Ok(servers)
}
pub async fn select_user_channels(
&self,
user_id: entity::user::Id,
) -> Result<Vec<entity::channel::Channel>> {
// for some reason using macro overflows tokio stack
let channels = sqlx::query_as(
r#"SELECT * FROM "channel" WHERE "id" IN (
SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1
)"#,
)
.bind(user_id)
.fetch_all(&self.pool)
.await?;
Ok(channels)
}
pub async fn insert_server(
&self,
name: &str,
icon_url: Option<&str>,
owner_id: entity::user::Id,
) -> Result<entity::server::Server> {
let server = sqlx::query_as!(
entity::server::Server,
r#"INSERT INTO "server"("name", "icon_url", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#,
name,
icon_url,
owner_id
)
.fetch_one(&self.pool)
.await?;
Ok(server)
}
pub async fn insert_server_role(
&self,
id: Option<entity::server::role::Id>,
server_id: entity::server::Id,
name: &str,
color: Option<&str>,
display: bool,
permissions: serde_json::Value,
position: u16,
) -> Result<entity::server::role::ServerRole> {
let role = sqlx::query_as!(
entity::server::role::ServerRole,
r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#,
id,
server_id,
name,
color,
display,
permissions,
position as i16
)
.fetch_one(&self.pool)
.await?;
Ok(role)
}
pub async fn insert_server_member(
&self,
server_id: entity::server::Id,
user_id: entity::user::Id,
) -> Result<entity::server::member::ServerMember> {
let member = sqlx::query_as!(
entity::server::member::ServerMember,
r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#,
server_id,
user_id
)
.fetch_one(&self.pool)
.await?;
Ok(member)
}
pub async fn insert_server_member_role(
&self,
server_member_id: entity::server::member::Id,
server_role_id: entity::server::role::Id,
) -> Result<()> {
sqlx::query!(
r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#,
server_member_id,
server_role_id
)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn insert_server_channel(
&self,
server_id: entity::server::Id,
name: &str,
channel_type: entity::channel::ChannelType,
position: u16,
parent: Option<entity::channel::Id>,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as!(
entity::channel::Channel,
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
name,
channel_type as i16,
position as i16,
server_id,
parent
)
.fetch_one(&self.pool)
.await?;
Ok(channel)
}
pub async fn select_channel_by_id(
&self,
channel_id: entity::channel::Id,
) -> Result<entity::channel::Channel> {
let channel = sqlx::query_as(r#"SELECT * FROM "channel" WHERE "id" = $1"#)
.bind(channel_id)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ChannelDoesNotExists)?;
Ok(channel)
}
pub async fn select_channel_recipients(
&self,
channel_id: entity::channel::Id,
) -> Result<Option<Vec<entity::user::User>>> {
let recipients = sqlx::query_as!(
entity::user::User,
r#"SELECT * FROM "user" WHERE "id" IN (
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
)"#,
channel_id
)
.fetch_all(&self.pool)
.await?;
if recipients.is_empty() {
return Ok(None);
}
Ok(Some(recipients))
}
pub async fn select_server_channels(
&self,
server_id: entity::server::Id,
) -> Result<Vec<entity::channel::Channel>> {
let channels = sqlx::query_as!(
entity::channel::Channel,
r#"SELECT * FROM "channel" WHERE "server_id" = $1"#,
server_id
)
.fetch_all(&self.pool)
.await?;
Ok(channels)
}
pub async fn procedure_create_dm_channel(
&self,
user1_id: entity::user::Id,
user2_id: entity::user::Id,
) -> Result<Option<entity::channel::Id>> {
let channel_id = sqlx::query_scalar!(
r#"SELECT create_dm_channel($1, $2, $3)"#,
user1_id,
user2_id,
entity::channel::ChannelType::DirectMessage as i16
)
.fetch_one(&self.pool)
.await?;
Ok(channel_id)
}
}

12
src/entity/attachment.rs Normal file
View File

@@ -0,0 +1,12 @@
use serde::Serialize;
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
pub struct Attachment {
pub id: Id,
pub filename: String,
pub content_type: String,
pub url: String,
pub size: u64,
}

41
src/entity/channel.rs Normal file
View File

@@ -0,0 +1,41 @@
use serde::Serialize;
use crate::entity::{message, server, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Channel {
pub id: Id,
pub name: String,
pub r#type: ChannelType,
pub position: i16,
#[serde(skip_serializing_if = "Option::is_none")]
pub server_id: Option<server::Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent: Option<Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub owner_id: Option<user::Id>,
#[serde(skip_serializing_if = "Option::is_none")]
pub last_message_id: Option<message::Id>,
}
#[derive(Debug, Clone, sqlx::Type, Serialize)]
#[non_exhaustive]
#[serde(rename_all = "snake_case")]
#[repr(i16)]
pub enum ChannelType {
ServerText = 1,
ServerVoice = 2,
ServerCategory = 3,
DirectMessage = 4,
Group = 5,
}
impl From<i16> for ChannelType {
fn from(value: i16) -> Self {
value.try_into().unwrap_or(ChannelType::ServerText)
}
}

15
src/entity/message.rs Normal file
View File

@@ -0,0 +1,15 @@
use serde::Serialize;
use crate::entity::{channel, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
pub id: Id,
pub author_id: user::Id,
pub channel_id: channel::Id,
pub content: String,
pub timestamp: chrono::DateTime<chrono::Utc>,
}

5
src/entity/mod.rs Normal file
View File

@@ -0,0 +1,5 @@
pub mod attachment;
pub mod channel;
pub mod message;
pub mod server;
pub mod user;

18
src/entity/server.rs Normal file
View File

@@ -0,0 +1,18 @@
mod invite;
pub mod member;
pub mod role;
use serde::Serialize;
use crate::entity::user;
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Server {
pub id: Id,
pub owner_id: user::Id,
pub name: String,
pub icon_url: Option<String>,
}

View File

@@ -0,0 +1,13 @@
use chrono::{DateTime, Utc};
use serde::Serialize;
use crate::entity::{server, user};
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerInvite {
pub code: String,
pub server_id: server::Id,
pub inviter_id: Option<user::Id>,
pub expires_at: Option<DateTime<Utc>>,
}

View File

@@ -0,0 +1,15 @@
use serde::Serialize;
use crate::entity::{server, user};
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerMember {
pub id: Id,
pub server_id: server::Id,
pub user_id: user::Id,
pub nickname: Option<String>,
pub avatar_url: Option<String>,
}

17
src/entity/server/role.rs Normal file
View File

@@ -0,0 +1,17 @@
use serde::Serialize;
use crate::entity::server;
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct ServerRole {
pub id: Id,
pub server_id: server::Id,
pub name: String,
pub color: Option<String>,
pub display: bool,
pub permissions: serde_json::Value,
pub position: i16,
}

21
src/entity/user.rs Normal file
View File

@@ -0,0 +1,21 @@
use std::sync::LazyLock;
use regex::Regex;
pub static USERNAME_REGEX: LazyLock<Regex> =
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap());
pub type Id = uuid::Uuid;
#[derive(Debug, Clone, sqlx::FromRow)]
pub struct User {
pub id: Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub password_hash: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}

55
src/jwt.rs Normal file
View File

@@ -0,0 +1,55 @@
#![allow(unused)]
use std::collections::HashSet;
use chrono::{Duration, Local, Utc};
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use crate::config;
#[derive(Debug, Serialize, Deserialize)]
pub struct Claims<T> {
pub user_id: T,
pub iat: i64,
}
pub fn generate_jwt<T: Serialize>(user_id: T) -> Result<String> {
let claims = Claims {
user_id,
iat: Utc::now().timestamp_millis(),
};
let token = jsonwebtoken::encode(
&jsonwebtoken::Header::default(),
&claims,
&jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()),
)
.map_err(|_| Error::CouldNotEncodeToken)?;
Ok(token)
}
pub fn verify_jwt<T: DeserializeOwned>(token: &str) -> Result<T> {
tracing::debug!("verifying token: {}", token);
let mut validation = jsonwebtoken::Validation::default();
validation.required_spec_claims = HashSet::new();
let token_data = jsonwebtoken::decode::<Claims<T>>(
token,
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
&validation,
)
.map_err(|_| Error::CouldNotVerifyToken)?;
Ok(token_data.claims.user_id)
}
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::Error, derive_more::Display)]
pub enum Error {
CouldNotEncodeToken,
CouldNotVerifyToken,
}

43
src/log.rs Normal file
View File

@@ -0,0 +1,43 @@
use std::io;
use tracing::level_filters::LevelFilter;
use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::EnvFilter;
use tracing_subscriber::fmt::time::ChronoLocal;
use tracing_subscriber::layer::SubscriberExt;
use tracing_subscriber::util::SubscriberInitExt;
pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> {
let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string());
let file = tracing_appender::rolling::daily("./logs", "log");
let (file, file_guard) = tracing_appender::non_blocking(file);
let file_logger = tracing_subscriber::fmt::layer()
.with_ansi(false)
.with_timer(logger_timer.clone())
.with_writer(file);
let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout());
let stdout_logger = tracing_subscriber::fmt::layer()
.with_timer(logger_timer)
.with_writer(stdout);
#[cfg(debug_assertions)]
let stdout_logger = stdout_logger.pretty();
tracing_subscriber::registry()
.with(
EnvFilter::builder()
.with_default_directive(LevelFilter::INFO.into())
.from_env_lossy(),
)
.with(file_logger)
.with(stdout_logger)
.init();
tracing::info!("initialized logging");
Ok((file_guard, stdout_guard))
}

35
src/main.rs Normal file
View File

@@ -0,0 +1,35 @@
use std::collections::HashMap;
use std::sync::Arc;
use argon2::Argon2;
use tokio::sync::RwLock;
use crate::database::Database;
use crate::state::AppState;
mod config;
mod database;
mod entity;
mod jwt;
mod log;
mod state;
mod util;
mod web;
mod webrtc;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
let _guard = log::init_logging()?;
let database = Database::connect(&config::config().database).await?;
let state = AppState {
database,
hasher: Arc::new(Argon2::default()),
event_connected_users: Arc::new(RwLock::new(HashMap::new())),
voice_rooms: Arc::new(RwLock::new(HashMap::new())),
};
web::run(state).await?;
Ok(())
}

51
src/state.rs Normal file
View File

@@ -0,0 +1,51 @@
use std::collections::HashMap;
use std::sync::Arc;
use argon2::Argon2;
use dashmap::DashMap;
use tokio::sync::{RwLock, mpsc};
use uuid::Uuid;
use crate::database::Database;
use crate::web::ws::gateway::{EventWsState, message};
use crate::webrtc::OfferSignal;
#[derive(Clone)]
pub struct AppState {
pub database: Database,
pub hasher: Arc<Argon2<'static>>,
pub event_connected_users: Arc<RwLock<HashMap<Uuid, EventWsState>>>,
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<OfferSignal>>>>,
}
impl AppState {
pub async fn register_event_connected_user(
&self,
user_id: Uuid,
session_id: String,
event_sender: mpsc::UnboundedSender<message::Event>,
) {
let mut connected_users = self.event_connected_users.write().await;
if let Some(state) = connected_users.get_mut(&user_id) {
state.connection_instance.insert(session_id, event_sender);
} else {
let state = EventWsState {
connection_instance: DashMap::new(),
};
state.connection_instance.insert(session_id, event_sender);
connected_users.insert(user_id, state);
}
}
pub async fn unregister_event_connected_user(&self, user_id: Uuid, session_id: &str) {
let mut connected_users = self.event_connected_users.write().await;
if let Some(state) = connected_users.get_mut(&user_id) {
state.connection_instance.remove(session_id);
if state.connection_instance.is_empty() {
connected_users.remove(&user_id);
}
}
}
}

52
src/util.rs Normal file
View File

@@ -0,0 +1,52 @@
use axum::extract::multipart::Field;
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
use serde::Serialize;
#[derive(Debug, derive_more::Deref)]
pub struct SerdeFieldData<T>(pub FieldData<T>);
#[async_trait::async_trait]
impl<T: TryFromField> TryFromField for SerdeFieldData<T> {
async fn try_from_field(
field: Field<'_>,
limit_bytes: Option<usize>,
) -> Result<Self, TypedMultipartError> {
let field = FieldData::try_from_field(field, limit_bytes).await?;
Ok(Self(field))
}
}
impl<T> Serialize for SerdeFieldData<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
struct Metadata {
name: Option<String>,
file_name: Option<String>,
content_type: Option<String>,
}
let metadata = Metadata {
name: self.0.metadata.name.clone(),
file_name: self.0.metadata.file_name.clone(),
content_type: self.0.metadata.content_type.clone(),
};
metadata.serialize(serializer)
}
}
pub fn serialize_duration_seconds<S>(
duration: &std::time::Duration,
serializer: S,
) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let seconds = duration.as_secs();
seconds.serialize(serializer)
}

35
src/web/context.rs Normal file
View File

@@ -0,0 +1,35 @@
use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use crate::entity;
#[derive(Debug, Copy, Clone)]
pub struct UserContext {
pub user_id: entity::user::Id,
}
pub type UserContextResult = Result<UserContext, Error>;
#[derive(Debug, Clone, Copy, derive_more::Display)]
pub enum Error {
NotInRequest,
NotInHeader,
BadCharacters,
WrongTokenType,
BadToken,
Model,
}
impl<S: Send + Sync> FromRequestParts<S> for UserContext {
type Rejection = super::error::Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
let context = parts
.extensions
.get::<UserContextResult>()
.cloned()
.ok_or(Error::NotInRequest)??;
Ok(context)
}
}

147
src/web/error.rs Normal file
View File

@@ -0,0 +1,147 @@
use std::sync::Arc;
use axum::http::StatusCode;
use axum::response::IntoResponse;
use crate::web::context;
use crate::{database, jwt};
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From)]
pub enum Error {
#[from]
Context(context::Error),
#[from]
Jwt(jwt::Error),
#[from]
Hash(argon2::password_hash::Error),
#[from]
Database(database::Error),
#[from]
Json(serde_json::error::Error),
#[from]
JsonRejection(axum::extract::rejection::JsonRejection),
#[from]
Client(ClientError),
}
#[derive(Debug, Clone, derive_more::From, serde::Serialize)]
#[serde(tag = "message", content = "details")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
UserAlreadyExists,
UserDoesNotExist,
ServerDoesNotExist,
ChannelDoesNotExist,
MessageDoesNotExist,
NotAuthorized,
WrongPassword,
NotAllowed,
InvalidJson(JsonRejection),
#[from]
ValidationFailed(validator::ValidationErrors),
InternalServerError,
}
#[derive(derive_more::Debug, Clone, serde::Serialize)]
pub struct JsonRejection {
#[serde(skip)]
status: StatusCode,
reason: String,
}
impl From<&axum::extract::rejection::JsonRejection> for JsonRejection {
fn from(value: &axum::extract::rejection::JsonRejection) -> Self {
use std::error::Error;
use axum::extract::rejection::JsonRejection::*;
let reason = match value {
JsonDataError(e) => e.source().map(ToString::to_string),
JsonSyntaxError(e) => e.source().map(ToString::to_string),
MissingJsonContentType(e) => e.source().map(ToString::to_string),
BytesRejection(e) => e.source().map(ToString::to_string),
_ => None,
}
.unwrap_or_else(|| value.body_text());
Self {
status: value.status(),
reason,
}
}
}
impl Error {
pub fn as_client_error(&self) -> ClientError {
match self {
Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized,
Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists,
Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist,
Error::Database(database::Error::ServerDoesNotExists) => {
ClientError::ServerDoesNotExist
},
Error::Database(database::Error::ChannelDoesNotExists) => {
ClientError::ChannelDoesNotExist
},
Error::Database(database::Error::MessageDoesNotExists) => {
ClientError::MessageDoesNotExist
},
// Error::WrongPassword => ClientError::WrongPassword,
// Error::NotAllowed => ClientError::NotAllowed,
Error::JsonRejection(e) => ClientError::InvalidJson(e.into()),
Error::Client(client_error) => client_error.clone(),
_ => ClientError::InternalServerError,
}
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
response.extensions_mut().insert(Arc::new(self));
response
}
}
impl ClientError {
pub fn status_code(&self) -> StatusCode {
match self {
ClientError::UserAlreadyExists => StatusCode::CONFLICT,
ClientError::UserDoesNotExist
| ClientError::ServerDoesNotExist
| ClientError::ChannelDoesNotExist
| ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND,
ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED,
ClientError::NotAllowed => StatusCode::FORBIDDEN,
ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY,
ClientError::InvalidJson(e) => e.status,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}
}

View File

@@ -0,0 +1,63 @@
use axum::extract::{Request, State};
use axum::middleware::Next;
use axum::response::{IntoResponse, Response};
use axum::{Extension, RequestExt};
use axum_extra::TypedHeader;
use axum_extra::headers::Authorization;
use axum_extra::headers::authorization::Bearer;
use axum_extra::typed_header::TypedHeaderRejectionReason;
use crate::jwt;
use crate::state::AppState;
use crate::web::{self, context};
pub async fn require_context(
Extension(context): Extension<context::UserContextResult>,
request: Request,
next: Next,
) -> web::error::Result<Response> {
context?;
Ok(next.run(request).await)
}
pub async fn resolve_context(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> impl IntoResponse {
let context = get_context(&state, &mut request).await;
request.extensions_mut().insert(context);
next.run(request).await
}
async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult {
let bearer = request
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
.await
.map_err(|error| match error.reason() {
TypedHeaderRejectionReason::Missing => context::Error::NotInHeader,
TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType,
_ => context::Error::NotInHeader,
})?;
let token = bearer.token();
let context = get_context_from_token(state, token).await?;
Ok(context)
}
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
let _ = state
.database
.select_user_by_id(user_id)
.await
.map_err(|_| context::Error::BadToken)?;
Ok(context::UserContext { user_id })
}

View File

@@ -0,0 +1,5 @@
mod auth;
mod response_map;
pub use auth::*;
pub use response_map::*;

View File

@@ -0,0 +1,31 @@
use std::sync::Arc;
use axum::Json;
use axum::extract::Request;
use axum::middleware::Next;
use axum::response::IntoResponse;
use serde_json::json;
use crate::web;
pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
let response = next.run(request).await;
let error = response.extensions().get::<Arc<web::Error>>();
if error.is_some() {
tracing::error!("{:?}", error);
}
let error_response = error.map(|error| {
let client_error = error.as_client_error();
(
client_error.status_code(),
Json(json!({"error": client_error})),
)
.into_response()
});
error_response.unwrap_or(response)
}

80
src/web/mod.rs Normal file
View File

@@ -0,0 +1,80 @@
mod context;
mod error;
mod middleware;
mod route;
pub mod ws;
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
pub use self::error::{Error, Result};
use crate::{config, state};
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
let config = config::config();
let addr: std::net::SocketAddr = ([127, 0, 0, 1], config.port).into();
tracing::info!("listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router(state))
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
fn router(state: state::AppState) -> axum::Router {
use axum::Router;
use axum::routing::*;
use self::route::*;
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::any())
.allow_methods(AllowMethods::any())
.allow_headers(AllowHeaders::any());
Router::new()
// websocket
.nest(
"/api/v1",
Router::new()
.route("/ws", get(ws::gateway::ws_handler))
.route("/auth/login", post(auth::login))
.route("/auth/register", post(auth::register))
.merge(protected_router())
.route_layer(axum::middleware::from_fn_with_state(
state.clone(),
middleware::resolve_context,
)),
)
.layer(axum::middleware::from_fn(middleware::response_map))
.layer(cors)
.layer(tower_http::trace::TraceLayer::new_for_http())
.with_state(state.clone())
}
fn protected_router() -> axum::Router<state::AppState> {
use axum::Router;
use axum::routing::*;
use self::route::*;
Router::new()
// user
.route("/users/@me", get(user::me))
.route("/users/@me/channels", get(user::channel::list))
.route("/users/{id}", get(user::get_by_id))
// server
.route("/servers", get(server::list))
.route("/servers", post(server::create))
.route("/servers/{server_id}", get(server::get))
.route("/servers/{server_id}/channels", get(server::channel::list))
.route("/voice/{channel_id}/connect", post(voice::connect))
// middleware
.route_layer(axum::middleware::from_fn(middleware::require_context))
}
async fn shutdown_signal() {
_ = tokio::signal::ctrl_c().await;
}

View File

@@ -0,0 +1,49 @@
use argon2::{PasswordHash, PasswordVerifier};
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use serde::{Deserialize, Serialize};
use crate::state::AppState;
use crate::web::route::user::FullUser;
use crate::{jwt, web};
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LoginPayload {
username: String,
password: String,
}
#[derive(Serialize)]
#[serde(rename_all = "camelCase")]
pub struct LoginResponse {
user: FullUser,
token: String,
}
pub async fn login(
State(state): State<AppState>,
Json(payload): Json<LoginPayload>,
) -> web::Result<impl IntoResponse> {
let user = state
.database
.select_user_by_username(&payload.username)
.await?;
let password_hash = PasswordHash::new(&user.password_hash)?;
state
.hasher
.verify_password(payload.password.as_bytes(), &password_hash)
.map_err(|_| web::error::ClientError::WrongPassword)?;
let token = jwt::generate_jwt(user.id)?;
let response = LoginResponse {
user: user.into(),
token,
};
Ok(Json(response))
}

View File

@@ -0,0 +1,5 @@
pub mod login;
pub mod register;
pub use login::*;
pub use register::*;

View File

@@ -0,0 +1,68 @@
use argon2::password_hash::rand_core::OsRng;
use argon2::password_hash::{PasswordHasher, SaltString};
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use serde::Deserialize;
use validator::Validate;
use crate::state::AppState;
use crate::web;
use crate::web::error::ClientError;
use crate::web::route::user::FullUser;
#[derive(Validate, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPayload {
#[validate(regex(
path = "crate::entity::user::USERNAME_REGEX",
code = "invalid_username"
))]
#[validate(length(min = 3, max = 64))]
username: String,
#[validate(length(min = 1, max = 64))]
display_name: Option<String>,
#[validate(email)]
#[validate(length(min = 3, max = 128))]
email: String,
#[validate(length(min = 3, max = 128))]
password: String,
}
pub async fn register(
State(state): State<AppState>,
WithRejection(Json(payload), _): WithRejection<Json<RegisterPayload>, web::Error>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
let salt = SaltString::generate(&mut OsRng);
let password_hash = state
.hasher
.hash_password(payload.password.as_bytes(), &salt)?
.to_string();
let user = state
.database
.insert_user(
&payload.username,
payload.display_name.as_ref().map(String::as_str),
&payload.email,
&password_hash,
)
.await?;
let system_user = state.database.select_user_by_username("system").await?;
if system_user.system {
state
.database
.procedure_create_dm_channel(user.id, system_user.id)
.await?;
}
Ok(Json(FullUser::from(user)))
}

4
src/web/route/mod.rs Normal file
View File

@@ -0,0 +1,4 @@
pub mod auth;
pub mod server;
pub mod user;
pub mod voice;

View File

@@ -0,0 +1,17 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
pub async fn list(
State(state): State<AppState>,
context: UserContext,
Path(id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
let channels = state.database.select_server_channels(id).await?;
Ok(Json(channels))
}

View File

@@ -0,0 +1,3 @@
mod list;
pub use list::*;

View File

@@ -0,0 +1,83 @@
use axum::Json;
use axum::body::Bytes;
use axum::extract::State;
use axum::response::IntoResponse;
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
use validator::{Validate, ValidationError};
use crate::state::AppState;
use crate::util::SerdeFieldData;
use crate::web;
use crate::web::context::UserContext;
use crate::web::error::ClientError;
use crate::web::ws;
#[derive(Debug, Validate, TryFromMultipart)]
#[try_from_multipart(rename_all = "camelCase")]
pub struct CreatePayload {
#[validate(length(min = 1, max = 32))]
name: String,
#[validate(custom(function = "validate_icon_content_type"))]
#[form_data(limit = "10MB")]
icon: Option<SerdeFieldData<Bytes>>,
}
fn validate_icon_content_type(icon: &SerdeFieldData<Bytes>) -> Result<(), ValidationError> {
if let Some(content_type) = icon.metadata.content_type.as_deref() {
if !content_type.starts_with("image/") {
return Err(ValidationError::new("invalid_icon_content_type"));
}
} else {
return Err(ValidationError::new("missing_icon_content_type"));
}
Ok(())
}
pub async fn create(
State(state): State<AppState>,
context: UserContext,
TypedMultipart(payload): TypedMultipart<CreatePayload>,
) -> web::Result<impl IntoResponse> {
payload.validate().map_err(ClientError::ValidationFailed)?;
let server = state
.database
.insert_server(&payload.name, None, context.user_id)
.await?;
let role = state
.database
.insert_server_role(
Some(server.id.into()),
server.id,
"@everyone",
None,
false,
serde_json::json!({}),
0,
)
.await?;
let member = state
.database
.insert_server_member(server.id, context.user_id)
.await?;
state
.database
.insert_server_member_role(member.id, role.id)
.await?;
ws::gateway::util::send_message(
&state,
context.user_id,
ws::gateway::message::Event::AddServer {
server: server.clone(),
},
)
.await;
Ok(Json(server))
}

View File

@@ -0,0 +1,15 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::{entity, web};
pub async fn get(
State(state): State<AppState>,
Path(id): Path<entity::server::Id>,
) -> web::Result<impl IntoResponse> {
let server = state.database.select_server_by_id(id).await?;
Ok(Json(server))
}

View File

@@ -0,0 +1,16 @@
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
pub async fn list(
State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> {
let servers = state.database.select_user_servers(context.user_id).await?;
Ok(Json(servers))
}

View File

@@ -0,0 +1,8 @@
pub mod channel;
mod create;
mod get;
mod list;
pub use create::*;
pub use get::*;
pub use list::*;

View File

@@ -0,0 +1,48 @@
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use serde::Serialize;
use crate::entity::channel;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
use crate::web::route::user::PartialUser;
#[derive(Debug, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct RecipientChannel {
#[serde(flatten)]
pub channel: channel::Channel,
pub recipients: Vec<PartialUser>,
}
pub async fn list(
State(state): State<AppState>,
context: UserContext,
) -> web::Result<impl IntoResponse> {
let channels = state.database.select_user_channels(context.user_id).await?;
let mut recipient_channels = Vec::new();
for channel in channels {
let recipients = state.database.select_channel_recipients(channel.id).await?;
let recipients = match recipients {
Some(recipients) => recipients
.into_iter()
.filter(|user| user.id != context.user_id)
.map(PartialUser::from)
.collect(),
None => {
continue;
},
};
recipient_channels.push(RecipientChannel {
channel,
recipients,
});
}
Ok(Json(recipient_channels))
}

View File

@@ -0,0 +1,3 @@
mod list;
pub use list::*;

16
src/web/route/user/get.rs Normal file
View File

@@ -0,0 +1,16 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::route::user::PartialUser;
pub async fn get_by_id(
Path(id): Path<uuid::Uuid>,
State(state): State<AppState>,
) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(id).await?;
Ok(Json(PartialUser::from(user)))
}

17
src/web/route/user/me.rs Normal file
View File

@@ -0,0 +1,17 @@
use axum::Json;
use axum::extract::State;
use axum::response::IntoResponse;
use crate::state::AppState;
use crate::web;
use crate::web::context::UserContext;
use crate::web::route::user::FullUser;
pub async fn me(
context: UserContext,
State(state): State<AppState>,
) -> web::Result<impl IntoResponse> {
let user = state.database.select_user_by_id(context.user_id).await?;
Ok(Json(FullUser::from(user)))
}

60
src/web/route/user/mod.rs Normal file
View File

@@ -0,0 +1,60 @@
pub mod channel;
mod get;
mod me;
pub use get::*;
pub use me::*;
use crate::entity::user;
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct FullUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub email: String,
pub bot: bool,
pub system: bool,
pub settings: serde_json::Value,
}
#[derive(serde::Serialize, Debug)]
#[serde(rename_all = "camelCase")]
pub struct PartialUser {
pub id: user::Id,
pub avatar_url: Option<String>,
pub username: String,
pub display_name: Option<String>,
pub bot: bool,
pub system: bool,
}
impl From<user::User> for FullUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_url,
username: user.username,
display_name: user.display_name,
email: user.email,
bot: user.bot,
system: user.system,
settings: user.settings,
}
}
}
impl From<user::User> for PartialUser {
fn from(user: user::User) -> Self {
Self {
id: user.id,
avatar_url: user.avatar_url,
username: user.username,
display_name: user.display_name,
bot: user.bot,
system: user.system,
}
}
}

View File

@@ -0,0 +1,91 @@
use axum::Json;
use axum::extract::{Path, State};
use axum::response::IntoResponse;
use axum_extra::extract::WithRejection;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use crate::state::AppState;
use crate::web::context::UserContext;
use crate::{entity, web};
#[derive(Debug, serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Payload {
sdp: RTCSessionDescription,
}
#[derive(Debug, serde::Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Response {
sdp: RTCSessionDescription,
}
pub async fn connect(
State(state): State<AppState>,
context: UserContext,
Path(channel_id): Path<entity::channel::Id>,
WithRejection(Json(payload), _): WithRejection<Json<Payload>, web::Error>,
) -> web::Result<impl IntoResponse> {
tracing::debug!("connect to voice channel: {:?}", channel_id);
let channel = state.database.select_channel_by_id(channel_id).await?;
let channel_id = channel.id;
let room_sender = {
state
.voice_rooms
.read()
.await
.get(&channel_id)
.map(|room| room.clone())
};
let room_sender = match room_sender {
Some(room) => room,
None => {
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let rooms = state.voice_rooms.clone();
tokio::spawn(async move {
crate::webrtc::webrtc_task(channel_id, rx)
.await
.unwrap_or_else(|err| {
tracing::error!("webrtc task error: {:?}", err);
});
{
let mut rooms = rooms.write().await;
rooms.remove(&channel_id);
}
});
{
let mut rooms = state.voice_rooms.write().await;
rooms.insert(channel_id, tx.clone());
}
tx
},
};
let offer = crate::webrtc::Offer {
peer_id: context.user_id,
sdp_offer: payload.sdp,
};
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let _ = room_sender.send(crate::webrtc::OfferSignal {
offer,
response: response_tx,
});
let answer = response_rx
.await
.map_err(|_| web::error::ClientError::InternalServerError)?;
let response = Response {
sdp: answer.sdp_answer,
};
Ok(Json(response))
}

View File

@@ -0,0 +1,3 @@
mod connect;
pub use connect::*;

View File

@@ -0,0 +1,340 @@
// src/web/ws/connection.rs
use std::ops::ControlFlow;
use std::time::Duration;
use axum::extract::ws::{Message as AxumMessage, WebSocket};
use base64::Engine as _; // Bring trait into scope
use futures::stream::SplitStream;
use futures::{SinkExt, StreamExt};
use sha2::{Digest, Sha256};
use tokio::time::Instant;
// Use items from sibling modules within `ws`
use super::error::{self, Error as WsError};
// Assuming Event type is from ws::message
use super::message::Event as WsEvent;
use super::protocol::{
SendWsMessage, WsClientMessage, WsServerMessage, deserialize_ws_message, serialize_ws_message,
};
use super::state::{WsContext, WsState, WsUserContext};
use crate::state::AppState;
/// Main handler for an individual WebSocket connection's lifecycle.
/// Spawned by Axum upon successful WebSocket upgrade.
#[tracing::instrument(skip_all, name = "ws_connection_handler")]
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {
let (ws_sink, ws_stream) = websocket.split();
let (internal_send_tx, mut internal_send_rx) = tokio::sync::mpsc::unbounded_channel();
// Writer task: consumes messages from MPSC channel and sends them to the WebSocket sink.
let writer_task = tokio::spawn(async move {
let mut ws_sink_mut = ws_sink;
while let Some(SendWsMessage {
message,
response_ch,
}) = internal_send_rx.recv().await
{
let send_result = match serialize_ws_message(message) {
Ok(ws_msg) => {
if ws_sink_mut.send(ws_msg).await.is_err() {
Err(WsError::WebSocketClosed) // Send to client failed
} else {
Ok(())
}
},
Err(e) => Err(e), // Serialization error itself
};
if let Some(ch) = response_ch {
if ch.send(send_result).is_err() {
// Log if the receiver of the acknowledgement was dropped, though this is unlikely
// if send_with_response is awaiting it.
tracing::debug!("Failed to send acknowledgement; receiver dropped.");
}
} else if let Err(e) = send_result {
// For fire-and-forget, log critical errors (not just WebSocketClosed).
if !matches!(e, WsError::WebSocketClosed) {
tracing::warn!("Error in fire-and-forget WebSocket send: {:?}", e);
}
}
}
// MPSC channel closed, attempt to gracefully close WebSocket.
if ws_sink_mut.close().await.is_err() {
tracing::debug!("Error closing WebSocket sink; connection might be already dead.");
}
});
let mut context = WsContext {
connection_state: WsState::Initialize,
user_context: None,
heartbeat_interval: std::time::Duration::from_secs(30), // Assuming config path
next_ping_deadline: Instant::now(), // Will be properly set before first use
event_channel: None,
};
let processing_result = process_websocket_messages(
&mut context,
ws_stream,
&internal_send_tx, // Pass as reference
&app_state,
)
.await;
// --- Cleanup ---
if let Some(user_ctx_data) = &context.user_context {
app_state
.unregister_event_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key)
.await;
tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user.");
}
// Drop our sender for the event channel; receiver in `process_websocket_messages` will see this.
drop(context.event_channel.take());
// If processing loop exited with an error (not a graceful close like WebSocketClosed or HeartbeatTimeout),
// try to send a final error message to the client.
if let Err(err_to_report) = &processing_result {
if !matches!(
err_to_report,
WsError::WebSocketClosed | WsError::HeartbeatTimeout
) {
tracing::warn!(
"WebSocket processing error, attempting to notify client: {:?}",
err_to_report
);
let client_err_code = err_to_report.into_client_error();
let error_ws_message = WsServerMessage::Error {
code: client_err_code,
};
// Use new_no_response for best-effort send during shutdown.
// Ignore result as internal_send_tx might already be closed if writer_task ended.
let _ = internal_send_tx.send(SendWsMessage::new_no_response(error_ws_message));
}
}
// Signal writer task to stop by dropping the MPSC sender.
drop(internal_send_tx);
// Wait for the writer task to complete its shutdown.
if let Err(e) = writer_task.await {
tracing::error!(
"WebSocket writer task panicked or encountered an error: {:?}",
e
);
}
tracing::debug!(result = ?processing_result, "WebSocket connection handler finished.");
}
/// Main loop for processing incoming WebSocket messages and outgoing application events.
/// Manages state transitions (Initialize -> Connected) and heartbeating.
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
async fn process_websocket_messages(
context: &mut WsContext,
mut ws_stream: SplitStream<WebSocket>,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
app_state: &AppState,
) -> error::Result<()> {
// Send initial heartbeat interval and set first deadline.
SendWsMessage::send_with_response(
sender,
WsServerMessage::HeartbeatInterval {
interval: context.heartbeat_interval,
},
)
.await?;
context.reset_deadline();
loop {
match context.connection_state {
WsState::Initialize => {
tokio::select! {
biased; // Prefer timeout check if multiple branches are ready
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
tracing::warn!("Initial connection timeout (no Authenticate or Ping).");
return Err(WsError::HeartbeatTimeout);
}
maybe_message = ws_stream.next() => {
match maybe_message {
Some(Ok(message)) => {
match handle_initial_message(context, message, sender, app_state).await {
Ok(ControlFlow::Continue(())) => {},
Ok(ControlFlow::Break(new_state)) => { // Authenticated
context.connection_state = new_state;
tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected.");
},
Err(e) => { // Auth failed critically or other error
return Err(e);
}
}
}
Some(Err(axum_ws_err)) => {
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
return Err(WsError::WebSocketClosed);
}
None => { // Stream closed by client
tracing::debug!("WebSocket stream ended by client during Initialize state.");
return Err(WsError::WebSocketClosed);
}
}
}
}
},
WsState::Connected => {
let user_ctx = context
.user_context
.as_ref()
.expect("User context must be set in Connected state");
let (_event_tx_ref, event_rx) = context
.event_channel
.as_mut()
.expect("Event channel must be set in Connected state");
tokio::select! {
biased;
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
tracing::warn!(user_id = ?user_ctx.user_id, "Heartbeat timeout.");
return Err(WsError::HeartbeatTimeout);
}
// Listen for application events to send to the client
maybe_app_event = event_rx.recv() => {
if let Some(app_event_data) = maybe_app_event {
SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?;
// Sending an app event doesn't reset the client's ping requirement.
} else {
// Event channel closed (e.g., AppState unregistered, or system shutdown signal)
tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket.");
return Ok(()); // Graceful shutdown signaled by closed event channel
}
}
// Listen for messages from the client (e.g., Ping)
maybe_ws_message = ws_stream.next() => {
match maybe_ws_message {
Some(Ok(message)) => {
handle_connected_message(context, message, sender).await?;
}
Some(Err(axum_ws_err)) => {
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
return Err(WsError::WebSocketClosed);
}
None => { // Stream closed by client
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
return Err(WsError::WebSocketClosed);
}
}
}
}
},
}
}
}
/// Handles messages received when the connection is in the `Initialize` state.
/// Expects `Authenticate` to transition to `Connected`, or `Ping` to stay in `Initialize`.
#[tracing::instrument(skip_all, fields(state = ?context.connection_state))]
async fn handle_initial_message(
context: &mut WsContext,
message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
app_state: &AppState,
) -> error::Result<ControlFlow<WsState, ()>> {
// Break(NewState) or Continue(())
match deserialize_ws_message(message)? {
WsClientMessage::Authenticate { token } => {
// IMPORTANT: Adjust the call below to your actual token validation logic.
// Assuming `get_context_from_token` returns `Result<crate::web::context::UserContext, YourAuthError>`
match crate::web::middleware::get_context_from_token(
&app_state, // Example: Pass necessary parts of AppState
&token,
)
.await
{
Ok(auth_user_context) => {
// auth_user_context is `crate::web::context::UserContext`
let user_id = auth_user_context.user_id;
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
context.event_channel = Some((event_tx.clone(), event_rx));
let random_key_part = rand::random::<u64>();
let current_session_key = {
let mut hasher = Sha256::new();
hasher.update(token.as_bytes());
hasher.update(user_id.to_string().as_bytes());
hasher.update(&random_key_part.to_be_bytes());
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize())
};
context.user_context = Some(WsUserContext {
user_id,
session_key: current_session_key.clone(),
});
app_state
.register_event_connected_user(
user_id,
current_session_key.clone(),
event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender<ws::message::Event>
)
.await;
SendWsMessage::send_with_response(
sender,
WsServerMessage::AuthenticateAccepted {
user_id,
session_key: current_session_key.clone(),
},
)
.await?;
// Deadline is reset by the caller upon ControlFlow::Break
Ok(ControlFlow::Break(WsState::Connected))
},
Err(_auth_err) => {
tracing::warn!(token = %token, "Authentication failed for token.");
// Send AuthenticateDenied, then the connection will be closed by HeartbeatTimeout or by returning error.
// We send response to ensure client gets the denial before we might drop connection.
let _ = SendWsMessage::send_with_response(
sender,
WsServerMessage::AuthenticateDenied,
)
.await;
Err(WsError::AuthenticationFailed) // This will terminate process_websocket_messages
},
}
},
WsClientMessage::Ping => {
context.reset_deadline(); // Reset deadline on successful ping
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
Ok(ControlFlow::Continue(()))
},
// Per original code, only Authenticate and Ping are expected in Initialize.
// If WsClientMessage has other variants, this might need adjustment.
#[allow(unreachable_patterns)]
_ => {
tracing::warn!("Unexpected message type received during Initialize state.");
Err(WsError::UnexpectedMessageType)
},
}
}
/// Handles messages received when the connection is in the `Connected` state.
/// Primarily expects `Ping` messages to keep the connection alive.
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
async fn handle_connected_message(
context: &mut WsContext, // Although not heavily used here, good for consistency and tracing
message: AxumMessage,
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
) -> error::Result<()> {
match deserialize_ws_message(message)? {
WsClientMessage::Ping => {
tracing::debug!("Ping received.");
context.reset_deadline();
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
Ok(())
},
other_message => {
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
Err(WsError::UnexpectedMessageType)
},
}
}

View File

@@ -0,0 +1,46 @@
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::From, derive_more::Display)]
pub enum Error {
#[from]
Axum(axum::Error),
#[from]
Json(serde_json::Error),
#[from]
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
UnexpectedMessageType,
WrongMessageType,
WebSocketClosed,
HeartbeatTimeout,
AuthenticationFailed,
}
#[derive(Debug, Clone, serde::Serialize)]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum ClientError {
DeserializationError,
NotAuthenticated,
AlreadyAuthenticated,
HeartbeatTimeout,
Unknown,
}
impl Error {
pub fn into_client_error(&self) -> ClientError {
match self {
Error::HeartbeatTimeout => ClientError::HeartbeatTimeout,
Error::Json(_) => ClientError::DeserializationError,
Error::UnexpectedMessageType => ClientError::Unknown,
Error::WrongMessageType => ClientError::Unknown,
Error::WebSocketClosed => ClientError::Unknown,
_ => ClientError::Unknown,
}
}
}

View File

@@ -0,0 +1,15 @@
use crate::entity;
#[derive(Debug, Clone, serde::Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum Event {
AddServer { server: entity::server::Server },
RemoveServer { server_id: entity::server::Id },
AddDmChannel { channel: entity::channel::Channel },
RemoveDmChannel { channel_id: entity::channel::Id },
AddServerChannel { channel: entity::channel::Channel },
RemoveServerChannel { channel_id: entity::channel::Id },
}

29
src/web/ws/gateway/mod.rs Normal file
View File

@@ -0,0 +1,29 @@
use axum::extract::{State, WebSocketUpgrade};
use axum::response::IntoResponse;
use base64::Engine;
use dashmap::DashMap;
use futures::{SinkExt, StreamExt};
use sha2::Digest;
use crate::state::AppState;
use crate::web::ws::gateway::connection::handle_socket_connection;
use crate::web::ws::gateway::state::EventSender;
mod connection;
mod error;
pub mod message;
mod protocol;
mod state;
pub mod util;
#[derive(Debug, Default)]
pub struct EventWsState {
pub connection_instance: DashMap<String, EventSender>,
}
pub async fn ws_handler(
State(app_state): State<AppState>,
ws: WebSocketUpgrade,
) -> crate::web::error::Result<impl IntoResponse> {
Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state)))
}

View File

@@ -0,0 +1,99 @@
// src/web/ws/protocol.rs
use axum::extract::ws::Message as AxumMessage;
use serde::{Deserialize, Serialize};
use std::time::Duration;
use super::error::{self, ClientError, Error as WsError};
use super::message as ws_local_message; // For ws::message::Event
use crate::entity;
use crate::util as crate_root_util; // For crate::util::serialize_duration_seconds
#[derive(Debug, Serialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsServerMessage {
HeartbeatInterval {
#[serde(serialize_with = "crate_root_util::serialize_duration_seconds")]
interval: Duration,
},
AuthenticateDenied,
#[serde(rename_all = "camelCase")]
AuthenticateAccepted {
user_id: entity::user::Id,
session_key: String,
},
#[serde(rename_all = "camelCase")]
Event {
event: ws_local_message::Event, // Assumes Event is defined in ws::message
},
Pong,
Error {
code: ClientError,
},
}
#[derive(Debug, Deserialize)]
#[serde(tag = "type", content = "data")]
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
pub enum WsClientMessage {
#[serde(rename_all = "camelCase")]
Authenticate {
token: String,
},
Ping,
}
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
pub fn deserialize_ws_message(message: AxumMessage) -> error::Result<WsClientMessage> {
match message {
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from),
AxumMessage::Close(_) => Err(WsError::WebSocketClosed),
_ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message
}
}
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
pub fn serialize_ws_message(message: WsServerMessage) -> error::Result<AxumMessage> {
serde_json::to_string(&message)
.map(Into::into)
.map(AxumMessage::Text)
.map_err(WsError::from)
}
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
pub struct SendWsMessage {
pub message: WsServerMessage,
pub response_ch: Option<tokio::sync::oneshot::Sender<error::Result<()>>>,
}
impl SendWsMessage {
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
pub async fn send_with_response(
tx: &tokio::sync::mpsc::UnboundedSender<Self>, // Changed to reference
message: WsServerMessage,
) -> error::Result<()> {
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let send_message = SendWsMessage {
message,
response_ch: Some(response_tx),
};
if tx.send(send_message).is_err() {
Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead
} else {
// Wait for the writer task to acknowledge the send attempt.
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
}
}
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
pub fn new_no_response(message: WsServerMessage) -> Self {
SendWsMessage {
message,
response_ch: None,
}
}
}

View File

@@ -0,0 +1,49 @@
// src/web/ws/state.rs
use std::time::{Duration};
use tokio::sync::mpsc;
use super::message;
use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver
/// Represents the current state of a single WebSocket connection.
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
pub enum WsState {
Initialize, // Connection established, awaiting authentication
Connected, // Authenticated and operational
}
/// Contextual information for an authenticated WebSocket user session.
#[derive(Debug, Clone)] // Clone might be useful
pub struct WsUserContext {
pub user_id: entity::user::Id,
pub session_key: String, // Unique key for this specific WebSocket session instance
}
/// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client.
pub type EventSender = mpsc::UnboundedSender<message::Event>;
/// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s.
pub type EventReceiver = mpsc::UnboundedReceiver<message::Event>;
/// Holds the full context for a single WebSocket connection's lifecycle.
/// This struct is managed per-connection.
pub struct WsContext {
pub connection_state: WsState,
pub user_context: Option<WsUserContext>,
pub heartbeat_interval: Duration,
pub next_ping_deadline: tokio::time::Instant,
/// Channel for receiving application-specific events to be sent to this client.
/// The `EventSender` (tx) part is given to `AppState` for broadcasting.
/// The `EventReceiver` (rx) part is polled by the connection task.
pub event_channel: Option<(EventSender, EventReceiver)>,
}
impl WsContext {
/// Resets the ping deadline based on the current time and heartbeat interval.
/// This should be called after successfully receiving a ping from the client
/// or after sending a message that implies activity (like Pong or initial auth).
pub fn reset_deadline(&mut self) {
self.next_ping_deadline = tokio::time::Instant::now() + self.heartbeat_interval;
}
}

View File

@@ -0,0 +1,14 @@
use crate::entity;
use crate::state::AppState;
use crate::web::ws::gateway::message;
pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: message::Event) {
let connected_users = state.event_connected_users.read().await;
if let Some(state) = connected_users.get(&user_id) {
for instance in state.connection_instance.iter() {
if let Err(e) = instance.send(message.clone()) {
tracing::error!("failed to send message: {}", e);
}
}
}
}

1
src/web/ws/mod.rs Normal file
View File

@@ -0,0 +1 @@
pub mod gateway;

288
src/webrtc/mod.rs Normal file
View File

@@ -0,0 +1,288 @@
use std::sync::Arc;
use dashmap::DashMap;
use tracing::Instrument;
use webrtc::api::interceptor_registry::register_default_interceptors;
use webrtc::api::media_engine::MIME_TYPE_OPUS;
use webrtc::api::{API, APIBuilder};
use webrtc::interceptor::registry::Registry;
use webrtc::peer_connection::configuration::RTCConfiguration;
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
use webrtc::track::track_remote::TrackRemote;
use crate::entity;
type PeerId = entity::user::Id;
type RoomId = entity::channel::Id;
struct PeerState {
peer_id: PeerId,
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
}
struct RoomState {
room_id: RoomId,
peers: DashMap<PeerId, PeerState>,
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
}
#[derive(Debug)]
pub struct Offer {
pub peer_id: PeerId,
pub sdp_offer: RTCSessionDescription,
}
#[derive(Debug)]
pub struct AnswerSignal {
pub sdp_answer: RTCSessionDescription,
}
#[derive(Debug)]
pub struct OfferSignal {
pub offer: Offer,
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
}
#[tracing::instrument(skip(signal))]
pub async fn webrtc_task(
room_id: RoomId,
signal: tokio::sync::mpsc::UnboundedReceiver<OfferSignal>,
) -> anyhow::Result<()> {
tracing::info!("Starting WebRTC task");
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
let state = Arc::new(RoomState {
room_id,
peers: DashMap::new(),
close_signal,
});
let mut signal = signal;
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
media_engine.register_default_codecs()?;
let mut registry = Registry::new();
registry = register_default_interceptors(registry, &mut media_engine)?;
let api = APIBuilder::new()
.with_media_engine(media_engine)
.with_interceptor_registry(registry)
.build();
let api = Arc::new(api);
loop {
tokio::select! {
Some(signal) = signal.recv() => {
let room_state = state.clone();
let api = api.clone();
tokio::spawn(async move {
if let Err(e) = handle_peer(api, room_state, signal).await {
tracing::error!("error handling peer: {}", e);
}
}.instrument(tracing::Span::current()));
}
_ = close_receiver.recv() => {
tracing::debug!("WebRTC task stopped");
break;
}
}
}
Ok(())
}
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id))]
async fn handle_peer(
api: Arc<API>,
room_state: Arc<RoomState>,
offer_signal: OfferSignal,
) -> anyhow::Result<()> {
tracing::debug!("handling peer");
let config = RTCConfiguration {
..Default::default()
};
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
RTCRtpCodecCapability {
mime_type: MIME_TYPE_OPUS.to_string(),
..Default::default()
},
format!("audio-{}", offer_signal.offer.peer_id),
format!("room-{}", room_state.room_id),
));
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
let rtp_sender = peer_connection
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
.await?;
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
let outgoing_track_ = Arc::clone(&outgoing_track);
tokio::spawn(
async move {
let mut rtcp_buf = vec![0u8; 1500];
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
}
tracing::debug!(
"RTCP reader loop for outgoing track {} ended",
outgoing_track_.id()
);
}
.instrument(tracing::Span::current()),
);
let room_state_ = room_state.clone();
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
let room_state_ = room_state_.clone();
Box::pin(async move {
tracing::debug!("peer connection state changed: {:?}", state);
if state == RTCPeerConnectionState::Disconnected
|| state == RTCPeerConnectionState::Failed
{
tracing::debug!("peer connection closed");
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
}
})
}));
let room_state_ = room_state.clone();
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
tracing::debug!("track received: {:?}", track);
let room_state_ = room_state_.clone();
Box::pin(async move {
if track.kind() == RTPCodecType::Audio {
tracing::debug!("audio track received: {:?}", track);
tokio::spawn(async move {
if let Err(e) =
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
{
tracing::error!("error forwarding audio track: {}", e);
}
});
} else {
tracing::warn!("received non-audio track: {:?}", track);
}
})
}));
peer_connection
.set_remote_description(offer_signal.offer.sdp_offer)
.await?;
let answer = peer_connection.create_answer(None).await?;
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
peer_connection.set_local_description(answer).await?;
gathering_complete.recv().await;
tracing::debug!("ICE gathering complete");
let local_description = peer_connection
.local_description()
.await
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
let peer_state = PeerState {
peer_id: offer_signal.offer.peer_id,
peer_connection: Arc::clone(&peer_connection),
outgoing_audio_track: Arc::clone(&outgoing_track),
};
{
if let Some(old_peer) = room_state
.peers
.insert(offer_signal.offer.peer_id, peer_state)
{
let _ = old_peer.peer_connection.close().await;
}
}
let _ = offer_signal.response.send(AnswerSignal {
sdp_answer: local_description,
});
Ok(())
}
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id))]
async fn forward_audio_track(
room_state: Arc<RoomState>,
peer_id: PeerId,
track: Arc<TrackRemote>,
) -> anyhow::Result<()> {
let mut rtp_buf = vec![0u8; 1500];
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
let other_peer_tracks = room_state
.peers
.iter()
.filter_map(|pair| {
let peer_state = pair.value();
if peer_state.peer_id != peer_id {
Some(peer_state.outgoing_audio_track.clone())
} else {
None
}
})
.collect::<Vec<_>>();
if other_peer_tracks.is_empty() {
// tracing::warn!("no other peers to forward audio track to");
continue;
}
let write_futures = other_peer_tracks
.iter()
.map(|outgoing_track| outgoing_track.write_rtp(&rtp_packet));
let results = futures::future::join_all(write_futures).await;
for result in results {
if let Err(e) = result {
tracing::error!("error writing RTP packet: {}", e);
}
}
}
tracing::debug!(
"RTP Read loop ended for track {} from peer {}",
track.id(),
peer_id
);
Ok(())
}
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
tracing::debug!("cleaning up peer");
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
tracing::debug!("removed peer");
let pc = Arc::clone(&peer_state.peer_connection);
tokio::spawn(async move {
if let Err(e) = pc.close().await {
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
tracing::warn!("error closing peer connection: {}", e);
}
}
tracing::debug!("peer connection closed");
});
}
if room_state.peers.is_empty() {
tracing::debug!("no more peers in room, closing room");
let _ = room_state.close_signal.send(());
}
}