From 7fabb4a08d7e2fcf729633b5cfe3589612c2fbf9 Mon Sep 17 00:00:00 2001 From: Jordan Doyle Date: Thu, 01 Sep 2022 02:10:52 +0100 Subject: [PATCH] First-class support for GitHub authentication --- Cargo.lock | 131 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-- chartered-web/Cargo.toml | 1 + chartered-web/src/config.rs | 52 ++++++++++++++++++++++++++++++++++++++++++++++++++-- chartered-web/src/main.rs | 10 ++++++++-- chartered-web/src/endpoints/web_api/auth/openid.rs | 162 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- 5 files changed, 303 insertions(+), 53 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 6b57aad..d950d4d 100644 --- a/Cargo.lock +++ a/Cargo.lock @@ -300,7 +300,7 @@ "http", "http-body", "hyper", - "hyper-rustls", + "hyper-rustls 0.22.1", "lazy_static", "pin-project-lite", "tokio", @@ -741,6 +741,7 @@ "hex", "nom", "nom-bytes", + "oauth2", "once_cell", "openid", "rand", @@ -935,7 +936,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c1a816186fa68d9e426e3cb4ae4dff1fcd8e4a2c34b781bf7a822574a0d0aac8" dependencies = [ - "sct", + "sct 0.6.1", ] [[package]] @@ -1249,8 +1250,10 @@ checksum = "4eb1a864a501629691edf6c15a593b7a51eebaa1e8468e9ddc623de7c9b58ec6" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi 0.11.0+wasi-snapshot-preview1", + "wasm-bindgen", ] [[package]] @@ -1417,11 +1420,24 @@ "futures-util", "hyper", "log", - "rustls", + "rustls 0.19.1", "rustls-native-certs", + "tokio", + "tokio-rustls 0.22.0", + "webpki 0.21.4", +] + +[[package]] +name = "hyper-rustls" +version = "0.23.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d87c48c02e0dc5e3b849a2041db3029fd066650f8f717c07bf8ed78ccb895cac" +dependencies = [ + "http", + "hyper", + "rustls 0.20.6", "tokio", - "tokio-rustls", - "webpki", + "tokio-rustls 0.23.4", ] [[package]] @@ -1821,6 +1837,26 @@ checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" dependencies = [ "libc", +] + +[[package]] +name = "oauth2" +version = "4.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d62c436394991641b970a92e23e8eeb4eb9bca74af4f5badc53bcd568daadbd" +dependencies = [ + "base64", + "chrono", + "getrandom", + "http", + "rand", + "reqwest", + "serde", + "serde_json", + "serde_path_to_error", + "sha2 0.10.2", + "thiserror", + "url", ] [[package]] @@ -2175,6 +2211,7 @@ "http", "http-body", "hyper", + "hyper-rustls 0.23.0", "hyper-tls", "ipnet", "js-sys", @@ -2184,16 +2221,20 @@ "native-tls", "percent-encoding", "pin-project-lite", + "rustls 0.20.6", + "rustls-pemfile", "serde", "serde_json", "serde_urlencoded", "tokio", "tokio-native-tls", + "tokio-rustls 0.23.4", "tower-service", "url", "wasm-bindgen", "wasm-bindgen-futures", "web-sys", + "webpki-roots", "winreg", ] @@ -2228,10 +2269,22 @@ checksum = "35edb675feee39aec9c99fa5ff985081995a06d594114ae14cbe797ad7b7a6d7" dependencies = [ "base64", + "log", + "ring", + "sct 0.6.1", + "webpki 0.21.4", +] + +[[package]] +name = "rustls" +version = "0.20.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5aab8ee6c7097ed6057f43c187a62418d0c05a4bd5f18b3571db50ee0f9ce033" +dependencies = [ "log", "ring", - "sct", - "webpki", + "sct 0.7.0", + "webpki 0.22.0", ] [[package]] @@ -2241,9 +2294,18 @@ checksum = "5a07b7c1885bd8ed3831c289b7870b13ef46fe0e856d288c30d9cc17d75a2092" dependencies = [ "openssl-probe", - "rustls", + "rustls 0.19.1", "schannel", "security-framework", +] + +[[package]] +name = "rustls-pemfile" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0864aeff53f8c05aa08d86e5ef839d3dfcf07aeba2db32f12db0ef716e87bd55" +dependencies = [ + "base64", ] [[package]] @@ -2291,6 +2353,16 @@ version = "0.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b362b83898e0e69f38515b82ee15aa80636befe47c3b6d3d89a911e78fc228ce" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "sct" +version = "0.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d53dcdb7c9f8158937a7981b48accfd39a43af418591a5d008c7b22b5e1b7ca4" dependencies = [ "ring", "untrusted", @@ -2354,6 +2426,15 @@ "indexmap", "itoa", "ryu", + "serde", +] + +[[package]] +name = "serde_path_to_error" +version = "0.1.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "184c643044780f7ceb59104cef98a5a6f12cb2288a7bc701ab93a362b49fd47d" +dependencies = [ "serde", ] @@ -2707,10 +2788,21 @@ version = "0.22.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bc6844de72e57df1980054b38be3a9f4702aba4858be64dd700181a8a6d0e1b6" +dependencies = [ + "rustls 0.19.1", + "tokio", + "webpki 0.21.4", +] + +[[package]] +name = "tokio-rustls" +version = "0.23.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c43ee83903113e03984cb9e5cebe6c04a5116269e900e3ddba8f068a62adda59" dependencies = [ - "rustls", + "rustls 0.20.6", "tokio", - "webpki", + "webpki 0.22.0", ] [[package]] @@ -3112,9 +3204,28 @@ version = "0.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b8e38c0608262c46d4a56202ebabdeb094cef7e560ca7a226c6bf055188aa4ea" +dependencies = [ + "ring", + "untrusted", +] + +[[package]] +name = "webpki" +version = "0.22.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f095d78192e208183081cc07bc5515ef55216397af48b873e5edcd72637fa1bd" dependencies = [ "ring", "untrusted", +] + +[[package]] +name = "webpki-roots" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1c760f0d366a6c24a02ed7816e23e691f5d92291f94d15e836006fd11b04daf" +dependencies = [ + "webpki 0.22.0", ] [[package]] diff --git a/chartered-web/Cargo.toml b/chartered-web/Cargo.toml index 39d1054..ef875f9 100644 --- a/chartered-web/Cargo.toml +++ a/chartered-web/Cargo.toml @@ -23,6 +23,7 @@ hex = "0.4" nom = "7" nom-bytes = { git = "https://github.com/w4/nom-bytes" } +oauth2 = "4.2" once_cell = "1.8" openid = "0.10" rand = "0.8" diff --git a/chartered-web/src/config.rs b/chartered-web/src/config.rs index d15e120..a64428d 100644 --- a/chartered-web/src/config.rs +++ a/chartered-web/src/config.rs @@ -1,5 +1,6 @@ use chacha20poly1305::Key as ChaCha20Poly1305Key; use chartered_fs::FileSystem; +use oauth2::{AuthUrl, ClientId, ClientSecret, RedirectUrl, TokenUrl}; use openid::DiscoveredClient; use serde::{de::Error as SerdeDeError, Deserialize}; use std::collections::HashMap; @@ -17,8 +18,6 @@ Parse(#[from] url::ParseError), } -pub type OidcClients = HashMap; - #[derive(Deserialize, Debug)] #[serde(deny_unknown_fields)] pub struct Config { @@ -39,16 +38,15 @@ } pub async fn create_oidc_clients(&self) -> Result { - Ok(futures::future::try_join_all( + let mut clients: OidcClients = futures::future::try_join_all( self.auth .oauth .iter() .filter(|(_, config)| config.enabled) .map(|(name, config)| async move { - let redirect = self.frontend_base_uri.join("auth/login/oauth")?; + let redirect = self.frontend_base_uri.join("login/oauth")?; - Ok::<_, Error>(( - name.to_string(), + let client = Box::new( DiscoveredClient::discover( config.client_id.to_string(), config.client_secret.to_string(), @@ -56,18 +54,41 @@ config.discovery_uri.clone(), ) .await?, - )) + ); + + Ok::<_, Error>((name.to_string(), OidcClient::Discovered(client))) }), ) .await? .into_iter() - .collect()) + .collect(); + + if let Some(github) = self.auth.github.clone() { + let redirect = self.frontend_base_uri.join("login/oauth")?; + + let client = Box::new( + oauth2::basic::BasicClient::new( + github.client_id, + Some(github.client_secret), + AuthUrl::new("https://github.com/login/oauth/authorize".to_string())?, + Some(TokenUrl::new( + "https://github.com/login/oauth/access_token".to_string(), + )?), + ) + .set_redirect_uri(RedirectUrl::from_url(redirect)), + ); + + clients.insert("github".to_string(), OidcClient::GitHub(client)); + } + + Ok(clients) } } #[derive(Deserialize, Default, Debug)] pub struct AuthConfig { pub password: PasswordAuthConfig, + pub github: Option, #[serde(flatten)] pub oauth: HashMap, } @@ -75,7 +96,15 @@ #[derive(Deserialize, Default, Debug)] #[serde(deny_unknown_fields)] pub struct PasswordAuthConfig { + pub enabled: bool, +} + +#[derive(Deserialize, Clone, Debug)] +#[serde(deny_unknown_fields)] +pub struct GitHubConfig { pub enabled: bool, + pub client_id: ClientId, + pub client_secret: ClientSecret, } #[derive(Deserialize, Debug)] @@ -84,6 +113,13 @@ pub discovery_uri: Url, pub client_id: String, pub client_secret: String, +} + +pub type OidcClients = HashMap; + +pub enum OidcClient { + Discovered(Box), + GitHub(Box), } fn deserialize_encryption_key<'de, D: serde::Deserializer<'de>>( diff --git a/chartered-web/src/main.rs b/chartered-web/src/main.rs index 16587ac..edfb60e 100644 --- a/chartered-web/src/main.rs +++ a/chartered-web/src/main.rs @@ -11,7 +11,7 @@ routing::get, Extension, Router, }; -use clap::Parser; +use clap::{crate_name, crate_version, Parser}; use std::{fmt::Formatter, path::PathBuf, sync::Arc}; use thiserror::Error; use tower::ServiceBuilder; @@ -64,6 +64,9 @@ .into_inner(); let config = Arc::new(config); + let http_client = reqwest::Client::builder() + .user_agent(format!("{}/{}", crate_name!(), crate_version!())) + .build()?; let app = Router::new() .route("/", get(hello_world)) @@ -111,7 +114,8 @@ .layer(Extension(pool)) .layer(Extension(Arc::new(config.create_oidc_clients().await?))) .layer(Extension(Arc::new(config.get_file_system().await?))) - .layer(Extension(config.clone())); + .layer(Extension(config.clone())) + .layer(Extension(http_client)); info!("HTTP server listening on {}", bind_address); @@ -137,6 +141,8 @@ ServerSpawn(Box), #[error("Failed to build CORS header: {0}")] Cors(axum::http::header::InvalidHeaderValue), + #[error("Failed to initialise reqwest client: {0}")] + Reqwest(#[from] reqwest::Error), } impl std::fmt::Debug for InitError { diff --git a/chartered-web/src/endpoints/web_api/auth/openid.rs b/chartered-web/src/endpoints/web_api/auth/openid.rs index 60cfa04..07fdb74 100644 --- a/chartered-web/src/endpoints/web_api/auth/openid.rs +++ a/chartered-web/src/endpoints/web_api/auth/openid.rs @@ -1,15 +1,20 @@ //! Methods for `OpenID` Connect authentication, we allow the frontend to list all the available and //! enabled providers so they can show them to the frontend and provide methods for actually doing //! the authentication. -use crate::config::{Config, OidcClients}; +use crate::config::{Config, OidcClient, OidcClients}; use axum::{extract, Json}; use chacha20poly1305::{aead::Aead, ChaCha20Poly1305, KeyInit, Nonce as ChaCha20Poly1305Nonce}; use chartered_db::{users::User, ConnectionPool}; -use openid::{Options, Token}; +use oauth2::{ + basic::BasicErrorResponseType, AuthorizationCode, CsrfToken, RequestTokenError, Scope, + StandardErrorResponse, TokenResponse, +}; +use openid::{Options, Token, Userinfo}; use serde::{Deserialize, Serialize}; use std::sync::Arc; use thiserror::Error; +use url::Url; pub type Nonce = [u8; 16]; @@ -42,16 +47,27 @@ let nonce = rand::random::(); let state = serde_json::to_vec(&State { provider, nonce })?; + let state = encrypt_url_safe(&state, &config)?; + + let redirect_url = match client { + OidcClient::Discovered(client) => client.auth_url(&Options { + scope: Some("openid email profile".into()), + nonce: Some(base64::encode_config(&nonce, base64::URL_SAFE_NO_PAD)), + state: Some(state), + ..Options::default() + }), + OidcClient::GitHub(client) => { + client + .authorize_url(move || CsrfToken::new(state)) + .add_scope(Scope::new("read:user".to_string())) + .add_scope(Scope::new("user:email".to_string())) + .url() + .0 + } + }; - let auth_url = client.auth_url(&Options { - scope: Some("openid email profile".into()), - nonce: Some(base64::encode_config(&nonce, base64::URL_SAFE_NO_PAD)), - state: Some(encrypt_url_safe(&state, &config)?), - ..Options::default() - }); - Ok(Json(BeginResponse { - redirect_url: auth_url.to_string(), + redirect_url: redirect_url.to_string(), })) } @@ -62,6 +78,7 @@ extract::Extension(config): extract::Extension>, extract::Extension(oidc_clients): extract::Extension>, extract::Extension(db): extract::Extension, + extract::Extension(http_client): extract::Extension, user_agent: Option>, addr: extract::ConnectInfo, ) -> Result, Error> { @@ -73,41 +90,109 @@ let client = oidc_clients .get(&state.provider) .ok_or(Error::UnknownOauthProvider)?; - - let mut token: Token = client.request_token(¶ms.code).await?.into(); - - if let Some(id_token) = token.id_token.as_mut() { - // ensure the id_token is valid, checking `exp`, etc. - client.decode_token(id_token)?; - - // ensure the nonce in the returned id_token is the same as the one we sent out encrypted - // with the original request - let nonce = base64::encode_config(state.nonce, base64::URL_SAFE_NO_PAD); - client.validate_token(id_token, Some(nonce.as_str()), None)?; - } else { - // the provider didn't send us back a id_token - return Err(Error::MissingToken); - } - // get some basic info from the provider using the claims we requested in `begin_oidc` - let userinfo = client.request_userinfo(&token).await?; + let user = match client { + OidcClient::Discovered(client) => { + let mut token: Token = client.request_token(¶ms.code).await?.into(); + + if let Some(id_token) = token.id_token.as_mut() { + // ensure the id_token is valid, checking `exp`, etc. + client.decode_token(id_token)?; + + // ensure the nonce in the returned id_token is the same as the one we sent out encrypted + // with the original request + let nonce = base64::encode_config(state.nonce, base64::URL_SAFE_NO_PAD); + client.validate_token(id_token, Some(nonce.as_str()), None)?; + } else { + // the provider didn't send us back a id_token + return Err(Error::MissingToken); + } + + // get some basic info from the provider using the claims we requested in `begin_oidc` + UserIr::from(client.request_userinfo(&token).await?) + } + OidcClient::GitHub(client) => { + let token_result = client + .exchange_code(AuthorizationCode::new(params.code)) + .request_async(oauth2::reqwest::async_http_client) + .await?; + + eprintln!("{}", token_result.access_token().secret()); + + let res: GitHubUserResponse = http_client + .get("https://api.github.com/user") + .bearer_auth(token_result.access_token().secret()) + .header("Accept", "application/vnd.github+json") + .send() + .await? + .json() + .await?; + + UserIr::from(res) + } + }; let user = User::find_or_create( db.clone(), // we're using `provider:uid` as the format for OIDC logins, this is fine to create // without a password because (1) password auth rejects blank passwords and (2) password // auth also rejects any usernames with a `:` in. - format!("{}:{}", state.provider, userinfo.sub.unwrap()), - userinfo.name, - userinfo.nickname, - userinfo.email, - userinfo.profile, - userinfo.picture, + format!("{}:{}", state.provider, user.id), + user.name, + user.nick, + user.email, + user.profile_url, + user.avatar_url, ) .await?; // request looks good, log the user in! Ok(Json(super::login(db, user, user_agent, addr).await?)) +} + +pub struct UserIr { + id: String, + name: Option, + nick: Option, + email: Option, + profile_url: Option, + avatar_url: Option, +} + +impl From for UserIr { + fn from(v: GitHubUserResponse) -> Self { + UserIr { + id: v.id.to_string(), + name: Some(v.name.unwrap_or_else(|| v.login.to_string())), + nick: Some(v.login), + email: Some(v.email), + profile_url: v.html_url, + avatar_url: Some(v.avatar_url), + } + } +} + +impl From for UserIr { + fn from(v: Userinfo) -> Self { + UserIr { + id: v.sub.unwrap(), + name: v.name, + nick: v.nickname, + email: v.email, + profile_url: v.profile, + avatar_url: v.picture, + } + } +} + +#[derive(Deserialize)] +pub struct GitHubUserResponse { + id: u64, + login: String, + avatar_url: Url, + html_url: Option, + name: Option, + email: String, } const NONCE_LEN: usize = 12; @@ -182,6 +267,16 @@ Base64(#[from] base64::DecodeError), #[error("Missing id_token")] MissingToken, + #[error("Failed to request profile from OAuth provider")] + FetchProfile(#[from] reqwest::Error), + #[error("Failed to request token from OAuth provider")] + RequestOAuthToken( + #[from] + RequestTokenError< + oauth2::reqwest::Error, + StandardErrorResponse, + >, + ), } impl Error { @@ -190,6 +285,7 @@ match self { Self::Database(e) => e.status_code(), + Self::FetchProfile(_) | Self::RequestOAuthToken(_) => StatusCode::BAD_GATEWAY, _ => StatusCode::INTERNAL_SERVER_ERROR, } } -- rgit 0.1.3