From 18d9211dd79eaf6affbd334aa00d031993496b31 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Wed, 31 Jan 2024 03:23:59 +0000 Subject: [PATCH] Implement ban listing --- src/channel.rs | 29 +++++++++++++++++++++++------ src/channel/response.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++++ src/client.rs | 14 +++++++++----- src/host_mask.rs | 40 ++++++++++++++++++++++++++++++++++++++++ src/messages.rs | 2 +- src/server/response.rs | 6 ++++++ 6 files changed, 126 insertions(+), 12 deletions(-) diff --git a/src/channel.rs b/src/channel.rs index 09722f6..26b6473 100644 --- a/src/channel.rs +++ b/src/channel.rs @@ -16,8 +16,8 @@ use crate::{ channel::{ permissions::Permission, response::{ - ChannelInviteResult, ChannelJoinRejectionReason, ChannelNamesList, ChannelTopic, - ChannelWhoList, MissingPrivileges, + BanList, ChannelInviteResult, ChannelJoinRejectionReason, ChannelNamesList, + ChannelTopic, ChannelWhoList, MissingPrivileges, ModeList, }, }, client::Client, @@ -241,12 +241,12 @@ impl Handler for Channel { } impl Handler for Channel { - type Result = (); + type Result = MessageResult; #[instrument(parent = &msg.span, skip_all)] fn handle(&mut self, msg: ChannelSetMode, ctx: &mut Self::Context) -> Self::Result { let Some(client) = self.clients.get(&msg.client).cloned() else { - return; + return MessageResult(None); }; for mode in msg.modes { @@ -258,9 +258,24 @@ impl Handler for Channel { if let Ok(user_mode) = Permission::try_from(channel_mode) { let Some(affected_mask) = arg else { - // TODO: return error to caller + if add && matches!(user_mode, Permission::Ban) { + // list is readable and the user didn't supply a mask, so + // return the list + let bans = BanList { + channel: self.name.to_string(), + list: self + .permissions + .iter() + .filter(|(_, v)| matches!(v, Permission::Ban)) + .map(|(k, _)| k) + .collect(), + }; + + return MessageResult(Some(ModeList::Ban(bans))); + } + error!("No user given"); - continue; + break; }; let Ok(affected_mask) = HostMask::try_from(affected_mask.as_str()) else { @@ -280,6 +295,8 @@ impl Handler for Channel { // TODO } } + + MessageResult(None) } } diff --git a/src/channel/response.rs b/src/channel/response.rs index 6662f4c..0ef451f 100644 --- a/src/channel/response.rs +++ b/src/channel/response.rs @@ -1,3 +1,5 @@ +use std::iter::once; + use irc_proto::{Command, Message, Prefix, Response}; use itertools::Itertools; @@ -124,6 +126,51 @@ impl IntoProtocol for ChannelWhoList { } } +pub enum ModeList { + Ban(BanList), +} + +impl IntoProtocol for ModeList { + fn into_messages(self, for_user: &str) -> Vec { + match self { + Self::Ban(l) => l.into_messages(for_user), + } + } +} + +pub struct BanList { + pub channel: String, + pub list: Vec, +} + +impl IntoProtocol for BanList { + fn into_messages(self, for_user: &str) -> Vec { + self.list + .into_iter() + .map(|mask| Message { + tags: None, + prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), + command: Command::Response( + Response::RPL_BANLIST, + vec![for_user.to_string(), self.channel.to_string(), mask], + ), + }) + .chain(once(Message { + tags: None, + prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())), + command: Command::Response( + Response::RPL_ENDOFBANLIST, + vec![ + for_user.to_string(), + self.channel.to_string(), + "End of channel ban list".to_string(), + ], + ), + })) + .collect() + } +} + pub struct ChannelNamesList { pub channel_name: String, pub nick_list: Vec<(Permission, InitiatedConnection)>, diff --git a/src/client.rs b/src/client.rs index def7985..192f31c 100644 --- a/src/client.rs +++ b/src/client.rs @@ -690,11 +690,15 @@ impl StreamHandler> for Client { return; }; - channel.do_send(ChannelSetMode { - span: Span::current(), - client: ctx.address(), - modes, - }); + self.channel_send_map_write( + ctx, + channel, + ChannelSetMode { + span: Span::current(), + client: ctx.address(), + modes, + }, + ); } Command::TOPIC(channel, topic) => { let Some(channel) = self.channels.get(&channel) else { diff --git a/src/host_mask.rs b/src/host_mask.rs index 0782f99..8d17463 100644 --- a/src/host_mask.rs +++ b/src/host_mask.rs @@ -3,9 +3,11 @@ use std::{ collections::HashMap, fmt::{Display, Formatter}, io::{Error, ErrorKind}, + iter::once, str::FromStr, }; +use itertools::Either; use sqlx::{ database::{HasArguments, HasValueRef}, encode::IsNull, @@ -31,6 +33,36 @@ impl HostMaskMap { } } + pub fn iter(&self) -> impl Iterator { + self.iter_inner(String::new(), self.matcher) + } + + fn iter_inner(&self, s: String, last_seen: Matcher) -> impl Iterator { + self.children + .iter() + .flat_map(move |(k, v)| { + let (k, next_matcher) = match k { + Key::Wildcard => ( + format!("{s}*{}", last_seen.splitter()), + last_seen.next().unwrap_or(last_seen), + ), + Key::EndOfString => ( + format!("{s}{}", last_seen.splitter()), + last_seen.next().unwrap_or(last_seen), + ), + Key::Char(c) => (format!("{s}{c}"), last_seen), + }; + + match v { + Node::Match(v) => Either::Left(once((k, v))), + Node::Inner(v) => Either::Right(v.iter_inner(k, next_matcher)), + } + }) + // TODO + .collect::>() + .into_iter() + } + #[must_use] pub fn is_empty(&self) -> bool { self.children.is_empty() @@ -177,6 +209,14 @@ impl Matcher { Self::Host => None, } } + + const fn splitter(self) -> &'static str { + match self { + Self::Nick => "!", + Self::Username => "@", + Self::Host => "", + } + } } #[derive(Clone, Debug)] diff --git a/src/messages.rs b/src/messages.rs index 59212c0..99349b4 100644 --- a/src/messages.rs +++ b/src/messages.rs @@ -165,7 +165,7 @@ pub struct ChannelFetchWhoList { /// Sets the given modes on a channel. #[derive(Message)] -#[rtype(result = "()")] +#[rtype(result = "Option")] pub struct ChannelSetMode { pub span: Span, pub client: Addr, diff --git a/src/server/response.rs b/src/server/response.rs index 2e5d3f4..6c062b0 100644 --- a/src/server/response.rs +++ b/src/server/response.rs @@ -407,6 +407,12 @@ impl IntoProtocol for () { } } +impl IntoProtocol for Option { + fn into_messages(self, for_user: &str) -> Vec { + self.map_or_else(Vec::new, |v| v.into_messages(for_user)) + } +} + impl IntoProtocol for Result where T: IntoProtocol, -- libgit2 1.7.2