Allow a configurable header to override the user's IP
Diff
chartered-web/src/config.rs | 1 +
chartered-web/src/main.rs | 4 +++-
book/src/guide/config-reference.md | 7 +++++++
chartered-web/src/middleware/ip.rs | 94 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
chartered-web/src/middleware/logging.rs | 14 ++++++++------
chartered-web/src/middleware/mod.rs | 1 +
chartered-web/src/middleware/rate_limit.rs | 7 ++-----
chartered-web/src/endpoints/web_api/auth/mod.rs | 10 ++++++----
chartered-web/src/endpoints/web_api/auth/openid.rs | 6 ++++--
chartered-web/src/endpoints/web_api/auth/password.rs | 6 ++++--
10 files changed, 130 insertions(+), 20 deletions(-)
@@ -25,6 +25,7 @@
pub database_uri: String,
pub storage_uri: String,
pub frontend_base_uri: Url,
pub trusted_ip_header: Option<String>,
pub auth: AuthConfig,
#[serde(deserialize_with = "deserialize_encryption_key")]
pub encryption_key: ChaCha20Poly1305Key,
@@ -6,6 +6,7 @@
mod endpoints;
mod middleware;
use crate::middleware::ip::AddIp;
use crate::middleware::rate_limit::RateLimit;
use axum::{
http::{header, Method},
@@ -127,7 +128,8 @@
.layer(Extension(Arc::new(config.create_oidc_clients().await?)))
.layer(Extension(Arc::new(config.get_file_system().await?)))
.layer(Extension(config.clone()))
.layer(Extension(http_client));
.layer(Extension(http_client))
.layer(AddIp::new(config.trusted_ip_header.clone()));
info!("HTTP server listening on {}", bind_address);
@@ -81,6 +81,7 @@
storage_uri = "s3://s3-eu-west-1.amazonaws.com/my-cool-crate-store/" # or file:///var/lib/chartered
frontend_base_uri = "http://localhost:5173/"
trusted_ip_header = "x-forwarded-for"
[auth.password]
enabled = true # enables password auth
@@ -118,6 +119,12 @@
- Type: `string`
The base URL at which the frontend is being hosted.
#### `trusted_ip_header`
- Type: `string`
- Default: null
Allows a header to override the socket address as the end user's IP address
#### `[auth.password]`
The `[auth.password]` table controls the username/password-based authentication method.
@@ -1,0 +1,94 @@
use axum::{
body::BoxBody,
extract::{self, FromRequest, RequestParts},
http::{Request, Response},
};
use futures::future::BoxFuture;
use tower::{Layer, Service};
use std::{
net::IpAddr,
str::FromStr,
sync::Arc,
task::{Context, Poll},
};
#[derive(Clone)]
pub struct AddIp {
trusted_ip_header: Option<Arc<str>>,
}
impl AddIp {
pub fn new(trusted_ip_header: Option<String>) -> Self {
Self {
trusted_ip_header: trusted_ip_header.map(Arc::from),
}
}
}
impl<S> Layer<S> for AddIp {
type Service = AddIpService<S>;
fn layer(&self, inner: S) -> Self::Service {
AddIpService {
inner,
trusted_ip_header: self.trusted_ip_header.clone(),
}
}
}
#[derive(Clone)]
pub struct AddIpService<S> {
inner: S,
trusted_ip_header: Option<Arc<str>>,
}
impl<S, ReqBody> Service<Request<ReqBody>> for AddIpService<S>
where
S: Service<Request<ReqBody>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let clone = self.clone();
let mut this = std::mem::replace(self, clone);
Box::pin(async move {
let mut req = RequestParts::new(req);
let mut ip = None;
if let Some(trusted_ip_header) = this.trusted_ip_header.as_deref() {
ip = req
.headers()
.get(trusted_ip_header)
.and_then(|v| v.to_str().ok())
.and_then(|v| IpAddr::from_str(v).ok());
}
if ip.is_none() {
ip = extract::ConnectInfo::<std::net::SocketAddr>::from_request(&mut req)
.await
.map(|v| v.0.ip())
.ok();
}
req.extensions_mut().insert(ip.unwrap());
this.inner.call(req.try_into_request().unwrap()).await
})
}
}
@@ -7,12 +7,14 @@
use futures::future::BoxFuture;
use once_cell::sync::Lazy;
use regex::Regex;
use tower::Service;
use tracing::{error, info, Instrument};
use std::{
fmt::Debug,
net::IpAddr,
task::{Context, Poll},
};
use tower::Service;
use tracing::{error, info, Instrument};
pub trait GenericError: std::error::Error + Debug + Send + Sync {}
@@ -54,9 +56,9 @@
let uri = replace_sensitive_path(req.uri().path());
let mut req = RequestParts::new(req);
let socket_addr = extract::ConnectInfo::<std::net::SocketAddr>::from_request(&mut req)
let ip_addr = extract::Extension::<IpAddr>::from_request(&mut req)
.await
.map_or_else(|_| "0.0.0.0:0".parse().unwrap(), |v| v.0);
.map_or_else(|_| "0.0.0.0".parse().unwrap(), |v| v.0);
let response = inner.call(req.try_into_request().unwrap()).await?;
@@ -64,7 +66,7 @@
if response.status().is_server_error() {
error!(
"{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"",
ip = socket_addr,
ip = ip_addr,
method = method,
uri = uri,
status = response.status().as_u16(),
@@ -81,7 +83,7 @@
} else {
info!(
"{ip} - \"{method} {uri}\" {status} {duration:?} \"{user_agent}\" \"{error:?}\"",
ip = socket_addr,
ip = ip_addr,
method = method,
uri = uri,
status = response.status().as_u16(),
@@ -1,4 +1,5 @@
pub mod cargo_auth;
pub mod ip;
pub mod logging;
pub mod rate_limit;
pub mod web_auth;
@@ -90,14 +90,11 @@
Box::pin(async move {
let mut req = RequestParts::new(req);
let socket_addr = extract::ConnectInfo::<std::net::SocketAddr>::from_request(&mut req)
let ip_addr = extract::Extension::<IpAddr>::from_request(&mut req)
.await
.map(|v| v.0);
if let Ok(socket_addr) = socket_addr {
let addr = socket_addr.ip();
if let Ok(addr) = ip_addr {
if let Err(_e) = this.governor.check_key_n(&addr, this.cost) {
return Ok(Response::builder()
.status(StatusCode::TOO_MANY_REQUESTS)
@@ -1,8 +1,10 @@
use axum::handler::Handler;
use crate::middleware::rate_limit::RateLimit;
use axum::{
extract,
handler::Handler,
routing::{get, post},
Router,
Extension, Router,
};
use chartered_db::{
users::{User, UserSession},
@@ -11,7 +13,7 @@
};
use serde::Serialize;
use crate::middleware::rate_limit::RateLimit;
use std::net::IpAddr;
pub mod extend;
pub mod logout;
@@ -68,7 +70,7 @@
db: ConnectionPool,
user: User,
user_agent: Option<extract::TypedHeader<headers::UserAgent>>,
extract::ConnectInfo(addr): extract::ConnectInfo<std::net::SocketAddr>,
Extension(addr): Extension<IpAddr>,
) -> Result<LoginResponse, chartered_db::Error> {
let user_agent = if let Some(extract::TypedHeader(user_agent)) = user_agent {
Some(user_agent.as_str().to_string())
@@ -1,8 +1,9 @@
use crate::config::{Config, OidcClient, OidcClients};
use axum::{extract, Json};
use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit, Nonce as ChaCha20Poly1305Nonce};
use chartered_db::{users::User, ConnectionPool};
@@ -12,9 +13,10 @@
};
use openid::{Options, Token, Userinfo};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use url::Url;
use std::{net::IpAddr, sync::Arc};
pub type Nonce = [u8; 16];
@@ -82,7 +84,7 @@
extract::Extension(db): extract::Extension<ConnectionPool>,
extract::Extension(http_client): extract::Extension<reqwest::Client>,
user_agent: Option<extract::TypedHeader<headers::UserAgent>>,
addr: extract::ConnectInfo<std::net::SocketAddr>,
addr: extract::Extension<IpAddr>,
) -> Result<Json<super::LoginResponse>, Error> {
let state: State = serde_json::from_slice(&decrypt_url_safe(¶ms.state, &config)?)?;
@@ -1,11 +1,13 @@
use crate::config::Config;
use axum::{extract, Json};
use chartered_db::{users::User, ConnectionPool};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use std::{net::IpAddr, sync::Arc};
pub async fn handle_register(
extract::Extension(config): extract::Extension<Arc<Config>>,
@@ -35,7 +37,7 @@
extract::Extension(db): extract::Extension<ConnectionPool>,
extract::Json(req): extract::Json<LoginRequest>,
user_agent: Option<extract::TypedHeader<headers::UserAgent>>,
addr: extract::ConnectInfo<std::net::SocketAddr>,
addr: extract::Extension<IpAddr>,
) -> Result<Json<super::LoginResponse>, LoginError> {
if !config.auth.password.enabled {