z
This commit is contained in:
114
Cargo.lock
generated
114
Cargo.lock
generated
@@ -119,6 +119,8 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
|
||||
dependencies = [
|
||||
"async-trait",
|
||||
"axum-core",
|
||||
"axum-macros",
|
||||
"base64 0.21.7",
|
||||
"bytes",
|
||||
"futures-util",
|
||||
"http",
|
||||
@@ -137,8 +139,10 @@ dependencies = [
|
||||
"serde_json",
|
||||
"serde_path_to_error",
|
||||
"serde_urlencoded",
|
||||
"sha1",
|
||||
"sync_wrapper 1.0.1",
|
||||
"tokio",
|
||||
"tokio-tungstenite",
|
||||
"tower",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
@@ -166,6 +170,18 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "axum-macros"
|
||||
version = "0.4.1"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa"
|
||||
dependencies = [
|
||||
"heck",
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "backtrace"
|
||||
version = "0.3.71"
|
||||
@@ -341,6 +357,12 @@ dependencies = [
|
||||
"typenum",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "data-encoding"
|
||||
version = "2.5.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
|
||||
|
||||
[[package]]
|
||||
name = "deranged"
|
||||
version = "0.3.11"
|
||||
@@ -465,6 +487,21 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-executor",
|
||||
"futures-io",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"futures-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-channel"
|
||||
version = "0.3.30"
|
||||
@@ -503,6 +540,23 @@ dependencies = [
|
||||
"parking_lot 0.11.2",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-io"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
|
||||
|
||||
[[package]]
|
||||
name = "futures-macro"
|
||||
version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
|
||||
dependencies = [
|
||||
"proc-macro2",
|
||||
"quote",
|
||||
"syn 2.0.60",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "futures-sink"
|
||||
version = "0.3.30"
|
||||
@@ -521,9 +575,13 @@ version = "0.3.30"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
|
||||
dependencies = [
|
||||
"futures-channel",
|
||||
"futures-core",
|
||||
"futures-io",
|
||||
"futures-macro",
|
||||
"futures-sink",
|
||||
"futures-task",
|
||||
"memchr",
|
||||
"pin-project-lite",
|
||||
"pin-utils",
|
||||
"slab",
|
||||
@@ -931,11 +989,13 @@ dependencies = [
|
||||
"chrono",
|
||||
"derive_more",
|
||||
"figment",
|
||||
"futures",
|
||||
"jsonwebtoken",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"sqlx",
|
||||
"tokio",
|
||||
"tower-http",
|
||||
"tracing",
|
||||
"tracing-appender",
|
||||
"tracing-subscriber",
|
||||
@@ -1845,6 +1905,18 @@ dependencies = [
|
||||
"tokio",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tokio-tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
|
||||
dependencies = [
|
||||
"futures-util",
|
||||
"log",
|
||||
"tokio",
|
||||
"tungstenite",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "toml"
|
||||
version = "0.8.12"
|
||||
@@ -1895,6 +1967,23 @@ dependencies = [
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-http"
|
||||
version = "0.5.2"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
|
||||
dependencies = [
|
||||
"bitflags 2.5.0",
|
||||
"bytes",
|
||||
"http",
|
||||
"http-body",
|
||||
"http-body-util",
|
||||
"pin-project-lite",
|
||||
"tower-layer",
|
||||
"tower-service",
|
||||
"tracing",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tower-layer"
|
||||
version = "0.3.2"
|
||||
@@ -1982,6 +2071,25 @@ dependencies = [
|
||||
"tracing-log",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "tungstenite"
|
||||
version = "0.21.0"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
|
||||
dependencies = [
|
||||
"byteorder",
|
||||
"bytes",
|
||||
"data-encoding",
|
||||
"http",
|
||||
"httparse",
|
||||
"log",
|
||||
"rand",
|
||||
"sha1",
|
||||
"thiserror",
|
||||
"url",
|
||||
"utf-8",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "typenum"
|
||||
version = "1.17.0"
|
||||
@@ -2053,6 +2161,12 @@ dependencies = [
|
||||
"percent-encoding",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "utf-8"
|
||||
version = "0.7.6"
|
||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
|
||||
|
||||
[[package]]
|
||||
name = "valuable"
|
||||
version = "0.1.0"
|
||||
|
||||
@@ -5,10 +5,11 @@
|
||||
|
||||
[dependencies]
|
||||
anyhow = "1.0.82"
|
||||
axum = "0.7.5"
|
||||
axum = { version = "0.7.5", features = ["macros", "ws"] }
|
||||
chrono = { version = "0.4.38", features = ["serde"] }
|
||||
derive_more = "0.99.17"
|
||||
derive_more = "0.99.17"
|
||||
figment = { version = "0.10.18", features = ["env", "toml"] }
|
||||
futures = "0.3.30"
|
||||
jsonwebtoken = "9.3.0"
|
||||
serde = { version = "1.0.198", features = ["derive"] }
|
||||
serde_json = "1.0.116"
|
||||
@@ -20,6 +21,7 @@ derive_more = "0.99.17"
|
||||
"sqlite"
|
||||
] }
|
||||
tokio = { version = "1.37.0", features = ["full"] }
|
||||
tower-http = { version = "0.5.2", features = ["cors", "trace"] }
|
||||
tracing = "0.1.40"
|
||||
tracing-appender = "0.2.3"
|
||||
tracing-subscriber = { version = "0.3.18", features = ["chrono", "env-filter"] }
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
CREATE TABLE IF NOT EXISTS users (
|
||||
CREATE TABLE IF NOT EXISTS user (
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
username VARCHAR UNIQUE,
|
||||
password VARCHAR,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
avatar VARCHAR
|
||||
);
|
||||
27
migrations/20240506141114_channel.sql
Normal file
27
migrations/20240506141114_channel.sql
Normal file
@@ -0,0 +1,27 @@
|
||||
CREATE TABLE IF NOT EXISTS `channel` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`name` VARCHAR,
|
||||
`last_message_id` INTEGER,
|
||||
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `channel_user` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`channel_id` INTEGER NOT NULL,
|
||||
`user_id` INTEGER NOT NULL,
|
||||
`admin` BOOLEAN NOT NULL DEFAULT 0,
|
||||
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
|
||||
FOREIGN KEY(`user_id`) REFERENCES `user`(`id`) ON DELETE CASCADE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS `message` (
|
||||
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
`channel_id` INTEGER NOT NULL,
|
||||
`author_id` INTEGER,
|
||||
`content` TEXT,
|
||||
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
|
||||
FOREIGN KEY(`author_id`) REFERENCES `user`(`id`) ON DELETE
|
||||
SET
|
||||
NULL
|
||||
);
|
||||
6
migrations/20240518191103_tokens.sql
Normal file
6
migrations/20240518191103_tokens.sql
Normal file
@@ -0,0 +1,6 @@
|
||||
CREATE TABLE IF NOT EXISTS tokens (
|
||||
token TEXT NOT NULL PRIMARY KEY,
|
||||
user_id INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE,
|
||||
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
|
||||
expires_at DATETIME
|
||||
);
|
||||
237
src/database.rs
237
src/database.rs
@@ -1,7 +1,10 @@
|
||||
use derive_more::{Display, Error, From};
|
||||
use sqlx::migrate::Migrator;
|
||||
|
||||
use crate::{config, entity};
|
||||
use crate::{
|
||||
config,
|
||||
entity::{self, Channel},
|
||||
};
|
||||
|
||||
static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
|
||||
|
||||
@@ -25,9 +28,9 @@ impl Database {
|
||||
Ok(Self { pool })
|
||||
}
|
||||
|
||||
pub async fn get_user_by_id(&self, id: entity::ShortId) -> Result<entity::User> {
|
||||
let user = sqlx::query_as("SELECT * FROM users WHERE id = $1")
|
||||
.bind(id)
|
||||
pub async fn get_user_by_id(&self, user_id: entity::ShortId) -> Result<entity::User> {
|
||||
let user = sqlx::query_as("SELECT * FROM user WHERE id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or(Error::UserDoesNotExists)?;
|
||||
@@ -36,7 +39,7 @@ impl Database {
|
||||
}
|
||||
|
||||
pub async fn get_user_by_username(&self, username: &str) -> Result<entity::User> {
|
||||
let user = sqlx::query_as("SELECT * FROM users WHERE username = $1")
|
||||
let user = sqlx::query_as("SELECT * FROM user WHERE username = $1")
|
||||
.bind(username)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
@@ -51,18 +54,224 @@ impl Database {
|
||||
return Err(Error::UserAlreadyExists);
|
||||
}
|
||||
|
||||
let id = sqlx::query_scalar(
|
||||
"INSERT INTO users(username, password) VALUES ($1, $2) RETURNING id",
|
||||
)
|
||||
.bind(username)
|
||||
.bind(password)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
let id =
|
||||
sqlx::query_scalar("INSERT INTO user(username, password) VALUES ($1, $2) RETURNING id")
|
||||
.bind(username)
|
||||
.bind(password)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
let user = self.get_user_by_id(id).await?;
|
||||
|
||||
Ok(user)
|
||||
}
|
||||
|
||||
pub async fn create_token(
|
||||
&self,
|
||||
id: i32,
|
||||
token: &str,
|
||||
expires_at: chrono::DateTime<chrono::Utc>,
|
||||
) -> Result<entity::Token> {
|
||||
sqlx::query("INSERT INTO tokens(user_id, token, expires_at) VALUES ($1, $2, $3)")
|
||||
.bind(id)
|
||||
.bind(token)
|
||||
.bind(expires_at)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
self.get_token(token).await
|
||||
}
|
||||
|
||||
pub async fn get_token(&self, token: &str) -> Result<entity::Token> {
|
||||
let token = sqlx::query_as("SELECT * FROM tokens WHERE token = $1")
|
||||
.bind(token)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or(Error::UserDoesNotExists)?;
|
||||
|
||||
Ok(token)
|
||||
}
|
||||
|
||||
pub async fn get_channel_by_id(&self, channel_id: entity::ShortId) -> Result<entity::Channel> {
|
||||
let channel = sqlx::query_as("SELECT * FROM channel WHERE id = $1")
|
||||
.bind(channel_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or(Error::ChannelDoesNotExists)?;
|
||||
|
||||
Ok(channel)
|
||||
}
|
||||
|
||||
pub async fn get_all_user_channels(
|
||||
&self,
|
||||
user_id: entity::ShortId,
|
||||
) -> Result<Vec<entity::Channel>> {
|
||||
let channels = sqlx::query_as("SELECT channel.* FROM channel INNER JOIN channel_user ON channel.id = channel_user.channel_id WHERE user_id = $1")
|
||||
.bind(user_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(channels)
|
||||
}
|
||||
|
||||
pub async fn update_last_message_channel(
|
||||
&self,
|
||||
channel_id: entity::ShortId,
|
||||
message_id: entity::LongId,
|
||||
) -> Result<Channel> {
|
||||
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
|
||||
.bind(message_id)
|
||||
.bind(channel_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
self.get_channel_by_id(channel_id).await
|
||||
}
|
||||
|
||||
pub async fn add_user_to_channel(
|
||||
&self,
|
||||
user_id: entity::ShortId,
|
||||
channel_id: entity::ShortId,
|
||||
admin: bool,
|
||||
) -> Result<()> {
|
||||
sqlx::query("INSERT INTO channel_user(user_id, channel_id, admin) VALUES ($1, $2, $3)")
|
||||
.bind(user_id)
|
||||
.bind(channel_id)
|
||||
.bind(admin)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_user_from_channel(
|
||||
&self,
|
||||
user_id: entity::ShortId,
|
||||
channel_id: entity::ShortId,
|
||||
) -> Result<()> {
|
||||
sqlx::query("DELETE FROM channel_user WHERE user_id = $1 AND channel_id = $2")
|
||||
.bind(user_id)
|
||||
.bind(channel_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn create_channel(
|
||||
&self,
|
||||
user_id: entity::ShortId,
|
||||
name: &str,
|
||||
) -> Result<entity::Channel> {
|
||||
let id = sqlx::query_scalar("INSERT INTO channel(name) VALUES ($1) RETURNING id")
|
||||
.bind(name)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
self.add_user_to_channel(user_id, id, true).await?;
|
||||
self.get_channel_by_id(id).await
|
||||
}
|
||||
|
||||
pub async fn delete_channel(&self, channel_id: entity::ShortId) -> Result<()> {
|
||||
sqlx::query("DELETE FROM channel WHERE id = $1")
|
||||
.bind(channel_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn set_channel_last_message_id(
|
||||
&self,
|
||||
channel_id: entity::ShortId,
|
||||
message_id: entity::LongId,
|
||||
) -> Result<()> {
|
||||
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
|
||||
.bind(message_id)
|
||||
.bind(channel_id)
|
||||
.execute(&self.pool)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_channel_users(
|
||||
&self,
|
||||
channel_id: entity::ShortId,
|
||||
) -> Result<Vec<(entity::User, bool)>> {
|
||||
let user_ids =
|
||||
sqlx::query_as("SELECT user_id, admin FROM channel_user WHERE channel_id = $1")
|
||||
.bind(channel_id)
|
||||
.fetch_all(&self.pool)
|
||||
.await?;
|
||||
|
||||
let users = user_ids.iter().map(|(user_id, admin)| async move {
|
||||
let user = self.get_user_by_id(*user_id).await;
|
||||
user.map(|u| (u, *admin))
|
||||
});
|
||||
|
||||
let users = futures::future::try_join_all(users).await?;
|
||||
|
||||
Ok(users)
|
||||
}
|
||||
|
||||
pub async fn get_message_by_id(&self, message_id: entity::LongId) -> Result<entity::Message> {
|
||||
let message = sqlx::query_as("SELECT * FROM message WHERE id = $1")
|
||||
.bind(message_id)
|
||||
.fetch_optional(&self.pool)
|
||||
.await?
|
||||
.ok_or(Error::MessageDoesNotExists)?;
|
||||
|
||||
Ok(message)
|
||||
}
|
||||
|
||||
pub async fn create_message(
|
||||
&self,
|
||||
channel_id: entity::ShortId,
|
||||
user_id: entity::ShortId,
|
||||
content: &str,
|
||||
) -> Result<entity::Message> {
|
||||
let id = sqlx::query_scalar(
|
||||
"INSERT INTO message(channel_id, author_id, content) VALUES ($1, $2, $3) RETURNING id",
|
||||
)
|
||||
.bind(channel_id)
|
||||
.bind(user_id)
|
||||
.bind(content)
|
||||
.fetch_one(&self.pool)
|
||||
.await?;
|
||||
|
||||
self.set_channel_last_message_id(channel_id, id).await?;
|
||||
self.get_message_by_id(id).await
|
||||
}
|
||||
|
||||
pub async fn get_messages_by_channel_id(
|
||||
&self,
|
||||
channel_id: entity::ShortId,
|
||||
limit: i64,
|
||||
before_id: Option<entity::LongId>,
|
||||
) -> Result<Vec<entity::Message>> {
|
||||
let messages = match before_id {
|
||||
Some(before_id) => sqlx::query_as::<_, entity::Message>(
|
||||
"SELECT * FROM message WHERE channel_id = $1 AND id < $2 ORDER BY id DESC LIMIT $3",
|
||||
)
|
||||
.bind(channel_id)
|
||||
.bind(before_id)
|
||||
.bind(limit)
|
||||
.fetch_all(&self.pool)
|
||||
.await?,
|
||||
None => {
|
||||
sqlx::query_as::<_, entity::Message>(
|
||||
"SELECT * FROM message WHERE channel_id = $1 ORDER BY id DESC LIMIT $2",
|
||||
)
|
||||
.bind(channel_id)
|
||||
.bind(limit)
|
||||
.fetch_all(&self.pool)
|
||||
.await?
|
||||
}
|
||||
};
|
||||
|
||||
Ok(messages)
|
||||
}
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, Error>;
|
||||
@@ -75,4 +284,8 @@ pub enum Error {
|
||||
Sqlx(sqlx::Error),
|
||||
UserDoesNotExists,
|
||||
UserAlreadyExists,
|
||||
|
||||
ChannelDoesNotExists,
|
||||
|
||||
MessageDoesNotExists,
|
||||
}
|
||||
|
||||
10
src/entity/channel.rs
Normal file
10
src/entity/channel.rs
Normal file
@@ -0,0 +1,10 @@
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Channel {
|
||||
pub id: super::ShortId,
|
||||
pub name: String,
|
||||
pub last_message_id: Option<super::LongId>,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
@@ -1,9 +1,12 @@
|
||||
use chrono::Utc;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Clone, sqlx::FromRow)]
|
||||
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Message {
|
||||
pub id: super::LongId,
|
||||
pub user_id: super::ShortId,
|
||||
pub channel_id: super::ShortId,
|
||||
pub author_id: super::ShortId,
|
||||
pub content: String,
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
#![allow(unused)]
|
||||
|
||||
mod channel;
|
||||
mod log;
|
||||
mod message;
|
||||
mod secret;
|
||||
mod token;
|
||||
mod user;
|
||||
|
||||
pub use channel::Channel;
|
||||
pub use log::Log;
|
||||
pub use message::Message;
|
||||
pub use secret::Secret;
|
||||
pub use token::Token;
|
||||
pub use user::User;
|
||||
|
||||
pub type ShortId = i32;
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
#[derive(Clone, sqlx::FromRow)]
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||
pub struct Secret {
|
||||
pub id: super::ShortId,
|
||||
pub user_id: super::ShortId,
|
||||
|
||||
12
src/entity/token.rs
Normal file
12
src/entity/token.rs
Normal file
@@ -0,0 +1,12 @@
|
||||
use serde::Serialize;
|
||||
|
||||
use super::ShortId;
|
||||
|
||||
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct Token {
|
||||
pub token: String,
|
||||
pub user_id: ShortId,
|
||||
pub created_at: chrono::DateTime<chrono::Utc>,
|
||||
pub expires_at: chrono::DateTime<chrono::Utc>,
|
||||
}
|
||||
@@ -1,9 +1,13 @@
|
||||
use chrono::Utc;
|
||||
use serde::Serialize;
|
||||
|
||||
#[derive(Clone, sqlx::FromRow)]
|
||||
#[derive(Debug, Clone, Hash, sqlx::FromRow, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct User {
|
||||
pub id: super::ShortId,
|
||||
pub username: String,
|
||||
#[serde(skip_serializing)]
|
||||
pub password: String,
|
||||
pub avatar: Option<String>,
|
||||
pub created_at: chrono::DateTime<Utc>,
|
||||
}
|
||||
|
||||
@@ -11,9 +11,9 @@ pub struct Claims {
|
||||
pub user_id: i32,
|
||||
}
|
||||
|
||||
pub fn generate_jwt(user_id: i32) -> Result<String> {
|
||||
pub fn generate_jwt(user_id: i32, expires_at: chrono::DateTime<chrono::Utc>) -> Result<String> {
|
||||
let claims = Claims {
|
||||
exp: (Local::now() + Duration::days(1)).timestamp() as usize,
|
||||
exp: expires_at.timestamp() as usize,
|
||||
user_id,
|
||||
};
|
||||
|
||||
@@ -28,6 +28,7 @@ pub fn generate_jwt(user_id: i32) -> Result<String> {
|
||||
}
|
||||
|
||||
pub fn verify_jwt(token: &str) -> Result<i32> {
|
||||
tracing::debug!("Verifying token: {}", token);
|
||||
let token_data = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
|
||||
|
||||
@@ -14,7 +14,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
|
||||
let _guard = log::initialize()?;
|
||||
|
||||
let database = Database::init().await?;
|
||||
let context = AppState { database };
|
||||
let context = AppState {
|
||||
database,
|
||||
connected_users: Default::default(),
|
||||
};
|
||||
|
||||
web::run(context).await?;
|
||||
|
||||
|
||||
13
src/state.rs
13
src/state.rs
@@ -1,6 +1,17 @@
|
||||
use crate::database::Database;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use tokio::sync::{mpsc::UnboundedSender, RwLock};
|
||||
|
||||
use crate::{database::Database, entity::ShortId, web::ws};
|
||||
|
||||
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
|
||||
pub struct WebSocketKey {
|
||||
pub user_id: ShortId,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct AppState {
|
||||
pub database: Database,
|
||||
pub connected_users: Arc<RwLock<HashMap<WebSocketKey, UnboundedSender<ws::Message>>>>,
|
||||
}
|
||||
|
||||
@@ -2,9 +2,10 @@ use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
|
||||
|
||||
use crate::entity;
|
||||
|
||||
#[derive(Clone)]
|
||||
#[derive(Debug, Clone, Hash)]
|
||||
pub struct Context {
|
||||
pub user_id: entity::ShortId,
|
||||
pub user: entity::User,
|
||||
pub token: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
@@ -15,9 +16,9 @@ impl<S: Send + Sync> FromRequestParts<S> for Context {
|
||||
parts
|
||||
.extensions
|
||||
.get::<ContextResult>()
|
||||
.ok_or(super::Error::ContextError(Error::NotInRequest))?
|
||||
.ok_or(super::Error::Context(Error::NotInRequest))?
|
||||
.clone()
|
||||
.map_err(super::Error::ContextError)
|
||||
.map_err(super::Error::Context)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -12,18 +12,41 @@ pub type Result<T> = std::result::Result<T, Error>;
|
||||
#[derive(Debug, From)]
|
||||
pub enum Error {
|
||||
#[from]
|
||||
ContextError(context::Error),
|
||||
Context(context::Error),
|
||||
#[from]
|
||||
DatabaseError(database::Error),
|
||||
Database(database::Error),
|
||||
#[from]
|
||||
JWT(jwt::Error),
|
||||
Jwt(jwt::Error),
|
||||
|
||||
WrongPassword,
|
||||
NotAllowed,
|
||||
}
|
||||
|
||||
impl Error {
|
||||
pub fn as_client_error(&self) -> ClientError {
|
||||
ClientError::HahaError
|
||||
pub fn as_client_error_and_status(&self) -> (ClientError, StatusCode) {
|
||||
match self {
|
||||
Error::Context(_) | Error::Jwt(_) => {
|
||||
(ClientError::NotAuthorized, StatusCode::UNAUTHORIZED)
|
||||
}
|
||||
Error::Database(database::Error::UserAlreadyExists) => {
|
||||
(ClientError::UserAlreadyExists, StatusCode::CONFLICT)
|
||||
}
|
||||
Error::Database(database::Error::UserDoesNotExists) => {
|
||||
(ClientError::UserDoesNotExists, StatusCode::NOT_FOUND)
|
||||
}
|
||||
Error::Database(database::Error::ChannelDoesNotExists) => {
|
||||
(ClientError::ChannelDoesNotExists, StatusCode::NOT_FOUND)
|
||||
}
|
||||
Error::Database(database::Error::MessageDoesNotExists) => {
|
||||
(ClientError::MessageDoesNotExists, StatusCode::NOT_FOUND)
|
||||
}
|
||||
Error::WrongPassword => (ClientError::WrongPassword, StatusCode::UNAUTHORIZED),
|
||||
Error::NotAllowed => (ClientError::NotAllowed, StatusCode::FORBIDDEN),
|
||||
_ => (
|
||||
ClientError::InternalServerError,
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -39,5 +62,22 @@ impl IntoResponse for Error {
|
||||
|
||||
#[derive(Debug, Error, Display)]
|
||||
pub enum ClientError {
|
||||
HahaError,
|
||||
#[display(fmt = "Not authorized")]
|
||||
NotAuthorized,
|
||||
#[display(fmt = "Wrong password")]
|
||||
WrongPassword,
|
||||
#[display(fmt = "Not allowed")]
|
||||
NotAllowed,
|
||||
|
||||
#[display(fmt = "User already exists")]
|
||||
UserAlreadyExists,
|
||||
#[display(fmt = "User does not exists")]
|
||||
UserDoesNotExists,
|
||||
#[display(fmt = "Channel does not exists")]
|
||||
ChannelDoesNotExists,
|
||||
#[display(fmt = "Message does not exists")]
|
||||
MessageDoesNotExists,
|
||||
|
||||
#[display(fmt = "Internal server error")]
|
||||
InternalServerError,
|
||||
}
|
||||
|
||||
@@ -2,11 +2,11 @@ use axum::{
|
||||
extract::{Request, State},
|
||||
http::{header, HeaderMap},
|
||||
middleware::Next,
|
||||
response::{IntoResponse, Response},
|
||||
response::Response,
|
||||
Extension,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
database::Database,
|
||||
jwt,
|
||||
state::AppState,
|
||||
web::{
|
||||
@@ -16,7 +16,7 @@ use crate::{
|
||||
};
|
||||
|
||||
pub async fn require_context(
|
||||
context: ContextResult,
|
||||
Extension(context): Extension<ContextResult>,
|
||||
request: Request,
|
||||
next: Next,
|
||||
) -> web::Result<Response> {
|
||||
@@ -46,12 +46,17 @@ async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextRe
|
||||
return Err(context::Error::WrongTokenType);
|
||||
}
|
||||
let token = &token[7..];
|
||||
|
||||
get_context_from_token(state, token).await
|
||||
}
|
||||
|
||||
pub async fn get_context_from_token(state: AppState, token: &str) -> context::ContextResult {
|
||||
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
|
||||
let _user = state
|
||||
let user = state
|
||||
.database
|
||||
.get_user_by_id(user_id)
|
||||
.await
|
||||
.map_err(|_| context::Error::Model)?;
|
||||
|
||||
Ok(context::Context { user_id })
|
||||
Ok(context::Context { user, token: token.to_owned() })
|
||||
}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
mod auth;
|
||||
mod response_map;
|
||||
|
||||
pub use auth::require_context;
|
||||
pub use auth::resolve_context;
|
||||
pub use response_map::response_map;
|
||||
pub use auth::*;
|
||||
pub use response_map::*;
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use axum::{extract::Request, middleware::Next, response::IntoResponse};
|
||||
use axum::{extract::Request, middleware::Next, response::IntoResponse, Json};
|
||||
use serde_json::json;
|
||||
|
||||
use crate::web;
|
||||
|
||||
@@ -9,10 +10,18 @@ pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
|
||||
|
||||
let error = response.extensions().get::<Arc<web::Error>>();
|
||||
|
||||
let error_response = error.map(|e| {
|
||||
let client_error = e.as_client_error();
|
||||
if error.is_some() {
|
||||
tracing::error!("{:?}", error);
|
||||
}
|
||||
|
||||
client_error.to_string().into_response()
|
||||
let error_response = error.map(|e| {
|
||||
let client_error = e.as_client_error_and_status();
|
||||
|
||||
(
|
||||
client_error.1,
|
||||
Json(json!({"error": client_error.0.to_string()})),
|
||||
)
|
||||
.into_response()
|
||||
});
|
||||
|
||||
error_response.unwrap_or(response)
|
||||
|
||||
@@ -1,11 +1,13 @@
|
||||
use crate::{config, state};
|
||||
use crate::{config, state, web::routes::message};
|
||||
|
||||
mod context;
|
||||
mod error;
|
||||
pub mod middlware;
|
||||
pub mod routes;
|
||||
pub mod ws;
|
||||
|
||||
pub use error::{Error, Result};
|
||||
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
|
||||
|
||||
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
||||
let config = config::config();
|
||||
@@ -22,18 +24,66 @@ pub async fn run(state: state::AppState) -> anyhow::Result<()> {
|
||||
}
|
||||
|
||||
fn router(state: state::AppState) -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/user/login", axum::routing::post(routes::user::login))
|
||||
.route(
|
||||
"/user/register",
|
||||
axum::routing::post(routes::user::register),
|
||||
)
|
||||
.with_state(state.clone())
|
||||
.layer(axum::middleware::from_fn_with_state(
|
||||
use self::routes::user;
|
||||
use axum::routing::*;
|
||||
use axum::Router;
|
||||
|
||||
let cors = tower_http::cors::CorsLayer::new()
|
||||
.allow_origin(AllowOrigin::any())
|
||||
.allow_methods(AllowMethods::any())
|
||||
.allow_headers(AllowHeaders::any());
|
||||
|
||||
Router::new()
|
||||
// websocket
|
||||
.route("/ws/:token", get(ws::ws_handler))
|
||||
// unprotected
|
||||
.route("/user/login", post(user::login))
|
||||
.route("/user/register", post(user::register))
|
||||
// protected
|
||||
.nest("/", protected_router())
|
||||
.layer(tower_http::trace::TraceLayer::new_for_http())
|
||||
.route_layer(axum::middleware::from_fn_with_state(
|
||||
state.clone(),
|
||||
middlware::resolve_context,
|
||||
))
|
||||
.layer(axum::middleware::from_fn(middlware::response_map))
|
||||
.layer(cors)
|
||||
.with_state(state.clone())
|
||||
}
|
||||
|
||||
fn protected_router() -> axum::Router<state::AppState> {
|
||||
use self::routes::channel;
|
||||
use self::routes::user;
|
||||
use axum::routing::*;
|
||||
use axum::Router;
|
||||
|
||||
Router::new()
|
||||
// user
|
||||
.route("/user", get(user::get_by_username))
|
||||
.route("/user/me", get(user::get_self))
|
||||
.route("/user/:user_id", get(user::get_by_id))
|
||||
// channel
|
||||
.route("/channel", get(channel::get_all_user_channels))
|
||||
.route("/channel/:channel_id", get(channel::get_by_id))
|
||||
.route("/channel", post(channel::create_channel))
|
||||
.route("/channel/:channel_id", delete(channel::delete_channel))
|
||||
.route(
|
||||
"/channel/:channel_id/user/:user_id",
|
||||
post(channel::add_user_to_channel),
|
||||
)
|
||||
.route(
|
||||
"/channel/:channel_id/user/:user_id",
|
||||
delete(channel::remove_user_from_channel),
|
||||
)
|
||||
// message
|
||||
.route("/message/:message_id", get(message::get_by_id))
|
||||
.route("/message", post(message::create))
|
||||
.route(
|
||||
"/channel/:channel_id/message",
|
||||
get(message::get_messages_by_channel_id),
|
||||
)
|
||||
// middleware
|
||||
.route_layer(axum::middleware::from_fn(middlware::require_context))
|
||||
}
|
||||
|
||||
async fn shutdown_signal() {
|
||||
|
||||
34
src/web/routes/channel/create.rs
Normal file
34
src/web/routes/channel/create.rs
Normal file
@@ -0,0 +1,34 @@
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
|
||||
use crate::{
|
||||
state::AppState,
|
||||
web::{self, ws},
|
||||
};
|
||||
|
||||
pub async fn create_channel(
|
||||
State(state): State<AppState>,
|
||||
context: web::context::Context,
|
||||
Json(body): Json<CreateChannel>,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel = state
|
||||
.database
|
||||
.create_channel(context.user.id, &body.name)
|
||||
.await?;
|
||||
|
||||
let channel_users = state.database.get_channel_users(channel.id).await?;
|
||||
|
||||
ws::broadcast_message(ws::Message::CreateChannel(channel.clone()), &state, |key| {
|
||||
channel_users
|
||||
.iter()
|
||||
.any(|(user, _admin)| user.id == key.user_id)
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(Json(channel))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateChannel {
|
||||
name: String,
|
||||
}
|
||||
40
src/web/routes/channel/delete.rs
Normal file
40
src/web/routes/channel/delete.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
entity::ShortId,
|
||||
state::AppState,
|
||||
web::{self, ws},
|
||||
};
|
||||
|
||||
pub async fn delete_channel(
|
||||
State(state): State<AppState>,
|
||||
Path(channel_id): Path<ShortId>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.iter()
|
||||
.any(|(user, admin)| user.id == context.user.id && *admin)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
state.database.delete_channel(channel_id).await?;
|
||||
|
||||
ws::broadcast_message(
|
||||
ws::Message::DeleteChannel { id: channel_id },
|
||||
&state,
|
||||
|key| {
|
||||
channel_users
|
||||
.iter()
|
||||
.any(|(user, _admin)| user.id == key.user_id)
|
||||
},
|
||||
)
|
||||
.await;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
37
src/web/routes/channel/get.rs
Normal file
37
src/web/routes/channel/get.rs
Normal file
@@ -0,0 +1,37 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
|
||||
use crate::{entity::ShortId, state::AppState, web};
|
||||
|
||||
pub async fn get_by_id(
|
||||
State(state): State<AppState>,
|
||||
Path(channel_id): Path<ShortId>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel = state.database.get_channel_by_id(channel_id).await?;
|
||||
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.into_iter()
|
||||
.any(|(user, _admin)| user.id == context.user.id)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
Ok(Json(channel))
|
||||
}
|
||||
|
||||
pub async fn get_all_user_channels(
|
||||
State(state): State<AppState>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channels = state
|
||||
.database
|
||||
.get_all_user_channels(context.user.id)
|
||||
.await?;
|
||||
|
||||
Ok(Json(channels))
|
||||
}
|
||||
9
src/web/routes/channel/mod.rs
Normal file
9
src/web/routes/channel/mod.rs
Normal file
@@ -0,0 +1,9 @@
|
||||
mod create;
|
||||
mod delete;
|
||||
mod get;
|
||||
mod user;
|
||||
|
||||
pub use create::*;
|
||||
pub use delete::*;
|
||||
pub use get::*;
|
||||
pub use user::*;
|
||||
50
src/web/routes/channel/user.rs
Normal file
50
src/web/routes/channel/user.rs
Normal file
@@ -0,0 +1,50 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
};
|
||||
|
||||
use crate::{entity::ShortId, state::AppState, web};
|
||||
|
||||
pub async fn add_user_to_channel(
|
||||
State(state): State<AppState>,
|
||||
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.into_iter()
|
||||
.any(|(user, admin)| user.id == context.user.id && admin)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
state
|
||||
.database
|
||||
.add_user_to_channel(user_id, channel_id, false)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn remove_user_from_channel(
|
||||
State(state): State<AppState>,
|
||||
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.into_iter()
|
||||
.any(|(user, admin)| user.id == context.user.id && admin)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
state
|
||||
.database
|
||||
.remove_user_from_channel(user_id, channel_id)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
54
src/web/routes/message/create.rs
Normal file
54
src/web/routes/message/create.rs
Normal file
@@ -0,0 +1,54 @@
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
|
||||
use crate::{
|
||||
entity::ShortId,
|
||||
state::AppState,
|
||||
web::{self, ws},
|
||||
};
|
||||
|
||||
pub async fn create(
|
||||
State(state): State<AppState>,
|
||||
context: web::context::Context,
|
||||
Json(body): Json<CreateMessage>,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel_users = state.database.get_channel_users(body.channel_id).await?;
|
||||
if !channel_users
|
||||
.iter()
|
||||
.any(|(user, _admin)| user.id == context.user.id)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
let message = state
|
||||
.database
|
||||
.create_message(body.channel_id, context.user.id, &body.content)
|
||||
.await?;
|
||||
|
||||
let channel = state
|
||||
.database
|
||||
.update_last_message_channel(body.channel_id, message.id)
|
||||
.await?;
|
||||
|
||||
ws::broadcast_message(ws::Message::CreateMessage(message.clone()), &state, |key| {
|
||||
channel_users
|
||||
.iter()
|
||||
.any(|(user, _admin)| user.id == key.user_id)
|
||||
})
|
||||
.await;
|
||||
|
||||
ws::broadcast_message(ws::Message::UpdateChannel(channel), &state, |key| {
|
||||
channel_users
|
||||
.iter()
|
||||
.any(|(user, _admin)| user.id == key.user_id)
|
||||
})
|
||||
.await;
|
||||
|
||||
Ok(Json(message))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct CreateMessage {
|
||||
content: String,
|
||||
channel_id: ShortId,
|
||||
}
|
||||
61
src/web/routes/message/get.rs
Normal file
61
src/web/routes/message/get.rs
Normal file
@@ -0,0 +1,61 @@
|
||||
use axum::{
|
||||
extract::{Path, Query, State},
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
|
||||
use crate::{
|
||||
entity::{LongId, ShortId},
|
||||
state::AppState,
|
||||
web,
|
||||
};
|
||||
|
||||
pub async fn get_by_id(
|
||||
State(state): State<AppState>,
|
||||
Path(message_id): Path<LongId>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let message = state.database.get_message_by_id(message_id).await?;
|
||||
if message.author_id != context.user.id {
|
||||
let channel_users = state.database.get_channel_users(message.channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.into_iter()
|
||||
.any(|(user, _admin)| user.id == context.user.id)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Json(message))
|
||||
}
|
||||
|
||||
pub async fn get_messages_by_channel_id(
|
||||
State(state): State<AppState>,
|
||||
Path(channel_id): Path<ShortId>,
|
||||
context: web::context::Context,
|
||||
Query(body): Query<GetByChannelId>,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let channel_users = state.database.get_channel_users(channel_id).await?;
|
||||
|
||||
if !channel_users
|
||||
.into_iter()
|
||||
.any(|(user, _admin)| user.id == context.user.id)
|
||||
{
|
||||
return Err(web::Error::NotAllowed);
|
||||
}
|
||||
|
||||
let messages = state
|
||||
.database
|
||||
.get_messages_by_channel_id(channel_id, body.limit.unwrap_or(50), body.before_id)
|
||||
.await?;
|
||||
|
||||
Ok(Json(messages))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetByChannelId {
|
||||
before_id: Option<LongId>,
|
||||
limit: Option<i64>,
|
||||
}
|
||||
5
src/web/routes/message/mod.rs
Normal file
5
src/web/routes/message/mod.rs
Normal file
@@ -0,0 +1,5 @@
|
||||
mod create;
|
||||
mod get;
|
||||
|
||||
pub use create::*;
|
||||
pub use get::*;
|
||||
@@ -1 +1,3 @@
|
||||
pub mod channel;
|
||||
pub mod message;
|
||||
pub mod user;
|
||||
|
||||
40
src/web/routes/user/get.rs
Normal file
40
src/web/routes/user/get.rs
Normal file
@@ -0,0 +1,40 @@
|
||||
use axum::{
|
||||
extract::{Path, State},
|
||||
response::IntoResponse,
|
||||
Json,
|
||||
};
|
||||
|
||||
use crate::{entity::ShortId, state::AppState, web};
|
||||
|
||||
pub async fn get_self(
|
||||
State(state): State<AppState>,
|
||||
context: web::context::Context,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let user = state.database.get_user_by_id(context.user.id).await?;
|
||||
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
pub async fn get_by_id(
|
||||
State(state): State<AppState>,
|
||||
Path(user_id): Path<ShortId>,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let user = state.database.get_user_by_id(user_id).await?;
|
||||
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
pub async fn get_by_username(
|
||||
State(state): State<AppState>,
|
||||
Json(body): Json<GetByUsernameRequest>,
|
||||
) -> web::Result<impl IntoResponse> {
|
||||
let user = state.database.get_user_by_username(&body.username).await?;
|
||||
|
||||
Ok(Json(user))
|
||||
}
|
||||
|
||||
#[derive(serde::Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct GetByUsernameRequest {
|
||||
pub username: String,
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
use axum::{extract::State, response::IntoResponse, Json};
|
||||
use chrono::{Duration, Utc};
|
||||
use serde::Deserialize;
|
||||
use serde_json::json;
|
||||
|
||||
use crate::{jwt, state::AppState, web};
|
||||
|
||||
@@ -17,14 +17,20 @@ pub async fn login(
|
||||
return Err(web::Error::WrongPassword);
|
||||
}
|
||||
|
||||
let token = jwt::generate_jwt(user.id)?;
|
||||
let expires_at = Utc::now() + Duration::hours(24);
|
||||
|
||||
Ok(Json(json!({
|
||||
"token": token
|
||||
})))
|
||||
let token = jwt::generate_jwt(user.id, expires_at)?;
|
||||
|
||||
let token = state
|
||||
.database
|
||||
.create_token(user.id, &token, expires_at)
|
||||
.await?;
|
||||
|
||||
Ok(Json(token))
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct LoginPayload {
|
||||
username: String,
|
||||
password: String,
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
mod get;
|
||||
mod login;
|
||||
mod register;
|
||||
|
||||
pub use get::*;
|
||||
pub use login::login;
|
||||
pub use register::register;
|
||||
|
||||
@@ -21,6 +21,7 @@ pub async fn register(
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
pub struct RegisterPayload {
|
||||
username: String,
|
||||
password: String,
|
||||
|
||||
101
src/web/ws/mod.rs
Normal file
101
src/web/ws/mod.rs
Normal file
@@ -0,0 +1,101 @@
|
||||
use axum::{
|
||||
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use futures::{SinkExt, StreamExt};
|
||||
|
||||
use crate::{
|
||||
entity,
|
||||
state::{AppState, WebSocketKey},
|
||||
};
|
||||
|
||||
use super::{context, error, middlware};
|
||||
|
||||
#[derive(serde::Serialize, Clone)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
#[serde(tag = "type", content = "data")]
|
||||
pub enum Message {
|
||||
CreateMessage(entity::Message),
|
||||
UpdateChannel(entity::Channel),
|
||||
CreateChannel(entity::Channel),
|
||||
DeleteChannel { id: entity::ShortId },
|
||||
}
|
||||
|
||||
pub async fn broadcast_message(
|
||||
message: Message,
|
||||
state: &AppState,
|
||||
predicate: impl Fn(&WebSocketKey) -> bool,
|
||||
) {
|
||||
let connected_users = state.connected_users.read().await;
|
||||
|
||||
let recievers =
|
||||
connected_users
|
||||
.iter()
|
||||
.filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None });
|
||||
|
||||
for reciever in recievers {
|
||||
_ = reciever
|
||||
.send(message.clone())
|
||||
.inspect_err(|err| tracing::error!("Failed to send message: {}", err));
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
State(state): State<AppState>,
|
||||
Path(token): Path<String>,
|
||||
) -> error::Result<impl IntoResponse> {
|
||||
let context = middlware::get_context_from_token(state.clone(), &token).await?;
|
||||
|
||||
Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context)))
|
||||
}
|
||||
|
||||
async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) {
|
||||
let user_key = WebSocketKey {
|
||||
token: context.token.clone(),
|
||||
user_id: context.user.id,
|
||||
};
|
||||
let (mut sender, _) = websocket.split();
|
||||
|
||||
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
|
||||
|
||||
{
|
||||
let mut connected_users = state.connected_users.write().await;
|
||||
if connected_users.contains_key(&user_key) {
|
||||
tracing::trace!("websocket already connected: {user_key:?}");
|
||||
|
||||
drop(connected_users.remove(&user_key));
|
||||
}
|
||||
connected_users.insert(user_key.clone(), tx);
|
||||
}
|
||||
|
||||
tracing::trace!("websocket connected: {user_key:?}");
|
||||
|
||||
tokio::spawn(async move {
|
||||
// idk
|
||||
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
||||
while let Some(message) = rx.recv().await {
|
||||
let err = sender
|
||||
.send(axum::extract::ws::Message::Text(
|
||||
serde_json::to_string(&message)
|
||||
.inspect_err(|e| tracing::error!("Could not serialize message: {e}"))
|
||||
.unwrap(),
|
||||
))
|
||||
.await
|
||||
.inspect_err(|e| tracing::error!("Could not send message: {e}"));
|
||||
|
||||
if err.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let _ = sender.send(axum::extract::ws::Message::Close(None)).await;
|
||||
|
||||
{
|
||||
let mut connected_users = state.connected_users.write().await;
|
||||
connected_users.remove(&user_key);
|
||||
|
||||
tracing::trace!("websocket disconnected: {user_key:?}");
|
||||
}
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user