From 46ab04be24499a412f14ae21d822b6275fb27e73 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Sun, 8 Jan 2023 17:03:27 +0000 Subject: [PATCH] Connect the user back to any channels they were previously connected to --- src/client.rs | 18 ++++++++++++++++++ src/main.rs | 7 ++++++- src/persistence.rs | 28 +++++++++++++++++++++++++++- src/persistence/events.rs | 7 +++++++ 4 files changed, 58 insertions(+), 2 deletions(-) diff --git a/src/client.rs b/src/client.rs index 722718b..fcac020 100644 --- a/src/client.rs +++ b/src/client.rs @@ -19,6 +19,7 @@ use crate::{ ServerDisconnect, ServerFetchMotd, UserKickedFromChannel, UserNickChange, UserNickChangeInternal, }, + persistence::{events::FetchUserChannels, Persistence}, server::Server, SERVER_NAME, }; @@ -44,6 +45,8 @@ pub struct Client { /// The reason the client is leaving the server, whether this is set by the server or the user /// is decided by graceful_shutdown pub server_leave_reason: Option, + /// Actor for persisting state to the datastore. + pub persistence: Addr, /// The connection span to group all logs for the same connection pub span: Span, } @@ -71,6 +74,21 @@ impl Actor for Client { command: Command::PING(SERVER_NAME.to_string(), None), }); }); + + ctx.spawn( + self.persistence + .send(FetchUserChannels { + username: self.connection.user.to_string(), + span: Span::current(), + }) + .into_actor(self) + .map(move |res, this, ctx| { + ctx.notify(JoinChannelRequest { + channels: res.unwrap(), + span: this.span.clone(), + }); + }), + ); } /// Called when the actor is shutting down, either gracefully by the client or forcefully diff --git a/src/main.rs b/src/main.rs index f551b49..709406b 100644 --- a/src/main.rs +++ b/src/main.rs @@ -79,7 +79,7 @@ async fn main() -> anyhow::Result<()> { let server_arbiter = Arbiter::new(); - let persistence = { + let persistence_addr = { let database = database.clone(); Supervisor::start_in_arbiter(&server_arbiter.handle(), move |_ctx| Persistence { @@ -87,6 +87,7 @@ async fn main() -> anyhow::Result<()> { }) }; + let persistence = persistence_addr.clone(); let server = Supervisor::start_in_arbiter(&server_arbiter.handle(), move |_ctx| Server { channels: HashMap::default(), clients: HashMap::default(), @@ -100,6 +101,7 @@ async fn main() -> anyhow::Result<()> { actix_rt::spawn(start_tcp_acceptor_loop( listener, database, + persistence_addr, server, client_threads, )); @@ -117,6 +119,7 @@ async fn main() -> anyhow::Result<()> { async fn start_tcp_acceptor_loop( listener: TcpListener, database: sqlx::Pool, + persistence: Addr, server: Addr, client_threads: usize, ) { @@ -131,6 +134,7 @@ async fn start_tcp_acceptor_loop( let database = database.clone(); let server = server.clone(); let client_arbiters = client_arbiters.clone(); + let persistence = persistence.clone(); actix_rt::spawn(async move { // split the stream into its read and write halves and setup codecs @@ -170,6 +174,7 @@ async fn start_tcp_acceptor_loop( graceful_shutdown: false, server_leave_reason: None, span, + persistence, } }) }; diff --git a/src/persistence.rs b/src/persistence.rs index 8bcaa3d..d0bb723 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -3,7 +3,7 @@ pub mod events; use actix::{Context, Handler, ResponseFuture}; use tracing::instrument; -use crate::persistence::events::{ChannelCreated, ChannelJoined, ChannelParted}; +use crate::persistence::events::{ChannelCreated, ChannelJoined, ChannelParted, FetchUserChannels}; /// Takes events destined for other actors and persists them to the database. pub struct Persistence { @@ -81,3 +81,29 @@ impl Handler for Persistence { }) } } + +impl Handler for Persistence { + type Result = ResponseFuture>; + + fn handle(&mut self, msg: FetchUserChannels, _ctx: &mut Self::Context) -> Self::Result { + let conn = self.database.clone(); + + Box::pin(async move { + sqlx::query_as( + "SELECT channels.name + FROM channel_users + INNER JOIN channels + ON channels.id = channel_users.channel + WHERE user = (SELECT id FROM users WHERE username = ?) + AND in_channel = true", + ) + .bind(msg.username) + .fetch_all(&conn) + .await + .unwrap() + .into_iter() + .map(|(v,)| v) + .collect() + }) + } +} diff --git a/src/persistence/events.rs b/src/persistence/events.rs index beaeb53..7f0f634 100644 --- a/src/persistence/events.rs +++ b/src/persistence/events.rs @@ -22,3 +22,10 @@ pub struct ChannelParted { pub username: String, pub span: Span, } + +#[derive(Message)] +#[rtype(result = "Vec")] +pub struct FetchUserChannels { + pub username: String, + pub span: Span, +} -- libgit2 1.7.2