From 150368f3bc01d79c4d6ecc0b12aa755e5a7bf3dc Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Wed, 11 Jan 2023 00:04:25 +0000 Subject: [PATCH] Reserve user nicks per user and prevent nick changes to other user's reserved nicks --- migrations/2023010814480_initial-schema.sql | 6 ++++++ src/client.rs | 56 +++++++++++++++++++++++++++++++++++--------------------- src/connection.rs | 23 +++++++++++++++++++++-- src/database/mod.rs | 21 +++++++++++++++++++++ src/main.rs | 2 +- src/persistence.rs | 26 +++++++++++++++++++++++++- src/persistence/events.rs | 7 +++++++ 7 files changed, 116 insertions(+), 25 deletions(-) diff --git a/migrations/2023010814480_initial-schema.sql b/migrations/2023010814480_initial-schema.sql index 303ffe5..8d76446 100644 --- a/migrations/2023010814480_initial-schema.sql +++ b/migrations/2023010814480_initial-schema.sql @@ -6,6 +6,12 @@ CREATE TABLE users ( CREATE UNIQUE INDEX users_username ON users(username); +CREATE TABLE user_nicks ( + nick VARCHAR(255) NOT NULL PRIMARY KEY, + user INTEGER NOT NULL, + FOREIGN KEY(user) REFERENCES users(id) +); + CREATE TABLE channels ( id INTEGER PRIMARY KEY, name VARCHAR(255) NOT NULL diff --git a/src/client.rs b/src/client.rs index 4e32da1..f96afd1 100644 --- a/src/client.rs +++ b/src/client.rs @@ -20,7 +20,7 @@ use crate::{ UserNickChange, UserNickChangeInternal, }, persistence::{ - events::{FetchUnseenMessages, FetchUserChannels}, + events::{FetchUnseenMessages, FetchUserChannels, ReserveNick}, Persistence, }, server::Server, @@ -271,30 +271,44 @@ impl Handler for Client { } impl Handler for Client { - type Result = (); + type Result = ResponseActFuture; #[instrument(parent = &msg.span, skip_all)] - fn handle(&mut self, msg: UserNickChangeInternal, ctx: &mut Self::Context) -> Self::Result { - // alert the server to the nick change (we'll receive this event back so the user - // gets the notification too) - self.server.do_send(UserNickChange { - client: ctx.address(), - connection: self.connection.clone(), - new_nick: msg.new_nick.clone(), - span: Span::current(), - }); + fn handle(&mut self, msg: UserNickChangeInternal, _ctx: &mut Self::Context) -> Self::Result { + self.persistence + .send(ReserveNick { + user_id: self.connection.user_id, + nick: msg.new_nick.clone(), + }) + .into_actor(self) + .map(|res, this, ctx| { + if !res.unwrap() { + // TODO: send notification to user to say the nick isn't available + return; + } - for channel in self.channels.values() { - channel.do_send(UserNickChange { - client: ctx.address(), - connection: self.connection.clone(), - new_nick: msg.new_nick.clone(), - span: Span::current(), - }); - } + // alert the server to the nick change (we'll receive this event back so the user + // gets the notification too) + this.server.do_send(UserNickChange { + client: ctx.address(), + connection: this.connection.clone(), + new_nick: msg.new_nick.clone(), + span: Span::current(), + }); + + for channel in this.channels.values() { + channel.do_send(UserNickChange { + client: ctx.address(), + connection: this.connection.clone(), + new_nick: msg.new_nick.clone(), + span: Span::current(), + }); + } - // updates our nick locally - self.connection.nick = msg.new_nick; + // updates our nick locally + this.connection.nick = msg.new_nick; + }) + .boxed_local() } } diff --git a/src/connection.rs b/src/connection.rs index 4c6b339..556aedf 100644 --- a/src/connection.rs +++ b/src/connection.rs @@ -4,7 +4,7 @@ use std::{ str::FromStr, }; -use actix::io::FramedWrite; +use actix::{io::FramedWrite, Addr}; use argon2::PasswordHash; use base64::{prelude::BASE64_STANDARD, Engine}; use const_format::concatcp; @@ -19,7 +19,10 @@ use tokio::{ use tokio_util::codec::FramedRead; use tracing::{instrument, warn}; -use crate::database::verify_password; +use crate::{ + database::verify_password, + persistence::{events::ReserveNick, Persistence}, +}; pub type MessageStream = FramedRead, irc_proto::IrcCodec>; pub type MessageSink = FramedWrite, irc_proto::IrcCodec>; @@ -91,6 +94,7 @@ pub async fn negotiate_client_connection( s: &mut MessageStream, write: &mut tokio_util::codec::FramedWrite, IrcCodec>, host: SocketAddr, + persistence: &Addr, database: sqlx::Pool, ) -> Result, ProtocolError> { let mut request = ConnectionRequest { @@ -190,6 +194,21 @@ pub async fn negotiate_client_connection( if let Some(user_id) = user_id { initiated.user_id.0 = user_id; + let reserved_nick = persistence + .send(ReserveNick { + user_id: initiated.user_id, + nick: initiated.nick.clone(), + }) + .await + .map_err(|e| ProtocolError::Io(Error::new(ErrorKind::InvalidData, e)))?; + + if !reserved_nick { + return Err(ProtocolError::Io(Error::new( + ErrorKind::InvalidData, + "nick is already in use by another user", + ))); + } + Ok(Some(initiated)) } else { Err(ProtocolError::Io(Error::new( diff --git a/src/database/mod.rs b/src/database/mod.rs index 366e0ea..16aac39 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,6 +1,8 @@ use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier}; use rand::rngs::OsRng; +use crate::connection::UserId; + /// Attempts creation of a new user, returning the password of the user. /// /// The returned password _is not_ guaranteed to be the password that was just set. @@ -26,6 +28,25 @@ pub async fn create_user_or_fetch_password_hash( .await } +pub async fn reserve_nick( + conn: &sqlx::Pool, + nick: &str, + user_id: UserId, +) -> Result { + let (owning_user,): (i64,) = sqlx::query_as( + "INSERT INTO user_nicks (nick, user) + VALUES (?, ?) + ON CONFLICT(nick) DO UPDATE SET nick = nick + RETURNING user", + ) + .bind(nick) + .bind(user_id.0) + .fetch_one(conn) + .await?; + + Ok(owning_user == user_id.0) +} + /// Compares a password to a hash stored in the database. pub fn verify_password(password: &[u8], hash: &PasswordHash) -> argon2::password_hash::Result<()> { Argon2::default().verify_password(password, hash) diff --git a/src/main.rs b/src/main.rs index 6bd8a3f..bf5e2c0 100644 --- a/src/main.rs +++ b/src/main.rs @@ -147,7 +147,7 @@ async fn start_tcp_acceptor_loop( // ensure we have all the details required to actually connect the client to the server // (ie. we have a nick, user, etc) - let Some(connection) = connection::negotiate_client_connection(&mut read, &mut write, addr, database).await.unwrap() else { + let Some(connection) = connection::negotiate_client_connection(&mut read, &mut write, addr, &persistence, database).await.unwrap() else { error!("Failed to fully handshake with client, dropping connection"); return; }; diff --git a/src/persistence.rs b/src/persistence.rs index b18f655..00c9411 100644 --- a/src/persistence.rs +++ b/src/persistence.rs @@ -11,7 +11,7 @@ use crate::{ channel::permissions::Permission, persistence::events::{ ChannelCreated, ChannelJoined, ChannelMessage, ChannelParted, FetchUnseenMessages, - FetchUserChannelPermissions, FetchUserChannels, SetUserChannelPermissions, + FetchUserChannelPermissions, FetchUserChannels, ReserveNick, SetUserChannelPermissions, }, }; @@ -281,6 +281,30 @@ impl Handler for Persistence { } } +impl Handler for Persistence { + type Result = ResponseFuture; + + fn handle(&mut self, msg: ReserveNick, _ctx: &mut Self::Context) -> Self::Result { + let database = self.database.clone(); + + Box::pin(async move { + let (owning_user,): (i64,) = sqlx::query_as( + "INSERT INTO user_nicks (nick, user) + VALUES (?, ?) + ON CONFLICT(nick) DO UPDATE SET nick = nick + RETURNING user", + ) + .bind(msg.nick) + .bind(msg.user_id.0) + .fetch_one(&database) + .await + .unwrap(); + + owning_user == msg.user_id.0 + }) + } +} + /// Remove any messages from the messages table whenever they've been seen by all users /// or have passed their retention period /// . diff --git a/src/persistence/events.rs b/src/persistence/events.rs index 266c432..2fbdd95 100644 --- a/src/persistence/events.rs +++ b/src/persistence/events.rs @@ -66,3 +66,10 @@ pub struct FetchUnseenMessages { pub user_id: UserId, pub span: Span, } + +#[derive(Message)] +#[rtype(result = "bool")] +pub struct ReserveNick { + pub user_id: UserId, + pub nick: String, +} -- libgit2 1.7.2