🏡 index : ~doyle/titanirc.git

author Jordan Doyle <jordan@doyle.la> 2024-01-31 23:30:20.0 +00:00:00
committer Jordan Doyle <jordan@doyle.la> 2024-01-31 23:30:20.0 +00:00:00
commit
92e5af86ddad5fbbb223cbdba90b4affdcab3489 [patch]
tree
c0237f2928b00f070f8523a47b00748713e03a0f
parent
18d9211dd79eaf6affbd334aa00d031993496b31
download
92e5af86ddad5fbbb223cbdba90b4affdcab3489.tar.gz

Implement GLINE



Diff

 Cargo.lock                                          |  10 +-
 Cargo.toml                                          |   2 +-
 migrations/20240131220401_add_server_bans_table.sql |   9 +-
 src/client.rs                                       |  72 ++++++--
 src/host_mask.rs                                    |  61 ++++++-
 src/lib.rs                                          |   1 +-
 src/main.rs                                         |  22 +-
 src/messages.rs                                     |  25 +++-
 src/persistence.rs                                  |  69 ++++++-
 src/persistence/events.rs                           |  32 +++-
 src/proto.rs                                        | 196 +++++++++++++++++++++-
 src/server.rs                                       | 163 +++++++++++++++--
 src/server/response.rs                              |  64 ++++++-
 13 files changed, 691 insertions(+), 35 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 160b8e2..7d821b4 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -907,6 +907,12 @@ dependencies = [
]

[[package]]
name = "humantime"
version = "2.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a3a5bfb195931eeb336b2a7b4d761daec841b97f947d34394601737a7bba5e4"

[[package]]
name = "iana-time-zone"
version = "0.1.59"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1586,7 +1592,7 @@ version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c367b5dafa12cef19c554638db10acde90d5e9acea2b80e1ad98b00f88068f7d"
dependencies = [
 "humantime",
 "humantime 1.3.0",
 "serde",
]

@@ -2051,6 +2057,7 @@ dependencies = [
 "futures",
 "hex",
 "hickory-resolver",
 "humantime 2.1.0",
 "irc-proto",
 "itertools",
 "rand",
@@ -2058,6 +2065,7 @@ dependencies = [
 "serde-humantime",
 "sha2",
 "sqlx",
 "thiserror",
 "tokio",
 "tokio-stream",
 "tokio-util",
diff --git a/Cargo.toml b/Cargo.toml
index 5bfd9cd..2ecaf61 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -18,12 +18,14 @@ chrono = "0.4"
clap = { version = "4.1", features = ["cargo", "derive", "std", "suggestions", "color"] }
futures = "0.3"
hex = "0.4"
humantime = "2.1"
hickory-resolver = { version = "0.24", features = ["tokio-runtime", "system-config"] }
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde-humantime = "0.1"
sha2 = "0.10    "
sqlx = { version = "0.7", features = ["runtime-tokio-rustls", "sqlite", "any"] }
thiserror = "1.0"
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
toml = "0.8"
diff --git a/migrations/20240131220401_add_server_bans_table.sql b/migrations/20240131220401_add_server_bans_table.sql
new file mode 100644
index 0000000..aee55fd
--- /dev/null
+++ b/migrations/20240131220401_add_server_bans_table.sql
@@ -0,0 +1,9 @@
CREATE TABLE server_bans (
    mask VARCHAR(255) NOT NULL,
    requester INT NOT NULL,
    reason VARCHAR(255) NOT NULL,
    created_timestamp INT NOT NULL,
    expires_timestamp INT,
    FOREIGN KEY(requester) REFERENCES users,
    PRIMARY KEY(mask)
);
diff --git a/src/client.rs b/src/client.rs
index 192f31c..1230076 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -24,9 +24,10 @@ use crate::{
        Broadcast, ChannelFetchTopic, ChannelFetchWhoList, ChannelInvite, ChannelJoin,
        ChannelKickUser, ChannelList, ChannelMemberList, ChannelMessage, ChannelPart,
        ChannelSetMode, ChannelUpdateTopic, ClientAway, ConnectedChannels, FetchClientDetails,
        FetchUserPermission, FetchWhoList, FetchWhois, ForceDisconnect, KillUser, MessageKind,
        PrivateMessage, ServerAdminInfo, ServerDisconnect, ServerFetchMotd, ServerListUsers,
        UserKickedFromChannel, UserNickChange, UserNickChangeInternal, Wallops,
        FetchUserPermission, FetchWhoList, FetchWhois, ForceDisconnect, Gline, KillUser, ListGline,
        MessageKind, PrivateMessage, RemoveGline, ServerAdminInfo, ServerDisconnect,
        ServerFetchMotd, ServerListUsers, UserKickedFromChannel, UserNickChange,
        UserNickChangeInternal, Wallops,
    },
    persistence::{
        events::{
@@ -35,6 +36,7 @@ use crate::{
        },
        Persistence,
    },
    proto::LocalCommand,
    server::{
        response::{IntoProtocol, WhoList},
        Server,
@@ -941,21 +943,55 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
            Command::BATCH(_, _, _) => {}
            Command::CHGHOST(_, _) => {}
            Command::Response(_, _) => {}
            v => self.writer.write(Message {
                tags: None,
                prefix: Some(Prefix::new_from_str(&self.connection.nick)),
                command: Command::Response(
                    Response::ERR_UNKNOWNCOMMAND,
                    vec![
                        String::from(&v)
                            .split_whitespace()
                            .next()
                            .unwrap_or_default()
                            .to_string(),
                        "Unknown command".to_string(),
                    ],
                ),
            }),
            Command::Raw(command, args) => self.handle_custom_command(ctx, command, args),
            _ => {
                for m in crate::proto::Error::UnknownCommand.into_messages(&self.connection.nick) {
                    self.writer.write(m);
                }
            }
        }
    }
}

impl Client {
    fn handle_custom_command(
        &mut self,
        ctx: &mut Context<Self>,
        command: String,
        args: Vec<String>,
    ) {
        match LocalCommand::try_from((command, args)) {
            Ok(LocalCommand::Gline(mask, duration, reason))
                if self.connection.mode.contains(UserMode::OPER) =>
            {
                self.server_send_map_write(
                    ctx,
                    Gline {
                        requester: self.connection.clone(),
                        mask,
                        duration,
                        reason,
                    },
                );
            }
            Ok(LocalCommand::RemoveGline(mask))
                if self.connection.mode.contains(UserMode::OPER) =>
            {
                self.server_send_map_write(ctx, RemoveGline { mask });
            }
            Ok(LocalCommand::ListGline) if self.connection.mode.contains(UserMode::OPER) => {
                self.server_send_map_write(ctx, ListGline);
            }
            Err(e) => {
                for m in e.into_messages(&self.connection.nick) {
                    self.writer.write(m);
                }
            }
            _ => {
                for m in crate::proto::Error::UnknownCommand.into_messages(&self.connection.nick) {
                    self.writer.write(m);
                }
            }
        }
    }
}
diff --git a/src/host_mask.rs b/src/host_mask.rs
index 8d17463..2cbd45a 100644
--- a/src/host_mask.rs
+++ b/src/host_mask.rs
@@ -106,6 +106,35 @@ impl<T> HostMaskMap<T> {
        }
    }

    pub fn remove(&mut self, mask: &HostMask<'_>) -> bool {
        let mut next_mask = mask.as_borrowed();

        let key = match self.matcher {
            Matcher::Nick => take_next_char(&mask.nick, &mut next_mask.nick),
            Matcher::Username => take_next_char(&mask.username, &mut next_mask.username),
            Matcher::Host => take_next_char(&mask.host, &mut next_mask.host),
        };

        let key = match key {
            Some('*') => Key::Wildcard,
            Some(c) => Key::Char(c),
            None => Key::EndOfString,
        };

        if key.is_end() && self.matcher.next().is_none() {
            self.children.remove(&key).is_some()
        } else {
            let Some(node) = self.children.get_mut(&key) else {
                return false;
            };

            match node {
                Node::Match(_) => unreachable!("stored hostmask has less parts than a!b@c"),
                Node::Inner(map) => map.remove(&next_mask),
            }
        }
    }

    /// Fetches all the matches within the trie that match the input. This function returns
    /// any exact matches as well as any wildcard matches. This function operates in `O(m)`
    /// average time complexity.
@@ -219,7 +248,7 @@ impl Matcher {
    }
}

#[derive(Clone, Debug)]
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct HostMask<'a> {
    nick: Cow<'a, str>,
    username: Cow<'a, str>,
@@ -337,6 +366,36 @@ mod test {
    use crate::host_mask::{HostMask, HostMaskMap};

    #[test]
    fn from_iter() {
        let map = [
            ("aaa*!bbb@cccc".try_into().unwrap(), 10),
            ("aaab!ccc@dddd".try_into().unwrap(), 10),
        ]
        .into_iter()
        .collect::<HostMaskMap<_>>();

        let retrieved = map.get(&"aaaa!bbb@cccc".try_into().unwrap());
        assert_eq!(retrieved.len(), 1);
        assert_eq!(*retrieved[0], 10);

        let retrieved = map.get(&"aaab!ccc@dddd".try_into().unwrap());
        assert_eq!(retrieved.len(), 1);
        assert_eq!(*retrieved[0], 10);
    }

    #[test]
    fn iter() {
        let mut map = HostMaskMap::new();
        map.insert(&"aaaa!*@*".try_into().unwrap(), 30);
        map.insert(&"bbbb!a@b".try_into().unwrap(), 40);

        let retrieved = map.iter().collect::<Vec<_>>();
        assert_eq!(retrieved.len(), 2);
        assert!(retrieved.contains(&("aaaa!*@*".to_string(), &30)));
        assert!(retrieved.contains(&("bbbb!a@b".to_string(), &40)));
    }

    #[test]
    fn wildcard_middle_of_string_unsupported() {
        assert!(HostMask::try_from("aa*a!bbbb@cccc").is_err());
    }
diff --git a/src/lib.rs b/src/lib.rs
index b6c93ae..3430195 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -14,6 +14,7 @@ pub mod host_mask;
pub mod keys;
pub mod messages;
pub mod persistence;
pub mod proto;
pub mod server;

pub const SERVER_NAME: &str = "my.cool.server";
diff --git a/src/main.rs b/src/main.rs
index a99ece4..0ce11a6 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -17,8 +17,14 @@ use irc_proto::{Command, IrcCodec, Message};
use rand::seq::SliceRandom;
use sqlx::migrate::Migrator;
use titanircd::{
    client::Client, config::Args, connection, keys::Keys, messages::UserConnected,
    persistence::Persistence, server::Server,
    client::Client,
    config::Args,
    connection,
    host_mask::HostMaskMap,
    keys::Keys,
    messages::{UserConnected, ValidateConnection},
    persistence::Persistence,
    server::{response::ConnectionValidated, Server},
};
use tokio::{
    io::WriteHalf,
@@ -86,6 +92,7 @@ async fn main() -> anyhow::Result<()> {
        config: opts.config,
        persistence,
        max_clients: 0,
        bans: HostMaskMap::new(),
    });

    let listener = TcpListener::bind(listen_address).await?;
@@ -165,6 +172,17 @@ async fn start_tcp_acceptor_loop(
                }
            };

            match server.send(ValidateConnection(connection.clone())).await.unwrap() {
                ConnectionValidated::Allowed => {}
                ConnectionValidated::Reject(reason) => {
                    let command = Command::ERROR(reason.to_string());
                    if let Err(error) = write.send(Message { tags: None, prefix: None, command, }).await {
                        error!(%error, "Failed to send error message to client, forcefully closing connection.");
                    }
                    return;
                }
            }

            // spawn the client's actor
            let handle = {
                let server = server.clone();
diff --git a/src/messages.rs b/src/messages.rs
index 99349b4..d2d7d4e 100644
--- a/src/messages.rs
+++ b/src/messages.rs
@@ -1,3 +1,5 @@
use std::time::Duration;

use actix::{Addr, Message};
use anyhow::Result;
use irc_proto::{ChannelMode, Mode};
@@ -172,6 +174,29 @@ pub struct ChannelSetMode {
    pub modes: Vec<Mode<ChannelMode>>,
}

#[derive(Message)]
#[rtype(result = "()")]
pub struct Gline {
    pub requester: InitiatedConnection,
    pub mask: HostMask<'static>,
    pub duration: Option<Duration>,
    pub reason: Option<String>,
}

#[derive(Message)]
#[rtype(result = "()")]
pub struct RemoveGline {
    pub mask: HostMask<'static>,
}

#[derive(Message)]
#[rtype(result = "Vec<super::server::response::ServerBan>")]
pub struct ListGline;

#[derive(Message)]
#[rtype(result = "super::server::response::ConnectionValidated")]
pub struct ValidateConnection(pub InitiatedConnection);

/// Attempts to kick a user from a channel.
#[derive(Message)]
#[rtype(result = "()")]
diff --git a/src/persistence.rs b/src/persistence.rs
index 3df3267..706ab4a 100644
--- a/src/persistence.rs
+++ b/src/persistence.rs
@@ -15,8 +15,8 @@ use crate::{
    persistence::events::{
        ChannelCreated, ChannelJoined, ChannelMessage, ChannelParted,
        FetchAllUserChannelPermissions, FetchUnseenChannelMessages, FetchUnseenPrivateMessages,
        FetchUserChannels, FetchUserIdByNick, PrivateMessage, ReserveNick,
        SetUserChannelPermissions,
        FetchUserChannels, FetchUserIdByNick, PrivateMessage, ReserveNick, ServerBan,
        ServerListBan, ServerListBanEntry, ServerRemoveBan, SetUserChannelPermissions,
    },
};

@@ -390,6 +390,71 @@ impl Handler<ReserveNick> for Persistence {
    }
}

impl Handler<ServerBan> for Persistence {
    type Result = ResponseFuture<()>;

    fn handle(&mut self, msg: ServerBan, _ctx: &mut Self::Context) -> Self::Result {
        let database = self.database.clone();

        Box::pin(async move {
            sqlx::query(
                "INSERT INTO server_bans
                 (mask, requester, reason, created_timestamp, expires_timestamp)
                 VALUES (?, ?, ?, ?, ?)",
            )
            .bind(msg.mask)
            .bind(msg.requester)
            .bind(msg.reason)
            .bind(msg.created.timestamp_nanos_opt().unwrap())
            .bind(msg.expires.map(|v| v.timestamp_nanos_opt().unwrap()))
            .execute(&database)
            .await
            .unwrap();
        })
    }
}

impl Handler<ServerRemoveBan> for Persistence {
    type Result = ResponseFuture<()>;

    fn handle(&mut self, msg: ServerRemoveBan, _ctx: &mut Self::Context) -> Self::Result {
        let database = self.database.clone();

        Box::pin(async move {
            sqlx::query("DELETE FROM server_bans WHERE mask = ?")
                .bind(msg.mask)
                .execute(&database)
                .await
                .unwrap();
        })
    }
}

impl Handler<ServerListBan> for Persistence {
    type Result = ResponseFuture<Vec<ServerListBanEntry>>;

    fn handle(&mut self, _msg: ServerListBan, _ctx: &mut Self::Context) -> Self::Result {
        let database = self.database.clone();

        Box::pin(async move {
            sqlx::query_as(
                "SELECT
                   users.username AS requester,
                   server_bans.mask,
                   server_bans.reason,
                   server_bans.created_timestamp,
                   server_bans.expires_timestamp
                 FROM server_bans
                 INNER JOIN users
                   ON server_bans.requester = users.id",
            )
            .fetch_all(&database)
            .await
            .unwrap()
        })
    }
}

/// Remove any messages from the messages table whenever they've been seen by all users
/// or have passed their retention period
/// .
diff --git a/src/persistence/events.rs b/src/persistence/events.rs
index 4823091..4e69901 100644
--- a/src/persistence/events.rs
+++ b/src/persistence/events.rs
@@ -1,5 +1,6 @@
use actix::Message;
use chrono::{DateTime, Utc};
use sqlx::FromRow;
use tracing::Span;

use crate::{
@@ -98,3 +99,34 @@ pub struct ReserveNick {
    pub user_id: UserId,
    pub nick: String,
}

#[derive(Message)]
#[rtype(result = "()")]
pub struct ServerBan {
    pub mask: HostMask<'static>,
    pub requester: UserId,
    pub reason: String,
    pub created: DateTime<Utc>,
    pub expires: Option<DateTime<Utc>>,
}

#[derive(Message)]
#[rtype(result = "()")]
pub struct ServerRemoveBan {
    pub mask: HostMask<'static>,
}

#[derive(Message)]
#[rtype(result = "Vec<ServerListBanEntry>")]
pub struct ServerListBan;

#[derive(Message, FromRow)]
#[rtype(result = "()")]
pub struct ServerListBanEntry {
    pub mask: HostMask<'static>,
    pub requester: String,
    pub reason: String,
    // timestamp in nanos. todo: sqlx datetime<utc>
    pub created_timestamp: i64,
    pub expires_timestamp: Option<i64>,
}
diff --git a/src/proto.rs b/src/proto.rs
new file mode 100644
index 0000000..2d3c832
--- /dev/null
+++ b/src/proto.rs
@@ -0,0 +1,196 @@
use std::{convert::identity, str::FromStr, time::Duration};

use irc_proto::{Command, Message, Prefix, Response};
use thiserror::Error;

use crate::{host_mask::HostMask, server::response::IntoProtocol, SERVER_NAME};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum LocalCommand {
    ListGline,
    /// Unbans a hostmask
    RemoveGline(HostMask<'static>),
    /// Bans a hostmask from the network for the given duration with the given message
    Gline(HostMask<'static>, Option<Duration>, Option<String>),
}

impl TryFrom<(String, Vec<String>)> for LocalCommand {
    type Error = Error;

    fn try_from((command, args): (String, Vec<String>)) -> Result<Self, Self::Error> {
        match command.as_str() {
            "GLINE" if args.is_empty() => Ok(Self::ListGline),
            "GLINE" if args.len() == 1 && args[0].starts_with('-') => parse1(
                Self::RemoveGline,
                args,
                required(truncate_first_character(parse_host_mask)),
            ),
            "GLINE" => parse3(
                Self::Gline,
                args,
                required(parse_host_mask),
                opt(parse_duration),
                opt(wrap_ok(identity)),
            ),
            _ => Err(Error::UnknownCommand),
        }
    }
}

#[derive(Debug, Error)]
pub enum Error {
    #[error("unknown command")]
    UnknownCommand,
    #[error("missing argument")]
    MissingArgument,
    #[error("invalid duration: {0}")]
    InvalidDuration(humantime::DurationError),
    #[error("invalid host mask: {0}")]
    InvalidHostMask(std::io::Error),
    #[error("too many arguments")]
    TooManyArguments,
}

impl IntoProtocol for Error {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        vec![Message {
            tags: None,
            prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
            command: Command::Response(
                Response::ERR_UNKNOWNCOMMAND,
                vec![
                    for_user.to_string(),
                    "command".to_string(), // TODO
                    "Unknown command".to_string(),
                ],
            ),
        }]
    }
}

fn opt<T>(
    transform: impl FnOnce(String) -> Result<T, Error>,
) -> impl FnOnce(Option<String>) -> Result<Option<T>, Error> {
    move |v| v.map(transform).transpose()
}

fn required<T>(
    transform: impl FnOnce(String) -> Result<T, Error>,
) -> impl FnOnce(Option<String>) -> Result<T, Error> {
    move |v| v.ok_or(Error::MissingArgument).and_then(transform)
}

/// Truncates the first character from the first argument and calls the inner transform function.
fn truncate_first_character<T>(
    transform: fn(String) -> Result<T, Error>,
) -> impl Fn(String) -> Result<T, Error> {
    move |mut v| {
        v.remove(0);
        (transform)(v)
    }
}

/// Parses a host mask argument
#[allow(clippy::needless_pass_by_value)]
fn parse_host_mask(v: String) -> Result<HostMask<'static>, Error> {
    HostMask::from_str(&v).map_err(Error::InvalidHostMask)
}

/// Parses a humantime duration
#[allow(clippy::needless_pass_by_value)]
fn parse_duration(v: String) -> Result<Duration, Error> {
    humantime::parse_duration(&v).map_err(Error::InvalidDuration)
}

/// Takes a string argument as-is
fn wrap_ok<T>(transform: fn(String) -> T) -> impl Fn(String) -> Result<T, Error> {
    move |v| Ok((transform)(v))
}

/// Parses a single argument from `args`, transforming it using `t1`
/// and returns a `LocalCommand`.
fn parse1<T1>(
    out: fn(T1) -> LocalCommand,
    args: Vec<String>,
    t1: impl FnOnce(Option<String>) -> Result<T1, Error>,
) -> Result<LocalCommand, Error> {
    if args.len() > 1 {
        return Err(Error::TooManyArguments);
    }

    let mut i = args.into_iter();
    Ok((out)(t1(i.next())?))
}

/// Parses three arguments from `args`, transforming them using `t1`, `t2` and `t3`
/// and returns a `LocalCommand`.
fn parse3<T1, T2, T3>(
    out: fn(T1, T2, T3) -> LocalCommand,
    args: Vec<String>,
    t1: impl FnOnce(Option<String>) -> Result<T1, Error>,
    t2: impl FnOnce(Option<String>) -> Result<T2, Error>,
    t3: impl FnOnce(Option<String>) -> Result<T3, Error>,
) -> Result<LocalCommand, Error> {
    if args.len() > 3 {
        return Err(Error::TooManyArguments);
    }

    let mut i = args.into_iter();
    Ok((out)(t1(i.next())?, t2(i.next())?, t3(i.next())?))
}

#[cfg(test)]
mod test {
    use std::time::Duration;

    use crate::proto::{Error, LocalCommand};

    #[test]
    fn remove_gline() {
        let command =
            LocalCommand::try_from(("GLINE".to_string(), vec!["-aaa!bbb@ccc".to_string()]))
                .unwrap();
        assert_eq!(
            command,
            LocalCommand::RemoveGline("aaa!bbb@ccc".try_into().unwrap())
        );
    }

    #[test]
    fn gline() {
        let command = LocalCommand::try_from((
            "GLINE".to_string(),
            vec![
                "aaa!bbb@ccc".to_string(),
                "1d".to_string(),
                "comment".to_string(),
            ],
        ))
        .unwrap();
        assert_eq!(
            command,
            LocalCommand::Gline(
                "aaa!bbb@ccc".try_into().unwrap(),
                Some(Duration::from_secs(86_400)),
                Some("comment".to_string())
            )
        );
    }

    #[test]
    fn too_many_arguments() {
        let command = LocalCommand::try_from((
            "GLINE".to_string(),
            vec![
                "aaa!bbb@ccc".to_string(),
                "1d".to_string(),
                "comment".to_string(),
                "toomany".to_string(),
            ],
        ));
        assert!(
            matches!(command, Err(Error::TooManyArguments)),
            "{command:?}"
        );
    }
}
diff --git a/src/server.rs b/src/server.rs
index 8b705b3..49b6377 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -1,12 +1,13 @@
pub mod response;

use std::{borrow::Cow, collections::HashMap};
use std::{borrow::Cow, collections::HashMap, time::Duration};

use actix::{
    Actor, Addr, AsyncContext, Context, Handler, MessageResult, ResponseFuture, Supervised,
    Supervisor,
    Actor, ActorContext, ActorFuture, ActorFutureExt, Addr, AsyncContext, Context, Handler,
    MessageResult, ResponseFuture, Supervised, Supervisor, WrapFuture,
};
use actix_rt::Arbiter;
use chrono::Utc;
use clap::crate_version;
use futures::{
    future,
@@ -16,23 +17,28 @@ use futures::{
use irc_proto::{Command, Message, Prefix, Response};
use rand::seq::SliceRandom;
use tokio_stream::StreamExt;
use tracing::{debug, instrument, warn, Span};
use tracing::{debug, error, info, instrument, warn, Span};

use crate::{
    channel::{permissions::Permission, Channel, ChannelId},
    client::Client,
    config::Config,
    connection::{InitiatedConnection, UserMode},
    host_mask::HostMaskMap,
    host_mask::{HostMask, HostMaskMap},
    messages::{
        Broadcast, ChannelFetchTopic, ChannelFetchWhoList, ChannelJoin, ChannelList,
        ChannelMemberList, ClientAway, ConnectedChannels, FetchClientByNick, FetchWhoList,
        FetchWhois, ForceDisconnect, KillUser, MessageKind, PrivateMessage, ServerAdminInfo,
        ServerDisconnect, ServerFetchMotd, ServerListUsers, UserConnected, UserNickChange,
        UserNickChangeInternal, Wallops,
        FetchWhois, ForceDisconnect, Gline, KillUser, ListGline, MessageKind, PrivateMessage,
        RemoveGline, ServerAdminInfo, ServerDisconnect, ServerFetchMotd, ServerListUsers,
        UserConnected, UserNickChange, UserNickChangeInternal, ValidateConnection, Wallops,
    },
    persistence::{
        events::{ServerBan, ServerRemoveBan},
        Persistence,
    },
    server::response::{
        AdminInfo, ConnectionValidated, IntoProtocol, ListUsers, Motd, NoSuchNick, WhoList, Whois,
    },
    persistence::Persistence,
    server::response::{AdminInfo, IntoProtocol, ListUsers, Motd, NoSuchNick, WhoList, Whois},
    SERVER_NAME,
};

@@ -44,6 +50,7 @@ pub struct Server {
    pub max_clients: usize,
    pub config: Config,
    pub persistence: Addr<Persistence>,
    pub bans: HostMaskMap<response::ServerBan>,
}

impl Supervised for Server {}
@@ -65,6 +72,24 @@ impl Handler<UserNickChangeInternal> for Server {
    }
}

impl Handler<ValidateConnection> for Server {
    type Result = MessageResult<ValidateConnection>;

    #[allow(clippy::option_if_let_else)]
    fn handle(&mut self, msg: ValidateConnection, _ctx: &mut Self::Context) -> Self::Result {
        MessageResult(
            if let Some(ban) = self.bans.get(&msg.0.to_host_mask()).into_iter().next() {
                ConnectionValidated::Reject(format!(
                    "G-lined: {}",
                    ban.reason.as_deref().unwrap_or("no reason given")
                ))
            } else {
                ConnectionValidated::Allowed
            },
        )
    }
}

/// Received when a user connects to the server, and sends them the server preamble
impl Handler<UserConnected> for Server {
    type Result = ();
@@ -472,6 +497,124 @@ impl Handler<PrivateMessage> for Server {
    }
}

impl Handler<Gline> for Server {
    type Result = ();

    fn handle(&mut self, msg: Gline, _ctx: &mut Self::Context) -> Self::Result {
        let created = Utc::now();
        let expires = msg.duration.map(|v| created + v);

        // TODO: return ack msg
        self.bans.insert(
            &msg.mask,
            response::ServerBan {
                mask: msg.mask.clone(),
                requester: msg.requester.user.to_string(),
                reason: msg.reason.clone(),
                created,
                expires,
            },
        );

        // TODO: stop looping over all users
        let comment = format!(
            "G-lined: {}",
            msg.reason.as_deref().unwrap_or("no reason given")
        );
        for (handle, user) in &self.clients {
            if !self.bans.get(&user.to_host_mask()).is_empty() {
                handle.do_send(KillUser {
                    span: Span::current(),
                    killer: msg.requester.nick.to_string(),
                    comment: comment.to_string(),
                    killed: user.nick.to_string(),
                });
            }
        }

        self.persistence.do_send(ServerBan {
            mask: msg.mask,
            requester: msg.requester.user_id,
            reason: msg.reason.unwrap_or_default(),
            created,
            expires,
        });
    }
}

impl Handler<RemoveGline> for Server {
    type Result = ();

    fn handle(&mut self, msg: RemoveGline, _ctx: &mut Self::Context) -> Self::Result {
        // TODO: return ack msg
        self.bans.remove(&msg.mask);

        self.persistence.do_send(ServerRemoveBan { mask: msg.mask });
    }
}

impl Handler<ListGline> for Server {
    type Result = MessageResult<ListGline>;

    fn handle(&mut self, _msg: ListGline, _ctx: &mut Self::Context) -> Self::Result {
        MessageResult(self.bans.iter().map(|(_, v)| v.clone()).collect())
    }
}

impl Actor for Server {
    type Context = Context<Self>;

    fn started(&mut self, ctx: &mut Self::Context) {
        ctx.wait(self.load_server_ban_list());
        ctx.run_interval(Duration::from_secs(30), Self::remove_expired_bans);
    }
}

impl Server {
    fn load_server_ban_list(&mut self) -> impl ActorFuture<Self, Output = ()> + 'static {
        self.persistence
            .send(crate::persistence::events::ServerListBan)
            .into_actor(self)
            .map(|res, this, ctx| match res {
                Ok(bans) => {
                    this.bans = bans
                        .into_iter()
                        .map(|v| (v.mask.clone(), v.into()))
                        .collect();
                }
                Err(error) => {
                    error!(%error, "Failed to fetch bans");
                    ctx.terminate();
                }
            })
    }

    fn remove_expired_bans(&mut self, _ctx: &mut Context<Self>) {
        let mut expired = Vec::new();

        for (mask, ban) in self.bans.iter() {
            let Some(expires_at) = ban.expires else {
                continue;
            };

            if expires_at > Utc::now() {
                continue;
            }

            let Ok(mask) = HostMask::try_from(mask.as_str()) else {
                continue;
            };

            expired.push(mask.into_owned());
        }

        for mask in expired {
            info!("Removing expired ban on {mask}");

            self.bans.remove(&mask);
            self.persistence.do_send(ServerRemoveBan {
                mask: mask.into_owned(),
            });
        }
    }
}
diff --git a/src/server/response.rs b/src/server/response.rs
index 6c062b0..8734bfe 100644
--- a/src/server/response.rs
+++ b/src/server/response.rs
@@ -1,8 +1,10 @@
use chrono::{DateTime, TimeZone, Utc};
use irc_proto::{Command, Message, Prefix, Response};
use itertools::Itertools;

use crate::{
    channel::permissions::Permission, connection::InitiatedConnection, server::Server, SERVER_NAME,
    channel::permissions::Permission, connection::InitiatedConnection, host_mask::HostMask,
    persistence::events::ServerListBanEntry, server::Server, SERVER_NAME,
};

pub struct Whois {
@@ -396,6 +398,58 @@ pub struct ChannelListItem {
    pub topic: Option<String>,
}

#[derive(Clone, Debug)]
pub struct ServerBan {
    pub mask: HostMask<'static>,
    pub requester: String,
    pub reason: Option<String>,
    pub created: DateTime<Utc>,
    pub expires: Option<DateTime<Utc>>,
}

impl From<ServerListBanEntry> for ServerBan {
    fn from(value: ServerListBanEntry) -> Self {
        Self {
            mask: value.mask,
            requester: value.requester,
            reason: Some(value.reason).filter(|v| !v.is_empty()),
            created: Utc.timestamp_nanos(value.created_timestamp),
            expires: value.expires_timestamp.map(|v| Utc.timestamp_nanos(v)),
        }
    }
}

impl IntoProtocol for ServerBan {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        vec![Message {
            tags: None,
            prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
            command: Command::Raw(
                "216".to_string(),
                vec![
                    for_user.to_string(),
                    format!(
                        "{} by {} ({}), created {}, expires {}",
                        self.mask,
                        self.requester,
                        self.reason.as_deref().unwrap_or("no reason given"),
                        self.created,
                        self.expires
                            .map(|v| v.to_string())
                            .as_deref()
                            .unwrap_or("never")
                    ),
                ],
            ),
        }]
    }
}

pub enum ConnectionValidated {
    Allowed,
    Reject(String),
}

pub trait IntoProtocol {
    #[must_use]
    fn into_messages(self, for_user: &str) -> Vec<Message>;
@@ -425,3 +479,11 @@ where
        }
    }
}

impl<T: IntoProtocol> IntoProtocol for Vec<T> {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        self.into_iter()
            .flat_map(|v| v.into_messages(for_user))
            .collect()
    }
}