use axum::{ extract::{ws::WebSocket, Path, State, WebSocketUpgrade}, response::IntoResponse, }; use futures::{SinkExt, StreamExt}; use crate::{ entity, state::{AppState, WebSocketKey}, }; use super::{context, error, middlware}; #[derive(serde::Serialize, Clone)] #[serde(rename_all = "camelCase")] #[serde(tag = "type", content = "data")] pub enum Message { CreateMessage(entity::Message), UpdateChannel(entity::Channel), CreateChannel(entity::Channel), DeleteChannel { id: entity::ShortId, }, AddedUserToChannel { user_id: entity::ShortId, channel_id: entity::ShortId, }, RemovedUserFromChannel { user_id: entity::ShortId, channel_id: entity::ShortId, }, CreateSecret(entity::Secret), UpdateSecret(entity::Secret), SecretRecipientAdded { id: entity::ShortId, user_id: entity::ShortId, }, SecretRecipientDeleted { id: entity::ShortId, user_id: entity::ShortId, }, DeleteSecret { id: entity::ShortId, }, CreateNotification(entity::Notification), SeenNotification { id: entity::LongId, }, FollowUser { user_id: entity::ShortId, }, UnfollowUser { user_id: entity::ShortId, }, } pub async fn broadcast_message( message: Message, state: &AppState, predicate: impl Fn(&WebSocketKey) -> bool, ) { let connected_users = state.connected_users.read().await; let recievers = connected_users .iter() .filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None }); for reciever in recievers { _ = reciever .send(message.clone()) .inspect_err(|err| tracing::error!("Failed to send message: {}", err)); } } pub async fn ws_handler( ws: WebSocketUpgrade, State(state): State, Path(token): Path, ) -> error::Result { let context = middlware::get_context_from_token(state.clone(), &token).await?; Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context))) } async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) { let user_key = WebSocketKey { token: context.token.clone(), user_id: context.user.id, }; let (mut sender, _) = websocket.split(); let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::(); { let mut connected_users = state.connected_users.write().await; if connected_users.contains_key(&user_key) { tracing::trace!("websocket already connected: {user_key:?}"); drop(connected_users.remove(&user_key)); } connected_users.insert(user_key.clone(), tx); } tracing::trace!("websocket connected: {user_key:?}"); tokio::spawn(async move { // idk // tokio::time::sleep(std::time::Duration::from_secs(1)).await; while let Some(message) = rx.recv().await { let err = sender .send(axum::extract::ws::Message::Text( serde_json::to_string(&message) .inspect_err(|e| tracing::error!("Could not serialize message: {e}")) .unwrap(), )) .await .inspect_err(|e| tracing::error!("Could not send message: {e}")); if err.is_err() { break; } } let _ = sender.send(axum::extract::ws::Message::Close(None)).await; { let mut connected_users = state.connected_users.write().await; connected_users.remove(&user_key); tracing::trace!("websocket disconnected: {user_key:?}"); } }); }