First-class support for GitHub authentication
Diff
Cargo.lock | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++--
chartered-web/Cargo.toml | 1 +
chartered-web/src/config.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++--
chartered-web/src/main.rs | 10 ++++++++--
chartered-web/src/endpoints/web_api/auth/openid.rs | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------
5 files changed, 303 insertions(+), 53 deletions(-)
@@ -300,7 +300,7 @@
"http",
"http-body",
"hyper",
"hyper-rustls",
"hyper-rustls 0.22.1",
"lazy_static",
"pin-project-lite",
"tokio",
@@ -741,6 +741,7 @@
"hex",
"nom",
"nom-bytes",
"oauth2",
"once_cell",
"openid",
"rand",
@@ -935,7 +936,7 @@
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1a816186fa68d9e426e3cb4ae4dff1fcd8e4a2c34b781bf7a822574a0d0aac8"
dependencies = [
"sct",
"sct 0.6.1",
]
[[package]]
@@ -1249,8 +1250,10 @@
checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6"
dependencies = [
"cfg-if",
"js-sys",
"libc",
"wasi 0.11.0+wasi-snapshot-preview1",
"wasm-bindgen",
]
[[package]]
@@ -1417,11 +1420,24 @@
"futures-util",
"hyper",
"log",
"rustls",
"rustls 0.19.1",
"rustls-native-certs",
"tokio",
"tokio-rustls 0.22.0",
"webpki 0.21.4",
]
[[package]]
name = "hyper-rustls"
version = "0.23.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac"
dependencies = [
"http",
"hyper",
"rustls 0.20.6",
"tokio",
"tokio-rustls",
"webpki",
"tokio-rustls 0.23.4",
]
[[package]]
@@ -1821,6 +1837,26 @@
checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44"
dependencies = [
"libc",
]
[[package]]
name = "oauth2"
version = "4.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6d62c436394991641b970a92e23e8eeb4eb9bca74af4f5badc53bcd568daadbd"
dependencies = [
"base64",
"chrono",
"getrandom",
"http",
"rand",
"reqwest",
"serde",
"serde_json",
"serde_path_to_error",
"sha2 0.10.2",
"thiserror",
"url",
]
[[package]]
@@ -2175,6 +2211,7 @@
"http",
"http-body",
"hyper",
"hyper-rustls 0.23.0",
"hyper-tls",
"ipnet",
"js-sys",
@@ -2184,16 +2221,20 @@
"native-tls",
"percent-encoding",
"pin-project-lite",
"rustls 0.20.6",
"rustls-pemfile",
"serde",
"serde_json",
"serde_urlencoded",
"tokio",
"tokio-native-tls",
"tokio-rustls 0.23.4",
"tower-service",
"url",
"wasm-bindgen",
"wasm-bindgen-futures",
"web-sys",
"webpki-roots",
"winreg",
]
@@ -2228,10 +2269,22 @@
checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7"
dependencies = [
"base64",
"log",
"ring",
"sct 0.6.1",
"webpki 0.21.4",
]
[[package]]
name = "rustls"
version = "0.20.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033"
dependencies = [
"log",
"ring",
"sct",
"webpki",
"sct 0.7.0",
"webpki 0.22.0",
]
[[package]]
@@ -2241,9 +2294,18 @@
checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092"
dependencies = [
"openssl-probe",
"rustls",
"rustls 0.19.1",
"schannel",
"security-framework",
]
[[package]]
name = "rustls-pemfile"
version = "1.0.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55"
dependencies = [
"base64",
]
[[package]]
@@ -2291,6 +2353,16 @@
version = "0.6.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "sct"
version = "0.7.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4"
dependencies = [
"ring",
"untrusted",
@@ -2354,6 +2426,15 @@
"indexmap",
"itoa",
"ryu",
"serde",
]
[[package]]
name = "serde_path_to_error"
version = "0.1.8"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "184c643044780f7ceb59104cef98a5a6f12cb2288a7bc701ab93a362b49fd47d"
dependencies = [
"serde",
]
@@ -2707,10 +2788,21 @@
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6"
dependencies = [
"rustls 0.19.1",
"tokio",
"webpki 0.21.4",
]
[[package]]
name = "tokio-rustls"
version = "0.23.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59"
dependencies = [
"rustls",
"rustls 0.20.6",
"tokio",
"webpki",
"webpki 0.22.0",
]
[[package]]
@@ -3112,9 +3204,28 @@
version = "0.21.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webpki"
version = "0.22.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd"
dependencies = [
"ring",
"untrusted",
]
[[package]]
name = "webpki-roots"
version = "0.22.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf"
dependencies = [
"webpki 0.22.0",
]
[[package]]
@@ -23,6 +23,7 @@
hex = "0.4"
nom = "7"
nom-bytes = { git = "https://github.com/w4/nom-bytes" }
oauth2 = "4.2"
once_cell = "1.8"
openid = "0.10"
rand = "0.8"
@@ -1,5 +1,6 @@
use chacha20poly1305::Key as ChaCha20Poly1305Key;
use chartered_fs::FileSystem;
use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl};
use openid::DiscoveredClient;
use serde::{de::Error as SerdeDeError, Deserialize};
use std::collections::HashMap;
@@ -17,8 +18,6 @@
Parse(#[from] url::ParseError),
}
pub type OidcClients = HashMap<String, DiscoveredClient>;
#[derive(Deserialize, Debug)]
#[serde(deny_unknown_fields)]
pub struct Config {
@@ -39,16 +38,15 @@
}
pub async fn create_oidc_clients(&self) -> Result<OidcClients, Error> {
Ok(futures::future::try_join_all(
let mut clients: OidcClients = futures::future::try_join_all(
self.auth
.oauth
.iter()
.filter(|(_, config)| config.enabled)
.map(|(name, config)| async move {
let redirect = self.frontend_base_uri.join("auth/login/oauth")?;
let redirect = self.frontend_base_uri.join("login/oauth")?;
Ok::<_, Error>((
name.to_string(),
let client = Box::new(
DiscoveredClient::discover(
config.client_id.to_string(),
config.client_secret.to_string(),
@@ -56,18 +54,41 @@
config.discovery_uri.clone(),
)
.await?,
))
);
Ok::<_, Error>((name.to_string(), OidcClient::Discovered(client)))
}),
)
.await?
.into_iter()
.collect())
.collect();
if let Some(github) = self.auth.github.clone() {
let redirect = self.frontend_base_uri.join("login/oauth")?;
let client = Box::new(
oauth2::basic::BasicClient::new(
github.client_id,
Some(github.client_secret),
AuthUrl::new("https://github.com/login/oauth/authorize".to_string())?,
Some(TokenUrl::new(
"https://github.com/login/oauth/access_token".to_string(),
)?),
)
.set_redirect_uri(RedirectUrl::from_url(redirect)),
);
clients.insert("github".to_string(), OidcClient::GitHub(client));
}
Ok(clients)
}
}
#[derive(Deserialize, Default, Debug)]
pub struct AuthConfig {
pub password: PasswordAuthConfig,
pub github: Option<GitHubConfig>,
#[serde(flatten)]
pub oauth: HashMap<String, OAuthConfig>,
}
@@ -75,7 +96,15 @@
#[derive(Deserialize, Default, Debug)]
#[serde(deny_unknown_fields)]
pub struct PasswordAuthConfig {
pub enabled: bool,
}
#[derive(Deserialize, Clone, Debug)]
#[serde(deny_unknown_fields)]
pub struct GitHubConfig {
pub enabled: bool,
pub client_id: ClientId,
pub client_secret: ClientSecret,
}
#[derive(Deserialize, Debug)]
@@ -84,6 +113,13 @@
pub discovery_uri: Url,
pub client_id: String,
pub client_secret: String,
}
pub type OidcClients = HashMap<String, OidcClient>;
pub enum OidcClient {
Discovered(Box<DiscoveredClient>),
GitHub(Box<oauth2::basic::BasicClient>),
}
fn deserialize_encryption_key<'de, D: serde::Deserializer<'de>>(
@@ -11,7 +11,7 @@
routing::get,
Extension, Router,
};
use clap::Parser;
use clap::{crate_name, crate_version, Parser};
use std::{fmt::Formatter, path::PathBuf, sync::Arc};
use thiserror::Error;
use tower::ServiceBuilder;
@@ -64,6 +64,9 @@
.into_inner();
let config = Arc::new(config);
let http_client = reqwest::Client::builder()
.user_agent(format!("{}/{}", crate_name!(), crate_version!()))
.build()?;
let app = Router::new()
.route("/", get(hello_world))
@@ -111,7 +114,8 @@
.layer(Extension(pool))
.layer(Extension(Arc::new(config.create_oidc_clients().await?)))
.layer(Extension(Arc::new(config.get_file_system().await?)))
.layer(Extension(config.clone()));
.layer(Extension(config.clone()))
.layer(Extension(http_client));
info!("HTTP server listening on {}", bind_address);
@@ -137,6 +141,8 @@
ServerSpawn(Box<dyn std::error::Error>),
#[error("Failed to build CORS header: {0}")]
Cors(axum::http::header::InvalidHeaderValue),
#[error("Failed to initialise reqwest client: {0}")]
Reqwest(#[from] reqwest::Error),
}
impl std::fmt::Debug for InitError {
@@ -1,15 +1,20 @@
use crate::config::{Config, OidcClients};
use crate::config::{Config, OidcClient, OidcClients};
use axum::{extract, Json};
use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit, Nonce as ChaCha20Poly1305Nonce};
use chartered_db::{users::User, ConnectionPool};
use openid::{Options, Token};
use oauth2::{
basic::BasicErrorResponseType, AuthorizationCode, CsrfToken, RequestTokenError, Scope,
StandardErrorResponse, TokenResponse,
};
use openid::{Options, Token, Userinfo};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
use url::Url;
pub type Nonce = [u8; 16];
@@ -42,16 +47,27 @@
let nonce = rand::random::<Nonce>();
let state = serde_json::to_vec(&State { provider, nonce })?;
let state = encrypt_url_safe(&state, &config)?;
let redirect_url = match client {
OidcClient::Discovered(client) => client.auth_url(&Options {
scope: Some("openid email profile".into()),
nonce: Some(base64::encode_config(&nonce, base64::URL_SAFE_NO_PAD)),
state: Some(state),
..Options::default()
}),
OidcClient::GitHub(client) => {
client
.authorize_url(move || CsrfToken::new(state))
.add_scope(Scope::new("read:user".to_string()))
.add_scope(Scope::new("user:email".to_string()))
.url()
.0
}
};
let auth_url = client.auth_url(&Options {
scope: Some("openid email profile".into()),
nonce: Some(base64::encode_config(&nonce, base64::URL_SAFE_NO_PAD)),
state: Some(encrypt_url_safe(&state, &config)?),
..Options::default()
});
Ok(Json(BeginResponse {
redirect_url: auth_url.to_string(),
redirect_url: redirect_url.to_string(),
}))
}
@@ -62,6 +78,7 @@
extract::Extension(config): extract::Extension<Arc<Config>>,
extract::Extension(oidc_clients): extract::Extension<Arc<OidcClients>>,
extract::Extension(db): extract::Extension<ConnectionPool>,
extract::Extension(http_client): extract::Extension<reqwest::Client>,
user_agent: Option<extract::TypedHeader<headers::UserAgent>>,
addr: extract::ConnectInfo<std::net::SocketAddr>,
) -> Result<Json<super::LoginResponse>, Error> {
@@ -73,41 +90,109 @@
let client = oidc_clients
.get(&state.provider)
.ok_or(Error::UnknownOauthProvider)?;
let mut token: Token = client.request_token(¶ms.code).await?.into();
if let Some(id_token) = token.id_token.as_mut() {
client.decode_token(id_token)?;
let nonce = base64::encode_config(state.nonce, base64::URL_SAFE_NO_PAD);
client.validate_token(id_token, Some(nonce.as_str()), None)?;
} else {
return Err(Error::MissingToken);
}
let userinfo = client.request_userinfo(&token).await?;
let user = match client {
OidcClient::Discovered(client) => {
let mut token: Token = client.request_token(¶ms.code).await?.into();
if let Some(id_token) = token.id_token.as_mut() {
client.decode_token(id_token)?;
let nonce = base64::encode_config(state.nonce, base64::URL_SAFE_NO_PAD);
client.validate_token(id_token, Some(nonce.as_str()), None)?;
} else {
return Err(Error::MissingToken);
}
UserIr::from(client.request_userinfo(&token).await?)
}
OidcClient::GitHub(client) => {
let token_result = client
.exchange_code(AuthorizationCode::new(params.code))
.request_async(oauth2::reqwest::async_http_client)
.await?;
eprintln!("{}", token_result.access_token().secret());
let res: GitHubUserResponse = http_client
.get("https://api.github.com/user")
.bearer_auth(token_result.access_token().secret())
.header("Accept", "application/vnd.github+json")
.send()
.await?
.json()
.await?;
UserIr::from(res)
}
};
let user = User::find_or_create(
db.clone(),
format!("{}:{}", state.provider, userinfo.sub.unwrap()),
userinfo.name,
userinfo.nickname,
userinfo.email,
userinfo.profile,
userinfo.picture,
format!("{}:{}", state.provider, user.id),
user.name,
user.nick,
user.email,
user.profile_url,
user.avatar_url,
)
.await?;
Ok(Json(super::login(db, user, user_agent, addr).await?))
}
pub struct UserIr {
id: String,
name: Option<String>,
nick: Option<String>,
email: Option<String>,
profile_url: Option<Url>,
avatar_url: Option<Url>,
}
impl From<GitHubUserResponse> for UserIr {
fn from(v: GitHubUserResponse) -> Self {
UserIr {
id: v.id.to_string(),
name: Some(v.name.unwrap_or_else(|| v.login.to_string())),
nick: Some(v.login),
email: Some(v.email),
profile_url: v.html_url,
avatar_url: Some(v.avatar_url),
}
}
}
impl From<Userinfo> for UserIr {
fn from(v: Userinfo) -> Self {
UserIr {
id: v.sub.unwrap(),
name: v.name,
nick: v.nickname,
email: v.email,
profile_url: v.profile,
avatar_url: v.picture,
}
}
}
#[derive(Deserialize)]
pub struct GitHubUserResponse {
id: u64,
login: String,
avatar_url: Url,
html_url: Option<Url>,
name: Option<String>,
email: String,
}
const NONCE_LEN: usize = 12;
@@ -182,6 +267,16 @@
Base64(#[from] base64::DecodeError),
#[error("Missing id_token")]
MissingToken,
#[error("Failed to request profile from OAuth provider")]
FetchProfile(#[from] reqwest::Error),
#[error("Failed to request token from OAuth provider")]
RequestOAuthToken(
#[from]
RequestTokenError<
oauth2::reqwest::Error<reqwest::Error>,
StandardErrorResponse<BasicErrorResponseType>,
>,
),
}
impl Error {
@@ -190,6 +285,7 @@
match self {
Self::Database(e) => e.status_code(),
Self::FetchProfile(_) | Self::RequestOAuthToken(_) => StatusCode::BAD_GATEWAY,
_ => StatusCode::INTERNAL_SERVER_ERROR,
}
}