diff --git a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py index 1d1c63efd..5981ff7fc 100644 --- a/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py +++ b/google/cloud/storage/_experimental/asyncio/async_multi_range_downloader.py @@ -209,6 +209,16 @@ def __init__( self._download_ranges_id_to_pending_read_ids = {} self.persisted_size: Optional[int] = None # updated after opening the stream + async def __aenter__(self): + """Opens the underlying bidi-gRPC connection to read from the object.""" + await self.open() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """Closes the underlying bidi-gRPC connection.""" + if self.is_stream_open: + await self.close() + def _on_open_error(self, exc): """Extracts routing token and read handle on redirect error during open.""" routing_token, read_handle = _handle_redirect(exc) diff --git a/tests/system/test_zonal.py b/tests/system/test_zonal.py index 40c84a2fb..8019156dd 100644 --- a/tests/system/test_zonal.py +++ b/tests/system/test_zonal.py @@ -116,19 +116,19 @@ async def _run(): assert object_metadata.size == object_size assert int(object_metadata.checksums.crc32c) == object_checksum - mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name) buffer = BytesIO() - await mrd.open() - # (0, 0) means read the whole object - await mrd.download_ranges([(0, 0, buffer)]) - await mrd.close() + async with AsyncMultiRangeDownloader( + grpc_client, _ZONAL_BUCKET, object_name + ) as mrd: + # (0, 0) means read the whole object + await mrd.download_ranges([(0, 0, buffer)]) + assert mrd.persisted_size == object_size + assert buffer.getvalue() == object_data - assert mrd.persisted_size == object_size # Clean up; use json client (i.e. `storage_client` fixture) to delete. blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name)) del writer - del mrd gc.collect() event_loop.run_until_complete(_run()) diff --git a/tests/unit/asyncio/test_async_multi_range_downloader.py b/tests/unit/asyncio/test_async_multi_range_downloader.py index 2f0600f8d..5a8d6c6f1 100644 --- a/tests/unit/asyncio/test_async_multi_range_downloader.py +++ b/tests/unit/asyncio/test_async_multi_range_downloader.py @@ -401,3 +401,44 @@ async def test_download_ranges_raises_on_checksum_mismatch( assert "Checksum mismatch" in str(exc_info.value) mock_checksum_class.assert_called_once_with(test_data) + + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.open", + new_callable=AsyncMock, + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.close", + new_callable=AsyncMock, + ) + @mock.patch( + "google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client" + ) + @pytest.mark.asyncio + async def test_async_context_manager_calls_open_and_close( + self, mock_grpc_client, mock_close, mock_open + ): + # Arrange + mrd = AsyncMultiRangeDownloader( + mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME + ) + + # To simulate the behavior of open and close changing the stream state + async def open_side_effect(): + mrd._is_stream_open = True + + async def close_side_effect(): + mrd._is_stream_open = False + + mock_open.side_effect = open_side_effect + mock_close.side_effect = close_side_effect + mrd._is_stream_open = False + + # Act + async with mrd as downloader: + # Assert + mock_open.assert_called_once() + assert downloader == mrd + assert mrd.is_stream_open + + mock_close.assert_called_once() + assert not mrd.is_stream_open