use std::{path::PathBuf, sync::Arc};
use axum::async_trait;
use rocksdb::{IteratorMode, MergeOperands, Options, DB};
use serde::Deserialize;
use uuid::Uuid;
use crate::store::{Account, AccountAccessLevel, AccountProvider, User, UserProvider};
#[derive(Debug)]
pub enum Error {}
const USER_BY_USERNAME_CF: &str = "users_by_username";
const USER_BY_UUID_CF: &str = "users_by_uuid";
const USER_SEQ_NUMBER: &str = "users_seq_number";
const ACCOUNTS_BY_UUID: &str = "accounts_by_uuid";
const ACCOUNTS_ACCESS_BY_USER: &str = "accounts_access_by_user";
const BINCODE_CONFIG: bincode::config::Configuration = bincode::config::standard();
#[derive(Deserialize)]
#[serde(rename_all = "kebab-case")]
pub struct Config {
path: PathBuf,
}
pub struct RocksDb {
db: Arc<DB>,
}
impl RocksDb {
pub fn new(config: Config) -> Self {
let mut db_options = Options::default();
db_options.create_if_missing(true);
db_options.set_merge_operator_associative("test operator", rocksdb_merger);
db_options.create_missing_column_families(true);
let db = DB::open_cf_with_opts(
&db_options,
config.path,
[
(USER_BY_USERNAME_CF, db_options.clone()),
(USER_BY_UUID_CF, db_options.clone()),
(ACCOUNTS_BY_UUID, db_options.clone()),
(ACCOUNTS_ACCESS_BY_USER, db_options.clone()),
(USER_SEQ_NUMBER, db_options.clone()),
],
)
.unwrap();
Self { db: Arc::new(db) }
}
}
#[allow(clippy::unnecessary_wraps)]
fn rocksdb_merger(
_new_key: &[u8],
existing_val: Option<&[u8]>,
operands: &MergeOperands,
) -> Option<Vec<u8>> {
let mut new_val = existing_val.map(<[u8]>::to_vec).unwrap_or_default();
for operand in operands {
let (operation, operand) = MergeOperation::parse(operand);
match operation {
Some(MergeOperation::Increment) => {
if new_val.is_empty() {
new_val.extend_from_slice(&0_u64.to_be_bytes());
}
let mut carry = true;
for byte in new_val.iter_mut().rev() {
if carry {
*byte = byte.wrapping_add(1);
carry = *byte == 0;
} else {
break;
}
}
if carry {
new_val.fill(0);
}
}
None => {
panic!("unknown operand: {operand:?}");
}
}
}
Some(new_val)
}
enum MergeOperation {
Increment,
}
impl MergeOperation {
pub fn parse(v: &[u8]) -> (Option<MergeOperation>, &[u8]) {
if v == b"INCR" {
(Some(Self::Increment), &[])
} else {
(None, v)
}
}
}
#[async_trait]
impl AccountProvider for RocksDb {
type Error = Error;
async fn create_account(&self, account: Account) -> Result<(), Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let bytes = bincode::serde::encode_to_vec(&account, BINCODE_CONFIG).unwrap();
let by_uuid_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
db.put_cf(by_uuid_handle, account.id.as_bytes(), bytes)
.unwrap();
Ok(())
})
.await
.unwrap()
}
async fn attach_account_to_user(
&self,
account: Uuid,
user: Uuid,
access: AccountAccessLevel,
) -> Result<(), Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let mut compound_key = [0_u8; 32];
compound_key[..16].copy_from_slice(user.as_bytes());
compound_key[16..].copy_from_slice(account.as_bytes());
db.put_cf(access_handle, compound_key, (access as u8).to_be_bytes())
.unwrap();
})
.await
.unwrap();
self.increment_seq_number_for_user(user).await.unwrap();
Ok(())
}
async fn get_accounts_for_user(&self, user_id: Uuid) -> Result<Vec<Account>, Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let access_handle = db.cf_handle(ACCOUNTS_ACCESS_BY_USER).unwrap();
let account_handle = db.cf_handle(ACCOUNTS_BY_UUID).unwrap();
Ok(db
.prefix_iterator_cf(access_handle, user_id.as_bytes())
.map(Result::unwrap)
.filter_map(|(key, _access_level)| {
let Some(account) = key.strip_prefix(user_id.as_bytes()) else {
panic!("got invalid key from rocksdb");
};
let Some(account_bytes) = db.get_cf(account_handle, account).unwrap() else {
return None;
};
let (res, _): (Account, _) =
bincode::serde::decode_from_slice(&account_bytes, BINCODE_CONFIG).unwrap();
Some(res)
})
.collect())
})
.await
.unwrap()
}
}
#[async_trait]
impl UserProvider for RocksDb {
type Error = Error;
async fn increment_seq_number_for_user(&self, user: Uuid) -> Result<(), Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap();
db.merge_cf(seq_handle, user.as_bytes(), "INCR").unwrap();
Ok(())
})
.await
.unwrap()
}
async fn fetch_seq_number_for_user(&self, user: Uuid) -> Result<u64, Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let seq_handle = db.cf_handle(USER_SEQ_NUMBER).unwrap();
let Some(bytes) = db.get_pinned_cf(seq_handle, user.as_bytes()).unwrap() else {
return Ok(0);
};
let mut val = [0_u8; std::mem::size_of::<u64>()];
val.copy_from_slice(&bytes);
Ok(u64::from_be_bytes(val))
})
.await
.unwrap()
}
async fn has_any_users(&self) -> Result<bool, Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
Ok(db
.full_iterator_cf(by_uuid_handle, IteratorMode::Start)
.next()
.is_some())
})
.await
.unwrap()
}
async fn create_user(&self, user: User) -> Result<(), Self::Error> {
let db = self.db.clone();
tokio::task::spawn_blocking(move || {
let bytes = bincode::serde::encode_to_vec(&user, BINCODE_CONFIG).unwrap();
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
db.put_cf(by_uuid_handle, user.id.as_bytes(), bytes)
.unwrap();
let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap();
db.put_cf(
by_username_handle,
user.username.as_bytes(),
user.id.as_bytes(),
)
.unwrap();
Ok(())
})
.await
.unwrap()
}
async fn get_by_username(&self, username: &str) -> Result<Option<User>, Error> {
let db = self.db.clone();
let username = username.to_string();
tokio::task::spawn_blocking(move || {
let uuid = {
let by_username_handle = db.cf_handle(USER_BY_USERNAME_CF).unwrap();
db.get_pinned_cf(by_username_handle, username).unwrap()
};
let Some(uuid) = uuid else {
return Ok(None);
};
let user_bytes = {
let by_uuid_handle = db.cf_handle(USER_BY_UUID_CF).unwrap();
db.get_pinned_cf(by_uuid_handle, &uuid).unwrap()
};
let Some(user_bytes) = user_bytes else {
return Ok(None);
};
Ok(Some(
bincode::serde::decode_from_slice(&user_bytes, BINCODE_CONFIG)
.unwrap()
.0,
))
})
.await
.unwrap()
}
}