From f2959023893ae422dd929a487ad664700bc9c833 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Sun, 25 Jun 2023 15:03:37 +0100 Subject: [PATCH] Fix clean shutdowns with long-running clients --- src/audit.rs | 14 +++++++++----- src/main.rs | 24 +++++++++++++++++++----- 2 files changed, 28 insertions(+), 10 deletions(-) diff --git a/src/audit.rs b/src/audit.rs index e2ae288..b340708 100644 --- a/src/audit.rs +++ b/src/audit.rs @@ -11,7 +11,7 @@ use time::OffsetDateTime; use tokio::{ fs::OpenOptions, io::{AsyncWriteExt, BufWriter}, - sync::watch, + sync::{oneshot, watch}, task::JoinHandle, }; use tracing::{debug, info}; @@ -20,6 +20,7 @@ use uuid::Uuid; pub fn start_audit_writer( config: Arc, mut reload: watch::Receiver<()>, + mut shutdown_recv: oneshot::Receiver<()>, ) -> ( tokio::sync::mpsc::UnboundedSender, JoinHandle>, @@ -39,9 +40,9 @@ pub fn start_audit_writer( let mut writer = open_writer().await?; let mut shutdown = false; - loop { + while !shutdown { tokio::select! { - log = recv.recv(), if !shutdown => { + log = recv.recv() => { match log { Some(log) => { let log = serde_json::to_vec(&log) @@ -54,11 +55,14 @@ pub fn start_audit_writer( } } } - _ = tokio::time::sleep(Duration::from_secs(5)), if !writer.buffer().is_empty() && !shutdown => { + _ = &mut shutdown_recv => { + shutdown = true; + } + _ = tokio::time::sleep(Duration::from_secs(5)), if !writer.buffer().is_empty() => { debug!("Flushing audits to disk"); writer.flush().await?; } - Ok(()) = reload.changed(), if !shutdown => { + Ok(()) = reload.changed() => { info!("Flushing audits to disk"); writer.flush().await?; diff --git a/src/main.rs b/src/main.rs index 09b9fa4..1c45609 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,10 @@ use clap::Parser; use futures::FutureExt; use std::sync::Arc; use thrussh::MethodSet; -use tokio::{signal::unix::SignalKind, sync::watch}; +use tokio::{ + signal::unix::SignalKind, + sync::{oneshot, watch}, +}; use tracing::{error, info}; use tracing_subscriber::EnvFilter; @@ -49,24 +52,26 @@ async fn run() -> anyhow::Result<()> { }); let (reload_send, reload_recv) = watch::channel(()); + let (shutdown_send, shutdown_recv) = oneshot::channel(); - let (audit_send, audit_handle) = audit::start_audit_writer(args.config.clone(), reload_recv); + let (audit_send, audit_handle) = + audit::start_audit_writer(args.config.clone(), reload_recv, shutdown_recv); let mut audit_handle = audit_handle.fuse(); let server = Server::new(args.config.clone(), audit_send); let listen_address = args.config.listen_address.to_string(); + // TODO: needs clean shutdowns on clients let fut = thrussh::server::run(thrussh_config, &listen_address, server); + let shutdown_watcher = watch_for_shutdown(shutdown_send); let reload_watcher = watch_for_reloads(reload_send); tokio::select! { res = fut => res?, res = &mut audit_handle => res??, + res = shutdown_watcher => res?, res = reload_watcher => res?, - _ = tokio::signal::ctrl_c() => { - info!("Received ctrl-c, initiating shutdown"); - } } info!("Finishing audit log writes"); @@ -76,6 +81,15 @@ async fn run() -> anyhow::Result<()> { Ok(()) } +async fn watch_for_shutdown(send: oneshot::Sender<()>) -> Result<(), anyhow::Error> { + tokio::signal::ctrl_c().await?; + info!("Received ctrl-c, initiating shutdown"); + + let _res = send.send(()); + + Ok(()) +} + async fn watch_for_reloads(send: watch::Sender<()>) -> Result<(), anyhow::Error> { let mut signal = tokio::signal::unix::signal(SignalKind::hangup())?; -- libgit2 1.7.2