135 lines
3.7 KiB
Rust
135 lines
3.7 KiB
Rust
use axum::{
|
|
extract::{ws::WebSocket, Path, State, WebSocketUpgrade},
|
|
response::IntoResponse,
|
|
};
|
|
use futures::{SinkExt, StreamExt};
|
|
|
|
use crate::{
|
|
entity,
|
|
state::{AppState, WebSocketKey},
|
|
};
|
|
|
|
use super::{context, error, middlware};
|
|
|
|
#[derive(serde::Serialize, Clone)]
|
|
#[serde(rename_all = "camelCase")]
|
|
#[serde(tag = "type", content = "data")]
|
|
pub enum Message {
|
|
CreateMessage(entity::Message),
|
|
UpdateChannel(entity::Channel),
|
|
CreateChannel(entity::Channel),
|
|
DeleteChannel {
|
|
id: entity::ShortId,
|
|
},
|
|
AddedUserToChannel {
|
|
user_id: entity::ShortId,
|
|
channel_id: entity::ShortId,
|
|
},
|
|
RemovedUserFromChannel {
|
|
user_id: entity::ShortId,
|
|
channel_id: entity::ShortId,
|
|
},
|
|
CreateSecret(entity::Secret),
|
|
UpdateSecret(entity::Secret),
|
|
SecretRecipientAdded {
|
|
id: entity::ShortId,
|
|
user_id: entity::ShortId,
|
|
},
|
|
SecretRecipientDeleted {
|
|
id: entity::ShortId,
|
|
user_id: entity::ShortId,
|
|
},
|
|
DeleteSecret {
|
|
id: entity::ShortId,
|
|
},
|
|
CreateNotification(entity::Notification),
|
|
SeenNotification {
|
|
id: entity::LongId,
|
|
},
|
|
FollowUser {
|
|
user_id: entity::ShortId,
|
|
},
|
|
UnfollowUser {
|
|
user_id: entity::ShortId,
|
|
},
|
|
}
|
|
|
|
pub async fn broadcast_message(
|
|
message: Message,
|
|
state: &AppState,
|
|
predicate: impl Fn(&WebSocketKey) -> bool,
|
|
) {
|
|
let connected_users = state.connected_users.read().await;
|
|
|
|
let recievers =
|
|
connected_users
|
|
.iter()
|
|
.filter_map(|(key, conn)| if predicate(key) { Some(conn) } else { None });
|
|
|
|
for reciever in recievers {
|
|
_ = reciever
|
|
.send(message.clone())
|
|
.inspect_err(|err| tracing::error!("Failed to send message: {}", err));
|
|
}
|
|
}
|
|
|
|
pub async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
State(state): State<AppState>,
|
|
Path(token): Path<String>,
|
|
) -> error::Result<impl IntoResponse> {
|
|
let context = middlware::get_context_from_token(state.clone(), &token).await?;
|
|
|
|
Ok(ws.on_upgrade(|socket| handle_socket(socket, state, context)))
|
|
}
|
|
|
|
async fn handle_socket(websocket: WebSocket, state: AppState, context: context::Context) {
|
|
let user_key = WebSocketKey {
|
|
token: context.token.clone(),
|
|
user_id: context.user.id,
|
|
};
|
|
let (mut sender, _) = websocket.split();
|
|
|
|
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel::<Message>();
|
|
|
|
{
|
|
let mut connected_users = state.connected_users.write().await;
|
|
if connected_users.contains_key(&user_key) {
|
|
tracing::trace!("websocket already connected: {user_key:?}");
|
|
|
|
drop(connected_users.remove(&user_key));
|
|
}
|
|
connected_users.insert(user_key.clone(), tx);
|
|
}
|
|
|
|
tracing::trace!("websocket connected: {user_key:?}");
|
|
|
|
tokio::spawn(async move {
|
|
// idk
|
|
// tokio::time::sleep(std::time::Duration::from_secs(1)).await;
|
|
while let Some(message) = rx.recv().await {
|
|
let err = sender
|
|
.send(axum::extract::ws::Message::Text(
|
|
serde_json::to_string(&message)
|
|
.inspect_err(|e| tracing::error!("Could not serialize message: {e}"))
|
|
.unwrap(),
|
|
))
|
|
.await
|
|
.inspect_err(|e| tracing::error!("Could not send message: {e}"));
|
|
|
|
if err.is_err() {
|
|
break;
|
|
}
|
|
}
|
|
|
|
let _ = sender.send(axum::extract::ws::Message::Close(None)).await;
|
|
|
|
{
|
|
let mut connected_users = state.connected_users.write().await;
|
|
connected_users.remove(&user_key);
|
|
|
|
tracing::trace!("websocket disconnected: {user_key:?}");
|
|
}
|
|
});
|
|
}
|