Add a max retention for stored messaged used for replay
Diff
Cargo.lock | 26 ++++++++++++++++++++++++++
Cargo.toml | 3 ++-
config.toml | 2 ++
migrations/2023010814480_initial-schema.sql | 13 ++++++-------
src/config.rs | 18 ++++++++++++++++--
src/main.rs | 3 +++
src/persistence.rs | 99 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++----------------
7 files changed, 119 insertions(+), 45 deletions(-)
@@ -676,6 +676,15 @@
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
[[package]]
name = "humantime"
version = "1.3.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df004cfca50ef23c36850aaaa59ad52cc70d0e90243c3c7737a4dd32dc7a3c4f"
dependencies = [
"quick-error",
]
[[package]]
name = "iana-time-zone"
version = "0.1.53"
source = "registry+https://github.com/rust-lang/crates.io-index"
@@ -1090,6 +1099,12 @@
dependencies = [
"unicode-ident",
]
[[package]]
name = "quick-error"
version = "1.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a1d01941d82fa2ab50be1e79e6714289dd7cde78eba4c074bc5a4374f650dfe0"
[[package]]
name = "quote"
@@ -1248,6 +1263,16 @@
checksum = "bb7d1f0d3021d347a83e556fc4683dea2ea09d87bccdf88ff5c12545d89d5efb"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde-humantime"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c367b5dafa12cef19c554638db10acde90d5e9acea2b80e1ad98b00f88068f7d"
dependencies = [
"humantime",
"serde",
]
[[package]]
@@ -1553,6 +1578,7 @@
"itertools",
"rand",
"serde",
"serde-humantime",
"sqlx",
"tokio",
"tokio-stream",
@@ -9,15 +9,16 @@
actix = "0.13"
actix-rt = "2.7"
anyhow = "1.0"
argon2 = "0.4"
base64 = "0.21.0-rc.1"
bytes = "1.3"
const_format = "0.2"
chrono = "0.4"
clap = { version = "4.0", features = ["cargo", "derive", "std", "suggestions", "color"] }
futures = "0.3"
argon2 = "0.4"
rand = "0.8"
serde = { version = "1.0", features = ["derive"] }
serde-humantime = "0.1"
sqlx = { version = "0.6", features = ["runtime-actix-rustls", "sqlite", "any"] }
tracing = "0.1"
tracing-subscriber = { version = "0.3", features = ["env-filter", "json"] }
@@ -1,6 +1,8 @@
listen-address = "[::]:6667"
database-uri = "sqlite://titanircd.db"
max-message-replay-since = "1d"
client-threads = 1
channel-threads = 1
@@ -15,11 +15,11 @@
CREATE TABLE channel_messages (
channel INT NOT NULL,
idx INT NOT NULL,
timestamp INT NOT NULL,
sender VARCHAR(255),
message VARCHAR(255),
FOREIGN KEY(channel) REFERENCES channels(id),
PRIMARY KEY(channel, idx)
PRIMARY KEY(channel, timestamp)
);
CREATE TABLE channel_users (
@@ -27,10 +27,9 @@
user INT NOT NULL,
permissions INT NOT NULL DEFAULT 0,
in_channel BOOLEAN DEFAULT false,
last_seen_message_idx INT,
last_seen_message_timestamp INT,
FOREIGN KEY(user) REFERENCES users(id),
FOREIGN KEY(channel) REFERENCES channels(id)
FOREIGN KEY(channel) REFERENCES channels(id),
PRIMARY KEY(channel, user)
);
CREATE UNIQUE INDEX channel_user ON channel_users(channel, user);
@@ -1,4 +1,4 @@
use std::{net::SocketAddr, str::FromStr};
use std::{net::SocketAddr, str::FromStr, time::Duration};
use clap::Parser;
use serde::Deserialize;
@@ -19,12 +19,19 @@
pub listen_address: SocketAddr,
pub database_uri: String,
pub motd: Option<String>,
#[serde(
default = "Config::default_max_message_replay_since",
with = "serde_humantime"
)]
pub max_message_replay_since: Duration,
#[serde(default = "Config::default_client_threads")]
pub client_threads: usize,
#[serde(default = "Config::default_channel_threads")]
pub channel_threads: usize,
}
@@ -38,6 +45,11 @@
#[must_use]
const fn default_channel_threads() -> usize {
1
}
#[must_use]
const fn default_max_message_replay_since() -> Duration {
Duration::from_secs(24 * 60 * 60)
}
}
@@ -81,9 +81,12 @@
let persistence_addr = {
let database = database.clone();
let config = opts.config.clone();
Supervisor::start_in_arbiter(&server_arbiter.handle(), move |_ctx| Persistence {
database,
max_message_replay_since: config.max_message_replay_since,
last_seen_clock: 0,
})
};
@@ -1,8 +1,9 @@
pub mod events;
use std::time::Duration;
use actix::{AsyncContext, Context, Handler, ResponseFuture, WrapFuture};
use chrono::Utc;
use itertools::Itertools;
use tracing::instrument;
@@ -14,8 +15,25 @@
pub struct Persistence {
pub database: sqlx::Pool<sqlx::Any>,
pub max_message_replay_since: Duration,
pub last_seen_clock: i64,
}
impl Persistence {
fn monotonically_increasing_id(&mut self) -> i64 {
let now = Utc::now().timestamp_nanos();
self.last_seen_clock = if now <= self.last_seen_clock {
self.last_seen_clock + 1
} else {
now
};
self.last_seen_clock
}
}
impl actix::Supervised for Persistence {}
impl actix::Actor for Persistence {
@@ -25,8 +43,9 @@
ctx.run_interval(Duration::from_secs(300), |this, ctx| {
let database = this.database.clone();
let max_message_replay_since = this.max_message_replay_since;
ctx.spawn(truncate_seen_messages(database).into_actor(this));
ctx.spawn(truncate_seen_messages(database, max_message_replay_since).into_actor(this));
});
}
}
@@ -135,31 +154,30 @@
fn handle(&mut self, msg: ChannelMessage, _ctx: &mut Self::Context) -> Self::Result {
let conn = self.database.clone();
let timestamp = self.monotonically_increasing_id();
Box::pin(async move {
let (idx,): (i64,) = sqlx::query_as(
"INSERT INTO channel_messages (channel, idx, sender, message)
VALUES (?, COALESCE((SELECT MAX(idx) + 1 FROM channel_messages WHERE channel = ?), 0), ?, ?)
RETURNING idx",
sqlx::query(
"INSERT INTO channel_messages (channel, timestamp, sender, message) VALUES (?, ?, ?, ?)",
)
.bind(msg.channel_id.0)
.bind(msg.channel_id.0)
.bind(timestamp)
.bind(msg.sender)
.bind(msg.message)
.fetch_one(&conn)
.execute(&conn)
.await
.unwrap();
if !msg.receivers.is_empty() {
let query = format!(
"UPDATE channel_users
SET last_seen_message_idx = ?
SET last_seen_message_timestamp = ?
WHERE channel = ?
AND user IN ({})",
msg.receivers.iter().map(|_| "?").join(",")
);
let mut query = sqlx::query(&query).bind(idx).bind(msg.channel_id.0);
let mut query = sqlx::query(&query).bind(timestamp).bind(msg.channel_id.0);
for receiver in msg.receivers {
query = query.bind(receiver.0);
}
@@ -176,6 +194,8 @@
#[instrument(parent = &msg.span, skip_all)]
fn handle(&mut self, msg: FetchUnseenMessages, _ctx: &mut Self::Context) -> Self::Result {
let conn = self.database.clone();
let max_message_reply_since =
Utc::now() - chrono::Duration::from_std(self.max_message_replay_since).unwrap();
Box::pin(async move {
@@ -185,22 +205,19 @@
SELECT sender, message
FROM channel_messages
WHERE channel = (SELECT id FROM channel)
AND idx > MAX(
(
SELECT MAX(0, MAX(idx) - 500)
FROM channel_messages
WHERE channel = (SELECT id FROM channel)
),
(
SELECT last_seen_message_idx
FROM channel_users
WHERE channel = (SELECT id FROM channel)
AND user = ?
)
AND timestamp > MAX(
?,
COALESCE((
SELECT last_seen_message_timestamp
FROM channel_users
WHERE channel = (SELECT id FROM channel)
AND user = ?
), 0)
)
ORDER BY idx ASC",
ORDER BY timestamp ASC",
)
.bind(msg.channel_name.to_string())
.bind(&msg.channel_name)
.bind(max_message_reply_since.timestamp_nanos())
.bind(msg.user_id.0)
.fetch_all(&conn)
.await
@@ -211,11 +228,13 @@
}
}
pub async fn truncate_seen_messages(db: sqlx::Pool<sqlx::Any>) {
pub async fn truncate_seen_messages(db: sqlx::Pool<sqlx::Any>, max_replay_since: Duration) {
let messages = sqlx::query_as::<_, (i64, i64)>(
"SELECT channel, MIN(last_seen_message_idx)
"SELECT channel, MIN(last_seen_message_timestamp)
FROM channel_users
GROUP BY channel",
)
@@ -223,13 +242,25 @@
.await
.unwrap();
for (channel, min_seen_id) in messages {
sqlx::query("DELETE FROM channel_messages WHERE channel = ? AND idx < ?")
.bind(channel)
.bind(min_seen_id)
.execute(&db)
.await
.unwrap();
let max_replay_since = Utc::now() - chrono::Duration::from_std(max_replay_since).unwrap();
for (channel, min_seen_timestamp) in messages {
let mut tx = db.begin().await.unwrap();
let remove_before = std::cmp::max(min_seen_timestamp, max_replay_since.timestamp_nanos());
sqlx::query(
"DELETE FROM channel_messages
WHERE channel = ?
AND timestamp <= ?",
)
.bind(channel)
.bind(remove_before)
.execute(&mut tx)
.await
.unwrap();
tx.commit().await.unwrap();
}
}