From bc61804d5b2e58ff064ffeeaee6224efd1278001 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Sat, 17 Sep 2022 13:55:59 +0100 Subject: [PATCH] Allow a configurable header to override the user's IP --- chartered-web/src/config.rs | 1 + chartered-web/src/main.rs | 4 +++- book/src/guide/config-reference.md | 7 +++++++ chartered-web/src/middleware/ip.rs | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++ chartered-web/src/middleware/logging.rs | 14 ++++++++------ chartered-web/src/middleware/mod.rs | 1 + chartered-web/src/middleware/rate_limit.rs | 7 ++----- chartered-web/src/endpoints/web_api/auth/mod.rs | 10 ++++++---- chartered-web/src/endpoints/web_api/auth/openid.rs | 6 ++++-- chartered-web/src/endpoints/web_api/auth/password.rs | 6 ++++-- 10 files changed, 130 insertions(+), 20 deletions(-) diff --git a/chartered-web/src/config.rs b/chartered-web/src/config.rs index a64428d..ebd1ad0 100644 --- a/chartered-web/src/config.rs +++ a/chartered-web/src/config.rs @@ -25,6 +25,7 @@ pub database_uri: String, pub storage_uri: String, pub frontend_base_uri: Url, + pub trusted_ip_header: Option, pub auth: AuthConfig, #[serde(deserialize_with = "deserialize_encryption_key")] pub encryption_key: ChaCha20Poly1305Key, diff --git a/chartered-web/src/main.rs b/chartered-web/src/main.rs index 1083c5b..84ca16c 100644 --- a/chartered-web/src/main.rs +++ a/chartered-web/src/main.rs @@ -6,6 +6,7 @@ mod endpoints; mod middleware; +use crate::middleware::ip::AddIp; use crate::middleware::rate_limit::RateLimit; use axum::{ http::{header, Method}, @@ -127,7 +128,8 @@ .layer(Extension(Arc::new(config.create_oidc_clients().await?))) .layer(Extension(Arc::new(config.get_file_system().await?))) .layer(Extension(config.clone())) - .layer(Extension(http_client)); + .layer(Extension(http_client)) + .layer(AddIp::new(config.trusted_ip_header.clone())); info!("HTTP server listening on {}", bind_address); diff --git a/book/src/guide/config-reference.md b/book/src/guide/config-reference.md index 8a9541b..6909506 100644 --- a/book/src/guide/config-reference.md +++ a/book/src/guide/config-reference.md @@ -81,6 +81,7 @@ storage_uri = "s3://s3-eu-west-1.amazonaws.com/my-cool-crate-store/" # or file:///var/lib/chartered frontend_base_uri = "http://localhost:5173/" +trusted_ip_header = "x-forwarded-for" [auth.password] enabled = true # enables password auth @@ -118,6 +119,12 @@ - Type: `string` The base URL at which the frontend is being hosted. + +#### `trusted_ip_header` +- Type: `string` +- Default: null + +Allows a header to override the socket address as the end user's IP address #### `[auth.password]` The `[auth.password]` table controls the username/password-based authentication method. diff --git a/chartered-web/src/middleware/ip.rs b/chartered-web/src/middleware/ip.rs new file mode 100644 index 0000000..25d8515 100644 --- /dev/null +++ a/chartered-web/src/middleware/ip.rs @@ -1,0 +1,94 @@ +//! Adds the user's IP address to the request, taking into account the `trusted_ip_header` config +//! value. + +use axum::{ + body::BoxBody, + extract::{self, FromRequest, RequestParts}, + http::{Request, Response}, +}; +use futures::future::BoxFuture; +use tower::{Layer, Service}; + +use std::{ + net::IpAddr, + str::FromStr, + sync::Arc, + task::{Context, Poll}, +}; + +#[derive(Clone)] +pub struct AddIp { + trusted_ip_header: Option>, +} + +impl AddIp { + pub fn new(trusted_ip_header: Option) -> Self { + Self { + trusted_ip_header: trusted_ip_header.map(Arc::from), + } + } +} + +impl Layer for AddIp { + type Service = AddIpService; + + fn layer(&self, inner: S) -> Self::Service { + AddIpService { + inner, + trusted_ip_header: self.trusted_ip_header.clone(), + } + } +} + +#[derive(Clone)] +pub struct AddIpService { + inner: S, + trusted_ip_header: Option>, +} + +impl Service> for AddIpService +where + S: Service, Response = Response> + Clone + Send + 'static, + S::Future: Send + 'static, + ReqBody: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, req: Request) -> Self::Future { + // ensure we take the instance that has already been poll_ready'd + let clone = self.clone(); + let mut this = std::mem::replace(self, clone); + + Box::pin(async move { + let mut req = RequestParts::new(req); + + let mut ip = None; + + if let Some(trusted_ip_header) = this.trusted_ip_header.as_deref() { + ip = req + .headers() + .get(trusted_ip_header) + .and_then(|v| v.to_str().ok()) + .and_then(|v| IpAddr::from_str(v).ok()); + } + + // no trusted ip header, fallback to the socket address + if ip.is_none() { + ip = extract::ConnectInfo::::from_request(&mut req) + .await + .map(|v| v.0.ip()) + .ok(); + } + + req.extensions_mut().insert(ip.unwrap()); + + this.inner.call(req.try_into_request().unwrap()).await + }) + } +} diff --git a/chartered-web/src/middleware/logging.rs b/chartered-web/src/middleware/logging.rs index 3425a25..e4d3ccf 100644 --- a/chartered-web/src/middleware/logging.rs +++ a/chartered-web/src/middleware/logging.rs @@ -7,12 +7,14 @@ use futures::future::BoxFuture; use once_cell::sync::Lazy; use regex::Regex; +use tower::Service; +use tracing::{error, info, Instrument}; + use std::{ fmt::Debug, + net::IpAddr, task::{Context, Poll}, }; -use tower::Service; -use tracing::{error, info, Instrument}; pub trait GenericError: std::error::Error + Debug + Send + Sync {} @@ -54,9 +56,9 @@ let uri = replace_sensitive_path(req.uri().path()); let mut req = RequestParts::new(req); - let socket_addr = extract::ConnectInfo::::from_request(&mut req) + let ip_addr = extract::Extension::::from_request(&mut req) .await - .map_or_else(|_| "0.0.0.0:0".parse().unwrap(), |v| v.0); + .map_or_else(|_| "0.0.0.0".parse().unwrap(), |v| v.0); // this is infallible because of the type of S::Error let response = inner.call(req.try_into_request().unwrap()).await?; @@ -64,7 +66,7 @@ if response.status().is_server_error() { error!( "{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"", - ip = socket_addr, + ip = ip_addr, method = method, uri = uri, status = response.status().as_u16(), @@ -81,7 +83,7 @@ } else { info!( "{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"", - ip = socket_addr, + ip = ip_addr, method = method, uri = uri, status = response.status().as_u16(), diff --git a/chartered-web/src/middleware/mod.rs b/chartered-web/src/middleware/mod.rs index 0016701..1355ee5 100644 --- a/chartered-web/src/middleware/mod.rs +++ a/chartered-web/src/middleware/mod.rs @@ -1,4 +1,5 @@ pub mod cargo_auth; +pub mod ip; pub mod logging; pub mod rate_limit; pub mod web_auth; diff --git a/chartered-web/src/middleware/rate_limit.rs b/chartered-web/src/middleware/rate_limit.rs index 8ec2d6f..9e26084 100644 --- a/chartered-web/src/middleware/rate_limit.rs +++ a/chartered-web/src/middleware/rate_limit.rs @@ -90,14 +90,11 @@ Box::pin(async move { let mut req = RequestParts::new(req); - let socket_addr = extract::ConnectInfo::::from_request(&mut req) + let ip_addr = extract::Extension::::from_request(&mut req) .await .map(|v| v.0); - if let Ok(socket_addr) = socket_addr { - // TODO: cloudflare? - let addr = socket_addr.ip(); - + if let Ok(addr) = ip_addr { if let Err(_e) = this.governor.check_key_n(&addr, this.cost) { return Ok(Response::builder() .status(StatusCode::TOO_MANY_REQUESTS) diff --git a/chartered-web/src/endpoints/web_api/auth/mod.rs b/chartered-web/src/endpoints/web_api/auth/mod.rs index de7ff37..178d306 100644 --- a/chartered-web/src/endpoints/web_api/auth/mod.rs +++ a/chartered-web/src/endpoints/web_api/auth/mod.rs @@ -1,8 +1,10 @@ -use axum::handler::Handler; +use crate::middleware::rate_limit::RateLimit; + use axum::{ extract, + handler::Handler, routing::{get, post}, - Router, + Extension, Router, }; use chartered_db::{ users::{User, UserSession}, @@ -11,7 +13,7 @@ }; use serde::Serialize; -use crate::middleware::rate_limit::RateLimit; +use std::net::IpAddr; pub mod extend; pub mod logout; @@ -68,7 +70,7 @@ db: ConnectionPool, user: User, user_agent: Option>, - extract::ConnectInfo(addr): extract::ConnectInfo, + Extension(addr): Extension, ) -> Result { let user_agent = if let Some(extract::TypedHeader(user_agent)) = user_agent { Some(user_agent.as_str().to_string()) diff --git a/chartered-web/src/endpoints/web_api/auth/openid.rs b/chartered-web/src/endpoints/web_api/auth/openid.rs index 4f87eb3..affc9db 100644 --- a/chartered-web/src/endpoints/web_api/auth/openid.rs +++ a/chartered-web/src/endpoints/web_api/auth/openid.rs @@ -1,8 +1,9 @@ //! Methods for `OpenID` Connect authentication, we allow the frontend to list all the available and //! enabled providers so they can show them to the frontend and provide methods for actually doing //! the authentication. use crate::config::{Config, OidcClient, OidcClients}; + use axum::{extract, Json}; use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit, Nonce as ChaCha20Poly1305Nonce}; use chartered_db::{users::User, ConnectionPool}; @@ -12,9 +13,10 @@ }; use openid::{Options, Token, Userinfo}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; use thiserror::Error; use url::Url; + +use std::{net::IpAddr, sync::Arc}; pub type Nonce = [u8; 16]; @@ -82,7 +84,7 @@ extract::Extension(db): extract::Extension, extract::Extension(http_client): extract::Extension, user_agent: Option>, - addr: extract::ConnectInfo, + addr: extract::Extension, ) -> Result, Error> { // decrypt the state that we created in `begin_oidc` and parse it as json let state: State = serde_json::from_slice(&decrypt_url_safe(¶ms.state, &config)?)?; diff --git a/chartered-web/src/endpoints/web_api/auth/password.rs b/chartered-web/src/endpoints/web_api/auth/password.rs index c7a0032..7747b5f 100644 --- a/chartered-web/src/endpoints/web_api/auth/password.rs +++ a/chartered-web/src/endpoints/web_api/auth/password.rs @@ -1,11 +1,13 @@ //! Password-based authentication, including registration and login. use crate::config::Config; + use axum::{extract, Json}; use chartered_db::{users::User, ConnectionPool}; use serde::{Deserialize, Serialize}; -use std::sync::Arc; use thiserror::Error; + +use std::{net::IpAddr, sync::Arc}; pub async fn handle_register( extract::Extension(config): extract::Extension>, @@ -35,7 +37,7 @@ extract::Extension(db): extract::Extension, extract::Json(req): extract::Json, user_agent: Option>, - addr: extract::ConnectInfo, + addr: extract::Extension, ) -> Result, LoginError> { // some basic validation before we attempt a login if !config.auth.password.enabled { -- rgit 0.1.3