Give each session a connection id
Diff
Cargo.lock | 10 ++++++++++
Cargo.toml | 1 +
src/main.rs | 180 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------------
src/providers/gitlab.rs | 101 +++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++-------------
4 files changed, 172 insertions(+), 120 deletions(-)
@@ -575,6 +575,7 @@
"toml",
"tracing",
"tracing-subscriber",
"uuid",
]
[[package]]
@@ -1784,6 +1785,15 @@
"idna",
"matches",
"percent-encoding",
]
[[package]]
name = "uuid"
version = "1.0.0-alpha.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb3ab47baa004111b323696c6eaa2752e7356f7f77cf6b6dc7a2087368ce1ca4"
dependencies = [
"getrandom",
]
[[package]]
@@ -36,3 +36,4 @@
tokio = { version = "1.17", features = ["full"] }
tokio-util = { version = "0.7", features = ["codec"] }
toml = "0.5"
uuid = { version = "1.0.0-alpha.1", features = ["v4"] }
@@ -25,14 +25,18 @@
use clap::Parser;
use futures::Future;
use parking_lot::RwLock;
use std::{borrow::Cow, collections::HashMap, fmt::Write, net::SocketAddr, pin::Pin, sync::Arc};
use std::{
borrow::Cow, collections::HashMap, fmt::Write, net::SocketAddr, net::SocketAddrV6, pin::Pin,
str::FromStr, sync::Arc,
};
use thrussh::{
server::{Auth, Session},
ChannelId, CryptoVec,
};
use thrussh_keys::key::PublicKey;
use tokio_util::{codec::Decoder, codec::Encoder as CodecEncoder};
use tracing::{error, info};
use tracing::{error, info, info_span, Instrument, Span};
use uuid::Uuid;
const AGENT: &str = concat!(
"agent=",
@@ -113,18 +117,25 @@
for Server<U>
{
type Handler = Handler<U>;
fn new(&mut self, peer_addr: Option<SocketAddr>) -> Self::Handler {
let connection_id = Uuid::new_v4();
let peer_addr =
peer_addr.unwrap_or_else(|| SocketAddrV6::from_str("[::]:0").unwrap().into());
let span = info_span!("ssh", ?peer_addr, ?connection_id);
fn new(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
info!(parent: &span, "Incoming connection");
Handler {
codec: GitCodec::default(),
gitlab: Arc::clone(&self.gitlab),
user: None,
group: None,
input_bytes: BytesMut::new(),
output_bytes: BytesMut::new(),
is_git_protocol_v2: false,
metadata_cache: Arc::clone(&self.metadata_cache),
span,
packfile_cache: None,
}
}
@@ -140,6 +151,7 @@
output_bytes: BytesMut,
is_git_protocol_v2: bool,
metadata_cache: MetadataCache,
span: Span,
packfile_cache: Option<Arc<(HashOutput, Vec<PackFileEntry>)>>,
@@ -348,83 +360,100 @@
fn auth_publickey(mut self, user: &str, public_key: &PublicKey) -> Self::FutureAuth {
let fingerprint = public_key.fingerprint();
let user = user.to_string();
Box::pin(capture_errors(async move {
let mut user = self
.gitlab
.find_user_by_username_password_combo(&user)
.await?;
if user.is_none() {
user = self
let span = info_span!(parent: &self.span, "auth_publickey", ?fingerprint);
Box::pin(
capture_errors(async move {
let mut by_ssh_key = false;
let mut user = self
.gitlab
.find_user_by_ssh_key(&util::format_fingerprint(&fingerprint))
.find_user_by_username_password_combo(&user)
.await?;
}
if let Some(user) = user {
self.user = Some(user);
self.finished_auth(Auth::Accept).await
} else {
self.finished_auth(Auth::Reject).await
}
}))
if user.is_none() {
by_ssh_key = true;
user = self
.gitlab
.find_user_by_ssh_key(&util::format_fingerprint(&fingerprint))
.await?;
}
if let Some(user) = user {
info!(
"Successfully authenticated for GitLab user `{}` by {}",
&user.username,
if by_ssh_key { "SSH Key" } else { "Build Token" },
);
self.user = Some(user);
self.finished_auth(Auth::Accept).await
} else {
info!("Public key rejected");
self.finished_auth(Auth::Reject).await
}
})
.instrument(span),
)
}
fn data(mut self, channel: ChannelId, data: &[u8], mut session: Session) -> Self::FutureUnit {
self.input_bytes.extend_from_slice(data);
let span = info_span!(parent: &self.span, "data");
Box::pin(capture_errors(async move {
let (commit_hash, packfile_entries) = &*self.build_packfile().await?;
while let Some(frame) = self.codec.decode(&mut self.input_bytes)? {
if frame.command.is_empty() {
session.exit_status_request(channel, 0);
session.eof(channel);
session.close(channel);
return Ok((self, session));
}
self.input_bytes.extend_from_slice(data);
match frame.command.as_ref() {
b"command=ls-refs" => {
git_command_handlers::ls_refs::handle(
&mut self,
&mut session,
channel,
&frame.metadata,
commit_hash,
)?;
}
b"command=fetch" => {
git_command_handlers::fetch::handle(
&mut self,
&mut session,
channel,
&frame.metadata,
packfile_entries,
)?;
Box::pin(
capture_errors(async move {
let (commit_hash, packfile_entries) = &*self.build_packfile().await?;
while let Some(frame) = self.codec.decode(&mut self.input_bytes)? {
if frame.command.is_empty() {
session.exit_status_request(channel, 0);
session.eof(channel);
session.close(channel);
return Ok((self, session));
}
v => {
error!(
"Client sent unknown command, ignoring command {}",
std::str::from_utf8(v).unwrap_or("invalid utf8")
);
match frame.command.as_ref() {
b"command=ls-refs" => {
git_command_handlers::ls_refs::handle(
&mut self,
&mut session,
channel,
&frame.metadata,
commit_hash,
)?;
}
b"command=fetch" => {
git_command_handlers::fetch::handle(
&mut self,
&mut session,
channel,
&frame.metadata,
packfile_entries,
)?;
}
v => {
error!(
"Client sent unknown command, ignoring command {}",
std::str::from_utf8(v).unwrap_or("invalid utf8")
);
}
}
}
}
Ok((self, session))
}))
Ok((self, session))
})
.instrument(span),
)
}
fn env_request(
@@ -444,6 +473,8 @@
}
fn shell_request(mut self, channel: ChannelId, mut session: Session) -> Self::FutureUnit {
let span = info_span!(parent: &self.span, "shell_request");
Box::pin(capture_errors(async move {
let username = self.user()?.username.clone();
write!(
@@ -452,10 +483,11 @@
username,
env!("CARGO_PKG_NAME")
)?;
info!("Shell requested, dropping connection");
self.flush(&mut session, channel);
session.close(channel);
Ok((self, session))
}))
}).instrument(span))
}
@@ -470,9 +502,13 @@
data: &[u8],
mut session: Session,
) -> Self::FutureUnit {
let span = info_span!(parent: &self.span, "exec_request");
let data = match std::str::from_utf8(data) {
Ok(data) => data,
Err(e) => return Box::pin(capture_errors(futures::future::err(e.into()))),
Err(e) => {
return Box::pin(capture_errors(futures::future::err(e.into())).instrument(span))
}
};
let args = shlex::split(data);
@@ -532,7 +568,7 @@
self.flush(&mut session, channel);
Ok((self, session))
}))
}).instrument(span))
}
}
@@ -9,6 +9,7 @@
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::sync::Arc;
use tracing::Instrument;
pub struct Gitlab {
client: reqwest::Client,
@@ -158,54 +159,58 @@
for release in res {
let this = self.clone();
futures.push(tokio::spawn(async move {
let (project, package) = {
let mut splitter = release.links.web_path.splitn(2, "/-/packages/");
match (splitter.next(), splitter.next()) {
(Some(project), Some(package)) => (&project[1..], package),
_ => return Ok(None),
}
};
let package_path = Arc::new(GitlabCratePath {
project: utf8_percent_encode(project, NON_ALPHANUMERIC).to_string(),
package_name: utf8_percent_encode(&release.name, NON_ALPHANUMERIC)
.to_string(),
});
let package_files: Vec<GitlabPackageFilesResponse> = handle_error(
this.client
.get(format!(
"{}/projects/{}/packages/{}/package_files",
this.base_url,
utf8_percent_encode(project, NON_ALPHANUMERIC),
utf8_percent_encode(package, NON_ALPHANUMERIC),
))
.send()
.await?,
)
.await?
.json()
.await?;
let expected_file_name = format!("{}-{}.crate", release.name, release.version);
Ok::<_, anyhow::Error>(
package_files
.into_iter()
.find(|package_file| package_file.file_name == expected_file_name)
.map(move |package_file| {
(
Arc::clone(&package_path),
Release {
name: release.name,
version: release.version,
checksum: package_file.file_sha256,
},
)
}),
)
}));
futures.push(tokio::spawn(
async move {
let (project, package) = {
let mut splitter = release.links.web_path.splitn(2, "/-/packages/");
match (splitter.next(), splitter.next()) {
(Some(project), Some(package)) => (&project[1..], package),
_ => return Ok(None),
}
};
let package_path = Arc::new(GitlabCratePath {
project: utf8_percent_encode(project, NON_ALPHANUMERIC).to_string(),
package_name: utf8_percent_encode(&release.name, NON_ALPHANUMERIC)
.to_string(),
});
let package_files: Vec<GitlabPackageFilesResponse> = handle_error(
this.client
.get(format!(
"{}/projects/{}/packages/{}/package_files",
this.base_url,
utf8_percent_encode(project, NON_ALPHANUMERIC),
utf8_percent_encode(package, NON_ALPHANUMERIC),
))
.send()
.await?,
)
.await?
.json()
.await?;
let expected_file_name =
format!("{}-{}.crate", release.name, release.version);
Ok::<_, anyhow::Error>(
package_files
.into_iter()
.find(|package_file| package_file.file_name == expected_file_name)
.map(move |package_file| {
(
Arc::clone(&package_path),
Release {
name: release.name,
version: release.version,
checksum: package_file.file_sha256,
},
)
}),
)
}
.in_current_span(),
));
}
}