From 8e10b50a26112df884c36dc503dfaae2b0c0abd4 Mon Sep 17 00:00:00 2001 From: Andrew Gazelka Date: Sat, 1 Apr 2023 23:44:47 -0700 Subject: [PATCH] update to nightly (until unsize stabilized) --- .github/workflows/rust-test.yml | 6 +- Cargo.lock | 39 +++++- code/derive-build/src/lib.rs | 59 +++++---- code/derive-discriminant/src/lib.rs | 154 +++++++++++++++--------- code/derive-discriminant/tests/macro.rs | 2 + code/executor/Cargo.toml | 7 ++ code/executor/src/command.rs | 35 ++++++ code/executor/src/command/bash.rs | 26 ++++ code/executor/src/command/zsh.rs | 58 +++++++++ code/executor/src/main.rs | 35 ++++++ code/openai/src/lib.rs | 24 ++-- code/utils/src/lib.rs | 1 + code/utils/src/str.rs | 44 +++++++ 13 files changed, 390 insertions(+), 100 deletions(-) create mode 100644 code/executor/src/command.rs create mode 100644 code/executor/src/command/bash.rs create mode 100644 code/executor/src/command/zsh.rs create mode 100644 code/utils/src/str.rs diff --git a/.github/workflows/rust-test.yml b/.github/workflows/rust-test.yml index 11f4054..a333b91 100644 --- a/.github/workflows/rust-test.yml +++ b/.github/workflows/rust-test.yml @@ -49,11 +49,7 @@ jobs: -A clippy::match-bool -# - uses: taiki-e/install-action@cargo-llvm-cov - - - name: Install cargo-llvm-cov - run: cargo install --git https://github.com/andrewgazelka/cargo-llvm-cov --branch region-coverage-codecov cargo-llvm-cov - + - uses: taiki-e/install-action@cargo-llvm-cov - uses: taiki-e/install-action@nextest - name: Collect coverage data diff --git a/Cargo.lock b/Cargo.lock index 782d3d9..f1b3ed6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -36,12 +36,29 @@ dependencies = [ "num-traits", ] +[[package]] +name = "async-trait" +version = "0.1.68" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9ccdd8f2a161be9bd5c023df56f1b2a0bd1d83872ae53b71a84a12c9bf6e842" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.12", +] + [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" +[[package]] +name = "base64" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" + [[package]] name = "base64" version = "0.21.0" @@ -213,6 +230,15 @@ dependencies = [ [[package]] name = "executor" version = "0.1.0" +dependencies = [ + "anyhow", + "async-trait", + "derive-build", + "derive-discriminant", + "ron", + "tokio", + "utils", +] [[package]] name = "fastrand" @@ -849,7 +875,7 @@ version = "0.11.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "27b71749df584b7f4cac2c426c127a7c785a5106cc98f7a8feb044115f0fa254" dependencies = [ - "base64", + "base64 0.21.0", "bytes", "encoding_rs", "futures-core", @@ -882,6 +908,17 @@ dependencies = [ "winreg", ] +[[package]] +name = "ron" +version = "0.8.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "300a51053b1cb55c80b7a9fde4120726ddf25ca241a1cbb926626f62fb136bff" +dependencies = [ + "base64 0.13.1", + "bitflags", + "serde", +] + [[package]] name = "rustc_version" version = "0.4.0" diff --git a/code/derive-build/src/lib.rs b/code/derive-build/src/lib.rs index cd1545e..65c0879 100644 --- a/code/derive-build/src/lib.rs +++ b/code/derive-build/src/lib.rs @@ -2,10 +2,10 @@ extern crate proc_macro; use inflector::string::singularize::to_singular; use proc_macro::TokenStream; -use quote::quote; -use syn::{parse_macro_input, DeriveInput, Path, Type, TypePath}; +use quote::{quote}; +use syn::{parse_macro_input, DeriveInput, Meta, Path, Type, TypePath}; -#[proc_macro_derive(Build, attributes(required))] +#[proc_macro_derive(Build, attributes(required, default))] pub fn build_macro_derive(input: TokenStream) -> TokenStream { let ast = parse_macro_input!(input as DeriveInput); impl_build_macro(&ast) @@ -50,6 +50,28 @@ fn impl_build_macro(ast: &DeriveInput) -> TokenStream { let name = &ast.ident; let (required_fields, optional_fields) = partition_fields(&ast.data); + let (optional_fields, optional_defaults): (Vec<_>, Vec<_>) = optional_fields + .iter() + .map(|field| { + let default_value = field + .attrs + .iter() + .find(|attr| attr.path().is_ident("default")) + .map(|attr| { + let Meta::NameValue(v)= &attr.meta else { + panic!("only named values allowed for default attribute") + }; + + let v= &v.value; + + quote!(#v) + }) + .unwrap_or_else(|| quote! { Default::default() }); + + (field, default_value) + }) + .unzip(); + let required_params = required_fields.iter().map(|field| { let field_name = &field.ident; let field_type = &field.ty; @@ -116,28 +138,21 @@ fn impl_build_macro(ast: &DeriveInput) -> TokenStream { } }); - let expanded = match required_params.len() == 0 { - true => quote! { - impl #name { - pub fn new() -> Self { - Default::default() - } + let optional_field_idents = optional_fields.iter().map(|field| &field.ident); - #(#optional_methods)* - } - }, - false => quote! { - impl #name { - pub fn new(#(#required_params),*) -> Self { - Self { - #(#required_assignments,)* - ..Default::default() - } + let expanded = quote! { + impl #name { + pub fn new(#(#required_params),*) -> Self { + Self { + #(#required_assignments,)* + #( + #optional_field_idents: #optional_defaults, + )* } - - #(#optional_methods)* } - }, + + #(#optional_methods)* + } }; TokenStream::from(expanded) diff --git a/code/derive-discriminant/src/lib.rs b/code/derive-discriminant/src/lib.rs index 0ac0070..730b94c 100644 --- a/code/derive-discriminant/src/lib.rs +++ b/code/derive-discriminant/src/lib.rs @@ -12,80 +12,118 @@ pub fn discriminant_derive(input: TokenStream) -> TokenStream { fn impl_discriminant_macro(ast: DeriveInput) -> TokenStream { let name = &ast.ident; - let attrs = ast.attrs; - - if let Data::Enum(data_enum) = ast.data { - let variant_impls = data_enum.variants.into_iter().map(|variant| { - let variant_name = &variant.ident; - let fields = &variant.fields; - - match fields { - Fields::Unit => { - quote! { - impl From<#variant_name> for #name { - fn from(value: #variant_name) -> Self { - Self::#variant_name - } + + // all non-doc attributes + let global_attrs: Vec<_> = ast + .attrs + .into_iter() + .filter(|attr| !attr.path().is_ident("doc")) + .collect(); + + let Data::Enum(data_enum) = ast.data else { + panic!("Discriminant can only be derived for enums"); + }; + + let variant_names: Vec<_> = data_enum + .variants + .iter() + .map(|variant| &variant.ident) + .collect(); + + // implementation for the .cast() method to cast into a trait object + // this requires nightly + let cast = quote! { + impl #name { + fn cast(self) -> Box where #(#variant_names: ::core::marker::Unsize),* { + let value = self; + // TODO: use a singular match expression + #( + let value = match #variant_names::try_from(value) { + Ok(v) => { + let x = Box::new(v); + return x; } + Err(v) => v, + }; + )* - impl std::convert::TryFrom<#name> for #variant_name { - type Error = (); + unreachable!(); + } + } + }; - fn try_from(value: #name) -> Result { - if let #name::#variant_name = value { - Ok(#variant_name) - } else { - Err(()) - } - } + let variant_impls = data_enum.variants.into_iter().map(|variant| { + let variant_name = &variant.ident; + let fields = &variant.fields; + let variant_attrs = variant.attrs; + + match fields { + Fields::Unit => { + quote! { + impl From<#variant_name> for #name { + fn from(value: #variant_name) -> Self { + Self::#variant_name } + } + + impl std::convert::TryFrom<#name> for #variant_name { + type Error = #name; - #(#attrs)* - struct #variant_name; + fn try_from(value: #name) -> Result { + if let #name::#variant_name = value { + Ok(#variant_name) + } else { + Err(value) + } + } } + + #(#global_attrs)* + #(#variant_attrs)* + struct #variant_name; } - _ => { - let field_name = fields.iter().map(|field| &field.ident).collect::>(); - let field_type = fields.iter().map(|field| &field.ty).collect::>(); - - quote! { - impl From<#variant_name> for #name { - fn from(value: #variant_name) -> Self { - Self::#variant_name { - #(#field_name: value.#field_name),* - } + } + _ => { + let field_name = fields.iter().map(|field| &field.ident).collect::>(); + let field_type = fields.iter().map(|field| &field.ty).collect::>(); + + quote! { + impl From<#variant_name> for #name { + fn from(value: #variant_name) -> Self { + Self::#variant_name { + #(#field_name: value.#field_name),* } } + } + + impl std::convert::TryFrom<#name> for #variant_name { + type Error = #name; - impl std::convert::TryFrom<#name> for #variant_name { - type Error = (); - - fn try_from(value: #name) -> Result { - if let #name::#variant_name { #(#field_name),* } = value { - Ok(#variant_name { - #(#field_name),* - }) - } else { - Err(()) - } + fn try_from(value: #name) -> Result { + if let #name::#variant_name { #(#field_name),* } = value { + Ok(#variant_name { + #(#field_name),* + }) + } else { + Err(value) } } + } - #(#attrs)* - struct #variant_name { - #(#field_name: #field_type),* - } + #(#global_attrs)* + #(#variant_attrs)* + struct #variant_name { + #(#field_name: #field_type),* } } } - }); + } + }); - let output = quote! { - #(#variant_impls)* - }; + let output = quote! { + #(#variant_impls)* + #cast + }; - TokenStream::from(output) - } else { - panic!("Discriminant can only be derived for enums.") - } + TokenStream::from(output) } diff --git a/code/derive-discriminant/tests/macro.rs b/code/derive-discriminant/tests/macro.rs index 312f93f..7374555 100644 --- a/code/derive-discriminant/tests/macro.rs +++ b/code/derive-discriminant/tests/macro.rs @@ -1,3 +1,5 @@ +#![feature(unsize)] + use derive_discriminant::Discriminant; #[derive(Discriminant)] diff --git a/code/executor/Cargo.toml b/code/executor/Cargo.toml index 56a3215..5ad832a 100644 --- a/code/executor/Cargo.toml +++ b/code/executor/Cargo.toml @@ -4,3 +4,10 @@ version = "0.1.0" edition = "2021" [dependencies] +anyhow = "1.0.70" +async-trait = "0.1.68" +derive-discriminant.workspace = true +derive-build.workspace = true +ron = "0.8.0" +tokio = { version = "1.27.0", features = ["full"] } +utils.workspace = true diff --git a/code/executor/src/command.rs b/code/executor/src/command.rs new file mode 100644 index 0000000..3800478 --- /dev/null +++ b/code/executor/src/command.rs @@ -0,0 +1,35 @@ +//! Commands are executed as such +//! +//! ```text +//! {cmd header} +//! {args} +//! ``` +//! +//! where {cmd data} is one line of RON +//! but {args} can be several lines + +use async_trait::async_trait; +use derive_discriminant::Discriminant; + +use crate::Ctx; + +mod bash; +mod zsh; + +/// The command we are executing +#[derive(Discriminant)] +enum Cmd { + /// a zsh script to execute + Zsh, + Bash, +} + +#[async_trait] +trait Command { + async fn execute(&self, ctx: Ctx, input: &str) -> anyhow::Result; +} + +fn this_requires_unsize() { + let cmd1: Box = Cmd::Zsh.cast(); + let cmd2: Box = Cmd::Bash.cast(); +} diff --git a/code/executor/src/command/bash.rs b/code/executor/src/command/bash.rs new file mode 100644 index 0000000..010739b --- /dev/null +++ b/code/executor/src/command/bash.rs @@ -0,0 +1,26 @@ +use anyhow::{ensure, Context}; +use async_trait::async_trait; +use utils::str::StringExt; + +use crate::{ + command::{Bash, Command}, + Ctx, +}; + +#[async_trait] +impl Command for Bash { + async fn execute(&self, _exec: Ctx, input: &str) -> anyhow::Result { + let output = tokio::process::Command::new("bash") + .arg("-c") + .arg(input) + .output() + .await?; + + ensure!(output.status.success(), "bash command failed"); + + let mut output = String::from_utf8(output.stdout).context("could not parse to UTF-8")?; + output.trim_end_in_place(); // remove trailing newline + + Ok(output) + } +} diff --git a/code/executor/src/command/zsh.rs b/code/executor/src/command/zsh.rs new file mode 100644 index 0000000..0927442 --- /dev/null +++ b/code/executor/src/command/zsh.rs @@ -0,0 +1,58 @@ +use anyhow::{ensure, Context}; +use async_trait::async_trait; +use utils::str::StringExt; + +use crate::{ + command::{Command, Zsh}, + Ctx, +}; + +#[async_trait] +impl Command for Zsh { + async fn execute(&self, exec: Ctx, input: &str) -> anyhow::Result { + let output = tokio::process::Command::new("zsh") + .arg("-c") + .arg(input) + .output() + .await?; + + ensure!(output.status.success(), "zsh command failed"); + + let mut output = String::from_utf8(output.stdout).context("could not parse to UTF-8")?; + output.trim_end_in_place(); // remove trailing newline + + Ok(output) + } +} + +#[cfg(test)] +mod tests { + use crate::{command::Command, Ctx}; + + #[tokio::test] + async fn test_oneline() -> anyhow::Result<()> { + let exec = Ctx::default(); + let cmd = super::Zsh; + + let output = cmd.execute(exec, "echo hello there").await?; + + assert_eq!(output, "hello there"); + + Ok(()) + } + + #[tokio::test] + async fn test_multiline() -> anyhow::Result<()> { + let exec = Ctx::default(); + let cmd = super::Zsh; + + let input = r#"echo hello + echo there"#; + + let output = cmd.execute(exec, input).await?; + + assert_eq!(output, "hello\nthere"); + + Ok(()) + } +} diff --git a/code/executor/src/main.rs b/code/executor/src/main.rs index af883e0..437a058 100644 --- a/code/executor/src/main.rs +++ b/code/executor/src/main.rs @@ -1,4 +1,39 @@ +#![feature(unsize)] + +use std::sync::Arc; + +mod command; + + /// TODO: add executor functionality (running zsh cmd, cd, etc, and modifying files) fn main() { println!("Hello, world!"); } + +type Ctx = Arc; + +#[derive(Default)] +struct Inner { + +} + +struct Executor { + ctx: Ctx, +} + +impl Executor { + fn new() -> Self { + Self { + ctx: Default::default(), + } + } + + fn run(&self, input: &str) -> anyhow::Result { + // let cmd = command::Cmd::try_from(input)?; + // let output = cmd.execute(self.ctx, input)?; + // Ok(output) + todo!() + } +} + + diff --git a/code/openai/src/lib.rs b/code/openai/src/lib.rs index 83ab346..6ebbb79 100644 --- a/code/openai/src/lib.rs +++ b/code/openai/src/lib.rs @@ -217,7 +217,7 @@ const fn empty(input: &[T]) -> bool { input.is_empty() } -#[derive(Serialize, Debug, Build)] +#[derive(Debug, Build, Serialize)] pub struct ChatRequest { pub model: ChatModel, pub messages: Vec, @@ -228,6 +228,7 @@ pub struct ChatRequest { /// /// OpenAI generally recommend altering this or top_p but not both. #[serde(skip_serializing_if = "real_is_one")] + #[default = 1.0] pub temperature: f64, /// An alternative to sampling with temperature, called nucleus sampling, where the model @@ -236,16 +237,24 @@ pub struct ChatRequest { /// /// OpenAI generally recommends altering this or temperature but not both. #[serde(skip_serializing_if = "real_is_one")] + #[default = 1.0] pub top_p: f64, /// How many chat completion choices to generate for each input message. #[serde(skip_serializing_if = "int_is_one")] + #[default = 1] pub n: u32, #[serde(skip_serializing_if = "empty", rename = "stop")] pub stop_at: Vec, } +impl Default for ChatRequest { + fn default() -> Self { + Self::new() + } +} + impl<'a> From<&'a str> for ChatRequest { fn from(input: &'a str) -> Self { Self { @@ -281,19 +290,6 @@ impl From<[Msg; N]> for ChatRequest { } } -impl Default for ChatRequest { - fn default() -> Self { - Self { - model: ChatModel::default(), - messages: vec![], - temperature: 1.0, - top_p: 1.0, - n: 1, - stop_at: Vec::new(), - } - } -} - #[derive(Serialize, Deserialize, Debug)] pub struct ChatChoice { pub message: Msg, diff --git a/code/utils/src/lib.rs b/code/utils/src/lib.rs index e7fcc7f..367a40e 100644 --- a/code/utils/src/lib.rs +++ b/code/utils/src/lib.rs @@ -4,6 +4,7 @@ use tokio::task::JoinSet; use tokio_stream::wrappers::ReceiverStream; pub mod discretize; +pub mod str; // pub type SyncBoxStream<'a, T> = Pin + Send + Sync + 'a>>; pub type Stream = futures_util::stream::BoxStream<'static, T>; diff --git a/code/utils/src/str.rs b/code/utils/src/str.rs new file mode 100644 index 0000000..947d20e --- /dev/null +++ b/code/utils/src/str.rs @@ -0,0 +1,44 @@ +mod sealed { + pub trait Sealed {} +} + +pub trait StringExt: sealed::Sealed { + /// trim the string in place + fn trim_end_in_place(&mut self); +} + +impl sealed::Sealed for String {} + +impl StringExt for String { + fn trim_end_in_place(&mut self) { + self.truncate(self.trim_end().len()); + } +} + +#[cfg(test)] +mod tests { + use crate::str::StringExt; + + #[test] + fn test_trim_end_in_place() { + let mut s = "hello there".to_string(); + s.trim_end_in_place(); + assert_eq!(s, "hello there"); + + let mut s = "".to_string(); + s.trim_end_in_place(); + assert_eq!(s, ""); + + let mut s = " ".to_string(); + s.trim_end_in_place(); + assert_eq!(s, ""); + + let mut s = "hello there ".to_string(); + s.trim_end_in_place(); + assert_eq!(s, "hello there"); + + let mut s = " hello there ".to_string(); + s.trim_end_in_place(); + assert_eq!(s, " hello there"); + } +}