From 636411f8e0cbab37ba3676f045e8f1434a29189f Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Sun, 8 Jan 2023 20:38:23 +0000 Subject: [PATCH] Give user access to its user id and modify queries to use it --- src/channel.rs | 6 +++--- src/client.rs | 17 +++++++---------- src/connection.rs | 55 +++++++++++++++++++++++++++---------------------------- src/database/mod.rs | 56 ++++++++++++++++++-------------------------------------- src/persistence.rs | 46 ++++++++++++++++++++++++---------------------- src/persistence/events.rs | 12 +++++++----- 6 files changed, 86 insertions(+), 106 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 26afbee..c5a6c81 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -90,7 +90,7 @@ impl Handler for Channel { channel_name: self.name.to_string(), sender: nick.to_string(), message: msg.message.to_string(), - receivers: self.clients.values().map(|v| v.user.to_string()).collect(), + receivers: self.clients.values().map(|v| v.user_id).collect(), }); for client in self.clients.keys() { @@ -144,7 +144,7 @@ impl Handler for Channel { self.persistence .do_send(crate::persistence::events::ChannelJoined { channel_name: self.name.to_string(), - username: msg.connection.user.to_string(), + user_id: msg.connection.user_id, span: msg.span.clone(), }); @@ -282,7 +282,7 @@ impl Handler for Channel { self.persistence .do_send(crate::persistence::events::ChannelParted { channel_name: self.name.to_string(), - username: client_info.user.to_string(), + user_id: client_info.user_id, span: msg.span.clone(), }); diff --git a/src/client.rs b/src/client.rs index 1d015f6..7639e55 100644 --- a/src/client.rs +++ b/src/client.rs @@ -81,7 +81,7 @@ impl Actor for Client { ctx.spawn( self.persistence .send(FetchUserChannels { - username: self.connection.user.to_string(), + user_id: self.connection.user_id, span: Span::current(), }) .into_actor(self) @@ -185,7 +185,7 @@ impl Handler for Client { let channel_messages_fut = self.persistence.send(FetchUnseenMessages { channel_name: channel_name.to_string(), - username: self.connection.user.to_string(), + user_id: self.connection.user_id, span: Span::current(), }); @@ -203,18 +203,15 @@ impl Handler for Client { let fut = wrap_future::<_, Self>( futures::future::join_all(futures.into_iter()).instrument(Span::current()), ) - .map(|result, this, ctx| { + .map(|result, this, _ctx| { for (channel_name, handle, messages) in result { this.channels.insert(channel_name.clone(), handle); for (source, message) in messages { - ctx.notify(Broadcast { - message: Message { - tags: None, - prefix: Some(Prefix::new_from_str(&source)), - command: Command::PRIVMSG(channel_name.clone(), message), - }, - span: this.span.clone(), + this.writer.write(Message { + tags: None, + prefix: Some(Prefix::new_from_str(&source)), + command: Command::PRIVMSG(channel_name.clone(), message), }); } } diff --git a/src/connection.rs b/src/connection.rs index 6817b0d..07d8bf9 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -25,6 +25,9 @@ pub type MessageSink = FramedWrite, irc_proto::Irc pub const SUPPORTED_CAPABILITIES: &[&str] = &[concatcp!("sasl=", AuthStrategy::SUPPORTED)]; +#[derive(Copy, Clone, Debug)] +pub struct UserId(pub i64); + #[derive(Default)] pub struct ConnectionRequest { nick: Option, @@ -39,6 +42,7 @@ pub struct InitiatedConnection { pub user: String, pub mode: String, pub real_name: String, + pub user_id: UserId, } impl InitiatedConnection { @@ -70,6 +74,7 @@ impl TryFrom for InitiatedConnection { user, mode, real_name, + user_id: UserId(0), }) } } @@ -136,7 +141,7 @@ pub async fn negotiate_client_connection( // if the user closed the connection before the connection was fully established, // return back early - let Some(initiated) = initiated else { + let Some(mut initiated) = initiated else { return Ok(None); }; @@ -147,7 +152,7 @@ pub async fn negotiate_client_connection( ))); } - let mut has_authenticated = false; + let mut user_id = None; // start negotiating capabilities with the client while let Some(msg) = s.try_next().await? { @@ -161,7 +166,7 @@ pub async fn negotiate_client_connection( break; } Command::AUTHENTICATE(strategy) => { - has_authenticated = + user_id = start_authenticate_flow(s, write, &initiated, strategy, &database).await?; } _ => { @@ -173,7 +178,9 @@ pub async fn negotiate_client_connection( } } - if has_authenticated { + if let Some(user_id) = user_id { + initiated.user_id.0 = user_id; + Ok(Some(initiated)) } else { Err(ProtocolError::Io(Error::new( @@ -193,11 +200,11 @@ async fn start_authenticate_flow( connection: &InitiatedConnection, strategy: String, database: &sqlx::Pool, -) -> Result { +) -> Result, ProtocolError> { let Ok(auth_strategy) = AuthStrategy::from_str(&strategy) else { write.send(SaslStrategyUnsupported(connection.nick.to_string()).into_message()) .await?; - return Ok(false); + return Ok(None); }; // tell the client to go ahead with their authentication @@ -226,7 +233,7 @@ async fn start_authenticate_flow( break; } - let authenticated = match auth_strategy { + let user_id = match auth_strategy { AuthStrategy::Plain => { // TODO: this needs to deal with the case where the full arguments can be split over // multiple messages @@ -234,12 +241,12 @@ async fn start_authenticate_flow( } }; - if authenticated { + if user_id.is_some() { for message in SaslSuccess(connection.clone()).into_messages() { write.send(message).await?; } - return Ok(true); + return Ok(user_id); } write @@ -247,18 +254,20 @@ async fn start_authenticate_flow( .await?; } - Ok(false) + Ok(None) } /// Attempts to handle an `AUTHENTICATE` command for the `PLAIN` authentication method. /// /// This will parse the full message, ensure that the identity is correct and compare the hashes /// to what we have stored in the database. +/// +/// This function will return the authenticated user id, or none if the password was incorrect. pub async fn handle_plain_authentication( arguments: &str, connection: &InitiatedConnection, database: &sqlx::Pool, -) -> Result { +) -> Result, Error> { let arguments = BASE64_STANDARD .decode(arguments) .map_err(|e| Error::new(ErrorKind::InvalidData, e))?; @@ -277,26 +286,16 @@ pub async fn handle_plain_authentication( } // lookup the user's password based on the USER command they sent earlier - let password_hash = crate::database::fetch_password_hash(database, &connection.user) - .await - .unwrap(); - let password_hash = password_hash - .as_deref() - .map(PasswordHash::new) - .transpose() - .unwrap(); - let Some(password_hash) = password_hash else { - // this is a new user, so we'll create an account for them - // TODO: we need to deal with races here, right now we'll just error out on dup - crate::database::create_user(database, &connection.user, password).await.unwrap(); - - return Ok(true); - }; + let (user_id, password_hash) = + crate::database::create_user_or_fetch_password_hash(database, &connection.user, password) + .await + .unwrap(); + let password_hash = PasswordHash::new(&password_hash).unwrap(); // check the user's password match verify_password(password, &password_hash) { - Ok(()) => Ok(true), - Err(argon2::password_hash::Error::Password) => Ok(false), + Ok(()) => Ok(Some(user_id)), + Err(argon2::password_hash::Error::Password) => Ok(None), Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())), } } diff --git a/src/database/mod.rs b/src/database/mod.rs index 0fb5ba7..366e0ea 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,49 +1,29 @@ use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; use rand::rngs::OsRng; -use sqlx::{database::HasArguments, Database, Encode, Executor, FromRow, IntoArguments, Type}; -/// Fetches the given user's password from the database. -pub async fn fetch_password_hash<'a, E: Executor<'a>>( - conn: E, - username: &'a str, -) -> Result, sqlx::Error> -where - for<'b> &'b str: Type + Encode<'b, E::Database>, - >::Arguments: IntoArguments<'a, E::Database>, - for<'b> (String,): FromRow<'b, ::Row>, -{ - let res = sqlx::query_as("SELECT password FROM users WHERE username = ?") - .bind(username) - .fetch_optional(conn) - .await? - .map(|(v,)| v); - - Ok(res) -} - -/// Creates a new user, returning an error if the user already exists. -pub async fn create_user<'a, E: Executor<'a>>( - conn: E, - username: &'a str, +/// Attempts creation of a new user, returning the password of the user. +/// +/// The returned password _is not_ guaranteed to be the password that was just set. +pub async fn create_user_or_fetch_password_hash( + conn: &sqlx::Pool, + username: &str, password: &[u8], -) -> Result<(), sqlx::Error> -where - for<'b> &'b str: Type + Encode<'b, E::Database>, - for<'b> String: Type + Encode<'b, E::Database>, - >::Arguments: IntoArguments<'a, E::Database>, -{ - let salt = SaltString::generate(&mut OsRng); +) -> Result<(i64, String), sqlx::Error> { let password_hash = Argon2::default() - .hash_password(password, &salt) + .hash_password(password, &SaltString::generate(&mut OsRng)) .unwrap() .to_string(); - sqlx::query("INSERT INTO users (username, password) VALUES (?, ?)") - .bind(username) - .bind(password_hash) - .execute(conn) - .await - .map(|_| ()) + sqlx::query_as( + "INSERT INTO users (username, password) + VALUES (?, ?) + ON CONFLICT(username) DO UPDATE SET username = username + RETURNING id, password", + ) + .bind(username) + .bind(password_hash) + .fetch_one(conn) + .await } /// Compares a password to a hash stored in the database. diff --git a/src/persistence.rs b/src/persistence.rs index 90d7041..20d2fb8 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -48,11 +48,11 @@ impl Handler for Persistence { Box::pin(async move { sqlx::query( "INSERT INTO channel_users (channel, user, permissions, in_channel) - VALUES ((SELECT id FROM channels WHERE name = ?), (SELECT id FROM users WHERE username = ?), ?, ?) - ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel" + VALUES ((SELECT id FROM channels WHERE name = ?), ?, ?, ?) + ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel", ) .bind(msg.channel_name) - .bind(msg.username) + .bind(msg.user_id.0) .bind(0i32) .bind(true) .execute(&conn) @@ -75,10 +75,10 @@ impl Handler for Persistence { "UPDATE channel_users SET in_channel = false WHERE channel = (SELECT id FROM channels WHERE name = ?) - AND user = (SELECT id FROM users WHERE username = ?)", + AND user = ?", ) .bind(msg.channel_name) - .bind(msg.username) + .bind(msg.user_id.0) .execute(&conn) .await .unwrap(); @@ -101,7 +101,7 @@ impl Handler for Persistence { WHERE user = (SELECT id FROM users WHERE username = ?) AND in_channel = true", ) - .bind(msg.username) + .bind(msg.user_id.0) .fetch_all(&conn) .await .unwrap() @@ -137,20 +137,22 @@ impl Handler for Persistence { .await .unwrap(); - let query = format!( - "UPDATE channel_users - SET last_seen_message_idx = ? - WHERE channel = ? - AND user IN (SELECT id FROM users WHERE username IN ({}))", - msg.receivers.iter().map(|_| "?").join(",") - ); - - let mut query = sqlx::query(&query).bind(idx).bind(channel); - for receiver in msg.receivers { - query = query.bind(receiver); + if !msg.receivers.is_empty() { + let query = format!( + "UPDATE channel_users + SET last_seen_message_idx = ? + WHERE channel = ? + AND user IN ({})", + msg.receivers.iter().map(|_| "?").join(",") + ); + + let mut query = sqlx::query(&query).bind(idx).bind(channel); + for receiver in msg.receivers { + query = query.bind(receiver.0); + } + + query.execute(&conn).await.unwrap(); } - - query.execute(&conn).await.unwrap(); }) } } @@ -180,13 +182,13 @@ impl Handler for Persistence { SELECT last_seen_message_idx FROM channel_users WHERE channel = (SELECT id FROM channel) - AND user = (SELECT id FROM users WHERE username = ?) + AND user = ? ) ) - ORDER BY idx DESC", + ORDER BY idx ASC", ) .bind(msg.channel_name.to_string()) - .bind(msg.username.to_string()) + .bind(msg.user_id.0) .fetch_all(&conn) .await .unwrap(); diff --git a/src/persistence/events.rs b/src/persistence/events.rs index 8e69bb2..e064511 100644 --- a/src/persistence/events.rs +++ b/src/persistence/events.rs @@ -1,6 +1,8 @@ use actix::Message; use tracing::Span; +use crate::connection::UserId; + #[derive(Message)] #[rtype(result = "()")] pub struct ChannelCreated { @@ -11,7 +13,7 @@ pub struct ChannelCreated { #[rtype(result = "()")] pub struct ChannelJoined { pub channel_name: String, - pub username: String, + pub user_id: UserId, pub span: Span, } @@ -19,14 +21,14 @@ pub struct ChannelJoined { #[rtype(result = "()")] pub struct ChannelParted { pub channel_name: String, - pub username: String, + pub user_id: UserId, pub span: Span, } #[derive(Message)] #[rtype(result = "Vec")] pub struct FetchUserChannels { - pub username: String, + pub user_id: UserId, pub span: Span, } @@ -36,13 +38,13 @@ pub struct ChannelMessage { pub channel_name: String, pub sender: String, pub message: String, - pub receivers: Vec, + pub receivers: Vec, } #[derive(Message)] #[rtype(result = "Vec<(String, String)>")] pub struct FetchUnseenMessages { pub channel_name: String, - pub username: String, + pub user_id: UserId, pub span: Span, } -- libgit2 1.7.2