z
This commit is contained in:
114
Cargo.lock
generated
114
Cargo.lock
generated
@@ -119,6 +119,8 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
|||||||
dependencies = [
|
dependencies = [
|
||||||
"async-trait",
|
"async-trait",
|
||||||
"axum-core",
|
"axum-core",
|
||||||
|
"axum-macros",
|
||||||
|
"base64 0.21.7",
|
||||||
"bytes",
|
"bytes",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"http",
|
"http",
|
||||||
@@ -137,8 +139,10 @@ dependencies = [
|
|||||||
"serde_json",
|
"serde_json",
|
||||||
"serde_path_to_error",
|
"serde_path_to_error",
|
||||||
"serde_urlencoded",
|
"serde_urlencoded",
|
||||||
|
"sha1",
|
||||||
"sync_wrapper 1.0.1",
|
"sync_wrapper 1.0.1",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-tungstenite",
|
||||||
"tower",
|
"tower",
|
||||||
"tower-layer",
|
"tower-layer",
|
||||||
"tower-service",
|
"tower-service",
|
||||||
@@ -166,6 +170,18 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "axum-macros"
|
||||||
|
version = "0.4.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa"
|
||||||
|
dependencies = [
|
||||||
|
"heck",
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "backtrace"
|
name = "backtrace"
|
||||||
version = "0.3.71"
|
version = "0.3.71"
|
||||||
@@ -341,6 +357,12 @@ dependencies = [
|
|||||||
"typenum",
|
"typenum",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "data-encoding"
|
||||||
|
version = "2.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "deranged"
|
name = "deranged"
|
||||||
version = "0.3.11"
|
version = "0.3.11"
|
||||||
@@ -465,6 +487,21 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures"
|
||||||
|
version = "0.3.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
|
||||||
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
|
"futures-core",
|
||||||
|
"futures-executor",
|
||||||
|
"futures-io",
|
||||||
|
"futures-sink",
|
||||||
|
"futures-task",
|
||||||
|
"futures-util",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-channel"
|
name = "futures-channel"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
@@ -503,6 +540,23 @@ dependencies = [
|
|||||||
"parking_lot 0.11.2",
|
"parking_lot 0.11.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-io"
|
||||||
|
version = "0.3.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "futures-macro"
|
||||||
|
version = "0.3.30"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||||
|
dependencies = [
|
||||||
|
"proc-macro2",
|
||||||
|
"quote",
|
||||||
|
"syn 2.0.60",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "futures-sink"
|
name = "futures-sink"
|
||||||
version = "0.3.30"
|
version = "0.3.30"
|
||||||
@@ -521,9 +575,13 @@ version = "0.3.30"
|
|||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
|
"futures-channel",
|
||||||
"futures-core",
|
"futures-core",
|
||||||
|
"futures-io",
|
||||||
|
"futures-macro",
|
||||||
"futures-sink",
|
"futures-sink",
|
||||||
"futures-task",
|
"futures-task",
|
||||||
|
"memchr",
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
"pin-utils",
|
"pin-utils",
|
||||||
"slab",
|
"slab",
|
||||||
@@ -931,11 +989,13 @@ dependencies = [
|
|||||||
"chrono",
|
"chrono",
|
||||||
"derive_more",
|
"derive_more",
|
||||||
"figment",
|
"figment",
|
||||||
|
"futures",
|
||||||
"jsonwebtoken",
|
"jsonwebtoken",
|
||||||
"serde",
|
"serde",
|
||||||
"serde_json",
|
"serde_json",
|
||||||
"sqlx",
|
"sqlx",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tower-http",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-appender",
|
"tracing-appender",
|
||||||
"tracing-subscriber",
|
"tracing-subscriber",
|
||||||
@@ -1845,6 +1905,18 @@ dependencies = [
|
|||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-tungstenite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||||
|
dependencies = [
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"tokio",
|
||||||
|
"tungstenite",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "toml"
|
name = "toml"
|
||||||
version = "0.8.12"
|
version = "0.8.12"
|
||||||
@@ -1895,6 +1967,23 @@ dependencies = [
|
|||||||
"tracing",
|
"tracing",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tower-http"
|
||||||
|
version = "0.5.2"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags 2.5.0",
|
||||||
|
"bytes",
|
||||||
|
"http",
|
||||||
|
"http-body",
|
||||||
|
"http-body-util",
|
||||||
|
"pin-project-lite",
|
||||||
|
"tower-layer",
|
||||||
|
"tower-service",
|
||||||
|
"tracing",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tower-layer"
|
name = "tower-layer"
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
@@ -1982,6 +2071,25 @@ dependencies = [
|
|||||||
"tracing-log",
|
"tracing-log",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tungstenite"
|
||||||
|
version = "0.21.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||||
|
dependencies = [
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"data-encoding",
|
||||||
|
"http",
|
||||||
|
"httparse",
|
||||||
|
"log",
|
||||||
|
"rand",
|
||||||
|
"sha1",
|
||||||
|
"thiserror",
|
||||||
|
"url",
|
||||||
|
"utf-8",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "typenum"
|
name = "typenum"
|
||||||
version = "1.17.0"
|
version = "1.17.0"
|
||||||
@@ -2053,6 +2161,12 @@ dependencies = [
|
|||||||
"percent-encoding",
|
"percent-encoding",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "utf-8"
|
||||||
|
version = "0.7.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "valuable"
|
name = "valuable"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
|
|||||||
@@ -5,10 +5,11 @@
|
|||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
anyhow = "1.0.82"
|
anyhow = "1.0.82"
|
||||||
axum = "0.7.5"
|
axum = { version = "0.7.5", features = ["macros", "ws"] }
|
||||||
chrono = { version = "0.4.38", features = ["serde"] }
|
chrono = { version = "0.4.38", features = ["serde"] }
|
||||||
derive_more = "0.99.17"
|
derive_more = "0.99.17"
|
||||||
figment = { version = "0.10.18", features = ["env", "toml"] }
|
figment = { version = "0.10.18", features = ["env", "toml"] }
|
||||||
|
futures = "0.3.30"
|
||||||
jsonwebtoken = "9.3.0"
|
jsonwebtoken = "9.3.0"
|
||||||
serde = { version = "1.0.198", features = ["derive"] }
|
serde = { version = "1.0.198", features = ["derive"] }
|
||||||
serde_json = "1.0.116"
|
serde_json = "1.0.116"
|
||||||
@@ -20,6 +21,7 @@ derive_more = "0.99.17"
|
|||||||
"sqlite"
|
"sqlite"
|
||||||
] }
|
] }
|
||||||
tokio = { version = "1.37.0", features = ["full"] }
|
tokio = { version = "1.37.0", features = ["full"] }
|
||||||
|
tower-http = { version = "0.5.2", features = ["cors", "trace"] }
|
||||||
tracing = "0.1.40"
|
tracing = "0.1.40"
|
||||||
tracing-appender = "0.2.3"
|
tracing-appender = "0.2.3"
|
||||||
tracing-subscriber = { version = "0.3.18", features = ["chrono", "env-filter"] }
|
tracing-subscriber = { version = "0.3.18", features = ["chrono", "env-filter"] }
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
CREATE TABLE IF NOT EXISTS users (
|
CREATE TABLE IF NOT EXISTS user (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
username VARCHAR UNIQUE,
|
username VARCHAR UNIQUE,
|
||||||
password VARCHAR,
|
password VARCHAR,
|
||||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
avatar VARCHAR
|
||||||
);
|
);
|
||||||
27
migrations/20240506141114_channel.sql
Normal file
27
migrations/20240506141114_channel.sql
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS `channel` (
|
||||||
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
`name` VARCHAR,
|
||||||
|
`last_message_id` INTEGER,
|
||||||
|
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `channel_user` (
|
||||||
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
`channel_id` INTEGER NOT NULL,
|
||||||
|
`user_id` INTEGER NOT NULL,
|
||||||
|
`admin` BOOLEAN NOT NULL DEFAULT 0,
|
||||||
|
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY(`user_id`) REFERENCES `user`(`id`) ON DELETE CASCADE
|
||||||
|
);
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS `message` (
|
||||||
|
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
`channel_id` INTEGER NOT NULL,
|
||||||
|
`author_id` INTEGER,
|
||||||
|
`content` TEXT,
|
||||||
|
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
|
||||||
|
FOREIGN KEY(`author_id`) REFERENCES `user`(`id`) ON DELETE
|
||||||
|
SET
|
||||||
|
NULL
|
||||||
|
);
|
||||||
6
migrations/20240518191103_tokens.sql
Normal file
6
migrations/20240518191103_tokens.sql
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
CREATE TABLE IF NOT EXISTS tokens (
|
||||||
|
token TEXT NOT NULL PRIMARY KEY,
|
||||||
|
user_id INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE,
|
||||||
|
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||||
|
expires_at DATETIME
|
||||||
|
);
|
||||||
229
src/database.rs
229
src/database.rs
@@ -1,7 +1,10 @@
|
|||||||
use derive_more::{Display, Error, From};
|
use derive_more::{Display, Error, From};
|
||||||
use sqlx::migrate::Migrator;
|
use sqlx::migrate::Migrator;
|
||||||
|
|
||||||
use crate::{config, entity};
|
use crate::{
|
||||||
|
config,
|
||||||
|
entity::{self, Channel},
|
||||||
|
};
|
||||||
|
|
||||||
static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
|
static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
|
||||||
|
|
||||||
@@ -25,9 +28,9 @@ impl Database {
|
|||||||
Ok(Self { pool })
|
Ok(Self { pool })
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_user_by_id(&self, id: entity::ShortId) -> Result<entity::User> {
|
pub async fn get_user_by_id(&self, user_id: entity::ShortId) -> Result<entity::User> {
|
||||||
let user = sqlx::query_as("SELECT * FROM users WHERE id = $1")
|
let user = sqlx::query_as("SELECT * FROM user WHERE id = $1")
|
||||||
.bind(id)
|
.bind(user_id)
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await?
|
.await?
|
||||||
.ok_or(Error::UserDoesNotExists)?;
|
.ok_or(Error::UserDoesNotExists)?;
|
||||||
@@ -36,7 +39,7 @@ impl Database {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_user_by_username(&self, username: &str) -> Result<entity::User> {
|
pub async fn get_user_by_username(&self, username: &str) -> Result<entity::User> {
|
||||||
let user = sqlx::query_as("SELECT * FROM users WHERE username = $1")
|
let user = sqlx::query_as("SELECT * FROM user WHERE username = $1")
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.fetch_optional(&self.pool)
|
.fetch_optional(&self.pool)
|
||||||
.await?
|
.await?
|
||||||
@@ -51,9 +54,8 @@ impl Database {
|
|||||||
return Err(Error::UserAlreadyExists);
|
return Err(Error::UserAlreadyExists);
|
||||||
}
|
}
|
||||||
|
|
||||||
let id = sqlx::query_scalar(
|
let id =
|
||||||
"INSERT INTO users(username, password) VALUES ($1, $2) RETURNING id",
|
sqlx::query_scalar("INSERT INTO user(username, password) VALUES ($1, $2) RETURNING id")
|
||||||
)
|
|
||||||
.bind(username)
|
.bind(username)
|
||||||
.bind(password)
|
.bind(password)
|
||||||
.fetch_one(&self.pool)
|
.fetch_one(&self.pool)
|
||||||
@@ -63,6 +65,213 @@ impl Database {
|
|||||||
|
|
||||||
Ok(user)
|
Ok(user)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
pub async fn create_token(
|
||||||
|
&self,
|
||||||
|
id: i32,
|
||||||
|
token: &str,
|
||||||
|
expires_at: chrono::DateTime<chrono::Utc>,
|
||||||
|
) -> Result<entity::Token> {
|
||||||
|
sqlx::query("INSERT INTO tokens(user_id, token, expires_at) VALUES ($1, $2, $3)")
|
||||||
|
.bind(id)
|
||||||
|
.bind(token)
|
||||||
|
.bind(expires_at)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.get_token(token).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_token(&self, token: &str) -> Result<entity::Token> {
|
||||||
|
let token = sqlx::query_as("SELECT * FROM tokens WHERE token = $1")
|
||||||
|
.bind(token)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::UserDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(token)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_channel_by_id(&self, channel_id: entity::ShortId) -> Result<entity::Channel> {
|
||||||
|
let channel = sqlx::query_as("SELECT * FROM channel WHERE id = $1")
|
||||||
|
.bind(channel_id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::ChannelDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(channel)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_all_user_channels(
|
||||||
|
&self,
|
||||||
|
user_id: entity::ShortId,
|
||||||
|
) -> Result<Vec<entity::Channel>> {
|
||||||
|
let channels = sqlx::query_as("SELECT channel.* FROM channel INNER JOIN channel_user ON channel.id = channel_user.channel_id WHERE user_id = $1")
|
||||||
|
.bind(user_id)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(channels)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn update_last_message_channel(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
message_id: entity::LongId,
|
||||||
|
) -> Result<Channel> {
|
||||||
|
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
|
||||||
|
.bind(message_id)
|
||||||
|
.bind(channel_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.get_channel_by_id(channel_id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn add_user_to_channel(
|
||||||
|
&self,
|
||||||
|
user_id: entity::ShortId,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
admin: bool,
|
||||||
|
) -> Result<()> {
|
||||||
|
sqlx::query("INSERT INTO channel_user(user_id, channel_id, admin) VALUES ($1, $2, $3)")
|
||||||
|
.bind(user_id)
|
||||||
|
.bind(channel_id)
|
||||||
|
.bind(admin)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_user_from_channel(
|
||||||
|
&self,
|
||||||
|
user_id: entity::ShortId,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
) -> Result<()> {
|
||||||
|
sqlx::query("DELETE FROM channel_user WHERE user_id = $1 AND channel_id = $2")
|
||||||
|
.bind(user_id)
|
||||||
|
.bind(channel_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_channel(
|
||||||
|
&self,
|
||||||
|
user_id: entity::ShortId,
|
||||||
|
name: &str,
|
||||||
|
) -> Result<entity::Channel> {
|
||||||
|
let id = sqlx::query_scalar("INSERT INTO channel(name) VALUES ($1) RETURNING id")
|
||||||
|
.bind(name)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.add_user_to_channel(user_id, id, true).await?;
|
||||||
|
self.get_channel_by_id(id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn delete_channel(&self, channel_id: entity::ShortId) -> Result<()> {
|
||||||
|
sqlx::query("DELETE FROM channel WHERE id = $1")
|
||||||
|
.bind(channel_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn set_channel_last_message_id(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
message_id: entity::LongId,
|
||||||
|
) -> Result<()> {
|
||||||
|
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
|
||||||
|
.bind(message_id)
|
||||||
|
.bind(channel_id)
|
||||||
|
.execute(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_channel_users(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
) -> Result<Vec<(entity::User, bool)>> {
|
||||||
|
let user_ids =
|
||||||
|
sqlx::query_as("SELECT user_id, admin FROM channel_user WHERE channel_id = $1")
|
||||||
|
.bind(channel_id)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let users = user_ids.iter().map(|(user_id, admin)| async move {
|
||||||
|
let user = self.get_user_by_id(*user_id).await;
|
||||||
|
user.map(|u| (u, *admin))
|
||||||
|
});
|
||||||
|
|
||||||
|
let users = futures::future::try_join_all(users).await?;
|
||||||
|
|
||||||
|
Ok(users)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_message_by_id(&self, message_id: entity::LongId) -> Result<entity::Message> {
|
||||||
|
let message = sqlx::query_as("SELECT * FROM message WHERE id = $1")
|
||||||
|
.bind(message_id)
|
||||||
|
.fetch_optional(&self.pool)
|
||||||
|
.await?
|
||||||
|
.ok_or(Error::MessageDoesNotExists)?;
|
||||||
|
|
||||||
|
Ok(message)
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn create_message(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
user_id: entity::ShortId,
|
||||||
|
content: &str,
|
||||||
|
) -> Result<entity::Message> {
|
||||||
|
let id = sqlx::query_scalar(
|
||||||
|
"INSERT INTO message(channel_id, author_id, content) VALUES ($1, $2, $3) RETURNING id",
|
||||||
|
)
|
||||||
|
.bind(channel_id)
|
||||||
|
.bind(user_id)
|
||||||
|
.bind(content)
|
||||||
|
.fetch_one(&self.pool)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
self.set_channel_last_message_id(channel_id, id).await?;
|
||||||
|
self.get_message_by_id(id).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_messages_by_channel_id(
|
||||||
|
&self,
|
||||||
|
channel_id: entity::ShortId,
|
||||||
|
limit: i64,
|
||||||
|
before_id: Option<entity::LongId>,
|
||||||
|
) -> Result<Vec<entity::Message>> {
|
||||||
|
let messages = match before_id {
|
||||||
|
Some(before_id) => sqlx::query_as::<_, entity::Message>(
|
||||||
|
"SELECT * FROM message WHERE channel_id = $1 AND id < $2 ORDER BY id DESC LIMIT $3",
|
||||||
|
)
|
||||||
|
.bind(channel_id)
|
||||||
|
.bind(before_id)
|
||||||
|
.bind(limit)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?,
|
||||||
|
None => {
|
||||||
|
sqlx::query_as::<_, entity::Message>(
|
||||||
|
"SELECT * FROM message WHERE channel_id = $1 ORDER BY id DESC LIMIT $2",
|
||||||
|
)
|
||||||
|
.bind(channel_id)
|
||||||
|
.bind(limit)
|
||||||
|
.fetch_all(&self.pool)
|
||||||
|
.await?
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
Ok(messages)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub type Result<T> = std::result::Result<T, Error>;
|
pub type Result<T> = std::result::Result<T, Error>;
|
||||||
@@ -75,4 +284,8 @@ pub enum Error {
|
|||||||
Sqlx(sqlx::Error),
|
Sqlx(sqlx::Error),
|
||||||
UserDoesNotExists,
|
UserDoesNotExists,
|
||||||
UserAlreadyExists,
|
UserAlreadyExists,
|
||||||
|
|
||||||
|
ChannelDoesNotExists,
|
||||||
|
|
||||||
|
MessageDoesNotExists,
|
||||||
}
|
}
|
||||||
|
|||||||
10
src/entity/channel.rs
Normal file
10
src/entity/channel.rs
Normal file
@@ -0,0 +1,10 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Channel {
|
||||||
|
pub id: super::ShortId,
|
||||||
|
pub name: String,
|
||||||
|
pub last_message_id: Option<super::LongId>,
|
||||||
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||||
|
}
|
||||||
@@ -1,9 +1,12 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
#[derive(Clone, sqlx::FromRow)]
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct Message {
|
pub struct Message {
|
||||||
pub id: super::LongId,
|
pub id: super::LongId,
|
||||||
pub user_id: super::ShortId,
|
pub channel_id: super::ShortId,
|
||||||
|
pub author_id: super::ShortId,
|
||||||
pub content: String,
|
pub content: String,
|
||||||
pub created_at: chrono::DateTime<Utc>,
|
pub created_at: chrono::DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,13 +1,17 @@
|
|||||||
#![allow(unused)]
|
#![allow(unused)]
|
||||||
|
|
||||||
|
mod channel;
|
||||||
mod log;
|
mod log;
|
||||||
mod message;
|
mod message;
|
||||||
mod secret;
|
mod secret;
|
||||||
|
mod token;
|
||||||
mod user;
|
mod user;
|
||||||
|
|
||||||
|
pub use channel::Channel;
|
||||||
pub use log::Log;
|
pub use log::Log;
|
||||||
pub use message::Message;
|
pub use message::Message;
|
||||||
pub use secret::Secret;
|
pub use secret::Secret;
|
||||||
|
pub use token::Token;
|
||||||
pub use user::User;
|
pub use user::User;
|
||||||
|
|
||||||
pub type ShortId = i32;
|
pub type ShortId = i32;
|
||||||
|
|||||||
@@ -1,4 +1,6 @@
|
|||||||
#[derive(Clone, sqlx::FromRow)]
|
use serde::Serialize;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
pub struct Secret {
|
pub struct Secret {
|
||||||
pub id: super::ShortId,
|
pub id: super::ShortId,
|
||||||
pub user_id: super::ShortId,
|
pub user_id: super::ShortId,
|
||||||
|
|||||||
12
src/entity/token.rs
Normal file
12
src/entity/token.rs
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
|
use super::ShortId;
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct Token {
|
||||||
|
pub token: String,
|
||||||
|
pub user_id: ShortId,
|
||||||
|
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||||
|
pub expires_at: chrono::DateTime<chrono::Utc>,
|
||||||
|
}
|
||||||
@@ -1,9 +1,13 @@
|
|||||||
use chrono::Utc;
|
use chrono::Utc;
|
||||||
|
use serde::Serialize;
|
||||||
|
|
||||||
#[derive(Clone, sqlx::FromRow)]
|
#[derive(Debug, Clone, Hash, sqlx::FromRow, Serialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct User {
|
pub struct User {
|
||||||
pub id: super::ShortId,
|
pub id: super::ShortId,
|
||||||
pub username: String,
|
pub username: String,
|
||||||
|
#[serde(skip_serializing)]
|
||||||
pub password: String,
|
pub password: String,
|
||||||
|
pub avatar: Option<String>,
|
||||||
pub created_at: chrono::DateTime<Utc>,
|
pub created_at: chrono::DateTime<Utc>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ pub struct Claims {
|
|||||||
pub user_id: i32,
|
pub user_id: i32,
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn generate_jwt(user_id: i32) -> Result<String> {
|
pub fn generate_jwt(user_id: i32, expires_at: chrono::DateTime<chrono::Utc>) -> Result<String> {
|
||||||
let claims = Claims {
|
let claims = Claims {
|
||||||
exp: (Local::now() + Duration::days(1)).timestamp() as usize,
|
exp: expires_at.timestamp() as usize,
|
||||||
user_id,
|
user_id,
|
||||||
};
|
};
|
||||||
|
|
||||||
@@ -28,6 +28,7 @@ pub fn generate_jwt(user_id: i32) -> Result<String> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
pub fn verify_jwt(token: &str) -> Result<i32> {
|
pub fn verify_jwt(token: &str) -> Result<i32> {
|
||||||
|
tracing::debug!("Verifying token: {}", token);
|
||||||
let token_data = jsonwebtoken::decode::<Claims>(
|
let token_data = jsonwebtoken::decode::<Claims>(
|
||||||
token,
|
token,
|
||||||
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
|
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
|
||||||
|
|||||||
@@ -14,7 +14,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
|||||||
let _guard = log::initialize()?;
|
let _guard = log::initialize()?;
|
||||||
|
|
||||||
let database = Database::init().await?;
|
let database = Database::init().await?;
|
||||||
let context = AppState { database };
|
let context = AppState {
|
||||||
|
database,
|
||||||
|
connected_users: Default::default(),
|
||||||
|
};
|
||||||
|
|
||||||
web::run(context).await?;
|
web::run(context).await?;
|
||||||
|
|
||||||
|
|||||||
13
src/state.rs
13
src/state.rs
@@ -1,6 +1,17 @@
|
|||||||
use crate::database::Database;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
|
use tokio::sync::{mpsc::UnboundedSender, RwLock};
|
||||||
|
|
||||||
|
use crate::{database::Database, entity::ShortId, web::ws};
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||||
|
pub struct WebSocketKey {
|
||||||
|
pub user_id: ShortId,
|
||||||
|
pub token: String,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct AppState {
|
pub struct AppState {
|
||||||
pub database: Database,
|
pub database: Database,
|
||||||
|
pub connected_users: Arc<RwLock<HashMap<WebSocketKey, UnboundedSender<ws::Message>>>>,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,9 +2,10 @@ use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
|
|||||||
|
|
||||||
use crate::entity;
|
use crate::entity;
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Debug, Clone, Hash)]
|
||||||
pub struct Context {
|
pub struct Context {
|
||||||
pub user_id: entity::ShortId,
|
pub user: entity::User,
|
||||||
|
pub token: String,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
@@ -15,9 +16,9 @@ impl<S: Send + Sync> FromRequestParts<S> for Context {
|
|||||||
parts
|
parts
|
||||||
.extensions
|
.extensions
|
||||||
.get::<ContextResult>()
|
.get::<ContextResult>()
|
||||||
.ok_or(super::Error::ContextError(Error::NotInRequest))?
|
.ok_or(super::Error::Context(Error::NotInRequest))?
|
||||||
.clone()
|
.clone()
|
||||||
.map_err(super::Error::ContextError)
|
.map_err(super::Error::Context)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,18 +12,41 @@ pub type Result<T> = std::result::Result<T, Error>;
|
|||||||
#[derive(Debug, From)]
|
#[derive(Debug, From)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
#[from]
|
#[from]
|
||||||
ContextError(context::Error),
|
Context(context::Error),
|
||||||
#[from]
|
#[from]
|
||||||
DatabaseError(database::Error),
|
Database(database::Error),
|
||||||
#[from]
|
#[from]
|
||||||
JWT(jwt::Error),
|
Jwt(jwt::Error),
|
||||||
|
|
||||||
WrongPassword,
|
WrongPassword,
|
||||||
|
NotAllowed,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Error {
|
impl Error {
|
||||||
pub fn as_client_error(&self) -> ClientError {
|
pub fn as_client_error_and_status(&self) -> (ClientError, StatusCode) {
|
||||||
ClientError::HahaError
|
match self {
|
||||||
|
Error::Context(_) | Error::Jwt(_) => {
|
||||||
|
(ClientError::NotAuthorized, StatusCode::UNAUTHORIZED)
|
||||||
|
}
|
||||||
|
Error::Database(database::Error::UserAlreadyExists) => {
|
||||||
|
(ClientError::UserAlreadyExists, StatusCode::CONFLICT)
|
||||||
|
}
|
||||||
|
Error::Database(database::Error::UserDoesNotExists) => {
|
||||||
|
(ClientError::UserDoesNotExists, StatusCode::NOT_FOUND)
|
||||||
|
}
|
||||||
|
Error::Database(database::Error::ChannelDoesNotExists) => {
|
||||||
|
(ClientError::ChannelDoesNotExists, StatusCode::NOT_FOUND)
|
||||||
|
}
|
||||||
|
Error::Database(database::Error::MessageDoesNotExists) => {
|
||||||
|
(ClientError::MessageDoesNotExists, StatusCode::NOT_FOUND)
|
||||||
|
}
|
||||||
|
Error::WrongPassword => (ClientError::WrongPassword, StatusCode::UNAUTHORIZED),
|
||||||
|
Error::NotAllowed => (ClientError::NotAllowed, StatusCode::FORBIDDEN),
|
||||||
|
_ => (
|
||||||
|
ClientError::InternalServerError,
|
||||||
|
StatusCode::INTERNAL_SERVER_ERROR,
|
||||||
|
),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,5 +62,22 @@ impl IntoResponse for Error {
|
|||||||
|
|
||||||
#[derive(Debug, Error, Display)]
|
#[derive(Debug, Error, Display)]
|
||||||
pub enum ClientError {
|
pub enum ClientError {
|
||||||
HahaError,
|
#[display(fmt = "Not authorized")]
|
||||||
|
NotAuthorized,
|
||||||
|
#[display(fmt = "Wrong password")]
|
||||||
|
WrongPassword,
|
||||||
|
#[display(fmt = "Not allowed")]
|
||||||
|
NotAllowed,
|
||||||
|
|
||||||
|
#[display(fmt = "User already exists")]
|
||||||
|
UserAlreadyExists,
|
||||||
|
#[display(fmt = "User does not exists")]
|
||||||
|
UserDoesNotExists,
|
||||||
|
#[display(fmt = "Channel does not exists")]
|
||||||
|
ChannelDoesNotExists,
|
||||||
|
#[display(fmt = "Message does not exists")]
|
||||||
|
MessageDoesNotExists,
|
||||||
|
|
||||||
|
#[display(fmt = "Internal server error")]
|
||||||
|
InternalServerError,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,11 @@ use axum::{
|
|||||||
extract::{Request, State},
|
extract::{Request, State},
|
||||||
http::{header, HeaderMap},
|
http::{header, HeaderMap},
|
||||||
middleware::Next,
|
middleware::Next,
|
||||||
response::{IntoResponse, Response},
|
response::Response,
|
||||||
|
Extension,
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
database::Database,
|
|
||||||
jwt,
|
jwt,
|
||||||
state::AppState,
|
state::AppState,
|
||||||
web::{
|
web::{
|
||||||
@@ -16,7 +16,7 @@ use crate::{
|
|||||||
};
|
};
|
||||||
|
|
||||||
pub async fn require_context(
|
pub async fn require_context(
|
||||||
context: ContextResult,
|
Extension(context): Extension<ContextResult>,
|
||||||
request: Request,
|
request: Request,
|
||||||
next: Next,
|
next: Next,
|
||||||
) -> web::Result<Response> {
|
) -> web::Result<Response> {
|
||||||
@@ -46,12 +46,17 @@ async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextRe
|
|||||||
return Err(context::Error::WrongTokenType);
|
return Err(context::Error::WrongTokenType);
|
||||||
}
|
}
|
||||||
let token = &token[7..];
|
let token = &token[7..];
|
||||||
|
|
||||||
|
get_context_from_token(state, token).await
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_context_from_token(state: AppState, token: &str) -> context::ContextResult {
|
||||||
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
|
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
|
||||||
let _user = state
|
let user = state
|
||||||
.database
|
.database
|
||||||
.get_user_by_id(user_id)
|
.get_user_by_id(user_id)
|
||||||
.await
|
.await
|
||||||
.map_err(|_| context::Error::Model)?;
|
.map_err(|_| context::Error::Model)?;
|
||||||
|
|
||||||
Ok(context::Context { user_id })
|
Ok(context::Context { user, token: token.to_owned() })
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,5 @@
|
|||||||
mod auth;
|
mod auth;
|
||||||
mod response_map;
|
mod response_map;
|
||||||
|
|
||||||
pub use auth::require_context;
|
pub use auth::*;
|
||||||
pub use auth::resolve_context;
|
pub use response_map::*;
|
||||||
pub use response_map::response_map;
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use axum::{extract::Request, middleware::Next, response::IntoResponse};
|
use axum::{extract::Request, middleware::Next, response::IntoResponse, Json};
|
||||||
|
use serde_json::json;
|
||||||
|
|
||||||
use crate::web;
|
use crate::web;
|
||||||
|
|
||||||
@@ -9,10 +10,18 @@ pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
|
|||||||
|
|
||||||
let error = response.extensions().get::<Arc<web::Error>>();
|
let error = response.extensions().get::<Arc<web::Error>>();
|
||||||
|
|
||||||
let error_response = error.map(|e| {
|
if error.is_some() {
|
||||||
let client_error = e.as_client_error();
|
tracing::error!("{:?}", error);
|
||||||
|
}
|
||||||
|
|
||||||
client_error.to_string().into_response()
|
let error_response = error.map(|e| {
|
||||||
|
let client_error = e.as_client_error_and_status();
|
||||||
|
|
||||||
|
(
|
||||||
|
client_error.1,
|
||||||
|
Json(json!({"error": client_error.0.to_string()})),
|
||||||
|
)
|
||||||
|
.into_response()
|
||||||
});
|
});
|
||||||
|
|
||||||
error_response.unwrap_or(response)
|
error_response.unwrap_or(response)
|
||||||
|
|||||||
@@ -1,11 +1,13 @@
|
|||||||
use crate::{config, state};
|
use crate::{config, state, web::routes::message};
|
||||||
|
|
||||||
mod context;
|
mod context;
|
||||||
mod error;
|
mod error;
|
||||||
pub mod middlware;
|
pub mod middlware;
|
||||||
pub mod routes;
|
pub mod routes;
|
||||||
|
pub mod ws;
|
||||||
|
|
||||||
pub use error::{Error, Result};
|
pub use error::{Error, Result};
|
||||||
|
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
|
||||||
|
|
||||||
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
||||||
let config = config::config();
|
let config = config::config();
|
||||||
@@ -22,18 +24,66 @@ pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn router(state: state::AppState) -> axum::Router {
|
fn router(state: state::AppState) -> axum::Router {
|
||||||
axum::Router::new()
|
use self::routes::user;
|
||||||
.route("/user/login", axum::routing::post(routes::user::login))
|
use axum::routing::*;
|
||||||
.route(
|
use axum::Router;
|
||||||
"/user/register",
|
|
||||||
axum::routing::post(routes::user::register),
|
let cors = tower_http::cors::CorsLayer::new()
|
||||||
)
|
.allow_origin(AllowOrigin::any())
|
||||||
.with_state(state.clone())
|
.allow_methods(AllowMethods::any())
|
||||||
.layer(axum::middleware::from_fn_with_state(
|
.allow_headers(AllowHeaders::any());
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// websocket
|
||||||
|
.route("/ws/:token", get(ws::ws_handler))
|
||||||
|
// unprotected
|
||||||
|
.route("/user/login", post(user::login))
|
||||||
|
.route("/user/register", post(user::register))
|
||||||
|
// protected
|
||||||
|
.nest("/", protected_router())
|
||||||
|
.layer(tower_http::trace::TraceLayer::new_for_http())
|
||||||
|
.route_layer(axum::middleware::from_fn_with_state(
|
||||||
state.clone(),
|
state.clone(),
|
||||||
middlware::resolve_context,
|
middlware::resolve_context,
|
||||||
))
|
))
|
||||||
.layer(axum::middleware::from_fn(middlware::response_map))
|
.layer(axum::middleware::from_fn(middlware::response_map))
|
||||||
|
.layer(cors)
|
||||||
|
.with_state(state.clone())
|
||||||
|
}
|
||||||
|
|
||||||
|
fn protected_router() -> axum::Router<state::AppState> {
|
||||||
|
use self::routes::channel;
|
||||||
|
use self::routes::user;
|
||||||
|
use axum::routing::*;
|
||||||
|
use axum::Router;
|
||||||
|
|
||||||
|
Router::new()
|
||||||
|
// user
|
||||||
|
.route("/user", get(user::get_by_username))
|
||||||
|
.route("/user/me", get(user::get_self))
|
||||||
|
.route("/user/:user_id", get(user::get_by_id))
|
||||||
|
// channel
|
||||||
|
.route("/channel", get(channel::get_all_user_channels))
|
||||||
|
.route("/channel/:channel_id", get(channel::get_by_id))
|
||||||
|
.route("/channel", post(channel::create_channel))
|
||||||
|
.route("/channel/:channel_id", delete(channel::delete_channel))
|
||||||
|
.route(
|
||||||
|
"/channel/:channel_id/user/:user_id",
|
||||||
|
post(channel::add_user_to_channel),
|
||||||
|
)
|
||||||
|
.route(
|
||||||
|
"/channel/:channel_id/user/:user_id",
|
||||||
|
delete(channel::remove_user_from_channel),
|
||||||
|
)
|
||||||
|
// message
|
||||||
|
.route("/message/:message_id", get(message::get_by_id))
|
||||||
|
.route("/message", post(message::create))
|
||||||
|
.route(
|
||||||
|
"/channel/:channel_id/message",
|
||||||
|
get(message::get_messages_by_channel_id),
|
||||||
|
)
|
||||||
|
// middleware
|
||||||
|
.route_layer(axum::middleware::from_fn(middlware::require_context))
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown_signal() {
|
async fn shutdown_signal() {
|
||||||
|
|||||||
34
src/web/routes/channel/create.rs
Normal file
34
src/web/routes/channel/create.rs
Normal file
@@ -0,0 +1,34 @@
|
|||||||
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
state::AppState,
|
||||||
|
web::{self, ws},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn create_channel(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: web::context::Context,
|
||||||
|
Json(body): Json<CreateChannel>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel = state
|
||||||
|
.database
|
||||||
|
.create_channel(context.user.id, &body.name)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let channel_users = state.database.get_channel_users(channel.id).await?;
|
||||||
|
|
||||||
|
ws::broadcast_message(ws::Message::CreateChannel(channel.clone()), &state, |key| {
|
||||||
|
channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, _admin)| user.id == key.user_id)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(Json(channel))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct CreateChannel {
|
||||||
|
name: String,
|
||||||
|
}
|
||||||
40
src/web/routes/channel/delete.rs
Normal file
40
src/web/routes/channel/delete.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
entity::ShortId,
|
||||||
|
state::AppState,
|
||||||
|
web::{self, ws},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn delete_channel(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(channel_id): Path<ShortId>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, admin)| user.id == context.user.id && *admin)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
state.database.delete_channel(channel_id).await?;
|
||||||
|
|
||||||
|
ws::broadcast_message(
|
||||||
|
ws::Message::DeleteChannel { id: channel_id },
|
||||||
|
&state,
|
||||||
|
|key| {
|
||||||
|
channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, _admin)| user.id == key.user_id)
|
||||||
|
},
|
||||||
|
)
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
37
src/web/routes/channel/get.rs
Normal file
37
src/web/routes/channel/get.rs
Normal file
@@ -0,0 +1,37 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{entity::ShortId, state::AppState, web};
|
||||||
|
|
||||||
|
pub async fn get_by_id(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(channel_id): Path<ShortId>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel = state.database.get_channel_by_id(channel_id).await?;
|
||||||
|
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.into_iter()
|
||||||
|
.any(|(user, _admin)| user.id == context.user.id)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(channel))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_all_user_channels(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channels = state
|
||||||
|
.database
|
||||||
|
.get_all_user_channels(context.user.id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(channels))
|
||||||
|
}
|
||||||
9
src/web/routes/channel/mod.rs
Normal file
9
src/web/routes/channel/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
|||||||
|
mod create;
|
||||||
|
mod delete;
|
||||||
|
mod get;
|
||||||
|
mod user;
|
||||||
|
|
||||||
|
pub use create::*;
|
||||||
|
pub use delete::*;
|
||||||
|
pub use get::*;
|
||||||
|
pub use user::*;
|
||||||
50
src/web/routes/channel/user.rs
Normal file
50
src/web/routes/channel/user.rs
Normal file
@@ -0,0 +1,50 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{entity::ShortId, state::AppState, web};
|
||||||
|
|
||||||
|
pub async fn add_user_to_channel(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.into_iter()
|
||||||
|
.any(|(user, admin)| user.id == context.user.id && admin)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
state
|
||||||
|
.database
|
||||||
|
.add_user_to_channel(user_id, channel_id, false)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn remove_user_from_channel(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.into_iter()
|
||||||
|
.any(|(user, admin)| user.id == context.user.id && admin)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
state
|
||||||
|
.database
|
||||||
|
.remove_user_from_channel(user_id, channel_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
54
src/web/routes/message/create.rs
Normal file
54
src/web/routes/message/create.rs
Normal file
@@ -0,0 +1,54 @@
|
|||||||
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
entity::ShortId,
|
||||||
|
state::AppState,
|
||||||
|
web::{self, ws},
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn create(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: web::context::Context,
|
||||||
|
Json(body): Json<CreateMessage>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel_users = state.database.get_channel_users(body.channel_id).await?;
|
||||||
|
if !channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, _admin)| user.id == context.user.id)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let message = state
|
||||||
|
.database
|
||||||
|
.create_message(body.channel_id, context.user.id, &body.content)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
let channel = state
|
||||||
|
.database
|
||||||
|
.update_last_message_channel(body.channel_id, message.id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
ws::broadcast_message(ws::Message::CreateMessage(message.clone()), &state, |key| {
|
||||||
|
channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, _admin)| user.id == key.user_id)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
ws::broadcast_message(ws::Message::UpdateChannel(channel), &state, |key| {
|
||||||
|
channel_users
|
||||||
|
.iter()
|
||||||
|
.any(|(user, _admin)| user.id == key.user_id)
|
||||||
|
})
|
||||||
|
.await;
|
||||||
|
|
||||||
|
Ok(Json(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct CreateMessage {
|
||||||
|
content: String,
|
||||||
|
channel_id: ShortId,
|
||||||
|
}
|
||||||
61
src/web/routes/message/get.rs
Normal file
61
src/web/routes/message/get.rs
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, Query, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
entity::{LongId, ShortId},
|
||||||
|
state::AppState,
|
||||||
|
web,
|
||||||
|
};
|
||||||
|
|
||||||
|
pub async fn get_by_id(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(message_id): Path<LongId>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let message = state.database.get_message_by_id(message_id).await?;
|
||||||
|
if message.author_id != context.user.id {
|
||||||
|
let channel_users = state.database.get_channel_users(message.channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.into_iter()
|
||||||
|
.any(|(user, _admin)| user.id == context.user.id)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
Ok(Json(message))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_messages_by_channel_id(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(channel_id): Path<ShortId>,
|
||||||
|
context: web::context::Context,
|
||||||
|
Query(body): Query<GetByChannelId>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||||
|
|
||||||
|
if !channel_users
|
||||||
|
.into_iter()
|
||||||
|
.any(|(user, _admin)| user.id == context.user.id)
|
||||||
|
{
|
||||||
|
return Err(web::Error::NotAllowed);
|
||||||
|
}
|
||||||
|
|
||||||
|
let messages = state
|
||||||
|
.database
|
||||||
|
.get_messages_by_channel_id(channel_id, body.limit.unwrap_or(50), body.before_id)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(messages))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GetByChannelId {
|
||||||
|
before_id: Option<LongId>,
|
||||||
|
limit: Option<i64>,
|
||||||
|
}
|
||||||
5
src/web/routes/message/mod.rs
Normal file
5
src/web/routes/message/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
|||||||
|
mod create;
|
||||||
|
mod get;
|
||||||
|
|
||||||
|
pub use create::*;
|
||||||
|
pub use get::*;
|
||||||
@@ -1 +1,3 @@
|
|||||||
|
pub mod channel;
|
||||||
|
pub mod message;
|
||||||
pub mod user;
|
pub mod user;
|
||||||
|
|||||||
40
src/web/routes/user/get.rs
Normal file
40
src/web/routes/user/get.rs
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{Path, State},
|
||||||
|
response::IntoResponse,
|
||||||
|
Json,
|
||||||
|
};
|
||||||
|
|
||||||
|
use crate::{entity::ShortId, state::AppState, web};
|
||||||
|
|
||||||
|
pub async fn get_self(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
context: web::context::Context,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state.database.get_user_by_id(context.user.id).await?;
|
||||||
|
|
||||||
|
Ok(Json(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_by_id(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(user_id): Path<ShortId>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state.database.get_user_by_id(user_id).await?;
|
||||||
|
|
||||||
|
Ok(Json(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn get_by_username(
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Json(body): Json<GetByUsernameRequest>,
|
||||||
|
) -> web::Result<impl IntoResponse> {
|
||||||
|
let user = state.database.get_user_by_username(&body.username).await?;
|
||||||
|
|
||||||
|
Ok(Json(user))
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(serde::Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
pub struct GetByUsernameRequest {
|
||||||
|
pub username: String,
|
||||||
|
}
|
||||||
@@ -1,6 +1,6 @@
|
|||||||
use axum::{extract::State, response::IntoResponse, Json};
|
use axum::{extract::State, response::IntoResponse, Json};
|
||||||
|
use chrono::{Duration, Utc};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::json;
|
|
||||||
|
|
||||||
use crate::{jwt, state::AppState, web};
|
use crate::{jwt, state::AppState, web};
|
||||||
|
|
||||||
@@ -17,14 +17,20 @@ pub async fn login(
|
|||||||
return Err(web::Error::WrongPassword);
|
return Err(web::Error::WrongPassword);
|
||||||
}
|
}
|
||||||
|
|
||||||
let token = jwt::generate_jwt(user.id)?;
|
let expires_at = Utc::now() + Duration::hours(24);
|
||||||
|
|
||||||
Ok(Json(json!({
|
let token = jwt::generate_jwt(user.id, expires_at)?;
|
||||||
"token": token
|
|
||||||
})))
|
let token = state
|
||||||
|
.database
|
||||||
|
.create_token(user.id, &token, expires_at)
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(Json(token))
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct LoginPayload {
|
pub struct LoginPayload {
|
||||||
username: String,
|
username: String,
|
||||||
password: String,
|
password: String,
|
||||||
|
|||||||
@@ -1,5 +1,7 @@
|
|||||||
|
mod get;
|
||||||
mod login;
|
mod login;
|
||||||
mod register;
|
mod register;
|
||||||
|
|
||||||
|
pub use get::*;
|
||||||
pub use login::login;
|
pub use login::login;
|
||||||
pub use register::register;
|
pub use register::register;
|
||||||
|
|||||||
@@ -21,6 +21,7 @@ pub async fn register(
|
|||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
pub struct RegisterPayload {
|
pub struct RegisterPayload {
|
||||||
username: String,
|
username: String,
|
||||||
password: String,
|
password: String,
|
||||||
|
|||||||
101
src/web/ws/mod.rs
Normal file
101
src/web/ws/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
|||||||
|
use axum::{
|
||||||
|
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
|
||||||
|
response::IntoResponse,
|
||||||
|
};
|
||||||
|
use futures::{SinkExt, StreamExt};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
entity,
|
||||||
|
state::{AppState, WebSocketKey},
|
||||||
|
};
|
||||||
|
|
||||||
|
use super::{context, error, middlware};
|
||||||
|
|
||||||
|
#[derive(serde::Serialize, Clone)]
|
||||||
|
#[serde(rename_all = "camelCase")]
|
||||||
|
#[serde(tag = "type", content = "data")]
|
||||||
|
pub enum Message {
|
||||||
|
CreateMessage(entity::Message),
|
||||||
|
UpdateChannel(entity::Channel),
|
||||||
|
CreateChannel(entity::Channel),
|
||||||
|
DeleteChannel { id: entity::ShortId },
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn broadcast_message(
|
||||||
|
message: Message,
|
||||||
|
state: &AppState,
|
||||||
|
predicate: impl Fn(&WebSocketKey) -> bool,
|
||||||
|
) {
|
||||||
|
let connected_users = state.connected_users.read().await;
|
||||||
|
|
||||||
|
let recievers =
|
||||||
|
connected_users
|
||||||
|
.iter()
|
||||||
|
.filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None });
|
||||||
|
|
||||||
|
for reciever in recievers {
|
||||||
|
_ = reciever
|
||||||
|
.send(message.clone())
|
||||||
|
.inspect_err(|err| tracing::error!("Failed to send message: {}", err));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub async fn ws_handler(
|
||||||
|
ws: WebSocketUpgrade,
|
||||||
|
State(state): State<AppState>,
|
||||||
|
Path(token): Path<String>,
|
||||||
|
) -> error::Result<impl IntoResponse> {
|
||||||
|
let context = middlware::get_context_from_token(state.clone(), &token).await?;
|
||||||
|
|
||||||
|
Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context)))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) {
|
||||||
|
let user_key = WebSocketKey {
|
||||||
|
token: context.token.clone(),
|
||||||
|
user_id: context.user.id,
|
||||||
|
};
|
||||||
|
let (mut sender, _) = websocket.split();
|
||||||
|
|
||||||
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut connected_users = state.connected_users.write().await;
|
||||||
|
if connected_users.contains_key(&user_key) {
|
||||||
|
tracing::trace!("websocket already connected: {user_key:?}");
|
||||||
|
|
||||||
|
drop(connected_users.remove(&user_key));
|
||||||
|
}
|
||||||
|
connected_users.insert(user_key.clone(), tx);
|
||||||
|
}
|
||||||
|
|
||||||
|
tracing::trace!("websocket connected: {user_key:?}");
|
||||||
|
|
||||||
|
tokio::spawn(async move {
|
||||||
|
// idk
|
||||||
|
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||||
|
while let Some(message) = rx.recv().await {
|
||||||
|
let err = sender
|
||||||
|
.send(axum::extract::ws::Message::Text(
|
||||||
|
serde_json::to_string(&message)
|
||||||
|
.inspect_err(|e| tracing::error!("Could not serialize message: {e}"))
|
||||||
|
.unwrap(),
|
||||||
|
))
|
||||||
|
.await
|
||||||
|
.inspect_err(|e| tracing::error!("Could not send message: {e}"));
|
||||||
|
|
||||||
|
if err.is_err() {
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
let _ = sender.send(axum::extract::ws::Message::Close(None)).await;
|
||||||
|
|
||||||
|
{
|
||||||
|
let mut connected_users = state.connected_users.write().await;
|
||||||
|
connected_users.remove(&user_key);
|
||||||
|
|
||||||
|
tracing::trace!("websocket disconnected: {user_key:?}");
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user