use std::{path::PathBuf, sync::Arc}; use axum::async_trait; use rocksdb::{IteratorMode, MergeOperands, Options, DB}; use serde::Deserialize; use uuid::Uuid; use crate::store::{Account, AccountAccessLevel, AccountProvider, User, UserProvider}; #[derive(Debug)] pub enum Error {} const USER_BY_USERNAME_CF: &str = "users_by_username"; const USER_BY_UUID_CF: &str = "users_by_uuid"; const USER_SEQ_NUMBER: &str = "users_seq_number"; const ACCOUNTS_BY_UUID: &str = "accounts_by_uuid"; const ACCOUNTS_ACCESS_BY_USER: &str = "accounts_access_by_user"; const BINCODE_CONFIG: bincode::config::Configuration = bincode::config::standard(); #[derive(Deserialize)] #[serde(rename_all = "kebab-case")] pub struct Config { path: PathBuf, } // TODO: lots of blocking on async thread pub struct RocksDb { db: Arc, } impl RocksDb { pub fn new(config: Config) -> Self { let mut db_options = Options::default(); db_options.create_if_missing(true); db_options.set_merge_operator_associative("test operator", rocksdb_merger); db_options.create_missing_column_families(true); let db = DB::open_cf_with_opts( &db_options, config.path, [ (USER_BY_USERNAME_CF, db_options.clone()), (USER_BY_UUID_CF, db_options.clone()), (ACCOUNTS_BY_UUID, db_options.clone()), (ACCOUNTS_ACCESS_BY_USER, db_options.clone()), (USER_SEQ_NUMBER, db_options.clone()), ], ) .unwrap(); Self { db: Arc::new(db) } } } #[allow(clippy::unnecessary_wraps)] // rocksdb api restriction fn rocksdb_merger( _new_key: &[u8], existing_val: Option<&[u8]>, operands: &MergeOperands, ) -> Option> { let mut new_val = existing_val.map(<[u8]>::to_vec).unwrap_or_default(); for operand in operands { let (operation, operand) = MergeOperation::parse(operand); match operation { Some(MergeOperation::Increment) => { if new_val.is_empty() { new_val.extend_from_slice(&0_u64.to_be_bytes()); } let mut carry = true; for byte in new_val.iter_mut().rev() { if carry { *byte = byte.wrapping_add(1); carry = *byte == 0; } else { break; } } if carry { new_val.fill(0); } } None => { panic!("unknown operand: {operand:?}"); } } } Some(new_val) } enum MergeOperation { Increment, } impl MergeOperation { pub fn parse(v: &[u8]) -> (Option, &[u8]) { if v == b"INCR" { (Some(Self::Increment), &[]) } else { (None, v) } } } #[async_trait] impl AccountProvider for RocksDb { type Error = Error; async fn create_account(&self, account: Account) -> Result<(), Self::Error> { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap(); let by_uuid_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); db.put_cf(by_uuid_handle, account.id.as_bytes(), bytes) .unwrap(); Ok(()) }) .await .unwrap() } async fn attach_account_to_user( &self, account: Uuid, user: Uuid, access: AccountAccessLevel, ) -> Result<(), Self::Error> { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); let mut compound_key = [0_u8; 32]; compound_key[..16].copy_from_slice(user.as_bytes()); compound_key[16..].copy_from_slice(account.as_bytes()); db.put_cf(access_handle, compound_key, (access as u8).to_be_bytes()) .unwrap(); }) .await .unwrap(); self.increment_seq_number_for_user(user).await.unwrap(); Ok(()) } async fn get_accounts_for_user(&self, user_id: Uuid) -> Result, Self::Error> { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap(); let account_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap(); Ok(db .prefix_iterator_cf(access_handle, user_id.as_bytes()) .map(Result::unwrap) .filter_map(|(key, _access_level)| { let Some(account) = key.strip_prefix(user_id.as_bytes()) else { panic!("got invalid key from rocksdb"); }; let Some(account_bytes) = db.get_cf(account_handle, account).unwrap() else { return None; }; let (res, _): (Account, _) = bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap(); Some(res) }) .collect()) }) .await .unwrap() } } #[async_trait] impl UserProvider for RocksDb { type Error = Error; async fn increment_seq_number_for_user(&self, user: Uuid) -> Result<(), Self::Error> { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap(); db.merge_cf(seq_handle, user.as_bytes(), "INCR").unwrap(); Ok(()) }) .await .unwrap() } async fn fetch_seq_number_for_user(&self, user: Uuid) -> Result { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap(); let Some(bytes) = db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else { return Ok(0); }; let mut val = [0_u8; std::mem::size_of::()]; val.copy_from_slice(&bytes); Ok(u64::from_be_bytes(val)) }) .await .unwrap() } async fn has_any_users(&self) -> Result { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); Ok(db .full_iterator_cf(by_uuid_handle, IteratorMode::Start) .next() .is_some()) }) .await .unwrap() } async fn create_user(&self, user: User) -> Result<(), Self::Error> { let db = self.db.clone(); tokio::task::spawn_blocking(move || { let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap(); let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); db.put_cf(by_uuid_handle, user.id.as_bytes(), bytes) .unwrap(); let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap(); db.put_cf( by_username_handle, user.username.as_bytes(), user.id.as_bytes(), ) .unwrap(); Ok(()) }) .await .unwrap() } async fn get_by_username(&self, username: &str) -> Result, Error> { let db = self.db.clone(); let username = username.to_string(); tokio::task::spawn_blocking(move || { let uuid = { let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap(); db.get_pinned_cf(by_username_handle, username).unwrap() }; let Some(uuid) = uuid else { return Ok(None); }; let user_bytes = { let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap(); db.get_pinned_cf(by_uuid_handle, &uuid).unwrap() }; let Some(user_bytes) = user_bytes else { return Ok(None); }; Ok(Some( bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG) .unwrap() .0, )) }) .await .unwrap() } }