use std::{
io::{Error, ErrorKind},
str::FromStr,
};
use actix::{Actor, ActorContext, Context, Handler, Message, ResponseFuture};
use argon2::PasswordHash;
use base64::{prelude::BASE64_STANDARD, Engine};
use futures::TryFutureExt;
use irc_proto::Command;
use crate::{
connection::{
sasl::{AuthStrategy, SaslAborted, SaslFail, SaslStrategyUnsupported},
UserId,
},
database::verify_password,
};
pub struct Authenticate {
pub selected_strategy: Option<AuthStrategy>,
pub database: sqlx::Pool<sqlx::Any>,
}
impl Actor for Authenticate {
type Context = Context<Self>;
}
impl Handler<AuthenticateMessage> for Authenticate {
type Result = ResponseFuture<Result<AuthenticateResult, std::io::Error>>;
fn handle(&mut self, msg: AuthenticateMessage, ctx: &mut Self::Context) -> Self::Result {
let Some(selected_strategy) = self.selected_strategy else {
let message = match AuthStrategy::from_str(&msg.0) {
Ok(strategy) => {
self.selected_strategy = Some(strategy);
irc_proto::Message {
tags: None,
prefix: None,
command: Command::AUTHENTICATE("+".to_string()),
}
}
Err(_) => SaslStrategyUnsupported::into_message(),
};
return Box::pin(futures::future::ok(AuthenticateResult::Reply(Box::new(
message,
))));
};
if msg.0 == "*" {
ctx.stop();
return Box::pin(futures::future::ok(AuthenticateResult::Reply(Box::new(
SaslAborted::into_message(),
))));
}
match selected_strategy {
AuthStrategy::Plain => Box::pin(
handle_plain_authentication(msg.0, self.database.clone()).map_ok(|v| {
v.map_or_else(
|| AuthenticateResult::Reply(Box::new(SaslFail::into_message())),
|(username, user_id)| AuthenticateResult::Done(username, user_id),
)
}),
),
}
}
}
pub async fn handle_plain_authentication(
arguments: String,
database: sqlx::Pool<sqlx::Any>,
) -> Result<Option<(String, UserId)>, Error> {
let arguments = BASE64_STANDARD
.decode(&arguments)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let mut message = arguments.splitn(3, |f| *f == b'\0');
let (Some(authorization_identity), Some(authentication_identity), Some(password)) =
(message.next(), message.next(), message.next())
else {
return Err(Error::new(ErrorKind::InvalidData, "bad plain message"));
};
if authorization_identity != authentication_identity {
return Err(Error::new(ErrorKind::InvalidData, "identity mismatch"));
}
let authorization_identity = std::str::from_utf8(authentication_identity)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let (user_id, password_hash) = crate::database::create_user_or_fetch_password_hash(
&database,
authorization_identity,
password,
)
.await
.unwrap();
let password_hash = PasswordHash::new(&password_hash).unwrap();
match verify_password(password, &password_hash) {
Ok(()) => Ok(Some((authorization_identity.to_string(), UserId(user_id)))),
Err(argon2::password_hash::Error::Password) => Ok(None),
Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
}
}
pub enum AuthenticateResult {
Reply(Box<irc_proto::Message>),
Done(String, UserId),
}
#[derive(Message)]
#[rtype(result = "Result<AuthenticateResult, std::io::Error>")]
pub struct AuthenticateMessage(pub String);