#![allow(clippy::iter_without_into_iter)]
mod authenticate;
pub mod sasl;
use std::{
fmt::{Display, Formatter},
io::{Error, ErrorKind},
net::SocketAddr,
str::FromStr,
time::Duration,
};
use actix::{io::FramedWrite, Actor, Addr};
use bitflags::bitflags;
use chrono::Utc;
use const_format::concatcp;
use futures::{SinkExt, TryStreamExt};
use hickory_resolver::TokioAsyncResolver;
use irc_proto::{
error::ProtocolError, CapSubCommand, Command, IrcCodec, Message, Prefix, Response,
};
use sha2::digest::{FixedOutput, Update};
use tokio::{
io::{ReadHalf, WriteHalf},
net::TcpStream,
};
use tokio_util::codec::FramedRead;
use tracing::{instrument, warn};
use crate::{
connection::{
authenticate::{Authenticate, AuthenticateMessage, AuthenticateResult},
sasl::{AuthStrategy, ConnectionSuccess, SaslSuccess},
},
host_mask::HostMask,
keys::Keys,
persistence::{events::ReserveNick, Persistence},
};
pub type MessageStream = FramedRead<ReadHalf<TcpStream>, irc_proto::IrcCodec>;
pub type MessageSink = FramedWrite<Message, WriteHalf<TcpStream>, irc_proto::IrcCodec>;
#[derive(Copy, Clone, Debug, Hash, PartialEq, Eq, sqlx::Type)]
#[sqlx(transparent)]
pub struct UserId(pub i64);
#[derive(Default)]
pub struct ConnectionRequest {
host: Option<SocketAddr>,
nick: Option<String>,
user: Option<String>,
real_name: Option<String>,
user_id: Option<UserId>,
capabilities: Capability,
}
#[derive(Clone, Debug)]
pub struct InitiatedConnection {
pub host: SocketAddr,
pub resolved_host: Option<String>,
pub cloak: String,
pub nick: String,
pub user: String,
pub mode: UserMode,
pub real_name: String,
pub user_id: UserId,
pub capabilities: Capability,
pub away: Option<String>,
pub at: chrono::DateTime<Utc>,
}
impl InitiatedConnection {
pub fn new(value: ConnectionRequest, keys: &Keys) -> Result<Self, ConnectionRequest> {
let ConnectionRequest {
host: Some(host),
nick: Some(nick),
user: Some(user),
real_name: Some(real_name),
user_id: Some(user_id),
capabilities,
} = value
else {
return Err(value);
};
let cloak = sha2::Sha256::default()
.chain(host.ip().to_canonical().to_string())
.chain(keys.ip_salt)
.finalize_fixed();
let mut cloak = hex::encode(cloak);
cloak.truncate(12);
Ok(Self {
host,
resolved_host: None,
cloak: format!("cloaked-{cloak}"),
nick,
user,
mode: UserMode::empty(),
real_name,
user_id,
capabilities,
away: None,
at: Utc::now(),
})
}
#[must_use]
pub fn to_nick(&self) -> Prefix {
Prefix::Nickname(
self.nick.to_string(),
self.user.to_string(),
self.cloak.to_string(),
)
}
#[must_use]
pub fn to_host_mask(&self) -> HostMask<'_> {
HostMask::new(&self.nick, &self.user, &self.cloak)
}
}
#[instrument(skip_all)]
pub async fn negotiate_client_connection(
s: &mut MessageStream,
write: &mut tokio_util::codec::FramedWrite<WriteHalf<TcpStream>, IrcCodec>,
host: SocketAddr,
persistence: &Addr<Persistence>,
database: sqlx::Pool<sqlx::Any>,
resolver: &TokioAsyncResolver,
keys: &Keys,
) -> Result<Option<InitiatedConnection>, ProtocolError> {
let mut request = ConnectionRequest {
host: Some(host),
..ConnectionRequest::default()
};
let authenticate_handle = Authenticate {
selected_strategy: None,
database: database.clone(),
}
.start();
let initiated = loop {
let Some(msg) = s.try_next().await? else {
break None;
};
#[allow(clippy::match_same_arms)]
match msg.command {
Command::PASS(_) => {}
Command::NICK(nick) => request.nick = Some(nick),
Command::USER(_user, _mode, real_name) => {
request.real_name = Some(real_name);
}
Command::CAP(_, CapSubCommand::LIST | CapSubCommand::LS, _, _) => {
write
.send(Message {
tags: None,
prefix: None,
command: Command::CAP(
Some("*".to_string()),
CapSubCommand::LS,
None,
Some(Capability::SUPPORTED.join(" ")),
),
})
.await
.unwrap();
}
Command::CAP(_, CapSubCommand::REQ, Some(arguments), None) => {
let mut acked = true;
for argument in arguments.split(' ') {
acked = if argument == "sasl" {
acked
} else if let Ok(capability) = Capability::from_str(argument) {
request.capabilities |= capability;
acked
} else {
false
};
}
write
.send(AcknowledgedCapabilities(arguments, acked).into_message())
.await?;
}
Command::AUTHENTICATE(msg) => {
match authenticate_handle
.send(AuthenticateMessage(msg))
.await
.unwrap()?
{
AuthenticateResult::Reply(v) => {
write.send(*v).await?;
}
AuthenticateResult::Done(username, user_id) => {
request.user = Some(username);
request.user_id = Some(user_id);
write.send(SaslSuccess::into_message()).await?;
}
}
}
_ => {
warn!(?msg, "Client sent unknown command during negotiation");
}
};
match InitiatedConnection::new(std::mem::take(&mut request), keys) {
Ok(v) => break Some(v),
Err(v) => {
request = v;
}
}
};
let Some(mut initiated) = initiated else {
return Ok(None);
};
if let Ok(Ok(v)) = tokio::time::timeout(
Duration::from_millis(250),
resolver.reverse_lookup(host.ip().to_canonical()),
)
.await
{
initiated.resolved_host = v
.iter()
.next()
.map(|v| v.to_utf8().trim_end_matches('.').to_string());
}
write
.send(ConnectionSuccess(initiated.clone()).into_message())
.await?;
let reserved_nick = persistence
.send(ReserveNick {
user_id: initiated.user_id,
nick: initiated.nick.clone(),
})
.await
.map_err(|e| ProtocolError::Io(Error::new(ErrorKind::InvalidData, e)))?;
if !reserved_nick {
write
.send(NickNotOwnedByUser(initiated.nick).into_message())
.await?;
return Err(ProtocolError::Io(Error::new(
ErrorKind::InvalidData,
"nick is already in use by another user",
)));
}
Ok(Some(initiated))
}
pub struct NickNotOwnedByUser(pub String);
impl NickNotOwnedByUser {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::Response(
Response::ERR_NICKNAMEINUSE,
vec![self.0, "Nickname is already in use".to_string()],
),
}
}
}
pub struct AcknowledgedCapabilities(String, bool);
impl AcknowledgedCapabilities {
#[must_use]
pub fn into_message(self) -> Message {
Message {
tags: None,
prefix: None,
command: Command::CAP(
Some("*".to_string()),
if self.1 {
CapSubCommand::ACK
} else {
CapSubCommand::NAK
},
None,
Some(self.0),
),
}
}
}
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub struct Capability: u32 {
const USERHOST_IN_NAMES = 0b0000_0000_0000_0000_0000_0000_0000_0001;
const SERVER_TIME = 0b0000_0000_0000_0000_0000_0000_0000_0010;
}
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Default)]
pub struct UserMode: u32 {
const WALLOPS = 0b0000_0000_0000_0000_0000_0000_0000_0001;
const OPER = 0b0000_0000_0000_0000_0000_0000_0000_0010;
}
}
impl Display for UserMode {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "+")?;
if self.contains(Self::WALLOPS) {
write!(f, "w")?;
}
if self.contains(Self::OPER) {
write!(f, "o")?;
}
Ok(())
}
}
impl Capability {
pub const SUPPORTED: &'static [&'static str] = &[
"userhost-in-names",
"server-time",
concatcp!("sasl=", AuthStrategy::SUPPORTED),
];
}
impl FromStr for Capability {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"userhost-in-names" => Ok(Self::USERHOST_IN_NAMES),
"server-time" => Ok(Self::SERVER_TIME),
_ => Err(()),
}
}
}