🏡 index : ~doyle/titanirc.git

author Jordan Doyle <jordan@doyle.la> 2024-01-30 20:52:52.0 +00:00:00
committer Jordan Doyle <jordan@doyle.la> 2024-01-30 20:52:52.0 +00:00:00
commit
80c1733739c7c33002b458089514676529ddfbeb [patch]
tree
0879e9f28a86309063166c291a93445382672ac8
parent
85756a6a8fd501482cde14bc263d338a05b16dd1
download
80c1733739c7c33002b458089514676529ddfbeb.tar.gz

Introduce IntoProtocol trait for cleaning up error handling



Diff

 src/channel.rs          |  9 ++----
 src/channel/response.rs | 36 ++++++++++++++---------
 src/client.rs           | 64 +++++++++++++---------------------------
 src/messages.rs         |  4 ++-
 src/server.rs           |  9 +++---
 src/server/response.rs  | 80 ++++++++++++++++++++++++++++++++------------------
 6 files changed, 108 insertions(+), 94 deletions(-)

diff --git a/src/channel.rs b/src/channel.rs
index 377443c..36816e4 100644
--- a/src/channel.rs
+++ b/src/channel.rs
@@ -32,7 +32,7 @@ use crate::{
        events::{FetchAllUserChannelPermissions, SetUserChannelPermissions},
        Persistence,
    },
    server::Server,
    server::{response::IntoProtocol, Server},
};

#[derive(Copy, Clone)]
@@ -441,7 +441,7 @@ impl Handler<ChannelJoin> for Channel {
        }

        // send the channel's topic to the joining user
        for message in ChannelTopic::new(self).into_messages(self.name.to_string(), true) {
        for message in ChannelTopic::new(self, true).into_messages(&self.name) {
            msg.client.do_send(Broadcast {
                message,
                span: Span::current(),
@@ -497,8 +497,7 @@ impl Handler<ChannelUpdateTopic> for Channel {
        });

        for (client, connection) in &self.clients {
            for message in ChannelTopic::new(self).into_messages(connection.nick.to_string(), false)
            {
            for message in ChannelTopic::new(self, false).into_messages(&connection.nick) {
                client.do_send(Broadcast {
                    message,
                    span: Span::current(),
@@ -569,7 +568,7 @@ impl Handler<ChannelFetchTopic> for Channel {

    #[instrument(parent = &msg.span, skip_all)]
    fn handle(&mut self, msg: ChannelFetchTopic, _ctx: &mut Self::Context) -> Self::Result {
        MessageResult(ChannelTopic::new(self))
        MessageResult(ChannelTopic::new(self, msg.skip_on_none))
    }
}

diff --git a/src/channel/response.rs b/src/channel/response.rs
index c8a4bc1..5fc927b 100644
--- a/src/channel/response.rs
+++ b/src/channel/response.rs
@@ -4,25 +4,29 @@ use itertools::Itertools;
use crate::{
    channel::{permissions::Permission, Channel, CurrentChannelTopic},
    connection::InitiatedConnection,
    server::response::IntoProtocol,
    SERVER_NAME,
};

pub struct ChannelTopic {
    pub channel_name: String,
    pub topic: Option<CurrentChannelTopic>,
    pub skip_on_none: bool,
}

impl ChannelTopic {
    #[must_use]
    pub fn new(channel: &Channel) -> Self {
    pub fn new(channel: &Channel, skip_on_none: bool) -> Self {
        Self {
            channel_name: channel.name.to_string(),
            topic: channel.topic.clone(),
            skip_on_none,
        }
    }
}

    #[must_use]
    pub fn into_messages(self, for_user: String, skip_on_none: bool) -> Vec<Message> {
impl IntoProtocol for ChannelTopic {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        if let Some(topic) = self.topic {
            vec![
                Message {
@@ -43,7 +47,7 @@ impl ChannelTopic {
                    command: Command::Response(
                        Response::RPL_TOPICWHOTIME,
                        vec![
                            for_user,
                            for_user.to_string(),
                            self.channel_name.to_string(),
                            topic.set_by,
                            topic.set_time.timestamp().to_string(),
@@ -51,13 +55,17 @@ impl ChannelTopic {
                    ),
                },
            ]
        } else if !skip_on_none {
        } else if !self.skip_on_none {
            vec![Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
                    Response::RPL_NOTOPIC,
                    vec![for_user, self.channel_name, "No topic is set".to_string()],
                    vec![
                        for_user.to_string(),
                        self.channel_name,
                        "No topic is set".to_string(),
                    ],
                ),
            }]
        } else {
@@ -83,9 +91,10 @@ impl ChannelWhoList {
                .collect(),
        }
    }
}

    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for ChannelWhoList {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        let mut out = Vec::with_capacity(self.nick_list.len());

        for (perm, conn) in self.nick_list {
@@ -233,18 +242,17 @@ pub enum ChannelJoinRejectionReason {
    Banned,
}

impl ChannelJoinRejectionReason {
    #[must_use]
    pub fn into_message(self) -> Message {
impl IntoProtocol for ChannelJoinRejectionReason {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        match self {
            Self::Banned => Message {
            Self::Banned => vec![Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
                    Response::ERR_BANNEDFROMCHAN,
                    vec!["Cannot join channel (+b)".to_string()],
                    vec![for_user.to_string(), "Cannot join channel (+b)".to_string()],
                ),
            },
            }],
        }
    }
}
diff --git a/src/client.rs b/src/client.rs
index 7afc711..b00f71c 100644
--- a/src/client.rs
+++ b/src/client.rs
@@ -36,7 +36,7 @@ use crate::{
        Persistence,
    },
    server::{
        response::{NoSuchNick, WhoList},
        response::{IntoProtocol, WhoList},
        Server,
    },
    SERVER_NAME,
@@ -159,10 +159,9 @@ impl Client {
        ctx: &mut Context<Self>,
        channel: &Addr<Channel>,
        message: M,
        map: impl FnOnce(M::Result, &Self) -> Vec<Message> + 'static,
    ) where
        M: actix::Message + Send + 'static,
        M::Result: Send,
        M::Result: Send + IntoProtocol,
        Channel: Handler<M>,
        <Channel as Actor>::Context: ToEnvelope<Channel, M>,
    {
@@ -170,21 +169,17 @@ impl Client {
            .send(message)
            .into_actor(self)
            .map(move |result, ref mut this, _ctx| {
                for message in (map)(result.unwrap(), this) {
                for message in result.unwrap().into_messages(&this.connection.nick) {
                    this.writer.write(message);
                }
            });
        ctx.spawn(fut);
    }

    fn server_send_map_write<M>(
        &self,
        ctx: &mut Context<Self>,
        message: M,
        map: impl FnOnce(M::Result, &Self) -> Vec<Message> + 'static,
    ) where
    fn server_send_map_write<M>(&self, ctx: &mut Context<Self>, message: M)
    where
        M: actix::Message + Send + 'static,
        M::Result: Send,
        M::Result: Send + IntoProtocol,
        Server: Handler<M>,
        <Server as Actor>::Context: ToEnvelope<Server, M>,
    {
@@ -193,7 +188,7 @@ impl Client {
                .send(message)
                .into_actor(self)
                .map(move |result, ref mut this, _ctx| {
                    for message in (map)(result.unwrap(), this) {
                    for message in result.unwrap().into_messages(&this.connection.nick) {
                        this.writer.write(message);
                    }
                });
@@ -304,7 +299,7 @@ impl Handler<ForceDisconnect> for Client {

    fn handle(&mut self, _msg: ForceDisconnect, ctx: &mut Self::Context) -> Self::Result {
        ctx.stop();
        MessageResult(true)
        MessageResult(Ok(()))
    }
}

@@ -450,7 +445,9 @@ impl Handler<JoinChannelRequest> for Client {
                    Ok(v) => v,
                    Err(error) => {
                        error!(?error, "User failed to join channel");
                        this.writer.write(error.into_message());
                        for m in error.into_messages(&this.connection.nick) {
                            this.writer.write(m);
                        }
                        continue;
                    }
                };
@@ -718,8 +715,10 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
                    self.channel_send_map_write(
                        ctx,
                        channel,
                        ChannelFetchTopic { span },
                        |res, this| res.into_messages(this.connection.nick.to_string(), false),
                        ChannelFetchTopic {
                            span,
                            skip_on_none: false,
                        },
                    );
                }
            }
@@ -740,9 +739,7 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
            }
            Command::LIST(_, _) => {
                let span = Span::current();
                self.server_send_map_write(ctx, ChannelList { span }, |res, this| {
                    res.into_messages(this.connection.nick.to_string())
                });
                self.server_send_map_write(ctx, ChannelList { span });
            }
            Command::INVITE(nick, channel) => {
                let Some(channel) = self.channels.get(&channel) else {
@@ -800,15 +797,11 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
            }
            Command::MOTD(_) => {
                let span = Span::current();
                self.server_send_map_write(ctx, ServerFetchMotd { span }, |res, this| {
                    res.into_messages(this.connection.nick.to_string())
                });
                self.server_send_map_write(ctx, ServerFetchMotd { span });
            }
            Command::LUSERS(_, _) => {
                let span = Span::current();
                self.server_send_map_write(ctx, ServerListUsers { span }, |res, this| {
                    res.into_messages(&this.connection.nick)
                });
                self.server_send_map_write(ctx, ServerListUsers { span });
            }
            Command::VERSION(_) => {
                self.writer.write(Message {
@@ -843,9 +836,7 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
            }
            Command::ADMIN(_) => {
                let span = Span::current();
                self.server_send_map_write(ctx, ServerAdminInfo { span }, |res, this| {
                    res.into_messages(&this.connection.nick)
                });
                self.server_send_map_write(ctx, ServerAdminInfo { span });
            }
            Command::INFO(_) => {
                static INFO_STR: &str = include_str!("../text/info.txt");
@@ -874,15 +865,11 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
            }
            Command::WHO(Some(query), _) => {
                let span = Span::current();
                self.server_send_map_write(ctx, FetchWhoList { span, query }, |res, this| {
                    res.into_messages(&this.connection.nick)
                });
                self.server_send_map_write(ctx, FetchWhoList { span, query });
            }
            Command::WHOIS(Some(query), _) => {
                let span = Span::current();
                self.server_send_map_write(ctx, FetchWhois { span, query }, |res, this| {
                    res.into_messages(&this.connection.nick)
                });
                self.server_send_map_write(ctx, FetchWhois { span, query });
            }
            Command::WHOWAS(_, _, _) => {}
            Command::KILL(nick, comment) => {
@@ -936,16 +923,9 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
                    ctx,
                    ForceDisconnect {
                        span,
                        user: user.to_string(),
                        user,
                        comment,
                    },
                    move |res, this| {
                        if res {
                            vec![]
                        } else {
                            NoSuchNick { nick: user }.into_messages(&this.connection.nick)
                        }
                    },
                );
            }
            Command::AUTHENTICATE(_) => {
diff --git a/src/messages.rs b/src/messages.rs
index ab1aec6..dcacfbc 100644
--- a/src/messages.rs
+++ b/src/messages.rs
@@ -7,6 +7,7 @@ use crate::{
    channel::Channel,
    client::Client,
    connection::{InitiatedConnection, UserId},
    server::response::NoSuchNick,
};

/// Sent when a user is connecting to the server.
@@ -38,7 +39,7 @@ pub struct KillUser {
}

#[derive(Message, Clone)]
#[rtype(result = "bool")]
#[rtype(result = "Result<(), NoSuchNick>")]
pub struct ForceDisconnect {
    pub span: Span,
    pub user: String,
@@ -151,6 +152,7 @@ pub struct FetchUserPermission {
#[rtype(result = "super::channel::response::ChannelTopic")]
pub struct ChannelFetchTopic {
    pub span: Span,
    pub skip_on_none: bool,
}

/// Retrieves the WHO list for the channel.
diff --git a/src/server.rs b/src/server.rs
index c3c386d..c762b0c 100644
--- a/src/server.rs
+++ b/src/server.rs
@@ -31,7 +31,7 @@ use crate::{
        UserNickChangeInternal, Wallops,
    },
    persistence::Persistence,
    server::response::{AdminInfo, ListUsers, Motd, WhoList, Whois},
    server::response::{AdminInfo, IntoProtocol, ListUsers, Motd, NoSuchNick, WhoList, Whois},
    SERVER_NAME,
};

@@ -124,7 +124,7 @@ impl Handler<UserConnected> for Server {
            });
        }

        for message in Motd::new(self).into_messages(msg.connection.nick.clone()) {
        for message in Motd::new(self).into_messages(&msg.connection.nick) {
            msg.handle.do_send(Broadcast {
                span: Span::current(),
                message,
@@ -313,9 +313,9 @@ impl Handler<ForceDisconnect> for Server {
    fn handle(&mut self, msg: ForceDisconnect, _ctx: &mut Self::Context) -> Self::Result {
        if let Some((handle, _)) = self.clients.iter().find(|(_, v)| v.nick == msg.user) {
            handle.do_send(msg);
            MessageResult(true)
            MessageResult(Ok(()))
        } else {
            MessageResult(false)
            MessageResult(Err(NoSuchNick { nick: msg.user }))
        }
    }
}
@@ -371,6 +371,7 @@ impl Handler<ChannelList> for Server {
            .map(|channel| {
                let fetch_topic = channel.send(ChannelFetchTopic {
                    span: Span::current(),
                    skip_on_none: true,
                });

                let fetch_members = channel.send(ChannelMemberList {
diff --git a/src/server/response.rs b/src/server/response.rs
index 3f75c04..8c38ce7 100644
--- a/src/server/response.rs
+++ b/src/server/response.rs
@@ -11,9 +11,8 @@ pub struct Whois {
    pub channels: Vec<(Permission, String)>,
}

impl Whois {
    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for Whois {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        macro_rules! msg {
            ($response:ident, $($payload:expr),*) => {

@@ -88,9 +87,8 @@ pub struct NoSuchNick {
    pub nick: String,
}

impl NoSuchNick {
    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for NoSuchNick {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        vec![Message {
            tags: None,
            prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
@@ -108,9 +106,8 @@ pub struct WhoList {
    pub query: String,
}

impl WhoList {
    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for WhoList {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        let mut out: Vec<_> = self
            .list
            .into_iter()
@@ -140,9 +137,8 @@ pub struct AdminInfo {
    pub email: String,
}

impl AdminInfo {
    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
impl IntoProtocol for AdminInfo {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        macro_rules! msg {
            ($response:ident, $($payload:expr),*) => {

@@ -177,9 +173,9 @@ pub struct ListUsers {
    pub channels_formed: usize,
}

impl ListUsers {
impl IntoProtocol for ListUsers {
    #[must_use]
    pub fn into_messages(self, for_user: &str) -> Vec<Message> {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        macro_rules! msg {
            ($response:ident, $($payload:expr),*) => {

@@ -253,11 +249,15 @@ impl Motd {
            motd: server.config.motd.clone(),
        }
    }
}

impl IntoProtocol for Motd {
    #[must_use]
    pub fn into_messages(self, for_user: String) -> Vec<Message> {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        let mut out = Vec::new();

        if let Some(motd) = self.motd {
            let mut motd_messages = vec![Message {
            out.push(Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
@@ -267,9 +267,9 @@ impl Motd {
                        format!("- {SERVER_NAME} Message of the day -"),
                    ],
                ),
            }];
            });

            motd_messages.extend(motd.trim().split('\n').map(|v| Message {
            out.extend(motd.trim().split('\n').map(|v| Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
@@ -278,26 +278,26 @@ impl Motd {
                ),
            }));

            motd_messages.push(Message {
            out.push(Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
                    Response::RPL_ENDOFMOTD,
                    vec![for_user, "End of /MOTD command.".to_string()],
                    vec![for_user.to_string(), "End of /MOTD command.".to_string()],
                ),
            });

            motd_messages
        } else {
            vec![Message {
            out.push(Message {
                tags: None,
                prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
                command: Command::Response(
                    Response::ERR_NOMOTD,
                    vec![for_user, "MOTD File is missing".to_string()],
                    vec![for_user.to_string(), "MOTD File is missing".to_string()],
                ),
            }]
            });
        }

        out
    }
}

@@ -306,9 +306,9 @@ pub struct ChannelList {
    pub members: Vec<ChannelListItem>,
}

impl ChannelList {
impl IntoProtocol for ChannelList {
    #[must_use]
    pub fn into_messages(self, for_user: String) -> Vec<Message> {
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        let mut messages = Vec::with_capacity(self.members.len() + 2);

        messages.push(Message {
@@ -345,7 +345,7 @@ impl ChannelList {
            prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
            command: Command::Response(
                Response::RPL_LISTEND,
                vec![for_user, "End of /LIST".to_string()],
                vec![for_user.to_string(), "End of /LIST".to_string()],
            ),
        });

@@ -358,3 +358,27 @@ pub struct ChannelListItem {
    pub client_count: usize,
    pub topic: Option<String>,
}

pub trait IntoProtocol {
    #[must_use]
    fn into_messages(self, for_user: &str) -> Vec<Message>;
}

impl IntoProtocol for () {
    fn into_messages(self, _for_user: &str) -> Vec<Message> {
        vec![]
    }
}

impl<T, E> IntoProtocol for Result<T, E>
where
    T: IntoProtocol,
    E: IntoProtocol,
{
    fn into_messages(self, for_user: &str) -> Vec<Message> {
        match self {
            Ok(v) => v.into_messages(for_user),
            Err(e) => e.into_messages(for_user),
        }
    }
}