.
This commit is contained in:
3
.cargo/config.toml
Normal file
3
.cargo/config.toml
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
[target.x86_64-unknown-linux-gnu]
|
||||||
|
linker = "/usr/bin/clang"
|
||||||
|
rustflags = ["-C", "link-arg=--ld-path=/usr/bin/mold"]
|
||||||
7
.gitignore
vendored
Normal file
7
.gitignore
vendored
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
/target
|
||||||
|
/.idea
|
||||||
|
/config.toml
|
||||||
|
/db
|
||||||
|
/logs
|
||||||
|
|
||||||
|
.env
|
||||||
4432
Cargo.lock
generated
Normal file
4432
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load Diff
37
Cargo.toml
Normal file
37
Cargo.toml
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
[package]
|
||||||
|
name = "diplom"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2024"
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
anyhow = "1.0"
|
||||||
|
argon2 = "0.5"
|
||||||
|
axum = { version = "0.8", features = ["ws", "multipart", "macros"] }
|
||||||
|
axum-extra = { version = "0.10", features = ["typed-header"] }
|
||||||
|
chrono = { version = "0.4", features = ["serde"] }
|
||||||
|
config = "0.15"
|
||||||
|
derive_more = { version = "2.0", features = ["full"] }
|
||||||
|
jsonwebtoken = "9.3"
|
||||||
|
rand_core = { version = "0.6", features = ["getrandom"] }
|
||||||
|
serde = { version = "1.0", features = ["derive"] }
|
||||||
|
serde_json = "1.0"
|
||||||
|
sqlx = { version = "0.8", features = ["postgres", "runtime-tokio", "uuid", "chrono"] }
|
||||||
|
tokio = { version = "1.45", features = ["full"] }
|
||||||
|
tower-http = { version = "0.6", features = ["cors", "trace"] }
|
||||||
|
tracing = "0.1"
|
||||||
|
tracing-appender = "0.2"
|
||||||
|
tracing-subscriber = { version = "0.3", features = ["serde", "env-filter", "chrono"] }
|
||||||
|
uuid = { version = "1.16", features = ["fast-rng", "serde", "v7"] }
|
||||||
|
url = { version = "2.5", features = ["serde"] }
|
||||||
|
validator = { version = "0.20.0", features = ["derive"] }
|
||||||
|
regex = "1.11.1"
|
||||||
|
mime = "0.3.17"
|
||||||
|
axum_typed_multipart = "0.16.2"
|
||||||
|
async-trait = "0.1.88"
|
||||||
|
futures = "0.3.31"
|
||||||
|
webrtc = "0.12.0"
|
||||||
|
dashmap = "6.1.0"
|
||||||
|
rand = "0.9.1"
|
||||||
|
sha2 = "0.10.9"
|
||||||
|
hex = "0.4.3"
|
||||||
|
base64 = "0.22.1"
|
||||||
10
docker-compose.yml
Normal file
10
docker-compose.yml
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
services:
|
||||||
|
database:
|
||||||
|
image: 'ghcr.io/craigpastro/pg_uuidv7:main'
|
||||||
|
ports:
|
||||||
|
- "15432:5432"
|
||||||
|
env_file:
|
||||||
|
- .env
|
||||||
|
volumes:
|
||||||
|
- ${PWD}/db/:/var/lib/postgresql/data/
|
||||||
|
user: "1000:1000"
|
||||||
1
migrations/20250510183101_uuidv7.down.sql
Normal file
1
migrations/20250510183101_uuidv7.down.sql
Normal file
@@ -0,0 +1 @@
|
|||||||
|
DROP EXTENSION pg_uuidv7;
|
||||||
1
migrations/20250510183101_uuidv7.up.sql
Normal file
1
migrations/20250510183101_uuidv7.up.sql
Normal file
@@ -0,0 +1 @@
|
|||||||
|
CREATE EXTENSION IF NOT EXISTS pg_uuidv7;
|
||||||
4
migrations/20250510183102_user.down.sql
Normal file
4
migrations/20250510183102_user.down.sql
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
DROP TRIGGER trg_user_relation_update ON "user_relation";
|
||||||
|
DROP FUNCTION fn_on_user_relation_update();
|
||||||
|
DROP TABLE "user_relation";
|
||||||
|
DROP TABLE "user";
|
||||||
46
migrations/20250510183102_user.up.sql
Normal file
46
migrations/20250510183102_user.up.sql
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "user"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"avatar_url" VARCHAR,
|
||||||
|
"username" VARCHAR NOT NULL UNIQUE,
|
||||||
|
"display_name" VARCHAR,
|
||||||
|
"email" VARCHAR NOT NULL,
|
||||||
|
"password_hash" VARCHAR NOT NULL,
|
||||||
|
"bot" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
"system" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
"settings" JSONB NOT NULL DEFAULT '{}'::JSONB
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "user_relation"
|
||||||
|
(
|
||||||
|
"user_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
|
||||||
|
"other_id" UUID NOT NULL CHECK (user_id <> other_id) REFERENCES "user" ("id") ON DELETE CASCADE,
|
||||||
|
"type" INT2 NOT NULL,
|
||||||
|
"created_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
"updated_at" TIMESTAMPTZ NOT NULL DEFAULT now(),
|
||||||
|
PRIMARY KEY ("user_id", "other_id")
|
||||||
|
);
|
||||||
|
|
||||||
|
-- create system account
|
||||||
|
INSERT INTO "user" ("username", "display_name", "email", "password_hash", "bot", "system")
|
||||||
|
VALUES ('system', 'System', 'system@lionarius.ru', '', TRUE, TRUE);
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION fn_on_user_relation_update()
|
||||||
|
RETURNS TRIGGER
|
||||||
|
LANGUAGE plpgsql
|
||||||
|
AS
|
||||||
|
$$
|
||||||
|
BEGIN
|
||||||
|
IF (TG_OP = 'UPDATE') THEN
|
||||||
|
NEW.updated_at := now();
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
|
|
||||||
|
CREATE TRIGGER trg_user_relation_update
|
||||||
|
BEFORE UPDATE
|
||||||
|
ON "user_relation"
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION fn_on_user_relation_update();
|
||||||
1
migrations/20250510183125_server.down.sql
Normal file
1
migrations/20250510183125_server.down.sql
Normal file
@@ -0,0 +1 @@
|
|||||||
|
DROP TABLE "server";
|
||||||
77
migrations/20250510183125_server.up.sql
Normal file
77
migrations/20250510183125_server.up.sql
Normal file
@@ -0,0 +1,77 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "server"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"owner_id" UUID NOT NULL REFERENCES "user" ("id"),
|
||||||
|
"name" VARCHAR NOT NULL,
|
||||||
|
"icon_url" VARCHAR
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "server_role"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
||||||
|
"name" VARCHAR NOT NULL,
|
||||||
|
"color" VARCHAR,
|
||||||
|
"display" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||||
|
"permissions" JSONB NOT NULL DEFAULT '{}'::JSONB,
|
||||||
|
"position" SMALLINT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "server_member"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
||||||
|
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
|
||||||
|
"nickname" VARCHAR,
|
||||||
|
"avatar_url" VARCHAR,
|
||||||
|
UNIQUE ("server_id", "user_id")
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "server_member_role"
|
||||||
|
(
|
||||||
|
"member_id" UUID NOT NULL REFERENCES "server_member" ("id") ON DELETE CASCADE,
|
||||||
|
"role_id" UUID NOT NULL REFERENCES "server_role" ("id") ON DELETE CASCADE,
|
||||||
|
PRIMARY KEY ("member_id", "role_id")
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "server_invite"
|
||||||
|
(
|
||||||
|
"code" VARCHAR NOT NULL PRIMARY KEY,
|
||||||
|
"server_id" UUID NOT NULL REFERENCES "server" ("id") ON DELETE CASCADE,
|
||||||
|
"inviter_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL,
|
||||||
|
"expires_at" TIMESTAMPTZ
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION check_server_user_role_server_id()
|
||||||
|
RETURNS TRIGGER AS
|
||||||
|
$$
|
||||||
|
DECLARE
|
||||||
|
member_server_id UUID;
|
||||||
|
role_server_id UUID;
|
||||||
|
BEGIN
|
||||||
|
-- Get server_id from server_user
|
||||||
|
SELECT server_id
|
||||||
|
INTO member_server_id
|
||||||
|
FROM server_member
|
||||||
|
WHERE id = NEW.member_id;
|
||||||
|
|
||||||
|
-- Get server_id from server_role
|
||||||
|
SELECT server_id
|
||||||
|
INTO role_server_id
|
||||||
|
FROM server_role
|
||||||
|
WHERE id = NEW.role_id;
|
||||||
|
|
||||||
|
-- Check if server_ids match
|
||||||
|
IF member_server_id != role_server_id THEN
|
||||||
|
RAISE EXCEPTION 'Cannot assign role from a different server: server_user server_id (%) does not match server_role server_id (%)', member_server_id, role_server_id;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
RETURN NEW;
|
||||||
|
END;
|
||||||
|
$$ LANGUAGE plpgsql;
|
||||||
|
|
||||||
|
CREATE TRIGGER enforce_server_user_role_server_id
|
||||||
|
BEFORE INSERT OR UPDATE
|
||||||
|
ON server_member_role
|
||||||
|
FOR EACH ROW
|
||||||
|
EXECUTE FUNCTION check_server_user_role_server_id();
|
||||||
2
migrations/20250510184011_channel_message.down.sql
Normal file
2
migrations/20250510184011_channel_message.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
DROP TABLE "message";
|
||||||
|
DROP TABLE "channel";
|
||||||
97
migrations/20250510184011_channel_message.up.sql
Normal file
97
migrations/20250510184011_channel_message.up.sql
Normal file
@@ -0,0 +1,97 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "channel"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"name" VARCHAR NOT NULL,
|
||||||
|
"type" INT2 NOT NULL,
|
||||||
|
"position" INT2 NOT NULL,
|
||||||
|
"server_id" UUID REFERENCES "server" ("id") ON DELETE CASCADE, -- only for server channels
|
||||||
|
"parent" UUID REFERENCES "channel" ("id") ON DELETE SET NULL, -- only for server channels
|
||||||
|
"owner_id" UUID REFERENCES "user" ("id") ON DELETE SET NULL -- only for group channels
|
||||||
|
);
|
||||||
|
|
||||||
|
-- only for dm or group channels
|
||||||
|
CREATE TABLE IF NOT EXISTS "channel_recipient"
|
||||||
|
(
|
||||||
|
"channel_id" UUID NOT NULL REFERENCES "channel" ("id") ON DELETE CASCADE,
|
||||||
|
"user_id" UUID NOT NULL REFERENCES "user" ("id") ON DELETE CASCADE,
|
||||||
|
PRIMARY KEY ("channel_id", "user_id")
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "message"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"author_id" UUID NOT NULL REFERENCES "user" ("id"),
|
||||||
|
"channel_id" UUID NOT NULL REFERENCES "channel" ("id"),
|
||||||
|
"content" TEXT NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
ALTER TABLE "channel"
|
||||||
|
ADD COLUMN "last_message_id" UUID REFERENCES "message" ("id") ON DELETE SET NULL;
|
||||||
|
|
||||||
|
|
||||||
|
CREATE OR REPLACE FUNCTION create_dm_channel(
|
||||||
|
user1_id UUID,
|
||||||
|
user2_id UUID,
|
||||||
|
channel_type INT2
|
||||||
|
)
|
||||||
|
RETURNS UUID
|
||||||
|
LANGUAGE plpgsql
|
||||||
|
AS
|
||||||
|
$$
|
||||||
|
DECLARE
|
||||||
|
new_channel_id UUID;
|
||||||
|
channel_name VARCHAR;
|
||||||
|
BEGIN
|
||||||
|
-- Validate parameters
|
||||||
|
IF user1_id IS NULL OR user2_id IS NULL THEN
|
||||||
|
RAISE EXCEPTION 'Both user IDs must be provided';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
IF user1_id = user2_id THEN
|
||||||
|
RAISE EXCEPTION 'Cannot create a DM channel with the same user';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Check if users exist
|
||||||
|
IF NOT exists (SELECT 1 FROM "user" WHERE id = user1_id) OR
|
||||||
|
NOT exists (SELECT 1 FROM "user" WHERE id = user2_id) THEN
|
||||||
|
RAISE EXCEPTION 'One or both users do not exist';
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Check if DM already exists between these users
|
||||||
|
IF exists (SELECT 1
|
||||||
|
FROM channel c
|
||||||
|
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
|
||||||
|
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
|
||||||
|
WHERE c.type = channel_type
|
||||||
|
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2) THEN
|
||||||
|
-- Find and return the existing channel ID
|
||||||
|
SELECT c.id
|
||||||
|
INTO new_channel_id
|
||||||
|
FROM channel c
|
||||||
|
JOIN channel_recipient cr1 ON c.id = cr1.channel_id AND cr1.user_id = user1_id
|
||||||
|
JOIN channel_recipient cr2 ON c.id = cr2.channel_id AND cr2.user_id = user2_id
|
||||||
|
WHERE c.type = channel_type
|
||||||
|
AND (SELECT count(*) FROM channel_recipient WHERE channel_id = c.id) = 2
|
||||||
|
LIMIT 1;
|
||||||
|
|
||||||
|
RAISE NOTICE 'DM channel already exists between these users with ID: %', new_channel_id;
|
||||||
|
RETURN new_channel_id;
|
||||||
|
END IF;
|
||||||
|
|
||||||
|
-- Generate channel name (conventionally uses user IDs in DMs)
|
||||||
|
channel_name := concat(user1_id, '-', user2_id);
|
||||||
|
|
||||||
|
-- Create new channel
|
||||||
|
INSERT INTO "channel" ("name", "type", "position")
|
||||||
|
VALUES (channel_name, channel_type, 0)
|
||||||
|
RETURNING id INTO new_channel_id;
|
||||||
|
|
||||||
|
-- Add both users as recipients
|
||||||
|
INSERT INTO "channel_recipient" ("channel_id", "user_id")
|
||||||
|
VALUES (new_channel_id, user1_id),
|
||||||
|
(new_channel_id, user2_id);
|
||||||
|
|
||||||
|
RAISE NOTICE 'DM channel created with ID: %', new_channel_id;
|
||||||
|
RETURN new_channel_id;
|
||||||
|
END;
|
||||||
|
$$;
|
||||||
2
migrations/20250510184916_file.down.sql
Normal file
2
migrations/20250510184916_file.down.sql
Normal file
@@ -0,0 +1,2 @@
|
|||||||
|
DROP TABLE "message_attachment";
|
||||||
|
DROP TABLE "file";
|
||||||
16
migrations/20250510184916_file.up.sql
Normal file
16
migrations/20250510184916_file.up.sql
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS "file"
|
||||||
|
(
|
||||||
|
"id" UUID NOT NULL PRIMARY KEY DEFAULT uuid_generate_v7(),
|
||||||
|
"filename" VARCHAR NOT NULL,
|
||||||
|
"content_type" VARCHAR NOT NULL,
|
||||||
|
"url" VARCHAR NOT NULL,
|
||||||
|
"size" INT8 NOT NULL
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS "message_attachment"
|
||||||
|
(
|
||||||
|
"message_id" UUID NOT NULL REFERENCES "message" ON DELETE CASCADE,
|
||||||
|
"attachment_id" UUID NOT NULL REFERENCES "file" ON DELETE CASCADE,
|
||||||
|
"order" INT2 NOT NULL,
|
||||||
|
PRIMARY KEY ("message_id", "attachment_id")
|
||||||
|
);
|
||||||
7
rustfmt.toml
Normal file
7
rustfmt.toml
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
match_block_trailing_comma = true
|
||||||
|
newline_style = "Unix"
|
||||||
|
style_edition = "2024"
|
||||||
|
|
||||||
|
group_imports = "StdExternalCrate"
|
||||||
|
imports_granularity = "Module"
|
||||||
|
unstable_features = true
|
||||||
64
src/config.rs
Normal file
64
src/config.rs
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
use std::sync::OnceLock;
|
||||||
|
|
||||||
|
use serde::Deserialize;
|
||||||
|
|
||||||
|
pub fn config() -> &'static Config {
|
||||||
|
static INSTANCE: OnceLock<Config> = OnceLock::new();
|
||||||
|
|
||||||
|
INSTANCE.get_or_init(|| {
|
||||||
|
config::Config::builder()
|
||||||
|
.add_source(config::File::with_name("config"))
|
||||||
|
.add_source(config::Environment::with_prefix("DIPLOM_"))
|
||||||
|
.build()
|
||||||
|
.expect("config builder")
|
||||||
|
.try_deserialize()
|
||||||
|
.expect("config deserialize")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub port: u16,
|
||||||
|
pub jwt_secret: String,
|
||||||
|
pub database: DatabaseConfig,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(untagged)]
|
||||||
|
#[serde(deny_unknown_fields)]
|
||||||
|
pub enum DatabaseConfig {
|
||||||
|
Url {
|
||||||
|
url: url::Url,
|
||||||
|
},
|
||||||
|
Full {
|
||||||
|
driver: String,
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
host: url::Host,
|
||||||
|
port: u16,
|
||||||
|
name: String,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
impl DatabaseConfig {
|
||||||
|
pub fn url(&self) -> Option<url::Url> {
|
||||||
|
match self {
|
||||||
|
Self::Url { url } => Some(url.clone()),
|
||||||
|
Self::Full {
|
||||||
|
driver,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
host,
|
||||||
|
port,
|
||||||
|
name,
|
||||||
|
} => {
|
||||||
|
let str = format!(
|
||||||
|
"{}://{}:{}@{}:{}/{}",
|
||||||
|
driver, username, password, host, port, name
|
||||||
|
);
|
||||||
|
let url = url::Url::parse(&str).ok()?;
|
||||||
|
Some(url)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
333
src/database.rs
Normal file
333
src/database.rs
Normal file
@@ -0,0 +1,333 @@
|
|||||||
|
use crate::config::DatabaseConfig;
|
||||||
|
use crate::entity;
|
||||||
|
|
||||||
|
#[derive(Clone, derive_more::AsRef)]
|
||||||
|
pub struct Database {
|
||||||
|
pool: sqlx::PgPool,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[derive(Debug, derive_more::From, derive_more::Error, derive_more::Display)]
|
||||||
|
pub enum Error {
|
||||||
|
#[from]
|
||||||
|
Migrate(sqlx::migrate::MigrateError),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Sqlx(sqlx::Error),
|
||||||
|
|
||||||
|
UserDoesNotExists,
|
||||||
|
UserAlreadyExists,
|
||||||
|
|
||||||
|
ServerDoesNotExists,
|
||||||
|
|
||||||
|
ChannelDoesNotExists,
|
||||||
|
|
||||||
|
MessageDoesNotExists,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
pub async fn connect(config: &DatabaseConfig) -> sqlx::Result<Self> {
|
||||||
|
static MIGRATOR: sqlx::migrate::Migrator = sqlx::migrate!();
|
||||||
|
|
||||||
|
let pool = sqlx::postgres::PgPoolOptions::new()
|
||||||
|
.connect(config.url().ok_or(sqlx::Error::BeginFailed)?.as_str())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
MIGRATOR.run(&pool).await?;
|
||||||
|
|
||||||
|
Ok(Self { pool })
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Database {
|
||||||
|
pub async fn insert_user(
|
||||||
|
&self,
|
||||||
|
username: &str,
|
||||||
|
display_name: Option<&str>,
|
||||||
|
email: &str,
|
||||||
|
password_hash: &str,
|
||||||
|
) -> Result<entity::user::User> {
|
||||||
|
let user = match sqlx::query_as!(
|
||||||
|
entity::user::User,
|
||||||
|
r#"INSERT INTO "user"("username", "display_name", "email", "password_hash") VALUES ($1, $2, $3, $4) RETURNING "user".*"#,
|
||||||
|
username,
|
||||||
|
display_name,
|
||||||
|
email,
|
||||||
|
password_hash
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await {
|
||||||
|
Ok(user) => user,
|
||||||
|
Err(sqlx::Error::Database(e)) if e.code() == Some("23505".into()) => {
|
||||||
|
return Err(Error::UserAlreadyExists);
|
||||||
|
}
|
||||||
|
Err(e) => return Err(e.into()),
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_user_by_id(&self, user_id: entity::user::Id) -> Result<entity::user::User> {
|
||||||
|
let user = sqlx::query_as!(
|
||||||
|
entity::user::User,
|
||||||
|
r#"SELECT * FROM "user" WHERE "id" = $1"#,
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::UserDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_users_by_ids(
|
||||||
|
&self,
|
||||||
|
user_ids: &[entity::user::Id],
|
||||||
|
) -> Result<Vec<entity::user::User>> {
|
||||||
|
let mut query_builder = sqlx::QueryBuilder::new(r#"SELECT * FROM "user" WHERE "id" IN ("#);
|
||||||
|
|
||||||
|
let mut separated = query_builder.separated(", ");
|
||||||
|
for id in user_ids.iter() {
|
||||||
|
separated.push_bind(id);
|
||||||
|
}
|
||||||
|
|
||||||
|
query_builder.push(")");
|
||||||
|
|
||||||
|
let users = query_builder.build_query_as().fetch_all(&self.pool).await?;
|
||||||
|
|
||||||
|
Ok(users)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_user_by_username(&self, username: &str) -> Result<entity::user::User> {
|
||||||
|
let user = sqlx::query_as!(
|
||||||
|
entity::user::User,
|
||||||
|
r#"SELECT * FROM "user" WHERE "username" = $1"#,
|
||||||
|
username
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::UserDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(user)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_server_by_id(
|
||||||
|
&self,
|
||||||
|
server_id: entity::server::Id,
|
||||||
|
) -> Result<entity::server::Server> {
|
||||||
|
let server = sqlx::query_as!(
|
||||||
|
entity::server::Server,
|
||||||
|
r#"SELECT * FROM "server" WHERE "id" = $1"#,
|
||||||
|
server_id
|
||||||
|
)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::ServerDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_user_servers(
|
||||||
|
&self,
|
||||||
|
user_id: entity::user::Id,
|
||||||
|
) -> Result<Vec<entity::server::Server>> {
|
||||||
|
let servers = sqlx::query_as!(
|
||||||
|
entity::server::Server,
|
||||||
|
r#"SELECT * FROM "server" WHERE "id" IN (
|
||||||
|
SELECT "server_id" FROM "server_member"
|
||||||
|
WHERE "user_id" = $1
|
||||||
|
)"#,
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(servers)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_user_channels(
|
||||||
|
&self,
|
||||||
|
user_id: entity::user::Id,
|
||||||
|
) -> Result<Vec<entity::channel::Channel>> {
|
||||||
|
// for some reason using macro overflows tokio stack
|
||||||
|
let channels = sqlx::query_as(
|
||||||
|
r#"SELECT * FROM "channel" WHERE "id" IN (
|
||||||
|
SELECT "channel_id" FROM "channel_recipient" WHERE "user_id" = $1
|
||||||
|
)"#,
|
||||||
|
)
|
||||||
|
.bind(user_id)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(channels)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_server(
|
||||||
|
&self,
|
||||||
|
name: &str,
|
||||||
|
icon_url: Option<&str>,
|
||||||
|
owner_id: entity::user::Id,
|
||||||
|
) -> Result<entity::server::Server> {
|
||||||
|
let server = sqlx::query_as!(
|
||||||
|
entity::server::Server,
|
||||||
|
r#"INSERT INTO "server"("name", "icon_url", "owner_id") VALUES ($1, $2, $3) RETURNING "server".*"#,
|
||||||
|
name,
|
||||||
|
icon_url,
|
||||||
|
owner_id
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_server_role(
|
||||||
|
&self,
|
||||||
|
id: Option<entity::server::role::Id>,
|
||||||
|
server_id: entity::server::Id,
|
||||||
|
name: &str,
|
||||||
|
color: Option<&str>,
|
||||||
|
display: bool,
|
||||||
|
permissions: serde_json::Value,
|
||||||
|
position: u16,
|
||||||
|
) -> Result<entity::server::role::ServerRole> {
|
||||||
|
let role = sqlx::query_as!(
|
||||||
|
entity::server::role::ServerRole,
|
||||||
|
r#"INSERT INTO "server_role"("id", "server_id", "name", "color", "display", "permissions", "position") VALUES ($1, $2, $3, $4, $5, $6, $7) RETURNING "server_role".*"#,
|
||||||
|
id,
|
||||||
|
server_id,
|
||||||
|
name,
|
||||||
|
color,
|
||||||
|
display,
|
||||||
|
permissions,
|
||||||
|
position as i16
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(role)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_server_member(
|
||||||
|
&self,
|
||||||
|
server_id: entity::server::Id,
|
||||||
|
user_id: entity::user::Id,
|
||||||
|
) -> Result<entity::server::member::ServerMember> {
|
||||||
|
let member = sqlx::query_as!(
|
||||||
|
entity::server::member::ServerMember,
|
||||||
|
r#"INSERT INTO "server_member"("server_id", "user_id") VALUES ($1, $2) RETURNING "server_member".*"#,
|
||||||
|
server_id,
|
||||||
|
user_id
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(member)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_server_member_role(
|
||||||
|
&self,
|
||||||
|
server_member_id: entity::server::member::Id,
|
||||||
|
server_role_id: entity::server::role::Id,
|
||||||
|
) -> Result<()> {
|
||||||
|
sqlx::query!(
|
||||||
|
r#"INSERT INTO "server_member_role"("member_id", "role_id") VALUES ($1, $2)"#,
|
||||||
|
server_member_id,
|
||||||
|
server_role_id
|
||||||
|
)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn insert_server_channel(
|
||||||
|
&self,
|
||||||
|
server_id: entity::server::Id,
|
||||||
|
name: &str,
|
||||||
|
channel_type: entity::channel::ChannelType,
|
||||||
|
position: u16,
|
||||||
|
parent: Option<entity::channel::Id>,
|
||||||
|
) -> Result<entity::channel::Channel> {
|
||||||
|
let channel = sqlx::query_as!(
|
||||||
|
entity::channel::Channel,
|
||||||
|
r#"INSERT INTO "channel"("name", "type", "position", "server_id", "parent") VALUES ($1, $2, $3, $4, $5) RETURNING "channel".*"#,
|
||||||
|
name,
|
||||||
|
channel_type as i16,
|
||||||
|
position as i16,
|
||||||
|
server_id,
|
||||||
|
parent
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_channel_by_id(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::channel::Id,
|
||||||
|
) -> Result<entity::channel::Channel> {
|
||||||
|
let channel = sqlx::query_as(r#"SELECT * FROM "channel" WHERE "id" = $1"#)
|
||||||
|
.bind(channel_id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::ChannelDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_channel_recipients(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::channel::Id,
|
||||||
|
) -> Result<Option<Vec<entity::user::User>>> {
|
||||||
|
let recipients = sqlx::query_as!(
|
||||||
|
entity::user::User,
|
||||||
|
r#"SELECT * FROM "user" WHERE "id" IN (
|
||||||
|
SELECT "user_id" FROM "channel_recipient" WHERE "channel_id" = $1
|
||||||
|
)"#,
|
||||||
|
channel_id
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
if recipients.is_empty() {
|
||||||
|
return Ok(None);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Some(recipients))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn select_server_channels(
|
||||||
|
&self,
|
||||||
|
server_id: entity::server::Id,
|
||||||
|
) -> Result<Vec<entity::channel::Channel>> {
|
||||||
|
let channels = sqlx::query_as!(
|
||||||
|
entity::channel::Channel,
|
||||||
|
r#"SELECT * FROM "channel" WHERE "server_id" = $1"#,
|
||||||
|
server_id
|
||||||
|
)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(channels)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn procedure_create_dm_channel(
|
||||||
|
&self,
|
||||||
|
user1_id: entity::user::Id,
|
||||||
|
user2_id: entity::user::Id,
|
||||||
|
) -> Result<Option<entity::channel::Id>> {
|
||||||
|
let channel_id = sqlx::query_scalar!(
|
||||||
|
r#"SELECT create_dm_channel($1, $2, $3)"#,
|
||||||
|
user1_id,
|
||||||
|
user2_id,
|
||||||
|
entity::channel::ChannelType::DirectMessage as i16
|
||||||
|
)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(channel_id)
|
||||||
|
}
|
||||||
|
}
|
||||||
12
src/entity/attachment.rs
Normal file
12
src/entity/attachment.rs
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
pub struct Attachment {
|
||||||
|
pub id: Id,
|
||||||
|
pub filename: String,
|
||||||
|
pub content_type: String,
|
||||||
|
pub url: String,
|
||||||
|
pub size: u64,
|
||||||
|
}
|
||||||
41
src/entity/channel.rs
Normal file
41
src/entity/channel.rs
Normal file
@@ -0,0 +1,41 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::{message, server, user};
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Channel {
|
||||||
|
pub id: Id,
|
||||||
|
pub name: String,
|
||||||
|
pub r#type: ChannelType,
|
||||||
|
pub position: i16,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub server_id: Option<server::Id>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub parent: Option<Id>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub owner_id: Option<user::Id>,
|
||||||
|
#[serde(skip_serializing_if = "Option::is_none")]
|
||||||
|
pub last_message_id: Option<message::Id>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::Type, Serialize)]
|
||||||
|
#[non_exhaustive]
|
||||||
|
#[serde(rename_all = "snake_case")]
|
||||||
|
#[repr(i16)]
|
||||||
|
pub enum ChannelType {
|
||||||
|
ServerText = 1,
|
||||||
|
ServerVoice = 2,
|
||||||
|
ServerCategory = 3,
|
||||||
|
|
||||||
|
DirectMessage = 4,
|
||||||
|
Group = 5,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<i16> for ChannelType {
|
||||||
|
fn from(value: i16) -> Self {
|
||||||
|
value.try_into().unwrap_or(ChannelType::ServerText)
|
||||||
|
}
|
||||||
|
}
|
||||||
15
src/entity/message.rs
Normal file
15
src/entity/message.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::{channel, user};
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Message {
|
||||||
|
pub id: Id,
|
||||||
|
pub author_id: user::Id,
|
||||||
|
pub channel_id: channel::Id,
|
||||||
|
pub content: String,
|
||||||
|
pub timestamp: chrono::DateTime<chrono::Utc>,
|
||||||
|
}
|
||||||
5
src/entity/mod.rs
Normal file
5
src/entity/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod attachment;
|
||||||
|
pub mod channel;
|
||||||
|
pub mod message;
|
||||||
|
pub mod server;
|
||||||
|
pub mod user;
|
||||||
18
src/entity/server.rs
Normal file
18
src/entity/server.rs
Normal file
@@ -0,0 +1,18 @@
|
|||||||
|
mod invite;
|
||||||
|
pub mod member;
|
||||||
|
pub mod role;
|
||||||
|
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::user;
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Server {
|
||||||
|
pub id: Id,
|
||||||
|
pub owner_id: user::Id,
|
||||||
|
pub name: String,
|
||||||
|
pub icon_url: Option<String>,
|
||||||
|
}
|
||||||
13
src/entity/server/invite.rs
Normal file
13
src/entity/server/invite.rs
Normal file
@@ -0,0 +1,13 @@
|
|||||||
|
use chrono::{DateTime, Utc};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::{server, user};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ServerInvite {
|
||||||
|
pub code: String,
|
||||||
|
pub server_id: server::Id,
|
||||||
|
pub inviter_id: Option<user::Id>,
|
||||||
|
pub expires_at: Option<DateTime<Utc>>,
|
||||||
|
}
|
||||||
15
src/entity/server/member.rs
Normal file
15
src/entity/server/member.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::{server, user};
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ServerMember {
|
||||||
|
pub id: Id,
|
||||||
|
pub server_id: server::Id,
|
||||||
|
pub user_id: user::Id,
|
||||||
|
pub nickname: Option<String>,
|
||||||
|
pub avatar_url: Option<String>,
|
||||||
|
}
|
||||||
17
src/entity/server/role.rs
Normal file
17
src/entity/server/role.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::server;
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct ServerRole {
|
||||||
|
pub id: Id,
|
||||||
|
pub server_id: server::Id,
|
||||||
|
pub name: String,
|
||||||
|
pub color: Option<String>,
|
||||||
|
pub display: bool,
|
||||||
|
pub permissions: serde_json::Value,
|
||||||
|
pub position: i16,
|
||||||
|
}
|
||||||
21
src/entity/user.rs
Normal file
21
src/entity/user.rs
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
use std::sync::LazyLock;
|
||||||
|
|
||||||
|
use regex::Regex;
|
||||||
|
|
||||||
|
pub static USERNAME_REGEX: LazyLock<Regex> =
|
||||||
|
LazyLock::new(|| Regex::new(r"^[a-zA-Z0-9_.]+$").unwrap());
|
||||||
|
|
||||||
|
pub type Id = uuid::Uuid;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow)]
|
||||||
|
pub struct User {
|
||||||
|
pub id: Id,
|
||||||
|
pub avatar_url: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
pub email: String,
|
||||||
|
pub password_hash: String,
|
||||||
|
pub bot: bool,
|
||||||
|
pub system: bool,
|
||||||
|
pub settings: serde_json::Value,
|
||||||
|
}
|
||||||
55
src/jwt.rs
Normal file
55
src/jwt.rs
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
#![allow(unused)]
|
||||||
|
|
||||||
|
use std::collections::HashSet;
|
||||||
|
|
||||||
|
use chrono::{Duration, Local, Utc};
|
||||||
|
use serde::de::DeserializeOwned;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::config;
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize, Deserialize)]
|
||||||
|
pub struct Claims<T> {
|
||||||
|
pub user_id: T,
|
||||||
|
pub iat: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn generate_jwt<T: Serialize>(user_id: T) -> Result<String> {
|
||||||
|
let claims = Claims {
|
||||||
|
user_id,
|
||||||
|
iat: Utc::now().timestamp_millis(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let token = jsonwebtoken::encode(
|
||||||
|
&jsonwebtoken::Header::default(),
|
||||||
|
&claims,
|
||||||
|
&jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()),
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::CouldNotEncodeToken)?;
|
||||||
|
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn verify_jwt<T: DeserializeOwned>(token: &str) -> Result<T> {
|
||||||
|
tracing::debug!("verifying token: {}", token);
|
||||||
|
|
||||||
|
let mut validation = jsonwebtoken::Validation::default();
|
||||||
|
validation.required_spec_claims = HashSet::new();
|
||||||
|
|
||||||
|
let token_data = jsonwebtoken::decode::<Claims<T>>(
|
||||||
|
token,
|
||||||
|
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
|
||||||
|
&validation,
|
||||||
|
)
|
||||||
|
.map_err(|_| Error::CouldNotVerifyToken)?;
|
||||||
|
|
||||||
|
Ok(token_data.claims.user_id)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[derive(Debug, derive_more::Error, derive_more::Display)]
|
||||||
|
pub enum Error {
|
||||||
|
CouldNotEncodeToken,
|
||||||
|
CouldNotVerifyToken,
|
||||||
|
}
|
||||||
43
src/log.rs
Normal file
43
src/log.rs
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
use std::io;
|
||||||
|
|
||||||
|
use tracing::level_filters::LevelFilter;
|
||||||
|
use tracing_appender::non_blocking::WorkerGuard;
|
||||||
|
use tracing_subscriber::EnvFilter;
|
||||||
|
use tracing_subscriber::fmt::time::ChronoLocal;
|
||||||
|
use tracing_subscriber::layer::SubscriberExt;
|
||||||
|
use tracing_subscriber::util::SubscriberInitExt;
|
||||||
|
|
||||||
|
pub fn init_logging() -> io::Result<(WorkerGuard, WorkerGuard)> {
|
||||||
|
let logger_timer = ChronoLocal::new("%FT%H:%M:%S%.6f%:::z".to_string());
|
||||||
|
|
||||||
|
let file = tracing_appender::rolling::daily("./logs", "log");
|
||||||
|
|
||||||
|
let (file, file_guard) = tracing_appender::non_blocking(file);
|
||||||
|
|
||||||
|
let file_logger = tracing_subscriber::fmt::layer()
|
||||||
|
.with_ansi(false)
|
||||||
|
.with_timer(logger_timer.clone())
|
||||||
|
.with_writer(file);
|
||||||
|
|
||||||
|
let (stdout, stdout_guard) = tracing_appender::non_blocking(io::stdout());
|
||||||
|
|
||||||
|
let stdout_logger = tracing_subscriber::fmt::layer()
|
||||||
|
.with_timer(logger_timer)
|
||||||
|
.with_writer(stdout);
|
||||||
|
#[cfg(debug_assertions)]
|
||||||
|
let stdout_logger = stdout_logger.pretty();
|
||||||
|
|
||||||
|
tracing_subscriber::registry()
|
||||||
|
.with(
|
||||||
|
EnvFilter::builder()
|
||||||
|
.with_default_directive(LevelFilter::INFO.into())
|
||||||
|
.from_env_lossy(),
|
||||||
|
)
|
||||||
|
.with(file_logger)
|
||||||
|
.with(stdout_logger)
|
||||||
|
.init();
|
||||||
|
|
||||||
|
tracing::info!("initialized logging");
|
||||||
|
|
||||||
|
Ok((file_guard, stdout_guard))
|
||||||
|
}
|
||||||
35
src/main.rs
Normal file
35
src/main.rs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use argon2::Argon2;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
|
||||||
|
use crate::database::Database;
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
mod config;
|
||||||
|
mod database;
|
||||||
|
mod entity;
|
||||||
|
mod jwt;
|
||||||
|
mod log;
|
||||||
|
mod state;
|
||||||
|
mod util;
|
||||||
|
mod web;
|
||||||
|
mod webrtc;
|
||||||
|
|
||||||
|
#[tokio::main]
|
||||||
|
async fn main() -> anyhow::Result<()> {
|
||||||
|
let _guard = log::init_logging()?;
|
||||||
|
|
||||||
|
let database = Database::connect(&config::config().database).await?;
|
||||||
|
let state = AppState {
|
||||||
|
database,
|
||||||
|
hasher: Arc::new(Argon2::default()),
|
||||||
|
event_connected_users: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
voice_rooms: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
};
|
||||||
|
|
||||||
|
web::run(state).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
51
src/state.rs
Normal file
51
src/state.rs
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use argon2::Argon2;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tokio::sync::{RwLock, mpsc};
|
||||||
|
use uuid::Uuid;
|
||||||
|
|
||||||
|
use crate::database::Database;
|
||||||
|
use crate::web::ws::gateway::{EventWsState, message};
|
||||||
|
use crate::webrtc::OfferSignal;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AppState {
|
||||||
|
pub database: Database,
|
||||||
|
pub hasher: Arc<Argon2<'static>>,
|
||||||
|
|
||||||
|
pub event_connected_users: Arc<RwLock<HashMap<Uuid, EventWsState>>>,
|
||||||
|
|
||||||
|
pub voice_rooms: Arc<RwLock<HashMap<Uuid, mpsc::UnboundedSender<OfferSignal>>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AppState {
|
||||||
|
pub async fn register_event_connected_user(
|
||||||
|
&self,
|
||||||
|
user_id: Uuid,
|
||||||
|
session_id: String,
|
||||||
|
event_sender: mpsc::UnboundedSender<message::Event>,
|
||||||
|
) {
|
||||||
|
let mut connected_users = self.event_connected_users.write().await;
|
||||||
|
if let Some(state) = connected_users.get_mut(&user_id) {
|
||||||
|
state.connection_instance.insert(session_id, event_sender);
|
||||||
|
} else {
|
||||||
|
let state = EventWsState {
|
||||||
|
connection_instance: DashMap::new(),
|
||||||
|
};
|
||||||
|
state.connection_instance.insert(session_id, event_sender);
|
||||||
|
connected_users.insert(user_id, state);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn unregister_event_connected_user(&self, user_id: Uuid, session_id: &str) {
|
||||||
|
let mut connected_users = self.event_connected_users.write().await;
|
||||||
|
if let Some(state) = connected_users.get_mut(&user_id) {
|
||||||
|
state.connection_instance.remove(session_id);
|
||||||
|
if state.connection_instance.is_empty() {
|
||||||
|
connected_users.remove(&user_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
52
src/util.rs
Normal file
52
src/util.rs
Normal file
@@ -0,0 +1,52 @@
|
|||||||
|
use axum::extract::multipart::Field;
|
||||||
|
use axum_typed_multipart::{FieldData, TryFromField, TypedMultipartError};
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Debug, derive_more::Deref)]
|
||||||
|
pub struct SerdeFieldData<T>(pub FieldData<T>);
|
||||||
|
|
||||||
|
#[async_trait::async_trait]
|
||||||
|
impl<T: TryFromField> TryFromField for SerdeFieldData<T> {
|
||||||
|
async fn try_from_field(
|
||||||
|
field: Field<'_>,
|
||||||
|
limit_bytes: Option<usize>,
|
||||||
|
) -> Result<Self, TypedMultipartError> {
|
||||||
|
let field = FieldData::try_from_field(field, limit_bytes).await?;
|
||||||
|
|
||||||
|
Ok(Self(field))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<T> Serialize for SerdeFieldData<T> {
|
||||||
|
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
struct Metadata {
|
||||||
|
name: Option<String>,
|
||||||
|
file_name: Option<String>,
|
||||||
|
content_type: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
let metadata = Metadata {
|
||||||
|
name: self.0.metadata.name.clone(),
|
||||||
|
file_name: self.0.metadata.file_name.clone(),
|
||||||
|
content_type: self.0.metadata.content_type.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
metadata.serialize(serializer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn serialize_duration_seconds<S>(
|
||||||
|
duration: &std::time::Duration,
|
||||||
|
serializer: S,
|
||||||
|
) -> Result<S::Ok, S::Error>
|
||||||
|
where
|
||||||
|
S: serde::Serializer,
|
||||||
|
{
|
||||||
|
let seconds = duration.as_secs();
|
||||||
|
seconds.serialize(serializer)
|
||||||
|
}
|
||||||
35
src/web/context.rs
Normal file
35
src/web/context.rs
Normal file
@@ -0,0 +1,35 @@
|
|||||||
|
use axum::extract::FromRequestParts;
|
||||||
|
use axum::http::request::Parts;
|
||||||
|
|
||||||
|
use crate::entity;
|
||||||
|
|
||||||
|
#[derive(Debug, Copy, Clone)]
|
||||||
|
pub struct UserContext {
|
||||||
|
pub user_id: entity::user::Id,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub type UserContextResult = Result<UserContext, Error>;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Copy, derive_more::Display)]
|
||||||
|
pub enum Error {
|
||||||
|
NotInRequest,
|
||||||
|
NotInHeader,
|
||||||
|
BadCharacters,
|
||||||
|
WrongTokenType,
|
||||||
|
BadToken,
|
||||||
|
Model,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<S: Send + Sync> FromRequestParts<S> for UserContext {
|
||||||
|
type Rejection = super::error::Error;
|
||||||
|
|
||||||
|
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
|
||||||
|
let context = parts
|
||||||
|
.extensions
|
||||||
|
.get::<UserContextResult>()
|
||||||
|
.cloned()
|
||||||
|
.ok_or(Error::NotInRequest)??;
|
||||||
|
|
||||||
|
Ok(context)
|
||||||
|
}
|
||||||
|
}
|
||||||
147
src/web/error.rs
Normal file
147
src/web/error.rs
Normal file
@@ -0,0 +1,147 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::http::StatusCode;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::web::context;
|
||||||
|
use crate::{database, jwt};
|
||||||
|
|
||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[derive(Debug, derive_more::From)]
|
||||||
|
pub enum Error {
|
||||||
|
#[from]
|
||||||
|
Context(context::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Jwt(jwt::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Hash(argon2::password_hash::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Database(database::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Json(serde_json::error::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
JsonRejection(axum::extract::rejection::JsonRejection),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Client(ClientError),
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, derive_more::From, serde::Serialize)]
|
||||||
|
#[serde(tag = "message", content = "details")]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum ClientError {
|
||||||
|
UserAlreadyExists,
|
||||||
|
UserDoesNotExist,
|
||||||
|
|
||||||
|
ServerDoesNotExist,
|
||||||
|
|
||||||
|
ChannelDoesNotExist,
|
||||||
|
|
||||||
|
MessageDoesNotExist,
|
||||||
|
|
||||||
|
NotAuthorized,
|
||||||
|
WrongPassword,
|
||||||
|
NotAllowed,
|
||||||
|
|
||||||
|
InvalidJson(JsonRejection),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
ValidationFailed(validator::ValidationErrors),
|
||||||
|
|
||||||
|
InternalServerError,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(derive_more::Debug, Clone, serde::Serialize)]
|
||||||
|
pub struct JsonRejection {
|
||||||
|
#[serde(skip)]
|
||||||
|
status: StatusCode,
|
||||||
|
|
||||||
|
reason: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<&axum::extract::rejection::JsonRejection> for JsonRejection {
|
||||||
|
fn from(value: &axum::extract::rejection::JsonRejection) -> Self {
|
||||||
|
use std::error::Error;
|
||||||
|
|
||||||
|
use axum::extract::rejection::JsonRejection::*;
|
||||||
|
|
||||||
|
let reason = match value {
|
||||||
|
JsonDataError(e) => e.source().map(ToString::to_string),
|
||||||
|
JsonSyntaxError(e) => e.source().map(ToString::to_string),
|
||||||
|
MissingJsonContentType(e) => e.source().map(ToString::to_string),
|
||||||
|
BytesRejection(e) => e.source().map(ToString::to_string),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
.unwrap_or_else(|| value.body_text());
|
||||||
|
|
||||||
|
Self {
|
||||||
|
status: value.status(),
|
||||||
|
reason,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn as_client_error(&self) -> ClientError {
|
||||||
|
match self {
|
||||||
|
Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized,
|
||||||
|
Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists,
|
||||||
|
Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExist,
|
||||||
|
|
||||||
|
Error::Database(database::Error::ServerDoesNotExists) => {
|
||||||
|
ClientError::ServerDoesNotExist
|
||||||
|
},
|
||||||
|
|
||||||
|
Error::Database(database::Error::ChannelDoesNotExists) => {
|
||||||
|
ClientError::ChannelDoesNotExist
|
||||||
|
},
|
||||||
|
|
||||||
|
Error::Database(database::Error::MessageDoesNotExists) => {
|
||||||
|
ClientError::MessageDoesNotExist
|
||||||
|
},
|
||||||
|
// Error::WrongPassword => ClientError::WrongPassword,
|
||||||
|
// Error::NotAllowed => ClientError::NotAllowed,
|
||||||
|
Error::JsonRejection(e) => ClientError::InvalidJson(e.into()),
|
||||||
|
Error::Client(client_error) => client_error.clone(),
|
||||||
|
_ => ClientError::InternalServerError,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl IntoResponse for Error {
|
||||||
|
fn into_response(self) -> axum::response::Response {
|
||||||
|
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
|
||||||
|
|
||||||
|
response.extensions_mut().insert(Arc::new(self));
|
||||||
|
|
||||||
|
response
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl ClientError {
|
||||||
|
pub fn status_code(&self) -> StatusCode {
|
||||||
|
match self {
|
||||||
|
ClientError::UserAlreadyExists => StatusCode::CONFLICT,
|
||||||
|
ClientError::UserDoesNotExist
|
||||||
|
| ClientError::ServerDoesNotExist
|
||||||
|
| ClientError::ChannelDoesNotExist
|
||||||
|
| ClientError::MessageDoesNotExist => StatusCode::NOT_FOUND,
|
||||||
|
|
||||||
|
ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED,
|
||||||
|
|
||||||
|
ClientError::NotAllowed => StatusCode::FORBIDDEN,
|
||||||
|
|
||||||
|
ClientError::ValidationFailed(_) => StatusCode::UNPROCESSABLE_ENTITY,
|
||||||
|
|
||||||
|
ClientError::InvalidJson(e) => e.status,
|
||||||
|
|
||||||
|
_ => StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
63
src/web/middleware/auth.rs
Normal file
63
src/web/middleware/auth.rs
Normal file
@@ -0,0 +1,63 @@
|
|||||||
|
use axum::extract::{Request, State};
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::{IntoResponse, Response};
|
||||||
|
use axum::{Extension, RequestExt};
|
||||||
|
use axum_extra::TypedHeader;
|
||||||
|
use axum_extra::headers::Authorization;
|
||||||
|
use axum_extra::headers::authorization::Bearer;
|
||||||
|
use axum_extra::typed_header::TypedHeaderRejectionReason;
|
||||||
|
|
||||||
|
use crate::jwt;
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::{self, context};
|
||||||
|
|
||||||
|
pub async fn require_context(
|
||||||
|
Extension(context): Extension<context::UserContextResult>,
|
||||||
|
request: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> web::error::Result<Response> {
|
||||||
|
context?;
|
||||||
|
|
||||||
|
Ok(next.run(request).await)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn resolve_context(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
mut request: Request,
|
||||||
|
next: Next,
|
||||||
|
) -> impl IntoResponse {
|
||||||
|
let context = get_context(&state, &mut request).await;
|
||||||
|
|
||||||
|
request.extensions_mut().insert(context);
|
||||||
|
|
||||||
|
next.run(request).await
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get_context(state: &AppState, request: &mut Request) -> context::UserContextResult {
|
||||||
|
let bearer = request
|
||||||
|
.extract_parts::<TypedHeader<Authorization<Bearer>>>()
|
||||||
|
.await
|
||||||
|
.map_err(|error| match error.reason() {
|
||||||
|
TypedHeaderRejectionReason::Missing => context::Error::NotInHeader,
|
||||||
|
TypedHeaderRejectionReason::Error(_) => context::Error::WrongTokenType,
|
||||||
|
_ => context::Error::NotInHeader,
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let token = bearer.token();
|
||||||
|
|
||||||
|
let context = get_context_from_token(state, token).await?;
|
||||||
|
|
||||||
|
Ok(context)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_context_from_token(state: &AppState, token: &str) -> context::UserContextResult {
|
||||||
|
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
|
||||||
|
|
||||||
|
let _ = state
|
||||||
|
.database
|
||||||
|
.select_user_by_id(user_id)
|
||||||
|
.await
|
||||||
|
.map_err(|_| context::Error::BadToken)?;
|
||||||
|
|
||||||
|
Ok(context::UserContext { user_id })
|
||||||
|
}
|
||||||
5
src/web/middleware/mod.rs
Normal file
5
src/web/middleware/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
mod auth;
|
||||||
|
mod response_map;
|
||||||
|
|
||||||
|
pub use auth::*;
|
||||||
|
pub use response_map::*;
|
||||||
31
src/web/middleware/response_map.rs
Normal file
31
src/web/middleware/response_map.rs
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use axum::Json;
|
||||||
|
use axum::extract::Request;
|
||||||
|
use axum::middleware::Next;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
|
use crate::web;
|
||||||
|
|
||||||
|
pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
|
||||||
|
let response = next.run(request).await;
|
||||||
|
|
||||||
|
let error = response.extensions().get::<Arc<web::Error>>();
|
||||||
|
|
||||||
|
if error.is_some() {
|
||||||
|
tracing::error!("{:?}", error);
|
||||||
|
}
|
||||||
|
|
||||||
|
let error_response = error.map(|error| {
|
||||||
|
let client_error = error.as_client_error();
|
||||||
|
|
||||||
|
(
|
||||||
|
client_error.status_code(),
|
||||||
|
Json(json!({"error": client_error})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
|
});
|
||||||
|
|
||||||
|
error_response.unwrap_or(response)
|
||||||
|
}
|
||||||
80
src/web/mod.rs
Normal file
80
src/web/mod.rs
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
mod context;
|
||||||
|
mod error;
|
||||||
|
mod middleware;
|
||||||
|
mod route;
|
||||||
|
pub mod ws;
|
||||||
|
|
||||||
|
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
|
||||||
|
|
||||||
|
pub use self::error::{Error, Result};
|
||||||
|
use crate::{config, state};
|
||||||
|
|
||||||
|
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
||||||
|
let config = config::config();
|
||||||
|
|
||||||
|
let addr: std::net::SocketAddr = ([127, 0, 0, 1], config.port).into();
|
||||||
|
tracing::info!("listening on {}", addr);
|
||||||
|
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||||
|
|
||||||
|
axum::serve(listener, router(state))
|
||||||
|
.with_graceful_shutdown(shutdown_signal())
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn router(state: state::AppState) -> axum::Router {
|
||||||
|
use axum::Router;
|
||||||
|
use axum::routing::*;
|
||||||
|
|
||||||
|
use self::route::*;
|
||||||
|
|
||||||
|
let cors = tower_http::cors::CorsLayer::new()
|
||||||
|
.allow_origin(AllowOrigin::any())
|
||||||
|
.allow_methods(AllowMethods::any())
|
||||||
|
.allow_headers(AllowHeaders::any());
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// websocket
|
||||||
|
.nest(
|
||||||
|
"/api/v1",
|
||||||
|
Router::new()
|
||||||
|
.route("/ws", get(ws::gateway::ws_handler))
|
||||||
|
.route("/auth/login", post(auth::login))
|
||||||
|
.route("/auth/register", post(auth::register))
|
||||||
|
.merge(protected_router())
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
|
state.clone(),
|
||||||
|
middleware::resolve_context,
|
||||||
|
)),
|
||||||
|
)
|
||||||
|
.layer(axum::middleware::from_fn(middleware::response_map))
|
||||||
|
.layer(cors)
|
||||||
|
.layer(tower_http::trace::TraceLayer::new_for_http())
|
||||||
|
.with_state(state.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn protected_router() -> axum::Router<state::AppState> {
|
||||||
|
use axum::Router;
|
||||||
|
use axum::routing::*;
|
||||||
|
|
||||||
|
use self::route::*;
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// user
|
||||||
|
.route("/users/@me", get(user::me))
|
||||||
|
.route("/users/@me/channels", get(user::channel::list))
|
||||||
|
.route("/users/{id}", get(user::get_by_id))
|
||||||
|
// server
|
||||||
|
.route("/servers", get(server::list))
|
||||||
|
.route("/servers", post(server::create))
|
||||||
|
.route("/servers/{server_id}", get(server::get))
|
||||||
|
.route("/servers/{server_id}/channels", get(server::channel::list))
|
||||||
|
.route("/voice/{channel_id}/connect", post(voice::connect))
|
||||||
|
// middleware
|
||||||
|
.route_layer(axum::middleware::from_fn(middleware::require_context))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn shutdown_signal() {
|
||||||
|
_ = tokio::signal::ctrl_c().await;
|
||||||
|
}
|
||||||
49
src/web/route/auth/login.rs
Normal file
49
src/web/route/auth/login.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
use argon2::{PasswordHash, PasswordVerifier};
|
||||||
|
use axum::Json;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::route::user::FullUser;
|
||||||
|
use crate::{jwt, web};
|
||||||
|
|
||||||
|
#[derive(Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct LoginPayload {
|
||||||
|
username: String,
|
||||||
|
password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct LoginResponse {
|
||||||
|
user: FullUser,
|
||||||
|
token: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn login(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(payload): Json<LoginPayload>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state
|
||||||
|
.database
|
||||||
|
.select_user_by_username(&payload.username)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let password_hash = PasswordHash::new(&user.password_hash)?;
|
||||||
|
|
||||||
|
state
|
||||||
|
.hasher
|
||||||
|
.verify_password(payload.password.as_bytes(), &password_hash)
|
||||||
|
.map_err(|_| web::error::ClientError::WrongPassword)?;
|
||||||
|
|
||||||
|
let token = jwt::generate_jwt(user.id)?;
|
||||||
|
|
||||||
|
let response = LoginResponse {
|
||||||
|
user: user.into(),
|
||||||
|
token,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
5
src/web/route/auth/mod.rs
Normal file
5
src/web/route/auth/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
pub mod login;
|
||||||
|
pub mod register;
|
||||||
|
|
||||||
|
pub use login::*;
|
||||||
|
pub use register::*;
|
||||||
68
src/web/route/auth/register.rs
Normal file
68
src/web/route/auth/register.rs
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
use argon2::password_hash::rand_core::OsRng;
|
||||||
|
use argon2::password_hash::{PasswordHasher, SaltString};
|
||||||
|
use axum::Json;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum_extra::extract::WithRejection;
|
||||||
|
use serde::Deserialize;
|
||||||
|
use validator::Validate;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::error::ClientError;
|
||||||
|
use crate::web::route::user::FullUser;
|
||||||
|
|
||||||
|
#[derive(Validate, Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct RegisterPayload {
|
||||||
|
#[validate(regex(
|
||||||
|
path = "crate::entity::user::USERNAME_REGEX",
|
||||||
|
code = "invalid_username"
|
||||||
|
))]
|
||||||
|
#[validate(length(min = 3, max = 64))]
|
||||||
|
username: String,
|
||||||
|
|
||||||
|
#[validate(length(min = 1, max = 64))]
|
||||||
|
display_name: Option<String>,
|
||||||
|
|
||||||
|
#[validate(email)]
|
||||||
|
#[validate(length(min = 3, max = 128))]
|
||||||
|
email: String,
|
||||||
|
|
||||||
|
#[validate(length(min = 3, max = 128))]
|
||||||
|
password: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn register(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
WithRejection(Json(payload), _): WithRejection<Json<RegisterPayload>, web::Error>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
payload.validate().map_err(ClientError::ValidationFailed)?;
|
||||||
|
|
||||||
|
let salt = SaltString::generate(&mut OsRng);
|
||||||
|
|
||||||
|
let password_hash = state
|
||||||
|
.hasher
|
||||||
|
.hash_password(payload.password.as_bytes(), &salt)?
|
||||||
|
.to_string();
|
||||||
|
|
||||||
|
let user = state
|
||||||
|
.database
|
||||||
|
.insert_user(
|
||||||
|
&payload.username,
|
||||||
|
payload.display_name.as_ref().map(String::as_str),
|
||||||
|
&payload.email,
|
||||||
|
&password_hash,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let system_user = state.database.select_user_by_username("system").await?;
|
||||||
|
if system_user.system {
|
||||||
|
state
|
||||||
|
.database
|
||||||
|
.procedure_create_dm_channel(user.id, system_user.id)
|
||||||
|
.await?;
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(FullUser::from(user)))
|
||||||
|
}
|
||||||
4
src/web/route/mod.rs
Normal file
4
src/web/route/mod.rs
Normal file
@@ -0,0 +1,4 @@
|
|||||||
|
pub mod auth;
|
||||||
|
pub mod server;
|
||||||
|
pub mod user;
|
||||||
|
pub mod voice;
|
||||||
17
src/web/route/server/channel/list.rs
Normal file
17
src/web/route/server/channel/list.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
use crate::{entity, web};
|
||||||
|
|
||||||
|
pub async fn list(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: UserContext,
|
||||||
|
Path(id): Path<entity::server::Id>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channels = state.database.select_server_channels(id).await?;
|
||||||
|
|
||||||
|
Ok(Json(channels))
|
||||||
|
}
|
||||||
3
src/web/route/server/channel/mod.rs
Normal file
3
src/web/route/server/channel/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod list;
|
||||||
|
|
||||||
|
pub use list::*;
|
||||||
83
src/web/route/server/create.rs
Normal file
83
src/web/route/server/create.rs
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::body::Bytes;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum_typed_multipart::{TryFromMultipart, TypedMultipart};
|
||||||
|
use validator::{Validate, ValidationError};
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::util::SerdeFieldData;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
use crate::web::error::ClientError;
|
||||||
|
use crate::web::ws;
|
||||||
|
|
||||||
|
#[derive(Debug, Validate, TryFromMultipart)]
|
||||||
|
#[try_from_multipart(rename_all = "camelCase")]
|
||||||
|
pub struct CreatePayload {
|
||||||
|
#[validate(length(min = 1, max = 32))]
|
||||||
|
name: String,
|
||||||
|
|
||||||
|
#[validate(custom(function = "validate_icon_content_type"))]
|
||||||
|
#[form_data(limit = "10MB")]
|
||||||
|
icon: Option<SerdeFieldData<Bytes>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
fn validate_icon_content_type(icon: &SerdeFieldData<Bytes>) -> Result<(), ValidationError> {
|
||||||
|
if let Some(content_type) = icon.metadata.content_type.as_deref() {
|
||||||
|
if !content_type.starts_with("image/") {
|
||||||
|
return Err(ValidationError::new("invalid_icon_content_type"));
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
return Err(ValidationError::new("missing_icon_content_type"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: UserContext,
|
||||||
|
TypedMultipart(payload): TypedMultipart<CreatePayload>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
payload.validate().map_err(ClientError::ValidationFailed)?;
|
||||||
|
|
||||||
|
let server = state
|
||||||
|
.database
|
||||||
|
.insert_server(&payload.name, None, context.user_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let role = state
|
||||||
|
.database
|
||||||
|
.insert_server_role(
|
||||||
|
Some(server.id.into()),
|
||||||
|
server.id,
|
||||||
|
"@everyone",
|
||||||
|
None,
|
||||||
|
false,
|
||||||
|
serde_json::json!({}),
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let member = state
|
||||||
|
.database
|
||||||
|
.insert_server_member(server.id, context.user_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
state
|
||||||
|
.database
|
||||||
|
.insert_server_member_role(member.id, role.id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
ws::gateway::util::send_message(
|
||||||
|
&state,
|
||||||
|
context.user_id,
|
||||||
|
ws::gateway::message::Event::AddServer {
|
||||||
|
server: server.clone(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(Json(server))
|
||||||
|
}
|
||||||
15
src/web/route/server/get.rs
Normal file
15
src/web/route/server/get.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::{entity, web};
|
||||||
|
|
||||||
|
pub async fn get(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(id): Path<entity::server::Id>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let server = state.database.select_server_by_id(id).await?;
|
||||||
|
|
||||||
|
Ok(Json(server))
|
||||||
|
}
|
||||||
16
src/web/route/server/list.rs
Normal file
16
src/web/route/server/list.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
|
||||||
|
pub async fn list(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: UserContext,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let servers = state.database.select_user_servers(context.user_id).await?;
|
||||||
|
|
||||||
|
Ok(Json(servers))
|
||||||
|
}
|
||||||
8
src/web/route/server/mod.rs
Normal file
8
src/web/route/server/mod.rs
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
pub mod channel;
|
||||||
|
mod create;
|
||||||
|
mod get;
|
||||||
|
mod list;
|
||||||
|
|
||||||
|
pub use create::*;
|
||||||
|
pub use get::*;
|
||||||
|
pub use list::*;
|
||||||
48
src/web/route/user/channel/list.rs
Normal file
48
src/web/route/user/channel/list.rs
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use crate::entity::channel;
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
use crate::web::route::user::PartialUser;
|
||||||
|
|
||||||
|
#[derive(Debug, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct RecipientChannel {
|
||||||
|
#[serde(flatten)]
|
||||||
|
pub channel: channel::Channel,
|
||||||
|
pub recipients: Vec<PartialUser>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn list(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: UserContext,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channels = state.database.select_user_channels(context.user_id).await?;
|
||||||
|
|
||||||
|
let mut recipient_channels = Vec::new();
|
||||||
|
for channel in channels {
|
||||||
|
let recipients = state.database.select_channel_recipients(channel.id).await?;
|
||||||
|
|
||||||
|
let recipients = match recipients {
|
||||||
|
Some(recipients) => recipients
|
||||||
|
.into_iter()
|
||||||
|
.filter(|user| user.id != context.user_id)
|
||||||
|
.map(PartialUser::from)
|
||||||
|
.collect(),
|
||||||
|
None => {
|
||||||
|
continue;
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
recipient_channels.push(RecipientChannel {
|
||||||
|
channel,
|
||||||
|
recipients,
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(recipient_channels))
|
||||||
|
}
|
||||||
3
src/web/route/user/channel/mod.rs
Normal file
3
src/web/route/user/channel/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod list;
|
||||||
|
|
||||||
|
pub use list::*;
|
||||||
16
src/web/route/user/get.rs
Normal file
16
src/web/route/user/get.rs
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::route::user::PartialUser;
|
||||||
|
|
||||||
|
pub async fn get_by_id(
|
||||||
|
Path(id): Path<uuid::Uuid>,
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state.database.select_user_by_id(id).await?;
|
||||||
|
|
||||||
|
Ok(Json(PartialUser::from(user)))
|
||||||
|
}
|
||||||
17
src/web/route/user/me.rs
Normal file
17
src/web/route/user/me.rs
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::State;
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
use crate::web::route::user::FullUser;
|
||||||
|
|
||||||
|
pub async fn me(
|
||||||
|
context: UserContext,
|
||||||
|
State(state): State<AppState>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state.database.select_user_by_id(context.user_id).await?;
|
||||||
|
|
||||||
|
Ok(Json(FullUser::from(user)))
|
||||||
|
}
|
||||||
60
src/web/route/user/mod.rs
Normal file
60
src/web/route/user/mod.rs
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
pub mod channel;
|
||||||
|
mod get;
|
||||||
|
mod me;
|
||||||
|
|
||||||
|
pub use get::*;
|
||||||
|
pub use me::*;
|
||||||
|
|
||||||
|
use crate::entity::user;
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, Debug)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct FullUser {
|
||||||
|
pub id: user::Id,
|
||||||
|
pub avatar_url: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
pub email: String,
|
||||||
|
pub bot: bool,
|
||||||
|
pub system: bool,
|
||||||
|
pub settings: serde_json::Value,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, Debug)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct PartialUser {
|
||||||
|
pub id: user::Id,
|
||||||
|
pub avatar_url: Option<String>,
|
||||||
|
pub username: String,
|
||||||
|
pub display_name: Option<String>,
|
||||||
|
pub bot: bool,
|
||||||
|
pub system: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<user::User> for FullUser {
|
||||||
|
fn from(user: user::User) -> Self {
|
||||||
|
Self {
|
||||||
|
id: user.id,
|
||||||
|
avatar_url: user.avatar_url,
|
||||||
|
username: user.username,
|
||||||
|
display_name: user.display_name,
|
||||||
|
email: user.email,
|
||||||
|
bot: user.bot,
|
||||||
|
system: user.system,
|
||||||
|
settings: user.settings,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<user::User> for PartialUser {
|
||||||
|
fn from(user: user::User) -> Self {
|
||||||
|
Self {
|
||||||
|
id: user.id,
|
||||||
|
avatar_url: user.avatar_url,
|
||||||
|
username: user.username,
|
||||||
|
display_name: user.display_name,
|
||||||
|
bot: user.bot,
|
||||||
|
system: user.system,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
91
src/web/route/voice/connect.rs
Normal file
91
src/web/route/voice/connect.rs
Normal file
@@ -0,0 +1,91 @@
|
|||||||
|
use axum::Json;
|
||||||
|
use axum::extract::{Path, State};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use axum_extra::extract::WithRejection;
|
||||||
|
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::context::UserContext;
|
||||||
|
use crate::{entity, web};
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Payload {
|
||||||
|
sdp: RTCSessionDescription,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, serde::Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Response {
|
||||||
|
sdp: RTCSessionDescription,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn connect(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: UserContext,
|
||||||
|
Path(channel_id): Path<entity::channel::Id>,
|
||||||
|
WithRejection(Json(payload), _): WithRejection<Json<Payload>, web::Error>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
tracing::debug!("connect to voice channel: {:?}", channel_id);
|
||||||
|
|
||||||
|
let channel = state.database.select_channel_by_id(channel_id).await?;
|
||||||
|
let channel_id = channel.id;
|
||||||
|
|
||||||
|
let room_sender = {
|
||||||
|
state
|
||||||
|
.voice_rooms
|
||||||
|
.read()
|
||||||
|
.await
|
||||||
|
.get(&channel_id)
|
||||||
|
.map(|room| room.clone())
|
||||||
|
};
|
||||||
|
|
||||||
|
let room_sender = match room_sender {
|
||||||
|
Some(room) => room,
|
||||||
|
None => {
|
||||||
|
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
let rooms = state.voice_rooms.clone();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
crate::webrtc::webrtc_task(channel_id, rx)
|
||||||
|
.await
|
||||||
|
.unwrap_or_else(|err| {
|
||||||
|
tracing::error!("webrtc task error: {:?}", err);
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut rooms = rooms.write().await;
|
||||||
|
rooms.remove(&channel_id);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut rooms = state.voice_rooms.write().await;
|
||||||
|
rooms.insert(channel_id, tx.clone());
|
||||||
|
}
|
||||||
|
|
||||||
|
tx
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
let offer = crate::webrtc::Offer {
|
||||||
|
peer_id: context.user_id,
|
||||||
|
sdp_offer: payload.sdp,
|
||||||
|
};
|
||||||
|
|
||||||
|
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||||
|
let _ = room_sender.send(crate::webrtc::OfferSignal {
|
||||||
|
offer,
|
||||||
|
response: response_tx,
|
||||||
|
});
|
||||||
|
|
||||||
|
let answer = response_rx
|
||||||
|
.await
|
||||||
|
.map_err(|_| web::error::ClientError::InternalServerError)?;
|
||||||
|
|
||||||
|
let response = Response {
|
||||||
|
sdp: answer.sdp_answer,
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(Json(response))
|
||||||
|
}
|
||||||
3
src/web/route/voice/mod.rs
Normal file
3
src/web/route/voice/mod.rs
Normal file
@@ -0,0 +1,3 @@
|
|||||||
|
mod connect;
|
||||||
|
|
||||||
|
pub use connect::*;
|
||||||
340
src/web/ws/gateway/connection.rs
Normal file
340
src/web/ws/gateway/connection.rs
Normal file
@@ -0,0 +1,340 @@
|
|||||||
|
// src/web/ws/connection.rs
|
||||||
|
|
||||||
|
use std::ops::ControlFlow;
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use axum::extract::ws::{Message as AxumMessage, WebSocket};
|
||||||
|
use base64::Engine as _; // Bring trait into scope
|
||||||
|
use futures::stream::SplitStream;
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use sha2::{Digest, Sha256};
|
||||||
|
use tokio::time::Instant;
|
||||||
|
|
||||||
|
// Use items from sibling modules within `ws`
|
||||||
|
use super::error::{self, Error as WsError};
|
||||||
|
// Assuming Event type is from ws::message
|
||||||
|
use super::message::Event as WsEvent;
|
||||||
|
use super::protocol::{
|
||||||
|
SendWsMessage, WsClientMessage, WsServerMessage, deserialize_ws_message, serialize_ws_message,
|
||||||
|
};
|
||||||
|
use super::state::{WsContext, WsState, WsUserContext};
|
||||||
|
use crate::state::AppState;
|
||||||
|
|
||||||
|
/// Main handler for an individual WebSocket connection's lifecycle.
|
||||||
|
/// Spawned by Axum upon successful WebSocket upgrade.
|
||||||
|
#[tracing::instrument(skip_all, name = "ws_connection_handler")]
|
||||||
|
pub async fn handle_socket_connection(websocket: WebSocket, app_state: AppState) {
|
||||||
|
let (ws_sink, ws_stream) = websocket.split();
|
||||||
|
|
||||||
|
let (internal_send_tx, mut internal_send_rx) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
// Writer task: consumes messages from MPSC channel and sends them to the WebSocket sink.
|
||||||
|
let writer_task = tokio::spawn(async move {
|
||||||
|
let mut ws_sink_mut = ws_sink;
|
||||||
|
while let Some(SendWsMessage {
|
||||||
|
message,
|
||||||
|
response_ch,
|
||||||
|
}) = internal_send_rx.recv().await
|
||||||
|
{
|
||||||
|
let send_result = match serialize_ws_message(message) {
|
||||||
|
Ok(ws_msg) => {
|
||||||
|
if ws_sink_mut.send(ws_msg).await.is_err() {
|
||||||
|
Err(WsError::WebSocketClosed) // Send to client failed
|
||||||
|
} else {
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
},
|
||||||
|
Err(e) => Err(e), // Serialization error itself
|
||||||
|
};
|
||||||
|
|
||||||
|
if let Some(ch) = response_ch {
|
||||||
|
if ch.send(send_result).is_err() {
|
||||||
|
// Log if the receiver of the acknowledgement was dropped, though this is unlikely
|
||||||
|
// if send_with_response is awaiting it.
|
||||||
|
tracing::debug!("Failed to send acknowledgement; receiver dropped.");
|
||||||
|
}
|
||||||
|
} else if let Err(e) = send_result {
|
||||||
|
// For fire-and-forget, log critical errors (not just WebSocketClosed).
|
||||||
|
if !matches!(e, WsError::WebSocketClosed) {
|
||||||
|
tracing::warn!("Error in fire-and-forget WebSocket send: {:?}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// MPSC channel closed, attempt to gracefully close WebSocket.
|
||||||
|
if ws_sink_mut.close().await.is_err() {
|
||||||
|
tracing::debug!("Error closing WebSocket sink; connection might be already dead.");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut context = WsContext {
|
||||||
|
connection_state: WsState::Initialize,
|
||||||
|
user_context: None,
|
||||||
|
heartbeat_interval: std::time::Duration::from_secs(30), // Assuming config path
|
||||||
|
next_ping_deadline: Instant::now(), // Will be properly set before first use
|
||||||
|
event_channel: None,
|
||||||
|
};
|
||||||
|
let processing_result = process_websocket_messages(
|
||||||
|
&mut context,
|
||||||
|
ws_stream,
|
||||||
|
&internal_send_tx, // Pass as reference
|
||||||
|
&app_state,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
// --- Cleanup ---
|
||||||
|
if let Some(user_ctx_data) = &context.user_context {
|
||||||
|
app_state
|
||||||
|
.unregister_event_connected_user(user_ctx_data.user_id, &user_ctx_data.session_key)
|
||||||
|
.await;
|
||||||
|
tracing::info!(user_id = ?user_ctx_data.user_id, session_key = %user_ctx_data.session_key, "Unregistered WebSocket user.");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Drop our sender for the event channel; receiver in `process_websocket_messages` will see this.
|
||||||
|
drop(context.event_channel.take());
|
||||||
|
|
||||||
|
// If processing loop exited with an error (not a graceful close like WebSocketClosed or HeartbeatTimeout),
|
||||||
|
// try to send a final error message to the client.
|
||||||
|
if let Err(err_to_report) = &processing_result {
|
||||||
|
if !matches!(
|
||||||
|
err_to_report,
|
||||||
|
WsError::WebSocketClosed | WsError::HeartbeatTimeout
|
||||||
|
) {
|
||||||
|
tracing::warn!(
|
||||||
|
"WebSocket processing error, attempting to notify client: {:?}",
|
||||||
|
err_to_report
|
||||||
|
);
|
||||||
|
let client_err_code = err_to_report.into_client_error();
|
||||||
|
let error_ws_message = WsServerMessage::Error {
|
||||||
|
code: client_err_code,
|
||||||
|
};
|
||||||
|
// Use new_no_response for best-effort send during shutdown.
|
||||||
|
// Ignore result as internal_send_tx might already be closed if writer_task ended.
|
||||||
|
let _ = internal_send_tx.send(SendWsMessage::new_no_response(error_ws_message));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Signal writer task to stop by dropping the MPSC sender.
|
||||||
|
drop(internal_send_tx);
|
||||||
|
// Wait for the writer task to complete its shutdown.
|
||||||
|
if let Err(e) = writer_task.await {
|
||||||
|
tracing::error!(
|
||||||
|
"WebSocket writer task panicked or encountered an error: {:?}",
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
tracing::debug!(result = ?processing_result, "WebSocket connection handler finished.");
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Main loop for processing incoming WebSocket messages and outgoing application events.
|
||||||
|
/// Manages state transitions (Initialize -> Connected) and heartbeating.
|
||||||
|
#[tracing::instrument(skip_all, fields(state = ?context.connection_state, user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
|
||||||
|
async fn process_websocket_messages(
|
||||||
|
context: &mut WsContext,
|
||||||
|
mut ws_stream: SplitStream<WebSocket>,
|
||||||
|
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
|
||||||
|
app_state: &AppState,
|
||||||
|
) -> error::Result<()> {
|
||||||
|
// Send initial heartbeat interval and set first deadline.
|
||||||
|
SendWsMessage::send_with_response(
|
||||||
|
sender,
|
||||||
|
WsServerMessage::HeartbeatInterval {
|
||||||
|
interval: context.heartbeat_interval,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
context.reset_deadline();
|
||||||
|
|
||||||
|
loop {
|
||||||
|
match context.connection_state {
|
||||||
|
WsState::Initialize => {
|
||||||
|
tokio::select! {
|
||||||
|
biased; // Prefer timeout check if multiple branches are ready
|
||||||
|
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
|
||||||
|
tracing::warn!("Initial connection timeout (no Authenticate or Ping).");
|
||||||
|
return Err(WsError::HeartbeatTimeout);
|
||||||
|
}
|
||||||
|
maybe_message = ws_stream.next() => {
|
||||||
|
match maybe_message {
|
||||||
|
Some(Ok(message)) => {
|
||||||
|
match handle_initial_message(context, message, sender, app_state).await {
|
||||||
|
Ok(ControlFlow::Continue(())) => {},
|
||||||
|
Ok(ControlFlow::Break(new_state)) => { // Authenticated
|
||||||
|
context.connection_state = new_state;
|
||||||
|
tracing::info!(user_id = ?context.user_context.as_ref().unwrap().user_id, "User authenticated, WebSocket connected.");
|
||||||
|
},
|
||||||
|
Err(e) => { // Auth failed critically or other error
|
||||||
|
return Err(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(Err(axum_ws_err)) => {
|
||||||
|
tracing::debug!("WebSocket stream error during Initialize: {:?}", axum_ws_err);
|
||||||
|
return Err(WsError::WebSocketClosed);
|
||||||
|
}
|
||||||
|
None => { // Stream closed by client
|
||||||
|
tracing::debug!("WebSocket stream ended by client during Initialize state.");
|
||||||
|
return Err(WsError::WebSocketClosed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
WsState::Connected => {
|
||||||
|
let user_ctx = context
|
||||||
|
.user_context
|
||||||
|
.as_ref()
|
||||||
|
.expect("User context must be set in Connected state");
|
||||||
|
let (_event_tx_ref, event_rx) = context
|
||||||
|
.event_channel
|
||||||
|
.as_mut()
|
||||||
|
.expect("Event channel must be set in Connected state");
|
||||||
|
|
||||||
|
tokio::select! {
|
||||||
|
biased;
|
||||||
|
_ = tokio::time::sleep_until(context.next_ping_deadline) => {
|
||||||
|
tracing::warn!(user_id = ?user_ctx.user_id, "Heartbeat timeout.");
|
||||||
|
return Err(WsError::HeartbeatTimeout);
|
||||||
|
}
|
||||||
|
// Listen for application events to send to the client
|
||||||
|
maybe_app_event = event_rx.recv() => {
|
||||||
|
if let Some(app_event_data) = maybe_app_event {
|
||||||
|
SendWsMessage::send_with_response(sender, WsServerMessage::Event { event: app_event_data }).await?;
|
||||||
|
// Sending an app event doesn't reset the client's ping requirement.
|
||||||
|
} else {
|
||||||
|
// Event channel closed (e.g., AppState unregistered, or system shutdown signal)
|
||||||
|
tracing::info!(user_id = ?user_ctx.user_id, "Event channel closed, closing WebSocket.");
|
||||||
|
return Ok(()); // Graceful shutdown signaled by closed event channel
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Listen for messages from the client (e.g., Ping)
|
||||||
|
maybe_ws_message = ws_stream.next() => {
|
||||||
|
match maybe_ws_message {
|
||||||
|
Some(Ok(message)) => {
|
||||||
|
handle_connected_message(context, message, sender).await?;
|
||||||
|
}
|
||||||
|
Some(Err(axum_ws_err)) => {
|
||||||
|
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream error during Connected: {:?}", axum_ws_err);
|
||||||
|
return Err(WsError::WebSocketClosed);
|
||||||
|
}
|
||||||
|
None => { // Stream closed by client
|
||||||
|
tracing::debug!(user_id = ?user_ctx.user_id, "WebSocket stream ended by client during Connected state.");
|
||||||
|
return Err(WsError::WebSocketClosed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handles messages received when the connection is in the `Initialize` state.
|
||||||
|
/// Expects `Authenticate` to transition to `Connected`, or `Ping` to stay in `Initialize`.
|
||||||
|
#[tracing::instrument(skip_all, fields(state = ?context.connection_state))]
|
||||||
|
async fn handle_initial_message(
|
||||||
|
context: &mut WsContext,
|
||||||
|
message: AxumMessage,
|
||||||
|
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
|
||||||
|
app_state: &AppState,
|
||||||
|
) -> error::Result<ControlFlow<WsState, ()>> {
|
||||||
|
// Break(NewState) or Continue(())
|
||||||
|
match deserialize_ws_message(message)? {
|
||||||
|
WsClientMessage::Authenticate { token } => {
|
||||||
|
// IMPORTANT: Adjust the call below to your actual token validation logic.
|
||||||
|
// Assuming `get_context_from_token` returns `Result<crate::web::context::UserContext, YourAuthError>`
|
||||||
|
match crate::web::middleware::get_context_from_token(
|
||||||
|
&app_state, // Example: Pass necessary parts of AppState
|
||||||
|
&token,
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(auth_user_context) => {
|
||||||
|
// auth_user_context is `crate::web::context::UserContext`
|
||||||
|
let user_id = auth_user_context.user_id;
|
||||||
|
|
||||||
|
let (event_tx, event_rx) = tokio::sync::mpsc::unbounded_channel::<WsEvent>();
|
||||||
|
context.event_channel = Some((event_tx.clone(), event_rx));
|
||||||
|
|
||||||
|
let random_key_part = rand::random::<u64>();
|
||||||
|
let current_session_key = {
|
||||||
|
let mut hasher = Sha256::new();
|
||||||
|
hasher.update(token.as_bytes());
|
||||||
|
hasher.update(user_id.to_string().as_bytes());
|
||||||
|
hasher.update(&random_key_part.to_be_bytes());
|
||||||
|
base64::engine::general_purpose::URL_SAFE_NO_PAD.encode(hasher.finalize())
|
||||||
|
};
|
||||||
|
|
||||||
|
context.user_context = Some(WsUserContext {
|
||||||
|
user_id,
|
||||||
|
session_key: current_session_key.clone(),
|
||||||
|
});
|
||||||
|
|
||||||
|
app_state
|
||||||
|
.register_event_connected_user(
|
||||||
|
user_id,
|
||||||
|
current_session_key.clone(),
|
||||||
|
event_tx, // This is ws::state::EventSender -> mpsc::UnboundedSender<ws::message::Event>
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
SendWsMessage::send_with_response(
|
||||||
|
sender,
|
||||||
|
WsServerMessage::AuthenticateAccepted {
|
||||||
|
user_id,
|
||||||
|
session_key: current_session_key.clone(),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await?;
|
||||||
|
// Deadline is reset by the caller upon ControlFlow::Break
|
||||||
|
Ok(ControlFlow::Break(WsState::Connected))
|
||||||
|
},
|
||||||
|
Err(_auth_err) => {
|
||||||
|
tracing::warn!(token = %token, "Authentication failed for token.");
|
||||||
|
// Send AuthenticateDenied, then the connection will be closed by HeartbeatTimeout or by returning error.
|
||||||
|
// We send response to ensure client gets the denial before we might drop connection.
|
||||||
|
let _ = SendWsMessage::send_with_response(
|
||||||
|
sender,
|
||||||
|
WsServerMessage::AuthenticateDenied,
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
Err(WsError::AuthenticationFailed) // This will terminate process_websocket_messages
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
WsClientMessage::Ping => {
|
||||||
|
context.reset_deadline(); // Reset deadline on successful ping
|
||||||
|
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
|
||||||
|
Ok(ControlFlow::Continue(()))
|
||||||
|
},
|
||||||
|
// Per original code, only Authenticate and Ping are expected in Initialize.
|
||||||
|
// If WsClientMessage has other variants, this might need adjustment.
|
||||||
|
#[allow(unreachable_patterns)]
|
||||||
|
_ => {
|
||||||
|
tracing::warn!("Unexpected message type received during Initialize state.");
|
||||||
|
Err(WsError::UnexpectedMessageType)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Handles messages received when the connection is in the `Connected` state.
|
||||||
|
/// Primarily expects `Ping` messages to keep the connection alive.
|
||||||
|
#[tracing::instrument(skip_all, fields(user_id = ?context.user_context.as_ref().map(|uc| uc.user_id)))]
|
||||||
|
async fn handle_connected_message(
|
||||||
|
context: &mut WsContext, // Although not heavily used here, good for consistency and tracing
|
||||||
|
message: AxumMessage,
|
||||||
|
sender: &tokio::sync::mpsc::UnboundedSender<SendWsMessage>, // Changed to reference
|
||||||
|
) -> error::Result<()> {
|
||||||
|
match deserialize_ws_message(message)? {
|
||||||
|
WsClientMessage::Ping => {
|
||||||
|
tracing::debug!("Ping received.");
|
||||||
|
context.reset_deadline();
|
||||||
|
SendWsMessage::send_with_response(sender, WsServerMessage::Pong).await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
},
|
||||||
|
other_message => {
|
||||||
|
tracing::warn!(message_type = ?other_message, "Unexpected message type received during Connected state.");
|
||||||
|
Err(WsError::UnexpectedMessageType)
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
46
src/web/ws/gateway/error.rs
Normal file
46
src/web/ws/gateway/error.rs
Normal file
@@ -0,0 +1,46 @@
|
|||||||
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
|
|
||||||
|
#[derive(Debug, derive_more::From, derive_more::Display)]
|
||||||
|
pub enum Error {
|
||||||
|
#[from]
|
||||||
|
Axum(axum::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
Json(serde_json::Error),
|
||||||
|
|
||||||
|
#[from]
|
||||||
|
AcknowledgementError(tokio::sync::oneshot::error::RecvError),
|
||||||
|
|
||||||
|
UnexpectedMessageType,
|
||||||
|
|
||||||
|
WrongMessageType,
|
||||||
|
|
||||||
|
WebSocketClosed,
|
||||||
|
|
||||||
|
HeartbeatTimeout,
|
||||||
|
AuthenticationFailed,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum ClientError {
|
||||||
|
DeserializationError,
|
||||||
|
NotAuthenticated,
|
||||||
|
AlreadyAuthenticated,
|
||||||
|
HeartbeatTimeout,
|
||||||
|
|
||||||
|
Unknown,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error {
|
||||||
|
pub fn into_client_error(&self) -> ClientError {
|
||||||
|
match self {
|
||||||
|
Error::HeartbeatTimeout => ClientError::HeartbeatTimeout,
|
||||||
|
Error::Json(_) => ClientError::DeserializationError,
|
||||||
|
Error::UnexpectedMessageType => ClientError::Unknown,
|
||||||
|
Error::WrongMessageType => ClientError::Unknown,
|
||||||
|
Error::WebSocketClosed => ClientError::Unknown,
|
||||||
|
_ => ClientError::Unknown,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
15
src/web/ws/gateway/message.rs
Normal file
15
src/web/ws/gateway/message.rs
Normal file
@@ -0,0 +1,15 @@
|
|||||||
|
use crate::entity;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, serde::Serialize)]
|
||||||
|
#[serde(tag = "type", content = "data")]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum Event {
|
||||||
|
AddServer { server: entity::server::Server },
|
||||||
|
RemoveServer { server_id: entity::server::Id },
|
||||||
|
|
||||||
|
AddDmChannel { channel: entity::channel::Channel },
|
||||||
|
RemoveDmChannel { channel_id: entity::channel::Id },
|
||||||
|
|
||||||
|
AddServerChannel { channel: entity::channel::Channel },
|
||||||
|
RemoveServerChannel { channel_id: entity::channel::Id },
|
||||||
|
}
|
||||||
29
src/web/ws/gateway/mod.rs
Normal file
29
src/web/ws/gateway/mod.rs
Normal file
@@ -0,0 +1,29 @@
|
|||||||
|
use axum::extract::{State, WebSocketUpgrade};
|
||||||
|
use axum::response::IntoResponse;
|
||||||
|
use base64::Engine;
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
use sha2::Digest;
|
||||||
|
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::ws::gateway::connection::handle_socket_connection;
|
||||||
|
use crate::web::ws::gateway::state::EventSender;
|
||||||
|
|
||||||
|
mod connection;
|
||||||
|
mod error;
|
||||||
|
pub mod message;
|
||||||
|
mod protocol;
|
||||||
|
mod state;
|
||||||
|
pub mod util;
|
||||||
|
|
||||||
|
#[derive(Debug, Default)]
|
||||||
|
pub struct EventWsState {
|
||||||
|
pub connection_instance: DashMap<String, EventSender>,
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_handler(
|
||||||
|
State(app_state): State<AppState>,
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
) -> crate::web::error::Result<impl IntoResponse> {
|
||||||
|
Ok(ws.on_upgrade(|socket| handle_socket_connection(socket, app_state)))
|
||||||
|
}
|
||||||
99
src/web/ws/gateway/protocol.rs
Normal file
99
src/web/ws/gateway/protocol.rs
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
// src/web/ws/protocol.rs
|
||||||
|
|
||||||
|
use axum::extract::ws::Message as AxumMessage;
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::time::Duration;
|
||||||
|
|
||||||
|
use super::error::{self, ClientError, Error as WsError};
|
||||||
|
use super::message as ws_local_message; // For ws::message::Event
|
||||||
|
use crate::entity;
|
||||||
|
use crate::util as crate_root_util; // For crate::util::serialize_duration_seconds
|
||||||
|
|
||||||
|
#[derive(Debug, Serialize)]
|
||||||
|
#[serde(tag = "type", content = "data")]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum WsServerMessage {
|
||||||
|
HeartbeatInterval {
|
||||||
|
#[serde(serialize_with = "crate_root_util::serialize_duration_seconds")]
|
||||||
|
interval: Duration,
|
||||||
|
},
|
||||||
|
AuthenticateDenied,
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
AuthenticateAccepted {
|
||||||
|
user_id: entity::user::Id,
|
||||||
|
session_key: String,
|
||||||
|
},
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
Event {
|
||||||
|
event: ws_local_message::Event, // Assumes Event is defined in ws::message
|
||||||
|
},
|
||||||
|
Pong,
|
||||||
|
Error {
|
||||||
|
code: ClientError,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Deserialize)]
|
||||||
|
#[serde(tag = "type", content = "data")]
|
||||||
|
#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
|
||||||
|
pub enum WsClientMessage {
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
Authenticate {
|
||||||
|
token: String,
|
||||||
|
},
|
||||||
|
Ping,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Deserializes an Axum WebSocket message into a `WsClientMessage`.
|
||||||
|
pub fn deserialize_ws_message(message: AxumMessage) -> error::Result<WsClientMessage> {
|
||||||
|
match message {
|
||||||
|
AxumMessage::Text(text) => serde_json::from_str(&text).map_err(WsError::from),
|
||||||
|
AxumMessage::Close(_) => Err(WsError::WebSocketClosed),
|
||||||
|
_ => Err(WsError::WrongMessageType), // e.g. Binary, Ping, Pong from axum::Message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Serializes a `WsServerMessage` into an Axum WebSocket message.
|
||||||
|
pub fn serialize_ws_message(message: WsServerMessage) -> error::Result<AxumMessage> {
|
||||||
|
serde_json::to_string(&message)
|
||||||
|
.map(Into::into)
|
||||||
|
.map(AxumMessage::Text)
|
||||||
|
.map_err(WsError::from)
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Wrapper for messages sent over an internal MPSC channel to the WebSocket writer task.
|
||||||
|
/// Includes an optional one-shot channel for acknowledgements or error reporting back from the writer.
|
||||||
|
pub struct SendWsMessage {
|
||||||
|
pub message: WsServerMessage,
|
||||||
|
pub response_ch: Option<tokio::sync::oneshot::Sender<error::Result<()>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl SendWsMessage {
|
||||||
|
/// Sends a message over the MPSC channel and awaits a response via a oneshot channel.
|
||||||
|
pub async fn send_with_response(
|
||||||
|
tx: &tokio::sync::mpsc::UnboundedSender<Self>, // Changed to reference
|
||||||
|
message: WsServerMessage,
|
||||||
|
) -> error::Result<()> {
|
||||||
|
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
|
||||||
|
let send_message = SendWsMessage {
|
||||||
|
message,
|
||||||
|
response_ch: Some(response_tx),
|
||||||
|
};
|
||||||
|
|
||||||
|
if tx.send(send_message).is_err() {
|
||||||
|
Err(WsError::WebSocketClosed) // MPSC channel closed, writer task likely dead
|
||||||
|
} else {
|
||||||
|
// Wait for the writer task to acknowledge the send attempt.
|
||||||
|
// This will return Ok(Ok(())) on success, Ok(Err(e)) on write error, or Err on channel error.
|
||||||
|
response_rx.await? // Propagates RecvError into WsError::AcknowledgementError
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Creates a new message for fire-and-forget sending (no response/acknowledgement expected).
|
||||||
|
pub fn new_no_response(message: WsServerMessage) -> Self {
|
||||||
|
SendWsMessage {
|
||||||
|
message,
|
||||||
|
response_ch: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
49
src/web/ws/gateway/state.rs
Normal file
49
src/web/ws/gateway/state.rs
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
// src/web/ws/state.rs
|
||||||
|
|
||||||
|
use std::time::{Duration};
|
||||||
|
|
||||||
|
use tokio::sync::mpsc;
|
||||||
|
|
||||||
|
use super::message;
|
||||||
|
use crate::entity; // For entity::user::Id // For ws::message::Event used in EventSender/Receiver
|
||||||
|
|
||||||
|
/// Represents the current state of a single WebSocket connection.
|
||||||
|
#[derive(Debug, Eq, PartialEq, Clone, Copy)]
|
||||||
|
pub enum WsState {
|
||||||
|
Initialize, // Connection established, awaiting authentication
|
||||||
|
Connected, // Authenticated and operational
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Contextual information for an authenticated WebSocket user session.
|
||||||
|
#[derive(Debug, Clone)] // Clone might be useful
|
||||||
|
pub struct WsUserContext {
|
||||||
|
pub user_id: entity::user::Id,
|
||||||
|
pub session_key: String, // Unique key for this specific WebSocket session instance
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Sender part of an MPSC channel used to send `ws::message::Event`s to a connected client.
|
||||||
|
pub type EventSender = mpsc::UnboundedSender<message::Event>;
|
||||||
|
/// Receiver part of an MPSC channel used by a connection task to receive `ws::message::Event`s.
|
||||||
|
pub type EventReceiver = mpsc::UnboundedReceiver<message::Event>;
|
||||||
|
|
||||||
|
/// Holds the full context for a single WebSocket connection's lifecycle.
|
||||||
|
/// This struct is managed per-connection.
|
||||||
|
pub struct WsContext {
|
||||||
|
pub connection_state: WsState,
|
||||||
|
pub user_context: Option<WsUserContext>,
|
||||||
|
pub heartbeat_interval: Duration,
|
||||||
|
pub next_ping_deadline: tokio::time::Instant,
|
||||||
|
/// Channel for receiving application-specific events to be sent to this client.
|
||||||
|
/// The `EventSender` (tx) part is given to `AppState` for broadcasting.
|
||||||
|
/// The `EventReceiver` (rx) part is polled by the connection task.
|
||||||
|
pub event_channel: Option<(EventSender, EventReceiver)>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl WsContext {
|
||||||
|
/// Resets the ping deadline based on the current time and heartbeat interval.
|
||||||
|
/// This should be called after successfully receiving a ping from the client
|
||||||
|
/// or after sending a message that implies activity (like Pong or initial auth).
|
||||||
|
pub fn reset_deadline(&mut self) {
|
||||||
|
self.next_ping_deadline = tokio::time::Instant::now() + self.heartbeat_interval;
|
||||||
|
}
|
||||||
|
}
|
||||||
14
src/web/ws/gateway/util.rs
Normal file
14
src/web/ws/gateway/util.rs
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
use crate::entity;
|
||||||
|
use crate::state::AppState;
|
||||||
|
use crate::web::ws::gateway::message;
|
||||||
|
|
||||||
|
pub async fn send_message(state: &AppState, user_id: entity::user::Id, message: message::Event) {
|
||||||
|
let connected_users = state.event_connected_users.read().await;
|
||||||
|
if let Some(state) = connected_users.get(&user_id) {
|
||||||
|
for instance in state.connection_instance.iter() {
|
||||||
|
if let Err(e) = instance.send(message.clone()) {
|
||||||
|
tracing::error!("failed to send message: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
1
src/web/ws/mod.rs
Normal file
1
src/web/ws/mod.rs
Normal file
@@ -0,0 +1 @@
|
|||||||
|
pub mod gateway;
|
||||||
288
src/webrtc/mod.rs
Normal file
288
src/webrtc/mod.rs
Normal file
@@ -0,0 +1,288 @@
|
|||||||
|
use std::sync::Arc;
|
||||||
|
|
||||||
|
use dashmap::DashMap;
|
||||||
|
use tracing::Instrument;
|
||||||
|
use webrtc::api::interceptor_registry::register_default_interceptors;
|
||||||
|
use webrtc::api::media_engine::MIME_TYPE_OPUS;
|
||||||
|
use webrtc::api::{API, APIBuilder};
|
||||||
|
use webrtc::interceptor::registry::Registry;
|
||||||
|
use webrtc::peer_connection::configuration::RTCConfiguration;
|
||||||
|
use webrtc::peer_connection::peer_connection_state::RTCPeerConnectionState;
|
||||||
|
use webrtc::peer_connection::sdp::session_description::RTCSessionDescription;
|
||||||
|
use webrtc::rtp_transceiver::rtp_codec::{RTCRtpCodecCapability, RTPCodecType};
|
||||||
|
use webrtc::track::track_local::track_local_static_rtp::TrackLocalStaticRTP;
|
||||||
|
use webrtc::track::track_local::{TrackLocal, TrackLocalWriter};
|
||||||
|
use webrtc::track::track_remote::TrackRemote;
|
||||||
|
|
||||||
|
use crate::entity;
|
||||||
|
|
||||||
|
type PeerId = entity::user::Id;
|
||||||
|
type RoomId = entity::channel::Id;
|
||||||
|
|
||||||
|
struct PeerState {
|
||||||
|
peer_id: PeerId,
|
||||||
|
peer_connection: Arc<webrtc::peer_connection::RTCPeerConnection>,
|
||||||
|
outgoing_audio_track: Arc<TrackLocalStaticRTP>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct RoomState {
|
||||||
|
room_id: RoomId,
|
||||||
|
peers: DashMap<PeerId, PeerState>,
|
||||||
|
close_signal: tokio::sync::mpsc::UnboundedSender<()>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct Offer {
|
||||||
|
pub peer_id: PeerId,
|
||||||
|
pub sdp_offer: RTCSessionDescription,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct AnswerSignal {
|
||||||
|
pub sdp_answer: RTCSessionDescription,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct OfferSignal {
|
||||||
|
pub offer: Offer,
|
||||||
|
pub response: tokio::sync::oneshot::Sender<AnswerSignal>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(signal))]
|
||||||
|
pub async fn webrtc_task(
|
||||||
|
room_id: RoomId,
|
||||||
|
signal: tokio::sync::mpsc::UnboundedReceiver<OfferSignal>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
tracing::info!("Starting WebRTC task");
|
||||||
|
|
||||||
|
let (close_signal, mut close_receiver) = tokio::sync::mpsc::unbounded_channel();
|
||||||
|
|
||||||
|
let state = Arc::new(RoomState {
|
||||||
|
room_id,
|
||||||
|
peers: DashMap::new(),
|
||||||
|
close_signal,
|
||||||
|
});
|
||||||
|
|
||||||
|
let mut signal = signal;
|
||||||
|
let mut media_engine = webrtc::api::media_engine::MediaEngine::default();
|
||||||
|
media_engine.register_default_codecs()?;
|
||||||
|
let mut registry = Registry::new();
|
||||||
|
registry = register_default_interceptors(registry, &mut media_engine)?;
|
||||||
|
|
||||||
|
let api = APIBuilder::new()
|
||||||
|
.with_media_engine(media_engine)
|
||||||
|
.with_interceptor_registry(registry)
|
||||||
|
.build();
|
||||||
|
|
||||||
|
let api = Arc::new(api);
|
||||||
|
|
||||||
|
loop {
|
||||||
|
tokio::select! {
|
||||||
|
Some(signal) = signal.recv() => {
|
||||||
|
let room_state = state.clone();
|
||||||
|
let api = api.clone();
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = handle_peer(api, room_state, signal).await {
|
||||||
|
tracing::error!("error handling peer: {}", e);
|
||||||
|
}
|
||||||
|
}.instrument(tracing::Span::current()));
|
||||||
|
}
|
||||||
|
_ = close_receiver.recv() => {
|
||||||
|
tracing::debug!("WebRTC task stopped");
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(api, room_state, offer_signal), fields(peer_id = %offer_signal.offer.peer_id))]
|
||||||
|
async fn handle_peer(
|
||||||
|
api: Arc<API>,
|
||||||
|
room_state: Arc<RoomState>,
|
||||||
|
offer_signal: OfferSignal,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
tracing::debug!("handling peer");
|
||||||
|
let config = RTCConfiguration {
|
||||||
|
..Default::default()
|
||||||
|
};
|
||||||
|
|
||||||
|
let outgoing_track = Arc::new(TrackLocalStaticRTP::new(
|
||||||
|
RTCRtpCodecCapability {
|
||||||
|
mime_type: MIME_TYPE_OPUS.to_string(),
|
||||||
|
..Default::default()
|
||||||
|
},
|
||||||
|
format!("audio-{}", offer_signal.offer.peer_id),
|
||||||
|
format!("room-{}", room_state.room_id),
|
||||||
|
));
|
||||||
|
|
||||||
|
let peer_connection = Arc::new(api.new_peer_connection(config).await?);
|
||||||
|
|
||||||
|
let rtp_sender = peer_connection
|
||||||
|
.add_track(Arc::clone(&outgoing_track) as Arc<dyn TrackLocal + Send + Sync>)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
tracing::debug!("added track to peer connection: {:?}", rtp_sender);
|
||||||
|
|
||||||
|
// Read RTCP packets for the outgoing track (important for feedback like NACKs, PLI)
|
||||||
|
let outgoing_track_ = Arc::clone(&outgoing_track);
|
||||||
|
tokio::spawn(
|
||||||
|
async move {
|
||||||
|
let mut rtcp_buf = vec![0u8; 1500];
|
||||||
|
while let Ok((_, _)) = rtp_sender.read(&mut rtcp_buf).await {
|
||||||
|
// Process RTCP if needed (e.g., bandwidth estimation, custom feedback)
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
"RTCP reader loop for outgoing track {} ended",
|
||||||
|
outgoing_track_.id()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
.instrument(tracing::Span::current()),
|
||||||
|
);
|
||||||
|
|
||||||
|
let room_state_ = room_state.clone();
|
||||||
|
peer_connection.on_peer_connection_state_change(Box::new(move |state| {
|
||||||
|
let room_state_ = room_state_.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
tracing::debug!("peer connection state changed: {:?}", state);
|
||||||
|
if state == RTCPeerConnectionState::Disconnected
|
||||||
|
|| state == RTCPeerConnectionState::Failed
|
||||||
|
{
|
||||||
|
tracing::debug!("peer connection closed");
|
||||||
|
cleanup_peer(room_state_, offer_signal.offer.peer_id).await;
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
|
||||||
|
let room_state_ = room_state.clone();
|
||||||
|
peer_connection.on_track(Box::new(move |track, _receiver, _transceiver| {
|
||||||
|
tracing::debug!("track received: {:?}", track);
|
||||||
|
|
||||||
|
let room_state_ = room_state_.clone();
|
||||||
|
Box::pin(async move {
|
||||||
|
if track.kind() == RTPCodecType::Audio {
|
||||||
|
tracing::debug!("audio track received: {:?}", track);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) =
|
||||||
|
forward_audio_track(room_state_, offer_signal.offer.peer_id, track).await
|
||||||
|
{
|
||||||
|
tracing::error!("error forwarding audio track: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
tracing::warn!("received non-audio track: {:?}", track);
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}));
|
||||||
|
|
||||||
|
peer_connection
|
||||||
|
.set_remote_description(offer_signal.offer.sdp_offer)
|
||||||
|
.await?;
|
||||||
|
let answer = peer_connection.create_answer(None).await?;
|
||||||
|
|
||||||
|
let mut gathering_complete = peer_connection.gathering_complete_promise().await;
|
||||||
|
|
||||||
|
peer_connection.set_local_description(answer).await?;
|
||||||
|
|
||||||
|
gathering_complete.recv().await;
|
||||||
|
|
||||||
|
tracing::debug!("ICE gathering complete");
|
||||||
|
|
||||||
|
let local_description = peer_connection
|
||||||
|
.local_description()
|
||||||
|
.await
|
||||||
|
.ok_or_else(|| anyhow::anyhow!("failed to get local description after setting it"))?;
|
||||||
|
|
||||||
|
let peer_state = PeerState {
|
||||||
|
peer_id: offer_signal.offer.peer_id,
|
||||||
|
peer_connection: Arc::clone(&peer_connection),
|
||||||
|
outgoing_audio_track: Arc::clone(&outgoing_track),
|
||||||
|
};
|
||||||
|
|
||||||
|
{
|
||||||
|
if let Some(old_peer) = room_state
|
||||||
|
.peers
|
||||||
|
.insert(offer_signal.offer.peer_id, peer_state)
|
||||||
|
{
|
||||||
|
let _ = old_peer.peer_connection.close().await;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = offer_signal.response.send(AnswerSignal {
|
||||||
|
sdp_answer: local_description,
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(room_state, track), fields(room_id = %room_state.room_id, peer_id = %peer_id))]
|
||||||
|
async fn forward_audio_track(
|
||||||
|
room_state: Arc<RoomState>,
|
||||||
|
peer_id: PeerId,
|
||||||
|
track: Arc<TrackRemote>,
|
||||||
|
) -> anyhow::Result<()> {
|
||||||
|
let mut rtp_buf = vec![0u8; 1500];
|
||||||
|
while let Ok((rtp_packet, _attr)) = track.read(&mut rtp_buf).await {
|
||||||
|
let other_peer_tracks = room_state
|
||||||
|
.peers
|
||||||
|
.iter()
|
||||||
|
.filter_map(|pair| {
|
||||||
|
let peer_state = pair.value();
|
||||||
|
if peer_state.peer_id != peer_id {
|
||||||
|
Some(peer_state.outgoing_audio_track.clone())
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
}
|
||||||
|
})
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
if other_peer_tracks.is_empty() {
|
||||||
|
// tracing::warn!("no other peers to forward audio track to");
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
|
||||||
|
let write_futures = other_peer_tracks
|
||||||
|
.iter()
|
||||||
|
.map(|outgoing_track| outgoing_track.write_rtp(&rtp_packet));
|
||||||
|
|
||||||
|
let results = futures::future::join_all(write_futures).await;
|
||||||
|
for result in results {
|
||||||
|
if let Err(e) = result {
|
||||||
|
tracing::error!("error writing RTP packet: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracing::debug!(
|
||||||
|
"RTP Read loop ended for track {} from peer {}",
|
||||||
|
track.id(),
|
||||||
|
peer_id
|
||||||
|
);
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tracing::instrument(skip(room_state), fields(room_id = %room_state.room_id))]
|
||||||
|
async fn cleanup_peer(room_state: Arc<RoomState>, peer_id: PeerId) {
|
||||||
|
tracing::debug!("cleaning up peer");
|
||||||
|
|
||||||
|
if let Some((_, peer_state)) = room_state.peers.remove(&peer_id) {
|
||||||
|
tracing::debug!("removed peer");
|
||||||
|
let pc = Arc::clone(&peer_state.peer_connection);
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = pc.close().await {
|
||||||
|
if !matches!(e, webrtc::Error::ErrConnectionClosed) {
|
||||||
|
tracing::warn!("error closing peer connection: {}", e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::debug!("peer connection closed");
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
if room_state.peers.is_empty() {
|
||||||
|
tracing::debug!("no more peers in room, closing room");
|
||||||
|
let _ = room_state.close_signal.send(());
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user