Use hostmasks for channel user mode setting
Diff
migrations/2023010814480_initial-schema.sql | 9 ++-
src/channel.rs | 98 ++++++++++++++--------------
src/channel/permissions.rs | 25 ++++++-
src/channel/response.rs | 6 +-
src/client.rs | 8 +--
src/connection.rs | 10 ++-
src/host_mask.rs | 101 +++++++++++++++++++++++++++--
src/messages.rs | 3 +-
src/persistence.rs | 29 +++-----
src/persistence/events.rs | 7 +--
src/server.rs | 3 +-
11 files changed, 210 insertions(+), 89 deletions(-)
@@ -32,7 +32,6 @@ CREATE TABLE channel_messages (
CREATE TABLE channel_users (
channel INT NOT NULL,
user INT NOT NULL,
permissions INT NOT NULL DEFAULT 0,
in_channel BOOLEAN DEFAULT false,
last_seen_message_timestamp INT,
FOREIGN KEY(user) REFERENCES users(id),
@@ -41,6 +40,14 @@ CREATE TABLE channel_users (
PRIMARY KEY(channel, user)
);
CREATE TABLE channel_permissions (
channel INT NOT NULL,
mask VARCHAR(255),
permissions INT NOT NULL DEFAULT 0,
FOREIGN KEY(channel) REFERENCES channels(id),
PRIMARY KEY(channel, mask)
);
CREATE TABLE private_messages (
timestamp INT NOT NULL PRIMARY KEY,
sender VARCHAR(255) NOT NULL,
@@ -21,7 +21,8 @@ use crate::{
},
},
client::Client,
connection::{Capability, InitiatedConnection, UserId},
connection::{Capability, InitiatedConnection},
host_mask::{HostMask, HostMaskMap},
messages::{
Broadcast, ChannelFetchTopic, ChannelFetchWhoList, ChannelInvite, ChannelJoin,
ChannelKickUser, ChannelMemberList, ChannelMessage, ChannelPart, ChannelSetMode,
@@ -43,7 +44,7 @@ pub struct ChannelId(pub i64);
pub struct Channel {
pub name: String,
pub server: Addr<Server>,
pub permissions: HashMap<UserId, Permission>,
pub permissions: HostMaskMap<Permission>,
pub clients: HashMap<Addr<Client>, InitiatedConnection>,
pub topic: Option<CurrentChannelTopic>,
pub persistence: Addr<Persistence>,
@@ -95,10 +96,12 @@ impl Supervised for Channel {}
impl Channel {
#[must_use]
pub fn get_user_permissions(&self, user_id: UserId) -> Permission {
pub fn get_user_permissions(&self, host_mask: &HostMask<'_>) -> Permission {
self.permissions
.get(&user_id)
.get(host_mask)
.into_iter()
.copied()
.max()
.unwrap_or(Permission::Normal)
}
}
@@ -139,7 +142,7 @@ impl Handler<FetchUserPermission> for Channel {
type Result = MessageResult<FetchUserPermission>;
fn handle(&mut self, msg: FetchUserPermission, _ctx: &mut Self::Context) -> Self::Result {
MessageResult(self.get_user_permissions(msg.user))
MessageResult(self.get_user_permissions(&msg.host_mask))
}
}
@@ -166,7 +169,10 @@ impl Handler<ChannelMessage> for Channel {
return;
};
if !self.get_user_permissions(sender.user_id).can_chatter() {
if !self
.get_user_permissions(&sender.to_host_mask())
.can_chatter()
{
msg.client.do_send(Broadcast {
message: Message {
tags: None,
@@ -251,15 +257,22 @@ impl Handler<ChannelSetMode> for Channel {
};
if let Ok(user_mode) = Permission::try_from(channel_mode) {
let Some(affected_nick) = arg else {
let Some(affected_mask) = arg else {
error!("No user given");
continue;
};
let Ok(affected_mask) = HostMask::try_from(affected_mask.as_str()) else {
error!("Invalid mask");
continue;
};
ctx.notify(SetUserMode {
requester: client.clone(),
add,
affected_nick,
affected_mask: affected_mask.into_owned(),
user_mode,
span: Span::current(),
});
@@ -280,20 +293,10 @@ impl Handler<SetUserMode> for Channel {
#[instrument(parent = &msg.span, skip_all)]
fn handle(&mut self, msg: SetUserMode, ctx: &mut Self::Context) -> Self::Result {
let permissions = self.get_user_permissions(msg.requester.user_id);
let Some((_, affected_user)) = self
.clients
.iter()
.find(|(_, connection)| connection.nick == msg.affected_nick)
else {
error!("Unknown user to set perms on");
return;
};
let permissions = self.get_user_permissions(&msg.requester.to_host_mask());
let affected_user_perms = self.get_user_permissions(affected_user.user_id);
let affected_user_perms = self.get_user_permissions(&msg.affected_mask);
let new_affected_user_perms = if msg.add {
@@ -319,35 +322,28 @@ impl Handler<SetUserMode> for Channel {
self.permissions
.insert(affected_user.user_id, new_affected_user_perms);
.insert(&msg.affected_mask, new_affected_user_perms);
self.persistence.do_send(SetUserChannelPermissions {
channel_id: self.channel_id,
user_id: affected_user.user_id,
mask: msg.affected_mask.clone().into_owned(),
permissions: new_affected_user_perms,
});
let all_connected_for_user_id = self
.clients
.values()
.filter(|connection| connection.user_id == affected_user.user_id);
for connection in all_connected_for_user_id {
let Some(mode) = msg
.user_mode
.into_mode(msg.add, connection.nick.to_string())
else {
continue;
};
let Some(mode) = msg
.user_mode
.into_mode(msg.add, msg.affected_mask.to_string())
else {
return;
};
ctx.notify(Broadcast {
message: Message {
tags: None,
prefix: Some(connection.to_nick()),
command: Command::ChannelMODE(self.name.to_string(), vec![mode.clone()]),
},
span: Span::current(),
});
}
ctx.notify(Broadcast {
message: Message {
tags: None,
prefix: Some(msg.requester.to_nick()),
command: Command::ChannelMODE(self.name.to_string(), vec![mode]),
},
span: Span::current(),
});
}
}
@@ -383,8 +379,10 @@ impl Handler<ChannelJoin> for Channel {
let mut permissions = self
.permissions
.get(&msg.connection.user_id)
.get(&msg.connection.to_host_mask())
.into_iter()
.copied()
.max()
.unwrap_or(Permission::Normal);
if !permissions.can_join() {
@@ -405,11 +403,13 @@ impl Handler<ChannelJoin> for Channel {
permissions = Permission::Founder;
self.permissions.insert(msg.connection.user_id, permissions);
let username_mask = HostMask::new("*", &msg.connection.user, "*");
self.permissions.insert(&username_mask, permissions);
self.persistence.do_send(SetUserChannelPermissions {
channel_id: self.channel_id,
user_id: msg.connection.user_id,
mask: username_mask.into_owned(),
permissions,
});
}
@@ -478,7 +478,7 @@ impl Handler<ChannelUpdateTopic> for Channel {
debug!(msg.topic, "User is attempting to update channel topic");
if !self
.get_user_permissions(client_info.user_id)
.get_user_permissions(&client_info.to_host_mask())
.can_set_topic()
{
error!("User attempted to set channel topic without privileges");
@@ -517,7 +517,7 @@ impl Handler<ChannelKickUser> for Channel {
return;
};
if !self.get_user_permissions(kicker.user_id).can_kick() {
if !self.get_user_permissions(&kicker.to_host_mask()).can_kick() {
error!("Kicker can not kick people from the channel");
msg.client.do_send(Broadcast {
message: MissingPrivileges(kicker.to_nick(), self.name.to_string()).into_message(),
@@ -700,7 +700,7 @@ pub struct CurrentChannelTopic {
pub struct SetUserMode {
requester: InitiatedConnection,
add: bool,
affected_nick: String,
affected_mask: HostMask<'static>,
user_mode: Permission,
span: Span,
}
@@ -1,3 +1,5 @@
use std::cmp::Ordering;
use anyhow::anyhow;
use irc_proto::{ChannelMode, Mode};
@@ -12,6 +14,23 @@ pub enum Permission {
Founder = i16::MAX,
}
impl PartialOrd for Permission {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for Permission {
fn cmp(&self, other: &Self) -> Ordering {
match (*self as i16, *other as i16) {
(-1, 0) => Ordering::Less,
(0, -1) => Ordering::Greater,
_ => (*self as i16).cmp(&(*other as i16)),
}
}
}
impl TryFrom<ChannelMode> for Permission {
type Error = anyhow::Error;
@@ -33,12 +52,12 @@ impl Permission {
#[must_use]
pub fn into_mode(self, add: bool, nick: String) -> Option<Mode<ChannelMode>> {
pub fn into_mode(self, add: bool, mask: String) -> Option<Mode<ChannelMode>> {
<Option<ChannelMode>>::from(self).map(|v| {
if add {
Mode::Plus(v, Some(nick))
Mode::Plus(v, Some(mask))
} else {
Mode::Minus(v, Some(nick))
Mode::Minus(v, Some(mask))
}
})
}
@@ -87,7 +87,7 @@ impl ChannelWhoList {
nick_list: channel
.clients
.values()
.map(|v| (channel.get_user_permissions(v.user_id), v.clone()))
.map(|v| (channel.get_user_permissions(&v.to_host_mask()), v.clone()))
.collect(),
}
}
@@ -109,7 +109,7 @@ impl IntoProtocol for ChannelWhoList {
for_user.to_string(),
self.channel_name.to_string(),
conn.user,
conn.host.to_string(),
conn.cloak.to_string(),
SERVER_NAME.to_string(),
conn.nick,
format!("{presence}{}", perm.into_prefix()), @@ -137,7 +137,7 @@ impl ChannelNamesList {
nick_list: channel
.clients
.values()
.map(|v| (channel.get_user_permissions(v.user_id), v.clone()))
.map(|v| (channel.get_user_permissions(&v.to_host_mask()), v.clone()))
.collect(),
}
}
@@ -270,19 +270,17 @@ impl Handler<ConnectedChannels> for Client {
#[instrument(parent = &msg.span, skip_all)]
fn handle(&mut self, msg: ConnectedChannels, _ctx: &mut Self::Context) -> Self::Result {
let span = Span::current();
let user_id = self.connection.user_id;
let host_mask = self.connection.to_host_mask().into_owned();
let fut = self.channels.iter().map(move |(channel_name, handle)| {
let span = span.clone();
let channel_name = channel_name.to_string();
let handle = handle.clone();
let host_mask = host_mask.clone();
async move {
let permission = handle
.send(FetchUserPermission {
span,
user: user_id,
})
.send(FetchUserPermission { span, host_mask })
.await
.unwrap();
@@ -29,6 +29,7 @@ use crate::{
authenticate::{Authenticate, AuthenticateMessage, AuthenticateResult},
sasl::{AuthStrategy, ConnectionSuccess, SaslSuccess},
},
host_mask::HostMask,
persistence::{events::ReserveNick, Persistence},
};
@@ -52,6 +53,7 @@ pub struct ConnectionRequest {
#[derive(Clone, Debug)]
pub struct InitiatedConnection {
pub host: SocketAddr,
pub cloak: String,
pub nick: String,
pub user: String,
pub mode: UserMode,
@@ -68,9 +70,14 @@ impl InitiatedConnection {
Prefix::Nickname(
self.nick.to_string(),
self.user.to_string(),
self.host.ip().to_string(),
self.cloak.to_string(),
)
}
#[must_use]
pub fn to_host_mask(&self) -> HostMask<'_> {
HostMask::new(&self.nick, &self.user, &self.cloak)
}
}
impl TryFrom<ConnectionRequest> for InitiatedConnection {
@@ -91,6 +98,7 @@ impl TryFrom<ConnectionRequest> for InitiatedConnection {
Ok(Self {
host,
cloak: host.ip().to_string(),
nick,
user,
mode: UserMode::empty(),
@@ -1,7 +1,16 @@
use std::{
borrow::Cow,
collections::HashMap,
fmt::{Display, Formatter},
io::{Error, ErrorKind},
str::FromStr,
};
use sqlx::{
database::{HasArguments, HasValueRef},
encode::IsNull,
error::BoxDynError,
Database, Decode, Encode, Type,
};
@@ -22,6 +31,11 @@ impl<T> HostMaskMap<T> {
}
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.children.is_empty()
}
pub fn insert(&mut self, mask: &HostMask<'_>, value: T) {
@@ -89,6 +103,18 @@ impl<T> HostMaskMap<T> {
}
}
impl<'a, T> FromIterator<(HostMask<'a>, T)> for HostMaskMap<T> {
fn from_iter<I: IntoIterator<Item = (HostMask<'a>, T)>>(iter: I) -> Self {
let mut out = Self::new();
for (k, v) in iter {
out.insert(&k, v);
}
out
}
}
impl<T> Default for HostMaskMap<T> {
fn default() -> Self {
Self::new()
@@ -153,6 +179,7 @@ impl Matcher {
}
}
#[derive(Clone, Debug)]
pub struct HostMask<'a> {
nick: Cow<'a, str>,
username: Cow<'a, str>,
@@ -161,6 +188,15 @@ pub struct HostMask<'a> {
impl<'a> HostMask<'a> {
#[must_use]
pub const fn new(nick: &'a str, username: &'a str, host: &'a str) -> Self {
Self {
nick: Cow::Borrowed(nick),
username: Cow::Borrowed(username),
host: Cow::Borrowed(host),
}
}
#[must_use]
pub fn as_borrowed(&'a self) -> Self {
Self {
nick: Cow::Borrowed(self.nick.as_ref()),
@@ -168,18 +204,71 @@ impl<'a> HostMask<'a> {
host: Cow::Borrowed(self.host.as_ref()),
}
}
#[must_use]
pub fn into_owned(self) -> HostMask<'static> {
HostMask {
nick: Cow::Owned(self.nick.into_owned()),
username: Cow::Owned(self.username.into_owned()),
host: Cow::Owned(self.host.into_owned()),
}
}
}
impl Display for HostMask<'_> {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{}!{}@{}", self.nick, self.username, self.host)
}
}
impl<'a, DB> Type<DB> for HostMask<'a>
where
String: Type<DB>,
DB: Database,
{
fn type_info() -> DB::TypeInfo {
String::type_info()
}
fn compatible(ty: &DB::TypeInfo) -> bool {
String::compatible(ty)
}
}
impl<'a, 'q, DB> Encode<'q, DB> for HostMask<'a>
where
String: Encode<'q, DB>,
DB: Database,
{
fn encode_by_ref(&self, buf: &mut <DB as HasArguments<'q>>::ArgumentBuffer) -> IsNull {
self.to_string().encode(buf)
}
}
impl<'r, DB> Decode<'r, DB> for HostMask<'static>
where
&'r str: Decode<'r, DB>,
DB: Database,
{
fn decode(value: <DB as HasValueRef<'r>>::ValueRef) -> Result<Self, BoxDynError> {
Ok(<&'r str as Decode<'r, DB>>::decode(value)?.parse()?)
}
}
impl FromStr for HostMask<'static> {
type Err = Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
HostMask::try_from(s).map(HostMask::into_owned)
}
}
impl<'a> TryFrom<&'a str> for HostMask<'a> {
type Error = Error;
fn try_from(rest: &'a str) -> Result<Self, Self::Error> {
let (nick, rest) = rest
.split_once('!')
.ok_or_else(|| Error::new(ErrorKind::Other, "missing nick separator"))?;
let (username, host) = rest
.split_once('@')
.ok_or_else(|| Error::new(ErrorKind::Other, "missing host separator"))?;
let (nick, rest) = rest.split_once('!').unwrap_or((rest, ""));
let (username, host) = rest.split_once('@').unwrap_or(("*", "*"));
let is_invalid = |v: &str| {
(v.contains('*') && !v.ends_with('*'))
@@ -7,6 +7,7 @@ use crate::{
channel::Channel,
client::Client,
connection::{InitiatedConnection, UserId},
host_mask::HostMask,
server::response::NoSuchNick,
};
@@ -144,7 +145,7 @@ pub struct ChannelMemberList {
#[rtype(result = "crate::channel::permissions::Permission")]
pub struct FetchUserPermission {
pub span: Span,
pub user: UserId,
pub host_mask: HostMask<'static>,
}
@@ -1,6 +1,6 @@
pub mod events;
use std::{collections::HashMap, time::Duration};
use std::time::Duration;
use actix::{AsyncContext, Context, Handler, ResponseFuture, WrapFuture};
use chrono::{DateTime, TimeZone, Utc};
@@ -10,6 +10,7 @@ use tracing::instrument;
use crate::{
channel::permissions::Permission,
connection::UserId,
host_mask::{HostMask, HostMaskMap},
messages::MessageKind,
persistence::events::{
ChannelCreated, ChannelJoined, ChannelMessage, ChannelParted,
@@ -91,13 +92,12 @@ impl Handler<ChannelJoined> for Persistence {
Box::pin(async move {
sqlx::query(
"INSERT INTO channel_users (channel, user, permissions, in_channel)
VALUES (?, ?, ?, ?)
"INSERT INTO channel_users (channel, user, in_channel)
VALUES (?, ?, ?)
ON CONFLICT(channel, user) DO UPDATE SET in_channel = excluded.in_channel",
)
.bind(msg.channel_id.0)
.bind(msg.user_id.0)
.bind(0i32)
.bind(true)
.execute(&conn)
.await
@@ -131,7 +131,7 @@ impl Handler<ChannelParted> for Persistence {
}
impl Handler<FetchAllUserChannelPermissions> for Persistence {
type Result = ResponseFuture<HashMap<UserId, Permission>>;
type Result = ResponseFuture<HostMaskMap<Permission>>;
fn handle(
&mut self,
@@ -141,9 +141,9 @@ impl Handler<FetchAllUserChannelPermissions> for Persistence {
let conn = self.database.clone();
Box::pin(async move {
sqlx::query_as::<_, (UserId, Permission)>(
"SELECT user, permissions
FROM channel_users
sqlx::query_as::<_, (HostMask, Permission)>(
"SELECT mask, permissions
FROM channel_permissions
WHERE channel = ?",
)
.bind(msg.channel_id.0)
@@ -164,14 +164,13 @@ impl Handler<SetUserChannelPermissions> for Persistence {
Box::pin(async move {
sqlx::query(
"UPDATE channel_users
SET permissions = ?
WHERE user = ?
AND channel = ?",
"INSERT INTO channel_permissions (channel, mask, permissions)
VALUES (?, ?, ?)
ON CONFLICT(channel, mask) DO UPDATE SET permissions = excluded.permissions",
)
.bind(msg.permissions)
.bind(msg.user_id.0)
.bind(msg.channel_id.0)
.bind(msg.mask)
.bind(msg.permissions)
.execute(&conn)
.await
.unwrap();
@@ -397,7 +396,7 @@ impl Handler<ReserveNick> for Persistence {
pub async fn truncate_seen_messages(db: sqlx::Pool<sqlx::Any>, max_replay_since: Duration) {
let messages = sqlx::query_as::<_, (i64, i64)>(
"SELECT channel, MIN(last_seen_message_timestamp)
"SELECT channel, COALESCE(MIN(last_seen_message_timestamp), 0)
FROM channel_users
GROUP BY channel",
)
@@ -1,5 +1,3 @@
use std::collections::HashMap;
use actix::Message;
use chrono::{DateTime, Utc};
use tracing::Span;
@@ -7,6 +5,7 @@ use tracing::Span;
use crate::{
channel::{permissions::Permission, ChannelId},
connection::UserId,
host_mask::{HostMask, HostMaskMap},
messages::MessageKind,
};
@@ -40,7 +39,7 @@ pub struct FetchUserChannels {
}
#[derive(Message)]
#[rtype(result = "HashMap<UserId, Permission>")]
#[rtype(result = "HostMaskMap<Permission>")]
pub struct FetchAllUserChannelPermissions {
pub channel_id: ChannelId,
}
@@ -49,7 +48,7 @@ pub struct FetchAllUserChannelPermissions {
#[rtype(result = "()")]
pub struct SetUserChannelPermissions {
pub channel_id: ChannelId,
pub user_id: UserId,
pub mask: HostMask<'static>,
pub permissions: Permission,
}
@@ -23,6 +23,7 @@ use crate::{
client::Client,
config::Config,
connection::{InitiatedConnection, UserMode},
host_mask::HostMaskMap,
messages::{
Broadcast, ChannelFetchTopic, ChannelFetchWhoList, ChannelJoin, ChannelList,
ChannelMemberList, ClientAway, ConnectedChannels, FetchClientByNick, FetchWhoList,
@@ -200,7 +201,7 @@ impl Handler<ChannelJoin> for Server {
Supervisor::start_in_arbiter(&arbiter, move |_ctx| Channel {
name: channel_name,
permissions: HashMap::new(),
permissions: HostMaskMap::new(),
clients: HashMap::new(),
topic: None,
server,