//! Adds the user's IP address to the request, taking into account the `trusted_ip_header` config //! value. use axum::{ body::BoxBody, extract::{self, FromRequest, RequestParts}, http::{Request, Response}, }; use futures::future::BoxFuture; use tower::{Layer, Service}; use std::{ net::IpAddr, str::FromStr, sync::Arc, task::{Context, Poll}, }; #[derive(Clone)] pub struct AddIp { trusted_ip_header: Option>, } impl AddIp { pub fn new(trusted_ip_header: Option) -> Self { Self { trusted_ip_header: trusted_ip_header.map(Arc::from), } } } impl Layer for AddIp { type Service = AddIpService; fn layer(&self, inner: S) -> Self::Service { AddIpService { inner, trusted_ip_header: self.trusted_ip_header.clone(), } } } #[derive(Clone)] pub struct AddIpService { inner: S, trusted_ip_header: Option>, } impl Service> for AddIpService where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, ReqBody: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.inner.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { // ensure we take the instance that has already been poll_ready'd let clone = self.clone(); let mut this = std::mem::replace(self, clone); Box::pin(async move { let mut req = RequestParts::new(req); let mut ip = None; if let Some(trusted_ip_header) = this.trusted_ip_header.as_deref() { ip = req .headers() .get(trusted_ip_header) .and_then(|v| v.to_str().ok()) .and_then(|v| IpAddr::from_str(v).ok()); } // no trusted ip header, fallback to the socket address if ip.is_none() { ip = extract::ConnectInfo::::from_request(&mut req) .await .map(|v| v.0.ip()) .ok(); } req.extensions_mut().insert(ip.unwrap()); this.inner.call(req.try_into_request().unwrap()).await }) } }