Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,16 @@ def __init__(
self._download_ranges_id_to_pending_read_ids = {}
self.persisted_size: Optional[int] = None # updated after opening the stream

async def __aenter__(self):
"""Opens the underlying bidi-gRPC connection to read from the object."""
await self.open()
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Closes the underlying bidi-gRPC connection."""
if self.is_stream_open:
await self.close()

def _on_open_error(self, exc):
"""Extracts routing token and read handle on redirect error during open."""
routing_token, read_handle = _handle_redirect(exc)
Expand Down
14 changes: 7 additions & 7 deletions tests/system/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,19 +116,19 @@ async def _run():
assert object_metadata.size == object_size
assert int(object_metadata.checksums.crc32c) == object_checksum

mrd = AsyncMultiRangeDownloader(grpc_client, _ZONAL_BUCKET, object_name)
buffer = BytesIO()
await mrd.open()
# (0, 0) means read the whole object
await mrd.download_ranges([(0, 0, buffer)])
await mrd.close()
async with AsyncMultiRangeDownloader(
grpc_client, _ZONAL_BUCKET, object_name
) as mrd:
# (0, 0) means read the whole object
await mrd.download_ranges([(0, 0, buffer)])
assert mrd.persisted_size == object_size

assert buffer.getvalue() == object_data
assert mrd.persisted_size == object_size

# Clean up; use json client (i.e. `storage_client` fixture) to delete.
blobs_to_delete.append(storage_client.bucket(_ZONAL_BUCKET).blob(object_name))
del writer
del mrd
gc.collect()

event_loop.run_until_complete(_run())
Expand Down
41 changes: 41 additions & 0 deletions tests/unit/asyncio/test_async_multi_range_downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,3 +401,44 @@ async def test_download_ranges_raises_on_checksum_mismatch(

assert "Checksum mismatch" in str(exc_info.value)
mock_checksum_class.assert_called_once_with(test_data)

@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.open",
new_callable=AsyncMock,
)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_multi_range_downloader.AsyncMultiRangeDownloader.close",
new_callable=AsyncMock,
)
@mock.patch(
"google.cloud.storage._experimental.asyncio.async_grpc_client.AsyncGrpcClient.grpc_client"
)
@pytest.mark.asyncio
async def test_async_context_manager_calls_open_and_close(
self, mock_grpc_client, mock_close, mock_open
):
# Arrange
mrd = AsyncMultiRangeDownloader(
mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME
)

# To simulate the behavior of open and close changing the stream state
async def open_side_effect():
mrd._is_stream_open = True

async def close_side_effect():
mrd._is_stream_open = False

mock_open.side_effect = open_side_effect
mock_close.side_effect = close_side_effect
mrd._is_stream_open = False

# Act
async with mrd as downloader:
# Assert
mock_open.assert_called_once()
assert downloader == mrd
assert mrd.is_stream_open

mock_close.assert_called_once()
assert not mrd.is_stream_open
Comment on lines +417 to +444
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This is a good test for the happy path of the async context manager. To make it more robust, it would be beneficial to also test the behavior in edge cases.

Specifically, consider adding tests for:

  1. When open() fails (e.g., raises an exception). In this scenario, close() should not be called.
  2. When an exception is raised from within the async with block. In this case, close() should still be called to ensure cleanup.

Here are some examples of how you could structure these tests:

Test for open() failure:

@mock.patch("...")
@pytest.mark.asyncio
async def test_context_manager_no_close_on_open_failure(self, mock_grpc_client, mock_close, mock_open):
    mock_open.side_effect = ValueError("Failed to open")
    mrd = AsyncMultiRangeDownloader(
        mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME
    )
    
    with pytest.raises(ValueError, match="Failed to open"):
        async with mrd:
            pytest.fail("This block should not be executed.")

    mock_open.assert_called_once()
    mock_close.assert_not_called()

Test for exception within the block:

@mock.patch("...")
@pytest.mark.asyncio
async def test_context_manager_closes_on_exception(self, mock_grpc_client, mock_close, mock_open):
    mrd = AsyncMultiRangeDownloader(
        mock_grpc_client, _TEST_BUCKET_NAME, _TEST_OBJECT_NAME
    )
    # set up side effects for open/close as in the existing test
    async def open_side_effect():
        mrd._is_stream_open = True
    mock_open.side_effect = open_side_effect
    
    with pytest.raises(RuntimeError, match="Oops"):
        async with mrd:
            raise RuntimeError("Oops")

    mock_open.assert_called_once()
    mock_close.assert_called_once()