Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions tests/perf/microbenchmarks/reads/test_reads.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,44 +300,46 @@ def target_wrapper(*args, **kwargs):
)


def _download_files_worker(files_to_download, other_params, chunks, bucket_type):
# For regional buckets, a new client must be created for each process.
# For zonal, the same is done for consistency.
# --- Global Variables for Worker Process ---
worker_loop = None
worker_client = None
worker_json_client = None


def _worker_init(bucket_type):
"""Initializes a persistent event loop and client for each worker process."""
global worker_loop, worker_client, worker_json_client
if bucket_type == "zonal":
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
client = loop.run_until_complete(create_client())
try:
# download_files_using_mrd_multi_coro returns max latency of coros
result = download_files_using_mrd_multi_coro(
loop, client, files_to_download, other_params, chunks
)
finally:
tasks = asyncio.all_tasks(loop=loop)
for task in tasks:
task.cancel()
loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True))
loop.close()
return result
worker_loop = asyncio.new_event_loop()
asyncio.set_event_loop(worker_loop)
worker_client = worker_loop.run_until_complete(create_client())
else: # regional
from google.cloud import storage

json_client = storage.Client()
worker_json_client = storage.Client()
Comment on lines +309 to +319
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The resources created in this initializer (worker_loop, worker_client, worker_json_client) are not being cleaned up. This can lead to resource leaks, such as open network connections.

While terminating worker processes will cause the OS to reclaim these resources, it's better to perform a graceful shutdown. You can use the atexit module to register cleanup functions that will be called when the worker process exits.

I suggest updating _worker_init to register cleanup functions. Please also add import atexit at the top of the file.

def _worker_init(bucket_type):
    """Initializes a persistent event loop and client for each worker process."""
    global worker_loop, worker_client, worker_json_client
    import atexit

    if bucket_type == "zonal":
        worker_loop = asyncio.new_event_loop()
        asyncio.set_event_loop(worker_loop)
        worker_client = worker_loop.run_until_complete(create_client())

        def _cleanup_zonal():
            # Ensure resources are cleaned up when the worker process exits.
            if worker_client and worker_loop and not worker_loop.is_closed():
                try:
                    worker_loop.run_until_complete(worker_client.close())
                finally:
                    worker_loop.close()

        atexit.register(_cleanup_zonal)
    else:  # regional
        from google.cloud import storage

        worker_json_client = storage.Client()

        def _cleanup_regional():
            # Ensure resources are cleaned up when the worker process exits.
            if worker_json_client:
                worker_json_client.close()

        atexit.register(_cleanup_regional)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

json_client doesn't have .close() so is the case with grpc_clent.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My bad, json has .close() and grpc gapic has , we need to expose it. File a bug - b/479135274



def _download_files_worker(files_to_download, other_params, chunks, bucket_type):
if bucket_type == "zonal":
# The loop and client are already initialized in _worker_init.
# download_files_using_mrd_multi_coro returns max latency of coros
return download_files_using_mrd_multi_coro(
worker_loop, worker_client, files_to_download, other_params, chunks
)
else: # regional
# download_files_using_json_multi_threaded returns max latency of threads
return download_files_using_json_multi_threaded(
None, json_client, files_to_download, other_params, chunks
None, worker_json_client, files_to_download, other_params, chunks
)


def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type):
num_processes = params.num_processes
def download_files_mp_mc_wrapper(pool, files_names, params, chunks, bucket_type):
num_coros = params.num_coros # This is n, number of files per process

# Distribute filenames to processes
filenames_per_process = [
files_names[i : i + num_coros] for i in range(0, len(files_names), num_coros)
]

args = [
(
filenames,
Expand All @@ -348,10 +350,7 @@ def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type):
for filenames in filenames_per_process
]

ctx = multiprocessing.get_context("spawn")
with ctx.Pool(processes=num_processes) as pool:
results = pool.starmap(_download_files_worker, args)

results = pool.starmap(_download_files_worker, args)
return max(results)


Expand Down Expand Up @@ -386,10 +385,16 @@ def test_downloads_multi_proc_multi_coro(
logging.info("randomizing chunks")
random.shuffle(chunks)

ctx = multiprocessing.get_context("spawn")
pool = ctx.Pool(
processes=params.num_processes,
initializer=_worker_init,
initargs=(params.bucket_type,),
)
output_times = []

def target_wrapper(*args, **kwargs):
result = download_files_mp_mc_wrapper(*args, **kwargs)
result = download_files_mp_mc_wrapper(pool, *args, **kwargs)
output_times.append(result)
return output_times

Expand All @@ -407,6 +412,8 @@ def target_wrapper(*args, **kwargs):
),
)
finally:
pool.close()
pool.join()
publish_benchmark_extra_info(benchmark, params, true_times=output_times)
publish_resource_metrics(benchmark, m)

Expand Down
2 changes: 1 addition & 1 deletion tests/system/test_zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -264,7 +264,7 @@ async def _run():

event_loop.run_until_complete(_run())


@pytest.mark.skip(reason='Flaky test b/478129078')
def test_mrd_open_with_read_handle(event_loop, grpc_client_direct):
object_name = f"test_read_handl-{str(uuid.uuid4())[:4]}"

Expand Down