Allow SASL authentication before the registration process has finished
Diff
src/client.rs | 2 +-
src/connection.rs | 365 ++++++++++++++++++++++++++++++++++++++++++--------------------------------------
src/connection/authenticate.rs | 136 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
src/connection/sasl.rs | 144 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
4 files changed, 336 insertions(+), 311 deletions(-)
@@ -12,7 +12,7 @@
use crate::{
channel::Channel,
connection::{InitiatedConnection, MessageSink, SaslAlreadyAuthenticated},
connection::{sasl::SaslAlreadyAuthenticated, InitiatedConnection, MessageSink},
messages::{
Broadcast, ChannelFetchTopic, ChannelInvite, ChannelJoin, ChannelKickUser, ChannelList,
ChannelMemberList, ChannelMessage, ChannelPart, ChannelSetMode, ChannelUpdateTopic,
@@ -1,17 +1,15 @@
mod authenticate;
pub mod sasl;
use std::{
io::{Error, ErrorKind},
net::SocketAddr,
str::FromStr,
};
use actix::{io::FramedWrite, Addr};
use argon2::PasswordHash;
use base64::{prelude::BASE64_STANDARD, Engine};
use actix::{io::FramedWrite, Actor, Addr};
use const_format::concatcp;
use futures::{SinkExt, TryStreamExt};
use irc_proto::{
error::ProtocolError, CapSubCommand, Command, IrcCodec, Message, Prefix, Response,
};
use irc_proto::{error::ProtocolError, CapSubCommand, Command, IrcCodec, Message, Prefix};
use tokio::{
io::{ReadHalf, WriteHalf},
net::TcpStream,
@@ -20,7 +18,10 @@
use tracing::{instrument, warn};
use crate::{
database::verify_password,
connection::{
authenticate::{Authenticate, AuthenticateMessage, AuthenticateResult},
sasl::{AuthStrategy, ConnectionSuccess, SaslSuccess},
},
persistence::{events::ReserveNick, Persistence},
};
@@ -39,6 +40,7 @@
user: Option<String>,
mode: Option<String>,
real_name: Option<String>,
user_id: Option<UserId>,
}
#[derive(Clone)]
@@ -72,6 +74,7 @@
user: Some(user),
mode: Some(mode),
real_name: Some(real_name),
user_id: Some(user_id),
} = value else {
return Err(value);
};
@@ -82,7 +85,7 @@
user,
mode,
real_name,
user_id: UserId(0),
user_id,
})
}
}
@@ -102,7 +105,11 @@
..ConnectionRequest::default()
};
let mut capabilities_requested = false;
let authenticate_handle = Authenticate {
selected_strategy: None,
database: database.clone(),
}
.start();
@@ -116,14 +123,12 @@
match msg.command {
Command::PASS(_) => {}
Command::NICK(nick) => request.nick = Some(nick),
Command::USER(user, mode, real_name) => {
request.user = Some(user);
Command::USER(_user, mode, real_name) => {
request.mode = Some(mode);
request.real_name = Some(real_name);
}
Command::CAP(_, CapSubCommand::LIST | CapSubCommand::LS, _, _) => {
capabilities_requested = true;
write
.send(Message {
tags: None,
@@ -137,6 +142,27 @@
})
.await
.unwrap();
}
Command::CAP(_, CapSubCommand::REQ, Some(arguments), None) => {
write
.send(AcknowledgedCapabilities(arguments).into_message())
.await?;
}
Command::AUTHENTICATE(msg) => {
match authenticate_handle
.send(AuthenticateMessage(msg))
.await
.unwrap()?
{
AuthenticateResult::Reply(v) => {
write.send(*v).await?;
}
AuthenticateResult::Done(username, user_id) => {
request.user = Some(username);
request.user_id = Some(user_id);
write.send(SaslSuccess::into_message()).await?;
}
}
}
_ => {
warn!(?msg, "Client sent unknown command during negotiation");
@@ -154,178 +180,30 @@
let Some(mut initiated) = initiated else {
return Ok(None);
};
if !capabilities_requested {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"capabilities not requested by client, so SASL authentication can not be performed",
)));
}
let mut user_id = None;
while let Some(msg) = s.try_next().await? {
match msg.command {
Command::CAP(_, CapSubCommand::REQ, Some(arguments), None) => {
write
.send(AcknowledgedCapabilities(arguments).into_message())
.await?;
}
Command::CAP(_, CapSubCommand::END, _, _) => {
break;
}
Command::AUTHENTICATE(strategy) => {
user_id =
start_authenticate_flow(s, write, &initiated, strategy, &database).await?;
}
_ => {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
format!("client sent non-cap message during negotiation {msg:?}"),
)))
}
}
}
if let Some(user_id) = user_id {
initiated.user_id.0 = user_id;
let reserved_nick = persistence
.send(ReserveNick {
user_id: initiated.user_id,
nick: initiated.nick.clone(),
})
.await
.map_err(|e| ProtocolError::Io(Error::new(ErrorKind::InvalidData, e)))?;
if !reserved_nick {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"nick is already in use by another user",
)));
}
Ok(Some(initiated))
} else {
Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"user has not authenticated",
)))
}
}
async fn start_authenticate_flow(
s: &mut MessageStream,
write: &mut tokio_util::codec::FramedWrite<WriteHalf<TcpStream>, IrcCodec>,
connection: &InitiatedConnection,
strategy: String,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<Option<i64>, ProtocolError> {
let Ok(auth_strategy) = AuthStrategy::from_str(&strategy) else {
write.send(SaslStrategyUnsupported(connection.nick.to_string()).into_message())
.await?;
let Some(initiated) = initiated else {
return Ok(None);
};
write
.send(Message {
tags: None,
prefix: None,
command: Command::AUTHENTICATE("+".to_string()),
})
.send(ConnectionSuccess(initiated.clone()).into_message())
.await?;
while let Some(msg) = s.try_next().await? {
let Command::AUTHENTICATE(arguments) = msg.command else {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
format!("client sent invalid message during authentication {msg:?}"),
)));
};
if arguments == "*" {
write
.send(SaslAborted(connection.nick.to_string()).into_message())
.await?;
break;
}
let user_id = match auth_strategy {
AuthStrategy::Plain => {
handle_plain_authentication(&arguments, connection, database).await?
}
};
if user_id.is_some() {
for message in SaslSuccess(connection.clone()).into_messages() {
write.send(message).await?;
}
return Ok(user_id);
}
write
.send(SaslFail(connection.nick.to_string()).into_message())
.await?;
}
Ok(None)
}
pub async fn handle_plain_authentication(
arguments: &str,
connection: &InitiatedConnection,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<Option<i64>, Error> {
let arguments = BASE64_STANDARD
.decode(arguments)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let mut message = arguments.splitn(3, |f| *f == b'\0');
let (Some(authorization_identity), Some(authentication_identity), Some(password)) = (message.next(), message.next(), message.next()) else {
return Err(Error::new(ErrorKind::InvalidData, "bad plain message"));
};
let reserved_nick = persistence
.send(ReserveNick {
user_id: initiated.user_id,
nick: initiated.nick.clone(),
})
.await
.map_err(|e| ProtocolError::Io(Error::new(ErrorKind::InvalidData, e)))?;
if authorization_identity != connection.user.as_bytes()
|| authentication_identity != connection.user.as_bytes()
{
return Err(Error::new(ErrorKind::InvalidData, "username mismatch"));
if !reserved_nick {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"nick is already in use by another user",
)));
}
let (user_id, password_hash) =
crate::database::create_user_or_fetch_password_hash(database, &connection.user, password)
.await
.unwrap();
let password_hash = PasswordHash::new(&password_hash).unwrap();
match verify_password(password, &password_hash) {
Ok(()) => Ok(Some(user_id)),
Err(argon2::password_hash::Error::Password) => Ok(None),
Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
}
Ok(Some(initiated))
}
@@ -342,139 +220,6 @@
CapSubCommand::ACK,
None,
Some(self.0),
),
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum AuthStrategy {
Plain,
}
impl AuthStrategy {
pub const SUPPORTED: &'static str = "PLAIN";
}
impl FromStr for AuthStrategy {
type Err = std::io::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"PLAIN" => Ok(Self::Plain),
_ => Err(Error::new(ErrorKind::InvalidData, "unknown auth strategy")),
}
}
}
pub struct SaslAlreadyAuthenticated(pub String);
impl SaslAlreadyAuthenticated {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLALREADY,
vec![
self.0,
"You have already authenticated using SASL".to_string(),
],
),
}
}
}
pub struct SaslStrategyUnsupported(String);
impl SaslStrategyUnsupported {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLMECHS,
vec![
self.0,
AuthStrategy::SUPPORTED.to_string(),
"are available SASL mechanisms".to_string(),
],
),
}
}
}
pub struct SaslSuccess(InitiatedConnection);
impl SaslSuccess {
#[must_use]
pub fn into_messages(self) -> [Message; 2] {
[
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLSUCCESS,
vec![
self.0.nick.to_string(),
"SASL authentication successful".to_string(),
],
),
},
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_LOGGEDIN,
vec![
self.0.nick.to_string(),
self.0.to_nick().to_string(),
self.0.user.to_string(),
format!("You are now logged in as {}", self.0.user),
],
),
},
]
}
}
pub struct SaslFail(String);
impl SaslFail {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLFAIL,
vec![self.0, "SASL authentication failed".to_string()],
),
}
}
}
pub struct SaslAborted(String);
impl SaslAborted {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLABORT,
vec![self.0, "SASL authentication aborted".to_string()],
),
}
}
@@ -1,0 +1,136 @@
use std::{
io::{Error, ErrorKind},
str::FromStr,
};
use actix::{Actor, ActorContext, Context, Handler, Message, ResponseFuture};
use argon2::PasswordHash;
use base64::{prelude::BASE64_STANDARD, Engine};
use futures::TryFutureExt;
use irc_proto::Command;
use crate::{
connection::{
sasl::{AuthStrategy, SaslAborted, SaslFail, SaslStrategyUnsupported},
UserId,
},
database::verify_password,
};
pub struct Authenticate {
pub selected_strategy: Option<AuthStrategy>,
pub database: sqlx::Pool<sqlx::Any>,
}
impl Actor for Authenticate {
type Context = Context<Self>;
}
impl Handler<AuthenticateMessage> for Authenticate {
type Result = ResponseFuture<Result<AuthenticateResult, std::io::Error>>;
fn handle(&mut self, msg: AuthenticateMessage, ctx: &mut Self::Context) -> Self::Result {
let Some(selected_strategy) = self.selected_strategy else {
let message = match AuthStrategy::from_str(&msg.0) {
Ok(strategy) => {
self.selected_strategy = Some(strategy);
irc_proto::Message {
tags: None,
prefix: None,
command: Command::AUTHENTICATE("+".to_string()),
}
}
Err(_) => SaslStrategyUnsupported::into_message(),
};
return Box::pin(futures::future::ok(AuthenticateResult::Reply(Box::new(message))));
};
if msg.0 == "*" {
ctx.stop();
return Box::pin(futures::future::ok(AuthenticateResult::Reply(Box::new(
SaslAborted::into_message(),
))));
}
match selected_strategy {
AuthStrategy::Plain => Box::pin(
handle_plain_authentication(msg.0, self.database.clone()).map_ok(|v| {
v.map_or_else(
|| AuthenticateResult::Reply(Box::new(SaslFail::into_message())),
|(username, user_id)| AuthenticateResult::Done(username, user_id),
)
}),
),
}
}
}
pub async fn handle_plain_authentication(
arguments: String,
database: sqlx::Pool<sqlx::Any>,
) -> Result<Option<(String, UserId)>, Error> {
let arguments = BASE64_STANDARD
.decode(&arguments)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let mut message = arguments.splitn(3, |f| *f == b'\0');
let (Some(authorization_identity), Some(authentication_identity), Some(password)) = (message.next(), message.next(), message.next()) else {
return Err(Error::new(ErrorKind::InvalidData, "bad plain message"));
};
if authorization_identity != authentication_identity {
return Err(Error::new(ErrorKind::InvalidData, "identity mismatch"));
}
let authorization_identity = std::str::from_utf8(authentication_identity)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let (user_id, password_hash) = crate::database::create_user_or_fetch_password_hash(
&database,
authorization_identity,
password,
)
.await
.unwrap();
let password_hash = PasswordHash::new(&password_hash).unwrap();
match verify_password(password, &password_hash) {
Ok(()) => Ok(Some((authorization_identity.to_string(), UserId(user_id)))),
Err(argon2::password_hash::Error::Password) => Ok(None),
Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
}
}
pub enum AuthenticateResult {
Reply(Box<irc_proto::Message>),
Done(String, UserId),
}
#[derive(Message)]
#[rtype(result = "Result<AuthenticateResult, std::io::Error>")]
pub struct AuthenticateMessage(pub String);
@@ -1,0 +1,144 @@
use std::{
io::{Error, ErrorKind},
str::FromStr,
};
use irc_proto::{Command, Message, Response};
use crate::connection::InitiatedConnection;
#[derive(Copy, Clone, Debug)]
pub enum AuthStrategy {
Plain,
}
impl AuthStrategy {
pub const SUPPORTED: &'static str = "PLAIN";
}
impl FromStr for AuthStrategy {
type Err = std::io::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"PLAIN" => Ok(Self::Plain),
_ => Err(Error::new(ErrorKind::InvalidData, "unknown auth strategy")),
}
}
}
pub struct SaslAlreadyAuthenticated(pub String);
impl SaslAlreadyAuthenticated {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLALREADY,
vec![
self.0,
"You have already authenticated using SASL".to_string(),
],
),
}
}
}
pub struct SaslStrategyUnsupported;
impl SaslStrategyUnsupported {
#[must_use]
pub fn into_message() -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLMECHS,
vec![
AuthStrategy::SUPPORTED.to_string(),
"are available SASL mechanisms".to_string(),
],
),
}
}
}
pub struct SaslSuccess;
impl SaslSuccess {
#[must_use]
pub fn into_message() -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLSUCCESS,
vec!["SASL authentication successful".to_string()],
),
}
}
}
pub struct ConnectionSuccess(pub InitiatedConnection);
impl ConnectionSuccess {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_LOGGEDIN,
vec![
self.0.nick.to_string(),
self.0.to_nick().to_string(),
self.0.user.to_string(),
format!("You are now logged in as {}", self.0.user),
],
),
}
}
}
pub struct SaslFail;
impl SaslFail {
#[must_use]
pub fn into_message() -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLFAIL,
vec!["SASL authentication failed".to_string()],
),
}
}
}
pub struct SaslAborted;
impl SaslAborted {
#[must_use]
pub fn into_message() -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLABORT,
vec!["SASL authentication aborted".to_string()],
),
}
}
}