use std::{
io::{Error, ErrorKind},
net::SocketAddr,
str::FromStr,
};
use actix::{io::FramedWrite, Addr};
use argon2::PasswordHash;
use base64::{prelude::BASE64_STANDARD, Engine};
use const_format::concatcp;
use futures::{SinkExt, TryStreamExt};
use irc_proto::{
error::ProtocolError, CapSubCommand, Command, IrcCodec, Message, Prefix, Response,
};
use tokio::{
io::{ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_util::codec::FramedRead;
use tracing::{instrument, warn};
use crate::{
database::verify_password,
persistence::{events::ReserveNick, Persistence},
};
pub type MessageStream = FramedRead<ReadHalf<TcpStream>, irc_proto::IrcCodec>;
pub type MessageSink = FramedWrite<Message, WriteHalf<TcpStream>, irc_proto::IrcCodec>;
pub const SUPPORTED_CAPABILITIES: &[&str] = &[concatcp!("sasl=", AuthStrategy::SUPPORTED)];
#[derive(Copy, Clone, Debug)]
pub struct UserId(pub i64);
#[derive(Default)]
pub struct ConnectionRequest {
host: Option<SocketAddr>,
nick: Option<String>,
user: Option<String>,
mode: Option<String>,
real_name: Option<String>,
}
#[derive(Clone)]
pub struct InitiatedConnection {
pub host: SocketAddr,
pub nick: String,
pub user: String,
pub mode: String,
pub real_name: String,
pub user_id: UserId,
}
impl InitiatedConnection {
#[must_use]
pub fn to_nick(&self) -> Prefix {
Prefix::Nickname(
self.nick.to_string(),
self.user.to_string(),
self.host.ip().to_string(),
)
}
}
impl TryFrom<ConnectionRequest> for InitiatedConnection {
type Error = ConnectionRequest;
fn try_from(value: ConnectionRequest) -> Result<Self, Self::Error> {
let ConnectionRequest {
host: Some(host),
nick: Some(nick),
user: Some(user),
mode: Some(mode),
real_name: Some(real_name),
} = value else {
return Err(value);
};
Ok(Self {
host,
nick,
user,
mode,
real_name,
user_id: UserId(0),
})
}
}
#[instrument(skip_all)]
pub async fn negotiate_client_connection(
s: &mut MessageStream,
write: &mut tokio_util::codec::FramedWrite<WriteHalf<TcpStream>, IrcCodec>,
host: SocketAddr,
persistence: &Addr<Persistence>,
database: sqlx::Pool<sqlx::Any>,
) -> Result<Option<InitiatedConnection>, ProtocolError> {
let mut request = ConnectionRequest {
host: Some(host),
..ConnectionRequest::default()
};
let mut capabilities_requested = false;
let initiated = loop {
let Some(msg) = s.try_next().await? else {
break None;
};
#[allow(clippy::match_same_arms)]
match msg.command {
Command::PASS(_) => {}
Command::NICK(nick) => request.nick = Some(nick),
Command::USER(user, mode, real_name) => {
request.user = Some(user);
request.mode = Some(mode);
request.real_name = Some(real_name);
}
Command::CAP(_, CapSubCommand::LIST | CapSubCommand::LS, _, _) => {
capabilities_requested = true;
write
.send(Message {
tags: None,
prefix: None,
command: Command::CAP(
Some("*".to_string()),
CapSubCommand::LS,
None,
Some(SUPPORTED_CAPABILITIES.join(" ")),
),
})
.await
.unwrap();
}
_ => {
warn!(?msg, "Client sent unknown command during negotiation");
}
};
match InitiatedConnection::try_from(std::mem::take(&mut request)) {
Ok(v) => break Some(v),
Err(v) => {
request = v;
}
}
};
let Some(mut initiated) = initiated else {
return Ok(None);
};
if !capabilities_requested {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"capabilities not requested by client, so SASL authentication can not be performed",
)));
}
let mut user_id = None;
while let Some(msg) = s.try_next().await? {
match msg.command {
Command::CAP(_, CapSubCommand::REQ, Some(arguments), None) => {
write
.send(AcknowledgedCapabilities(arguments).into_message())
.await?;
}
Command::CAP(_, CapSubCommand::END, _, _) => {
break;
}
Command::AUTHENTICATE(strategy) => {
user_id =
start_authenticate_flow(s, write, &initiated, strategy, &database).await?;
}
_ => {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
format!("client sent non-cap message during negotiation {msg:?}"),
)))
}
}
}
if let Some(user_id) = user_id {
initiated.user_id.0 = user_id;
let reserved_nick = persistence
.send(ReserveNick {
user_id: initiated.user_id,
nick: initiated.nick.clone(),
})
.await
.map_err(|e| ProtocolError::Io(Error::new(ErrorKind::InvalidData, e)))?;
if !reserved_nick {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"nick is already in use by another user",
)));
}
Ok(Some(initiated))
} else {
Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"user has not authenticated",
)))
}
}
async fn start_authenticate_flow(
s: &mut MessageStream,
write: &mut tokio_util::codec::FramedWrite<WriteHalf<TcpStream>, IrcCodec>,
connection: &InitiatedConnection,
strategy: String,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<Option<i64>, ProtocolError> {
let Ok(auth_strategy) = AuthStrategy::from_str(&strategy) else {
write.send(SaslStrategyUnsupported(connection.nick.to_string()).into_message())
.await?;
return Ok(None);
};
write
.send(Message {
tags: None,
prefix: None,
command: Command::AUTHENTICATE("+".to_string()),
})
.await?;
while let Some(msg) = s.try_next().await? {
let Command::AUTHENTICATE(arguments) = msg.command else {
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
format!("client sent invalid message during authentication {msg:?}"),
)));
};
if arguments == "*" {
write
.send(SaslAborted(connection.nick.to_string()).into_message())
.await?;
break;
}
let user_id = match auth_strategy {
AuthStrategy::Plain => {
handle_plain_authentication(&arguments, connection, database).await?
}
};
if user_id.is_some() {
for message in SaslSuccess(connection.clone()).into_messages() {
write.send(message).await?;
}
return Ok(user_id);
}
write
.send(SaslFail(connection.nick.to_string()).into_message())
.await?;
}
Ok(None)
}
pub async fn handle_plain_authentication(
arguments: &str,
connection: &InitiatedConnection,
database: &sqlx::Pool<sqlx::Any>,
) -> Result<Option<i64>, Error> {
let arguments = BASE64_STANDARD
.decode(arguments)
.map_err(|e| Error::new(ErrorKind::InvalidData, e))?;
let mut message = arguments.splitn(3, |f| *f == b'\0');
let (Some(authorization_identity), Some(authentication_identity), Some(password)) = (message.next(), message.next(), message.next()) else {
return Err(Error::new(ErrorKind::InvalidData, "bad plain message"));
};
if authorization_identity != connection.user.as_bytes()
|| authentication_identity != connection.user.as_bytes()
{
return Err(Error::new(ErrorKind::InvalidData, "username mismatch"));
}
let (user_id, password_hash) =
crate::database::create_user_or_fetch_password_hash(database, &connection.user, password)
.await
.unwrap();
let password_hash = PasswordHash::new(&password_hash).unwrap();
match verify_password(password, &password_hash) {
Ok(()) => Ok(Some(user_id)),
Err(argon2::password_hash::Error::Password) => Ok(None),
Err(e) => Err(Error::new(ErrorKind::InvalidData, e.to_string())),
}
}
pub struct AcknowledgedCapabilities(String);
impl AcknowledgedCapabilities {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::CAP(
Some("*".to_string()),
CapSubCommand::ACK,
None,
Some(self.0),
),
}
}
}
#[derive(Copy, Clone, Debug)]
pub enum AuthStrategy {
Plain,
}
impl AuthStrategy {
pub const SUPPORTED: &'static str = "PLAIN";
}
impl FromStr for AuthStrategy {
type Err = std::io::Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"PLAIN" => Ok(Self::Plain),
_ => Err(Error::new(ErrorKind::InvalidData, "unknown auth strategy")),
}
}
}
pub struct SaslAlreadyAuthenticated(pub String);
impl SaslAlreadyAuthenticated {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLALREADY,
vec![
self.0,
"You have already authenticated using SASL".to_string(),
],
),
}
}
}
pub struct SaslStrategyUnsupported(String);
impl SaslStrategyUnsupported {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLMECHS,
vec![
self.0,
AuthStrategy::SUPPORTED.to_string(),
"are available SASL mechanisms".to_string(),
],
),
}
}
}
pub struct SaslSuccess(InitiatedConnection);
impl SaslSuccess {
#[must_use]
pub fn into_messages(self) -> [Message; 2] {
[
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_SASLSUCCESS,
vec![
self.0.nick.to_string(),
"SASL authentication successful".to_string(),
],
),
},
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::RPL_LOGGEDIN,
vec![
self.0.nick.to_string(),
self.0.to_nick().to_string(),
self.0.user.to_string(),
format!("You are now logged in as {}", self.0.user),
],
),
},
]
}
}
pub struct SaslFail(String);
impl SaslFail {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLFAIL,
vec![self.0, "SASL authentication failed".to_string()],
),
}
}
}
pub struct SaslAborted(String);
impl SaslAborted {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_SASLABORT,
vec![self.0, "SASL authentication aborted".to_string()],
),
}
}
}