use hmac::{digest::FixedOutput, Hmac, Mac};
use sha3::Sha3_256;
use tower_cookies::{
cookie::{time::Duration, CookieBuilder, SameSite},
Cookies,
};
use tracing::warn;
use crate::context::DerivedKeys;
type HmacSha3 = Hmac<Sha3_256>;
const CSRF_TOKEN_COOKIE_NAME: &str = "csrf_token";
#[derive(Copy, Clone)]
pub struct CsrfToken {
signed: [u8; 32],
unsigned: u128,
}
impl CsrfToken {
pub fn new(derived_keys: &DerivedKeys) -> Self {
let unsigned = rand::random::<u128>();
let mut hmac = HmacSha3::new_from_slice(&derived_keys.csrf_hmac_key).unwrap();
hmac.update(&unsigned.to_be_bytes());
let signed = hmac.finalize_fixed().into();
Self { signed, unsigned }
}
pub fn write_cookie(&self, cookies: &Cookies) {
cookies.add(
CookieBuilder::new(CSRF_TOKEN_COOKIE_NAME, hex::encode(self.signed))
.http_only(true)
.max_age(Duration::hours(24))
.same_site(SameSite::Strict)
.finish(),
);
}
#[must_use]
pub fn verify(derived_keys: &DerivedKeys, cookies: &Cookies, form_value: &str) -> bool {
let Some(cookie) = cookies.get(CSRF_TOKEN_COOKIE_NAME) else {
warn!("Missing CSRF token");
return false;
};
let form_value = match hex::decode(form_value) {
Ok(v) => v,
Err(error) => {
warn!(?error, "Invalid form CSRF token");
return false;
}
};
let cookie_token = match hex::decode(cookie.value()) {
Ok(v) => v,
Err(error) => {
warn!(?error, "Invalid cookie CSRF token");
return false;
}
};
let mut hmac = HmacSha3::new_from_slice(&derived_keys.csrf_hmac_key).unwrap();
hmac.update(&form_value);
match hmac.verify_slice(&cookie_token) {
Ok(()) => true,
Err(error) => {
warn!(?error, "CSRF form value and cookie mismatch");
false
}
}
}
pub fn form_value(&self) -> String {
hex::encode(self.unsigned.to_be_bytes())
}
}