use axum::{
body::{boxed, Body, BoxBody},
extract::{self, FromRequest, RequestParts},
http::{Request, Response, StatusCode},
};
use chartered_db::{users::User, ConnectionPool};
use futures::future::BoxFuture;
use std::{
collections::HashMap,
sync::Arc,
task::{Context, Poll},
};
use tower::Service;
use crate::endpoints::ErrorResponse;
#[derive(Clone)]
pub struct CargoAuthMiddleware<S>(pub S);
impl<S, ReqBody> Service<Request<ReqBody>> for CargoAuthMiddleware<S>
where
S: Service<Request<ReqBody>, Response = Response<BoxBody>> + Clone + Send + 'static,
S::Future: Send + 'static,
ReqBody: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.0.poll_ready(cx)
}
fn call(&mut self, req: Request<ReqBody>) -> Self::Future {
let clone = self.0.clone();
let mut inner = std::mem::replace(&mut self.0, clone);
Box::pin(async move {
let mut req = RequestParts::new(req);
let params = extract::Path::<HashMap<String, String>>::from_request(&mut req)
.await
.unwrap();
let key = params.get("key").map(String::as_str).unwrap_or_default();
let db = req.extensions().get::<ConnectionPool>().unwrap().clone();
let (session, user) = match User::find_by_session_key(db, String::from(key))
.await
.unwrap()
{
Some((session, user)) => (Arc::new(session), Arc::new(user)),
None => {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(
serde_json::to_vec(&ErrorResponse {
error: Some("Expired auth token".into()),
})
.unwrap(),
)))
.unwrap())
}
};
if session.user_ssh_key_id.is_none() {
return Ok(Response::builder()
.status(StatusCode::UNAUTHORIZED)
.body(boxed(Body::from(
serde_json::to_vec(&ErrorResponse {
error: Some("Invalid auth token".into()),
})
.unwrap(),
)))
.unwrap());
}
req.extensions_mut().insert(user);
req.extensions_mut().insert(session);
let response: Response<BoxBody> = inner.call(req.try_into_request().unwrap()).await?;
Ok(response)
})
}
}