🏡 index : ~doyle/titanirc.git

author Jordan Doyle <jordan@doyle.la> 2023-01-08 20:38:23.0 +00:00:00
committer Jordan Doyle <jordan@doyle.la> 2023-01-08 20:38:40.0 +00:00:00
commit
636411f8e0cbab37ba3676f045e8f1434a29189f [patch]
tree
cd5dccb51528c8ca990fcf6b9a846f19373c15fb
parent
511f35bcd6fc181b5db9e62dfcc96923285d5343
download
636411f8e0cbab37ba3676f045e8f1434a29189f.tar.gz

Give user access to its user id and modify queries to use it



Diff

 src/channel.rs            |  6 ++---
 src/client.rs             | 17 ++++++---------
 src/connection.rs         | 55 +++++++++++++++++++++++------------------------
 src/database/mod.rs       | 56 +++++++++++++++---------------------------------
 src/persistence.rs        | 46 ++++++++++++++++++++-------------------
 src/persistence/events.rs | 12 +++++-----
 6 files changed, 86 insertions(+), 106 deletions(-)

diff --git a/src/channel.rs b/src/channel.rs
index 26afbee..c5a6c81 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -90,7 +90,7 @@ impl Handler<ChannelMessage> for Channel {
                channel_name: self.name.to_string(),
                sender: nick.to_string(),
                message: msg.message.to_string(),
                receivers: self.clients.values().map(|v| v.user.to_string()).collect(),
                receivers: self.clients.values().map(|v| v.user_id).collect(),
            });

        for client in self.clients.keys() {
@@ -144,7 +144,7 @@ impl Handler<ChannelJoin> for Channel {
        self.persistence
            .do_send(crate::persistence::events::ChannelJoined {
                channel_name: self.name.to_string(),
                username: msg.connection.user.to_string(),
                user_id: msg.connection.user_id,
                span: msg.span.clone(),
            });

@@ -282,7 +282,7 @@ impl Handler<ChannelPart> for Channel {
        self.persistence
            .do_send(crate::persistence::events::ChannelParted {
                channel_name: self.name.to_string(),
                username: client_info.user.to_string(),
                user_id: client_info.user_id,
                span: msg.span.clone(),
            });

diff --git a/src/client.rs b/src/client.rs
index 1d015f6..7639e55 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -81,7 +81,7 @@ impl Actor for Client {
        ctx.spawn(
            self.persistence
                .send(FetchUserChannels {
                    username: self.connection.user.to_string(),
                    user_id: self.connection.user_id,
                    span: Span::current(),
                })
                .into_actor(self)
@@ -185,7 +185,7 @@ impl Handler<JoinChannelRequest> for Client {

            let channel_messages_fut = self.persistence.send(FetchUnseenMessages {
                channel_name: channel_name.to_string(),
                username: self.connection.user.to_string(),
                user_id: self.connection.user_id,
                span: Span::current(),
            });

@@ -203,18 +203,15 @@ impl Handler<JoinChannelRequest> for Client {
        let fut = wrap_future::<_, Self>(
            futures::future::join_all(futures.into_iter()).instrument(Span::current()),
        )
        .map(|result, this, ctx| {
        .map(|result, this, _ctx| {
            for (channel_name, handle, messages) in result {
                this.channels.insert(channel_name.clone(), handle);

                for (source, message) in messages {
                    ctx.notify(Broadcast {
                        message: Message {
                            tags: None,
                            prefix: Some(Prefix::new_from_str(&source)),
                            command: Command::PRIVMSG(channel_name.clone(), message),
                        },
                        span: this.span.clone(),
                    this.writer.write(Message {
                        tags: None,
                        prefix: Some(Prefix::new_from_str(&source)),
                        command: Command::PRIVMSG(channel_name.clone(), message),
                    });
                }
            }
diff --git a/src/connection.rs b/src/connection.rs
index 6817b0d..07d8bf9 100644
--- a/src/connection.rs
+++ b/src/connection.rs
@@ -25,6 +25,9 @@ pub type MessageSink = FramedWrite<Message, WriteHalf<TcpStream>, irc_proto::Irc

pub const SUPPORTED_CAPABILITIES: &[&str] = &[concatcp!("sasl=", AuthStrategy::SUPPORTED)];

#[derive(Copy, Clone, Debug)]
pub struct UserId(pub i64);

#[derive(Default)]
pub struct ConnectionRequest {
    nick: Option<String>,
@@ -39,6 +42,7 @@ pub struct InitiatedConnection {
    pub user: String,
    pub mode: String,
    pub real_name: String,
    pub user_id: UserId,
}

impl InitiatedConnection {
@@ -70,6 +74,7 @@ impl TryFrom<ConnectionRequest> for InitiatedConnection {
            user,
            mode,
            real_name,
            user_id: UserId(0),
        })
    }
}
@@ -136,7 +141,7 @@ pub async fn negotiate_client_connection(

    // if the user closed the connection before the connection was fully established,
    // return back early
    let Some(initiated) = initiated else {
    let Some(mut initiated) = initiated else {
        return Ok(None);
    };

@@ -147,7 +152,7 @@ pub async fn negotiate_client_connection(
        )));
    }

    let mut has_authenticated = false;
    let mut user_id = None;

    // start negotiating capabilities with the client
    while let Some(msg) = s.try_next().await? {
@@ -161,7 +166,7 @@ pub async fn negotiate_client_connection(
                break;
            }
            Command::AUTHENTICATE(strategy) => {
                has_authenticated =
                user_id =
                    start_authenticate_flow(s, write, &initiated, strategy, &database).await?;
            }
            _ => {
@@ -173,7 +178,9 @@ pub async fn negotiate_client_connection(
        }
    }

    if has_authenticated {
    if let Some(user_id) = user_id {
        initiated.user_id.0 = user_id;

        Ok(Some(initiated))
    } else {
        Err(ProtocolError::Io(Error::new(
@@ -193,11 +200,11 @@ async fn start_authenticate_flow(
    connection: &InitiatedConnection,
    strategy: String,
    database: &sqlx::Pool<sqlx::Any>,
) -> Result<bool, ProtocolError> {
) -> Result<Option<i64>, ProtocolError> {
    let Ok(auth_strategy) = AuthStrategy::from_str(&strategy) else {
        write.send(SaslStrategyUnsupported(connection.nick.to_string()).into_message())
            .await?;
        return Ok(false);
        return Ok(None);
    };

    // tell the client to go ahead with their authentication
@@ -226,7 +233,7 @@ async fn start_authenticate_flow(
            break;
        }

        let authenticated = match auth_strategy {
        let user_id = match auth_strategy {
            AuthStrategy::Plain => {
                // TODO: this needs to deal with the case where the full arguments can be split over
                //  multiple messages
@@ -234,12 +241,12 @@ async fn start_authenticate_flow(
            }
        };

        if authenticated {
        if user_id.is_some() {
            for message in SaslSuccess(connection.clone()).into_messages() {
                write.send(message).await?;
            }

            return Ok(true);
            return Ok(user_id);
        }

        write
@@ -247,18 +254,20 @@ async fn start_authenticate_flow(
            .await?;
    }

    Ok(false)
    Ok(None)
}

/// Attempts to handle an `AUTHENTICATE` command for the `PLAIN` authentication method.
///
/// This will parse the full message, ensure that the identity is correct and compare the hashes
/// to what we have stored in the database.
///
/// This function will return the authenticated user id, or none if the password was incorrect.
pub async fn handle_plain_authentication(
    arguments: &str,
    connection: &InitiatedConnection,
    database: &sqlx::Pool<sqlx::Any>,
) -> Result<bool, Error> {
) -> Result<Option<i64>, Error> {
    let arguments = BASE64_STANDARD
        .decode(arguments)
        .map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
@@ -277,26 +286,16 @@ pub async fn handle_plain_authentication(
    }

    // lookup the user's password based on the USER command they sent earlier
    let password_hash = crate::database::fetch_password_hash(database, &connection.user)
        .await
        .unwrap();
    let password_hash = password_hash
        .as_deref()
        .map(PasswordHash::new)
        .transpose()
        .unwrap();
    let Some(password_hash) = password_hash else {
        // this is a new user, so we'll create an account for them
        // TODO: we need to deal with races here, right now we'll just error out on dup
        crate::database::create_user(database, &connection.user, password).await.unwrap();

        return Ok(true);
    };
    let (user_id, password_hash) =
        crate::database::create_user_or_fetch_password_hash(database, &connection.user, password)
            .await
            .unwrap();
    let password_hash = PasswordHash::new(&password_hash).unwrap();

    // check the user's password
    match verify_password(password, &password_hash) {
        Ok(()) => Ok(true),
        Err(argon2::password_hash::Error::Password) => Ok(false),
        Ok(()) => Ok(Some(user_id)),
        Err(argon2::password_hash::Error::Password) => Ok(None),
        Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
    }
}
diff --git a/src/database/mod.rs b/src/database/mod.rs
index 0fb5ba7..366e0ea 100644
--- a/src/database/mod.rs
+++ b/src/database/mod.rs
@@ -1,49 +1,29 @@
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use rand::rngs::OsRng;
use sqlx::{database::HasArguments, Database, Encode, Executor, FromRow, IntoArguments, Type};

/// Fetches the given user's password from the database.
pub async fn fetch_password_hash<'a, E: Executor<'a>>(
    conn: E,
    username: &'a str,
) -> Result<Option<String>, sqlx::Error>
where
    for<'b> &'b str: Type<E::Database> + Encode<'b, E::Database>,
    <E::Database as HasArguments<'a>>::Arguments: IntoArguments<'a, E::Database>,
    for<'b> (String,): FromRow<'b, <E::Database as Database>::Row>,
{
    let res = sqlx::query_as("SELECT password FROM users WHERE username = ?")
        .bind(username)
        .fetch_optional(conn)
        .await?
        .map(|(v,)| v);

    Ok(res)
}

/// Creates a new user, returning an error if the user already exists.
pub async fn create_user<'a, E: Executor<'a>>(
    conn: E,
    username: &'a str,
/// Attempts creation of a new user, returning the password of the user.
///
/// The returned password _is not_ guaranteed to be the password that was just set.
pub async fn create_user_or_fetch_password_hash(
    conn: &sqlx::Pool<sqlx::Any>,
    username: &str,
    password: &[u8],
) -> Result<(), sqlx::Error>
where
    for<'b> &'b str: Type<E::Database> + Encode<'b, E::Database>,
    for<'b> String: Type<E::Database> + Encode<'b, E::Database>,
    <E::Database as HasArguments<'a>>::Arguments: IntoArguments<'a, E::Database>,
{
    let salt = SaltString::generate(&mut OsRng);
) -> Result<(i64, String), sqlx::Error> {
    let password_hash = Argon2::default()
        .hash_password(password, &salt)
        .hash_password(password, &SaltString::generate(&mut OsRng))
        .unwrap()
        .to_string();

    sqlx::query("INSERT INTO users (username, password) VALUES (?, ?)")
        .bind(username)
        .bind(password_hash)
        .execute(conn)
        .await
        .map(|_| ())
    sqlx::query_as(
        "INSERT INTO users (username, password)
         VALUES (?, ?)
         ON CONFLICT(username) DO UPDATE SET username = username
         RETURNING id, password",
    )
    .bind(username)
    .bind(password_hash)
    .fetch_one(conn)
    .await
}

/// Compares a password to a hash stored in the database.
diff --git a/src/persistence.rs b/src/persistence.rs
index 90d7041..20d2fb8 100644
--- a/src/persistence.rs
+++ b/src/persistence.rs
@@ -48,11 +48,11 @@ impl Handler<ChannelJoined> for Persistence {
        Box::pin(async move {
            sqlx::query(
                "INSERT INTO channel_users (channel, user, permissions, in_channel)
                 VALUES ((SELECT id FROM channels WHERE name = ?), (SELECT id FROM users WHERE username = ?), ?, ?)
                 ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel"
                 VALUES ((SELECT id FROM channels WHERE name = ?), ?, ?, ?)
                 ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel",
            )
            .bind(msg.channel_name)
            .bind(msg.username)
            .bind(msg.user_id.0)
            .bind(0i32)
            .bind(true)
            .execute(&conn)
@@ -75,10 +75,10 @@ impl Handler<ChannelParted> for Persistence {
                "UPDATE channel_users
                 SET in_channel = false
                 WHERE channel = (SELECT id FROM channels WHERE name = ?)
                   AND user = (SELECT id FROM users WHERE username = ?)",
                   AND user = ?",
            )
            .bind(msg.channel_name)
            .bind(msg.username)
            .bind(msg.user_id.0)
            .execute(&conn)
            .await
            .unwrap();
@@ -101,7 +101,7 @@ impl Handler<FetchUserChannels> for Persistence {
                  WHERE user = (SELECT id FROM users WHERE username = ?)
                    AND in_channel = true",
            )
            .bind(msg.username)
            .bind(msg.user_id.0)
            .fetch_all(&conn)
            .await
            .unwrap()
@@ -137,20 +137,22 @@ impl Handler<ChannelMessage> for Persistence {
            .await
            .unwrap();

            let query = format!(
                "UPDATE channel_users
                 SET last_seen_message_idx = ?
                 WHERE channel = ?
                   AND user IN (SELECT id FROM users WHERE username IN ({}))",
                msg.receivers.iter().map(|_| "?").join(",")
            );

            let mut query = sqlx::query(&query).bind(idx).bind(channel);
            for receiver in msg.receivers {
                query = query.bind(receiver);
            if !msg.receivers.is_empty() {
                let query = format!(
                    "UPDATE channel_users
                     SET last_seen_message_idx = ?
                     WHERE channel = ?
                       AND user IN ({})",
                    msg.receivers.iter().map(|_| "?").join(",")
                );

                let mut query = sqlx::query(&query).bind(idx).bind(channel);
                for receiver in msg.receivers {
                    query = query.bind(receiver.0);
                }

                query.execute(&conn).await.unwrap();
            }

            query.execute(&conn).await.unwrap();
        })
    }
}
@@ -180,13 +182,13 @@ impl Handler<FetchUnseenMessages> for Persistence {
                            SELECT last_seen_message_idx
                            FROM channel_users
                            WHERE channel = (SELECT id FROM channel)
                              AND user = (SELECT id FROM users WHERE username = ?)
                              AND user = ?
                        )
                    )
                 ORDER BY idx DESC",
                 ORDER BY idx ASC",
            )
            .bind(msg.channel_name.to_string())
            .bind(msg.username.to_string())
            .bind(msg.user_id.0)
            .fetch_all(&conn)
            .await
            .unwrap();
diff --git a/src/persistence/events.rs b/src/persistence/events.rs
index 8e69bb2..e064511 100644
--- a/src/persistence/events.rs
+++ b/src/persistence/events.rs
@@ -1,6 +1,8 @@
use actix::Message;
use tracing::Span;

use crate::connection::UserId;

#[derive(Message)]
#[rtype(result = "()")]
pub struct ChannelCreated {
@@ -11,7 +13,7 @@ pub struct ChannelCreated {
#[rtype(result = "()")]
pub struct ChannelJoined {
    pub channel_name: String,
    pub username: String,
    pub user_id: UserId,
    pub span: Span,
}

@@ -19,14 +21,14 @@ pub struct ChannelJoined {
#[rtype(result = "()")]
pub struct ChannelParted {
    pub channel_name: String,
    pub username: String,
    pub user_id: UserId,
    pub span: Span,
}

#[derive(Message)]
#[rtype(result = "Vec<String>")]
pub struct FetchUserChannels {
    pub username: String,
    pub user_id: UserId,
    pub span: Span,
}

@@ -36,13 +38,13 @@ pub struct ChannelMessage {
    pub channel_name: String,
    pub sender: String,
    pub message: String,
    pub receivers: Vec<String>,
    pub receivers: Vec<UserId>,
}

#[derive(Message)]
#[rtype(result = "Vec<(String, String)>")]
pub struct FetchUnseenMessages {
    pub channel_name: String,
    pub username: String,
    pub user_id: UserId,
    pub span: Span,
}