Skip to content

Commit aba4c6b

Browse files
authored
feat(zb-experimental): implement close in AsyncMultiRangeDownloader (#1555)
feat(zb-experimental): implement `close` in AsyncMultiRangeDownloader
1 parent 9c8856a commit aba4c6b

File tree

2 files changed

+149
-16
lines changed

2 files changed

+149
-16
lines changed

google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py

Lines changed: 50 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -71,16 +71,23 @@ class AsyncMultiRangeDownloader:
7171
client, bucket_name="chandrasiri-rs", object_name="test_open9"
7272
)
7373
my_buff1 = open('my_fav_file.txt', 'wb')
74+
my_buff1 = open('my_fav_file.txt', 'wb')
7475
my_buff2 = BytesIO()
7576
my_buff3 = BytesIO()
7677
my_buff4 = any_object_which_provides_BytesIO_like_interface()
78+
results_arr, error_obj = await mrd.download_ranges(
79+
my_buff4 = any_object_which_provides_BytesIO_like_interface()
7780
results_arr, error_obj = await mrd.download_ranges(
7881
[
82+
# (start_byte, bytes_to_read, writeable_buffer)
7983
# (start_byte, bytes_to_read, writeable_buffer)
8084
(0, 100, my_buff1),
8185
(100, 20, my_buff2),
8286
(200, 123, my_buff3),
8387
(300, 789, my_buff4),
88+
(100, 20, my_buff2),
89+
(200, 123, my_buff3),
90+
(300, 789, my_buff4),
8491
]
8592
)
8693
if error_obj:
@@ -94,6 +101,17 @@ class AsyncMultiRangeDownloader:
94101
for result in results_arr:
95102
print("downloaded bytes", result)
96103
104+
if error_obj:
105+
print("Error occurred: ")
106+
print(error_obj)
107+
print(
108+
"please issue call to `download_ranges` with updated"
109+
"`read_ranges` based on diff of (bytes_requested - bytes_written)"
110+
)
111+
112+
for result in results_arr:
113+
print("downloaded bytes", result)
114+
97115
98116
"""
99117

@@ -165,7 +183,8 @@ def __init__(
165183
self.object_name = object_name
166184
self.generation_number = generation_number
167185
self.read_handle = read_handle
168-
self.read_obj_str: _AsyncReadObjectStream = None
186+
self.read_obj_str: Optional[_AsyncReadObjectStream] = None
187+
self._is_stream_open: bool = False
169188

170189
async def open(self) -> None:
171190
"""Opens the bidi-gRPC connection to read from the object.
@@ -176,14 +195,19 @@ async def open(self) -> None:
176195
"Opening" constitutes fetching object metadata such as generation number
177196
and read handle and sets them as attributes if not already set.
178197
"""
179-
self.read_obj_str = _AsyncReadObjectStream(
180-
client=self.client,
181-
bucket_name=self.bucket_name,
182-
object_name=self.object_name,
183-
generation_number=self.generation_number,
184-
read_handle=self.read_handle,
185-
)
198+
if self._is_stream_open:
199+
raise ValueError("Underlying bidi-gRPC stream is already open")
200+
201+
if self.read_obj_str is None:
202+
self.read_obj_str = _AsyncReadObjectStream(
203+
client=self.client,
204+
bucket_name=self.bucket_name,
205+
object_name=self.object_name,
206+
generation_number=self.generation_number,
207+
read_handle=self.read_handle,
208+
)
186209
await self.read_obj_str.open()
210+
self._is_stream_open = True
187211
if self.generation_number is None:
188212
self.generation_number = self.read_obj_str.generation_number
189213
self.read_handle = self.read_obj_str.read_handle
@@ -206,11 +230,15 @@ async def download_ranges(
206230
to a requested range.
207231
208232
"""
233+
209234
if len(read_ranges) > 1000:
210235
raise ValueError(
211236
"Invalid input - length of read_ranges cannot be more than 1000"
212237
)
213238

239+
if not self._is_stream_open:
240+
raise ValueError("Underlying bidi-gRPC stream is not open")
241+
214242
read_id_to_writable_buffer_dict = {}
215243
results = []
216244
for i in range(0, len(read_ranges), _MAX_READ_RANGES_PER_BIDI_READ_REQUEST):
@@ -255,4 +283,18 @@ async def download_ranges(
255283
del read_id_to_writable_buffer_dict[
256284
object_data_range.read_range.read_id
257285
]
286+
258287
return results
288+
289+
async def close(self):
290+
"""
291+
Closes the underlying bidi-gRPC connection.
292+
"""
293+
if not self._is_stream_open:
294+
raise ValueError("Underlying bidi-gRPC stream is not open")
295+
await self.read_obj_str.close()
296+
self._is_stream_open = False
297+
298+
@property
299+
def is_stream_open(self) -> bool:
300+
return self._is_stream_open

tests/unit/asyncio/test_async_multi_range_downloader.py

Lines changed: 99 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,14 @@
3030

3131

3232
class TestAsyncMultiRangeDownloader:
33+
def create_read_ranges(self, num_ranges):
34+
ranges = []
35+
for i in range(num_ranges):
36+
ranges.append(
37+
_storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i)
38+
)
39+
return ranges
40+
3341
# helper method
3442
@pytest.mark.asyncio
3543
async def _make_mock_mrd(
@@ -76,13 +84,24 @@ async def test_create_mrd(
7684
read_handle=_TEST_READ_HANDLE,
7785
)
7886

87+
mrd.read_obj_str.open.assert_called_once()
88+
# Assert
89+
mock_cls_async_read_object_stream.assert_called_once_with(
90+
client=mock_grpc_client,
91+
bucket_name=_TEST_BUCKET_NAME,
92+
object_name=_TEST_OBJECT_NAME,
93+
generation_number=_TEST_GENERATION_NUMBER,
94+
read_handle=_TEST_READ_HANDLE,
95+
)
96+
7997
mrd.read_obj_str.open.assert_called_once()
8098

8199
assert mrd.client == mock_grpc_client
82100
assert mrd.bucket_name == _TEST_BUCKET_NAME
83101
assert mrd.object_name == _TEST_OBJECT_NAME
84102
assert mrd.generation_number == _TEST_GENERATION_NUMBER
85103
assert mrd.read_handle == _TEST_READ_HANDLE
104+
assert mrd.is_stream_open
86105

87106
@mock.patch(
88107
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
@@ -131,14 +150,6 @@ async def test_download_ranges(
131150
assert results[0].bytes_written == 18
132151
assert buffer.getvalue() == b"these_are_18_chars"
133152

134-
def create_read_ranges(self, num_ranges):
135-
ranges = []
136-
for i in range(num_ranges):
137-
ranges.append(
138-
_storage_v2.ReadRange(read_offset=i, read_length=1, read_id=i)
139-
)
140-
return ranges
141-
142153
@mock.patch(
143154
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
144155
)
@@ -160,3 +171,83 @@ async def test_downloading_ranges_with_more_than_1000_should_throw_error(
160171
str(exc.value)
161172
== "Invalid input - length of read_ranges cannot be more than 1000"
162173
)
174+
175+
@mock.patch(
176+
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
177+
)
178+
@mock.patch(
179+
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
180+
)
181+
@pytest.mark.asyncio
182+
async def test_opening_mrd_more_than_once_should_throw_error(
183+
self, mock_grpc_client, mock_cls_async_read_object_stream
184+
):
185+
# Arrange
186+
mrd = await self._make_mock_mrd(
187+
mock_grpc_client, mock_cls_async_read_object_stream
188+
) # mock mrd is already opened
189+
190+
# Act + Assert
191+
with pytest.raises(ValueError) as exc:
192+
await mrd.open()
193+
194+
# Assert
195+
assert str(exc.value) == "Underlying bidi-gRPC stream is already open"
196+
197+
@mock.patch(
198+
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader._AsyncReadObjectStream"
199+
)
200+
@mock.patch(
201+
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
202+
)
203+
@pytest.mark.asyncio
204+
async def test_close_mrd(self, mock_grpc_client, mock_cls_async_read_object_stream):
205+
# Arrange
206+
mrd = await self._make_mock_mrd(
207+
mock_grpc_client, mock_cls_async_read_object_stream
208+
) # mock mrd is already opened
209+
mrd.read_obj_str.close = AsyncMock()
210+
211+
# Act
212+
await mrd.close()
213+
214+
# Assert
215+
assert not mrd.is_stream_open
216+
217+
@mock.patch(
218+
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
219+
)
220+
@pytest.mark.asyncio
221+
async def test_close_mrd_not_opened_should_throw_error(self, mock_grpc_client):
222+
# Arrange
223+
mrd = AsyncMultiRangeDownloader(
224+
mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME
225+
)
226+
227+
# Act + Assert
228+
with pytest.raises(ValueError) as exc:
229+
await mrd.close()
230+
231+
# Assert
232+
assert str(exc.value) == "Underlying bidi-gRPC stream is not open"
233+
assert not mrd.is_stream_open
234+
235+
@mock.patch(
236+
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
237+
)
238+
@pytest.mark.asyncio
239+
async def test_downloading_without_opening_should_throw_error(
240+
self, mock_grpc_client
241+
):
242+
# Arrange
243+
mrd = AsyncMultiRangeDownloader(
244+
mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME
245+
)
246+
247+
# Act + Assert
248+
with pytest.raises(ValueError) as exc:
249+
await mrd.download_ranges([(0, 18, BytesIO())])
250+
251+
# Assert
252+
assert str(exc.value) == "Underlying bidi-gRPC stream is not open"
253+
assert not mrd.is_stream_open

0 commit comments

Comments
 (0)