1
0
This commit is contained in:
2024-04-23 12:08:37 +03:00
parent 95719b299e
commit caae24d260
23 changed files with 363 additions and 61 deletions

35
Cargo.lock generated
View File

@@ -271,6 +271,12 @@ dependencies = [
"windows-targets 0.52.5", "windows-targets 0.52.5",
] ]
[[package]]
name = "convert_case"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6245d59a3e82a7fc217c5828a6692dbc6dfb63a0c8c90495621f7b9d79704a0e"
[[package]] [[package]]
name = "core-foundation-sys" name = "core-foundation-sys"
version = "0.8.6" version = "0.8.6"
@@ -344,6 +350,19 @@ dependencies = [
"powerfmt", "powerfmt",
] ]
[[package]]
name = "derive_more"
version = "0.99.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4fb810d30a7c1953f91334de7244731fc3f3c10d7fe163338a35b9f640960321"
dependencies = [
"convert_case",
"proc-macro2",
"quote",
"rustc_version",
"syn 1.0.109",
]
[[package]] [[package]]
name = "digest" name = "digest"
version = "0.10.7" version = "0.10.7"
@@ -910,6 +929,7 @@ dependencies = [
"anyhow", "anyhow",
"axum", "axum",
"chrono", "chrono",
"derive_more",
"figment", "figment",
"jsonwebtoken", "jsonwebtoken",
"serde", "serde",
@@ -1320,6 +1340,15 @@ version = "0.1.23"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76" checksum = "d626bb9dae77e28219937af045c257c28bfd3f69333c512553507f5f9798cb76"
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]] [[package]]
name = "rustls" name = "rustls"
version = "0.20.9" version = "0.20.9"
@@ -1369,6 +1398,12 @@ dependencies = [
"untrusted 0.9.0", "untrusted 0.9.0",
] ]
[[package]]
name = "semver"
version = "1.0.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca"
[[package]] [[package]]
name = "serde" name = "serde"
version = "1.0.198" version = "1.0.198"

View File

@@ -7,6 +7,7 @@
anyhow = "1.0.82" anyhow = "1.0.82"
axum = "0.7.5" axum = "0.7.5"
chrono = { version = "0.4.38", features = ["serde"] } chrono = { version = "0.4.38", features = ["serde"] }
derive_more = "0.99.17"
figment = { version = "0.10.18", features = ["env", "toml"] } figment = { version = "0.10.18", features = ["env", "toml"] }
jsonwebtoken = "9.3.0" jsonwebtoken = "9.3.0"
serde = { version = "1.0.198", features = ["derive"] } serde = { version = "1.0.198", features = ["derive"] }

3
build.rs Normal file
View File

@@ -0,0 +1,3 @@
fn main() {
println!("cargo:rerun-if-changed=migrations");
}

View File

@@ -1,5 +1,5 @@
jwt_secret = "secret" jwt_secret = "secret"
port = 6666 port = 1234
[database] [database]
max_connections = 5 max_connections = 5

View File

@@ -0,0 +1,6 @@
CREATE TABLE IF NOT EXISTS users (
id INTEGER PRIMARY KEY AUTOINCREMENT,
username VARCHAR UNIQUE,
password VARCHAR,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);

View File

@@ -1,6 +0,0 @@
use crate::database::Database;
#[derive(Clone)]
pub struct Context<D: Database> {
pub database: D,
}

View File

@@ -1,20 +1,17 @@
use derive_more::{Display, Error, From};
use sqlx::migrate::Migrator;
use crate::{config, entity}; use crate::{config, entity};
pub trait Database { static MIGRATOR: Migrator = sqlx::migrate!("./migrations");
async fn init() -> anyhow::Result<Self>
where
Self: Sized;
async fn get_user_by_id(&self, id: entity::ShortId) -> anyhow::Result<Option<entity::User>>; #[derive(Clone)]
pub struct Database {
async fn get_user_by_login(&self, login: &str) -> anyhow::Result<Option<entity::User>>; pool: sqlx::AnyPool,
} }
impl Database for sqlx::AnyPool { impl Database {
async fn init() -> anyhow::Result<Self> pub async fn init() -> Result<Self> {
where
Self: Sized,
{
let config = config::config(); let config = config::config();
let pool = sqlx::any::AnyPoolOptions::new() let pool = sqlx::any::AnyPoolOptions::new()
@@ -23,24 +20,59 @@ impl Database for sqlx::AnyPool {
.await .await
.inspect_err(|e| tracing::error!("Could not connect to database: {e}"))?; .inspect_err(|e| tracing::error!("Could not connect to database: {e}"))?;
Ok(pool) MIGRATOR.run(&pool).await?;
Ok(Self { pool })
} }
async fn get_user_by_id(&self, id: entity::ShortId) -> anyhow::Result<Option<entity::User>> { pub async fn get_user_by_id(&self, id: entity::ShortId) -> Result<entity::User> {
let user = sqlx::query_as("SELECT * FROM users WHERE id = $1") let user = sqlx::query_as("SELECT * FROM users WHERE id = $1")
.bind(id) .bind(id)
.fetch_optional(self) .fetch_optional(&self.pool)
.await?; .await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user) Ok(user)
} }
async fn get_user_by_login(&self, login: &str) -> anyhow::Result<Option<entity::User>> { pub async fn get_user_by_username(&self, username: &str) -> Result<entity::User> {
let user = sqlx::query_as("SELECT * FROM users WHERE login = $1") let user = sqlx::query_as("SELECT * FROM users WHERE username = $1")
.bind(login) .bind(username)
.fetch_optional(self) .fetch_optional(&self.pool)
.await?
.ok_or(Error::UserDoesNotExists)?;
Ok(user)
}
pub async fn create_user(&self, username: &str, password: &str) -> Result<entity::User> {
let user = self.get_user_by_username(username).await;
if user.is_ok() {
return Err(Error::UserAlreadyExists);
}
let id = sqlx::query_scalar(
"INSERT INTO users(username, password) VALUES ($1, $2) RETURNING id",
)
.bind(username)
.bind(password)
.fetch_one(&self.pool)
.await?; .await?;
let user = self.get_user_by_id(id).await?;
Ok(user) Ok(user)
} }
} }
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, From, Error, Display)]
pub enum Error {
#[from]
Migrate(sqlx::migrate::MigrateError),
#[from]
Sqlx(sqlx::Error),
UserDoesNotExists,
UserAlreadyExists,
}

View File

@@ -1,7 +1,9 @@
#[derive(sqlx::FromRow)] use chrono::Utc;
#[derive(Clone, sqlx::FromRow)]
pub struct User { pub struct User {
pub id: super::ShortId, pub id: super::ShortId,
pub username: String, pub username: String,
pub password: String, pub password: String,
pub created_at: chrono::NaiveDateTime, pub created_at: chrono::DateTime<Utc>,
} }

View File

@@ -11,7 +11,7 @@ pub struct Claims {
pub user_id: i32, pub user_id: i32,
} }
pub fn generate_jwt(user_id: i32) -> anyhow::Result<String> { pub fn generate_jwt(user_id: i32) -> Result<String> {
let claims = Claims { let claims = Claims {
exp: (Local::now() + Duration::days(1)).timestamp() as usize, exp: (Local::now() + Duration::days(1)).timestamp() as usize,
user_id, user_id,
@@ -21,17 +21,27 @@ pub fn generate_jwt(user_id: i32) -> anyhow::Result<String> {
&jsonwebtoken::Header::default(), &jsonwebtoken::Header::default(),
&claims, &claims,
&jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()), &jsonwebtoken::EncodingKey::from_secret(config::config().jwt_secret.as_ref()),
)?; )
.map_err(|_| Error::CouldNotEncodeToken)?;
Ok(token) Ok(token)
} }
pub fn verify_jwt(token: &str) -> anyhow::Result<i32> { pub fn verify_jwt(token: &str) -> Result<i32> {
let token_data = jsonwebtoken::decode::<Claims>( let token_data = jsonwebtoken::decode::<Claims>(
token, token,
&jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()), &jsonwebtoken::DecodingKey::from_secret(config::config().jwt_secret.as_ref()),
&jsonwebtoken::Validation::default(), &jsonwebtoken::Validation::default(),
)?; )
.map_err(|_| Error::CouldNotVerifyToken)?;
Ok(token_data.claims.user_id) Ok(token_data.claims.user_id)
} }
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, derive_more::Error, derive_more::Display)]
pub enum Error {
CouldNotEncodeToken,
CouldNotVerifyToken,
}

View File

@@ -1,6 +1,7 @@
use std::{fs, io}; use std::{fs, io};
use chrono::Local; use chrono::Local;
use tracing::level_filters::LevelFilter;
use tracing_appender::non_blocking::WorkerGuard; use tracing_appender::non_blocking::WorkerGuard;
use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt}; use tracing_subscriber::{layer::SubscriberExt, util::SubscriberInitExt};
@@ -24,7 +25,11 @@ pub fn initialize() -> io::Result<LogGuard> {
tracing_subscriber::registry() tracing_subscriber::registry()
.with(file_logger) .with(file_logger)
.with(stdout_logger) .with(stdout_logger)
.with(tracing_subscriber::EnvFilter::from_default_env()) .with(
tracing_subscriber::EnvFilter::builder()
.with_default_directive(LevelFilter::TRACE.into())
.from_env_lossy(),
)
.init(); .init();
Ok(LogGuard { Ok(LogGuard {

View File

@@ -1,19 +1,20 @@
use database::Database; use database::Database;
use state::AppState;
mod config; mod config;
mod context;
mod database; mod database;
mod entity; mod entity;
mod jwt; mod jwt;
mod log; mod log;
mod state;
mod web; mod web;
#[tokio::main] #[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> { async fn main() -> Result<(), Box<dyn std::error::Error>> {
let _guard = log::initialize()?; let _guard = log::initialize()?;
let database = sqlx::any::AnyPool::init().await?; let database = Database::init().await?;
let context = context::Context { database }; let context = AppState { database };
web::run(context).await?; web::run(context).await?;

6
src/state.rs Normal file
View File

@@ -0,0 +1,6 @@
use crate::database::Database;
#[derive(Clone)]
pub struct AppState {
pub database: Database,
}

34
src/web/context.rs Normal file
View File

@@ -0,0 +1,34 @@
use axum::{async_trait, extract::FromRequestParts, http::request::Parts};
use crate::entity;
#[derive(Clone)]
pub struct Context {
pub user_id: entity::ShortId,
}
#[async_trait]
impl<S: Send + Sync> FromRequestParts<S> for Context {
type Rejection = super::Error;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> super::Result<Self> {
parts
.extensions
.get::<ContextResult>()
.ok_or(super::Error::ContextError(Error::NotInRequest))?
.clone()
.map_err(super::Error::ContextError)
}
}
pub type ContextResult = std::result::Result<Context, Error>;
#[derive(Debug, Clone, Copy)]
pub enum Error {
NotInRequest,
NotInHeader,
BadCharacters,
WrongTokenType,
BadToken,
Model,
}

View File

@@ -0,0 +1,43 @@
use std::sync::Arc;
use axum::{http::StatusCode, response::IntoResponse};
use derive_more::{Display, Error, From};
use crate::{database, jwt};
use super::context;
pub type Result<T> = std::result::Result<T, Error>;
#[derive(Debug, From)]
pub enum Error {
#[from]
ContextError(context::Error),
#[from]
DatabaseError(database::Error),
#[from]
JWT(jwt::Error),
WrongPassword,
}
impl Error {
pub fn as_client_error(&self) -> ClientError {
ClientError::HahaError
}
}
impl IntoResponse for Error {
fn into_response(self) -> axum::response::Response {
let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response();
response.extensions_mut().insert(Arc::new(self));
response
}
}
#[derive(Debug, Error, Display)]
pub enum ClientError {
HahaError,
}

View File

@@ -1,5 +1,57 @@
use axum::{extract::Request, middleware::Next, response::IntoResponse}; use axum::{
extract::{Request, State},
http::{header, HeaderMap},
middleware::Next,
response::{IntoResponse, Response},
};
use crate::{
database::Database,
jwt,
state::AppState,
web::{
self,
context::{self, ContextResult},
},
};
pub async fn require_context(
context: ContextResult,
request: Request,
next: Next,
) -> web::Result<Response> {
context?;
Ok(next.run(request).await)
}
pub async fn resolve_context(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> Response {
let context = get_context(state, request.headers()).await;
request.extensions_mut().insert(context);
pub async fn auth(request: Request, next: Next) -> impl IntoResponse {
next.run(request).await next.run(request).await
} }
async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextResult {
let header = headers
.get(header::AUTHORIZATION)
.ok_or(context::Error::NotInHeader)?;
let token = header.to_str().map_err(|_| context::Error::BadCharacters)?;
if !token.starts_with("Bearer ") {
return Err(context::Error::WrongTokenType);
}
let token = &token[7..];
let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?;
let _user = state
.database
.get_user_by_id(user_id)
.await
.map_err(|_| context::Error::Model)?;
Ok(context::Context { user_id })
}

View File

@@ -1,5 +1,6 @@
mod auth; mod auth;
mod response_map; mod response_map;
pub use auth::auth; pub use auth::require_context;
pub use auth::resolve_context;
pub use response_map::response_map; pub use response_map::response_map;

View File

@@ -1,5 +1,19 @@
use std::sync::Arc;
use axum::{extract::Request, middleware::Next, response::IntoResponse}; use axum::{extract::Request, middleware::Next, response::IntoResponse};
use crate::web;
pub async fn response_map(request: Request, next: Next) -> impl IntoResponse { pub async fn response_map(request: Request, next: Next) -> impl IntoResponse {
next.run(request).await let response = next.run(request).await;
let error = response.extensions().get::<Arc<web::Error>>();
let error_response = error.map(|e| {
let client_error = e.as_client_error();
client_error.to_string().into_response()
});
error_response.unwrap_or(response)
} }

View File

@@ -1,31 +1,38 @@
use crate::{config, context, database}; use crate::{config, state};
pub mod error; mod context;
mod error;
pub mod middlware; pub mod middlware;
pub mod routes; pub mod routes;
pub async fn run<D: database::Database + Clone + Send + Sync + 'static>( pub use error::{Error, Result};
context: context::Context<D>,
) -> anyhow::Result<()> { pub async fn run(state: state::AppState) -> anyhow::Result<()> {
let config = config::config(); let config = config::config();
let addr: std::net::SocketAddr = ([0, 0, 0, 0], config.port).into(); let addr: std::net::SocketAddr = ([0, 0, 0, 0], config.port).into();
tracing::info!("Listening on {}", addr); tracing::info!("Listening on {}", addr);
let listener = tokio::net::TcpListener::bind(addr).await?; let listener = tokio::net::TcpListener::bind(addr).await?;
axum::serve(listener, router(context)) axum::serve(listener, router(state))
.with_graceful_shutdown(shutdown_signal()) .with_graceful_shutdown(shutdown_signal())
.await?; .await?;
Ok(()) Ok(())
} }
fn router<D: database::Database + Clone + Send + Sync + 'static>( fn router(state: state::AppState) -> axum::Router {
context: context::Context<D>,
) -> axum::Router {
axum::Router::new() axum::Router::new()
.route("/login", axum::routing::post(routes::login)) .route("/user/login", axum::routing::post(routes::user::login))
.with_state(context) .route(
"/user/register",
axum::routing::post(routes::user::register),
)
.with_state(state.clone())
.layer(axum::middleware::from_fn_with_state(
state.clone(),
middlware::resolve_context,
))
.layer(axum::middleware::from_fn(middlware::response_map)) .layer(axum::middleware::from_fn(middlware::response_map))
} }

View File

@@ -1,5 +0,0 @@
use axum::response::IntoResponse;
pub async fn login() -> impl IntoResponse {
"login"
}

View File

@@ -1,3 +1 @@
mod login; pub mod user;
pub use login::login;

View File

@@ -0,0 +1,31 @@
use axum::{extract::State, response::IntoResponse, Json};
use serde::Deserialize;
use serde_json::json;
use crate::{jwt, state::AppState, web};
pub async fn login(
State(state): State<AppState>,
Json(payload): Json<LoginPayload>,
) -> web::Result<impl IntoResponse> {
let user = state
.database
.get_user_by_username(&payload.username)
.await?;
if user.password != payload.password {
return Err(web::Error::WrongPassword);
}
let token = jwt::generate_jwt(user.id)?;
Ok(Json(json!({
"token": token
})))
}
#[derive(Deserialize)]
pub struct LoginPayload {
username: String,
password: String,
}

View File

@@ -0,0 +1,5 @@
mod login;
mod register;
pub use login::login;
pub use register::register;

View File

@@ -0,0 +1,27 @@
use axum::{extract::State, response::IntoResponse, Json};
use serde::Deserialize;
use serde_json::json;
use crate::{state::AppState, web};
pub async fn register(
State(state): State<AppState>,
Json(payload): Json<RegisterPayload>,
) -> web::Result<impl IntoResponse> {
let user = state
.database
.create_user(&payload.username, &payload.password)
.await?;
Ok(Json(json!({
"id": user.id,
"username": user.username,
"created_at": user.created_at
})))
}
#[derive(Deserialize)]
pub struct RegisterPayload {
username: String,
password: String,
}