diff --git a/config.toml b/config.toml index 853f43c..3919151 100644 --- a/config.toml +++ b/config.toml @@ -1,5 +1,6 @@ -jwt_secret = "secret" -port = 1234 +jwt_secret = "secret" +notifer_timer = 5 +port = 1234 [database] max_connections = 5 diff --git a/migrations/20240423082838_user_table.sql b/migrations/20240423082838_user_table.sql index 1f441b9..3361772 100644 --- a/migrations/20240423082838_user_table.sql +++ b/migrations/20240423082838_user_table.sql @@ -1,7 +1,8 @@ CREATE TABLE IF NOT EXISTS user ( - id INTEGER PRIMARY KEY AUTOINCREMENT, - username VARCHAR UNIQUE, - password VARCHAR, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - avatar VARCHAR -); \ No newline at end of file + `id` INTEGER PRIMARY KEY AUTOINCREMENT, + `avatar` VARCHAR, + `username` VARCHAR UNIQUE, + `password` VARCHAR, + `last_seen` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); diff --git a/migrations/20240506141114_channel.sql b/migrations/20240506141114_channel.sql index 863a508..c5786ad 100644 --- a/migrations/20240506141114_channel.sql +++ b/migrations/20240506141114_channel.sql @@ -17,11 +17,12 @@ CREATE TABLE IF NOT EXISTS `channel_user` ( CREATE TABLE IF NOT EXISTS `message` ( `id` INTEGER PRIMARY KEY AUTOINCREMENT, `channel_id` INTEGER NOT NULL, - `author_id` INTEGER, - `content` TEXT, - `created_at` DATETIME DEFAULT CURRENT_TIMESTAMP, + `author_id` INTEGER NOT NULL, + `content` TEXT NOT NULL, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + `system` BOOLEAN NOT NULL DEFAULT 0, FOREIGN KEY(`channel_id`) REFERENCES `channel`(`id`) ON DELETE CASCADE, FOREIGN KEY(`author_id`) REFERENCES `user`(`id`) ON DELETE SET NULL -); \ No newline at end of file +); diff --git a/migrations/20240518191103_tokens.sql b/migrations/20240518191103_tokens.sql index 5c56361..4d00c71 100644 --- a/migrations/20240518191103_tokens.sql +++ b/migrations/20240518191103_tokens.sql @@ -1,6 +1,6 @@ CREATE TABLE IF NOT EXISTS tokens ( - token TEXT NOT NULL PRIMARY KEY, - user_id INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, - created_at DATETIME DEFAULT CURRENT_TIMESTAMP, - expires_at DATETIME + `token` TEXT NOT NULL PRIMARY KEY, + `user_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + `created_at` DATETIME DEFAULT CURRENT_TIMESTAMP, + `expires_at` DATETIME NOT NULL ); \ No newline at end of file diff --git a/migrations/20240520132841_followers.sql b/migrations/20240520132841_followers.sql new file mode 100644 index 0000000..141f4f1 --- /dev/null +++ b/migrations/20240520132841_followers.sql @@ -0,0 +1,7 @@ +CREATE TABLE IF NOT EXISTS user_follow ( + `user_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + `follow_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP, + PRIMARY KEY (user_id, follow_id), + CHECK (user_id <> follow_id) +); \ No newline at end of file diff --git a/migrations/20240520203022_secrets.sql b/migrations/20240520203022_secrets.sql new file mode 100644 index 0000000..0133aa8 --- /dev/null +++ b/migrations/20240520203022_secrets.sql @@ -0,0 +1,15 @@ +CREATE TABLE IF NOT EXISTS secret ( + `id` INTEGER PRIMARY KEY AUTOINCREMENT, + `name` TEXT NOT NULL, + `content` TEXT NOT NULL, + `user_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + `timeout_seconds` INTEGER NOT NULL, + `expired` BOOLEAN NOT NULL DEFAULT 0, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +CREATE TABLE IF NOT EXISTS secret_recipient ( + `secret_id` INTEGER NOT NULL REFERENCES secret(id) ON DELETE CASCADE, + `user_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + PRIMARY KEY (secret_id, user_id) +); \ No newline at end of file diff --git a/migrations/20240520205451_notifications.sql b/migrations/20240520205451_notifications.sql new file mode 100644 index 0000000..cfb9a5a --- /dev/null +++ b/migrations/20240520205451_notifications.sql @@ -0,0 +1,8 @@ +CREATE TABLE IF NOT EXISTS notification ( + `id` INTEGER PRIMARY KEY AUTOINCREMENT, + `user_id` INTEGER NOT NULL REFERENCES user(id) ON DELETE CASCADE, + `title` TEXT NOT NULL, + `body` TEXT NOT NULL, + `seen` BOOLEAN NOT NULL DEFAULT 0, + `created_at` DATETIME NOT NULL DEFAULT CURRENT_TIMESTAMP +); \ No newline at end of file diff --git a/src/config.rs b/src/config.rs index 673122c..3d7dd1e 100644 --- a/src/config.rs +++ b/src/config.rs @@ -21,6 +21,7 @@ pub fn config() -> &'static Config { pub struct Config { pub port: u16, pub jwt_secret: String, + pub notifer_timer: u64, pub database: DatabaseConfig, } diff --git a/src/database.rs b/src/database.rs index 5da6343..1e89e29 100644 --- a/src/database.rs +++ b/src/database.rs @@ -66,6 +66,24 @@ impl Database { Ok(user) } + pub async fn update_user_last_seen(&self, user_id: entity::ShortId) -> Result<()> { + sqlx::query("UPDATE user SET last_seen = CURRENT_TIMESTAMP WHERE id = $1") + .bind(user_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn activate_all_secrets(&self, user_id: entity::ShortId) -> Result<()> { + sqlx::query("UPDATE secret SET expired = false WHERE user_id = $1") + .bind(user_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + pub async fn create_token( &self, id: i32, @@ -272,6 +290,259 @@ impl Database { Ok(messages) } + + pub async fn get_followed_users(&self, user_id: entity::ShortId) -> Result> { + let users = sqlx::query_as("SELECT user.* FROM user_follow JOIN user ON user.id = user_follow.follow_id WHERE user_id = $1") + .bind(user_id) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn follow_user( + &self, + user_id: entity::ShortId, + follow_id: entity::ShortId, + ) -> Result<()> { + sqlx::query("INSERT INTO user_follow(user_id, follow_id) VALUES ($1, $2)") + .bind(user_id) + .bind(follow_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn unfollow_user( + &self, + user_id: entity::ShortId, + follow_id: entity::ShortId, + ) -> Result<()> { + sqlx::query("DELETE FROM user_follow WHERE user_id = $1 AND follow_id = $2") + .bind(user_id) + .bind(follow_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn search_user_by_username( + &self, + username: &str, + limit: i64, + offset: i64, + ) -> Result> { + let users = sqlx::query_as( + "SELECT * FROM user WHERE username LIKE $1 ORDER BY username LIMIT $2 OFFSET $3", + ) + .bind(format!("{}%", username)) + .bind(limit) + .bind(offset) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn get_channel_user_permissions( + &self, + user_id: entity::ShortId, + channel_id: entity::ShortId, + ) -> Result { + let permissions = + sqlx::query_as("SELECT * FROM channel_user WHERE user_id = $1 AND channel_id = $2") + .bind(user_id) + .bind(channel_id) + .fetch_one(&self.pool) + .await?; + + Ok(permissions) + } + + pub async fn get_secret_by_id(&self, secret_id: entity::ShortId) -> Result { + let secret = sqlx::query_as("SELECT * FROM secret WHERE id = $1") + .bind(secret_id) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::MessageDoesNotExists)?; + + Ok(secret) + } + + pub async fn get_all_user_secrets( + &self, + user_id: entity::ShortId, + ) -> Result> { + let secrets = sqlx::query_as("SELECT * FROM secret WHERE user_id = $1") + .bind(user_id) + .fetch_all(&self.pool) + .await?; + + Ok(secrets) + } + + pub async fn get_active_all_secrets(&self) -> Result> { + let secrets = sqlx::query_as("SELECT * FROM secret WHERE expired = false") + .fetch_all(&self.pool) + .await?; + + Ok(secrets) + } + + pub async fn update_secret( + &self, + secret_id: entity::ShortId, + name: &str, + content: &str, + timeout_seconds: i32, + ) -> Result { + sqlx::query( + "UPDATE secret SET name = $1, content = $2, timeout_seconds = $3 WHERE id = $4", + ) + .bind(name) + .bind(content) + .bind(timeout_seconds) + .bind(secret_id) + .execute(&self.pool) + .await?; + + self.get_secret_by_id(secret_id).await + } + + pub async fn create_secret( + &self, + user_id: entity::ShortId, + name: &str, + content: &str, + timeout_seconds: i32, + ) -> Result { + let id = sqlx::query_scalar( + "INSERT INTO secret(user_id, name, content, timeout_seconds) VALUES ($1, $2, $3, $4) RETURNING id", + ) + .bind(user_id) + .bind(name) + .bind(content) + .bind(timeout_seconds) + .fetch_one(&self.pool) + .await?; + + self.get_secret_by_id(id).await + } + + pub async fn add_secret_recipient( + &self, + secret_id: entity::ShortId, + user_id: entity::ShortId, + ) -> Result<()> { + sqlx::query("INSERT INTO secret_recipient(secret_id, user_id) VALUES ($1, $2)") + .bind(secret_id) + .bind(user_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn delete_secret(&self, secret_id: entity::ShortId) -> Result<()> { + sqlx::query("DELETE FROM secret WHERE id = $1") + .bind(secret_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn expire_secret(&self, secret_id: entity::ShortId) -> Result<()> { + sqlx::query("UPDATE secret SET expired = true WHERE id = $1") + .bind(secret_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn delete_secret_recipient( + &self, + secret_id: entity::ShortId, + user_id: entity::ShortId, + ) -> Result<()> { + sqlx::query("DELETE FROM secret_recipient WHERE secret_id = $1 AND user_id = $2") + .bind(secret_id) + .bind(user_id) + .execute(&self.pool) + .await?; + + Ok(()) + } + + pub async fn get_secret_recipients( + &self, + secret_id: entity::ShortId, + ) -> Result> { + let users = sqlx::query_as("SELECT user.* FROM user INNER JOIN secret_recipient ON user.id = secret_recipient.user_id WHERE secret_id = $1") + .bind(secret_id) + .fetch_all(&self.pool) + .await?; + + Ok(users) + } + + pub async fn get_all_notifications( + &self, + user_id: entity::ShortId, + ) -> Result> { + let notifications = sqlx::query_as("SELECT * FROM notification WHERE user_id = $1") + .bind(user_id) + .fetch_all(&self.pool) + .await?; + + Ok(notifications) + } + + pub async fn get_notification_by_id( + &self, + notification_id: entity::LongId, + ) -> Result { + let notification = sqlx::query_as("SELECT * FROM notification WHERE id = $1") + .bind(notification_id) + .fetch_optional(&self.pool) + .await? + .ok_or(Error::MessageDoesNotExists)?; + + Ok(notification) + } + + pub async fn create_notification( + &self, + user_id: entity::ShortId, + title: &str, + body: &str, + ) -> Result { + let id = sqlx::query_scalar( + "INSERT INTO notification(user_id, title, body) VALUES ($1, $2, $3) RETURNING id", + ) + .bind(user_id) + .bind(title) + .bind(body) + .fetch_one(&self.pool) + .await?; + + self.get_notification_by_id(id).await + } + + pub async fn seen_notification( + &self, + notification_id: entity::LongId, + ) -> Result { + sqlx::query("UPDATE notification SET seen = true WHERE id = $1") + .bind(notification_id) + .execute(&self.pool) + .await?; + + self.get_notification_by_id(notification_id).await + } } pub type Result = std::result::Result; diff --git a/src/entity/message.rs b/src/entity/message.rs index 3578a5d..0e08ed4 100644 --- a/src/entity/message.rs +++ b/src/entity/message.rs @@ -8,5 +8,6 @@ pub struct Message { pub channel_id: super::ShortId, pub author_id: super::ShortId, pub content: String, + pub system: bool, pub created_at: chrono::DateTime, } diff --git a/src/entity/mod.rs b/src/entity/mod.rs index c023445..2cc4a25 100644 --- a/src/entity/mod.rs +++ b/src/entity/mod.rs @@ -3,16 +3,18 @@ mod channel; mod log; mod message; +mod notification; mod secret; mod token; mod user; -pub use channel::Channel; -pub use log::Log; -pub use message::Message; -pub use secret::Secret; -pub use token::Token; -pub use user::User; +pub use channel::*; +pub use log::*; +pub use message::*; +pub use notification::*; +pub use secret::*; +pub use token::*; +pub use user::*; pub type ShortId = i32; pub type LongId = i64; diff --git a/src/entity/notification.rs b/src/entity/notification.rs new file mode 100644 index 0000000..17b3fae --- /dev/null +++ b/src/entity/notification.rs @@ -0,0 +1,12 @@ +use serde::Serialize; + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct Notification { + pub id: super::LongId, + pub user_id: super::ShortId, + pub title: String, + pub body: String, + pub seen: bool, + pub created_at: chrono::DateTime, +} diff --git a/src/entity/secret.rs b/src/entity/secret.rs index 5d2ec88..1916a8a 100644 --- a/src/entity/secret.rs +++ b/src/entity/secret.rs @@ -1,9 +1,13 @@ use serde::Serialize; #[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] pub struct Secret { pub id: super::ShortId, pub user_id: super::ShortId, - pub title: String, + pub name: String, pub content: String, + pub timeout_seconds: i32, + pub expired: bool, + pub created_at: chrono::DateTime, } diff --git a/src/entity/user.rs b/src/entity/user.rs index e21bc00..b93dafe 100644 --- a/src/entity/user.rs +++ b/src/entity/user.rs @@ -5,9 +5,18 @@ use serde::Serialize; #[serde(rename_all = "camelCase")] pub struct User { pub id: super::ShortId, + pub avatar: Option, pub username: String, #[serde(skip_serializing)] pub password: String, - pub avatar: Option, + pub last_seen: chrono::DateTime, pub created_at: chrono::DateTime, } + +#[derive(Debug, Clone, sqlx::FromRow, Serialize)] +#[serde(rename_all = "camelCase")] +pub struct ChannelPermisions { + pub user_id: super::ShortId, + pub channel_id: super::ShortId, + pub admin: bool, +} diff --git a/src/main.rs b/src/main.rs index 1808b57..e1ac060 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,6 +6,7 @@ mod database; mod entity; mod jwt; mod log; +mod notifier; mod state; mod web; @@ -19,6 +20,8 @@ async fn main() -> Result<(), Box> { connected_users: Default::default(), }; + tokio::spawn(notifier::run(context.clone())); + web::run(context).await?; Ok(()) diff --git a/src/notifier/mod.rs b/src/notifier/mod.rs new file mode 100644 index 0000000..441918c --- /dev/null +++ b/src/notifier/mod.rs @@ -0,0 +1,77 @@ +use std::collections::{HashMap, HashSet}; + +use chrono::{DateTime, Duration}; + +use crate::{config, entity, state, web::ws}; + +pub async fn run(context: state::AppState) { + loop { + tokio::time::sleep(std::time::Duration::from_secs( + config::config().notifer_timer, + )) + .await; + + let current_time = chrono::Utc::now(); + + let all_secrets = context + .database + .get_active_all_secrets() + .await + .expect("Could not get all secrets"); + + let mut cached_users: HashMap = HashMap::new(); + + for secret in all_secrets { + let user = match cached_users.get(&secret.user_id) { + Some(user) => user.clone(), + None => { + let user = context + .database + .get_user_by_id(secret.user_id) + .await + .expect("Could not get user"); + + cached_users.insert(secret.user_id, user.clone()); + + user + } + }; + + let last_seen = user.last_seen; + let timeout = Duration::seconds(secret.timeout_seconds as i64); + + if last_seen + timeout < current_time { + context + .database + .expire_secret(secret.id) + .await + .expect("Could not expire secret"); + + let recipients = context + .database + .get_secret_recipients(secret.id) + .await + .expect("Could not get recipients"); + + for recipient in recipients { + let notification = context + .database + .create_notification( + recipient.id, + &format!("Secret of {}", user.username), + &secret.content, + ) + .await + .expect("Could not create notification"); + + ws::broadcast_message( + ws::Message::CreateNotification(notification.clone()), + &context, + |key| key.user_id == recipient.id, + ) + .await; + } + } + } + } +} diff --git a/src/web/error.rs b/src/web/error.rs index 0314226..6ab6099 100644 --- a/src/web/error.rs +++ b/src/web/error.rs @@ -18,34 +18,29 @@ pub enum Error { #[from] Jwt(jwt::Error), + #[from] + Client(ClientError), + WrongPassword, NotAllowed, } impl Error { - pub fn as_client_error_and_status(&self) -> (ClientError, StatusCode) { + pub fn as_client_error(&self) -> ClientError { match self { - Error::Context(_) | Error::Jwt(_) => { - (ClientError::NotAuthorized, StatusCode::UNAUTHORIZED) - } - Error::Database(database::Error::UserAlreadyExists) => { - (ClientError::UserAlreadyExists, StatusCode::CONFLICT) - } - Error::Database(database::Error::UserDoesNotExists) => { - (ClientError::UserDoesNotExists, StatusCode::NOT_FOUND) - } + Error::Context(_) | Error::Jwt(_) => ClientError::NotAuthorized, + Error::Database(database::Error::UserAlreadyExists) => ClientError::UserAlreadyExists, + Error::Database(database::Error::UserDoesNotExists) => ClientError::UserDoesNotExists, Error::Database(database::Error::ChannelDoesNotExists) => { - (ClientError::ChannelDoesNotExists, StatusCode::NOT_FOUND) + ClientError::ChannelDoesNotExists } Error::Database(database::Error::MessageDoesNotExists) => { - (ClientError::MessageDoesNotExists, StatusCode::NOT_FOUND) + ClientError::MessageDoesNotExists } - Error::WrongPassword => (ClientError::WrongPassword, StatusCode::UNAUTHORIZED), - Error::NotAllowed => (ClientError::NotAllowed, StatusCode::FORBIDDEN), - _ => ( - ClientError::InternalServerError, - StatusCode::INTERNAL_SERVER_ERROR, - ), + Error::WrongPassword => ClientError::WrongPassword, + Error::NotAllowed => ClientError::NotAllowed, + Error::Client(client_error) => *client_error, + _ => ClientError::InternalServerError, } } } @@ -60,7 +55,7 @@ impl IntoResponse for Error { } } -#[derive(Debug, Error, Display)] +#[derive(Debug, Clone, Copy, Error, Display)] pub enum ClientError { #[display(fmt = "Not authorized")] NotAuthorized, @@ -78,6 +73,26 @@ pub enum ClientError { #[display(fmt = "Message does not exists")] MessageDoesNotExists, + #[display(fmt = "Cannot unfollow")] + CannotUnfollow, + #[display(fmt = "Cannot follow")] + CannotFollow, + #[display(fmt = "Internal server error")] InternalServerError, } + +impl ClientError { + pub fn status_code(&self) -> StatusCode { + match self { + ClientError::NotAuthorized | ClientError::WrongPassword => StatusCode::UNAUTHORIZED, + ClientError::CannotUnfollow | ClientError::CannotFollow => StatusCode::BAD_REQUEST, + ClientError::UserAlreadyExists => StatusCode::CONFLICT, + ClientError::UserDoesNotExists + | ClientError::ChannelDoesNotExists + | ClientError::MessageDoesNotExists => StatusCode::NOT_FOUND, + ClientError::NotAllowed => StatusCode::FORBIDDEN, + _ => StatusCode::INTERNAL_SERVER_ERROR, + } + } +} diff --git a/src/web/middlware/auth.rs b/src/web/middlware/auth.rs index 4040d44..d7d5fc9 100644 --- a/src/web/middlware/auth.rs +++ b/src/web/middlware/auth.rs @@ -30,7 +30,21 @@ pub async fn resolve_context( mut request: Request, next: Next, ) -> Response { - let context = get_context(state, request.headers()).await; + let context = get_context(state.clone(), request.headers()).await; + + if let Ok(ref context) = context { + let _ = state + .database + .update_user_last_seen(context.user.id) + .await + .inspect_err(|err| tracing::error!("could not update last seen: {}", err)); + + let _ = state + .database + .activate_all_secrets(context.user.id) + .await + .inspect_err(|err| tracing::error!("could not activate all secrets: {}", err)); + } request.extensions_mut().insert(context); @@ -52,11 +66,24 @@ async fn get_context(state: AppState, headers: &HeaderMap) -> context::ContextRe pub async fn get_context_from_token(state: AppState, token: &str) -> context::ContextResult { let user_id = jwt::verify_jwt(token).map_err(|_| context::Error::BadToken)?; + let token = state + .database + .get_token(token) + .await + .map_err(|_| context::Error::BadToken)?; + + if token.user_id != user_id { + return Err(context::Error::BadToken); + } + let user = state .database .get_user_by_id(user_id) .await .map_err(|_| context::Error::Model)?; - Ok(context::Context { user, token: token.to_owned() }) + Ok(context::Context { + user, + token: token.token, + }) } diff --git a/src/web/middlware/response_map.rs b/src/web/middlware/response_map.rs index e7c23f2..41b03b7 100644 --- a/src/web/middlware/response_map.rs +++ b/src/web/middlware/response_map.rs @@ -15,11 +15,11 @@ pub async fn response_map(request: Request, next: Next) -> impl IntoResponse { } let error_response = error.map(|e| { - let client_error = e.as_client_error_and_status(); + let client_error = e.as_client_error(); ( - client_error.1, - Json(json!({"error": client_error.0.to_string()})), + client_error.status_code(), + Json(json!({"error": client_error.to_string()})), ) .into_response() }); diff --git a/src/web/mod.rs b/src/web/mod.rs index 1cc70c1..2f169cf 100644 --- a/src/web/mod.rs +++ b/src/web/mod.rs @@ -1,4 +1,7 @@ -use crate::{config, state, web::routes::message}; +use crate::{ + config, state, + web::routes::{message, notification, secret}, +}; mod context; mod error; @@ -55,18 +58,24 @@ fn protected_router() -> axum::Router { use self::routes::channel; use self::routes::user; use axum::routing::*; - use axum::Router; Router::new() // user .route("/user", get(user::get_by_username)) .route("/user/me", get(user::get_self)) .route("/user/:user_id", get(user::get_by_id)) + .route("/user/search", get(user::search_by_username)) + // follow + .route("/user/:user_id/follow", get(user::is_following_user)) + .route("/user/:user_id/follow", post(user::follow_user)) + .route("/user/:user_id/follow", delete(user::unfollow_user)) + .route("/user/me/follow", get(user::get_followed_users)) // channel .route("/channel", get(channel::get_all_user_channels)) .route("/channel/:channel_id", get(channel::get_by_id)) .route("/channel", post(channel::create_channel)) .route("/channel/:channel_id", delete(channel::delete_channel)) + .route("/channel/:channel_id/users", get(user::get_by_channel_id)) .route( "/channel/:channel_id/user/:user_id", post(channel::add_user_to_channel), @@ -75,6 +84,10 @@ fn protected_router() -> axum::Router { "/channel/:channel_id/user/:user_id", delete(channel::remove_user_from_channel), ) + .route( + "/channel/:channel_id/users/:user_id/permissions", + get(channel::get_channel_user_permissions), + ) // message .route("/message/:message_id", get(message::get_by_id)) .route("/message", post(message::create)) @@ -82,6 +95,34 @@ fn protected_router() -> axum::Router { "/channel/:channel_id/message", get(message::get_messages_by_channel_id), ) + // secret + .route("/secret/:secret_id", get(secret::get_by_id)) + .route("/secret", get(secret::get_all_self)) + .route("/secret", post(secret::create)) + .route("/secret/:secret_id", put(secret::update_by_id)) + .route("/secret/:secret_id", delete(secret::delete_by_id)) + .route( + "/secret/:secret_id/recipients", + get(secret::get_recipients_by_id), + ) + .route( + "/secret/:secret_id/recipients/:user_id", + post(secret::add_recipient_by_id), + ) + .route( + "/secret/:secret_id/recipients/:user_id", + delete(secret::delete_recipient_by_id), + ) + // notification + .route( + "/notification/:notification_id", + get(notification::get_by_id), + ) + .route("/notification", get(notification::get_all_self)) + .route( + "/notification/:notification_id", + post(notification::seen_by_id), + ) // middleware .route_layer(axum::middleware::from_fn(middlware::require_context)) } diff --git a/src/web/routes/channel/user.rs b/src/web/routes/channel/user.rs index d5214aa..3dd0e90 100644 --- a/src/web/routes/channel/user.rs +++ b/src/web/routes/channel/user.rs @@ -1,9 +1,14 @@ use axum::{ extract::{Path, State}, response::IntoResponse, + Json, }; -use crate::{entity::ShortId, state::AppState, web}; +use crate::{ + entity::ShortId, + state::AppState, + web::{self, ws}, +}; pub async fn add_user_to_channel( State(state): State, @@ -12,9 +17,16 @@ pub async fn add_user_to_channel( ) -> web::Result { let channel_users = state.database.get_channel_users(channel_id).await?; + if channel_users + .iter() + .any(|(user, _admin)| user.id == user_id) + { + return Err(web::Error::NotAllowed); + } + if !channel_users - .into_iter() - .any(|(user, admin)| user.id == context.user.id && admin) + .iter() + .any(|(user, _admin)| user.id == context.user.id) { return Err(web::Error::NotAllowed); } @@ -24,6 +36,27 @@ pub async fn add_user_to_channel( .add_user_to_channel(user_id, channel_id, false) .await?; + let channel = state.database.get_channel_by_id(channel_id).await?; + + ws::broadcast_message(ws::Message::CreateChannel(channel), &state, |key| { + key.user_id == user_id + }) + .await; + + ws::broadcast_message( + ws::Message::AddedUserToChannel { + user_id, + channel_id, + }, + &state, + |key| { + channel_users + .iter() + .any(|(user, _)| user.id == key.user_id || user.id == user_id) + }, + ) + .await; + Ok(()) } @@ -35,8 +68,8 @@ pub async fn remove_user_from_channel( let channel_users = state.database.get_channel_users(channel_id).await?; if !channel_users - .into_iter() - .any(|(user, admin)| user.id == context.user.id && admin) + .iter() + .any(|(user, admin)| user.id == context.user.id && *admin) { return Err(web::Error::NotAllowed); } @@ -46,5 +79,34 @@ pub async fn remove_user_from_channel( .remove_user_from_channel(user_id, channel_id) .await?; + ws::broadcast_message( + ws::Message::DeleteChannel { id: channel_id }, + &state, + |key| key.user_id == user_id, + ) + .await; + + ws::broadcast_message( + ws::Message::RemovedUserFromChannel { + user_id, + channel_id, + }, + &state, + |key| channel_users.iter().any(|(user, _)| user.id == key.user_id), + ) + .await; + Ok(()) } + +pub async fn get_channel_user_permissions( + State(state): State, + Path((channel_id, user_id)): Path<(ShortId, ShortId)>, +) -> web::Result { + Ok(Json( + state + .database + .get_channel_user_permissions(channel_id, user_id) + .await?, + )) +} diff --git a/src/web/routes/mod.rs b/src/web/routes/mod.rs index e7c2185..80b2b61 100644 --- a/src/web/routes/mod.rs +++ b/src/web/routes/mod.rs @@ -1,3 +1,5 @@ pub mod channel; pub mod message; +pub mod notification; +pub mod secret; pub mod user; diff --git a/src/web/routes/notification.rs b/src/web/routes/notification.rs new file mode 100644 index 0000000..57da433 --- /dev/null +++ b/src/web/routes/notification.rs @@ -0,0 +1,67 @@ +use axum::{ + extract::{Path, State}, + response::IntoResponse, +}; + +use crate::{ + entity::{LongId, ShortId}, + state::AppState, + web::{self, ws}, +}; + +pub async fn get_all_self( + State(state): State, + context: web::context::Context, +) -> web::Result { + let notifications = state + .database + .get_all_notifications(context.user.id) + .await?; + + Ok(axum::response::Json(notifications)) +} + +pub async fn get_by_id( + State(state): State, + Path(notification_id): Path, + context: web::context::Context, +) -> web::Result { + let notification = state + .database + .get_notification_by_id(notification_id) + .await?; + + if notification.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + Ok(axum::response::Json(notification)) +} + +pub async fn seen_by_id( + State(state): State, + Path(notification_id): Path, + context: web::context::Context, +) -> web::Result { + let notification = state + .database + .get_notification_by_id(notification_id) + .await?; + + if notification.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + let notification = state.database.seen_notification(notification_id).await?; + + ws::broadcast_message( + ws::Message::SeenNotification { + id: notification.id, + }, + &state, + |key| key.user_id == notification.user_id, + ) + .await; + + Ok(axum::response::Json(notification)) +} diff --git a/src/web/routes/secret/mod.rs b/src/web/routes/secret/mod.rs new file mode 100644 index 0000000..302ed73 --- /dev/null +++ b/src/web/routes/secret/mod.rs @@ -0,0 +1,336 @@ +use axum::{ + extract::{Path, State}, + response::IntoResponse, + Json, +}; + +use crate::{ + entity::{LongId, ShortId}, + state::AppState, + web::{self, ws}, +}; + +pub async fn get_by_id( + State(state): State, + Path(secret_id): Path, + context: web::context::Context, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + Ok(Json(secret)) +} + +pub async fn get_all_self( + State(state): State, + context: web::context::Context, +) -> web::Result { + let secrets = state.database.get_all_user_secrets(context.user.id).await?; + + Ok(Json(secrets)) +} + +pub async fn create( + State(state): State, + context: web::context::Context, + Json(payload): Json, +) -> web::Result { + let secret = state + .database + .create_secret( + context.user.id, + &payload.name, + &payload.content, + payload.timeout_seconds, + ) + .await?; + + let recipients = state + .database + .get_secret_recipients(secret.id) + .await? + .into_iter() + .map(|r| r.id) + .collect::>(); + + // for recipient_id in &payload.recipients { + // state + // .database + // .add_secret_recipient(secret.id, *recipient_id) + // .await?; + // } + + let to_add = payload + .recipients + .iter() + .filter(|id| !recipients.contains(id)) + .cloned() + .collect::>(); + + let to_delete = recipients + .iter() + .filter(|id| !to_add.contains(id)) + .cloned() + .collect::>(); + + for recipient_id in &to_add { + state + .database + .add_secret_recipient(secret.id, *recipient_id) + .await?; + } + + for recipient_id in &to_delete { + state + .database + .delete_secret_recipient(secret.id, *recipient_id) + .await?; + } + + ws::broadcast_message(ws::Message::CreateSecret(secret.clone()), &state, |key| { + key.user_id == secret.user_id + }) + .await; + + for recipient_id in to_add { + ws::broadcast_message( + ws::Message::SecretRecipientAdded { + id: secret.id, + user_id: recipient_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + } + + for recipient_id in to_delete { + ws::broadcast_message( + ws::Message::SecretRecipientDeleted { + id: secret.id, + user_id: recipient_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + } + + Ok(Json(secret)) +} + +pub async fn update_by_id( + State(state): State, + context: web::context::Context, + Path(secret_id): Path, + Json(payload): Json, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + let secret = state + .database + .update_secret( + secret_id, + &payload.name, + &payload.content, + payload.timeout_seconds, + ) + .await?; + + let recipients = state + .database + .get_secret_recipients(secret.id) + .await? + .into_iter() + .map(|r| r.id) + .collect::>(); + + let to_add = payload + .recipients + .iter() + .filter(|id| !recipients.contains(id)) + .cloned() + .collect::>(); + + let to_delete = recipients + .iter() + .filter(|id| !payload.recipients.contains(id)) + .cloned() + .collect::>(); + + tracing::debug!( + "payload: {:?}, to_add: {to_add:?}, to_delete: {to_delete:?}, recipients: {recipients:?}", + payload.recipients + ); + + for recipient_id in &to_add { + state + .database + .add_secret_recipient(secret.id, *recipient_id) + .await?; + } + + for recipient_id in &to_delete { + state + .database + .delete_secret_recipient(secret.id, *recipient_id) + .await?; + } + + ws::broadcast_message(ws::Message::UpdateSecret(secret.clone()), &state, |key| { + key.user_id == secret.user_id + }) + .await; + + for recipient_id in to_add { + ws::broadcast_message( + ws::Message::SecretRecipientAdded { + id: secret.id, + user_id: recipient_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + } + + for recipient_id in to_delete { + ws::broadcast_message( + ws::Message::SecretRecipientDeleted { + id: secret.id, + user_id: recipient_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + } + + Ok(Json(secret)) +} + +pub async fn delete_by_id( + State(state): State, + context: web::context::Context, + Path(secret_id): Path, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + state.database.delete_secret(secret_id).await?; + + ws::broadcast_message(ws::Message::DeleteSecret { id: secret.id }, &state, |key| { + key.user_id == secret.user_id + }) + .await; + + Ok(()) +} + +pub async fn add_recipient_by_id( + State(state): State, + context: web::context::Context, + Path((secret_id, recipient_id)): Path<(ShortId, ShortId)>, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + let recipient = state.database.get_user_by_id(recipient_id).await?; + + let recipients = state.database.get_secret_recipients(secret_id).await?; + + if !recipients.iter().any(|r| r.id == recipient.id) { + state + .database + .add_secret_recipient(secret_id, recipient.id) + .await?; + } + + ws::broadcast_message( + ws::Message::SecretRecipientAdded { + id: secret.id, + user_id: recipient.id, + }, + &state, + |key| key.user_id == secret.user_id, + ) + .await; + + Ok(()) +} + +pub async fn delete_recipient_by_id( + State(state): State, + context: web::context::Context, + Path((secret_id, recipient_id)): Path<(ShortId, ShortId)>, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + state + .database + .delete_secret_recipient(secret_id, recipient_id) + .await?; + + ws::broadcast_message( + ws::Message::SecretRecipientDeleted { + id: secret.id, + user_id: recipient_id, + }, + &state, + |key| key.user_id == secret.user_id, + ) + .await; + + Ok(()) +} + +pub async fn get_recipients_by_id( + State(state): State, + context: web::context::Context, + Path(secret_id): Path, +) -> web::Result { + let secret = state.database.get_secret_by_id(secret_id).await?; + + if secret.user_id != context.user.id { + return Err(web::Error::NotAllowed); + } + + let recipients = state.database.get_secret_recipients(secret_id).await?; + + Ok(Json(recipients)) +} + +#[derive(serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct CreateSecretRequest { + pub name: String, + pub content: String, + pub timeout_seconds: i32, + pub recipients: Vec, +} + +#[derive(serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct UpdateSecretRequest { + pub name: String, + pub content: String, + pub timeout_seconds: i32, + pub recipients: Vec, +} diff --git a/src/web/routes/user/follow.rs b/src/web/routes/user/follow.rs new file mode 100644 index 0000000..b3f1b69 --- /dev/null +++ b/src/web/routes/user/follow.rs @@ -0,0 +1,109 @@ +use axum::{ + extract::{Path, State}, + response::IntoResponse, + Json, +}; + +use crate::{ + entity::ShortId, + state::AppState, + web::{self, error::ClientError, ws}, +}; + +pub async fn is_following_user( + State(state): State, + context: web::context::Context, + Path(user_id): Path, +) -> web::Result { + let is_following = state + .database + .get_followed_users(context.user.id) + .await? + .into_iter() + .map(|u| u.id) + .any(|id| id == user_id); + + Ok(Json(is_following)) +} + +pub async fn follow_user( + State(state): State, + context: web::context::Context, + Path(user_id): Path, +) -> web::Result { + if user_id == context.user.id { + return Err(ClientError::CannotFollow.into()); + } + + let is_following = state + .database + .get_followed_users(context.user.id) + .await? + .into_iter() + .map(|u| u.id) + .any(|id| id == user_id); + + if is_following { + return Err(ClientError::CannotFollow.into()); + } + + state.database.follow_user(context.user.id, user_id).await?; + + ws::broadcast_message( + ws::Message::FollowUser { + user_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + + Ok(()) +} + +pub async fn unfollow_user( + State(state): State, + context: web::context::Context, + Path(user_id): Path, +) -> web::Result { + if user_id == context.user.id { + return Err(ClientError::CannotUnfollow.into()); + } + + let is_following = state + .database + .get_followed_users(context.user.id) + .await? + .into_iter() + .map(|u| u.id) + .any(|id| id == user_id); + + if !is_following { + return Err(ClientError::CannotUnfollow.into()); + } + + state + .database + .unfollow_user(context.user.id, user_id) + .await?; + + ws::broadcast_message( + ws::Message::UnfollowUser { + user_id, + }, + &state, + |key| key.user_id == context.user.id, + ) + .await; + + Ok(()) +} + +pub async fn get_followed_users( + State(state): State, + context: web::context::Context, +) -> web::Result { + Ok(Json( + state.database.get_followed_users(context.user.id).await?, + )) +} diff --git a/src/web/routes/user/get.rs b/src/web/routes/user/get.rs index a453e9b..147c698 100644 --- a/src/web/routes/user/get.rs +++ b/src/web/routes/user/get.rs @@ -1,5 +1,5 @@ use axum::{ - extract::{Path, State}, + extract::{Path, Query, State}, response::IntoResponse, Json, }; @@ -33,8 +33,53 @@ pub async fn get_by_username( Ok(Json(user)) } +pub async fn search_by_username( + State(state): State, + Query(body): Query, +) -> web::Result { + if body.query.is_empty() { + return Ok(Json(vec![])); + } + + let users = state + .database + .search_user_by_username( + &body.query, + body.limit.unwrap_or(50).min(50), + body.offset.unwrap_or(0), + ) + .await?; + + Ok(Json(users)) +} + +pub async fn get_by_channel_id( + State(state): State, + Path(channel_id): Path, +) -> web::Result { + let users = state + .database + .get_channel_users(channel_id) + .await? + .into_iter() + .map(|(u, _admin)| u) + .collect::>(); + + Ok(Json(users)) +} + #[derive(serde::Deserialize)] #[serde(rename_all = "camelCase")] pub struct GetByUsernameRequest { pub username: String, } + +#[derive(serde::Deserialize)] +#[serde(rename_all = "camelCase")] +pub struct SearchByUsernameRequest { + pub query: String, + #[serde(default)] + pub limit: Option, + #[serde(default)] + pub offset: Option, +} diff --git a/src/web/routes/user/mod.rs b/src/web/routes/user/mod.rs index 5494ee3..b7e4394 100644 --- a/src/web/routes/user/mod.rs +++ b/src/web/routes/user/mod.rs @@ -1,7 +1,9 @@ +mod follow; mod get; mod login; mod register; +pub use follow::*; pub use get::*; -pub use login::login; -pub use register::register; +pub use login::*; +pub use register::*; diff --git a/src/web/routes/user/register.rs b/src/web/routes/user/register.rs index ecdb4fe..cdc10c4 100644 --- a/src/web/routes/user/register.rs +++ b/src/web/routes/user/register.rs @@ -1,6 +1,5 @@ use axum::{extract::State, response::IntoResponse, Json}; use serde::Deserialize; -use serde_json::json; use crate::{state::AppState, web}; @@ -13,11 +12,7 @@ pub async fn register( .create_user(&payload.username, &payload.password) .await?; - Ok(Json(json!({ - "id": user.id, - "username": user.username, - "created_at": user.created_at - }))) + Ok(Json(user)) } #[derive(Deserialize)] diff --git a/src/web/ws/mod.rs b/src/web/ws/mod.rs index 04fcbed..7fd5a11 100644 --- a/src/web/ws/mod.rs +++ b/src/web/ws/mod.rs @@ -18,7 +18,40 @@ pub enum Message { CreateMessage(entity::Message), UpdateChannel(entity::Channel), CreateChannel(entity::Channel), - DeleteChannel { id: entity::ShortId }, + DeleteChannel { + id: entity::ShortId, + }, + AddedUserToChannel { + user_id: entity::ShortId, + channel_id: entity::ShortId, + }, + RemovedUserFromChannel { + user_id: entity::ShortId, + channel_id: entity::ShortId, + }, + CreateSecret(entity::Secret), + UpdateSecret(entity::Secret), + SecretRecipientAdded { + id: entity::ShortId, + user_id: entity::ShortId, + }, + SecretRecipientDeleted { + id: entity::ShortId, + user_id: entity::ShortId, + }, + DeleteSecret { + id: entity::ShortId, + }, + CreateNotification(entity::Notification), + SeenNotification { + id: entity::LongId, + }, + FollowUser { + user_id: entity::ShortId, + }, + UnfollowUser { + user_id: entity::ShortId, + }, } pub async fn broadcast_message(