From caae24d26034f1f4796633277384b3dc06fea968 Mon Sep 17 00:00:00 2001 From: lionarius Date: Tue, 23 Apr 2024 12:08:37 +0300 Subject: [PATCH] . --- Cargo.lock | 35 +++++++++++ Cargo.toml | 1 + build.rs | 3 + config.toml | 2 +- migrations/20240423082838_user_table.sql | 6 ++ src/context.rs | 6 -- src/database.rs | 74 +++++++++++++++++------- src/entity/user.rs | 6 +- src/jwt.rs | 18 ++++-- src/log.rs | 7 ++- src/main.rs | 7 ++- src/state.rs | 6 ++ src/web/context.rs | 34 +++++++++++ src/web/error.rs | 43 ++++++++++++++ src/web/middlware/auth.rs | 56 +++++++++++++++++- src/web/middlware/mod.rs | 3 +- src/web/middlware/response_map.rs | 16 ++++- src/web/mod.rs | 29 ++++++---- src/web/routes/login.rs | 5 -- src/web/routes/mod.rs | 4 +- src/web/routes/user/login.rs | 31 ++++++++++ src/web/routes/user/mod.rs | 5 ++ src/web/routes/user/register.rs | 27 +++++++++ 23 files changed, 363 insertions(+), 61 deletions(-) create mode 100644 build.rs create mode 100644 migrations/20240423082838_user_table.sql delete mode 100644 src/context.rs create mode 100644 src/state.rs create mode 100644 src/web/context.rs delete mode 100644 src/web/routes/login.rs create mode 100644 src/web/routes/user/login.rs create mode 100644 src/web/routes/user/mod.rs create mode 100644 src/web/routes/user/register.rs diff --git a/Cargo.lock b/Cargo.lock index 3784379..c4747ee 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -271,6 +271,12 @@ dependencies = [ "windows-targets 0.52.5", ] +[[package]] +name = "convert_case" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e" + [[package]] name = "core-foundation-sys" version = "0.8.6" @@ -344,6 +350,19 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_more" +version = "0.99.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321" +dependencies = [ + "convert_case", + "proc-macro2", + "quote", + "rustc_version", + "syn 1.0.109", +] + [[package]] name = "digest" version = "0.10.7" @@ -910,6 +929,7 @@ dependencies = [ "anyhow", "axum", "chrono", + "derive_more", "figment", "jsonwebtoken", "serde", @@ -1320,6 +1340,15 @@ version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustls" version = "0.20.9" @@ -1369,6 +1398,12 @@ dependencies = [ "untrusted 0.9.0", ] +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "serde" version = "1.0.198" diff --git a/Cargo.toml b/Cargo.toml index d49b62d..c8450f1 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -7,6 +7,7 @@ anyhow = "1.0.82" axum = "0.7.5" chrono = { version = "0.4.38", features = ["serde"] } +derive_more = "0.99.17" figment = { version = "0.10.18", features = ["env", "toml"] } jsonwebtoken = "9.3.0" serde = { version = "1.0.198", features = ["derive"] } diff --git a/build.rs b/build.rs new file mode 100644 index 0000000..3a8149e --- /dev/null +++ b/build.rs @@ -0,0 +1,3 @@ +fn main() { + println!("cargo:rerun-if-changed=migrations"); +} diff --git a/config.toml b/config.toml index 1442037..853f43c 100644 --- a/config.toml +++ b/config.toml @@ -1,5 +1,5 @@ jwt_secret = "secret" -port = 6666 +port = 1234 [database] max_connections = 5 diff --git a/migrations/20240423082838_user_table.sql b/migrations/20240423082838_user_table.sql new file mode 100644 index 0000000..eb42a66 --- /dev/null +++ b/migrations/20240423082838_user_table.sql @@ -0,0 +1,6 @@ +CREATE TABLE IF NOT EXISTS users ( + id INTEGER PRIMARY KEY AUTOINCREMENT, + username VARCHAR UNIQUE, + password VARCHAR, + created_at DATETIME DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/src/context.rs b/src/context.rs deleted file mode 100644 index c4ab855..0000000 --- a/src/context.rs +++ /dev/null @@ -1,6 +0,0 @@ -use crate::database::Database; - -#[derive(Clone)] -pub struct Context { - pub database: D, -} diff --git a/src/database.rs b/src/database.rs index 0d946cf..6258102 100644 --- a/src/database.rs +++ b/src/database.rs @@ -1,20 +1,17 @@ +use derive_more::{Display, Error, From}; +use sqlx::migrate::Migrator; + use crate::{config, entity}; -pub trait Database { - async fn init() -> anyhow::Result - where - Self: Sized; +static MIGRATOR: Migrator = sqlx::migrate!("./migrations"); - async fn get_user_by_id(&self, id: entity::ShortId) -> anyhow::Result>; - - async fn get_user_by_login(&self, login: &str) -> anyhow::Result>; +#[derive(Clone)] +pub struct Database { + pool: sqlx::AnyPool, } -impl Database for sqlx::AnyPool { - async fn init() -> anyhow::Result - where - Self: Sized, - { +impl Database { + pub async fn init() -> Result { let config = config::config(); let pool = sqlx::any::AnyPoolOptions::new() @@ -23,24 +20,59 @@ impl Database for sqlx::AnyPool { .await .inspect_err(|e| tracing::error!("Could not connect to database: {e}"))?; - Ok(pool) + MIGRATOR.run(&pool).await?; + + Ok(Self { pool }) } - async fn get_user_by_id(&self, id: entity::ShortId) -> anyhow::Result> { + pub async fn get_user_by_id(&self, id: entity::ShortId) -> Result { let user = sqlx::query_as("SELECT * FROM users WHERE id = $1") .bind(id) - .fetch_optional(self) - .await?; + .fetch_optional(&self.pool) + .await? + .ok_or(Error::UserDoesNotExists)?; Ok(user) } - async fn get_user_by_login(&self, login: &str) -> anyhow::Result> { - let user = sqlx::query_as("SELECT * FROM users WHERE login = $1") - .bind(login) - .fetch_optional(self) - .await?; + pub async fn get_user_by_username(&self, username: &str) -> Result { + let user = sqlx::query_as("SELECT * FROM users WHERE username = $1") + .bind(username) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::UserDoesNotExists)?; + + Ok(user) + } + + pub async fn create_user(&self, username: &str, password: &str) -> Result { + let user = self.get_user_by_username(username).await; + if user.is_ok() { + return Err(Error::UserAlreadyExists); + } + + let id = sqlx::query_scalar( + "INSERT INTO users(username, password) VALUES ($1, $2) RETURNING id", + ) + .bind(username) + .bind(password) + .fetch_one(&self.pool) + .await?; + + let user = self.get_user_by_id(id).await?; Ok(user) } } + +pub type Result = std::result::Result; + +#[derive(Debug, From, Error, Display)] +pub enum Error { + #[from] + Migrate(sqlx::migrate::MigrateError), + #[from] + Sqlx(sqlx::Error), + UserDoesNotExists, + UserAlreadyExists, +} diff --git a/src/entity/user.rs b/src/entity/user.rs index e1e1d81..3bac233 100644 --- a/src/entity/user.rs +++ b/src/entity/user.rs @@ -1,7 +1,9 @@ -#[derive(sqlx::FromRow)] +use chrono::Utc; + +#[derive(Clone, sqlx::FromRow)] pub struct User { pub id: super::ShortId, pub username: String, pub password: String, - pub created_at: chrono::NaiveDateTime, + pub created_at: chrono::DateTime, } diff --git a/src/jwt.rs b/src/jwt.rs index 2d36617..1d4fe85 100644 --- a/src/jwt.rs +++ b/src/jwt.rs @@ -11,7 +11,7 @@ pub struct Claims { pub user_id: i32, } -pub fn generate_jwt(user_id: i32) -> anyhow::Result { +pub fn generate_jwt(user_id: i32) -> Result { let claims = Claims { exp: (Local::now() + Duration::days(1)).timestamp() as usize, user_id, @@ -21,17 +21,27 @@ pub fn generate_jwt(user_id: i32) -> anyhow::Result { &jsonwebtoken::Header::default(), &claims, &jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()), - )?; + ) + .map_err(|_| Error::CouldNotEncodeToken)?; Ok(token) } -pub fn verify_jwt(token: &str) -> anyhow::Result { +pub fn verify_jwt(token: &str) -> Result { let token_data = jsonwebtoken::decode::( token, &jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()), &jsonwebtoken::Validation::default(), - )?; + ) + .map_err(|_| Error::CouldNotVerifyToken)?; Ok(token_data.claims.user_id) } + +pub type Result = std::result::Result; + +#[derive(Debug, derive_more::Error, derive_more::Display)] +pub enum Error { + CouldNotEncodeToken, + CouldNotVerifyToken, +} diff --git a/src/log.rs b/src/log.rs index a444d44..fff7274 100644 --- a/src/log.rs +++ b/src/log.rs @@ -1,6 +1,7 @@ use std::{fs, io}; use chrono::Local; +use tracing::level_filters::LevelFilter; use tracing_appender::non_blocking::WorkerGuard; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; @@ -24,7 +25,11 @@ pub fn initialize() -> io::Result { tracing_subscriber::registry() .with(file_logger) .with(stdout_logger) - .with(tracing_subscriber::EnvFilter::from_default_env()) + .with( + tracing_subscriber::EnvFilter::builder() + .with_default_directive(LevelFilter::TRACE.into()) + .from_env_lossy(), + ) .init(); Ok(LogGuard { diff --git a/src/main.rs b/src/main.rs index 9ff8c7a..25e9ae7 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,19 +1,20 @@ use database::Database; +use state::AppState; mod config; -mod context; mod database; mod entity; mod jwt; mod log; +mod state; mod web; #[tokio::main] async fn main() -> Result<(), Box> { let _guard = log::initialize()?; - let database = sqlx::any::AnyPool::init().await?; - let context = context::Context { database }; + let database = Database::init().await?; + let context = AppState { database }; web::run(context).await?; diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..1b6fd0b --- /dev/null +++ b/src/state.rs @@ -0,0 +1,6 @@ +use crate::database::Database; + +#[derive(Clone)] +pub struct AppState { + pub database: Database, +} diff --git a/src/web/context.rs b/src/web/context.rs new file mode 100644 index 0000000..85d8ca9 --- /dev/null +++ b/src/web/context.rs @@ -0,0 +1,34 @@ +use axum::{async_trait, extract::FromRequestParts, http::request::Parts}; + +use crate::entity; + +#[derive(Clone)] +pub struct Context { + pub user_id: entity::ShortId, +} + +#[async_trait] +impl FromRequestParts for Context { + type Rejection = super::Error; + + async fn from_request_parts(parts: &mut Parts, _state: &S) -> super::Result { + parts + .extensions + .get::() + .ok_or(super::Error::ContextError(Error::NotInRequest))? + .clone() + .map_err(super::Error::ContextError) + } +} + +pub type ContextResult = std::result::Result; + +#[derive(Debug, Clone, Copy)] +pub enum Error { + NotInRequest, + NotInHeader, + BadCharacters, + WrongTokenType, + BadToken, + Model, +} diff --git a/src/web/error.rs b/src/web/error.rs index e69de29..46d92c6 100644 --- a/src/web/error.rs +++ b/src/web/error.rs @@ -0,0 +1,43 @@ +use std::sync::Arc; + +use axum::{http::StatusCode, response::IntoResponse}; +use derive_more::{Display, Error, From}; + +use crate::{database, jwt}; + +use super::context; + +pub type Result = std::result::Result; + +#[derive(Debug, From)] +pub enum Error { + #[from] + ContextError(context::Error), + #[from] + DatabaseError(database::Error), + #[from] + JWT(jwt::Error), + + WrongPassword, +} + +impl Error { + pub fn as_client_error(&self) -> ClientError { + ClientError::HahaError + } +} + +impl IntoResponse for Error { + fn into_response(self) -> axum::response::Response { + let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response(); + + response.extensions_mut().insert(Arc::new(self)); + + response + } +} + +#[derive(Debug, Error, Display)] +pub enum ClientError { + HahaError, +} diff --git a/src/web/middlware/auth.rs b/src/web/middlware/auth.rs index 3da585a..3661fe8 100644 --- a/src/web/middlware/auth.rs +++ b/src/web/middlware/auth.rs @@ -1,5 +1,57 @@ -use axum::{extract::Request, middleware::Next, response::IntoResponse}; +use axum::{ + extract::{Request, State}, + http::{header, HeaderMap}, + middleware::Next, + response::{IntoResponse, Response}, +}; + +use crate::{ + database::Database, + jwt, + state::AppState, + web::{ + self, + context::{self, ContextResult}, + }, +}; + +pub async fn require_context( + context: ContextResult, + request: Request, + next: Next, +) -> web::Result { + context?; + + Ok(next.run(request).await) +} + +pub async fn resolve_context( + State(state): State, + mut request: Request, + next: Next, +) -> Response { + let context = get_context(state, request.headers()).await; + + request.extensions_mut().insert(context); -pub async fn auth(request: Request, next: Next) -> impl IntoResponse { next.run(request).await } + +async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextResult { + let header = headers + .get(header::AUTHORIZATION) + .ok_or(context::Error::NotInHeader)?; + let token = header.to_str().map_err(|_| context::Error::BadCharacters)?; + if !token.starts_with("Bearer ") { + return Err(context::Error::WrongTokenType); + } + let token = &token[7..]; + let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?; + let _user = state + .database + .get_user_by_id(user_id) + .await + .map_err(|_| context::Error::Model)?; + + Ok(context::Context { user_id }) +} diff --git a/src/web/middlware/mod.rs b/src/web/middlware/mod.rs index 307d299..53eacf5 100644 --- a/src/web/middlware/mod.rs +++ b/src/web/middlware/mod.rs @@ -1,5 +1,6 @@ mod auth; mod response_map; -pub use auth::auth; +pub use auth::require_context; +pub use auth::resolve_context; pub use response_map::response_map; diff --git a/src/web/middlware/response_map.rs b/src/web/middlware/response_map.rs index 28bc1eb..cefe04d 100644 --- a/src/web/middlware/response_map.rs +++ b/src/web/middlware/response_map.rs @@ -1,5 +1,19 @@ +use std::sync::Arc; + use axum::{extract::Request, middleware::Next, response::IntoResponse}; +use crate::web; + pub async fn response_map(request: Request, next: Next) -> impl IntoResponse { - next.run(request).await + let response = next.run(request).await; + + let error = response.extensions().get::>(); + + let error_response = error.map(|e| { + let client_error = e.as_client_error(); + + client_error.to_string().into_response() + }); + + error_response.unwrap_or(response) } diff --git a/src/web/mod.rs b/src/web/mod.rs index c601c62..e9482d1 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,31 +1,38 @@ -use crate::{config, context, database}; +use crate::{config, state}; -pub mod error; +mod context; +mod error; pub mod middlware; pub mod routes; -pub async fn run( - context: context::Context, -) -> anyhow::Result<()> { +pub use error::{Error, Result}; + +pub async fn run(state: state::AppState) -> anyhow::Result<()> { let config = config::config(); let addr: std::net::SocketAddr = ([0, 0, 0, 0], config.port).into(); tracing::info!("Listening on {}", addr); let listener = tokio::net::TcpListener::bind(addr).await?; - axum::serve(listener, router(context)) + axum::serve(listener, router(state)) .with_graceful_shutdown(shutdown_signal()) .await?; Ok(()) } -fn router( - context: context::Context, -) -> axum::Router { +fn router(state: state::AppState) -> axum::Router { axum::Router::new() - .route("/login", axum::routing::post(routes::login)) - .with_state(context) + .route("/user/login", axum::routing::post(routes::user::login)) + .route( + "/user/register", + axum::routing::post(routes::user::register), + ) + .with_state(state.clone()) + .layer(axum::middleware::from_fn_with_state( + state.clone(), + middlware::resolve_context, + )) .layer(axum::middleware::from_fn(middlware::response_map)) } diff --git a/src/web/routes/login.rs b/src/web/routes/login.rs deleted file mode 100644 index 579da41..0000000 --- a/src/web/routes/login.rs +++ /dev/null @@ -1,5 +0,0 @@ -use axum::response::IntoResponse; - -pub async fn login() -> impl IntoResponse { - "login" -} diff --git a/src/web/routes/mod.rs b/src/web/routes/mod.rs index 02b9c33..22d12a3 100644 --- a/src/web/routes/mod.rs +++ b/src/web/routes/mod.rs @@ -1,3 +1 @@ -mod login; - -pub use login::login; \ No newline at end of file +pub mod user; diff --git a/src/web/routes/user/login.rs b/src/web/routes/user/login.rs new file mode 100644 index 0000000..eab55b4 --- /dev/null +++ b/src/web/routes/user/login.rs @@ -0,0 +1,31 @@ +use axum::{extract::State, response::IntoResponse, Json}; +use serde::Deserialize; +use serde_json::json; + +use crate::{jwt, state::AppState, web}; + +pub async fn login( + State(state): State, + Json(payload): Json, +) -> web::Result { + let user = state + .database + .get_user_by_username(&payload.username) + .await?; + + if user.password != payload.password { + return Err(web::Error::WrongPassword); + } + + let token = jwt::generate_jwt(user.id)?; + + Ok(Json(json!({ + "token": token + }))) +} + +#[derive(Deserialize)] +pub struct LoginPayload { + username: String, + password: String, +} diff --git a/src/web/routes/user/mod.rs b/src/web/routes/user/mod.rs new file mode 100644 index 0000000..6de95f2 --- /dev/null +++ b/src/web/routes/user/mod.rs @@ -0,0 +1,5 @@ +mod login; +mod register; + +pub use login::login; +pub use register::register; diff --git a/src/web/routes/user/register.rs b/src/web/routes/user/register.rs new file mode 100644 index 0000000..6ec4af1 --- /dev/null +++ b/src/web/routes/user/register.rs @@ -0,0 +1,27 @@ +use axum::{extract::State, response::IntoResponse, Json}; +use serde::Deserialize; +use serde_json::json; + +use crate::{state::AppState, web}; + +pub async fn register( + State(state): State, + Json(payload): Json, +) -> web::Result { + let user = state + .database + .create_user(&payload.username, &payload.password) + .await?; + + Ok(Json(json!({ + "id": user.id, + "username": user.username, + "created_at": user.created_at + }))) +} + +#[derive(Deserialize)] +pub struct RegisterPayload { + username: String, + password: String, +}