Move to oxide-auth-async
Diff
Cargo.lock | 20 ++++++++++++++++++++
jogre-server/Cargo.toml | 1 +
jogre-server/src/context/oauth2.rs | 221 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------
jogre-server/src/layers/auth_required.rs | 2 +-
jogre-server/src/store/rocksdb.rs | 215 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--------------
jogre-server/src/methods/oauth/authorize.rs | 2 +-
jogre-server/src/methods/oauth/refresh.rs | 2 +-
jogre-server/src/methods/oauth/token.rs | 7 +++++--
8 files changed, 329 insertions(+), 141 deletions(-)
@@ -243,6 +243,12 @@
[[package]]
name = "base64"
version = "0.12.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3441f0f7b02788e948e47f457ca01f1d7e6d92c693bc132c22b087d3141c03ff"
[[package]]
name = "base64"
version = "0.13.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8"
@@ -924,6 +930,7 @@
"hmac",
"jmap-proto",
"oxide-auth",
"oxide-auth-async",
"oxide-auth-axum",
"rand",
"rocksdb",
@@ -1176,6 +1183,19 @@
"serde_json",
"sha2",
"subtle",
"url",
]
[[package]]
name = "oxide-auth-async"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "03c2ae7df8070b0a02bf81d67355af4ab4ae7ad5efc376d322f0239b75701ed4"
dependencies = [
"async-trait",
"base64 0.12.3",
"chrono",
"oxide-auth",
"url",
]
@@ -18,6 +18,7 @@
hex = "0.4"
hmac = "0.12"
oxide-auth = "0.5"
oxide-auth-async = "0.1"
oxide-auth-axum = "0.3"
rand = "0.8"
rocksdb = "0.21"
@@ -12,22 +12,23 @@
http::{Method, Request},
BoxError, RequestExt,
};
use futures::FutureExt;
use oxide_auth::{
endpoint::{
Authorizer, Issuer, OwnerConsent, OwnerSolicitor, QueryParameter, Registrar, Scope,
Solicitation, WebRequest,
},
endpoint::{OAuthError, OwnerConsent, QueryParameter, Scope, Scopes, Solicitation, WebRequest},
frontends::simple::{
endpoint,
endpoint::{Generic, Vacant},
endpoint::{Error, ResponseCreator, Vacant},
},
primitives::{
grant::Grant,
issuer::{IssuedToken, RefreshedToken},
prelude::{AuthMap, Client, ClientMap, RandomGenerator, TokenMap},
registrar::RegisteredUrl,
},
};
use oxide_auth_async::endpoint::{
access_token::AccessTokenFlow, authorization::AuthorizationFlow, refresh::RefreshFlow,
resource::ResourceFlow, OwnerSolicitor,
};
use oxide_auth_axum::{OAuthRequest, OAuthResponse, WebError};
use tower_cookies::Cookies;
use tracing::info;
@@ -41,8 +42,8 @@
pub struct OAuth2 {
pub registrar: ClientMap,
pub authorizer: Mutex<AuthMap<RandomGenerator>>,
pub issuer: Mutex<TokenMap<RandomGenerator>>,
pub authorizer: Authorizer,
pub issuer: Issuer,
pub derived_keys: Arc<DerivedKeys>,
pub store: Arc<Store>,
}
@@ -57,8 +58,8 @@
"test".parse::<Scope>().unwrap(),
));
let authorizer = Mutex::new(AuthMap::new(RandomGenerator::new(16)));
let issuer = Mutex::new(TokenMap::new(RandomGenerator::new(16)));
let authorizer = Authorizer::default();
let issuer = Issuer::default();
Self {
registrar,
@@ -69,54 +70,176 @@
}
}
pub fn resource(
pub async fn resource(
&self,
request: OAuthRequest,
) -> Result<Grant, Result<OAuthResponse, endpoint::Error<OAuthRequest>>> {
self.endpoint().resource_flow().execute(request)
match ResourceFlow::prepare(self.endpoint()) {
Ok(mut flow) => flow.execute(request).await,
Err(e) => Err(Err(e)),
}
}
pub fn authorize(
pub async fn authorize(
&self,
request: OAuthRequestWrapper,
) -> Result<OAuthResponse, endpoint::Error<OAuthRequestWrapper>> {
self.endpoint().authorization_flow().execute(request)
AuthorizationFlow::prepare(self.endpoint())?
.execute(request)
.await
}
pub fn token(
pub async fn token(
&self,
request: OAuthRequestWrapper,
) -> Result<OAuthResponse, endpoint::Error<OAuthRequestWrapper>> {
self.endpoint().access_token_flow().execute(request)
AccessTokenFlow::prepare(self.endpoint())?
.execute(request)
.await
}
pub fn refresh(
pub async fn refresh(
&self,
request: OAuthRequestWrapper,
) -> Result<OAuthResponse, endpoint::Error<OAuthRequestWrapper>> {
self.endpoint().refresh_flow().execute(request)
RefreshFlow::prepare(self.endpoint())?
.execute(request)
.await
}
fn endpoint(
&self,
) -> Generic<
impl Registrar + '_,
impl Authorizer + '_,
impl Issuer + '_,
Solicitor<'_>,
Vec<Scope>,
> {
Generic {
fn endpoint(&self) -> Endpoint<'_> {
Endpoint {
registrar: &self.registrar,
authorizer: self.authorizer.lock().unwrap(),
issuer: self.issuer.lock().unwrap(),
authorizer: self.authorizer.clone(),
issuer: self.issuer.clone(),
solicitor: Solicitor {
derived_keys: &self.derived_keys,
store: &self.store,
},
scopes: vec![Scope::from_str("test").unwrap()],
response: Vacant,
}
}
}
pub struct Endpoint<'a> {
registrar: &'a ClientMap,
authorizer: Authorizer,
issuer: Issuer,
solicitor: Solicitor<'a>,
scopes: Vec<Scope>,
response: Vacant,
}
impl<T: WebRequest + Send> oxide_auth_async::endpoint::Endpoint<T> for Endpoint<'_>
where
<T as WebRequest>::Response: Default,
for<'a> Solicitor<'a>: OwnerSolicitor<T>,
{
type Error = Error<T>;
fn registrar(&self) -> Option<&(dyn oxide_auth_async::primitives::Registrar + Sync)> {
Some(&self.registrar)
}
fn authorizer_mut(
&mut self,
) -> Option<&mut (dyn oxide_auth_async::primitives::Authorizer + Send)> {
Some(&mut self.authorizer)
}
fn issuer_mut(&mut self) -> Option<&mut (dyn oxide_auth_async::primitives::Issuer + Send)> {
Some(&mut self.issuer)
}
fn owner_solicitor(&mut self) -> Option<&mut (dyn OwnerSolicitor<T> + Send)> {
Some(&mut self.solicitor)
}
fn scopes(&mut self) -> Option<&mut dyn Scopes<T>> {
Some(&mut self.scopes)
}
fn response(
&mut self,
request: &mut T,
kind: oxide_auth::endpoint::Template,
) -> Result<T::Response, Self::Error> {
Ok(self.response.create(request, kind))
}
fn error(&mut self, err: OAuthError) -> Self::Error {
Error::OAuth(err)
}
fn web_error(&mut self, err: T::Error) -> Self::Error {
Error::Web(err)
}
}
#[derive(Clone)]
pub struct Issuer {
issuer: Arc<Mutex<TokenMap<RandomGenerator>>>,
}
impl Default for Issuer {
fn default() -> Self {
Self {
issuer: Arc::new(Mutex::new(TokenMap::new(RandomGenerator::new(16)))),
}
}
}
#[async_trait]
impl oxide_auth_async::primitives::Issuer for Issuer {
async fn issue(&mut self, grant: Grant) -> Result<IssuedToken, ()> {
oxide_auth::primitives::issuer::Issuer::issue(&mut self.issuer.lock().unwrap(), grant)
}
async fn refresh(&mut self, token: &str, grant: Grant) -> Result<RefreshedToken, ()> {
oxide_auth::primitives::issuer::Issuer::refresh(
&mut self.issuer.lock().unwrap(),
token,
grant,
)
}
async fn recover_token(&mut self, token: &str) -> Result<Option<Grant>, ()> {
oxide_auth::primitives::issuer::Issuer::recover_token(&self.issuer.lock().unwrap(), token)
}
async fn recover_refresh(&mut self, token: &str) -> Result<Option<Grant>, ()> {
oxide_auth::primitives::issuer::Issuer::recover_refresh(&self.issuer.lock().unwrap(), token)
}
}
#[derive(Clone)]
pub struct Authorizer {
auth: Arc<Mutex<AuthMap<RandomGenerator>>>,
}
impl Default for Authorizer {
fn default() -> Self {
Self {
auth: Arc::new(Mutex::new(AuthMap::new(RandomGenerator::new(16)))),
}
}
}
#[async_trait]
impl oxide_auth_async::primitives::Authorizer for Authorizer {
async fn authorize(&mut self, grant: Grant) -> Result<String, ()> {
oxide_auth::primitives::authorizer::Authorizer::authorize(
&mut self.auth.lock().unwrap(),
grant,
)
}
async fn extract(&mut self, token: &str) -> Result<Option<Grant>, ()> {
oxide_auth::primitives::authorizer::Authorizer::extract(
&mut self.auth.lock().unwrap(),
token,
)
}
}
@@ -125,21 +248,23 @@
store: &'a Store,
}
#[async_trait]
impl OwnerSolicitor<OAuthRequest> for Solicitor<'_> {
fn check_consent(
async fn check_consent(
&mut self,
_: &mut OAuthRequest,
_: Solicitation,
_: Solicitation<'_>,
) -> OwnerConsent<OAuthResponse> {
unreachable!("OAuthRequest should only be used for resource requests")
}
}
#[async_trait]
impl OwnerSolicitor<OAuthRequestWrapper> for Solicitor<'_> {
fn check_consent(
async fn check_consent(
&mut self,
req: &mut OAuthRequestWrapper,
solicitation: Solicitation,
solicitation: Solicitation<'_>,
) -> OwnerConsent<OAuthResponse> {
let auth_state = if req.method == Method::GET {
AuthState::Unauthenticated(None)
@@ -153,9 +278,10 @@
self.store,
&req.cookie_jar,
&username,
&password,
password.into_owned(),
&csrf_token,
)
.await
} else {
AuthState::Unauthenticated(Some(UnauthenticatedState::MissingUserPass))
};
@@ -187,12 +313,12 @@
}
}
fn attempt_authentication(
async fn attempt_authentication(
derived_keys: &DerivedKeys,
store: &Store,
cookies: &Cookies,
username: &str,
password: &str,
password: String,
csrf_token: &str,
) -> AuthState {
if !CsrfToken::verify(derived_keys, cookies, csrf_token) {
@@ -200,20 +326,19 @@
}
let Some(user) = store
.get_by_username(username)
.now_or_never()
.unwrap()
.unwrap()
else {
let Some(user) = store.get_by_username(username).await.unwrap() else {
return AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass));
};
if user.verify_password(password) {
AuthState::Authenticated(user.username)
} else {
AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass))
}
tokio::task::spawn_blocking(move || {
if user.verify_password(&password) {
AuthState::Authenticated(user.username)
} else {
AuthState::Unauthenticated(Some(UnauthenticatedState::InvalidUserPass))
}
})
.await
.unwrap()
}
#[derive(Template)]
@@ -26,7 +26,7 @@
}
};
let grant = match state.oauth2.resource(resource_request.into()) {
let grant = match state.oauth2.resource(resource_request.into()).await {
Ok(v) => v,
Err(e) => {
error!("Rejecting request due to it being unauthorized");
@@ -1,4 +1,4 @@
use std::path::PathBuf;
use std::{path::PathBuf, sync::Arc};
use axum::async_trait;
use rocksdb::{IteratorMode, MergeOperands, Options, DB};
@@ -27,7 +27,7 @@
pub struct RocksDb {
db: DB,
db: Arc<DB>,
}
impl RocksDb {
@@ -50,7 +50,7 @@
)
.unwrap();
Self { db }
Self { db: Arc::new(db) }
}
}
@@ -114,14 +114,19 @@
type Error = Error;
async fn create_account(&self, account: Account) -> Result<(), Self::Error> {
let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap();
let by_uuid_handle = self.db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
self.db
.put_cf(by_uuid_handle, account.id.as_bytes(), bytes)
.unwrap();
let db = self.db.clone();
Ok(())
tokio::task::spawn_blocking(move || {
let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap();
let by_uuid_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
db.put_cf(by_uuid_handle, account.id.as_bytes(), bytes)
.unwrap();
Ok(())
})
.await
.unwrap()
}
async fn attach_account_to_user(
@@ -130,17 +135,20 @@
user: Uuid,
access: AccountAccessLevel,
) -> Result<(), Self::Error> {
{
let access_handle = self.db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let mut compound_key = [0_u8; 32];
compound_key[..16].copy_from_slice(user.as_bytes());
compound_key[16..].copy_from_slice(account.as_bytes());
self.db
.put_cf(access_handle, compound_key, (access as u8).to_be_bytes())
db.put_cf(access_handle, compound_key, (access as u8).to_be_bytes())
.unwrap();
}
})
.await
.unwrap();
self.increment_seq_number_for_user(user).await.unwrap();
@@ -148,28 +156,33 @@
}
async fn get_accounts_for_user(&self, user_id: Uuid) -> Result<Vec<Account>, Self::Error> {
let access_handle = self.db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let account_handle = self.db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
Ok(self
.db
.prefix_iterator_cf(access_handle, user_id.as_bytes())
.map(Result::unwrap)
.filter_map(|(key, _access_level)| {
let Some(account) = key.strip_prefix(user_id.as_bytes()) else {
panic!("got invalid key from rocksdb");
};
let Some(account_bytes) = self.db.get_cf(account_handle, account).unwrap() else {
return None;
};
let (res, _): (Account, _) =
bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap();
Some(res)
})
.collect())
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let account_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
Ok(db
.prefix_iterator_cf(access_handle, user_id.as_bytes())
.map(Result::unwrap)
.filter_map(|(key, _access_level)| {
let Some(account) = key.strip_prefix(user_id.as_bytes()) else {
panic!("got invalid key from rocksdb");
};
let Some(account_bytes) = db.get_cf(account_handle, account).unwrap() else {
return None;
};
let (res, _): (Account, _) =
bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap();
Some(res)
})
.collect())
})
.await
.unwrap()
}
}
@@ -178,78 +191,104 @@
type Error = Error;
async fn increment_seq_number_for_user(&self, user: Uuid) -> Result<(), Self::Error> {
let seq_handle = self.db.cf_handle(USER_SEQ_NUMBER).unwrap();
self.db
.merge_cf(seq_handle, user.as_bytes(), "INCR")
.unwrap();
Ok(())
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap();
db.merge_cf(seq_handle, user.as_bytes(), "INCR").unwrap();
Ok(())
})
.await
.unwrap()
}
async fn fetch_seq_number_for_user(&self, user: Uuid) -> Result<u64, Self::Error> {
let seq_handle = self.db.cf_handle(USER_SEQ_NUMBER).unwrap();
let db = self.db.clone();
let Some(bytes) = self.db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else {
return Ok(0);
};
tokio::task::spawn_blocking(move || {
let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap();
let mut val = [0_u8; std::mem::size_of::<u64>()];
val.copy_from_slice(&bytes);
let Some(bytes) = db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else {
return Ok(0);
};
Ok(u64::from_be_bytes(val))
let mut val = [0_u8; std::mem::size_of::<u64>()];
val.copy_from_slice(&bytes);
Ok(u64::from_be_bytes(val))
})
.await
.unwrap()
}
async fn has_any_users(&self) -> Result<bool, Self::Error> {
let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap();
Ok(self
.db
.full_iterator_cf(by_uuid_handle, IteratorMode::Start)
.next()
.is_some())
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
Ok(db
.full_iterator_cf(by_uuid_handle, IteratorMode::Start)
.next()
.is_some())
})
.await
.unwrap()
}
async fn create_user(&self, user: User) -> Result<(), Self::Error> {
let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap();
let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap();
self.db
.put_cf(by_uuid_handle, user.id.as_bytes(), bytes)
.unwrap();
let db = self.db.clone();
let by_username_handle = self.db.cf_handle(USER_BY_USERNAME_CF).unwrap();
self.db
.put_cf(
tokio::task::spawn_blocking(move || {
let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap();
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
db.put_cf(by_uuid_handle, user.id.as_bytes(), bytes)
.unwrap();
let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap();
db.put_cf(
by_username_handle,
user.username.as_bytes(),
user.id.as_bytes(),
)
.unwrap();
Ok(())
Ok(())
})
.await
.unwrap()
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, Error> {
let uuid = {
let by_username_handle = self.db.cf_handle(USER_BY_USERNAME_CF).unwrap();
self.db.get_pinned_cf(by_username_handle, username).unwrap()
};
let Some(uuid) = uuid else {
return Ok(None);
};
let user_bytes = {
let by_uuid_handle = self.db.cf_handle(USER_BY_UUID_CF).unwrap();
self.db.get_pinned_cf(by_uuid_handle, &uuid).unwrap()
};
let Some(user_bytes) = user_bytes else {
return Ok(None);
};
Ok(Some(
bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG)
.unwrap()
.0,
))
let db = self.db.clone();
let username = username.to_string();
tokio::task::spawn_blocking(move || {
let uuid = {
let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap();
db.get_pinned_cf(by_username_handle, username).unwrap()
};
let Some(uuid) = uuid else {
return Ok(None);
};
let user_bytes = {
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
db.get_pinned_cf(by_uuid_handle, &uuid).unwrap()
};
let Some(user_bytes) = user_bytes else {
return Ok(None);
};
Ok(Some(
bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG)
.unwrap()
.0,
))
})
.await
.unwrap()
}
}
@@ -6,7 +6,6 @@
use crate::context::{oauth2::OAuthRequestWrapper, Context};
#[allow(clippy::unused_async)]
pub async fn handle(
State(context): State<Arc<Context>>,
request: OAuthRequestWrapper,
@@ -14,5 +13,6 @@
context
.oauth2
.authorize(request)
.await
.map_err(endpoint::Error::pack)
}
@@ -6,7 +6,6 @@
use crate::context::{oauth2::OAuthRequestWrapper, Context};
#[allow(clippy::unused_async)]
pub async fn handle(
State(context): State<Arc<Context>>,
request: OAuthRequestWrapper,
@@ -14,5 +13,6 @@
context
.oauth2
.refresh(request)
.await
.map_err(endpoint::Error::pack)
}
@@ -6,10 +6,13 @@
use crate::context::{oauth2::OAuthRequestWrapper, Context};
#[allow(clippy::unused_async)]
pub async fn handle(
State(context): State<Arc<Context>>,
request: OAuthRequestWrapper,
) -> Result<OAuthResponse, WebError> {
context.oauth2.token(request).map_err(endpoint::Error::pack)
context
.oauth2
.token(request)
.await
.map_err(endpoint::Error::pack)
}