//! Logs each and every request out in a format similar to that of Apache's logs. use std::{ fmt::Debug, net::SocketAddr, task::{Context, Poll}, time::Instant, }; use axum::{ extract, http::{HeaderValue, Method, Request, Response}, }; use futures::future::{Future, FutureExt, Join, Map, Ready}; use tokio::task::futures::TaskLocalFuture; use tower_service::Service; use tracing::{error, info, instrument::Instrumented, Instrument, Span}; use uuid::Uuid; use super::UnwrapInfallible; pub trait GenericError: std::error::Error + Debug + Send + Sync {} #[derive(Clone)] pub struct LoggingMiddleware(pub S); impl Service> for LoggingMiddleware where S: Service, Response = Response, Error = std::convert::Infallible> + Clone + Send + 'static, S::Future: Send + 'static, S::Response: Default + Debug, ReqBody: Send + Debug + 'static, ResBody: Default + Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = Map< Join>, Ready>, fn((::Output, PendingLogMessage)) -> ::Output, >; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { let request_id = Uuid::new_v4(); let span = tracing::info_span!("web", "request_id" = request_id.to_string().as_str()); let log_message = PendingLogMessage { span: span.clone(), request_id, ip: req .extensions() .get::>() .map_or_else(|| "0.0.0.0:0".parse().unwrap(), |v| v.0), method: req.method().clone(), uri: req.uri().path().to_string(), start: Instant::now(), user_agent: req.headers().get(axum::http::header::USER_AGENT).cloned(), }; futures::future::join( REQ_TIMESTAMP.scope(log_message.start, self.0.call(req).instrument(span)), futures::future::ready(log_message), ) .map(|(response, pending_log_message)| { let mut response = response.unwrap_infallible(); pending_log_message.log(&response); response.headers_mut().insert( "X-Request-ID", HeaderValue::try_from(pending_log_message.request_id.to_string()).unwrap(), ); Ok(response) }) } } tokio::task_local! { pub static REQ_TIMESTAMP: Instant; } pub struct PendingLogMessage { span: Span, request_id: Uuid, ip: SocketAddr, method: Method, uri: String, start: Instant, user_agent: Option, } impl PendingLogMessage { pub fn log(&self, response: &Response) { let _enter = self.span.enter(); if response.status().is_server_error() { error!( "{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"", ip = self.ip, method = self.method, uri = self.uri, status = response.status().as_u16(), duration = self.start.elapsed(), user_agent = self .user_agent .as_ref() .and_then(|v| v.to_str().ok()) .unwrap_or("unknown"), error = match response.extensions().get::>() { Some(e) => Err(e), None => Ok(()), } ); } else { info!( "{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"", ip = self.ip, method = self.method, uri = self.uri, status = response.status().as_u16(), duration = self.start.elapsed(), user_agent = self .user_agent .as_ref() .and_then(|v| v.to_str().ok()) .unwrap_or("unknown"), error = match response.extensions().get::>() { Some(e) => Err(e), None => Ok(()), } ); } } }