use crate::config::Config;
use axum::{extract, Json};
use chartered_db::{users::User, ConnectionPool};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use thiserror::Error;
pub async fn handle_register(
extract::Extension(config): extract::Extension<Arc<Config>>,
extract::Extension(db): extract::Extension<ConnectionPool>,
extract::Json(req): extract::Json<RegisterRequest>,
) -> Result<Json<RegisterResponse>, RegisterError> {
if !config.auth.password.enabled {
return Err(RegisterError::PasswordAuthDisabled);
} else if !validate_username(&req.username) {
return Err(RegisterError::InvalidUsername);
} else if req.password.len() < 6 {
return Err(RegisterError::PasswordRequirementNotMet);
}
let password_hash = bcrypt::hash(&req.password, bcrypt::DEFAULT_COST)?;
match User::register(db, req.username, password_hash).await {
Ok(_) => Ok(Json(RegisterResponse { success: true })),
Err(chartered_db::Error::UsernameTaken) => Err(RegisterError::UsernameTaken),
Err(e) => Err(e.into()),
}
}
pub async fn handle_login(
extract::Extension(config): extract::Extension<Arc<Config>>,
extract::Extension(db): extract::Extension<ConnectionPool>,
extract::Json(req): extract::Json<LoginRequest>,
user_agent: Option<extract::TypedHeader<headers::UserAgent>>,
addr: extract::ConnectInfo<std::net::SocketAddr>,
) -> Result<Json<super::LoginResponse>, LoginError> {
if !config.auth.password.enabled {
return Err(LoginError::PasswordAuthDisabled);
} else if !validate_username(&req.username) {
return Err(LoginError::UnknownUser);
} else if req.password.is_empty() {
return Err(LoginError::InvalidPassword);
}
let user = User::find_by_username(db.clone(), req.username)
.await?
.ok_or(LoginError::UnknownUser)?;
let password_hash = user
.password
.as_deref()
.ok_or(LoginError::UnknownUser)?;
if bcrypt::verify(&req.password, password_hash)? {
Ok(Json(super::login(db, user, user_agent, addr).await?))
} else {
Err(LoginError::InvalidPassword)
}
}
pub fn validate_username(username: &str) -> bool {
!username.contains(':')
&& !username.is_empty()
}
#[derive(Deserialize)]
pub struct RegisterRequest {
username: String,
password: String,
}
#[derive(Deserialize)]
pub struct LoginRequest {
username: String,
password: String,
}
#[derive(Serialize)]
pub struct RegisterResponse {
success: bool,
}
#[derive(Error, Debug)]
pub enum RegisterError {
#[error("Failed to query database")]
Database(#[from] chartered_db::Error),
#[error("Failed to hash password")]
Bcrypt(#[from] bcrypt::BcryptError),
#[error("Username is invalid")]
InvalidUsername,
#[error("Username already taken")]
UsernameTaken,
#[error("Password authentication is disabled")]
PasswordAuthDisabled,
#[error("Password must be at least 6 characters long")]
PasswordRequirementNotMet,
}
impl RegisterError {
pub fn status_code(&self) -> axum::http::StatusCode {
use axum::http::StatusCode;
match self {
Self::Database(_) | Self::Bcrypt(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::InvalidUsername | Self::UsernameTaken | Self::PasswordRequirementNotMet => {
StatusCode::BAD_REQUEST
}
Self::PasswordAuthDisabled => StatusCode::FORBIDDEN,
}
}
}
define_error_response!(RegisterError);
#[derive(Error, Debug)]
pub enum LoginError {
#[error("Failed to query database")]
Database(#[from] chartered_db::Error),
#[error("Failed to hash password")]
Bcrypt(#[from] bcrypt::BcryptError),
#[error("Invalid username/password")]
UnknownUser,
#[error("Invalid username/password")]
InvalidPassword,
#[error("Password authentication is disabled")]
PasswordAuthDisabled,
}
impl LoginError {
pub fn status_code(&self) -> axum::http::StatusCode {
use axum::http::StatusCode;
match self {
Self::Database(_) | Self::Bcrypt(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UnknownUser | Self::InvalidPassword | Self::PasswordAuthDisabled => {
StatusCode::FORBIDDEN
}
}
}
}
define_error_response!(LoginError);