#![deny(clippy::nursery, clippy::pedantic)]
#![allow(
clippy::module_name_repetitions,
clippy::missing_panics_doc,
clippy::missing_errors_doc
)]
use std::{collections::HashMap, str::FromStr, sync::Arc};
use actix::{io::FramedWrite, Actor, Addr, AsyncContext, Supervisor};
use actix_rt::{Arbiter, System};
use bytes::BytesMut;
use clap::Parser;
use futures::SinkExt;
use hickory_resolver::AsyncResolver;
use irc_proto::{Command, IrcCodec, Message};
use rand::seq::SliceRandom;
use sqlx::migrate::Migrator;
use titanircd::{
client::Client,
config::Args,
connection,
host_mask::HostMaskMap,
keys::Keys,
messages::{UserConnected, ValidateConnection},
persistence::Persistence,
server::{response::ConnectionValidated, Server},
};
use tokio::{
io::WriteHalf,
net::{TcpListener, TcpStream},
time::Instant,
};
use tokio_util::codec::FramedRead;
use tracing::{error, info, info_span, Instrument};
use tracing_subscriber::EnvFilter;
static MIGRATOR: Migrator = sqlx::migrate!();
#[actix_rt::main]
async fn main() -> anyhow::Result<()> {
let opts: Args = Args::parse();
std::env::set_var(
"RUST_LOG",
match opts.verbose {
1 => "debug",
2 => "trace",
_ => "info",
},
);
let subscriber = tracing_subscriber::fmt()
.with_env_filter(EnvFilter::from_default_env())
.pretty();
subscriber.init();
sqlx::any::install_default_drivers();
let database = sqlx::Pool::connect_with(sqlx::any::AnyConnectOptions::from_str(
&opts.config.database_uri,
)?)
.await?;
MIGRATOR.run(&database).await?;
let keys = Arc::new(Keys::new(&database).await?);
let listen_address = opts.config.listen_address;
let client_threads = opts.config.client_threads;
let server_arbiter = Arbiter::new();
let persistence_addr = {
let database = database.clone();
let config = opts.config.clone();
Supervisor::start_in_arbiter(&server_arbiter.handle(), move |_ctx| Persistence {
database,
max_message_replay_since: config.max_message_replay_since,
last_seen_clock: 0,
})
};
let persistence = persistence_addr.clone();
let server = Supervisor::start_in_arbiter(&server_arbiter.handle(), move |_ctx| Server {
channels: HashMap::default(),
clients: HashMap::default(),
channel_arbiters: build_arbiters(opts.config.channel_threads),
config: opts.config,
persistence,
max_clients: 0,
bans: HostMaskMap::new(),
});
let listener = TcpListener::bind(listen_address).await?;
actix_rt::spawn(start_tcp_acceptor_loop(
listener,
database,
persistence_addr,
server,
client_threads,
keys,
));
info!("Server listening on {}", listen_address);
tokio::signal::ctrl_c().await?;
System::current().stop();
Ok(())
}
async fn start_tcp_acceptor_loop(
listener: TcpListener,
database: sqlx::Pool<sqlx::Any>,
persistence: Addr<Persistence>,
server: Addr<Server>,
client_threads: usize,
keys: Arc<Keys>,
) {
let client_arbiters = Arc::new(build_arbiters(client_threads));
let resolver = Arc::new(AsyncResolver::tokio_from_system_conf().unwrap());
while let Ok((stream, addr)) = listener.accept().await {
let span = info_span!("connection", %addr);
let _entered = span.clone().entered();
info!("Accepted connection");
let database = database.clone();
let server = server.clone();
let client_arbiters = client_arbiters.clone();
let persistence = persistence.clone();
let resolver = resolver.clone();
let keys = keys.clone();
actix_rt::spawn(async move {
let (read, writer) = tokio::io::split(stream);
let mut read = FramedRead::new(read, irc_codec());
let mut write = tokio_util::codec::FramedWrite::new(writer, irc_codec());
let connection = match connection::negotiate_client_connection(&mut read, &mut write, addr, &persistence, database, &resolver, &keys).await {
Ok(Some(v)) => v,
Ok(None) => {
error!("Failed to fully handshake with client, dropping connection");
let command = Command::ERROR("You must use SASL to connect to this server".to_string());
if let Err(error) = write.send(Message { tags: None, prefix: None, command, }).await {
error!(%error, "Failed to send error message to client, forcefully closing connection.");
}
return;
}
Err(error) => {
error!(%error, "An error occurred whilst handshaking with client");
let command = Command::ERROR(error.to_string());
if let Err(error) = write.send(Message { tags: None, prefix: None, command, }).await {
error!(%error, "Failed to send error message to client, forcefully closing connection.");
}
return;
}
};
match server.send(ValidateConnection(connection.clone())).await.unwrap() {
ConnectionValidated::Allowed => {}
ConnectionValidated::Reject(reason) => {
let command = Command::ERROR(reason.to_string());
if let Err(error) = write.send(Message { tags: None, prefix: None, command, }).await {
error!(%error, "Failed to send error message to client, forcefully closing connection.");
}
return;
}
}
let handle = {
let server = server.clone();
let arbiter = client_arbiters.choose(&mut rand::thread_rng()).map_or_else(Arbiter::current, Arbiter::handle);
let span = span.clone();
let connection = connection.clone();
Client::start_in_arbiter(&arbiter, move |ctx| {
let (stream, codec, buffer) = unpack_writer(write);
let writer = FramedWrite::from_buffer(stream, codec, buffer, ctx);
ctx.add_stream(read);
Client {
writer,
connection,
server,
channels: HashMap::new(),
last_active: Instant::now(),
graceful_shutdown: false,
server_leave_reason: None,
span,
persistence,
}
})
};
server.do_send(UserConnected { handle, connection, span });
}.instrument(info_span!("negotiation")));
}
}
#[must_use]
pub fn unpack_writer(
mut writer: tokio_util::codec::FramedWrite<WriteHalf<TcpStream>, IrcCodec>,
) -> (WriteHalf<TcpStream>, IrcCodec, BytesMut) {
let codec = std::mem::replace(writer.encoder_mut(), irc_codec());
let bytes = writer.write_buffer_mut().split();
let stream = writer.into_inner();
(stream, codec, bytes)
}
#[must_use]
pub fn irc_codec() -> IrcCodec {
IrcCodec::new("utf8").unwrap()
}
#[must_use]
pub fn build_arbiters(count: usize) -> Vec<Arbiter> {
std::iter::repeat(())
.take(count)
.map(|()| Arbiter::new())
.collect()
}