Skip to content

Commit

Permalink
bug-fix: ensure task Send safety for download function (#568)
Browse files Browse the repository at this point in the history
  • Loading branch information
adrian-kong authored Sep 25, 2023
1 parent 7f75774 commit 676d88b
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 16 deletions.
4 changes: 2 additions & 2 deletions crates/esthri/src/aws_sdk.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ use aws_sdk_s3::types::{CompletedMultipartUpload, CompletedPart};
use aws_smithy_types_convert::date_time::DateTimeExt;
use bytes::Bytes;
use chrono::{DateTime, Utc};
use futures::Stream;
use futures::{Stream, TryStreamExt};

use crate::{Error, Result};

Expand Down Expand Up @@ -132,7 +132,7 @@ pub struct GetObjectResponse {

impl GetObjectResponse {
pub fn into_stream(self) -> impl Stream<Item = Result<Bytes>> {
futures::TryStreamExt::map_err(self.stream, |e| Error::ByteStreamError(e.to_string()))
TryStreamExt::map_err(self.stream, |e| Error::ByteStreamError(e.to_string()))
}
}

Expand Down
24 changes: 10 additions & 14 deletions crates/esthri/src/ops/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -127,18 +127,18 @@ async fn download_file(
} else {
let dest = &Arc::new(dest.take_std_file().await);
let part_size = obj_info.size;
let stream = download_unordered_streaming_helper(s3, bucket, key, obj_info.parts)
.map_ok(|(part, mut chunks)| async move {
let limit = Config::global().concurrent_writer_tasks();
download_unordered_streaming_helper(s3, bucket, key, obj_info.parts)
.try_for_each_concurrent(limit, |(part, mut chunks)| async move {
let mut offset = (part - 1) * part_size;
while let Some(buf) = chunks.try_next().await? {
let len = buf.len();
write_all_at(Arc::clone(dest), buf, offset as u64).await?;
write_all_at(dest.clone(), buf, offset as u64).await?;
offset += len as i64;
}
Result::Ok(())
Ok(())
})
.try_buffer_unordered(Config::global().concurrent_writer_tasks());
stream.try_collect().await?;
.await?;
};

// If we're trying to download into a directory, assemble the path for the user
Expand Down Expand Up @@ -211,19 +211,16 @@ async fn init_download_dir(path: &Path) -> Result<PathBuf> {
#[cfg(unix)]
async fn write_all_at(file: Arc<std::fs::File>, buf: Bytes, offset: u64) -> Result<()> {
use std::os::unix::prelude::FileExt;

tokio::task::spawn_blocking(move || {
file.write_all_at(&buf, offset)?;
Result::Ok(())
Ok(())
})
.await??;
Ok(())
.await?
}

#[cfg(windows)]
async fn write_all_at(file: Arc<std::fs::File>, buf: Bytes, offset: u64) -> Result<()> {
use std::os::windows::prelude::FileExt;

tokio::task::spawn_blocking(move || {
let (mut offset, mut length) = (offset, buf.len());
let mut buffer_offset = 0;
Expand All @@ -235,8 +232,7 @@ async fn write_all_at(file: Arc<std::fs::File>, buf: Bytes, offset: u64) -> Resu
offset += write_size as u64;
buffer_offset += write_size;
}
Result::Ok(())
Ok(())
})
.await??;
Ok(())
.await?
}

0 comments on commit 676d88b

Please sign in to comment.