Give user access to its user id and modify queries to use it
Diff
src/channel.rs | 6 +++---
src/client.rs | 17 +++++++----------
src/connection.rs | 55 ++++++++++++++++++++++++++++++++++---------------------
src/persistence.rs | 44 ++++++++++++++++++++++++++++----------------
src/database/mod.rs | 56 +++++++++++++++++++++++---------------------------------
src/persistence/events.rs | 12 +++++++-----
6 files changed, 85 insertions(+), 105 deletions(-)
@@ -90,7 +90,7 @@
channel_name: self.name.to_string(),
sender: nick.to_string(),
message: msg.message.to_string(),
receivers: self.clients.values().map(|v| v.user.to_string()).collect(),
receivers: self.clients.values().map(|v| v.user_id).collect(),
});
for client in self.clients.keys() {
@@ -144,7 +144,7 @@
self.persistence
.do_send(crate::persistence::events::ChannelJoined {
channel_name: self.name.to_string(),
username: msg.connection.user.to_string(),
user_id: msg.connection.user_id,
span: msg.span.clone(),
});
@@ -282,7 +282,7 @@
self.persistence
.do_send(crate::persistence::events::ChannelParted {
channel_name: self.name.to_string(),
username: client_info.user.to_string(),
user_id: client_info.user_id,
span: msg.span.clone(),
});
@@ -81,7 +81,7 @@
ctx.spawn(
self.persistence
.send(FetchUserChannels {
username: self.connection.user.to_string(),
user_id: self.connection.user_id,
span: Span::current(),
})
.into_actor(self)
@@ -185,7 +185,7 @@
let channel_messages_fut = self.persistence.send(FetchUnseenMessages {
channel_name: channel_name.to_string(),
username: self.connection.user.to_string(),
user_id: self.connection.user_id,
span: Span::current(),
});
@@ -203,18 +203,15 @@
let fut = wrap_future::<_, Self>(
futures::future::join_all(futures.into_iter()).instrument(Span::current()),
)
.map(|result, this, ctx| {
.map(|result, this, _ctx| {
for (channel_name, handle, messages) in result {
this.channels.insert(channel_name.clone(), handle);
for (source, message) in messages {
ctx.notify(Broadcast {
message: Message {
tags: None,
prefix: Some(Prefix::new_from_str(&source)),
command: Command::PRIVMSG(channel_name.clone(), message),
},
span: this.span.clone(),
this.writer.write(Message {
tags: None,
prefix: Some(Prefix::new_from_str(&source)),
command: Command::PRIVMSG(channel_name.clone(), message),
});
}
}
@@ -25,6 +25,9 @@
pub const SUPPORTED_CAPABILITIES: &[&str] = &[concatcp!("sasl=", AuthStrategy::SUPPORTED)];
#[derive(Copy, Clone, Debug)]
pub struct UserId(pub i64);
#[derive(Default)]
pub struct ConnectionRequest {
nick: Option<String>,
@@ -39,6 +42,7 @@
pub user: String,
pub mode: String,
pub real_name: String,
pub user_id: UserId,
}
impl InitiatedConnection {
@@ -70,6 +74,7 @@
user,
mode,
real_name,
user_id: UserId(0),
})
}
}
@@ -136,7 +141,7 @@
let Some(initiated) = initiated else {
let Some(mut initiated) = initiated else {
return Ok(None);
};
@@ -147,7 +152,7 @@
)));
}
let mut has_authenticated = false;
let mut user_id = None;
while let Some(msg) = s.try_next().await? {
@@ -161,7 +166,7 @@
break;
}
Command::AUTHENTICATE(strategy) => {
has_authenticated =
user_id =
start_authenticate_flow(s, write, &initiated, strategy, &database).await?;
}
_ => {
@@ -172,8 +177,10 @@
}
}
}
if let Some(user_id) = user_id {
initiated.user_id.0 = user_id;
if has_authenticated {
Ok(Some(initiated))
} else {
Err(ProtocolError::Io(Error::new(
@@ -193,11 +200,11 @@
connection: &InitiatedConnection,
strategy: String,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<bool, ProtocolError> {
) -> Result<Option<i64>, ProtocolError> {
let Ok(auth_strategy) = AuthStrategy::from_str(&strategy) else {
write.send(SaslStrategyUnsupported(connection.nick.to_string()).into_message())
.await?;
return Ok(false);
return Ok(None);
};
@@ -226,7 +233,7 @@
break;
}
let authenticated = match auth_strategy {
let user_id = match auth_strategy {
AuthStrategy::Plain => {
@@ -234,12 +241,12 @@
}
};
if authenticated {
if user_id.is_some() {
for message in SaslSuccess(connection.clone()).into_messages() {
write.send(message).await?;
}
return Ok(true);
return Ok(user_id);
}
write
@@ -247,18 +254,20 @@
.await?;
}
Ok(false)
Ok(None)
}
pub async fn handle_plain_authentication(
arguments: &str,
connection: &InitiatedConnection,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<bool, Error> {
) -> Result<Option<i64>, Error> {
let arguments = BASE64_STANDARD
.decode(arguments)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
@@ -277,26 +286,16 @@
}
let password_hash = crate::database::fetch_password_hash(database, &connection.user)
.await
.unwrap();
let password_hash = password_hash
.as_deref()
.map(PasswordHash::new)
.transpose()
.unwrap();
let Some(password_hash) = password_hash else {
crate::database::create_user(database, &connection.user, password).await.unwrap();
return Ok(true);
};
let (user_id, password_hash) =
crate::database::create_user_or_fetch_password_hash(database, &connection.user, password)
.await
.unwrap();
let password_hash = PasswordHash::new(&password_hash).unwrap();
match verify_password(password, &password_hash) {
Ok(()) => Ok(true),
Err(argon2::password_hash::Error::Password) => Ok(false),
Ok(()) => Ok(Some(user_id)),
Err(argon2::password_hash::Error::Password) => Ok(None),
Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
}
}
@@ -48,11 +48,11 @@
Box::pin(async move {
sqlx::query(
"INSERT INTO channel_users (channel, user, permissions, in_channel)
VALUES ((SELECT id FROM channels WHERE name = ?), (SELECT id FROM users WHERE username = ?), ?, ?)
ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel"
VALUES ((SELECT id FROM channels WHERE name = ?), ?, ?, ?)
ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel",
)
.bind(msg.channel_name)
.bind(msg.username)
.bind(msg.user_id.0)
.bind(0i32)
.bind(true)
.execute(&conn)
@@ -75,10 +75,10 @@
"UPDATE channel_users
SET in_channel = false
WHERE channel = (SELECT id FROM channels WHERE name = ?)
AND user = (SELECT id FROM users WHERE username = ?)",
AND user = ?",
)
.bind(msg.channel_name)
.bind(msg.username)
.bind(msg.user_id.0)
.execute(&conn)
.await
.unwrap();
@@ -101,7 +101,7 @@
WHERE user = (SELECT id FROM users WHERE username = ?)
AND in_channel = true",
)
.bind(msg.username)
.bind(msg.user_id.0)
.fetch_all(&conn)
.await
.unwrap()
@@ -137,20 +137,22 @@
.await
.unwrap();
let query = format!(
"UPDATE channel_users
SET last_seen_message_idx = ?
WHERE channel = ?
AND user IN (SELECT id FROM users WHERE username IN ({}))",
msg.receivers.iter().map(|_| "?").join(",")
);
if !msg.receivers.is_empty() {
let query = format!(
"UPDATE channel_users
SET last_seen_message_idx = ?
WHERE channel = ?
AND user IN ({})",
msg.receivers.iter().map(|_| "?").join(",")
);
let mut query = sqlx::query(&query).bind(idx).bind(channel);
for receiver in msg.receivers {
query = query.bind(receiver);
}
let mut query = sqlx::query(&query).bind(idx).bind(channel);
for receiver in msg.receivers {
query = query.bind(receiver.0);
}
query.execute(&conn).await.unwrap();
query.execute(&conn).await.unwrap();
}
})
}
}
@@ -180,13 +182,13 @@
SELECT last_seen_message_idx
FROM channel_users
WHERE channel = (SELECT id FROM channel)
AND user = (SELECT id FROM users WHERE username = ?)
AND user = ?
)
)
ORDER BY idx DESC",
ORDER BY idx ASC",
)
.bind(msg.channel_name.to_string())
.bind(msg.username.to_string())
.bind(msg.user_id.0)
.fetch_all(&conn)
.await
.unwrap();
@@ -1,49 +1,29 @@
use argon2::{password_hash::SaltString, Argon2, PasswordHash, PasswordHasher, PasswordVerifier};
use rand::rngs::OsRng;
use sqlx::{database::HasArguments, Database, Encode, Executor, FromRow, IntoArguments, Type};
pub async fn fetch_password_hash<'a, E: Executor<'a>>(
conn: E,
username: &'a str,
) -> Result<Option<String>, sqlx::Error>
where
for<'b> &'b str: Type<E::Database> + Encode<'b, E::Database>,
<E::Database as HasArguments<'a>>::Arguments: IntoArguments<'a, E::Database>,
for<'b> (String,): FromRow<'b, <E::Database as Database>::Row>,
{
let res = sqlx::query_as("SELECT password FROM users WHERE username = ?")
.bind(username)
.fetch_optional(conn)
.await?
.map(|(v,)| v);
Ok(res)
}
pub async fn create_user<'a, E: Executor<'a>>(
conn: E,
username: &'a str,
pub async fn create_user_or_fetch_password_hash(
conn: &sqlx::Pool<sqlx::Any>,
username: &str,
password: &[u8],
) -> Result<(), sqlx::Error>
where
for<'b> &'b str: Type<E::Database> + Encode<'b, E::Database>,
for<'b> String: Type<E::Database> + Encode<'b, E::Database>,
<E::Database as HasArguments<'a>>::Arguments: IntoArguments<'a, E::Database>,
{
let salt = SaltString::generate(&mut OsRng);
) -> Result<(i64, String), sqlx::Error> {
let password_hash = Argon2::default()
.hash_password(password, &salt)
.hash_password(password, &SaltString::generate(&mut OsRng))
.unwrap()
.to_string();
sqlx::query("INSERT INTO users (username, password) VALUES (?, ?)")
.bind(username)
.bind(password_hash)
.execute(conn)
.await
.map(|_| ())
sqlx::query_as(
"INSERT INTO users (username, password)
VALUES (?, ?)
ON CONFLICT(username) DO UPDATE SET username = username
RETURNING id, password",
)
.bind(username)
.bind(password_hash)
.fetch_one(conn)
.await
}
@@ -1,6 +1,8 @@
use actix::Message;
use tracing::Span;
use crate::connection::UserId;
#[derive(Message)]
#[rtype(result = "()")]
pub struct ChannelCreated {
@@ -11,7 +13,7 @@
#[rtype(result = "()")]
pub struct ChannelJoined {
pub channel_name: String,
pub username: String,
pub user_id: UserId,
pub span: Span,
}
@@ -19,14 +21,14 @@
#[rtype(result = "()")]
pub struct ChannelParted {
pub channel_name: String,
pub username: String,
pub user_id: UserId,
pub span: Span,
}
#[derive(Message)]
#[rtype(result = "Vec<String>")]
pub struct FetchUserChannels {
pub username: String,
pub user_id: UserId,
pub span: Span,
}
@@ -36,13 +38,13 @@
pub channel_name: String,
pub sender: String,
pub message: String,
pub receivers: Vec<String>,
pub receivers: Vec<UserId>,
}
#[derive(Message)]
#[rtype(result = "Vec<(String, String)>")]
pub struct FetchUnseenMessages {
pub channel_name: String,
pub username: String,
pub user_id: UserId,
pub span: Span,
}