#![deny(clippy::pedantic)] #![deny(rust_2018_idioms)] #![allow(clippy::module_name_repetitions)] mod config; mod endpoints; mod middleware; use axum::{ handler::get, http::{header, Method}, AddExtensionLayer, Router, }; use clap::Parser; use std::{fmt::Formatter, path::PathBuf, sync::Arc}; use thiserror::Error; use tower::ServiceBuilder; use tower_http::cors::{CorsLayer, Origin}; use tracing::info; use url::Url; #[derive(Parser)] #[clap(version = clap::crate_version!(), author = clap::crate_authors!())] pub struct Opts { #[clap(short, long, parse(from_occurrences))] verbose: i32, #[clap(short, long)] config: PathBuf, } #[allow(clippy::unused_async)] async fn hello_world() -> &'static str { "hello, world!" } // there's some sort of issue with monomorphization of axum routes // which causes compile times to increase exponentially with every // new route, the workaround is to box the router down to a // dynamically dispatched version with every new route. macro_rules! axum_box_after_every_route { (Router::new() $(.nest($nest_path:expr, $nest_svc:expr$(,)?))* $(.route($route_path:expr, $route_svc:expr$(,)?))* ) => { Router::new() $( .nest($nest_path, $nest_svc) .boxed() )* $( .route($route_path, $route_svc) .boxed() )* }; } pub(crate) use axum_box_after_every_route; #[tokio::main] #[allow(clippy::semicolon_if_nothing_returned)] // lint breaks with tokio::main async fn main() -> Result<(), InitError> { // parse CLI arguments let opts: Opts = Opts::parse(); // overrides the RUST_LOG variable to our own value based on the // amount of `-v`s that were passed when calling the service std::env::set_var( "RUST_LOG", match opts.verbose { 1 => "debug", 2 => "trace", _ => "info", }, ); let config: config::Config = toml::from_slice(&std::fs::read(&opts.config)?)?; // initialise logging/tracing tracing_subscriber::fmt::init(); let bind_address = config.bind_address; let pool = chartered_db::init(&config.database_uri)?; // the base stack of middleware that is applied to _all_ routes let middleware_stack = ServiceBuilder::new() .layer_fn(middleware::logging::LoggingMiddleware) .into_inner(); let config = Arc::new(config); let app = Router::new() .route("/", get(hello_world)) .nest( "/a/:key/web/v1", endpoints::web_api::authenticated_routes().layer( ServiceBuilder::new() .layer_fn(crate::middleware::auth::AuthMiddleware) .into_inner(), ), ) .nest("/a/-/web/v1", endpoints::web_api::unauthenticated_routes()) .nest( "/a/:key/o/:organisation/api/v1", endpoints::cargo_api::routes().layer( ServiceBuilder::new() .layer_fn(crate::middleware::auth::AuthMiddleware) .into_inner(), ), ) .layer(middleware_stack) .layer( CorsLayer::new() .allow_methods(vec![ Method::GET, Method::POST, Method::PATCH, Method::DELETE, Method::PUT, Method::OPTIONS, ]) .allow_headers(vec![header::CONTENT_TYPE, header::USER_AGENT]) .allow_origin(Origin::predicate({ let config = config.clone(); move |url, _| { url.to_str() .ok() .and_then(|url| Url::parse(url).ok()) .map(|url| url.host_str() == config.frontend_base_uri.host_str()) .unwrap_or_default() } })) .allow_credentials(false), ) .layer(AddExtensionLayer::new(pool)) .layer(AddExtensionLayer::new(Arc::new( config.create_oidc_clients().await?, ))) .layer(AddExtensionLayer::new(Arc::new( config.get_file_system().await?, ))) .layer(AddExtensionLayer::new(config.clone())); info!("HTTP server listening on {}", bind_address); axum::Server::bind(&bind_address) .serve(app.into_make_service_with_connect_info::()) .await .map_err(|e| InitError::ServerSpawn(Box::new(e)))?; Ok(()) } #[derive(Error)] pub enum InitError { #[error("Failed to read configuration: {0}")] ConfigRead(#[from] std::io::Error), #[error("Failed to parse configuration: {0}")] ConfigParse(#[from] toml::de::Error), #[error("Configuration error: {0}")] Config(#[from] config::Error), #[error("Database error: {0}")] Database(#[from] chartered_db::Error), #[error("Failed to spawn HTTP server: {0}")] ServerSpawn(Box), #[error("Failed to build CORS header: {0}")] Cors(axum::http::header::InvalidHeaderValue), } impl std::fmt::Debug for InitError { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { write!(f, "{}", self) } }