//! Check the API key embedded in the path is valid otherwise returns a 401 for authenticated //! endpoints. use axum::{ body::{boxed, Body, BoxBody}, extract::{self, FromRequest, RequestParts}, http::{Request, Response, StatusCode}, }; use chartered_db::{users::User, ConnectionPool}; use futures::future::BoxFuture; use std::{ collections::HashMap, sync::Arc, task::{Context, Poll}, }; use tower::Service; use crate::endpoints::ErrorResponse; #[derive(Clone)] pub struct CargoAuthMiddleware(pub S); impl Service> for CargoAuthMiddleware where S: Service, Response = Response> + Clone + Send + 'static, S::Future: Send + 'static, ReqBody: Send + 'static, { type Response = S::Response; type Error = S::Error; type Future = BoxFuture<'static, Result>; fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { self.0.poll_ready(cx) } fn call(&mut self, req: Request) -> Self::Future { // best practice is to clone the inner service like this // see https://github.com/tower-rs/tower/issues/547 for details let clone = self.0.clone(); let mut inner = std::mem::replace(&mut self.0, clone); Box::pin(async move { let mut req = RequestParts::new(req); // extracts all parameters from the path so we can get the API key which should // always be named key let params = extract::Path::>::from_request(&mut req) .await .unwrap(); let key = params.get("key").map(String::as_str).unwrap_or_default(); // grab the ConnectionPool from the extensions created when we initialised the // server let db = req.extensions().get::().unwrap().clone(); // grab the UserSession that's currently being used for this request and the User that // owns the key, otherwise return a 401 if the key doesn't exist let (session, user) = match User::find_by_session_key(db, String::from(key)) .await .unwrap() { Some((session, user)) => (Arc::new(session), Arc::new(user)), None => { return Ok(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(boxed(Body::from( serde_json::to_vec(&ErrorResponse { error: Some("Expired auth token".into()), }) .unwrap(), ))) .unwrap()) } }; if session.user_ssh_key_id.is_none() { // Web sessions can't be used for the Cargo API return Ok(Response::builder() .status(StatusCode::UNAUTHORIZED) .body(boxed(Body::from( serde_json::to_vec(&ErrorResponse { error: Some("Invalid auth token".into()), }) .unwrap(), ))) .unwrap()); } // insert both the user and the session into extensions so handlers can // get their hands on them req.extensions_mut().insert(user); req.extensions_mut().insert(session); // calls handlers/other middleware and drives the request to response let response: Response = inner.call(req.try_into_request().unwrap()).await?; Ok(response) }) } }