Add basic tests for each command
Diff
Cargo.lock | 98 ++++++++++++++++++++++++++++++++++++++++++++++++++-
pisshoff-server/Cargo.toml | 3 +-
pisshoff-server/src/command.rs | 30 ++++++++++-----
pisshoff-server/src/command/echo.rs | 57 ++++++++++++++++++++++++-----
pisshoff-server/src/command/exit.rs | 50 +++++++++++++++++++++-----
pisshoff-server/src/command/ls.rs | 64 ++++++++++++++++++++++++++++-----
pisshoff-server/src/command/pwd.rs | 49 ++++++++++++++++++++-----
pisshoff-server/src/command/scp.rs | 116 ++++++++++++++++++++++++++++++++++++++++++++++++++++++-----
pisshoff-server/src/command/snapshots/pisshoff_server__command__scp__test__works.snap | 22 +++++++++++-
pisshoff-server/src/command/uname.rs | 16 ++++----
pisshoff-server/src/command/whoami.rs | 49 ++++++++++++++++++++-----
pisshoff-server/src/server.rs | 156 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
pisshoff-server/src/subsystem/mod.rs | 4 +-
pisshoff-server/src/subsystem/sftp.rs | 4 +-
pisshoff-server/src/subsystem/shell.rs | 4 +-
15 files changed, 611 insertions(+), 111 deletions(-)
@@ -394,6 +394,12 @@ dependencies = [
]
[[package]]
name = "difflib"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6184e33543162437515c2e2b48714794e37845ec9851711914eec9d308f6ebe8"
[[package]]
name = "digest"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -434,6 +440,12 @@ dependencies = [
]
[[package]]
name = "downcast"
version = "0.11.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1435fa1053d8b2fbbe9be7e97eca7f33d37b28409959813daefc1446a14247f1"
[[package]]
name = "either"
version = "1.8.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -498,6 +510,15 @@ dependencies = [
]
[[package]]
name = "float-cmp"
version = "0.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "98de4bbd547a563b716d8dfa9aad1cb19bfab00f4fa09a6a4ed21dbcf44ce9c4"
dependencies = [
"num-traits",
]
[[package]]
name = "form_urlencoded"
version = "1.2.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -507,6 +528,12 @@ dependencies = [
]
[[package]]
name = "fragile"
version = "2.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6c2141d6d6c8512188a7891b4b01590a45f6dac67afb4f255c4124dbb86d4eaa"
[[package]]
name = "futures"
version = "0.3.28"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -700,6 +727,7 @@ dependencies = [
"console",
"lazy_static",
"linked-hash-map",
"regex",
"similar",
"yaml-rust",
]
@@ -869,6 +897,33 @@ dependencies = [
]
[[package]]
name = "mockall"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4c84490118f2ee2d74570d114f3d0493cbf02790df303d2707606c3e14e07c96"
dependencies = [
"cfg-if",
"downcast",
"fragile",
"lazy_static",
"mockall_derive",
"predicates",
"predicates-tree",
]
[[package]]
name = "mockall_derive"
version = "0.11.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "22ce75669015c4f47b289fd4d4f56e894e4c96003ffdf3ac51313126f94c6cbb"
dependencies = [
"cfg-if",
"proc-macro2",
"quote",
"syn 1.0.109",
]
[[package]]
name = "nix"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -893,6 +948,12 @@ dependencies = [
]
[[package]]
name = "normalize-line-endings"
version = "0.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "61807f77802ff30975e01f4f071c8ba10c022052f98b3294119f3e615d13e5be"
[[package]]
name = "nu-ansi-term"
version = "0.46.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1065,6 +1126,7 @@ dependencies = [
"futures",
"insta",
"itertools",
"mockall",
"nix",
"nom",
"parking_lot",
@@ -1160,6 +1222,36 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de"
[[package]]
name = "predicates"
version = "2.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59230a63c37f3e18569bdb90e4a89cbf5bf8b06fea0b84e65ea10cc4df47addd"
dependencies = [
"difflib",
"float-cmp",
"itertools",
"normalize-line-endings",
"predicates-core",
"regex",
]
[[package]]
name = "predicates-core"
version = "1.0.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b794032607612e7abeb4db69adb4e33590fa6cf1149e95fd7cb00e634b92f174"
[[package]]
name = "predicates-tree"
version = "1.0.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "368ba315fb8c5052ab692e68a0eefec6ec57b23a36959c14496f0b0df2c0cecf"
dependencies = [
"predicates-core",
"termtree",
]
[[package]]
name = "proc-macro-error"
version = "1.0.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1592,6 +1684,12 @@ dependencies = [
]
[[package]]
name = "termtree"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3369f5ac52d5eb6ab48c6b4ffdc8efbcad6b89c765749064ba298f2c68a16a76"
[[package]]
name = "test-case"
version = "3.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -32,5 +32,6 @@ tracing-subscriber = { version = "0.3", features = ["env-filter"] }
uuid = { version = "1.3", features = ["v4", "serde"] }
[dev-dependencies]
insta = "1.29"
mockall = "0.11"
insta = { version = "1.29", features = ["filters"] }
test-case = "3.1"
@@ -6,18 +6,20 @@ mod scp;
mod uname;
mod whoami;
use crate::server::Connection;
use crate::server::{ConnectionState, ThrusshSession};
use async_trait::async_trait;
use itertools::Either;
use std::fmt::Debug;
use thrussh::{server::Session, ChannelId};
#[derive(Debug)]
pub enum CommandResult<T> {
ReadStdin(T),
Exit(u32),
Close(u32),
}
impl<T> CommandResult<T> {
impl<T: Debug> CommandResult<T> {
fn map<N>(self, f: fn(T) -> N) -> CommandResult<N> {
match self {
Self::ReadStdin(val) => CommandResult::ReadStdin(f(val)),
@@ -25,23 +27,31 @@ impl<T> CommandResult<T> {
Self::Close(v) => CommandResult::Close(v),
}
}
#[cfg(test)]
pub fn unwrap_stdin(self) -> T {
match self {
Self::ReadStdin(val) => val,
v => panic!("got {v:?}, expected ReadStdin"),
}
}
}
#[async_trait]
pub trait Command: Sized {
async fn new(
connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self>;
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,
session: &mut S,
) -> CommandResult<Self>;
}
@@ -54,7 +64,7 @@ macro_rules! define_commands {
impl ConcreteCommand {
pub async fn new(
connection: &mut Connection,
connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
@@ -78,7 +88,7 @@ macro_rules! define_commands {
pub async fn stdin(
self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,
@@ -1,34 +1,75 @@
use crate::{
command::{Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use itertools::Itertools;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
#[derive(Debug, Clone)]
pub struct Echo {}
#[async_trait]
impl Command for Echo {
async fn new(
_connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
_connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
session.data(channel, format!("{}\n", params.iter().join(" ")).into());
CommandResult::Exit(0)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
}
#[cfg(test)]
mod test {
use crate::{
command::{echo::Echo, Command, CommandResult},
server::{
test::{fake_channel_id, predicate::eq_string},
ConnectionState, MockThrusshSession,
},
};
use mockall::predicate::always;
use test_case::test_case;
#[test_case(&[], "\n"; "no parameters")]
#[test_case(&["hello"], "hello\n"; "single parameter")]
#[test_case(&["hello", "world"], "hello world\n"; "multiple parameters")]
#[tokio::test]
async fn test(params: &[&str], output: &'static str) {
let mut session = MockThrusshSession::default();
session
.expect_data()
.once()
.with(always(), eq_string(output))
.returning(|_, _| ());
let out = Echo::new(
&mut ConnectionState::mock(),
params
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(matches!(out, CommandResult::Exit(0)), "{out:?}");
}
}
@@ -1,21 +1,21 @@
use crate::{
command::{Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use std::str::FromStr;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
#[derive(Debug, Clone)]
pub struct Exit {}
#[async_trait]
impl Command for Exit {
async fn new(
_connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
_connection: &mut ConnectionState,
params: &[String],
_channel: ChannelId,
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
let exit_status = params
.get(0)
@@ -26,13 +26,47 @@ impl Command for Exit {
CommandResult::Close(exit_status)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
}
#[cfg(test)]
mod test {
use crate::{
command::{exit::Exit, Command, CommandResult},
server::{test::fake_channel_id, ConnectionState, MockThrusshSession},
};
use test_case::test_case;
#[test_case(&[], 0; "no parameters")]
#[test_case(&["3"], 3; "with parameter")]
#[test_case(&["invalid"], 2; "invalid parameter")]
#[tokio::test]
async fn test(params: &[&str], expected_exit_code: u32) {
let mut session = MockThrusshSession::default();
let out = Exit::new(
&mut ConnectionState::mock(),
params
.iter()
.map(ToString::to_string)
.collect::<Vec<_>>()
.as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(
matches!(out, CommandResult::Close(v) if v == expected_exit_code),
"{out:?}"
);
}
}
@@ -1,21 +1,21 @@
use crate::{
command::{Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use std::fmt::Write;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
#[derive(Debug, Clone)]
pub struct Ls {}
#[async_trait]
impl Command for Ls {
async fn new(
connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
let resp = if params.is_empty() {
connection.file_system().ls(None).join(" ")
@@ -46,13 +46,61 @@ impl Command for Ls {
CommandResult::Exit(0)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
}
#[cfg(test)]
mod test {
use crate::{
command::{ls::Ls, Command, CommandResult},
server::{
test::{fake_channel_id, predicate::eq_string},
ConnectionState, MockThrusshSession,
},
};
use mockall::predicate::always;
#[tokio::test]
async fn empty_pwd() {
let mut session = MockThrusshSession::default();
let out = Ls::new(
&mut ConnectionState::mock(),
[].as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(matches!(out, CommandResult::Exit(0)), "{out:?}");
}
#[tokio::test]
async fn multiple_empty_directories() {
let mut session = MockThrusshSession::default();
session
.expect_data()
.once()
.with(always(), eq_string("a:\n\nb:\n"))
.returning(|_, _| ());
let out = Ls::new(
&mut ConnectionState::mock(),
["a".to_string(), "b".to_string()].as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(matches!(out, CommandResult::Exit(0)), "{out:?}");
}
}
@@ -1,20 +1,20 @@
use crate::{
command::{Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
#[derive(Debug, Clone)]
pub struct Pwd {}
#[async_trait]
impl Command for Pwd {
async fn new(
connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
connection: &mut ConnectionState,
_params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
session.data(
channel,
@@ -24,13 +24,46 @@ impl Command for Pwd {
CommandResult::Exit(0)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
}
#[cfg(test)]
mod test {
use crate::{
command::{pwd::Pwd, Command, CommandResult},
server::{
test::{fake_channel_id, predicate::eq_string},
ConnectionState, MockThrusshSession,
},
};
use mockall::predicate::always;
#[tokio::test]
async fn works() {
let mut session = MockThrusshSession::default();
session
.expect_data()
.once()
.with(always(), eq_string("/root\n"))
.returning(|_, _| ());
let out = Pwd::new(
&mut ConnectionState::mock(),
[].as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(matches!(out, CommandResult::Exit(0)), "{out:?}");
}
}
@@ -1,6 +1,6 @@
use crate::{
command::{Arg, Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use bytes::{Buf, BytesMut};
@@ -12,7 +12,7 @@ use nom::{
};
use pisshoff_types::audit::{AuditLogAction, WriteFileEvent};
use std::{path::PathBuf, str::FromStr};
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
use tracing::warn;
const HELP: &str = "usage: scp [-346ABCOpqRrsTv] [-c cipher] [-D sftp_server_path] [-F ssh_config]
@@ -33,11 +33,11 @@ pub struct Scp {
#[async_trait]
impl Command for Scp {
async fn new(
_connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
_connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
let mut path = None;
let mut transfer = false;
@@ -80,12 +80,12 @@ impl Command for Scp {
})
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
mut self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
self.pending_data.extend_from_slice(data);
@@ -173,7 +173,7 @@ enum State {
AwaitingSeparator,
}
#[derive(Debug)]
#[derive(Debug, PartialEq, Eq)]
#[allow(dead_code)]
enum Receive<'a> {
FileCopy {
@@ -278,3 +278,101 @@ impl<'a> Receive<'a> {
}
}
}
#[cfg(test)]
mod test {
use crate::{
command::{scp::Scp, Command},
server::{
test::{fake_channel_id, predicate::eq_string},
ConnectionState, MockThrusshSession,
},
};
use insta::assert_debug_snapshot;
use mockall::predicate::always;
mod packet_parser {
use crate::command::scp::Receive;
#[test]
fn file_copy() {
let (_, actual) = Receive::parse(b"C0777 1234 test.txt\n").unwrap();
let expected = Receive::FileCopy {
mode: "0777",
length: 1234,
file_name: "test.txt",
};
assert_eq!(actual, expected);
}
#[test]
fn directory_copy() {
let (_, actual) = Receive::parse(b"D0777 1234 test\n").unwrap();
let expected = Receive::DirectoryCopy {
mode: "0777",
length: 1234,
directory_name: "test",
};
assert_eq!(actual, expected);
}
#[test]
fn end_directory() {
let (_, actual) = Receive::parse(b"E\n").unwrap();
let expected = Receive::EndDirectory;
assert_eq!(actual, expected);
}
#[test]
fn access_time() {
let (_, actual) = Receive::parse(b"T123 444 555 666\n").unwrap();
let expected = Receive::AccessTime {
modified_time: 123,
modified_time_micros: 444,
access_time: 555,
access_time_micros: 666,
};
assert_eq!(actual, expected);
}
}
#[tokio::test]
async fn works() {
let mut session = MockThrusshSession::default();
let mut state = ConnectionState::mock();
session
.expect_data()
.with(always(), eq_string("\0"))
.returning(|_, _| ());
let out = Scp::new(
&mut state,
["-t".to_string(), "hello".to_string()].as_slice(),
fake_channel_id(),
&mut session,
)
.await
.unwrap_stdin();
let _out = out
.stdin(
&mut state,
fake_channel_id(),
b"C0777 11 hello.txt\nhello world\0",
&mut session,
)
.await
.unwrap_stdin();
insta::with_settings!({filters => vec![
(r#"\bstart_offset: [^,]+"#, "start_offset: [stripped]")
]}, {
assert_debug_snapshot!(state.audit_log());
});
}
}
@@ -0,0 +1,22 @@
---
source: pisshoff-server/src/command/scp.rs
expression: state.audit_log()
---
AuditLog {
connection_id: 01020304-0506-0708-090a-0b0c0d0e0f10,
peer_address: Some(
127.0.0.1:1234,
),
environment_variables: [],
events: [
AuditLogEvent {
start_offset: [stripped],
action: WriteFile(
WriteFileEvent {
path: "hello/hello.txt",
content: b"hello world",
},
),
},
],
}
@@ -1,10 +1,10 @@
use crate::{
command::{Arg, Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use bitflags::bitflags;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
bitflags! {
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
@@ -56,11 +56,11 @@ pub struct Uname {}
#[async_trait]
impl Command for Uname {
async fn new(
_connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
_connection: &mut ConnectionState,
params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
let (out, exit_code) = execute(params);
@@ -68,12 +68,12 @@ impl Command for Uname {
CommandResult::Exit(exit_code)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
@@ -1,32 +1,65 @@
use crate::{
command::{Command, CommandResult},
server::Connection,
server::{ConnectionState, ThrusshSession},
};
use async_trait::async_trait;
use thrussh::{server::Session, ChannelId};
use thrussh::ChannelId;
#[derive(Debug, Clone)]
pub struct Whoami {}
#[async_trait]
impl Command for Whoami {
async fn new(
connection: &mut Connection,
async fn new<S: ThrusshSession + Send>(
connection: &mut ConnectionState,
_params: &[String],
channel: ChannelId,
session: &mut Session,
session: &mut S,
) -> CommandResult<Self> {
session.data(channel, format!("{}\n", connection.username()).into());
CommandResult::Exit(0)
}
async fn stdin(
async fn stdin<S: ThrusshSession + Send>(
self,
_connection: &mut Connection,
_connection: &mut ConnectionState,
_channel: ChannelId,
_data: &[u8],
_session: &mut Session,
_session: &mut S,
) -> CommandResult<Self> {
CommandResult::Exit(0)
}
}
#[cfg(test)]
mod test {
use crate::{
command::{whoami::Whoami, Command, CommandResult},
server::{
test::{fake_channel_id, predicate::eq_string},
ConnectionState, MockThrusshSession,
},
};
use mockall::predicate::always;
#[tokio::test]
async fn works() {
let mut session = MockThrusshSession::default();
session
.expect_data()
.once()
.with(always(), eq_string("root\n"))
.returning(|_, _| ());
let out = Whoami::new(
&mut ConnectionState::mock(),
[].as_slice(),
fake_channel_id(),
&mut session,
)
.await;
assert!(matches!(out, CommandResult::Exit(0)), "{out:?}");
}
}
@@ -27,7 +27,7 @@ use std::{
};
use thrussh::{
server::{Auth, Response, Session},
ChannelId, Pty, Sig,
ChannelId, CryptoVec, Pty, Sig,
};
use thrussh_keys::key::PublicKey;
use tokio::sync::mpsc::UnboundedSender;
@@ -69,29 +69,51 @@ impl thrussh::server::Server for Server {
Connection {
span: info_span!("connection", ?peer_addr, %connection_id),
server: self.clone(),
audit_log: AuditLog {
connection_id,
host: Cow::Borrowed(self.hostname),
peer_address: peer_addr,
..AuditLog::default()
state: ConnectionState {
audit_log: AuditLog {
connection_id,
host: Cow::Borrowed(self.hostname),
peer_address: peer_addr,
..AuditLog::default()
},
username: None,
file_system: None,
},
username: None,
file_system: None,
subsystem: HashMap::new(),
}
}
}
pub struct Connection {
span: Span,
server: Server,
pub struct ConnectionState {
audit_log: AuditLog,
username: Option<String>,
file_system: Option<FileSystem>,
subsystem: HashMap<ChannelId, Arc<Mutex<Subsystem>>>,
}
impl Connection {
impl ConnectionState {
#[cfg(test)]
pub fn mock() -> Self {
use std::net::{IpAddr, Ipv4Addr};
ConnectionState {
audit_log: AuditLog {
connection_id: uuid::Uuid::from_bytes([
1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16,
]),
host: Cow::Borrowed("hello world"),
peer_address: Some(SocketAddr::new(
IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)),
1234,
)),
..AuditLog::default()
},
username: None,
file_system: None,
}
}
}
impl ConnectionState {
pub fn username(&self) -> &str {
self.username.as_deref().unwrap_or("root")
}
@@ -107,9 +129,18 @@ impl Connection {
pub fn audit_log(&mut self) -> &mut AuditLog {
&mut self.audit_log
}
}
pub struct Connection {
span: Span,
server: Server,
state: ConnectionState,
subsystem: HashMap<ChannelId, Arc<Mutex<Subsystem>>>,
}
impl Connection {
fn try_login(&mut self, user: &str, password: &str) -> bool {
self.username = Some(user.to_string());
self.state.username = Some(user.to_string());
let res = if self
.server
@@ -131,12 +162,14 @@ impl Connection {
false
};
self.audit_log.push_action(AuditLogAction::LoginAttempt(
LoginAttemptEvent::UsernamePassword {
username: Box::from(user),
password: Box::from(password),
},
));
self.state
.audit_log
.push_action(AuditLogAction::LoginAttempt(
LoginAttemptEvent::UsernamePassword {
username: Box::from(user),
password: Box::from(password),
},
));
res
}
@@ -200,7 +233,8 @@ impl thrussh::server::Handler for Connection {
let kind = public_key.name();
let fingerprint = public_key.fingerprint();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::LoginAttempt(LoginAttemptEvent::PublicKey {
kind: Cow::Borrowed(kind),
fingerprint: Box::from(fingerprint),
@@ -285,7 +319,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "channel_open_x11");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::OpenX11(OpenX11Event {
originator_address: Box::from(originator_address),
originator_port,
@@ -307,7 +342,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "channel_open_direct_tcpip");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::OpenDirectTcpIp(OpenDirectTcpIpEvent {
host_to_connect: Box::from(host_to_connect),
port_to_connect,
@@ -332,10 +368,14 @@ impl thrussh::server::Handler for Connection {
match &mut *subsystem {
Subsystem::Shell(ref mut inner) => {
inner.data(&mut self, channel, &data, &mut session).await;
inner
.data(&mut self.state, channel, &data, &mut session)
.await;
}
Subsystem::Sftp(ref mut inner) => {
inner.data(&mut self, channel, &data, &mut session).await;
inner
.data(&mut self.state, channel, &data, &mut session)
.await;
}
}
@@ -367,7 +407,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "window_adjusted");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::WindowAdjusted(WindowAdjustedEvent {
new_size: new_window_size,
}));
@@ -396,7 +437,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "pty_request");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::PtyRequest(PtyRequestEvent {
term: Box::from(term),
col_width,
@@ -428,7 +470,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "x11_request");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::X11Request(X11RequestEvent {
single_connection,
x11_auth_protocol: Box::from(x11_auth_protocol),
@@ -450,7 +493,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "env_request");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.environment_variables
.push((Box::from(variable_name), Box::from(variable_value)));
@@ -462,7 +506,9 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "shell_request");
let _entered = span.enter();
self.audit_log.push_action(AuditLogAction::ShellRequested);
self.state
.audit_log
.push_action(AuditLogAction::ShellRequested);
let shell = Shell::new(true, channel, &mut session);
self.subsystem
@@ -485,7 +531,9 @@ impl thrussh::server::Handler for Connection {
async move {
let mut shell = Shell::new(false, channel, &mut session);
shell.data(&mut self, channel, &data, &mut session).await;
shell
.data(&mut self.state, channel, &data, &mut session)
.await;
self.subsystem
.insert(channel, Arc::new(Mutex::new(Subsystem::Shell(shell))));
@@ -506,7 +554,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "subsystem_request");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::SubsystemRequest(SubsystemRequestEvent {
name: Box::from(name),
}));
@@ -539,7 +588,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "window_change_request");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::WindowChangeRequest(
WindowChangeRequestEvent {
col_width,
@@ -562,7 +612,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "signal");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::Signal(SignalEvent {
name: format!("{signal_name:?}").into(),
}));
@@ -574,7 +625,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "tcpip_forward");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::TcpIpForward(TcpIpForwardEvent {
address: Box::from(address),
port,
@@ -594,7 +646,8 @@ impl thrussh::server::Handler for Connection {
let span = info_span!(parent: &self.span, "cancel_tcpip_forward");
let _entered = span.enter();
self.audit_log
self.state
.audit_log
.push_action(AuditLogAction::CancelTcpIpForward(TcpIpForwardEvent {
address: Box::from(address),
port,
@@ -616,7 +669,7 @@ impl Drop for Connection {
let _res = self
.server
.audit_send
.send(std::mem::take(&mut self.audit_log));
.send(std::mem::take(&mut self.state.audit_log));
}
}
@@ -626,6 +679,17 @@ pub enum Subsystem {
Sftp(subsystem::sftp::Sftp),
}
#[cfg_attr(test, mockall::automock)]
pub trait ThrusshSession {
fn data(&mut self, channel: ChannelId, data: CryptoVec);
}
impl ThrusshSession for Session {
fn data(&mut self, channel: ChannelId, data: CryptoVec) {
Session::data(self, channel, data);
}
}
type HandlerResult<T> = Result<T, <Connection as thrussh::server::Handler>::Error>;
type HandlerFuture<T> = ServerFuture<
<Connection as thrussh::server::Handler>::Error,
@@ -669,3 +733,21 @@ impl<T, E, F: Future<Output = Result<T, E>> + Unpin> Future for ServerFuture<E,
Pin::new(&mut self.0).poll(cx)
}
}
#[cfg(test)]
pub mod test {
use thrussh::ChannelId;
pub fn fake_channel_id() -> ChannelId {
unsafe { std::mem::transmute(0_u32) }
}
pub mod predicate {
use mockall::{predicate, Predicate};
use thrussh::CryptoVec;
pub fn eq_string(s: &str) -> impl Predicate<CryptoVec> + '_ {
predicate::function(|v: &CryptoVec| &**v == s.as_bytes())
}
}
}
@@ -1,4 +1,4 @@
use crate::server::Connection;
use crate::server::ConnectionState;
use async_trait::async_trait;
use thrussh::server::Session;
use thrussh::ChannelId;
@@ -12,7 +12,7 @@ pub trait Subsystem {
async fn data(
&mut self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,
@@ -1,4 +1,4 @@
use crate::{server::Connection, subsystem::Subsystem};
use crate::{server::ConnectionState, subsystem::Subsystem};
use async_trait::async_trait;
use bytes::Bytes;
use nom::{
@@ -29,7 +29,7 @@ impl Subsystem for Sftp {
#[allow(clippy::too_many_lines)]
async fn data(
&mut self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,
@@ -1,6 +1,6 @@
use crate::{
command::{CommandResult, ConcreteCommand},
server::Connection,
server::ConnectionState,
subsystem::Subsystem,
};
use async_trait::async_trait;
@@ -47,7 +47,7 @@ impl Subsystem for Shell {
async fn data(
&mut self,
connection: &mut Connection,
connection: &mut ConnectionState,
channel: ChannelId,
data: &[u8],
session: &mut Session,