From 906785f6f1c9ce827992d4c5b4efc4a4e09be111 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Wed, 31 Jan 2024 00:35:52 +0000 Subject: [PATCH] Add HostMaskMap --- src/host_mask.rs | 343 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ src/lib.rs | 18 ++++++++++++++++++ src/main.rs | 19 ++++--------------- 3 files changed, 365 insertions(+), 15 deletions(-) create mode 100644 src/host_mask.rs create mode 100644 src/lib.rs diff --git a/src/host_mask.rs b/src/host_mask.rs new file mode 100644 index 0000000..7513f4e --- /dev/null +++ b/src/host_mask.rs @@ -0,0 +1,343 @@ +use std::{ + borrow::Cow, + collections::HashMap, + io::{Error, ErrorKind}, +}; + +/// A map of `HostMask`s to `T`, implemented as a prefix trie with three +/// sections with support for wildcards. +#[derive(Debug)] +pub struct HostMaskMap { + children: HashMap>, + matcher: Matcher, +} + +impl HostMaskMap { + /// Instantiates a new `HostMaskMap` with a top level capacity of 0. + #[must_use] + pub fn new() -> Self { + Self { + children: HashMap::new(), + matcher: Matcher::Nick, + } + } + + /// Inserts a new mask into the tree with the given `value`. This function operates + /// in `O(m)` average time complexity + pub fn insert(&mut self, mask: &HostMask<'_>, value: T) { + let mut next_mask = mask.as_borrowed(); + + let key = match self.matcher { + Matcher::Nick => take_next_char(&mask.nick, &mut next_mask.nick), + Matcher::Username => take_next_char(&mask.username, &mut next_mask.username), + Matcher::Host => take_next_char(&mask.host, &mut next_mask.host), + }; + + let key = match key { + Some('*') => Key::Wildcard, + Some(c) => Key::Char(c), + None => Key::EndOfString, + }; + + if key.is_end() && self.matcher.next().is_none() { + self.children.insert(key, Node::Match(value)); + } else { + let node = self.children.entry(key).or_insert_with(|| { + Node::Inner(Self { + children: HashMap::new(), + matcher: if key.is_end() { + self.matcher.next().expect("guarded by if") + } else { + self.matcher + }, + }) + }); + + match node { + Node::Match(_) => unreachable!("stored hostmask has less parts than a!b@c"), + Node::Inner(map) => map.insert(&next_mask, value), + }; + } + } + + /// Fetches all the matches within the trie that match the input. This function returns + /// any exact matches as well as any wildcard matches. This function operates in `O(m)` + /// average time complexity. + #[must_use] + pub fn get(&self, mask: &HostMask<'_>) -> Vec<&T> { + self.get_inner(mask, Vec::new()) + } + + fn get_inner<'a>(&'a self, mask: &HostMask<'_>, mut out: Vec<&'a T>) -> Vec<&'a T> { + let mut next_mask = mask.as_borrowed(); + + let key = match self.matcher { + Matcher::Nick => take_next_char(&mask.nick, &mut next_mask.nick), + Matcher::Username => take_next_char(&mask.username, &mut next_mask.username), + Matcher::Host => take_next_char(&mask.host, &mut next_mask.host), + }; + + if let Some(next) = self.children.get(&key.map_or(Key::EndOfString, Key::Char)) { + out = traverse(out, next, &next_mask); + } + + if let Some(wildcard) = self.children.get(&Key::Wildcard) { + out = traverse(out, wildcard, &next_mask); + } + + out + } +} + +impl Default for HostMaskMap { + fn default() -> Self { + Self::new() + } +} + +/// Traverses the trie, appending any matches into `out` before returning. +fn traverse<'a, T>(mut out: Vec<&'a T>, node: &'a Node, mask: &HostMask<'_>) -> Vec<&'a T> { + match node { + Node::Match(v) => { + out.push(v); + } + Node::Inner(child) => { + out = child.get_inner(mask, out); + } + } + + out +} + +/// Takes a single character from `v` and updates `next` to the remaining input. +fn take_next_char<'a>(v: &'a str, next: &mut Cow<'a, str>) -> Option { + let mut chars = v.chars(); + let c = chars.next(); + *next = Cow::Borrowed(chars.as_str()); + c +} + +#[derive(Hash, PartialEq, Eq, Debug, Copy, Clone)] +enum Key { + Wildcard, + EndOfString, + Char(char), +} + +impl Key { + const fn is_end(self) -> bool { + !matches!(self, Self::Char(_)) + } +} + +#[derive(Debug)] +enum Node { + Match(T), + Inner(HostMaskMap), +} + +#[derive(Copy, Clone, Debug)] +enum Matcher { + Nick, + Username, + Host, +} + +impl Matcher { + const fn next(self) -> Option { + match self { + Self::Nick => Some(Self::Username), + Self::Username => Some(Self::Host), + Self::Host => None, + } + } +} + +pub struct HostMask<'a> { + nick: Cow<'a, str>, + username: Cow<'a, str>, + host: Cow<'a, str>, +} + +impl<'a> HostMask<'a> { + #[must_use] + pub fn as_borrowed(&'a self) -> Self { + Self { + nick: Cow::Borrowed(self.nick.as_ref()), + username: Cow::Borrowed(self.username.as_ref()), + host: Cow::Borrowed(self.host.as_ref()), + } + } +} + +impl<'a> TryFrom<&'a str> for HostMask<'a> { + type Error = Error; + + fn try_from(rest: &'a str) -> Result { + let (nick, rest) = rest + .split_once('!') + .ok_or_else(|| Error::new(ErrorKind::Other, "missing nick separator"))?; + let (username, host) = rest + .split_once('@') + .ok_or_else(|| Error::new(ErrorKind::Other, "missing host separator"))?; + + let is_invalid = |v: &str| { + (v.contains('*') && !v.ends_with('*')) + || v.chars().filter(|&c| c == '*').count() > 1 + || v.is_empty() + }; + + if is_invalid(nick) { + return Err(Error::new(ErrorKind::Other, "invalid nick")); + } else if is_invalid(username) { + return Err(Error::new(ErrorKind::Other, "invalid username")); + } else if is_invalid(host) { + return Err(Error::new(ErrorKind::Other, "invalid hostname")); + } + + Ok(Self { + nick: Cow::Borrowed(nick), + username: Cow::Borrowed(username), + host: Cow::Borrowed(host), + }) + } +} + +#[cfg(test)] +mod test { + use crate::host_mask::{HostMask, HostMaskMap}; + + #[test] + fn wildcard_middle_of_string_unsupported() { + assert!(HostMask::try_from("aa*a!bbbb@cccc").is_err()); + } + + #[test] + fn multiple_wildcards_unsupported() { + assert!(HostMask::try_from("a**!bbb@cccc").is_err()); + } + + #[test] + fn empty_key_unsupported() { + assert!(HostMask::try_from("a!@cccc").is_err()); + } + + #[test] + fn test_insert_and_get_no_wildcard() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!bbbb@cccc".try_into().unwrap(), 10); + + let retrieved = map.get(&"aaaa!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 10); + } + + #[test] + fn test_insert_with_wildcard_and_get_exact() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!*@*".try_into().unwrap(), 20); + + let retrieved = map.get(&"aaaa!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 20); + } + + #[test] + fn test_insert_with_wildcard_and_get_wildcard() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!*@*".try_into().unwrap(), 30); + + let retrieved = map.get(&"aaaa!*@*".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 30); + } + + #[test] + fn test_insert_multiple_and_get_one() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!bbbb@cccc".try_into().unwrap(), 40); + map.insert(&"xxxx!yyyy@zzzz".try_into().unwrap(), 50); + + let retrieved = map.get(&"aaaa!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 40); + } + + #[test] + fn test_insert_and_get_multiple_matches() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!*@*".try_into().unwrap(), 60); + map.insert(&"*!bbbb@cccc".try_into().unwrap(), 70); + + let retrieved = map.get(&"aaaa!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 2); + assert!(retrieved.contains(&&60)); + assert!(retrieved.contains(&&70)); + } + + #[test] + fn test_get_no_match() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa!bbbb@cccc".try_into().unwrap(), 80); + + let retrieved = map.get(&"xxxx!yyyy@zzzz".try_into().unwrap()); + assert_eq!(retrieved.len(), 0); + } + + #[test] + fn test_insert_with_partial_wildcard_and_get_exact() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa*!bbbb@cccc".try_into().unwrap(), 100); + + let retrieved = map.get(&"aaaa1234!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 100); + } + + #[test] + fn test_insert_with_partial_wildcard_and_no_match() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa*!bbbb@cccc".try_into().unwrap(), 110); + + let retrieved = map.get(&"aaab!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 0); + } + + #[test] + fn test_insert_multiple_partial_wildcards_and_get_one() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa*!bbbb@cccc".try_into().unwrap(), 120); + map.insert(&"xxxx*!yyyy@zzzz".try_into().unwrap(), 130); + + let retrieved = map.get(&"aaaa123!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 120); + } + + #[test] + fn test_insert_with_multiple_wildcard_styles_and_get_match() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa*!bbbb@cccc".try_into().unwrap(), 140); + map.insert(&"xxxx!*@*".try_into().unwrap(), 150); + + let retrieved = map.get(&"aaaa123!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 1); + assert_eq!(*retrieved[0], 140); + + let retrieved2 = map.get(&"xxxx!testyyyy@zzzz".try_into().unwrap()); + assert_eq!(retrieved2.len(), 1); + assert_eq!(*retrieved2[0], 150); + } + + #[test] + fn test_insert_with_partial_wildcard_and_get_multiple_matches() { + let mut map = HostMaskMap::new(); + map.insert(&"aaaa*!bbbb@cccc".try_into().unwrap(), 160); + map.insert(&"aaaa*!bbbb@ccc*".try_into().unwrap(), 170); + + let retrieved = map.get(&"aaaa1234!bbbb@cccc".try_into().unwrap()); + assert_eq!(retrieved.len(), 2); + assert!(retrieved.contains(&&160)); + assert!(retrieved.contains(&&170)); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..b583bd1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,18 @@ +#![deny(clippy::nursery, clippy::pedantic)] +#![allow( + clippy::module_name_repetitions, + clippy::missing_panics_doc, + clippy::missing_errors_doc +)] + +pub mod channel; +pub mod client; +pub mod config; +pub mod connection; +pub mod database; +pub mod host_mask; +pub mod messages; +pub mod persistence; +pub mod server; + +pub const SERVER_NAME: &str = "my.cool.server"; diff --git a/src/main.rs b/src/main.rs index c706d24..da80296 100644 --- a/src/main.rs +++ b/src/main.rs @@ -15,6 +15,10 @@ use futures::SinkExt; use irc_proto::{Command, IrcCodec, Message}; use rand::seq::SliceRandom; use sqlx::migrate::Migrator; +use titanircd::{ + client::Client, config::Args, connection, messages::UserConnected, persistence::Persistence, + server::Server, +}; use tokio::{ io::WriteHalf, net::{TcpListener, TcpStream}, @@ -24,21 +28,6 @@ use tokio_util::codec::FramedRead; use tracing::{error, info, info_span, Instrument}; use tracing_subscriber::EnvFilter; -use crate::{ - client::Client, config::Args, messages::UserConnected, persistence::Persistence, server::Server, -}; - -pub mod channel; -pub mod client; -pub mod config; -pub mod connection; -pub mod database; -pub mod messages; -pub mod persistence; -pub mod server; - -pub const SERVER_NAME: &str = "my.cool.server"; - static MIGRATOR: Migrator = sqlx::migrate!(); #[actix_rt::main] -- libgit2 1.7.2