diff --git a/tests/perf/microbenchmarks/reads/test_reads.py b/tests/perf/microbenchmarks/reads/test_reads.py index 13a0a49b3..324938e94 100644 --- a/tests/perf/microbenchmarks/reads/test_reads.py +++ b/tests/perf/microbenchmarks/reads/test_reads.py @@ -300,44 +300,46 @@ def target_wrapper(*args, **kwargs): ) -def _download_files_worker(files_to_download, other_params, chunks, bucket_type): - # For regional buckets, a new client must be created for each process. - # For zonal, the same is done for consistency. +# --- Global Variables for Worker Process --- +worker_loop = None +worker_client = None +worker_json_client = None + + +def _worker_init(bucket_type): + """Initializes a persistent event loop and client for each worker process.""" + global worker_loop, worker_client, worker_json_client if bucket_type == "zonal": - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - client = loop.run_until_complete(create_client()) - try: - # download_files_using_mrd_multi_coro returns max latency of coros - result = download_files_using_mrd_multi_coro( - loop, client, files_to_download, other_params, chunks - ) - finally: - tasks = asyncio.all_tasks(loop=loop) - for task in tasks: - task.cancel() - loop.run_until_complete(asyncio.gather(*tasks, return_exceptions=True)) - loop.close() - return result + worker_loop = asyncio.new_event_loop() + asyncio.set_event_loop(worker_loop) + worker_client = worker_loop.run_until_complete(create_client()) else: # regional from google.cloud import storage - json_client = storage.Client() + worker_json_client = storage.Client() + + +def _download_files_worker(files_to_download, other_params, chunks, bucket_type): + if bucket_type == "zonal": + # The loop and client are already initialized in _worker_init. + # download_files_using_mrd_multi_coro returns max latency of coros + return download_files_using_mrd_multi_coro( + worker_loop, worker_client, files_to_download, other_params, chunks + ) + else: # regional # download_files_using_json_multi_threaded returns max latency of threads return download_files_using_json_multi_threaded( - None, json_client, files_to_download, other_params, chunks + None, worker_json_client, files_to_download, other_params, chunks ) -def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type): - num_processes = params.num_processes +def download_files_mp_mc_wrapper(pool, files_names, params, chunks, bucket_type): num_coros = params.num_coros # This is n, number of files per process # Distribute filenames to processes filenames_per_process = [ files_names[i : i + num_coros] for i in range(0, len(files_names), num_coros) ] - args = [ ( filenames, @@ -348,10 +350,7 @@ def download_files_mp_mc_wrapper(files_names, params, chunks, bucket_type): for filenames in filenames_per_process ] - ctx = multiprocessing.get_context("spawn") - with ctx.Pool(processes=num_processes) as pool: - results = pool.starmap(_download_files_worker, args) - + results = pool.starmap(_download_files_worker, args) return max(results) @@ -386,10 +385,16 @@ def test_downloads_multi_proc_multi_coro( logging.info("randomizing chunks") random.shuffle(chunks) + ctx = multiprocessing.get_context("spawn") + pool = ctx.Pool( + processes=params.num_processes, + initializer=_worker_init, + initargs=(params.bucket_type,), + ) output_times = [] def target_wrapper(*args, **kwargs): - result = download_files_mp_mc_wrapper(*args, **kwargs) + result = download_files_mp_mc_wrapper(pool, *args, **kwargs) output_times.append(result) return output_times @@ -407,6 +412,8 @@ def target_wrapper(*args, **kwargs): ), ) finally: + pool.close() + pool.join() publish_benchmark_extra_info(benchmark, params, true_times=output_times) publish_resource_metrics(benchmark, m) diff --git a/tests/system/test_zonal.py b/tests/system/test_zonal.py index 8019156dd..eb9df582c 100644 --- a/tests/system/test_zonal.py +++ b/tests/system/test_zonal.py @@ -264,7 +264,7 @@ async def _run(): event_loop.run_until_complete(_run()) - +@pytest.mark.skip(reason='Flaky test b/478129078') def test_mrd_open_with_read_handle(event_loop, grpc_client_direct): object_name = f"test_read_handl-{str(uuid.uuid4())[:4]}"