1
0
This commit is contained in:
2024-05-20 12:34:46 +03:00
parent 18f2160c27
commit 3bcc2b5c5c
35 changed files with 1019 additions and 60 deletions

114
Cargo.lock generated
View File

@@ -119,6 +119,8 @@ checksum = "3a6c9af12842a67734c9a2e355436e5d03b22383ed60cf13cd0c18fbfe3dcbcf"
dependencies = [
"async-trait",
"axum-core",
"axum-macros",
"base64 0.21.7",
"bytes",
"futures-util",
"http",
@@ -137,8 +139,10 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha1",
"sync_wrapper 1.0.1",
"tokio",
"tokio-tungstenite",
"tower",
"tower-layer",
"tower-service",
@@ -166,6 +170,18 @@ dependencies = [
"tracing",
]
[[package]]
name = "axum-macros"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "00c055ee2d014ae5981ce1016374e8213682aa14d9bf40e48ab48b5f3ef20eaa"
dependencies = [
"heck",
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "backtrace"
version = "0.3.71"
@@ -341,6 +357,12 @@ dependencies = [
"typenum",
]
[[package]]
name = "data-encoding"
version = "2.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e962a19be5cfc3f3bf6dd8f61eb50107f356ad6270fbb3ed41476571db78be5"
[[package]]
name = "deranged"
version = "0.3.11"
@@ -465,6 +487,21 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "futures"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
@@ -503,6 +540,23 @@ dependencies = [
"parking_lot 0.11.2",
]
[[package]]
name = "futures-io"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-macro"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn 2.0.60",
]
[[package]]
name = "futures-sink"
version = "0.3.30"
@@ -521,9 +575,13 @@ version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
@@ -931,11 +989,13 @@ dependencies = [
"chrono",
"derive_more",
"figment",
"futures",
"jsonwebtoken",
"serde",
"serde_json",
"sqlx",
"tokio",
"tower-http",
"tracing",
"tracing-appender",
"tracing-subscriber",
@@ -1845,6 +1905,18 @@ dependencies = [
"tokio",
]
[[package]]
name = "tokio-tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c83b561d025642014097b66e6c1bb422783339e0909e4429cde4749d1990bc38"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "toml"
version = "0.8.12"
@@ -1895,6 +1967,23 @@ dependencies = [
"tracing",
]
[[package]]
name = "tower-http"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e9cd434a998747dd2c4276bc96ee2e0c7a2eadf3cae88e52be55a05fa9053f5"
dependencies = [
"bitflags 2.5.0",
"bytes",
"http",
"http-body",
"http-body-util",
"pin-project-lite",
"tower-layer",
"tower-service",
"tracing",
]
[[package]]
name = "tower-layer"
version = "0.3.2"
@@ -1982,6 +2071,25 @@ dependencies = [
"tracing-log",
]
[[package]]
name = "tungstenite"
version = "0.21.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ef1a641ea34f399a848dea702823bbecfb4c486f911735368f1f137cb8257e1"
dependencies = [
"byteorder",
"bytes",
"data-encoding",
"http",
"httparse",
"log",
"rand",
"sha1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "typenum"
version = "1.17.0"
@@ -2053,6 +2161,12 @@ dependencies = [
"percent-encoding",
]
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "valuable"
version = "0.1.0"

View File

@@ -5,10 +5,11 @@
[dependencies]
anyhow = "1.0.82"
axum = "0.7.5"
axum = { version = "0.7.5", features = ["macros", "ws"] }
chrono = { version = "0.4.38", features = ["serde"] }
derive_more = "0.99.17"
figment = { version = "0.10.18", features = ["env", "toml"] }
futures = "0.3.30"
jsonwebtoken = "9.3.0"
serde = { version = "1.0.198", features = ["derive"] }
serde_json = "1.0.116"
@@ -20,6 +21,7 @@ derive_more = "0.99.17"
"sqlite"
] }
tokio = { version = "1.37.0", features = ["full"] }
tower-http = { version = "0.5.2", features = ["cors", "trace"] }
tracing = "0.1.40"
tracing-appender = "0.2.3"
tracing-subscriber = { version = "0.3.18", features = ["chrono", "env-filter"] }

View File

@@ -1,6 +1,7 @@
CREATE TABLE IF NOT EXISTS users (
CREATE TABLE IF NOT EXISTS user (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username VARCHAR UNIQUE,
password VARCHAR,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
avatar VARCHAR
);

View File

@@ -0,0 +1,27 @@
CREATE TABLE IF NOT EXISTS `channel` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`name` VARCHAR,
`last_message_id` INTEGER,
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE TABLE IF NOT EXISTS `channel_user` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`channel_id` INTEGER NOT NULL,
`user_id` INTEGER NOT NULL,
`admin` BOOLEAN NOT NULL DEFAULT 0,
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
FOREIGN KEY(`user_id`) REFERENCES `user`(`id`) ON DELETE CASCADE
);
CREATE TABLE IF NOT EXISTS `message` (
`id` INTEGER PRIMARY KEY AUTOINCREMENT,
`channel_id` INTEGER NOT NULL,
`author_id` INTEGER,
`content` TEXT,
`created_at` DATETIME DEFAULT CURRENT_TIMESTAMP,
FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE,
FOREIGN KEY(`author_id`) REFERENCES `user`(`id`) ON DELETE
SET
NULL
);

View File

@@ -0,0 +1,6 @@
CREATE TABLE IF NOT EXISTS tokens (
token TEXT NOT NULL PRIMARY KEY,
user_id INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP,
expires_at DATETIME
);

View File

@@ -1,7 +1,10 @@
use derive_more::{Display, Error, From};
use sqlx::migrate::Migrator;
use crate::{config, entity};
use crate::{
config,
entity::{self, Channel},
};
static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
@@ -25,9 +28,9 @@ impl Database {
Ok(Self { pool })
}
pub async fn get_user_by_id(&self, id: entity::ShortId) -> Result<entity::User> {
let user = sqlx::query_as("SELECT * FROM users WHERE id = $1")
.bind(id)
pub async fn get_user_by_id(&self, user_id: entity::ShortId) -> Result<entity::User> {
let user = sqlx::query_as("SELECT * FROM user WHERE id = $1")
.bind(user_id)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
@@ -36,7 +39,7 @@ impl Database {
}
pub async fn get_user_by_username(&self, username: &str) -> Result<entity::User> {
let user = sqlx::query_as("SELECT * FROM users WHERE username = $1")
let user = sqlx::query_as("SELECT * FROM user WHERE username = $1")
.bind(username)
.fetch_optional(&self.pool)
.await?
@@ -51,9 +54,8 @@ impl Database {
return Err(Error::UserAlreadyExists);
}
let id = sqlx::query_scalar(
"INSERT INTO users(username, password) VALUES ($1, $2) RETURNING id",
)
let id =
sqlx::query_scalar("INSERT INTO user(username, password) VALUES ($1, $2) RETURNING id")
.bind(username)
.bind(password)
.fetch_one(&self.pool)
@@ -63,6 +65,213 @@ impl Database {
Ok(user)
}
pub async fn create_token(
&self,
id: i32,
token: &str,
expires_at: chrono::DateTime<chrono::Utc>,
) -> Result<entity::Token> {
sqlx::query("INSERT INTO tokens(user_id, token, expires_at) VALUES ($1, $2, $3)")
.bind(id)
.bind(token)
.bind(expires_at)
.execute(&self.pool)
.await?;
self.get_token(token).await
}
pub async fn get_token(&self, token: &str) -> Result<entity::Token> {
let token = sqlx::query_as("SELECT * FROM tokens WHERE token = $1")
.bind(token)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(token)
}
pub async fn get_channel_by_id(&self, channel_id: entity::ShortId) -> Result<entity::Channel> {
let channel = sqlx::query_as("SELECT * FROM channel WHERE id = $1")
.bind(channel_id)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::ChannelDoesNotExists)?;
Ok(channel)
}
pub async fn get_all_user_channels(
&self,
user_id: entity::ShortId,
) -> Result<Vec<entity::Channel>> {
let channels = sqlx::query_as("SELECT channel.* FROM channel INNER JOIN channel_user ON channel.id = channel_user.channel_id WHERE user_id = $1")
.bind(user_id)
.fetch_all(&self.pool)
.await?;
Ok(channels)
}
pub async fn update_last_message_channel(
&self,
channel_id: entity::ShortId,
message_id: entity::LongId,
) -> Result<Channel> {
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
.bind(message_id)
.bind(channel_id)
.execute(&self.pool)
.await?;
self.get_channel_by_id(channel_id).await
}
pub async fn add_user_to_channel(
&self,
user_id: entity::ShortId,
channel_id: entity::ShortId,
admin: bool,
) -> Result<()> {
sqlx::query("INSERT INTO channel_user(user_id, channel_id, admin) VALUES ($1, $2, $3)")
.bind(user_id)
.bind(channel_id)
.bind(admin)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn remove_user_from_channel(
&self,
user_id: entity::ShortId,
channel_id: entity::ShortId,
) -> Result<()> {
sqlx::query("DELETE FROM channel_user WHERE user_id = $1 AND channel_id = $2")
.bind(user_id)
.bind(channel_id)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn create_channel(
&self,
user_id: entity::ShortId,
name: &str,
) -> Result<entity::Channel> {
let id = sqlx::query_scalar("INSERT INTO channel(name) VALUES ($1) RETURNING id")
.bind(name)
.fetch_one(&self.pool)
.await?;
self.add_user_to_channel(user_id, id, true).await?;
self.get_channel_by_id(id).await
}
pub async fn delete_channel(&self, channel_id: entity::ShortId) -> Result<()> {
sqlx::query("DELETE FROM channel WHERE id = $1")
.bind(channel_id)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn set_channel_last_message_id(
&self,
channel_id: entity::ShortId,
message_id: entity::LongId,
) -> Result<()> {
sqlx::query("UPDATE channel SET last_message_id = $1 WHERE id = $2")
.bind(message_id)
.bind(channel_id)
.execute(&self.pool)
.await?;
Ok(())
}
pub async fn get_channel_users(
&self,
channel_id: entity::ShortId,
) -> Result<Vec<(entity::User, bool)>> {
let user_ids =
sqlx::query_as("SELECT user_id, admin FROM channel_user WHERE channel_id = $1")
.bind(channel_id)
.fetch_all(&self.pool)
.await?;
let users = user_ids.iter().map(|(user_id, admin)| async move {
let user = self.get_user_by_id(*user_id).await;
user.map(|u| (u, *admin))
});
let users = futures::future::try_join_all(users).await?;
Ok(users)
}
pub async fn get_message_by_id(&self, message_id: entity::LongId) -> Result<entity::Message> {
let message = sqlx::query_as("SELECT * FROM message WHERE id = $1")
.bind(message_id)
.fetch_optional(&self.pool)
.await?
.ok_or(Error::MessageDoesNotExists)?;
Ok(message)
}
pub async fn create_message(
&self,
channel_id: entity::ShortId,
user_id: entity::ShortId,
content: &str,
) -> Result<entity::Message> {
let id = sqlx::query_scalar(
"INSERT INTO message(channel_id, author_id, content) VALUES ($1, $2, $3) RETURNING id",
)
.bind(channel_id)
.bind(user_id)
.bind(content)
.fetch_one(&self.pool)
.await?;
self.set_channel_last_message_id(channel_id, id).await?;
self.get_message_by_id(id).await
}
pub async fn get_messages_by_channel_id(
&self,
channel_id: entity::ShortId,
limit: i64,
before_id: Option<entity::LongId>,
) -> Result<Vec<entity::Message>> {
let messages = match before_id {
Some(before_id) => sqlx::query_as::<_, entity::Message>(
"SELECT * FROM message WHERE channel_id = $1 AND id < $2 ORDER BY id DESC LIMIT $3",
)
.bind(channel_id)
.bind(before_id)
.bind(limit)
.fetch_all(&self.pool)
.await?,
None => {
sqlx::query_as::<_, entity::Message>(
"SELECT * FROM message WHERE channel_id = $1 ORDER BY id DESC LIMIT $2",
)
.bind(channel_id)
.bind(limit)
.fetch_all(&self.pool)
.await?
}
};
Ok(messages)
}
}
pub type Result<T> = std::result::Result<T, Error>;
@@ -75,4 +284,8 @@ pub enum Error {
Sqlx(sqlx::Error),
UserDoesNotExists,
UserAlreadyExists,
ChannelDoesNotExists,
MessageDoesNotExists,
}

10
src/entity/channel.rs Normal file
View File

@@ -0,0 +1,10 @@
use serde::Serialize;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Channel {
pub id: super::ShortId,
pub name: String,
pub last_message_id: Option<super::LongId>,
pub created_at: chrono::DateTime<chrono::Utc>,
}

View File

@@ -1,9 +1,12 @@
use chrono::Utc;
use serde::Serialize;
#[derive(Clone, sqlx::FromRow)]
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Message {
pub id: super::LongId,
pub user_id: super::ShortId,
pub channel_id: super::ShortId,
pub author_id: super::ShortId,
pub content: String,
pub created_at: chrono::DateTime<Utc>,
}

View File

@@ -1,13 +1,17 @@
#![allow(unused)]
mod channel;
mod log;
mod message;
mod secret;
mod token;
mod user;
pub use channel::Channel;
pub use log::Log;
pub use message::Message;
pub use secret::Secret;
pub use token::Token;
pub use user::User;
pub type ShortId = i32;

View File

@@ -1,4 +1,6 @@
#[derive(Clone, sqlx::FromRow)]
use serde::Serialize;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
pub struct Secret {
pub id: super::ShortId,
pub user_id: super::ShortId,

12
src/entity/token.rs Normal file
View File

@@ -0,0 +1,12 @@
use serde::Serialize;
use super::ShortId;
#[derive(Debug, Clone, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct Token {
pub token: String,
pub user_id: ShortId,
pub created_at: chrono::DateTime<chrono::Utc>,
pub expires_at: chrono::DateTime<chrono::Utc>,
}

View File

@@ -1,9 +1,13 @@
use chrono::Utc;
use serde::Serialize;
#[derive(Clone, sqlx::FromRow)]
#[derive(Debug, Clone, Hash, sqlx::FromRow, Serialize)]
#[serde(rename_all = "camelCase")]
pub struct User {
pub id: super::ShortId,
pub username: String,
#[serde(skip_serializing)]
pub password: String,
pub avatar: Option<String>,
pub created_at: chrono::DateTime<Utc>,
}

View File

@@ -11,9 +11,9 @@ pub struct Claims {
pub user_id: i32,
}
pub fn generate_jwt(user_id: i32) -> Result<String> {
pub fn generate_jwt(user_id: i32, expires_at: chrono::DateTime<chrono::Utc>) -> Result<String> {
let claims = Claims {
exp: (Local::now() + Duration::days(1)).timestamp() as usize,
exp: expires_at.timestamp() as usize,
user_id,
};
@@ -28,6 +28,7 @@ pub fn generate_jwt(user_id: i32) -> Result<String> {
}
pub fn verify_jwt(token: &str) -> Result<i32> {
tracing::debug!("Verifying token: {}", token);
let token_data = jsonwebtoken::decode::<Claims>(
token,
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),

View File

@@ -14,7 +14,10 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let _guard = log::initialize()?;
let database = Database::init().await?;
let context = AppState { database };
let context = AppState {
database,
connected_users: Default::default(),
};
web::run(context).await?;

View File

@@ -1,6 +1,17 @@
use crate::database::Database;
use std::{collections::HashMap, sync::Arc};
use tokio::sync::{mpsc::UnboundedSender, RwLock};
use crate::{database::Database, entity::ShortId, web::ws};
#[derive(Debug, Clone, Hash, PartialEq, Eq)]
pub struct WebSocketKey {
pub user_id: ShortId,
pub token: String,
}
#[derive(Clone)]
pub struct AppState {
pub database: Database,
pub connected_users: Arc<RwLock<HashMap<WebSocketKey, UnboundedSender<ws::Message>>>>,
}

View File

@@ -2,9 +2,10 @@ use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
use crate::entity;
#[derive(Clone)]
#[derive(Debug, Clone, Hash)]
pub struct Context {
pub user_id: entity::ShortId,
pub user: entity::User,
pub token: String,
}
#[async_trait]
@@ -15,9 +16,9 @@ impl<S: Send + Sync> FromRequestParts<S> for Context {
parts
.extensions
.get::<ContextResult>()
.ok_or(super::Error::ContextError(Error::NotInRequest))?
.ok_or(super::Error::Context(Error::NotInRequest))?
.clone()
.map_err(super::Error::ContextError)
.map_err(super::Error::Context)
}
}

View File

@@ -12,18 +12,41 @@ pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, From)]
pub enum Error {
#[from]
ContextError(context::Error),
Context(context::Error),
#[from]
DatabaseError(database::Error),
Database(database::Error),
#[from]
JWT(jwt::Error),
Jwt(jwt::Error),
WrongPassword,
NotAllowed,
}
impl Error {
pub fn as_client_error(&self) -> ClientError {
ClientError::HahaError
pub fn as_client_error_and_status(&self) -> (ClientError, StatusCode) {
match self {
Error::Context(_) | Error::Jwt(_) => {
(ClientError::NotAuthorized, StatusCode::UNAUTHORIZED)
}
Error::Database(database::Error::UserAlreadyExists) => {
(ClientError::UserAlreadyExists, StatusCode::CONFLICT)
}
Error::Database(database::Error::UserDoesNotExists) => {
(ClientError::UserDoesNotExists, StatusCode::NOT_FOUND)
}
Error::Database(database::Error::ChannelDoesNotExists) => {
(ClientError::ChannelDoesNotExists, StatusCode::NOT_FOUND)
}
Error::Database(database::Error::MessageDoesNotExists) => {
(ClientError::MessageDoesNotExists, StatusCode::NOT_FOUND)
}
Error::WrongPassword => (ClientError::WrongPassword, StatusCode::UNAUTHORIZED),
Error::NotAllowed => (ClientError::NotAllowed, StatusCode::FORBIDDEN),
_ => (
ClientError::InternalServerError,
StatusCode::INTERNAL_SERVER_ERROR,
),
}
}
}
@@ -39,5 +62,22 @@ impl IntoResponse for Error {
#[derive(Debug, Error, Display)]
pub enum ClientError {
HahaError,
#[display(fmt = "Not authorized")]
NotAuthorized,
#[display(fmt = "Wrong password")]
WrongPassword,
#[display(fmt = "Not allowed")]
NotAllowed,
#[display(fmt = "User already exists")]
UserAlreadyExists,
#[display(fmt = "User does not exists")]
UserDoesNotExists,
#[display(fmt = "Channel does not exists")]
ChannelDoesNotExists,
#[display(fmt = "Message does not exists")]
MessageDoesNotExists,
#[display(fmt = "Internal server error")]
InternalServerError,
}

View File

@@ -2,11 +2,11 @@ use axum::{
extract::{Request, State},
http::{header, HeaderMap},
middleware::Next,
response::{IntoResponse, Response},
response::Response,
Extension,
};
use crate::{
database::Database,
jwt,
state::AppState,
web::{
@@ -16,7 +16,7 @@ use crate::{
};
pub async fn require_context(
context: ContextResult,
Extension(context): Extension<ContextResult>,
request: Request,
next: Next,
) -> web::Result<Response> {
@@ -46,12 +46,17 @@ async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextRe
return Err(context::Error::WrongTokenType);
}
let token = &token[7..];
get_context_from_token(state, token).await
}
pub async fn get_context_from_token(state: AppState, token: &str) -> context::ContextResult {
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
let _user = state
let user = state
.database
.get_user_by_id(user_id)
.await
.map_err(|_| context::Error::Model)?;
Ok(context::Context { user_id })
Ok(context::Context { user, token: token.to_owned() })
}

View File

@@ -1,6 +1,5 @@
mod auth;
mod response_map;
pub use auth::require_context;
pub use auth::resolve_context;
pub use response_map::response_map;
pub use auth::*;
pub use response_map::*;

View File

@@ -1,6 +1,7 @@
use std::sync::Arc;
use axum::{extract::Request, middleware::Next, response::IntoResponse};
use axum::{extract::Request, middleware::Next, response::IntoResponse, Json};
use serde_json::json;
use crate::web;
@@ -9,10 +10,18 @@ pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
let error = response.extensions().get::<Arc<web::Error>>();
let error_response = error.map(|e| {
let client_error = e.as_client_error();
if error.is_some() {
tracing::error!("{:?}", error);
}
client_error.to_string().into_response()
let error_response = error.map(|e| {
let client_error = e.as_client_error_and_status();
(
client_error.1,
Json(json!({"error": client_error.0.to_string()})),
)
.into_response()
});
error_response.unwrap_or(response)

View File

@@ -1,11 +1,13 @@
use crate::{config, state};
use crate::{config, state, web::routes::message};
mod context;
mod error;
pub mod middlware;
pub mod routes;
pub mod ws;
pub use error::{Error, Result};
use tower_http::cors::{AllowHeaders, AllowMethods, AllowOrigin};
pub async fn run(state: state::AppState) -> anyhow::Result<()> {
let config = config::config();
@@ -22,18 +24,66 @@ pub async fn run(state: state::AppState) -> anyhow::Result<()> {
}
fn router(state: state::AppState) -> axum::Router {
axum::Router::new()
.route("/user/login", axum::routing::post(routes::user::login))
.route(
"/user/register",
axum::routing::post(routes::user::register),
)
.with_state(state.clone())
.layer(axum::middleware::from_fn_with_state(
use self::routes::user;
use axum::routing::*;
use axum::Router;
let cors = tower_http::cors::CorsLayer::new()
.allow_origin(AllowOrigin::any())
.allow_methods(AllowMethods::any())
.allow_headers(AllowHeaders::any());
Router::new()
// websocket
.route("/ws/:token", get(ws::ws_handler))
// unprotected
.route("/user/login", post(user::login))
.route("/user/register", post(user::register))
// protected
.nest("/", protected_router())
.layer(tower_http::trace::TraceLayer::new_for_http())
.route_layer(axum::middleware::from_fn_with_state(
state.clone(),
middlware::resolve_context,
))
.layer(axum::middleware::from_fn(middlware::response_map))
.layer(cors)
.with_state(state.clone())
}
fn protected_router() -> axum::Router<state::AppState> {
use self::routes::channel;
use self::routes::user;
use axum::routing::*;
use axum::Router;
Router::new()
// user
.route("/user", get(user::get_by_username))
.route("/user/me", get(user::get_self))
.route("/user/:user_id", get(user::get_by_id))
// channel
.route("/channel", get(channel::get_all_user_channels))
.route("/channel/:channel_id", get(channel::get_by_id))
.route("/channel", post(channel::create_channel))
.route("/channel/:channel_id", delete(channel::delete_channel))
.route(
"/channel/:channel_id/user/:user_id",
post(channel::add_user_to_channel),
)
.route(
"/channel/:channel_id/user/:user_id",
delete(channel::remove_user_from_channel),
)
// message
.route("/message/:message_id", get(message::get_by_id))
.route("/message", post(message::create))
.route(
"/channel/:channel_id/message",
get(message::get_messages_by_channel_id),
)
// middleware
.route_layer(axum::middleware::from_fn(middlware::require_context))
}
async fn shutdown_signal() {

View File

@@ -0,0 +1,34 @@
use axum::{extract::State, response::IntoResponse, Json};
use crate::{
state::AppState,
web::{self, ws},
};
pub async fn create_channel(
State(state): State<AppState>,
context: web::context::Context,
Json(body): Json<CreateChannel>,
) -> web::Result<impl IntoResponse> {
let channel = state
.database
.create_channel(context.user.id, &body.name)
.await?;
let channel_users = state.database.get_channel_users(channel.id).await?;
ws::broadcast_message(ws::Message::CreateChannel(channel.clone()), &state, |key| {
channel_users
.iter()
.any(|(user, _admin)| user.id == key.user_id)
})
.await;
Ok(Json(channel))
}
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateChannel {
name: String,
}

View File

@@ -0,0 +1,40 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use crate::{
entity::ShortId,
state::AppState,
web::{self, ws},
};
pub async fn delete_channel(
State(state): State<AppState>,
Path(channel_id): Path<ShortId>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let channel_users = state.database.get_channel_users(channel_id).await?;
if !channel_users
.iter()
.any(|(user, admin)| user.id == context.user.id && *admin)
{
return Err(web::Error::NotAllowed);
}
state.database.delete_channel(channel_id).await?;
ws::broadcast_message(
ws::Message::DeleteChannel { id: channel_id },
&state,
|key| {
channel_users
.iter()
.any(|(user, _admin)| user.id == key.user_id)
},
)
.await;
Ok(())
}

View File

@@ -0,0 +1,37 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
Json,
};
use crate::{entity::ShortId, state::AppState, web};
pub async fn get_by_id(
State(state): State<AppState>,
Path(channel_id): Path<ShortId>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let channel = state.database.get_channel_by_id(channel_id).await?;
let channel_users = state.database.get_channel_users(channel_id).await?;
if !channel_users
.into_iter()
.any(|(user, _admin)| user.id == context.user.id)
{
return Err(web::Error::NotAllowed);
}
Ok(Json(channel))
}
pub async fn get_all_user_channels(
State(state): State<AppState>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let channels = state
.database
.get_all_user_channels(context.user.id)
.await?;
Ok(Json(channels))
}

View File

@@ -0,0 +1,9 @@
mod create;
mod delete;
mod get;
mod user;
pub use create::*;
pub use delete::*;
pub use get::*;
pub use user::*;

View File

@@ -0,0 +1,50 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
};
use crate::{entity::ShortId, state::AppState, web};
pub async fn add_user_to_channel(
State(state): State<AppState>,
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let channel_users = state.database.get_channel_users(channel_id).await?;
if !channel_users
.into_iter()
.any(|(user, admin)| user.id == context.user.id && admin)
{
return Err(web::Error::NotAllowed);
}
state
.database
.add_user_to_channel(user_id, channel_id, false)
.await?;
Ok(())
}
pub async fn remove_user_from_channel(
State(state): State<AppState>,
Path((channel_id, user_id)): Path<(ShortId, ShortId)>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let channel_users = state.database.get_channel_users(channel_id).await?;
if !channel_users
.into_iter()
.any(|(user, admin)| user.id == context.user.id && admin)
{
return Err(web::Error::NotAllowed);
}
state
.database
.remove_user_from_channel(user_id, channel_id)
.await?;
Ok(())
}

View File

@@ -0,0 +1,54 @@
use axum::{extract::State, response::IntoResponse, Json};
use crate::{
entity::ShortId,
state::AppState,
web::{self, ws},
};
pub async fn create(
State(state): State<AppState>,
context: web::context::Context,
Json(body): Json<CreateMessage>,
) -> web::Result<impl IntoResponse> {
let channel_users = state.database.get_channel_users(body.channel_id).await?;
if !channel_users
.iter()
.any(|(user, _admin)| user.id == context.user.id)
{
return Err(web::Error::NotAllowed);
}
let message = state
.database
.create_message(body.channel_id, context.user.id, &body.content)
.await?;
let channel = state
.database
.update_last_message_channel(body.channel_id, message.id)
.await?;
ws::broadcast_message(ws::Message::CreateMessage(message.clone()), &state, |key| {
channel_users
.iter()
.any(|(user, _admin)| user.id == key.user_id)
})
.await;
ws::broadcast_message(ws::Message::UpdateChannel(channel), &state, |key| {
channel_users
.iter()
.any(|(user, _admin)| user.id == key.user_id)
})
.await;
Ok(Json(message))
}
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct CreateMessage {
content: String,
channel_id: ShortId,
}

View File

@@ -0,0 +1,61 @@
use axum::{
extract::{Path, Query, State},
response::IntoResponse,
Json,
};
use crate::{
entity::{LongId, ShortId},
state::AppState,
web,
};
pub async fn get_by_id(
State(state): State<AppState>,
Path(message_id): Path<LongId>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let message = state.database.get_message_by_id(message_id).await?;
if message.author_id != context.user.id {
let channel_users = state.database.get_channel_users(message.channel_id).await?;
if !channel_users
.into_iter()
.any(|(user, _admin)| user.id == context.user.id)
{
return Err(web::Error::NotAllowed);
}
}
Ok(Json(message))
}
pub async fn get_messages_by_channel_id(
State(state): State<AppState>,
Path(channel_id): Path<ShortId>,
context: web::context::Context,
Query(body): Query<GetByChannelId>,
) -> web::Result<impl IntoResponse> {
let channel_users = state.database.get_channel_users(channel_id).await?;
if !channel_users
.into_iter()
.any(|(user, _admin)| user.id == context.user.id)
{
return Err(web::Error::NotAllowed);
}
let messages = state
.database
.get_messages_by_channel_id(channel_id, body.limit.unwrap_or(50), body.before_id)
.await?;
Ok(Json(messages))
}
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetByChannelId {
before_id: Option<LongId>,
limit: Option<i64>,
}

View File

@@ -0,0 +1,5 @@
mod create;
mod get;
pub use create::*;
pub use get::*;

View File

@@ -1 +1,3 @@
pub mod channel;
pub mod message;
pub mod user;

View File

@@ -0,0 +1,40 @@
use axum::{
extract::{Path, State},
response::IntoResponse,
Json,
};
use crate::{entity::ShortId, state::AppState, web};
pub async fn get_self(
State(state): State<AppState>,
context: web::context::Context,
) -> web::Result<impl IntoResponse> {
let user = state.database.get_user_by_id(context.user.id).await?;
Ok(Json(user))
}
pub async fn get_by_id(
State(state): State<AppState>,
Path(user_id): Path<ShortId>,
) -> web::Result<impl IntoResponse> {
let user = state.database.get_user_by_id(user_id).await?;
Ok(Json(user))
}
pub async fn get_by_username(
State(state): State<AppState>,
Json(body): Json<GetByUsernameRequest>,
) -> web::Result<impl IntoResponse> {
let user = state.database.get_user_by_username(&body.username).await?;
Ok(Json(user))
}
#[derive(serde::Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct GetByUsernameRequest {
pub username: String,
}

View File

@@ -1,6 +1,6 @@
use axum::{extract::State, response::IntoResponse, Json};
use chrono::{Duration, Utc};
use serde::Deserialize;
use serde_json::json;
use crate::{jwt, state::AppState, web};
@@ -17,14 +17,20 @@ pub async fn login(
return Err(web::Error::WrongPassword);
}
let token = jwt::generate_jwt(user.id)?;
let expires_at = Utc::now() + Duration::hours(24);
Ok(Json(json!({
"token": token
})))
let token = jwt::generate_jwt(user.id, expires_at)?;
let token = state
.database
.create_token(user.id, &token, expires_at)
.await?;
Ok(Json(token))
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct LoginPayload {
username: String,
password: String,

View File

@@ -1,5 +1,7 @@
mod get;
mod login;
mod register;
pub use get::*;
pub use login::login;
pub use register::register;

View File

@@ -21,6 +21,7 @@ pub async fn register(
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct RegisterPayload {
username: String,
password: String,

101
src/web/ws/mod.rs Normal file
View File

@@ -0,0 +1,101 @@
use axum::{
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use crate::{
entity,
state::{AppState, WebSocketKey},
};
use super::{context, error, middlware};
#[derive(serde::Serialize, Clone)]
#[serde(rename_all = "camelCase")]
#[serde(tag = "type", content = "data")]
pub enum Message {
CreateMessage(entity::Message),
UpdateChannel(entity::Channel),
CreateChannel(entity::Channel),
DeleteChannel { id: entity::ShortId },
}
pub async fn broadcast_message(
message: Message,
state: &AppState,
predicate: impl Fn(&WebSocketKey) -> bool,
) {
let connected_users = state.connected_users.read().await;
let recievers =
connected_users
.iter()
.filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None });
for reciever in recievers {
_ = reciever
.send(message.clone())
.inspect_err(|err| tracing::error!("Failed to send message: {}", err));
}
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Path(token): Path<String>,
) -> error::Result<impl IntoResponse> {
let context = middlware::get_context_from_token(state.clone(), &token).await?;
Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context)))
}
async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) {
let user_key = WebSocketKey {
token: context.token.clone(),
user_id: context.user.id,
};
let (mut sender, _) = websocket.split();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
{
let mut connected_users = state.connected_users.write().await;
if connected_users.contains_key(&user_key) {
tracing::trace!("websocket already connected: {user_key:?}");
drop(connected_users.remove(&user_key));
}
connected_users.insert(user_key.clone(), tx);
}
tracing::trace!("websocket connected: {user_key:?}");
tokio::spawn(async move {
// idk
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
while let Some(message) = rx.recv().await {
let err = sender
.send(axum::extract::ws::Message::Text(
serde_json::to_string(&message)
.inspect_err(|e| tracing::error!("Could not serialize message: {e}"))
.unwrap(),
))
.await
.inspect_err(|e| tracing::error!("Could not send message: {e}"));
if err.is_err() {
break;
}
}
let _ = sender.send(axum::extract::ws::Message::Close(None)).await;
{
let mut connected_users = state.connected_users.write().await;
connected_users.remove(&user_key);
tracing::trace!("websocket disconnected: {user_key:?}");
}
});
}