From 0e2e088353d3a30bb3dc31482b871b8ef5c942d7 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Mon, 18 Sep 2023 02:35:04 +0100 Subject: [PATCH] Move to oxide-auth-async --- Cargo.lock | 20 ++++++++++++++++++++ jogre-server/Cargo.toml | 1 + jogre-server/src/context/oauth2.rs | 221 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------- jogre-server/src/layers/auth_required.rs | 2 +- jogre-server/src/store/rocksdb.rs | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------- jogre-server/src/methods/oauth/authorize.rs | 2 +- jogre-server/src/methods/oauth/refresh.rs | 2 +- jogre-server/src/methods/oauth/token.rs | 7 +++++-- 8 files changed, 329 insertions(+), 141 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 94fdf4a..350a0c1 100644 --- a/Cargo.lock +++ a/Cargo.lock @@ -243,6 +243,12 @@ [[package]] name = "base64" +version = "0.12.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff" + +[[package]] +name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" @@ -924,6 +930,7 @@ "hmac", "jmap-proto", "oxide-auth", + "oxide-auth-async", "oxide-auth-axum", "rand", "rocksdb", @@ -1176,6 +1183,19 @@ "serde_json", "sha2", "subtle", + "url", +] + +[[package]] +name = "oxide-auth-async" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03c2ae7df8070b0a02bf81d67355af4ab4ae7ad5efc376d322f0239b75701ed4" +dependencies = [ + "async-trait", + "base64 0.12.3", + "chrono", + "oxide-auth", "url", ] diff --git a/jogre-server/Cargo.toml b/jogre-server/Cargo.toml index 2875260..1bf6862 100644 --- a/jogre-server/Cargo.toml +++ a/jogre-server/Cargo.toml @@ -18,6 +18,7 @@ hex = "0.4" hmac = "0.12" oxide-auth = "0.5" +oxide-auth-async = "0.1" oxide-auth-axum = "0.3" rand = "0.8" rocksdb = "0.21" diff --git a/jogre-server/src/context/oauth2.rs b/jogre-server/src/context/oauth2.rs index d409a75..3e74426 100644 --- a/jogre-server/src/context/oauth2.rs +++ a/jogre-server/src/context/oauth2.rs @@ -12,22 +12,23 @@ http::{Method, Request}, BoxError, RequestExt, }; -use futures::FutureExt; use oxide_auth::{ - endpoint::{ - Authorizer, Issuer, OwnerConsent, OwnerSolicitor, QueryParameter, Registrar, Scope, - Solicitation, WebRequest, - }, + endpoint::{OAuthError, OwnerConsent, QueryParameter, Scope, Scopes, Solicitation, WebRequest}, frontends::simple::{ endpoint, - endpoint::{Generic, Vacant}, + endpoint::{Error, ResponseCreator, Vacant}, }, primitives::{ grant::Grant, + issuer::{IssuedToken, RefreshedToken}, prelude::{AuthMap, Client, ClientMap, RandomGenerator, TokenMap}, registrar::RegisteredUrl, }, }; +use oxide_auth_async::endpoint::{ + access_token::AccessTokenFlow, authorization::AuthorizationFlow, refresh::RefreshFlow, + resource::ResourceFlow, OwnerSolicitor, +}; use oxide_auth_axum::{OAuthRequest, OAuthResponse, WebError}; use tower_cookies::Cookies; use tracing::info; @@ -41,8 +42,8 @@ pub struct OAuth2 { pub registrar: ClientMap, - pub authorizer: Mutex>, - pub issuer: Mutex>, + pub authorizer: Authorizer, + pub issuer: Issuer, pub derived_keys: Arc, pub store: Arc, } @@ -57,8 +58,8 @@ "test".parse::().unwrap(), )); - let authorizer = Mutex::new(AuthMap::new(RandomGenerator::new(16))); - let issuer = Mutex::new(TokenMap::new(RandomGenerator::new(16))); + let authorizer = Authorizer::default(); + let issuer = Issuer::default(); Self { registrar, @@ -69,54 +70,176 @@ } } - pub fn resource( + pub async fn resource( &self, request: OAuthRequest, ) -> Result>> { - self.endpoint().resource_flow().execute(request) + match ResourceFlow::prepare(self.endpoint()) { + Ok(mut flow) => flow.execute(request).await, + Err(e) => Err(Err(e)), + } } - pub fn authorize( + pub async fn authorize( &self, request: OAuthRequestWrapper, ) -> Result> { - self.endpoint().authorization_flow().execute(request) + AuthorizationFlow::prepare(self.endpoint())? + .execute(request) + .await } - pub fn token( + pub async fn token( &self, request: OAuthRequestWrapper, ) -> Result> { - self.endpoint().access_token_flow().execute(request) + AccessTokenFlow::prepare(self.endpoint())? + .execute(request) + .await } - pub fn refresh( + pub async fn refresh( &self, request: OAuthRequestWrapper, ) -> Result> { - self.endpoint().refresh_flow().execute(request) + RefreshFlow::prepare(self.endpoint())? + .execute(request) + .await } - fn endpoint( - &self, - ) -> Generic< - impl Registrar + '_, - impl Authorizer + '_, - impl Issuer + '_, - Solicitor<'_>, - Vec, - > { - Generic { + fn endpoint(&self) -> Endpoint<'_> { + Endpoint { registrar: &self.registrar, - authorizer: self.authorizer.lock().unwrap(), - issuer: self.issuer.lock().unwrap(), + authorizer: self.authorizer.clone(), + issuer: self.issuer.clone(), solicitor: Solicitor { derived_keys: &self.derived_keys, store: &self.store, }, scopes: vec![Scope::from_str("test").unwrap()], response: Vacant, + } + } +} + +pub struct Endpoint<'a> { + registrar: &'a ClientMap, + authorizer: Authorizer, + issuer: Issuer, + solicitor: Solicitor<'a>, + scopes: Vec, + response: Vacant, +} + +impl oxide_auth_async::endpoint::Endpoint for Endpoint<'_> +where + ::Response: Default, + for<'a> Solicitor<'a>: OwnerSolicitor, +{ + type Error = Error; + + fn registrar(&self) -> Option<&(dyn oxide_auth_async::primitives::Registrar + Sync)> { + Some(&self.registrar) + } + + fn authorizer_mut( + &mut self, + ) -> Option<&mut (dyn oxide_auth_async::primitives::Authorizer + Send)> { + Some(&mut self.authorizer) + } + + fn issuer_mut(&mut self) -> Option<&mut (dyn oxide_auth_async::primitives::Issuer + Send)> { + Some(&mut self.issuer) + } + + fn owner_solicitor(&mut self) -> Option<&mut (dyn OwnerSolicitor + Send)> { + Some(&mut self.solicitor) + } + + fn scopes(&mut self) -> Option<&mut dyn Scopes> { + Some(&mut self.scopes) + } + + fn response( + &mut self, + request: &mut T, + kind: oxide_auth::endpoint::Template, + ) -> Result { + Ok(self.response.create(request, kind)) + } + + fn error(&mut self, err: OAuthError) -> Self::Error { + Error::OAuth(err) + } + + fn web_error(&mut self, err: T::Error) -> Self::Error { + Error::Web(err) + } +} + +#[derive(Clone)] +pub struct Issuer { + issuer: Arc>>, +} + +impl Default for Issuer { + fn default() -> Self { + Self { + issuer: Arc::new(Mutex::new(TokenMap::new(RandomGenerator::new(16)))), + } + } +} + +#[async_trait] +impl oxide_auth_async::primitives::Issuer for Issuer { + async fn issue(&mut self, grant: Grant) -> Result { + oxide_auth::primitives::issuer::Issuer::issue(&mut self.issuer.lock().unwrap(), grant) + } + + async fn refresh(&mut self, token: &str, grant: Grant) -> Result { + oxide_auth::primitives::issuer::Issuer::refresh( + &mut self.issuer.lock().unwrap(), + token, + grant, + ) + } + + async fn recover_token(&mut self, token: &str) -> Result, ()> { + oxide_auth::primitives::issuer::Issuer::recover_token(&self.issuer.lock().unwrap(), token) + } + + async fn recover_refresh(&mut self, token: &str) -> Result, ()> { + oxide_auth::primitives::issuer::Issuer::recover_refresh(&self.issuer.lock().unwrap(), token) + } +} + +#[derive(Clone)] +pub struct Authorizer { + auth: Arc>>, +} + +impl Default for Authorizer { + fn default() -> Self { + Self { + auth: Arc::new(Mutex::new(AuthMap::new(RandomGenerator::new(16)))), } + } +} + +#[async_trait] +impl oxide_auth_async::primitives::Authorizer for Authorizer { + async fn authorize(&mut self, grant: Grant) -> Result { + oxide_auth::primitives::authorizer::Authorizer::authorize( + &mut self.auth.lock().unwrap(), + grant, + ) + } + + async fn extract(&mut self, token: &str) -> Result, ()> { + oxide_auth::primitives::authorizer::Authorizer::extract( + &mut self.auth.lock().unwrap(), + token, + ) } } @@ -125,21 +248,23 @@ store: &'a Store, } +#[async_trait] impl OwnerSolicitor for Solicitor<'_> { - fn check_consent( + async fn check_consent( &mut self, _: &mut OAuthRequest, - _: Solicitation, + _: Solicitation<'_>, ) -> OwnerConsent { unreachable!("OAuthRequest should only be used for resource requests") } } +#[async_trait] impl OwnerSolicitor for Solicitor<'_> { - fn check_consent( + async fn check_consent( &mut self, req: &mut OAuthRequestWrapper, - solicitation: Solicitation, + solicitation: Solicitation<'_>, ) -> OwnerConsent { let auth_state = if req.method == Method::GET { AuthState::Unauthenticated(None) @@ -153,9 +278,10 @@ self.store, &req.cookie_jar, &username, - &password, + password.into_owned(), &csrf_token, ) + .await } else { AuthState::Unauthenticated(Some(UnauthenticatedState::MissingUserPass)) }; @@ -187,12 +313,12 @@ } } -fn attempt_authentication( +async fn attempt_authentication( derived_keys: &DerivedKeys, store: &Store, cookies: &Cookies, username: &str, - password: &str, + password: String, csrf_token: &str, ) -> AuthState { if !CsrfToken::verify(derived_keys, cookies, csrf_token) { @@ -200,20 +326,19 @@ } // TODO: actually await here - let Some(user) = store - .get_by_username(username) - .now_or_never() - .unwrap() - .unwrap() - else { + let Some(user) = store.get_by_username(username).await.unwrap() else { return AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass)); }; - if user.verify_password(password) { - AuthState::Authenticated(user.username) - } else { - AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass)) - } + tokio::task::spawn_blocking(move || { + if user.verify_password(&password) { + AuthState::Authenticated(user.username) + } else { + AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass)) + } + }) + .await + .unwrap() } #[derive(Template)] diff --git a/jogre-server/src/layers/auth_required.rs b/jogre-server/src/layers/auth_required.rs index 7e4e1d3..32abdb5 100644 --- a/jogre-server/src/layers/auth_required.rs +++ a/jogre-server/src/layers/auth_required.rs @@ -26,7 +26,7 @@ } }; - let grant = match state.oauth2.resource(resource_request.into()) { + let grant = match state.oauth2.resource(resource_request.into()).await { Ok(v) => v, Err(e) => { error!("Rejecting request due to it being unauthorized"); diff --git a/jogre-server/src/store/rocksdb.rs b/jogre-server/src/store/rocksdb.rs index 9259200..4ea3a56 100644 --- a/jogre-server/src/store/rocksdb.rs +++ a/jogre-server/src/store/rocksdb.rs @@ -1,4 +1,4 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::Arc}; use axum::async_trait; use rocksdb::{IteratorMode, MergeOperands, Options, DB}; @@ -27,7 +27,7 @@ // TODO: lots of blocking on async thread pub struct RocksDb { - db: DB, + db: Arc, } impl RocksDb { @@ -50,7 +50,7 @@ ) .unwrap(); - Self { db } + Self { db: Arc::new(db) } } } @@ -114,14 +114,19 @@ type Error = Error; async fn create_account(&self, account: Account) -> Result<(), Self::Error> { - let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap(); - - let by_uuid_handle = self.db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); - self.db - .put_cf(by_uuid_handle, account.id.as_bytes(), bytes) - .unwrap(); + let db = self.db.clone(); - Ok(()) + tokio::task::spawn_blocking(move || { + let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap(); + + let by_uuid_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); + db.put_cf(by_uuid_handle, account.id.as_bytes(), bytes) + .unwrap(); + + Ok(()) + }) + .await + .unwrap() } async fn attach_account_to_user( @@ -130,17 +135,20 @@ user: Uuid, access: AccountAccessLevel, ) -> Result<(), Self::Error> { - { - let access_handle = self.db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); + let db = self.db.clone(); + + tokio::task::spawn_blocking(move || { + let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); let mut compound_key = [0_u8; 32]; compound_key[..16].copy_from_slice(user.as_bytes()); compound_key[16..].copy_from_slice(account.as_bytes()); - self.db - .put_cf(access_handle, compound_key, (access as u8).to_be_bytes()) + db.put_cf(access_handle, compound_key, (access as u8).to_be_bytes()) .unwrap(); - } + }) + .await + .unwrap(); self.increment_seq_number_for_user(user).await.unwrap(); @@ -148,28 +156,33 @@ } async fn get_accounts_for_user(&self, user_id: Uuid) -> Result, Self::Error> { - let access_handle = self.db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); - let account_handle = self.db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); - - Ok(self - .db - .prefix_iterator_cf(access_handle, user_id.as_bytes()) - .map(Result::unwrap) - .filter_map(|(key, _access_level)| { - let Some(account) = key.strip_prefix(user_id.as_bytes()) else { - panic!("got invalid key from rocksdb"); - }; - - let Some(account_bytes) = self.db.get_cf(account_handle, account).unwrap() else { - return None; - }; - - let (res, _): (Account, _) = - bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap(); - - Some(res) - }) - .collect()) + let db = self.db.clone(); + + tokio::task::spawn_blocking(move || { + let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); + let account_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); + + Ok(db + .prefix_iterator_cf(access_handle, user_id.as_bytes()) + .map(Result::unwrap) + .filter_map(|(key, _access_level)| { + let Some(account) = key.strip_prefix(user_id.as_bytes()) else { + panic!("got invalid key from rocksdb"); + }; + + let Some(account_bytes) = db.get_cf(account_handle, account).unwrap() else { + return None; + }; + + let (res, _): (Account, _) = + bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap(); + + Some(res) + }) + .collect()) + }) + .await + .unwrap() } } @@ -178,78 +191,104 @@ type Error = Error; async fn increment_seq_number_for_user(&self, user: Uuid) -> Result<(), Self::Error> { - let seq_handle = self.db.cf_handle(USER_SEQ_NUMBER).unwrap(); - self.db - .merge_cf(seq_handle, user.as_bytes(), "INCR") - .unwrap(); - Ok(()) + let db = self.db.clone(); + + tokio::task::spawn_blocking(move || { + let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap(); + db.merge_cf(seq_handle, user.as_bytes(), "INCR").unwrap(); + Ok(()) + }) + .await + .unwrap() } async fn fetch_seq_number_for_user(&self, user: Uuid) -> Result { - let seq_handle = self.db.cf_handle(USER_SEQ_NUMBER).unwrap(); + let db = self.db.clone(); - let Some(bytes) = self.db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else { - return Ok(0); - }; + tokio::task::spawn_blocking(move || { + let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap(); - let mut val = [0_u8; std::mem::size_of::()]; - val.copy_from_slice(&bytes); + let Some(bytes) = db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else { + return Ok(0); + }; - Ok(u64::from_be_bytes(val)) + let mut val = [0_u8; std::mem::size_of::()]; + val.copy_from_slice(&bytes); + + Ok(u64::from_be_bytes(val)) + }) + .await + .unwrap() } async fn has_any_users(&self) -> Result { - let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap(); - Ok(self - .db - .full_iterator_cf(by_uuid_handle, IteratorMode::Start) - .next() - .is_some()) + let db = self.db.clone(); + + tokio::task::spawn_blocking(move || { + let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); + Ok(db + .full_iterator_cf(by_uuid_handle, IteratorMode::Start) + .next() + .is_some()) + }) + .await + .unwrap() } async fn create_user(&self, user: User) -> Result<(), Self::Error> { - let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap(); - - let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap(); - self.db - .put_cf(by_uuid_handle, user.id.as_bytes(), bytes) - .unwrap(); + let db = self.db.clone(); - let by_username_handle = self.db.cf_handle(USER_BY_USERNAME_CF).unwrap(); - self.db - .put_cf( + tokio::task::spawn_blocking(move || { + let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap(); + + let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); + db.put_cf(by_uuid_handle, user.id.as_bytes(), bytes) + .unwrap(); + + let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap(); + db.put_cf( by_username_handle, user.username.as_bytes(), user.id.as_bytes(), ) .unwrap(); - Ok(()) + Ok(()) + }) + .await + .unwrap() } async fn get_by_username(&self, username: &str) -> Result, Error> { - let uuid = { - let by_username_handle = self.db.cf_handle(USER_BY_USERNAME_CF).unwrap(); - self.db.get_pinned_cf(by_username_handle, username).unwrap() - }; - - let Some(uuid) = uuid else { - return Ok(None); - }; - - let user_bytes = { - let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap(); - self.db.get_pinned_cf(by_uuid_handle, &uuid).unwrap() - }; - - let Some(user_bytes) = user_bytes else { - return Ok(None); - }; - - Ok(Some( - bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG) - .unwrap() - .0, - )) + let db = self.db.clone(); + let username = username.to_string(); + + tokio::task::spawn_blocking(move || { + let uuid = { + let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap(); + db.get_pinned_cf(by_username_handle, username).unwrap() + }; + + let Some(uuid) = uuid else { + return Ok(None); + }; + + let user_bytes = { + let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); + db.get_pinned_cf(by_uuid_handle, &uuid).unwrap() + }; + + let Some(user_bytes) = user_bytes else { + return Ok(None); + }; + + Ok(Some( + bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG) + .unwrap() + .0, + )) + }) + .await + .unwrap() } } diff --git a/jogre-server/src/methods/oauth/authorize.rs b/jogre-server/src/methods/oauth/authorize.rs index 2b393d0..a36e07c 100644 --- a/jogre-server/src/methods/oauth/authorize.rs +++ a/jogre-server/src/methods/oauth/authorize.rs @@ -6,7 +6,6 @@ use crate::context::{oauth2::OAuthRequestWrapper, Context}; -#[allow(clippy::unused_async)] pub async fn handle( State(context): State>, request: OAuthRequestWrapper, @@ -14,5 +13,6 @@ context .oauth2 .authorize(request) + .await .map_err(endpoint::Error::pack) } diff --git a/jogre-server/src/methods/oauth/refresh.rs b/jogre-server/src/methods/oauth/refresh.rs index 9ae918f..0d2e28b 100644 --- a/jogre-server/src/methods/oauth/refresh.rs +++ a/jogre-server/src/methods/oauth/refresh.rs @@ -6,7 +6,6 @@ use crate::context::{oauth2::OAuthRequestWrapper, Context}; -#[allow(clippy::unused_async)] pub async fn handle( State(context): State>, request: OAuthRequestWrapper, @@ -14,5 +13,6 @@ context .oauth2 .refresh(request) + .await .map_err(endpoint::Error::pack) } diff --git a/jogre-server/src/methods/oauth/token.rs b/jogre-server/src/methods/oauth/token.rs index 7ce9c20..a382d38 100644 --- a/jogre-server/src/methods/oauth/token.rs +++ a/jogre-server/src/methods/oauth/token.rs @@ -6,10 +6,13 @@ use crate::context::{oauth2::OAuthRequestWrapper, Context}; -#[allow(clippy::unused_async)] pub async fn handle( State(context): State>, request: OAuthRequestWrapper, ) -> Result { - context.oauth2.token(request).map_err(endpoint::Error::pack) + context + .oauth2 + .token(request) + .await + .map_err(endpoint::Error::pack) } -- rgit 0.1.3