diff --git a/Cargo.lock b/Cargo.lock index 014c5f5..4bdbd12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -186,10 +186,12 @@ version = "0.1.0" dependencies = [ "anyhow", "async-trait", + "bytes", "russh", "russh-keys", "thiserror", "tokio", + "tokio-util", "tracing", "tracing-subscriber", ] @@ -1117,6 +1119,7 @@ dependencies = [ "futures-sink", "pin-project-lite", "tokio", + "tracing", ] [[package]] diff --git a/crates/cli/Cargo.toml b/crates/cli/Cargo.toml index c4968d6..9b11a11 100644 --- a/crates/cli/Cargo.toml +++ b/crates/cli/Cargo.toml @@ -6,6 +6,7 @@ edition = "2021" [dependencies] anyhow = "1.0.68" async-trait = "0.1.63" +bytes = "1.3.0" russh = "0.36.0" russh-keys = "0.24.0" thiserror = "1.0.38" @@ -14,5 +15,6 @@ thiserror = "1.0.38" # serde = { version = "1", features = ["derive"] } # serde_json = "1.0.87" tokio = { version = "1", features = ["full"] } +tokio-util = { version = "0.7.4", features = ["codec"] } tracing = "0.1.37" tracing-subscriber = "0.3.16" diff --git a/crates/cli/src/codec.rs b/crates/cli/src/codec.rs new file mode 100644 index 0000000..8782911 --- /dev/null +++ b/crates/cli/src/codec.rs @@ -0,0 +1,138 @@ +use bytes::Buf; +use tokio_util::codec::{Decoder, Encoder}; + +use crate::error::AppError; + +/// git protocol encoder/decoder +struct ChunkCodec; + +const CHUNK_LENGTH_BYTES: usize = 4; + +fn hex_char_value(byte: u8) -> Option { + match byte { + b'0'..=b'9' => Some(byte - b'0'), + b'a'..=b'f' => Some(byte - b'a' + 10), + b'A'..=b'F' => Some(byte - b'A' + 10), + _ => None, + } +} + +impl Decoder for ChunkCodec { + type Item = Vec; + type Error = AppError; + + fn decode(&mut self, buf: &mut bytes::BytesMut) -> Result, Self::Error> { + if buf.len() < CHUNK_LENGTH_BYTES { + return Ok(None); + } + // read the length of the chunk + let chunk_len = (buf[0..CHUNK_LENGTH_BYTES]).iter().try_fold(0, |value, &byte| { + let char_value = hex_char_value(byte)?; + Some(value << 4 | char_value as usize) + }).ok_or_else(|| { + AppError::Anyhow(anyhow::anyhow!("invalid chunk length")) + })?; + tracing::info!(?chunk_len, "decode"); + + if chunk_len == 0 { + // TODO: end of stream? + return Ok(None); + } + + // the length includes the length bytes themselves, so subtract them + let chunk_len = chunk_len.checked_sub(CHUNK_LENGTH_BYTES).ok_or_else(|| { + AppError::Anyhow(anyhow::anyhow!("invalid chunk length")) + })?; + + // check if the entire chunk is in the buffer + if buf.len() < chunk_len + CHUNK_LENGTH_BYTES { + return Ok(None); + } + + // skip the length, get the chunk + let chunk: Vec = buf.iter().skip(CHUNK_LENGTH_BYTES).take(chunk_len).copied().collect(); + // remove the chunk from the buffer + buf.advance(chunk_len + CHUNK_LENGTH_BYTES); + + Ok(Some(chunk)) + } +} + +impl Encoder> for ChunkCodec { + type Error = AppError; + + fn encode(&mut self, item: Vec, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + let chunk_len = item.len() + CHUNK_LENGTH_BYTES; + let chunk_len_hex = format!("{chunk_len:04x}"); + dst.extend_from_slice(chunk_len_hex.as_bytes()); + dst.extend_from_slice(&item); + Ok(()) + } +} + +struct TextChunkCodec; + +impl Decoder for TextChunkCodec { + type Item = String; + type Error = AppError; + + fn decode(&mut self, buf: &mut bytes::BytesMut) -> Result, Self::Error> { + let chunk = ChunkCodec.decode(buf)?; + if let Some(chunk) = chunk { + let mut chunk = String::from_utf8(chunk)?; + + // Remove any trailing newlines as they are not needed + if chunk.ends_with('\n') { + chunk.pop(); + } + + Ok(Some(chunk)) + } else { + Ok(None) + } + } +} + +impl Encoder for TextChunkCodec { + type Error = AppError; + + fn encode(&mut self, item: String, dst: &mut bytes::BytesMut) -> Result<(), Self::Error> { + ChunkCodec.encode(item.into_bytes(), dst) + } +} + +#[cfg(test)] +mod tests { + use tokio_util::codec::{Decoder, Encoder}; + use crate::codec::{ChunkCodec, TextChunkCodec}; + + #[tokio::test] + async fn encode_strings() { + let mut codec = TextChunkCodec; + let mut buf = bytes::BytesMut::new(); + let chunk_contents = "cded0bbfe0b0a2c44a823d7bca226555f98200cd refs/heads/main\0report-status report-status-v2 delete-refs side-band-64k quiet atomic ofs-delta object-format=sha1 agent=git/2.38.1\n"; + codec.encode(chunk_contents.to_string(), &mut buf).unwrap(); + + let mut expected = bytes::BytesMut::new(); + let expected_string = "00b1cded0bbfe0b0a2c44a823d7bca226555f98200cd refs/heads/main\0report-status report-status-v2 delete-refs side-band-64k quiet atomic ofs-delta object-format=sha1 agent=git/2.38.1\n"; + expected.extend_from_slice(expected_string.as_bytes()); + + assert_eq!(buf, expected); + } + + #[tokio::test] + async fn decode_strings() { + let mut codec = TextChunkCodec; + let mut buf = bytes::BytesMut::new(); + let chunk_contents = "cded0bbfe0b0a2c44a823d7bca226555f98200cd refs/heads/main\0report-status report-status-v2 delete-refs side-band-64k quiet atomic ofs-delta object-format=sha1 agent=git/2.38.1\n"; + codec.encode(chunk_contents.to_string(), &mut buf).unwrap(); + + let decoded = codec.decode(&mut buf).unwrap().unwrap(); + + // Our decoder removes any trailing newlines, so we need to do the same + let mut expected = chunk_contents.to_string(); + expected.pop(); + + assert_eq!(decoded, expected); + } +} diff --git a/crates/cli/src/error.rs b/crates/cli/src/error.rs index 37b12f3..cf557f4 100644 --- a/crates/cli/src/error.rs +++ b/crates/cli/src/error.rs @@ -8,6 +8,8 @@ pub enum AppError { Russh(#[from] russh::Error), #[error(transparent)] Io(#[from] std::io::Error), + #[error(transparent)] + FromUtf8(#[from] std::string::FromUtf8Error), // Application specific errors #[error("no data directory specified")] diff --git a/crates/cli/src/main.rs b/crates/cli/src/main.rs index d52651a..c712897 100644 --- a/crates/cli/src/main.rs +++ b/crates/cli/src/main.rs @@ -1,12 +1,60 @@ use russh::server::{Auth, Session}; use russh_keys::PublicKeyBase64; use std::{collections::HashMap, net::SocketAddr, sync::Arc, path::PathBuf}; -use tokio::sync::Mutex; +use tokio::{ + macros::support::Pin, + io::{AsyncRead, AsyncWrite}, + sync::Mutex, +}; +mod codec; mod error; use error::AppResult; use tracing::info; +/// A thin wrapper around tokio::process::Child that implements AsyncRead +/// and AsyncWrite on top of the child's stdout and stdin. +struct ChildProcess { + inner: tokio::process::Child, +} + +impl AsyncRead for ChildProcess { + fn poll_read( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &mut tokio::io::ReadBuf<'_>, + ) -> std::task::Poll> { + Pin::new(&mut self.inner.stdout.as_mut().unwrap()).poll_read(cx, buf) + } +} + +impl AsyncWrite for ChildProcess { + fn poll_write( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + buf: &[u8], + ) -> std::task::Poll> { + info!(?buf, "poll_write"); + Pin::new(&mut self.inner.stdin.as_mut().unwrap()).poll_write(cx, buf) + } + + fn poll_flush( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + info!("poll_flush"); + Pin::new(&mut self.inner.stdin.as_mut().unwrap()).poll_flush(cx) + } + + fn poll_shutdown( + mut self: std::pin::Pin<&mut Self>, + cx: &mut std::task::Context<'_>, + ) -> std::task::Poll> { + info!("poll_shutdown"); + Pin::new(&mut self.inner.stdin.as_mut().unwrap()).poll_shutdown(cx) + } +} + #[derive(Clone, Debug)] struct SshServer { data_dir: PathBuf, @@ -34,18 +82,65 @@ impl SshSession { } } + async fn get_channel(&mut self, channel_id: russh::ChannelId) -> russh::Channel { + let mut clients = self.clients.lock().await; + clients.remove(&channel_id).unwrap() + } + /// Respond with one line for each reference we currently have /// The first line also haas a list of the server's capabilities /// The data is transmitted in chunks. /// Each chunk starts with a 4 character hex value specifying the length of the chunk (including the 4 character hex value) /// Chunks usually contain a single line of data and a trailing linefeed - async fn receive_pack(&self, args: Vec<&str>) -> AppResult<()> { + #[tracing::instrument(skip(self, args))] + async fn receive_pack(&mut self, channel_id: russh::ChannelId, args: Vec<&str>) -> AppResult<()> { info!(?args, ?self.data_dir, "git-receive-pack"); - // TODO: First, determine the repository name and path + // First, determine the repository name and path + // We need to clean up the text from the url and make it a relative path to the data directory + let repo_name = args[0].replace('\'', "").trim_start_matches('/').to_string(); + let repo_path = self.data_dir.join(repo_name); + info!(?repo_path); + + // Next, we need to create the repository if it doesn't exist + if !repo_path.exists() { + // assume a `git` command is available to create the repository + tokio::process::Command::new("git") + .arg("init") + .arg("--bare") + .arg(&repo_path) + .output() + .await?; + } + + let channel = self.get_channel(channel_id).await; + let _stream = channel.into_stream(); + + // invoke git-receive-pack + // send the output to the channel + let _child = tokio::process::Command::new("git") + .arg("receive-pack") + .arg(&repo_path) + .stdin(std::process::Stdio::piped()) + .stdout(std::process::Stdio::piped()) + .spawn()?; + + // tokio::io::copy_bidirectional(&mut child_process, &mut stream).await?; + + // collect stdout + // let mut stdout = child.stdout.unwrap(); + // let mut output = Vec::new(); + // tokio::io::copy(&mut stdout, &mut output).await?; + // info!(?output); - // TODO: Is it enough to just invoke the command from the proper directory? Ok(()) } + + async fn upload_pack(&mut self, _channel_id: russh::ChannelId, args: Vec<&str>) -> AppResult<()> { + info!(?args, ?self.data_dir, "git-upload-pack"); + + todo!() + } + } #[async_trait::async_trait] @@ -116,7 +211,12 @@ impl russh::server::Handler for SshSession { match parse_command(&command_str) { Some(("git-receive-pack", args)) => { - let _res = self.receive_pack(args).await?; + let _res = self.receive_pack(channel_id, args).await?; + + Ok((self, session)) + } + Some(("git-upload-pack", args)) => { + let _res = self.upload_pack(channel_id, args).await?; Ok((self, session)) } @@ -135,6 +235,7 @@ async fn main() -> AppResult<()> { // first arg: the directory to store repositories in let data_dir = std::env::args().nth(1).ok_or(error::AppError::NoDataDir)?; + // TODO: ensure data_dir exists let config = russh::server::Config { auth_rejection_time: std::time::Duration::from_secs(3), diff --git a/flake-parts/cargo.nix b/flake-parts/cargo.nix index 653b585..552848d 100644 --- a/flake-parts/cargo.nix +++ b/flake-parts/cargo.nix @@ -17,6 +17,7 @@ fenix-toolchain bacon rustfmt + cargo-nextest # misc ];