From 170a7949efb7cd3705c717f9c3ec1c030fbcf958 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Thu, 13 Jul 2023 01:11:14 +0100 Subject: [PATCH] Add basic tests for each command --- Cargo.lock | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ pisshoff-server/Cargo.toml | 3 ++- pisshoff-server/src/command.rs | 30 ++++++++++++++++++++---------- pisshoff-server/src/command/echo.rs | 57 +++++++++++++++++++++++++++++++++++++++++++++++++-------- pisshoff-server/src/command/exit.rs | 50 ++++++++++++++++++++++++++++++++++++++++++-------- pisshoff-server/src/command/ls.rs | 64 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- pisshoff-server/src/command/pwd.rs | 49 +++++++++++++++++++++++++++++++++++++++++-------- pisshoff-server/src/command/scp.rs | 116 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------- pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap | 22 ++++++++++++++++++++++ pisshoff-server/src/command/uname.rs | 16 ++++++++-------- pisshoff-server/src/command/whoami.rs | 49 +++++++++++++++++++++++++++++++++++++++++-------- pisshoff-server/src/server.rs | 156 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------------------------------------- pisshoff-server/src/subsystem/mod.rs | 4 ++-- pisshoff-server/src/subsystem/sftp.rs | 4 ++-- pisshoff-server/src/subsystem/shell.rs | 4 ++-- 15 files changed, 611 insertions(+), 111 deletions(-) create mode 100644 pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap diff --git a/Cargo.lock b/Cargo.lock index 2cce4d9..c7cc857 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -394,6 +394,12 @@ dependencies = [ ] [[package]] +name = "difflib" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8" + +[[package]] name = "digest" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -434,6 +440,12 @@ dependencies = [ ] [[package]] +name = "downcast" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1" + +[[package]] name = "either" version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -498,6 +510,15 @@ dependencies = [ ] [[package]] +name = "float-cmp" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4" +dependencies = [ + "num-traits", +] + +[[package]] name = "form_urlencoded" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -507,6 +528,12 @@ dependencies = [ ] [[package]] +name = "fragile" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa" + +[[package]] name = "futures" version = "0.3.28" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -700,6 +727,7 @@ dependencies = [ "console", "lazy_static", "linked-hash-map", + "regex", "similar", "yaml-rust", ] @@ -869,6 +897,33 @@ dependencies = [ ] [[package]] +name = "mockall" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4c84490118f2ee2d74570d114f3d0493cbf02790df303d2707606c3e14e07c96" +dependencies = [ + "cfg-if", + "downcast", + "fragile", + "lazy_static", + "mockall_derive", + "predicates", + "predicates-tree", +] + +[[package]] +name = "mockall_derive" +version = "0.11.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "22ce75669015c4f47b289fd4d4f56e894e4c96003ffdf3ac51313126f94c6cbb" +dependencies = [ + "cfg-if", + "proc-macro2", + "quote", + "syn 1.0.109", +] + +[[package]] name = "nix" version = "0.26.2" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -893,6 +948,12 @@ dependencies = [ ] [[package]] +name = "normalize-line-endings" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be" + +[[package]] name = "nu-ansi-term" version = "0.46.0" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1065,6 +1126,7 @@ dependencies = [ "futures", "insta", "itertools", + "mockall", "nix", "nom", "parking_lot", @@ -1160,6 +1222,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] +name = "predicates" +version = "2.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "59230a63c37f3e18569bdb90e4a89cbf5bf8b06fea0b84e65ea10cc4df47addd" +dependencies = [ + "difflib", + "float-cmp", + "itertools", + "normalize-line-endings", + "predicates-core", + "regex", +] + +[[package]] +name = "predicates-core" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174" + +[[package]] +name = "predicates-tree" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf" +dependencies = [ + "predicates-core", + "termtree", +] + +[[package]] name = "proc-macro-error" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" @@ -1592,6 +1684,12 @@ dependencies = [ ] [[package]] +name = "termtree" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76" + +[[package]] name = "test-case" version = "3.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/pisshoff-server/Cargo.toml b/pisshoff-server/Cargo.toml index 3400422..78ba335 100644 --- a/pisshoff-server/Cargo.toml +++ b/pisshoff-server/Cargo.toml @@ -32,5 +32,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] } uuid = { version = "1.3", features = ["v4", "serde"] } [dev-dependencies] -insta = "1.29" +mockall = "0.11" +insta = { version = "1.29", features = ["filters"] } test-case = "3.1" diff --git a/pisshoff-server/src/command.rs b/pisshoff-server/src/command.rs index 04ecea1..b02d8c0 100644 --- a/pisshoff-server/src/command.rs +++ b/pisshoff-server/src/command.rs @@ -6,18 +6,20 @@ mod scp; mod uname; mod whoami; -use crate::server::Connection; +use crate::server::{ConnectionState, ThrusshSession}; use async_trait::async_trait; use itertools::Either; +use std::fmt::Debug; use thrussh::{server::Session, ChannelId}; +#[derive(Debug)] pub enum CommandResult { ReadStdin(T), Exit(u32), Close(u32), } -impl CommandResult { +impl CommandResult { fn map(self, f: fn(T) -> N) -> CommandResult { match self { Self::ReadStdin(val) => CommandResult::ReadStdin(f(val)), @@ -25,23 +27,31 @@ impl CommandResult { Self::Close(v) => CommandResult::Close(v), } } + + #[cfg(test)] + pub fn unwrap_stdin(self) -> T { + match self { + Self::ReadStdin(val) => val, + v => panic!("got {v:?}, expected ReadStdin"), + } + } } #[async_trait] pub trait Command: Sized { - async fn new( - connection: &mut Connection, + async fn new( + connection: &mut ConnectionState, params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult; - async fn stdin( + async fn stdin( self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], - session: &mut Session, + session: &mut S, ) -> CommandResult; } @@ -54,7 +64,7 @@ macro_rules! define_commands { impl ConcreteCommand { pub async fn new( - connection: &mut Connection, + connection: &mut ConnectionState, params: &[String], channel: ChannelId, session: &mut Session, @@ -78,7 +88,7 @@ macro_rules! define_commands { pub async fn stdin( self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], session: &mut Session, diff --git a/pisshoff-server/src/command/echo.rs b/pisshoff-server/src/command/echo.rs index 777c1f2..4986a6c 100644 --- a/pisshoff-server/src/command/echo.rs +++ b/pisshoff-server/src/command/echo.rs @@ -1,34 +1,75 @@ use crate::{ command::{Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; use itertools::Itertools; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; #[derive(Debug, Clone)] pub struct Echo {} #[async_trait] impl Command for Echo { - async fn new( - _connection: &mut Connection, + async fn new( + _connection: &mut ConnectionState, params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { session.data(channel, format!("{}\n", params.iter().join(" ")).into()); CommandResult::Exit(0) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } } + +#[cfg(test)] +mod test { + use crate::{ + command::{echo::Echo, Command, CommandResult}, + server::{ + test::{fake_channel_id, predicate::eq_string}, + ConnectionState, MockThrusshSession, + }, + }; + use mockall::predicate::always; + use test_case::test_case; + + #[test_case(&[], "\n"; "no parameters")] + #[test_case(&["hello"], "hello\n"; "single parameter")] + #[test_case(&["hello", "world"], "hello world\n"; "multiple parameters")] + #[tokio::test] + async fn test(params: &[&str], output: &'static str) { + let mut session = MockThrusshSession::default(); + + session + .expect_data() + .once() + .with(always(), eq_string(output)) + .returning(|_, _| ()); + + let out = Echo::new( + &mut ConnectionState::mock(), + params + .iter() + .map(ToString::to_string) + .collect::>() + .as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!(matches!(out, CommandResult::Exit(0)), "{out:?}"); + } +} diff --git a/pisshoff-server/src/command/exit.rs b/pisshoff-server/src/command/exit.rs index af4fa2f..620dd6a 100644 --- a/pisshoff-server/src/command/exit.rs +++ b/pisshoff-server/src/command/exit.rs @@ -1,21 +1,21 @@ use crate::{ command::{Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; use std::str::FromStr; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; #[derive(Debug, Clone)] pub struct Exit {} #[async_trait] impl Command for Exit { - async fn new( - _connection: &mut Connection, + async fn new( + _connection: &mut ConnectionState, params: &[String], _channel: ChannelId, - _session: &mut Session, + _session: &mut S, ) -> CommandResult { let exit_status = params .get(0) @@ -26,13 +26,47 @@ impl Command for Exit { CommandResult::Close(exit_status) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } } + +#[cfg(test)] +mod test { + use crate::{ + command::{exit::Exit, Command, CommandResult}, + server::{test::fake_channel_id, ConnectionState, MockThrusshSession}, + }; + use test_case::test_case; + + #[test_case(&[], 0; "no parameters")] + #[test_case(&["3"], 3; "with parameter")] + #[test_case(&["invalid"], 2; "invalid parameter")] + #[tokio::test] + async fn test(params: &[&str], expected_exit_code: u32) { + let mut session = MockThrusshSession::default(); + + let out = Exit::new( + &mut ConnectionState::mock(), + params + .iter() + .map(ToString::to_string) + .collect::>() + .as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!( + matches!(out, CommandResult::Close(v) if v == expected_exit_code), + "{out:?}" + ); + } +} diff --git a/pisshoff-server/src/command/ls.rs b/pisshoff-server/src/command/ls.rs index 73aacb0..e51cbcf 100644 --- a/pisshoff-server/src/command/ls.rs +++ b/pisshoff-server/src/command/ls.rs @@ -1,21 +1,21 @@ use crate::{ command::{Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; use std::fmt::Write; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; #[derive(Debug, Clone)] pub struct Ls {} #[async_trait] impl Command for Ls { - async fn new( - connection: &mut Connection, + async fn new( + connection: &mut ConnectionState, params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { let resp = if params.is_empty() { connection.file_system().ls(None).join(" ") @@ -46,13 +46,61 @@ impl Command for Ls { CommandResult::Exit(0) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } } + +#[cfg(test)] +mod test { + use crate::{ + command::{ls::Ls, Command, CommandResult}, + server::{ + test::{fake_channel_id, predicate::eq_string}, + ConnectionState, MockThrusshSession, + }, + }; + use mockall::predicate::always; + + #[tokio::test] + async fn empty_pwd() { + let mut session = MockThrusshSession::default(); + + let out = Ls::new( + &mut ConnectionState::mock(), + [].as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!(matches!(out, CommandResult::Exit(0)), "{out:?}"); + } + + #[tokio::test] + async fn multiple_empty_directories() { + let mut session = MockThrusshSession::default(); + + session + .expect_data() + .once() + .with(always(), eq_string("a:\n\nb:\n")) + .returning(|_, _| ()); + + let out = Ls::new( + &mut ConnectionState::mock(), + ["a".to_string(), "b".to_string()].as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!(matches!(out, CommandResult::Exit(0)), "{out:?}"); + } +} diff --git a/pisshoff-server/src/command/pwd.rs b/pisshoff-server/src/command/pwd.rs index 6b68362..39f7586 100644 --- a/pisshoff-server/src/command/pwd.rs +++ b/pisshoff-server/src/command/pwd.rs @@ -1,20 +1,20 @@ use crate::{ command::{Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; #[derive(Debug, Clone)] pub struct Pwd {} #[async_trait] impl Command for Pwd { - async fn new( - connection: &mut Connection, + async fn new( + connection: &mut ConnectionState, _params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { session.data( channel, @@ -24,13 +24,46 @@ impl Command for Pwd { CommandResult::Exit(0) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } } + +#[cfg(test)] +mod test { + use crate::{ + command::{pwd::Pwd, Command, CommandResult}, + server::{ + test::{fake_channel_id, predicate::eq_string}, + ConnectionState, MockThrusshSession, + }, + }; + use mockall::predicate::always; + + #[tokio::test] + async fn works() { + let mut session = MockThrusshSession::default(); + + session + .expect_data() + .once() + .with(always(), eq_string("/root\n")) + .returning(|_, _| ()); + + let out = Pwd::new( + &mut ConnectionState::mock(), + [].as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!(matches!(out, CommandResult::Exit(0)), "{out:?}"); + } +} diff --git a/pisshoff-server/src/command/scp.rs b/pisshoff-server/src/command/scp.rs index c465c7d..7ded331 100644 --- a/pisshoff-server/src/command/scp.rs +++ b/pisshoff-server/src/command/scp.rs @@ -1,6 +1,6 @@ use crate::{ command::{Arg, Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; use bytes::{Buf, BytesMut}; @@ -12,7 +12,7 @@ use nom::{ }; use pisshoff_types::audit::{AuditLogAction, WriteFileEvent}; use std::{path::PathBuf, str::FromStr}; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; use tracing::warn; const HELP: &str = "usage: scp [-346ABCOpqRrsTv] [-c cipher] [-D sftp_server_path] [-F ssh_config] @@ -33,11 +33,11 @@ pub struct Scp { #[async_trait] impl Command for Scp { - async fn new( - _connection: &mut Connection, + async fn new( + _connection: &mut ConnectionState, params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { let mut path = None; let mut transfer = false; @@ -80,12 +80,12 @@ impl Command for Scp { }) } - async fn stdin( + async fn stdin( mut self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], - session: &mut Session, + session: &mut S, ) -> CommandResult { self.pending_data.extend_from_slice(data); @@ -173,7 +173,7 @@ enum State { AwaitingSeparator, } -#[derive(Debug)] +#[derive(Debug, PartialEq, Eq)] #[allow(dead_code)] enum Receive<'a> { FileCopy { @@ -278,3 +278,101 @@ impl<'a> Receive<'a> { } } } + +#[cfg(test)] +mod test { + use crate::{ + command::{scp::Scp, Command}, + server::{ + test::{fake_channel_id, predicate::eq_string}, + ConnectionState, MockThrusshSession, + }, + }; + use insta::assert_debug_snapshot; + use mockall::predicate::always; + + mod packet_parser { + use crate::command::scp::Receive; + + #[test] + fn file_copy() { + let (_, actual) = Receive::parse(b"C0777 1234 test.txt\n").unwrap(); + let expected = Receive::FileCopy { + mode: "0777", + length: 1234, + file_name: "test.txt", + }; + + assert_eq!(actual, expected); + } + + #[test] + fn directory_copy() { + let (_, actual) = Receive::parse(b"D0777 1234 test\n").unwrap(); + let expected = Receive::DirectoryCopy { + mode: "0777", + length: 1234, + directory_name: "test", + }; + + assert_eq!(actual, expected); + } + + #[test] + fn end_directory() { + let (_, actual) = Receive::parse(b"E\n").unwrap(); + let expected = Receive::EndDirectory; + + assert_eq!(actual, expected); + } + + #[test] + fn access_time() { + let (_, actual) = Receive::parse(b"T123 444 555 666\n").unwrap(); + let expected = Receive::AccessTime { + modified_time: 123, + modified_time_micros: 444, + access_time: 555, + access_time_micros: 666, + }; + + assert_eq!(actual, expected); + } + } + + #[tokio::test] + async fn works() { + let mut session = MockThrusshSession::default(); + let mut state = ConnectionState::mock(); + + session + .expect_data() + .with(always(), eq_string("\0")) + .returning(|_, _| ()); + + let out = Scp::new( + &mut state, + ["-t".to_string(), "hello".to_string()].as_slice(), + fake_channel_id(), + &mut session, + ) + .await + .unwrap_stdin(); + + let _out = out + .stdin( + &mut state, + fake_channel_id(), + b"C0777 11 hello.txt\nhello world\0", + &mut session, + ) + .await + .unwrap_stdin(); + + insta::with_settings!({filters => vec![ + (r#"\bstart_offset: [^,]+"#, "start_offset: [stripped]") + ]}, { + assert_debug_snapshot!(state.audit_log()); + }); + } +} diff --git a/pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap b/pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap new file mode 100644 index 0000000..691f8b6 --- /dev/null +++ b/pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap @@ -0,0 +1,22 @@ +--- +source: pisshoff-server/src/command/scp.rs +expression: state.audit_log() +--- +AuditLog { + connection_id: 01020304-0506-0708-090a-0b0c0d0e0f10, + peer_address: Some( + 127.0.0.1:1234, + ), + environment_variables: [], + events: [ + AuditLogEvent { + start_offset: [stripped], + action: WriteFile( + WriteFileEvent { + path: "hello/hello.txt", + content: b"hello world", + }, + ), + }, + ], +} diff --git a/pisshoff-server/src/command/uname.rs b/pisshoff-server/src/command/uname.rs index bd27e00..25b2310 100644 --- a/pisshoff-server/src/command/uname.rs +++ b/pisshoff-server/src/command/uname.rs @@ -1,10 +1,10 @@ use crate::{ command::{Arg, Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; use bitflags::bitflags; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; bitflags! { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)] @@ -56,11 +56,11 @@ pub struct Uname {} #[async_trait] impl Command for Uname { - async fn new( - _connection: &mut Connection, + async fn new( + _connection: &mut ConnectionState, params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { let (out, exit_code) = execute(params); @@ -68,12 +68,12 @@ impl Command for Uname { CommandResult::Exit(exit_code) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } diff --git a/pisshoff-server/src/command/whoami.rs b/pisshoff-server/src/command/whoami.rs index 6937fa6..d92110c 100644 --- a/pisshoff-server/src/command/whoami.rs +++ b/pisshoff-server/src/command/whoami.rs @@ -1,32 +1,65 @@ use crate::{ command::{Command, CommandResult}, - server::Connection, + server::{ConnectionState, ThrusshSession}, }; use async_trait::async_trait; -use thrussh::{server::Session, ChannelId}; +use thrussh::ChannelId; #[derive(Debug, Clone)] pub struct Whoami {} #[async_trait] impl Command for Whoami { - async fn new( - connection: &mut Connection, + async fn new( + connection: &mut ConnectionState, _params: &[String], channel: ChannelId, - session: &mut Session, + session: &mut S, ) -> CommandResult { session.data(channel, format!("{}\n", connection.username()).into()); CommandResult::Exit(0) } - async fn stdin( + async fn stdin( self, - _connection: &mut Connection, + _connection: &mut ConnectionState, _channel: ChannelId, _data: &[u8], - _session: &mut Session, + _session: &mut S, ) -> CommandResult { CommandResult::Exit(0) } } + +#[cfg(test)] +mod test { + use crate::{ + command::{whoami::Whoami, Command, CommandResult}, + server::{ + test::{fake_channel_id, predicate::eq_string}, + ConnectionState, MockThrusshSession, + }, + }; + use mockall::predicate::always; + + #[tokio::test] + async fn works() { + let mut session = MockThrusshSession::default(); + + session + .expect_data() + .once() + .with(always(), eq_string("root\n")) + .returning(|_, _| ()); + + let out = Whoami::new( + &mut ConnectionState::mock(), + [].as_slice(), + fake_channel_id(), + &mut session, + ) + .await; + + assert!(matches!(out, CommandResult::Exit(0)), "{out:?}"); + } +} diff --git a/pisshoff-server/src/server.rs b/pisshoff-server/src/server.rs index e390492..1663a0d 100644 --- a/pisshoff-server/src/server.rs +++ b/pisshoff-server/src/server.rs @@ -27,7 +27,7 @@ use std::{ }; use thrussh::{ server::{Auth, Response, Session}, - ChannelId, Pty, Sig, + ChannelId, CryptoVec, Pty, Sig, }; use thrussh_keys::key::PublicKey; use tokio::sync::mpsc::UnboundedSender; @@ -69,29 +69,51 @@ impl thrussh::server::Server for Server { Connection { span: info_span!("connection", ?peer_addr, %connection_id), server: self.clone(), - audit_log: AuditLog { - connection_id, - host: Cow::Borrowed(self.hostname), - peer_address: peer_addr, - ..AuditLog::default() + state: ConnectionState { + audit_log: AuditLog { + connection_id, + host: Cow::Borrowed(self.hostname), + peer_address: peer_addr, + ..AuditLog::default() + }, + username: None, + file_system: None, }, - username: None, - file_system: None, subsystem: HashMap::new(), } } } -pub struct Connection { - span: Span, - server: Server, +pub struct ConnectionState { audit_log: AuditLog, username: Option, file_system: Option, - subsystem: HashMap>>, } -impl Connection { +impl ConnectionState { + #[cfg(test)] + pub fn mock() -> Self { + use std::net::{IpAddr, Ipv4Addr}; + + ConnectionState { + audit_log: AuditLog { + connection_id: uuid::Uuid::from_bytes([ + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, + ]), + host: Cow::Borrowed("hello world"), + peer_address: Some(SocketAddr::new( + IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), + 1234, + )), + ..AuditLog::default() + }, + username: None, + file_system: None, + } + } +} + +impl ConnectionState { pub fn username(&self) -> &str { self.username.as_deref().unwrap_or("root") } @@ -107,9 +129,18 @@ impl Connection { pub fn audit_log(&mut self) -> &mut AuditLog { &mut self.audit_log } +} + +pub struct Connection { + span: Span, + server: Server, + state: ConnectionState, + subsystem: HashMap>>, +} +impl Connection { fn try_login(&mut self, user: &str, password: &str) -> bool { - self.username = Some(user.to_string()); + self.state.username = Some(user.to_string()); let res = if self .server @@ -131,12 +162,14 @@ impl Connection { false }; - self.audit_log.push_action(AuditLogAction::LoginAttempt( - LoginAttemptEvent::UsernamePassword { - username: Box::from(user), - password: Box::from(password), - }, - )); + self.state + .audit_log + .push_action(AuditLogAction::LoginAttempt( + LoginAttemptEvent::UsernamePassword { + username: Box::from(user), + password: Box::from(password), + }, + )); res } @@ -200,7 +233,8 @@ impl thrussh::server::Handler for Connection { let kind = public_key.name(); let fingerprint = public_key.fingerprint(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::LoginAttempt(LoginAttemptEvent::PublicKey { kind: Cow::Borrowed(kind), fingerprint: Box::from(fingerprint), @@ -285,7 +319,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "channel_open_x11"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::OpenX11(OpenX11Event { originator_address: Box::from(originator_address), originator_port, @@ -307,7 +342,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "channel_open_direct_tcpip"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::OpenDirectTcpIp(OpenDirectTcpIpEvent { host_to_connect: Box::from(host_to_connect), port_to_connect, @@ -332,10 +368,14 @@ impl thrussh::server::Handler for Connection { match &mut *subsystem { Subsystem::Shell(ref mut inner) => { - inner.data(&mut self, channel, &data, &mut session).await; + inner + .data(&mut self.state, channel, &data, &mut session) + .await; } Subsystem::Sftp(ref mut inner) => { - inner.data(&mut self, channel, &data, &mut session).await; + inner + .data(&mut self.state, channel, &data, &mut session) + .await; } } @@ -367,7 +407,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "window_adjusted"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::WindowAdjusted(WindowAdjustedEvent { new_size: new_window_size, })); @@ -396,7 +437,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "pty_request"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::PtyRequest(PtyRequestEvent { term: Box::from(term), col_width, @@ -428,7 +470,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "x11_request"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::X11Request(X11RequestEvent { single_connection, x11_auth_protocol: Box::from(x11_auth_protocol), @@ -450,7 +493,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "env_request"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .environment_variables .push((Box::from(variable_name), Box::from(variable_value))); @@ -462,7 +506,9 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "shell_request"); let _entered = span.enter(); - self.audit_log.push_action(AuditLogAction::ShellRequested); + self.state + .audit_log + .push_action(AuditLogAction::ShellRequested); let shell = Shell::new(true, channel, &mut session); self.subsystem @@ -485,7 +531,9 @@ impl thrussh::server::Handler for Connection { async move { let mut shell = Shell::new(false, channel, &mut session); - shell.data(&mut self, channel, &data, &mut session).await; + shell + .data(&mut self.state, channel, &data, &mut session) + .await; self.subsystem .insert(channel, Arc::new(Mutex::new(Subsystem::Shell(shell)))); @@ -506,7 +554,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "subsystem_request"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::SubsystemRequest(SubsystemRequestEvent { name: Box::from(name), })); @@ -539,7 +588,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "window_change_request"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::WindowChangeRequest( WindowChangeRequestEvent { col_width, @@ -562,7 +612,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "signal"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::Signal(SignalEvent { name: format!("{signal_name:?}").into(), })); @@ -574,7 +625,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "tcpip_forward"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::TcpIpForward(TcpIpForwardEvent { address: Box::from(address), port, @@ -594,7 +646,8 @@ impl thrussh::server::Handler for Connection { let span = info_span!(parent: &self.span, "cancel_tcpip_forward"); let _entered = span.enter(); - self.audit_log + self.state + .audit_log .push_action(AuditLogAction::CancelTcpIpForward(TcpIpForwardEvent { address: Box::from(address), port, @@ -616,7 +669,7 @@ impl Drop for Connection { let _res = self .server .audit_send - .send(std::mem::take(&mut self.audit_log)); + .send(std::mem::take(&mut self.state.audit_log)); } } @@ -626,6 +679,17 @@ pub enum Subsystem { Sftp(subsystem::sftp::Sftp), } +#[cfg_attr(test, mockall::automock)] +pub trait ThrusshSession { + fn data(&mut self, channel: ChannelId, data: CryptoVec); +} + +impl ThrusshSession for Session { + fn data(&mut self, channel: ChannelId, data: CryptoVec) { + Session::data(self, channel, data); + } +} + type HandlerResult = Result::Error>; type HandlerFuture = ServerFuture< ::Error, @@ -669,3 +733,21 @@ impl> + Unpin> Future for ServerFuture ChannelId { + unsafe { std::mem::transmute(0_u32) } + } + + pub mod predicate { + use mockall::{predicate, Predicate}; + use thrussh::CryptoVec; + + pub fn eq_string(s: &str) -> impl Predicate + '_ { + predicate::function(|v: &CryptoVec| &**v == s.as_bytes()) + } + } +} diff --git a/pisshoff-server/src/subsystem/mod.rs b/pisshoff-server/src/subsystem/mod.rs index 395a06e..457c92f 100644 --- a/pisshoff-server/src/subsystem/mod.rs +++ b/pisshoff-server/src/subsystem/mod.rs @@ -1,4 +1,4 @@ -use crate::server::Connection; +use crate::server::ConnectionState; use async_trait::async_trait; use thrussh::server::Session; use thrussh::ChannelId; @@ -12,7 +12,7 @@ pub trait Subsystem { async fn data( &mut self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], session: &mut Session, diff --git a/pisshoff-server/src/subsystem/sftp.rs b/pisshoff-server/src/subsystem/sftp.rs index 83d5370..aaa6c7d 100644 --- a/pisshoff-server/src/subsystem/sftp.rs +++ b/pisshoff-server/src/subsystem/sftp.rs @@ -1,4 +1,4 @@ -use crate::{server::Connection, subsystem::Subsystem}; +use crate::{server::ConnectionState, subsystem::Subsystem}; use async_trait::async_trait; use bytes::Bytes; use nom::{ @@ -29,7 +29,7 @@ impl Subsystem for Sftp { #[allow(clippy::too_many_lines)] async fn data( &mut self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], session: &mut Session, diff --git a/pisshoff-server/src/subsystem/shell.rs b/pisshoff-server/src/subsystem/shell.rs index effd21e..86b0059 100644 --- a/pisshoff-server/src/subsystem/shell.rs +++ b/pisshoff-server/src/subsystem/shell.rs @@ -1,6 +1,6 @@ use crate::{ command::{CommandResult, ConcreteCommand}, - server::Connection, + server::ConnectionState, subsystem::Subsystem, }; use async_trait::async_trait; @@ -47,7 +47,7 @@ impl Subsystem for Shell { async fn data( &mut self, - connection: &mut Connection, + connection: &mut ConnectionState, channel: ChannelId, data: &[u8], session: &mut Session, -- libgit2 1.7.2