🏡 index : ~doyle/gitlab-cargo-shim.git

author Jordan Doyle <jordan@doyle.la> 2022-03-13 3:01:37.0 +00:00:00
committer Jordan Doyle <jordan@doyle.la> 2022-03-13 3:01:37.0 +00:00:00
commit
be1b834ed9718c36bf26af3e6ffe69a731823ed1 [patch]
tree
6301878412f481d33e8745df84d36f6ef10d1b11
parent
46ca918befbc28693db6a2b1ff474648742fca44
download
be1b834ed9718c36bf26af3e6ffe69a731823ed1.tar.gz

Give each session a connection id



Diff

 Cargo.lock              |  10 +++-
 Cargo.toml              |   1 +-
 src/main.rs             | 180 +++++++++++++++++++++++++++++--------------------
 src/providers/gitlab.rs | 101 ++++++++++++++-------------
 4 files changed, 172 insertions(+), 120 deletions(-)

diff --git a/Cargo.lock b/Cargo.lock
index 5b8015b..3eec481 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -575,6 +575,7 @@ dependencies = [
 "toml",
 "tracing",
 "tracing-subscriber",
 "uuid",
]

[[package]]
@@ -1787,6 +1788,15 @@ dependencies = [
]

[[package]]
name = "uuid"
version = "1.0.0-alpha.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb3ab47baa004111b323696c6eaa2752e7356f7f77cf6b6dc7a2087368ce1ca4"
dependencies = [
 "getrandom",
]

[[package]]
name = "valuable"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
diff --git a/Cargo.toml b/Cargo.toml
index 31b4662..0a83231 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -36,3 +36,4 @@ time = "0.3"
tokio = { version = "1.17", features = ["full"] }
tokio-util = { version = "0.7", features = ["codec"] }
toml = "0.5"
uuid = { version = "1.0.0-alpha.1", features = ["v4"] }
diff --git a/src/main.rs b/src/main.rs
index 1a7485b..a9d07cd 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -25,14 +25,18 @@ use bytes::{BufMut, Bytes, BytesMut};
use clap::Parser;
use futures::Future;
use parking_lot::RwLock;
use std::{borrow::Cow, collections::HashMap, fmt::Write, net::SocketAddr, pin::Pin, sync::Arc};
use std::{
    borrow::Cow, collections::HashMap, fmt::Write, net::SocketAddr, net::SocketAddrV6, pin::Pin,
    str::FromStr, sync::Arc,
};
use thrussh::{
    server::{Auth, Session},
    ChannelId, CryptoVec,
};
use thrussh_keys::key::PublicKey;
use tokio_util::{codec::Decoder, codec::Encoder as CodecEncoder};
use tracing::{error, info};
use tracing::{error, info, info_span, Instrument, Span};
use uuid::Uuid;

const AGENT: &str = concat!(
    "agent=",
@@ -114,17 +118,24 @@ impl<U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::server:
{
    type Handler = Handler<U>;

    fn new(&mut self, _peer_addr: Option<SocketAddr>) -> Self::Handler {
    fn new(&mut self, peer_addr: Option<SocketAddr>) -> Self::Handler {
        let connection_id = Uuid::new_v4();
        let peer_addr =
            peer_addr.unwrap_or_else(|| SocketAddrV6::from_str("[::]:0").unwrap().into());
        let span = info_span!("ssh", ?peer_addr, ?connection_id);

        info!(parent: &span, "Incoming connection");

        Handler {
            codec: GitCodec::default(),
            gitlab: Arc::clone(&self.gitlab),
            user: None,
            group: None,
            // fetcher_future: None,
            input_bytes: BytesMut::new(),
            output_bytes: BytesMut::new(),
            is_git_protocol_v2: false,
            metadata_cache: Arc::clone(&self.metadata_cache),
            span,
            packfile_cache: None,
        }
    }
@@ -140,6 +151,7 @@ pub struct Handler<U: UserProvider + PackageProvider + Send + Sync + 'static> {
    output_bytes: BytesMut,
    is_git_protocol_v2: bool,
    metadata_cache: MetadataCache,
    span: Span,
    // Cache of the packfile generated for this user in case it's requested
    // more than once
    packfile_cache: Option<Arc<(HashOutput, Vec<PackFileEntry>)>>,
@@ -348,83 +360,100 @@ impl<'a, U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::ser
    fn auth_publickey(mut self, user: &str, public_key: &PublicKey) -> Self::FutureAuth {
        let fingerprint = public_key.fingerprint();
        let user = user.to_string();

        Box::pin(capture_errors(async move {
            // username:password combo is used by CI to authenticate to us,
            // it does not allow users to authenticate directly. it's
            // technically the SSH username that contains both the username
            // and password as we don't want an interactive prompt or
            // anything like that
            let mut user = self
                .gitlab
                .find_user_by_username_password_combo(&user)
                .await?;

            // if there was no username:password combo given we'll lookup
            // the user by the SSH key they're connecting to us with
            if user.is_none() {
                user = self
        let span = info_span!(parent: &self.span, "auth_publickey", ?fingerprint);

        Box::pin(
            capture_errors(async move {
                // username:password combo is used by CI to authenticate to us,
                // it does not allow users to authenticate directly. it's
                // technically the SSH username that contains both the username
                // and password as we don't want an interactive prompt or
                // anything like that
                let mut by_ssh_key = false;
                let mut user = self
                    .gitlab
                    .find_user_by_ssh_key(&util::format_fingerprint(&fingerprint))
                    .find_user_by_username_password_combo(&user)
                    .await?;
            }

            if let Some(user) = user {
                self.user = Some(user);
                self.finished_auth(Auth::Accept).await
            } else {
                self.finished_auth(Auth::Reject).await
            }
        }))
                // if there was no username:password combo given we'll lookup
                // the user by the SSH key they're connecting to us with
                if user.is_none() {
                    by_ssh_key = true;
                    user = self
                        .gitlab
                        .find_user_by_ssh_key(&util::format_fingerprint(&fingerprint))
                        .await?;
                }

                if let Some(user) = user {
                    info!(
                        "Successfully authenticated for GitLab user `{}` by {}",
                        &user.username,
                        if by_ssh_key { "SSH Key" } else { "Build Token" },
                    );
                    self.user = Some(user);
                    self.finished_auth(Auth::Accept).await
                } else {
                    info!("Public key rejected");
                    self.finished_auth(Auth::Reject).await
                }
            })
            .instrument(span),
        )
    }

    fn data(mut self, channel: ChannelId, data: &[u8], mut session: Session) -> Self::FutureUnit {
        self.input_bytes.extend_from_slice(data);
        let span = info_span!(parent: &self.span, "data");

        Box::pin(capture_errors(async move {
            // build the packfile we're going to send to the user
            let (commit_hash, packfile_entries) = &*self.build_packfile().await?;

            while let Some(frame) = self.codec.decode(&mut self.input_bytes)? {
                // if the client flushed without giving us a command, we're expected to close
                // the connection or else the client will just hang
                if frame.command.is_empty() {
                    session.exit_status_request(channel, 0);
                    session.eof(channel);
                    session.close(channel);
                    return Ok((self, session));
                }
        self.input_bytes.extend_from_slice(data);

                match frame.command.as_ref() {
                    b"command=ls-refs" => {
                        git_command_handlers::ls_refs::handle(
                            &mut self,
                            &mut session,
                            channel,
                            &frame.metadata,
                            commit_hash,
                        )?;
                    }
                    b"command=fetch" => {
                        git_command_handlers::fetch::handle(
                            &mut self,
                            &mut session,
                            channel,
                            &frame.metadata,
                            packfile_entries,
                        )?;
        Box::pin(
            capture_errors(async move {
                // build the packfile we're going to send to the user
                let (commit_hash, packfile_entries) = &*self.build_packfile().await?;

                while let Some(frame) = self.codec.decode(&mut self.input_bytes)? {
                    // if the client flushed without giving us a command, we're expected to close
                    // the connection or else the client will just hang
                    if frame.command.is_empty() {
                        session.exit_status_request(channel, 0);
                        session.eof(channel);
                        session.close(channel);
                        return Ok((self, session));
                    }
                    v => {
                        error!(
                            "Client sent unknown command, ignoring command {}",
                            std::str::from_utf8(v).unwrap_or("invalid utf8")
                        );

                    match frame.command.as_ref() {
                        b"command=ls-refs" => {
                            git_command_handlers::ls_refs::handle(
                                &mut self,
                                &mut session,
                                channel,
                                &frame.metadata,
                                commit_hash,
                            )?;
                        }
                        b"command=fetch" => {
                            git_command_handlers::fetch::handle(
                                &mut self,
                                &mut session,
                                channel,
                                &frame.metadata,
                                packfile_entries,
                            )?;
                        }
                        v => {
                            error!(
                                "Client sent unknown command, ignoring command {}",
                                std::str::from_utf8(v).unwrap_or("invalid utf8")
                            );
                        }
                    }
                }
            }

            Ok((self, session))
        }))
                Ok((self, session))
            })
            .instrument(span),
        )
    }

    fn env_request(
@@ -444,6 +473,8 @@ impl<'a, U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::ser
    }

    fn shell_request(mut self, channel: ChannelId, mut session: Session) -> Self::FutureUnit {
        let span = info_span!(parent: &self.span, "shell_request");

        Box::pin(capture_errors(async move {
            let username = self.user()?.username.clone();
            write!(
@@ -452,10 +483,11 @@ impl<'a, U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::ser
                username,
                env!("CARGO_PKG_NAME")
            )?;
            info!("Shell requested, dropping connection");
            self.flush(&mut session, channel);
            session.close(channel);
            Ok((self, session))
        }))
        }).instrument(span))
    }

    /// Initially when setting up the SSH connection, the remote Git client will send us an
@@ -470,9 +502,13 @@ impl<'a, U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::ser
        data: &[u8],
        mut session: Session,
    ) -> Self::FutureUnit {
        let span = info_span!(parent: &self.span, "exec_request");

        let data = match std::str::from_utf8(data) {
            Ok(data) => data,
            Err(e) => return Box::pin(capture_errors(futures::future::err(e.into()))),
            Err(e) => {
                return Box::pin(capture_errors(futures::future::err(e.into())).instrument(span))
            }
        };
        // parses the given args in the same fashion as a POSIX shell
        let args = shlex::split(data);
@@ -532,7 +568,7 @@ impl<'a, U: UserProvider + PackageProvider + Send + Sync + 'static> thrussh::ser
            self.flush(&mut session, channel);

            Ok((self, session))
        }))
        }).instrument(span))
    }
}

diff --git a/src/providers/gitlab.rs b/src/providers/gitlab.rs
index 9d27d10..74bc949 100644
--- a/src/providers/gitlab.rs
+++ b/src/providers/gitlab.rs
@@ -9,6 +9,7 @@ use reqwest::header;
use serde::{Deserialize, Serialize};
use std::borrow::Cow;
use std::sync::Arc;
use tracing::Instrument;

pub struct Gitlab {
    client: reqwest::Client,
@@ -158,54 +159,58 @@ impl super::PackageProvider for Gitlab {
            for release in res {
                let this = self.clone();

                futures.push(tokio::spawn(async move {
                    let (project, package) = {
                        let mut splitter = release.links.web_path.splitn(2, "/-/packages/");
                        match (splitter.next(), splitter.next()) {
                            (Some(project), Some(package)) => (&project[1..], package),
                            _ => return Ok(None),
                        }
                    };

                    let package_path = Arc::new(GitlabCratePath {
                        project: utf8_percent_encode(project, NON_ALPHANUMERIC).to_string(),
                        package_name: utf8_percent_encode(&release.name, NON_ALPHANUMERIC)
                            .to_string(),
                    });

                    let package_files: Vec<GitlabPackageFilesResponse> = handle_error(
                        this.client
                            .get(format!(
                                "{}/projects/{}/packages/{}/package_files",
                                this.base_url,
                                utf8_percent_encode(project, NON_ALPHANUMERIC),
                                utf8_percent_encode(package, NON_ALPHANUMERIC),
                            ))
                            .send()
                            .await?,
                    )
                    .await?
                    .json()
                    .await?;

                    let expected_file_name = format!("{}-{}.crate", release.name, release.version);

                    Ok::<_, anyhow::Error>(
                        package_files
                            .into_iter()
                            .find(|package_file| package_file.file_name == expected_file_name)
                            .map(move |package_file| {
                                (
                                    Arc::clone(&package_path),
                                    Release {
                                        name: release.name,
                                        version: release.version,
                                        checksum: package_file.file_sha256,
                                    },
                                )
                            }),
                    )
                }));
                futures.push(tokio::spawn(
                    async move {
                        let (project, package) = {
                            let mut splitter = release.links.web_path.splitn(2, "/-/packages/");
                            match (splitter.next(), splitter.next()) {
                                (Some(project), Some(package)) => (&project[1..], package),
                                _ => return Ok(None),
                            }
                        };

                        let package_path = Arc::new(GitlabCratePath {
                            project: utf8_percent_encode(project, NON_ALPHANUMERIC).to_string(),
                            package_name: utf8_percent_encode(&release.name, NON_ALPHANUMERIC)
                                .to_string(),
                        });

                        let package_files: Vec<GitlabPackageFilesResponse> = handle_error(
                            this.client
                                .get(format!(
                                    "{}/projects/{}/packages/{}/package_files",
                                    this.base_url,
                                    utf8_percent_encode(project, NON_ALPHANUMERIC),
                                    utf8_percent_encode(package, NON_ALPHANUMERIC),
                                ))
                                .send()
                                .await?,
                        )
                        .await?
                        .json()
                        .await?;

                        let expected_file_name =
                            format!("{}-{}.crate", release.name, release.version);

                        Ok::<_, anyhow::Error>(
                            package_files
                                .into_iter()
                                .find(|package_file| package_file.file_name == expected_file_name)
                                .map(move |package_file| {
                                    (
                                        Arc::clone(&package_path),
                                        Release {
                                            name: release.name,
                                            version: release.version,
                                            checksum: package_file.file_sha256,
                                        },
                                    )
                                }),
                        )
                    }
                    .in_current_span(),
                ));
            }
        }