diff --git a/src/executorlib/standalone/interactive/communication.py b/src/executorlib/standalone/interactive/communication.py index 68c6379a..9f2c132d 100644 --- a/src/executorlib/standalone/interactive/communication.py +++ b/src/executorlib/standalone/interactive/communication.py @@ -1,6 +1,7 @@ import logging import sys from socket import gethostname +from time import time from typing import Any, Callable, Optional import cloudpickle @@ -67,20 +68,27 @@ def send_dict(self, input_dict: dict): self._logger.warning("Send dictionary of size: " + str(sys.getsizeof(data))) self._socket.send(data) - def receive_dict(self) -> dict: + def receive_dict(self, timeout: Optional[int] = None) -> dict: """ Receive a dictionary from a connected client process. + Args: + timeout (int, optional): Time out for waiting for a message on socket in seconds. If None is provided, the + default time out set during initialization is used. + Returns: dict: dictionary with response received from the connected client """ response_lst: list[tuple[Any, int]] = [] + time_start = time() while len(response_lst) == 0: response_lst = self._poller.poll(self._time_out_ms) if not self._spawner.poll(): raise ExecutorlibSocketError( "SocketInterface crashed during execution." ) + if timeout is not None and (time() - time_start) > timeout: + raise TimeoutError("SocketInterface reached timeout.") data = self._socket.recv(zmq.NOBLOCK) if self._logger is not None: self._logger.warning( @@ -92,19 +100,23 @@ def receive_dict(self) -> dict: else: raise output["error"] - def send_and_receive_dict(self, input_dict: dict) -> dict: + def send_and_receive_dict( + self, input_dict: dict, timeout: Optional[int] = None + ) -> dict: """ Combine both the send_dict() and receive_dict() function in a single call. Args: input_dict (dict): dictionary of commands to be communicated. The key "shutdown" is reserved to stop the connected client from listening. + timeout (int, optional): Time out for waiting for a message on socket in seconds. If None is provided, the + default time out set during initialization is used. Returns: dict: dictionary with response received from the connected client """ self.send_dict(input_dict=input_dict) - return self.receive_dict() + return self.receive_dict(timeout=timeout) def bind_to_random_port(self) -> int: """ diff --git a/src/executorlib/task_scheduler/file/shared.py b/src/executorlib/task_scheduler/file/shared.py index 177e28cd..019c3a59 100644 --- a/src/executorlib/task_scheduler/file/shared.py +++ b/src/executorlib/task_scheduler/file/shared.py @@ -2,6 +2,7 @@ import os import queue from concurrent.futures import Future +from time import time from typing import Any, Callable, Optional from executorlib.standalone.command import get_cache_execute_command @@ -81,6 +82,7 @@ def execute_tasks_h5( process_dict: dict = {} cache_dir_dict: dict = {} file_name_dict: dict = {} + timeout_dict: dict = {} while True: task_dict = None with contextlib.suppress(queue.Empty): @@ -97,19 +99,22 @@ def execute_tasks_h5( for key, value in memory_dict.items() if not value.done() } + _check_timeout(timeout_dict=timeout_dict, memory_dict=memory_dict) if ( terminate_function is not None and terminate_function == terminate_subprocess ): - for task in process_dict.values(): - terminate_function(task=task) + for task_key, task in process_dict.items(): + if task_key not in timeout_dict: + terminate_function(task=task) elif terminate_function is not None: - for queue_id in process_dict.values(): - terminate_function( - queue_id=queue_id, - config_directory=pysqa_config_directory, - backend=backend, - ) + for task_key, queue_id in process_dict.items(): + if task_key not in timeout_dict: + terminate_function( + queue_id=queue_id, + config_directory=pysqa_config_directory, + backend=backend, + ) future_queue.task_done() future_queue.join() break @@ -123,9 +128,9 @@ def execute_tasks_h5( task_resource_dict.update( {k: v for k, v in resource_dict.items() if k not in task_resource_dict} ) - cache_key = task_resource_dict.pop("cache_key", None) - cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory")) - error_log_file = task_resource_dict.pop("error_log_file", None) + cache_key, cache_directory, error_log_file, timeout = ( + _get_resource_parameters(task_resource_dict=task_resource_dict) + ) task_key, data_dict = serialize_funct( fn=task_dict["fn"], fn_args=task_args, @@ -170,6 +175,8 @@ def execute_tasks_h5( backend=backend, cache_directory=cache_directory, ) + if timeout is not None: + timeout_dict[task_key] = time() + timeout file_name_dict[task_key] = os.path.join( cache_directory, task_key + "_o.h5" ) @@ -186,6 +193,7 @@ def execute_tasks_h5( for key, value in memory_dict.items() if not value.done() } + _check_timeout(timeout_dict=timeout_dict, memory_dict=memory_dict) def _check_task_output( @@ -259,3 +267,26 @@ def _convert_args_and_kwargs( else: task_kwargs[key] = arg return task_args, task_kwargs, future_wait_key_lst + + +def _check_timeout(timeout_dict: dict, memory_dict: dict) -> None: + if ( + len(timeout_dict) > 0 + and all(time() > timeout for timeout in timeout_dict.values()) + and all(key in timeout_dict for key in memory_dict) + ): + for key, future in memory_dict.items(): + if key in timeout_dict: + future.set_exception( + TimeoutError("Task execution exceeded the specified timeout.") + ) + + +def _get_resource_parameters( + task_resource_dict: dict, +) -> tuple[Optional[str], str, Optional[str], Optional[int]]: + cache_key = task_resource_dict.pop("cache_key", None) + cache_directory = os.path.abspath(task_resource_dict.pop("cache_directory")) + error_log_file = task_resource_dict.pop("error_log_file", None) + timeout = task_resource_dict.pop("timeout", None) + return cache_key, cache_directory, error_log_file, timeout diff --git a/src/executorlib/task_scheduler/interactive/blockallocation.py b/src/executorlib/task_scheduler/interactive/blockallocation.py index d093b77d..5c281616 100644 --- a/src/executorlib/task_scheduler/interactive/blockallocation.py +++ b/src/executorlib/task_scheduler/interactive/blockallocation.py @@ -72,6 +72,7 @@ def __init__( executor_kwargs["future_queue"] = self._future_queue executor_kwargs["spawner"] = spawner executor_kwargs["queue_join_on_shutdown"] = False + timeout = executor_kwargs.pop("timeout", None) self._process_kwargs = executor_kwargs self._max_workers = max_workers self_id = random.getrandbits(128) @@ -85,6 +86,7 @@ def __init__( | { "worker_id": worker_id, "stop_function": lambda: _interrupt_bootup_dict[self_id], + "timeout": timeout, }, ) for worker_id in range(self._max_workers) @@ -211,6 +213,7 @@ def _execute_multiple_tasks( worker_id: Optional[int] = None, stop_function: Optional[Callable] = None, restart_limit: int = 0, + timeout: Optional[int] = None, **kwargs, ) -> None: """ @@ -239,6 +242,8 @@ def _execute_multiple_tasks( distribution. stop_function (Callable): Function to stop the interface. restart_limit (int): The maximum number of restarting worker processes. + timeout (int, optional): Time out for waiting for a message on socket in seconds. If None is provided, the + default time out set during initialization is used. """ interface = interface_bootup( command_lst=get_interactive_execute_command( @@ -283,6 +288,8 @@ def _execute_multiple_tasks( f.set_exception(exception=interface_initialization_exception) else: # The interface failed during the execution + if timeout is not None: + task_dict["timeout"] = timeout interface.status = execute_task_dict( task_dict=task_dict, future_obj=f, diff --git a/src/executorlib/task_scheduler/interactive/onetoone.py b/src/executorlib/task_scheduler/interactive/onetoone.py index d303ea94..95549c6e 100644 --- a/src/executorlib/task_scheduler/interactive/onetoone.py +++ b/src/executorlib/task_scheduler/interactive/onetoone.py @@ -190,6 +190,8 @@ def _wrap_execute_task_in_separate_process( dictionary containing the future objects and the number of cores they require """ resource_dict = task_dict.pop("resource_dict").copy() + if "timeout" in resource_dict: + task_dict["timeout"] = resource_dict.pop("timeout") f = task_dict.pop("future") if "cores" not in resource_dict or ( resource_dict["cores"] == 1 and executor_kwargs["cores"] >= 1 diff --git a/src/executorlib/task_scheduler/interactive/shared.py b/src/executorlib/task_scheduler/interactive/shared.py index e4084222..7b7cd73b 100644 --- a/src/executorlib/task_scheduler/interactive/shared.py +++ b/src/executorlib/task_scheduler/interactive/shared.py @@ -102,7 +102,12 @@ def _execute_task_without_cache( bool: True if the task was submitted successfully, False otherwise. """ try: - future_obj.set_result(interface.send_and_receive_dict(input_dict=task_dict)) + future_obj.set_result( + interface.send_and_receive_dict( + input_dict=task_dict, + timeout=_get_timeout_from_task_dict(task_dict=task_dict), + ) + ) except Exception as thread_exception: if isinstance(thread_exception, ExecutorlibSocketError): return False @@ -143,7 +148,10 @@ def _execute_task_with_cache( if file_name not in get_cache_files(cache_directory=cache_directory): try: time_start = time.time() - result = interface.send_and_receive_dict(input_dict=task_dict) + result = interface.send_and_receive_dict( + input_dict=task_dict, + timeout=_get_timeout_from_task_dict(task_dict=task_dict), + ) data_dict["output"] = result data_dict["runtime"] = time.time() - time_start dump(file_name=file_name, data_dict=data_dict) @@ -157,3 +165,20 @@ def _execute_task_with_cache( _, _, result = get_output(file_name=file_name) future_obj.set_result(result) return True + + +def _get_timeout_from_task_dict(task_dict: dict) -> Optional[int]: + """ + Extract timeout value from the task_dict if present. + + Args: + task_dict (dict): task submitted to the executor as dictionary. This dictionary has the following keys + {"fn": Callable, "args": (), "kwargs": {}, "resource_dict": {}} + + Returns: + Optional[int]: timeout value if present in the resource_dict, None otherwise. + """ + if "timeout" in task_dict: + return task_dict.pop("timeout") + else: + return None diff --git a/tests/test_cache_fileexecutor_serial.py b/tests/test_cache_fileexecutor_serial.py index ce19e2fa..a92e82ca 100644 --- a/tests/test_cache_fileexecutor_serial.py +++ b/tests/test_cache_fileexecutor_serial.py @@ -31,6 +31,11 @@ def get_error(a): raise ValueError(a) +def reply(i): + sleep(1) + return i + + @unittest.skipIf( skip_h5py_test, "h5py is not installed, so the h5py tests are skipped." ) @@ -57,6 +62,13 @@ def test_executor_dependence_mixed(self): self.assertEqual(fs2.result(), 4) self.assertTrue(fs2.done()) + def test_executor_timeout(self): + with FileTaskScheduler(execute_function=execute_in_subprocess) as exe: + fs1 = exe.submit(reply, 2, resource_dict={"timeout": 0.01}) + with self.assertRaises(TimeoutError): + fs1.result() + self.assertTrue(fs1.done()) + def test_create_file_executor_error(self): with self.assertRaises(TypeError): create_file_executor() diff --git a/tests/test_mpiexecspawner.py b/tests/test_mpiexecspawner.py index 2f0dceac..4ecb1113 100644 --- a/tests/test_mpiexecspawner.py +++ b/tests/test_mpiexecspawner.py @@ -77,6 +77,18 @@ def test_pympiexecutor_two_workers(self): self.assertTrue(fs_1.done()) self.assertTrue(fs_2.done()) + def test_pympiexecutor_timeout(self): + with BlockAllocationTaskScheduler( + max_workers=2, + executor_kwargs={"timeout": 0.01}, + spawner=MpiExecSpawner, + ) as exe: + cloudpickle_register(ind=1) + fs_1 = exe.submit(sleep_one, 1) + with self.assertRaises(TimeoutError): + fs_1.result() + self.assertTrue(fs_1.done()) + def test_max_workers(self): with BlockAllocationTaskScheduler( max_workers=2, diff --git a/tests/test_singlenodeexecutor_noblock.py b/tests/test_singlenodeexecutor_noblock.py index b0606412..92e948a6 100644 --- a/tests/test_singlenodeexecutor_noblock.py +++ b/tests/test_singlenodeexecutor_noblock.py @@ -10,6 +10,11 @@ def calc(i): return i +def reply(i): + sleep(1) + return i + + def resource_dict(resource_dict): return resource_dict @@ -70,6 +75,17 @@ def test_meta_executor_single(self): self.assertTrue(fs_1.done()) self.assertTrue(fs_2.done()) + def test_time_out(self): + with SingleNodeExecutor( + max_cores=1, + block_allocation=False, + ) as exe: + cloudpickle_register(ind=1) + fs_1 = exe.submit(reply, 1, resource_dict={"timeout": 0.01}) + with self.assertRaises(TimeoutError): + fs_1.result() + self.assertTrue(fs_1.done()) + def test_errors(self): with self.assertRaises(TypeError): SingleNodeExecutor( @@ -120,6 +136,18 @@ def test_init_function(self): worker_id = exe.submit(get_worker_id, resource_dict={}).result() self.assertEqual(worker_id, 0) + def test_time_out(self): + with SingleNodeExecutor( + max_cores=1, + block_allocation=True, + resource_dict={"timeout": 0.01}, + ) as exe: + cloudpickle_register(ind=1) + fs_1 = exe.submit(reply, 1) + with self.assertRaises(TimeoutError): + fs_1.result() + self.assertTrue(fs_1.done()) + def test_init_function_two_workers(self): with SingleNodeExecutor( max_cores=2, diff --git a/tests/test_standalone_interactive_communication.py b/tests/test_standalone_interactive_communication.py index c1303068..04f4f990 100644 --- a/tests/test_standalone_interactive_communication.py +++ b/tests/test_standalone_interactive_communication.py @@ -27,10 +27,16 @@ def calc(i): return np.array(i**2) +def reply(i): + sleep(1) + return i + + class BrokenSpawner(MpiExecSpawner): def bootup(self, command_lst: list[str], stop_function: Optional[Callable] = None,): return False + class TestInterface(unittest.TestCase): @unittest.skipIf( skip_mpi4py_test, "mpi4py is not installed, so the mpi4py tests are skipped." @@ -96,6 +102,36 @@ def test_interface_serial_without_debug(self): ) interface.shutdown(wait=True) + def test_interface_serial_with_timeout(self): + cloudpickle_register(ind=1) + task_dict = {"fn": calc, "args": (), "kwargs": {"i": 2}} + interface = SocketInterface( + spawner=MpiExecSpawner(cwd=None, cores=1, openmpi_oversubscribe=False), + log_obj_size=False, + ) + interface.bootup( + command_lst=[ + sys.executable, + os.path.abspath( + os.path.join( + __file__, + "..", + "..", + "src", + "executorlib", + "backend", + "interactive_serial.py", + ) + ), + "--zmqport", + str(interface.bind_to_random_port()), + ] + ) + self.assertTrue(interface.status) + with self.assertRaises(TimeoutError): + interface.send_and_receive_dict(input_dict=task_dict, timeout=0.01) + interface.shutdown(wait=True) + def test_interface_serial_with_debug(self): cloudpickle_register(ind=1) task_dict = {"fn": calc, "args": (), "kwargs": {"i": 2}}