Implement ban listing
Diff
src/channel.rs | 29 +++++++++++++++++++++++------
src/channel/response.rs | 47 +++++++++++++++++++++++++++++++++++++++++++++++
src/client.rs | 14 +++++++++-----
src/host_mask.rs | 40 ++++++++++++++++++++++++++++++++++++++++
src/messages.rs | 2 +-
src/server/response.rs | 6 ++++++
6 files changed, 126 insertions(+), 12 deletions(-)
@@ -16,8 +16,8 @@ use crate::{
channel::{
permissions::Permission,
response::{
ChannelInviteResult, ChannelJoinRejectionReason, ChannelNamesList, ChannelTopic,
ChannelWhoList, MissingPrivileges,
BanList, ChannelInviteResult, ChannelJoinRejectionReason, ChannelNamesList,
ChannelTopic, ChannelWhoList, MissingPrivileges, ModeList,
},
},
client::Client,
@@ -241,12 +241,12 @@ impl Handler<ChannelFetchWhoList> for Channel {
}
impl Handler<ChannelSetMode> for Channel {
type Result = ();
type Result = MessageResult<ChannelSetMode>;
#[instrument(parent = &msg.span, skip_all)]
fn handle(&mut self, msg: ChannelSetMode, ctx: &mut Self::Context) -> Self::Result {
let Some(client) = self.clients.get(&msg.client).cloned() else {
return;
return MessageResult(None);
};
for mode in msg.modes {
@@ -258,9 +258,24 @@ impl Handler<ChannelSetMode> for Channel {
if let Ok(user_mode) = Permission::try_from(channel_mode) {
let Some(affected_mask) = arg else {
if add && matches!(user_mode, Permission::Ban) {
let bans = BanList {
channel: self.name.to_string(),
list: self
.permissions
.iter()
.filter(|(_, v)| matches!(v, Permission::Ban))
.map(|(k, _)| k)
.collect(),
};
return MessageResult(Some(ModeList::Ban(bans)));
}
error!("No user given");
continue;
break;
};
let Ok(affected_mask) = HostMask::try_from(affected_mask.as_str()) else {
@@ -280,6 +295,8 @@ impl Handler<ChannelSetMode> for Channel {
}
}
MessageResult(None)
}
}
@@ -1,3 +1,5 @@
use std::iter::once;
use irc_proto::{Command, Message, Prefix, Response};
use itertools::Itertools;
@@ -124,6 +126,51 @@ impl IntoProtocol for ChannelWhoList {
}
}
pub enum ModeList {
Ban(BanList),
}
impl IntoProtocol for ModeList {
fn into_messages(self, for_user: &str) -> Vec<Message> {
match self {
Self::Ban(l) => l.into_messages(for_user),
}
}
}
pub struct BanList {
pub channel: String,
pub list: Vec<String>,
}
impl IntoProtocol for BanList {
fn into_messages(self, for_user: &str) -> Vec<Message> {
self.list
.into_iter()
.map(|mask| Message {
tags: None,
prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
command: Command::Response(
Response::RPL_BANLIST,
vec![for_user.to_string(), self.channel.to_string(), mask],
),
})
.chain(once(Message {
tags: None,
prefix: Some(Prefix::ServerName(SERVER_NAME.to_string())),
command: Command::Response(
Response::RPL_ENDOFBANLIST,
vec![
for_user.to_string(),
self.channel.to_string(),
"End of channel ban list".to_string(),
],
),
}))
.collect()
}
}
pub struct ChannelNamesList {
pub channel_name: String,
pub nick_list: Vec<(Permission, InitiatedConnection)>,
@@ -690,11 +690,15 @@ impl StreamHandler<Result<irc_proto::Message, ProtocolError>> for Client {
return;
};
channel.do_send(ChannelSetMode {
span: Span::current(),
client: ctx.address(),
modes,
});
self.channel_send_map_write(
ctx,
channel,
ChannelSetMode {
span: Span::current(),
client: ctx.address(),
modes,
},
);
}
Command::TOPIC(channel, topic) => {
let Some(channel) = self.channels.get(&channel) else {
@@ -3,9 +3,11 @@ use std::{
collections::HashMap,
fmt::{Display, Formatter},
io::{Error, ErrorKind},
iter::once,
str::FromStr,
};
use itertools::Either;
use sqlx::{
database::{HasArguments, HasValueRef},
encode::IsNull,
@@ -31,6 +33,36 @@ impl<T> HostMaskMap<T> {
}
}
pub fn iter(&self) -> impl Iterator<Item = (String, &T)> {
self.iter_inner(String::new(), self.matcher)
}
fn iter_inner(&self, s: String, last_seen: Matcher) -> impl Iterator<Item = (String, &T)> {
self.children
.iter()
.flat_map(move |(k, v)| {
let (k, next_matcher) = match k {
Key::Wildcard => (
format!("{s}*{}", last_seen.splitter()),
last_seen.next().unwrap_or(last_seen),
),
Key::EndOfString => (
format!("{s}{}", last_seen.splitter()),
last_seen.next().unwrap_or(last_seen),
),
Key::Char(c) => (format!("{s}{c}"), last_seen),
};
match v {
Node::Match(v) => Either::Left(once((k, v))),
Node::Inner(v) => Either::Right(v.iter_inner(k, next_matcher)),
}
})
.collect::<Vec<_>>()
.into_iter()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.children.is_empty()
@@ -177,6 +209,14 @@ impl Matcher {
Self::Host => None,
}
}
const fn splitter(self) -> &'static str {
match self {
Self::Nick => "!",
Self::Username => "@",
Self::Host => "",
}
}
}
#[derive(Clone, Debug)]
@@ -165,7 +165,7 @@ pub struct ChannelFetchWhoList {
#[derive(Message)]
#[rtype(result = "()")]
#[rtype(result = "Option<super::channel::response::ModeList>")]
pub struct ChannelSetMode {
pub span: Span,
pub client: Addr<Client>,
@@ -407,6 +407,12 @@ impl IntoProtocol for () {
}
}
impl<T: IntoProtocol> IntoProtocol for Option<T> {
fn into_messages(self, for_user: &str) -> Vec<Message> {
self.map_or_else(Vec::new, |v| v.into_messages(for_user))
}
}
impl<T, E> IntoProtocol for Result<T, E>
where
T: IntoProtocol,