1
0
Files
nir/src/web/ws/mod.rs
2024-05-21 05:44:46 +03:00

135 lines
3.7 KiB
Rust

use axum::{
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
response::IntoResponse,
};
use futures::{SinkExt, StreamExt};
use crate::{
entity,
state::{AppState, WebSocketKey},
};
use super::{context, error, middlware};
#[derive(serde::Serialize, Clone)]
#[serde(rename_all = "camelCase")]
#[serde(tag = "type", content = "data")]
pub enum Message {
CreateMessage(entity::Message),
UpdateChannel(entity::Channel),
CreateChannel(entity::Channel),
DeleteChannel {
id: entity::ShortId,
},
AddedUserToChannel {
user_id: entity::ShortId,
channel_id: entity::ShortId,
},
RemovedUserFromChannel {
user_id: entity::ShortId,
channel_id: entity::ShortId,
},
CreateSecret(entity::Secret),
UpdateSecret(entity::Secret),
SecretRecipientAdded {
id: entity::ShortId,
user_id: entity::ShortId,
},
SecretRecipientDeleted {
id: entity::ShortId,
user_id: entity::ShortId,
},
DeleteSecret {
id: entity::ShortId,
},
CreateNotification(entity::Notification),
SeenNotification {
id: entity::LongId,
},
FollowUser {
user_id: entity::ShortId,
},
UnfollowUser {
user_id: entity::ShortId,
},
}
pub async fn broadcast_message(
message: Message,
state: &AppState,
predicate: impl Fn(&WebSocketKey) -> bool,
) {
let connected_users = state.connected_users.read().await;
let recievers =
connected_users
.iter()
.filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None });
for reciever in recievers {
_ = reciever
.send(message.clone())
.inspect_err(|err| tracing::error!("Failed to send message: {}", err));
}
}
pub async fn ws_handler(
ws: WebSocketUpgrade,
State(state): State<AppState>,
Path(token): Path<String>,
) -> error::Result<impl IntoResponse> {
let context = middlware::get_context_from_token(state.clone(), &token).await?;
Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context)))
}
async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) {
let user_key = WebSocketKey {
token: context.token.clone(),
user_id: context.user.id,
};
let (mut sender, _) = websocket.split();
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
{
let mut connected_users = state.connected_users.write().await;
if connected_users.contains_key(&user_key) {
tracing::trace!("websocket already connected: {user_key:?}");
drop(connected_users.remove(&user_key));
}
connected_users.insert(user_key.clone(), tx);
}
tracing::trace!("websocket connected: {user_key:?}");
tokio::spawn(async move {
// idk
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
while let Some(message) = rx.recv().await {
let err = sender
.send(axum::extract::ws::Message::Text(
serde_json::to_string(&message)
.inspect_err(|e| tracing::error!("Could not serialize message: {e}"))
.unwrap(),
))
.await
.inspect_err(|e| tracing::error!("Could not send message: {e}"));
if err.is_err() {
break;
}
}
let _ = sender.send(axum::extract::ws::Message::Close(None)).await;
{
let mut connected_users = state.connected_users.write().await;
connected_users.remove(&user_key);
tracing::trace!("websocket disconnected: {user_key:?}");
}
});
}