Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: extract run exports using object #1114

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,13 @@ mod post_process;
pub mod rebuild;
#[cfg(feature = "recipe-generation")]
pub mod recipe_generator;
mod run_exports;
mod unix;
pub mod upload;
mod windows;

mod package_cache_reporter;

use std::{
collections::{BTreeMap, HashMap},
env::current_dir,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use rattler_conda_types::RepoDataRecord;

/// A reporter that makes it easy to show the progress of updating the package
/// cache.
#[derive(Clone)]
pub struct PackageCacheReporter {
inner: Arc<Mutex<PackageCacheReporterInner>>,
}
Expand Down
1 change: 0 additions & 1 deletion src/render/mod.rs
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
#![allow(missing_docs)]
//! Render the dependencies to a final recipe

mod package_cache_reporter;
pub mod pin;
pub mod resolved_dependencies;
mod run_exports;
Expand Down
51 changes: 17 additions & 34 deletions src/render/resolved_dependencies.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,10 @@ use std::{

use indicatif::{HumanBytes, MultiProgress, ProgressBar};
use rattler::install::Placement;
use rattler_cache::package_cache::{CacheKey, PackageCache, PackageCacheError};
use rattler_cache::package_cache::PackageCache;
use rattler_conda_types::{
package::{PackageFile, RunExportsJson},
version_spec::ParseVersionSpecError,
MatchSpec, PackageName, PackageRecord, ParseStrictness, Platform, RepoDataRecord,
StringMatcher, VersionSpec,
package::RunExportsJson, version_spec::ParseVersionSpecError, MatchSpec, PackageName,
PackageRecord, ParseStrictness, Platform, RepoDataRecord, StringMatcher, VersionSpec,
};
use reqwest_middleware::ClientWithMiddleware;
use serde::{Deserialize, Serialize};
Expand All @@ -25,12 +23,13 @@ use url::Url;
use super::pin::PinError;
use crate::{
metadata::{build_reindexed_channels, BuildConfiguration, Output},
package_cache_reporter::PackageCacheReporter,
recipe::parser::{Dependency, Requirements},
render::{
package_cache_reporter::PackageCacheReporter,
pin::PinArgs,
solver::{install_packages, solve_environment},
},
run_exports::{RunExportExtractor, RunExportExtractorError},
tool_configuration,
tool_configuration::Configuration,
};
Expand Down Expand Up @@ -421,7 +420,7 @@ pub enum ResolveError {
DependencyResolutionError(#[from] anyhow::Error),

#[error("Could not collect run exports")]
CouldNotCollectRunExports(#[from] PackageCacheError),
CouldNotCollectRunExports(#[from] RunExportExtractorError),

#[error("Could not parse version spec: {0}")]
VersionSpecParseError(#[from] ParseVersionSpecError),
Expand Down Expand Up @@ -554,11 +553,11 @@ async fn amend_run_exports(
multi_progress: MultiProgress,
progress_prefix: impl Into<Cow<'static, str>>,
top_level_pb: Option<ProgressBar>,
) -> Result<(), PackageCacheError> {
) -> Result<(), RunExportExtractorError> {
let max_concurrent_requests = Arc::new(Semaphore::new(50));
let (tx, mut rx) = mpsc::channel(50);

let mut progress = PackageCacheReporter::new(
let progress = PackageCacheReporter::new(
multi_progress,
top_level_pb.map_or(Placement::End, Placement::After),
)
Expand All @@ -570,39 +569,23 @@ async fn amend_run_exports(
continue;
}

let progress_reporter = Arc::new(progress.add(pkg));
let extractor = RunExportExtractor::default()
.with_max_concurrent_requests(max_concurrent_requests.clone())
.with_client(client.clone())
.with_package_cache(package_cache.clone(), progress.clone());

let cache_key = CacheKey::from(&pkg.package_record);
let client = client.clone();
let url = pkg.url.clone();
let max_concurrent_requests = max_concurrent_requests.clone();
let tx = tx.clone();
let package_cache = package_cache.clone();
let record = pkg.clone();
tokio::spawn(async move {
let _permit = max_concurrent_requests
.acquire_owned()
.await
.expect("semaphore error");
let result = match package_cache
.get_or_fetch_from_url(cache_key, url, client, Some(progress_reporter))
.await
{
Ok(package_dir) => {
let run_exports =
RunExportsJson::from_package_directory(package_dir.path()).ok();
Ok((pkg_idx, run_exports))
}
Err(e) => Err(e),
};
let _ = tx.send(result).await;
let result = extractor.extract(&record).await;
let _ = tx.send((pkg_idx, result)).await;
});
}

drop(tx);

while let Some(result) = rx.recv().await {
let (pkg_idx, run_exports) = result?;
records[pkg_idx].package_record.run_exports = run_exports;
while let Some((pkg_idx, run_exports)) = rx.recv().await {
records[pkg_idx].package_record.run_exports = run_exports?;
}

Ok(())
Expand Down
106 changes: 106 additions & 0 deletions src/run_exports.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::sync::Arc;

use futures::future::OptionFuture;
use rattler_cache::package_cache::{CacheKey, PackageCache, PackageCacheError};
use rattler_conda_types::{
package::{PackageFile, RunExportsJson},
RepoDataRecord,
};
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;
use tokio::sync::Semaphore;

use crate::package_cache_reporter::PackageCacheReporter;

/// An object that can help extract run export information from a package.
///
/// This object can be configured with multiple sources and it will do its best
/// to find the run exports as fast as possible using the available resources.
#[derive(Default)]
pub struct RunExportExtractor {
max_concurrent_requests: Option<Arc<Semaphore>>,
package_cache: Option<(PackageCache, PackageCacheReporter)>,
client: Option<ClientWithMiddleware>,
}

#[derive(Debug, Error)]
pub enum RunExportExtractorError {
#[error(transparent)]
PackageCache(#[from] PackageCacheError),

#[error("the operation was cancelled")]
Cancelled,
}

impl RunExportExtractor {
/// Sets the maximum number of concurrent requests that the extractor can
/// make.
pub fn with_max_concurrent_requests(self, max_concurrent_requests: Arc<Semaphore>) -> Self {
Self {
max_concurrent_requests: Some(max_concurrent_requests),
..self
}
}

/// Set the package cache that the extractor can use as well as a reporter
/// to allow progress reporting.
pub fn with_package_cache(
self,
package_cache: PackageCache,
reporter: PackageCacheReporter,
) -> Self {
Self {
package_cache: Some((package_cache, reporter)),
..self
}
}

/// Sets the download client that the extractor can use.
pub fn with_client(self, client: ClientWithMiddleware) -> Self {
Self {
client: Some(client),
..self
}
}

/// Extracts the run exports from a package. Returns `None` if no run
/// exports are found.
pub async fn extract(
mut self,
record: &RepoDataRecord,
) -> Result<Option<RunExportsJson>, RunExportExtractorError> {
self.extract_into_package_cache(record).await
}

/// Extract the run exports from a package by downloading it to the cache
/// and then reading the run_exports.json file.
async fn extract_into_package_cache(
&mut self,
record: &RepoDataRecord,
) -> Result<Option<RunExportsJson>, RunExportExtractorError> {
let Some((package_cache, mut package_cache_reporter)) = self.package_cache.clone() else {
return Ok(None);
};
let Some(client) = self.client.clone() else {
return Ok(None);
};

let progress_reporter = package_cache_reporter.add(record);
let cache_key = CacheKey::from(&record.package_record);
let url = record.url.clone();
let max_concurrent_requests = self.max_concurrent_requests.clone();

let _permit = OptionFuture::from(max_concurrent_requests.map(Semaphore::acquire_owned))
.await
.transpose()
.expect("semaphore error");

match package_cache
.get_or_fetch_from_url(cache_key, url, client, Some(Arc::new(progress_reporter)))
.await
{
Ok(package_dir) => Ok(RunExportsJson::from_package_directory(package_dir.path()).ok()),
Err(e) => Err(e.into()),
}
}
}
Loading