From a53af5234f406e3b8a3a8597cd69f860a62621e2 Mon Sep 17 00:00:00 2001 From: Brian Caswell Date: Thu, 30 Nov 2023 22:59:37 -0500 Subject: [PATCH] add support for partial range headers Ref: https://learn.microsoft.com/en-us/rest/api/storageservices/specifying-the-range-header-for-blob-service-operations --- sdk/core/src/request_options/range.rs | 96 +++++++++---------- sdk/storage_blobs/examples/partial_range.rs | 45 +++++++++ .../src/blob/operations/get_blob.rs | 37 +++---- sdk/storage_blobs/src/blob/page_range_list.rs | 6 +- sdk/storage_blobs/src/options/ba512_range.rs | 17 ++-- 5 files changed, 120 insertions(+), 81 deletions(-) create mode 100644 sdk/storage_blobs/examples/partial_range.rs diff --git a/sdk/core/src/request_options/range.rs b/sdk/core/src/request_options/range.rs index 1cae81a457..d450c51d04 100644 --- a/sdk/core/src/request_options/range.rs +++ b/sdk/core/src/request_options/range.rs @@ -2,67 +2,66 @@ use crate::error::{Error, ErrorKind, ResultExt}; use crate::headers::{self, AsHeaders, HeaderName, HeaderValue}; use std::convert::From; use std::fmt; +use std::ops::{Range as StdRange, RangeFrom}; use std::str::FromStr; -#[derive(Debug, Copy, Clone, PartialEq, Eq)] -pub struct Range { - pub start: u64, - pub end: u64, +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Range { + Range(StdRange), + RangeFrom(RangeFrom), } impl Range { pub fn new(start: u64, end: u64) -> Range { - Range { start, end } + (start..end).into() } - pub fn len(&self) -> u64 { - self.end - self.start + fn optional_len(&self) -> Option { + match self { + Range::Range(r) => Some(r.end - r.start), + Range::RangeFrom(_) => None, + } } +} - pub fn is_empty(&self) -> bool { - self.end == self.start +impl From> for Range { + fn from(r: StdRange) -> Self { + Self::Range(r) } } -impl AsHeaders for Range { - type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>; - - fn as_headers(&self) -> Self::Iter { - let mut headers = vec![(headers::MS_RANGE, format!("{self}").into())]; - if self.len() < 1024 * 1024 * 4 { - headers.push(( - headers::RANGE_GET_CONTENT_CRC64, - HeaderValue::from_static("true"), - )); - } - headers.into_iter() +impl From> for Range { + fn from(r: RangeFrom) -> Self { + Self::RangeFrom(r) } } -impl From> for Range { - fn from(r: std::ops::Range) -> Self { - Self { - start: r.start, - end: r.end, - } +impl From> for Range { + fn from(r: StdRange) -> Self { + (r.start as u64..r.end as u64).into() } } -impl From> for Range { - fn from(r: std::ops::Range) -> Self { - Self { - start: r.start as u64, - end: r.end as u64, - } +impl From> for Range { + fn from(r: RangeFrom) -> Self { + (r.start as u64..).into() } } -impl From> for Range { - fn from(r: std::ops::Range) -> Self { - Self { - start: r.start as u64, - end: r.end as u64, +impl AsHeaders for Range { + type Iter = std::vec::IntoIter<(HeaderName, HeaderValue)>; + + fn as_headers(&self) -> Self::Iter { + let mut headers = vec![(headers::MS_RANGE, format!("{self}").into())]; + if let Some(len) = self.optional_len() { + if len < 1024 * 1024 * 4 { + headers.push(( + headers::RANGE_GET_CONTENT_CRC64, + HeaderValue::from_static("true"), + )); + } } + headers.into_iter() } } @@ -82,16 +81,16 @@ impl FromStr for Range { let cp_start = v[0].parse::().map_kind(ErrorKind::DataConversion)?; let cp_end = v[1].parse::().map_kind(ErrorKind::DataConversion)? + 1; - Ok(Range { - start: cp_start, - end: cp_end, - }) + Ok((cp_start..cp_end).into()) } } impl fmt::Display for Range { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "bytes={}-{}", self.start, self.end - 1) + match self { + Range::Range(r) => write!(f, "bytes={}-{}", r.start, r.end - 1), + Range::RangeFrom(r) => write!(f, "bytes={}-", r.start), + } } } @@ -102,9 +101,7 @@ mod test { #[test] fn test_range_parse() { let range = "1000/2000".parse::().unwrap(); - - assert_eq!(range.start, 1000); - assert_eq!(range.end, 2001); + assert_eq!(range, Range::new(1000, 2001)); } #[test] @@ -119,13 +116,8 @@ mod test { #[test] fn test_range_display() { - let range = Range { - start: 100, - end: 501, - }; - + let range = Range::new(100, 501); let txt = format!("{range}"); - assert_eq!(txt, "bytes=100-500"); } } diff --git a/sdk/storage_blobs/examples/partial_range.rs b/sdk/storage_blobs/examples/partial_range.rs new file mode 100644 index 0000000000..1be6056294 --- /dev/null +++ b/sdk/storage_blobs/examples/partial_range.rs @@ -0,0 +1,45 @@ +use azure_storage::prelude::*; +use azure_storage_blobs::prelude::*; +use futures::stream::StreamExt; +use uuid::Uuid; + +#[tokio::main] +async fn main() -> azure_core::Result<()> { + env_logger::init(); + + // First we retrieve the account name and access key from environment variables. + let account = + std::env::var("STORAGE_ACCOUNT").expect("Set env variable STORAGE_ACCOUNT first!"); + let access_key = + std::env::var("STORAGE_ACCESS_KEY").expect("Set env variable STORAGE_ACCESS_KEY first!"); + + let container_name = format!("range-example-{}", Uuid::new_v4()); + let blob_name = format!("blob-{}.txt", Uuid::new_v4()); + + let storage_credentials = StorageCredentials::access_key(account.clone(), access_key); + let container_client = + BlobServiceClient::new(account, storage_credentials).container_client(container_name); + container_client.create().await?; + + let blob_client = container_client.blob_client(&blob_name); + + let buf = "0123456789".repeat(100); + + blob_client.put_block_blob(buf.clone()).await?; + + let range = 3usize..; + let mut stream = blob_client.get().range(range.clone()).into_stream(); + + let mut data: Vec = vec![]; + while let Some(value) = stream.next().await { + let value = value?.data.collect().await?; + println!("{}", value.len()); + data.extend(&value); + } + let value = String::from_utf8(data)?; + assert_eq!(&buf[range.clone()], value); + + container_client.delete().await?; + + Ok(()) +} diff --git a/sdk/storage_blobs/src/blob/operations/get_blob.rs b/sdk/storage_blobs/src/blob/operations/get_blob.rs index 15801d5875..7746cddee9 100644 --- a/sdk/storage_blobs/src/blob/operations/get_blob.rs +++ b/sdk/storage_blobs/src/blob/operations/get_blob.rs @@ -31,9 +31,10 @@ impl GetBlobBuilder { let range = match continuation { Some(range) => range, - None => { - initial_range(this.chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE), this.range) - } + None => initial_range( + this.chunk_size.unwrap_or(DEFAULT_CHUNK_SIZE), + this.range.clone(), + ), }; this.blob_versioning.append_to_url_query(&mut url); @@ -105,17 +106,18 @@ impl GetBlobResponse { impl Continuable for GetBlobResponse { type Continuation = Range; fn continuation(&self) -> Option { - self.remaining_range + self.remaining_range.clone() } } // calculate the first Range for use at the beginning of the Pageable. fn initial_range(chunk_size: u64, request_range: Option) -> Range { match request_range { - Some(range) => { - let len = std::cmp::min(range.len(), chunk_size); - Range::new(range.start, range.start + len) + Some(Range::Range(x)) => { + let len = std::cmp::min(x.end - x.start, chunk_size); + (x.start..x.start + len).into() } + Some(Range::RangeFrom(x)) => (x.start..x.start + chunk_size).into(), None => Range::new(0, chunk_size), } } @@ -152,19 +154,22 @@ fn remaining_range( // if the response said the end of the blob was downloaded, we're done // Note, we add + 1, as we don't need to re-fetch the last // byte of the previous request. - if content_range.end() + 1 >= requested_range.end { - return None; - } + let after = content_range.end() + 1; - // if the user specified range is smaller than the blob, truncate the - // requested range. Note, we add + 1, as we don't need to re-fetch the last - // byte of the previous request. - let start = content_range.end() + 1; - let remaining_size = requested_range.end - start; + let remaining_size = match requested_range { + Range::Range(x) => { + if after >= x.end { + return None; + } + x.end - after + } + // no requested end + Range::RangeFrom(_) => after, + }; let size = std::cmp::min(remaining_size, chunk_size); - Some(Range::new(start, start + size)) + Some(Range::new(after, after + size)) } #[cfg(test)] diff --git a/sdk/storage_blobs/src/blob/page_range_list.rs b/sdk/storage_blobs/src/blob/page_range_list.rs index 1d0a1a85ee..9653b63860 100644 --- a/sdk/storage_blobs/src/blob/page_range_list.rs +++ b/sdk/storage_blobs/src/blob/page_range_list.rs @@ -66,10 +66,8 @@ mod test { let prl = PageRangeList::try_from_xml(page_list).unwrap(); assert!(prl.ranges.len() == 2); - assert!(prl.ranges[0].start == 0); - assert!(prl.ranges[0].end == 511); - assert!(prl.ranges[1].start == 1024); - assert!(prl.ranges[1].end == 1535); + assert!(prl.ranges[0] == Range::new(0, 511)); + assert!(prl.ranges[1] == Range::new(1024, 1535)); let page_list = ""; let prl = PageRangeList::try_from_xml(page_list).unwrap(); diff --git a/sdk/storage_blobs/src/options/ba512_range.rs b/sdk/storage_blobs/src/options/ba512_range.rs index 6a8f3e15df..044a81cdd3 100644 --- a/sdk/storage_blobs/src/options/ba512_range.rs +++ b/sdk/storage_blobs/src/options/ba512_range.rs @@ -45,10 +45,7 @@ impl BA512Range { impl From for Range { fn from(range: BA512Range) -> Self { - Self { - start: range.start(), - end: range.end(), - } + (range.start()..range.end()).into() } } @@ -56,7 +53,12 @@ impl TryFrom for BA512Range { type Error = Error; fn try_from(r: Range) -> azure_core::Result { - BA512Range::new(r.start, r.end) + match r { + Range::Range(r) => BA512Range::new(r.start, r.end), + Range::RangeFrom(r) => Err(Error::with_message(ErrorKind::DataConversion, || { + format!("error converting RangeFrom<{:?}> into BA512Range", r) + })), + } } } @@ -109,10 +111,7 @@ impl fmt::Display for BA512Range { impl<'a> From<&'a BA512Range> for Range { fn from(ba: &'a BA512Range) -> Range { - Range { - start: ba.start(), - end: ba.end(), - } + (ba.start()..ba.end()).into() } }