From 5b037b89baeb4e29d500643a32323e69616c57f8 Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Fri, 2 Feb 2024 10:31:15 +0100 Subject: [PATCH 1/2] feat: allow disambiguating environments --- src/cli/run.rs | 13 ++- src/lib.rs | 3 +- src/task/mod.rs | 5 +- src/task/task_environment.rs | 191 +++++++++++++++++++++++++++++++++++ src/task/task_graph.rs | 169 ++++++------------------------- tests/common/mod.rs | 14 +-- 6 files changed, 245 insertions(+), 150 deletions(-) create mode 100644 src/task/task_environment.rs diff --git a/src/cli/run.rs b/src/cli/run.rs index e502d0bdf..8e8318c34 100644 --- a/src/cli/run.rs +++ b/src/cli/run.rs @@ -10,7 +10,10 @@ use rattler_conda_types::Platform; use crate::activation::get_environment_variables; use crate::project::errors::UnsupportedPlatformError; -use crate::task::{ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, TaskGraph}; +use crate::task::{ + ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, SearchEnvironments, + TaskGraph, +}; use crate::{Project, UpdateLockFileOptions}; use crate::environment::LockFileDerivedData; @@ -77,12 +80,12 @@ pub async fn execute(args: Args) -> miette::Result<()> { tracing::debug!("Task parsed from run command: {:?}", task_args); // Construct a task graph from the input arguments - let task_graph = TaskGraph::from_cmd_args( + let search_environment = SearchEnvironments::from_opt_env( &project, - task_args, - Some(Platform::current()), explicit_environment.clone(), - )?; + Some(Platform::current()), + ); + let task_graph = TaskGraph::from_cmd_args(&project, &search_environment, task_args)?; // Traverse the task graph in topological order and execute each individual task. let mut task_idx = 0; diff --git a/src/lib.rs b/src/lib.rs index 27b1bb26a..72baabfdc 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -28,7 +28,8 @@ pub use project::{ DependencyType, Project, SpecType, }; pub use task::{ - CmdArgs, ExecutableTask, RunOutput, Task, TaskExecutionError, TaskGraph, TaskGraphError, + CmdArgs, ExecutableTask, FindTaskError, FindTaskSource, RunOutput, SearchEnvironments, Task, + TaskDisambiguation, TaskExecutionError, TaskGraph, TaskGraphError, }; use rattler_networking::retry_policies::ExponentialBackoff; diff --git a/src/task/mod.rs b/src/task/mod.rs index 731a21692..5e38e411f 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -7,12 +7,14 @@ use std::path::{Path, PathBuf}; mod error; mod executable_task; +mod task_environment; mod task_graph; pub use executable_task::{ ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, RunOutput, TaskExecutionError, }; +pub use task_environment::{FindTaskError, FindTaskSource, SearchEnvironments, TaskDisambiguation}; pub use task_graph::{TaskGraph, TaskGraphError, TaskId, TaskNode}; /// Represents different types of scripts @@ -22,7 +24,8 @@ pub enum Task { Plain(String), Execute(Execute), Alias(Alias), - // We don't what a way for the deserializer to except a custom task, as they are meant for tasks given in the command line. + // We want a way for the deserializer to except a custom task, as they are meant for tasks + // given in the command line. #[serde(skip)] Custom(Custom), } diff --git a/src/task/task_environment.rs b/src/task/task_environment.rs new file mode 100644 index 000000000..14c927c37 --- /dev/null +++ b/src/task/task_environment.rs @@ -0,0 +1,191 @@ +use crate::project::Environment; +use crate::task::error::{AmbiguousTaskError, MissingTaskError}; +use crate::{Project, Task}; +use itertools::Itertools; +use miette::Diagnostic; +use rattler_conda_types::Platform; +use thiserror::Error; + +/// Defines where the task was defined when looking for a task. +#[derive(Debug, Clone)] +pub enum FindTaskSource<'p> { + CmdArgs, + DependsOn(String, &'p Task), +} + +pub type TaskAndEnvironment<'p> = (Environment<'p>, &'p Task); + +pub trait TaskDisambiguation<'p> { + fn disambiguate(&self, task: &AmbiguousTask<'p>) -> Option>; +} + +#[derive(Default)] +pub struct NoDisambiguation; +pub struct DisambiguateFn(Fn); + +impl<'p> TaskDisambiguation<'p> for NoDisambiguation { + fn disambiguate(&self, _task: &AmbiguousTask<'p>) -> Option> { + None + } +} + +impl<'p, F: Fn(&AmbiguousTask<'p>) -> Option>> TaskDisambiguation<'p> + for DisambiguateFn +{ + fn disambiguate(&self, task: &AmbiguousTask<'p>) -> Option> { + self.0(task) + } +} + +/// An object to help with searching for tasks. +pub struct SearchEnvironments<'p, D: TaskDisambiguation<'p> = NoDisambiguation> { + pub project: &'p Project, + pub explicit_environment: Option>, + pub platform: Option, + pub disambiguate: D, +} + +/// Information about an task that was found when searching for a task +pub struct AmbiguousTask<'p> { + pub task_name: String, + pub depended_on_by: Option<(String, &'p Task)>, + pub environments: Vec>, +} + +impl<'p> From> for AmbiguousTaskError { + fn from(value: AmbiguousTask<'p>) -> Self { + Self { + task_name: value.task_name, + environments: value + .environments + .into_iter() + .map(|env| env.0.name().clone()) + .collect(), + } + } +} + +#[derive(Debug, Diagnostic, Error)] +pub enum FindTaskError { + #[error(transparent)] + MissingTask(MissingTaskError), + + #[error(transparent)] + AmbiguousTask(AmbiguousTaskError), +} + +impl<'p> SearchEnvironments<'p, NoDisambiguation> { + // Determine which environments we are allowed to check for tasks. + // + // If the user specified an environment, look for tasks in the main environment and the + // user specified environment. + // + // If the user did not specify an environment, look for tasks in any environment. + pub fn from_opt_env( + project: &'p Project, + explicit_environment: Option>, + platform: Option, + ) -> Self { + Self { + project, + explicit_environment, + platform, + disambiguate: NoDisambiguation, + } + } +} + +impl<'p, D: TaskDisambiguation<'p>> SearchEnvironments<'p, D> { + /// Returns a new `SearchEnvironments` with the given disambiguation function. + pub fn with_disambiguate_fn) -> Option>>( + self, + func: F, + ) -> SearchEnvironments<'p, DisambiguateFn> { + SearchEnvironments { + project: self.project, + explicit_environment: self.explicit_environment, + platform: self.platform, + disambiguate: DisambiguateFn(func), + } + } + + /// Finds the task with the given name or returns an error that explains why the task could not + /// be found. + pub fn find_task( + &self, + name: &str, + source: FindTaskSource<'p>, + ) -> Result, FindTaskError> { + // If the task was specified on the command line and there is no explicit environment and + // the task is only defined in the default feature, use the default environment. + if matches!(source, FindTaskSource::CmdArgs) && self.explicit_environment.is_none() { + if let Some(task) = self + .project + .manifest + .default_feature() + .targets + .resolve(self.platform) + .find_map(|target| target.tasks.get(name)) + { + // None of the other environments can have this task. Otherwise, its still + // ambiguous. + if !self + .project + .environments() + .into_iter() + .flat_map(|env| env.features(false).collect_vec()) + .flat_map(|feature| feature.targets.resolve(self.platform)) + .any(|target| target.tasks.contains_key(name)) + { + return Ok((self.project.default_environment(), task)); + } + } + } + + // If an explicit environment was specified, only look for tasks in that environment and + // the default environment. + let environments = if let Some(explicit_environment) = &self.explicit_environment { + vec![explicit_environment.clone()] + } else { + self.project.environments() + }; + + // Find all the task and environment combinations + let include_default_feature = true; + let mut tasks = Vec::new(); + for env in environments.iter() { + if let Some(task) = env + .tasks(self.platform, include_default_feature) + .ok() + .and_then(|tasks| tasks.get(name).copied()) + { + tasks.push((env.clone(), task)); + } + } + + match tasks.len() { + 0 => Err(FindTaskError::MissingTask(MissingTaskError { + task_name: name.to_string(), + })), + 1 => { + let (env, task) = tasks.remove(0); + Ok((env.clone(), task)) + } + _ => { + let ambiguous_task = AmbiguousTask { + task_name: name.to_string(), + depended_on_by: match source { + FindTaskSource::DependsOn(dep, task) => Some((dep, task)), + _ => None, + }, + environments: tasks, + }; + + match self.disambiguate.disambiguate(&ambiguous_task) { + Some(env) => Ok(env), + None => Err(FindTaskError::AmbiguousTask(ambiguous_task.into())), + } + } + } + } +} diff --git a/src/task/task_graph.rs b/src/task/task_graph.rs index 89c7ed55c..074a383d4 100644 --- a/src/task/task_graph.rs +++ b/src/task/task_graph.rs @@ -1,12 +1,12 @@ use crate::project::Environment; use crate::task::error::AmbiguousTaskError; +use crate::task::task_environment::{FindTaskError, FindTaskSource, SearchEnvironments}; +use crate::task::TaskDisambiguation; use crate::{ task::{error::MissingTaskError, CmdArgs, Custom, Task}, Project, }; -use itertools::Itertools; use miette::Diagnostic; -use rattler_conda_types::Platform; use std::{ borrow::Cow, collections::{HashMap, HashSet}, @@ -75,137 +75,19 @@ impl<'p> Index for TaskGraph<'p> { } } -/// Defines where the task was defined when looking for a task. -#[derive(Debug, Copy, Clone, Eq, PartialEq)] -enum FindTaskSource { - CmdArgs, - DependsOn, -} - -/// An object to help with searching for tasks. -struct SearchEnvironments<'p> { - pub project: &'p Project, - pub explicit_environment: Option>, - pub platform: Option, -} - -#[derive(Debug, Diagnostic, Error)] -pub enum FindTaskError { - #[error(transparent)] - MissingTask(MissingTaskError), - - #[error(transparent)] - AmbiguousTask(AmbiguousTaskError), -} - -impl<'p> SearchEnvironments<'p> { - // Determine which environments we are allowed to check for tasks. - // - // If the user specified an environment, look for tasks in the main environment and the - // user specified environment. - // - // If the user did not specify an environment, look for tasks in any environment. - pub fn from_opt_env( - project: &'p Project, - explicit_environment: Option>, - platform: Option, - ) -> Self { - Self { - project, - explicit_environment, - platform, - } - } - - /// Finds the task with the given name or returns an error that explains why the task could not - /// be found. - pub fn find_task( - &self, - name: &str, - source: FindTaskSource, - ) -> Result<(Environment<'p>, &'p Task), FindTaskError> { - // If the task was specified on the command line and there is no explicit environment and - // the task is only defined in the default feature, use the default environment. - if source == FindTaskSource::CmdArgs && self.explicit_environment.is_none() { - if let Some(task) = self - .project - .manifest - .default_feature() - .targets - .resolve(self.platform) - .find_map(|target| target.tasks.get(name)) - { - // None of the other environments can have this task. Otherwise, its still - // ambiguous. - if !self - .project - .environments() - .into_iter() - .flat_map(|env| env.features(false).collect_vec()) - .flat_map(|feature| feature.targets.resolve(self.platform)) - .any(|target| target.tasks.contains_key(name)) - { - return Ok((self.project.default_environment(), task)); - } - } - } - - // If an explicit environment was specified, only look for tasks in that environment and - // the default environment. - let environments = if let Some(explicit_environment) = &self.explicit_environment { - vec![explicit_environment.clone()] - } else { - self.project.environments() - }; - - // Find all the task and environment combinations - let include_default_feature = true; - let mut tasks = Vec::new(); - for env in environments { - if let Some(task) = env - .tasks(self.platform, include_default_feature) - .ok() - .and_then(|tasks| tasks.get(name).copied()) - { - tasks.push((env, task)); - } - } - - match tasks.len() { - 0 => Err(FindTaskError::MissingTask(MissingTaskError { - task_name: name.to_string(), - })), - 1 => { - let (env, task) = tasks.remove(0); - Ok((env.clone(), task)) - } - _ => Err(FindTaskError::AmbiguousTask(AmbiguousTaskError { - task_name: name.to_string(), - environments: tasks - .into_iter() - .map(|(env, _)| env.name().clone()) - .collect(), - })), - } - } -} - impl<'p> TaskGraph<'p> { pub fn project(&self) -> &'p Project { self.project } /// Constructs a new [`TaskGraph`] from a list of command line arguments. - pub fn from_cmd_args( + pub fn from_cmd_args>( project: &'p Project, + search_envs: &SearchEnvironments<'p, D>, args: Vec, - platform: Option, - environment: Option>, ) -> Result { let mut args = args; - let search_envs = SearchEnvironments::from_opt_env(project, environment, platform); - if let Some(name) = args.first() { match search_envs.find_task(name, FindTaskSource::CmdArgs) { Err(FindTaskError::MissingTask(_)) => {} @@ -260,9 +142,9 @@ impl<'p> TaskGraph<'p> { } /// Constructs a new instance of a [`TaskGraph`] from a root task. - fn from_root( + fn from_root>( project: &'p Project, - search_environments: SearchEnvironments<'p>, + search_environments: &SearchEnvironments<'p, D>, root: TaskNode<'p>, ) -> Result { let mut task_name_to_node: HashMap = @@ -285,16 +167,29 @@ impl<'p> TaskGraph<'p> { } // Find the task in the project - let (task_env, task_dependency) = - match search_environments.find_task(&dependency, FindTaskSource::DependsOn) { - Err(FindTaskError::MissingTask(err)) => { - return Err(TaskGraphError::MissingTask(err)) - } - Err(FindTaskError::AmbiguousTask(err)) => { - return Err(TaskGraphError::AmbiguousTask(err)) - } - Ok(result) => result, - }; + let node = &nodes[next_node_to_visit]; + let (task_env, task_dependency) = match search_environments.find_task( + &dependency, + FindTaskSource::DependsOn( + node.name + .clone() + .expect("only named tasks can have dependencies"), + match &node.task { + Cow::Borrowed(task) => task, + Cow::Owned(_) => { + unreachable!("only named tasks can have dependencies") + } + }, + ), + ) { + Err(FindTaskError::MissingTask(err)) => { + return Err(TaskGraphError::MissingTask(err)) + } + Err(FindTaskError::AmbiguousTask(err)) => { + return Err(TaskGraphError::AmbiguousTask(err)) + } + Ok(result) => result, + }; // Add the node to the graph let task_id = TaskId(nodes.len()); @@ -365,6 +260,7 @@ pub enum TaskGraphError { #[cfg(test)] mod test { + use crate::task::task_environment::SearchEnvironments; use crate::task::task_graph::TaskGraph; use crate::Project; use rattler_conda_types::Platform; @@ -377,11 +273,12 @@ mod test { ) -> Vec { let project = Project::from_str(Path::new(""), project_str).unwrap(); + let search_envs = SearchEnvironments::from_opt_env(&project, None, platform); + let graph = TaskGraph::from_cmd_args( &project, + &search_envs, run_args.into_iter().map(|arg| arg.to_string()).collect(), - platform, - None, ) .unwrap(); diff --git a/tests/common/mod.rs b/tests/common/mod.rs index bce5b9e9b..6f50ce669 100644 --- a/tests/common/mod.rs +++ b/tests/common/mod.rs @@ -14,8 +14,8 @@ use pixi::{ project, run, task::{self, AddArgs, AliasArgs}, }, - consts, EnvironmentName, ExecutableTask, Project, RunOutput, TaskGraph, TaskGraphError, - UpdateLockFileOptions, + consts, EnvironmentName, ExecutableTask, Project, RunOutput, SearchEnvironments, TaskGraph, + TaskGraphError, UpdateLockFileOptions, }; use rattler_conda_types::{MatchSpec, Platform}; @@ -247,13 +247,13 @@ impl PixiControl { .await?; // Create a task graph from the command line arguments. - let task_graph = TaskGraph::from_cmd_args( + let search_env = SearchEnvironments::from_opt_env( &project, - args.task, - Some(Platform::current()), explicit_environment, - ) - .map_err(RunError::TaskGraphError)?; + Some(Platform::current()), + ); + let task_graph = TaskGraph::from_cmd_args(&project, &search_env, args.task) + .map_err(RunError::TaskGraphError)?; // Iterate over all tasks in the graph and execute them. let mut task_env = None; From 2c9e5348bdf191f6a4b26ea13c87772cbabd35da Mon Sep 17 00:00:00 2001 From: Bas Zalmstra Date: Fri, 2 Feb 2024 11:25:40 +0100 Subject: [PATCH 2/2] feat: task ambiguation --- Cargo.lock | 26 ++++++++++++++++++++++++++ Cargo.toml | 1 + src/cli/run.rs | 37 ++++++++++++++++++++++++++++++++++--- src/task/mod.rs | 5 ++++- 4 files changed, 65 insertions(+), 4 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index ff3129a5f..5ee1d04bb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -945,6 +945,19 @@ dependencies = [ "syn 1.0.109", ] +[[package]] +name = "dialoguer" +version = "0.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bce805d770f407bc62102fca7c2c64ceef2fbcb2b8bd19d2765ce093980de" +dependencies = [ + "console", + "shell-words", + "tempfile", + "thiserror", + "zeroize", +] + [[package]] name = "digest" version = "0.10.7" @@ -2660,6 +2673,7 @@ dependencies = [ "clap_complete", "console", "deno_task_shell", + "dialoguer", "dirs", "dunce", "flate2", @@ -3867,6 +3881,12 @@ dependencies = [ "lazy_static", ] +[[package]] +name = "shell-words" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24188a676b6ae68c3b2cb3a01be17fbf7240ce009799bb56d5b1409051e78fde" + [[package]] name = "shlex" version = "1.3.0" @@ -5053,6 +5073,12 @@ dependencies = [ "zvariant", ] +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + [[package]] name = "zip" version = "0.6.6" diff --git a/Cargo.toml b/Cargo.toml index 6b941b5b4..b9b9764ab 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,6 +26,7 @@ clap-verbosity-flag = "2.1.2" clap_complete = "4.4.9" console = { version = "0.15.8", features = ["windows-console-colors"] } deno_task_shell = "0.14.3" +dialoguer = "0.11.0" dirs = "5.0.1" dunce = "1.0.4" flate2 = "1.0.28" diff --git a/src/cli/run.rs b/src/cli/run.rs index 8e8318c34..104ee10c3 100644 --- a/src/cli/run.rs +++ b/src/cli/run.rs @@ -1,9 +1,11 @@ use std::collections::hash_map::Entry; use std::collections::HashSet; +use std::convert::identity; use std::str::FromStr; use std::{collections::HashMap, path::PathBuf, string::String}; use clap::Parser; +use dialoguer::theme::ColorfulTheme; use itertools::Itertools; use miette::{miette, Context, Diagnostic}; use rattler_conda_types::Platform; @@ -11,8 +13,8 @@ use rattler_conda_types::Platform; use crate::activation::get_environment_variables; use crate::project::errors::UnsupportedPlatformError; use crate::task::{ - ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, SearchEnvironments, - TaskGraph, + AmbiguousTask, ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, + SearchEnvironments, TaskAndEnvironment, TaskGraph, }; use crate::{Project, UpdateLockFileOptions}; @@ -84,7 +86,9 @@ pub async fn execute(args: Args) -> miette::Result<()> { &project, explicit_environment.clone(), Some(Platform::current()), - ); + ) + .with_disambiguate_fn(disambiguate_task_interactive); + let task_graph = TaskGraph::from_cmd_args(&project, &search_environment, task_args)?; // Traverse the task graph in topological order and execute each individual task. @@ -257,3 +261,30 @@ async fn execute_task<'p>( Ok(()) } + +/// Called to disambiguate between environments to run a task in. +fn disambiguate_task_interactive<'p>( + problem: &AmbiguousTask<'p>, +) -> Option> { + let environment_names = problem + .environments + .iter() + .map(|(env, _)| env.name()) + .collect_vec(); + dialoguer::Select::with_theme(&ColorfulTheme::default()) + .with_prompt(format!( + "The task '{}' {}can be run in multiple environments.\n\nPlease select an environment to run the task in:", + problem.task_name, + if let Some(dependency) = &problem.depended_on_by { + format!("(depended on by '{}') ", dependency.0) + } else { + String::new() + } + )) + .report(false) + .items(&environment_names) + .default(0) + .interact_opt() + .map_or(None, identity) + .map(|idx| problem.environments[idx].clone()) +} diff --git a/src/task/mod.rs b/src/task/mod.rs index 5e38e411f..5e2ac872c 100644 --- a/src/task/mod.rs +++ b/src/task/mod.rs @@ -14,7 +14,10 @@ pub use executable_task::{ ExecutableTask, FailedToParseShellScript, InvalidWorkingDirectory, RunOutput, TaskExecutionError, }; -pub use task_environment::{FindTaskError, FindTaskSource, SearchEnvironments, TaskDisambiguation}; +pub use task_environment::{ + AmbiguousTask, FindTaskError, FindTaskSource, SearchEnvironments, TaskAndEnvironment, + TaskDisambiguation, +}; pub use task_graph::{TaskGraph, TaskGraphError, TaskId, TaskNode}; /// Represents different types of scripts