diff --git a/src/executorlib/executor/flux.py b/src/executorlib/executor/flux.py index 9682ef9a..ebe13995 100644 --- a/src/executorlib/executor/flux.py +++ b/src/executorlib/executor/flux.py @@ -13,23 +13,16 @@ check_wait_on_shutdown, validate_number_of_cores, ) +from executorlib.standalone.validate import ( + validate_resource_dict, + validate_resource_dict_with_optional_keys, +) from executorlib.task_scheduler.interactive.blockallocation import ( BlockAllocationTaskScheduler, ) from executorlib.task_scheduler.interactive.dependency import DependencyTaskScheduler from executorlib.task_scheduler.interactive.onetoone import OneProcessTaskScheduler -try: - from executorlib.standalone.validate import ( - validate_resource_dict, - validate_resource_dict_with_optional_keys, - ) -except ImportError: - from executorlib.task_scheduler.base import validate_resource_dict - from executorlib.task_scheduler.base import ( - validate_resource_dict as validate_resource_dict_with_optional_keys, - ) - class FluxJobExecutor(BaseExecutor): """ diff --git a/src/executorlib/executor/single.py b/src/executorlib/executor/single.py index 03045330..d034c24f 100644 --- a/src/executorlib/executor/single.py +++ b/src/executorlib/executor/single.py @@ -12,23 +12,16 @@ validate_number_of_cores, ) from executorlib.standalone.interactive.spawner import MpiExecSpawner +from executorlib.standalone.validate import ( + validate_resource_dict, + validate_resource_dict_with_optional_keys, +) from executorlib.task_scheduler.interactive.blockallocation import ( BlockAllocationTaskScheduler, ) from executorlib.task_scheduler.interactive.dependency import DependencyTaskScheduler from executorlib.task_scheduler.interactive.onetoone import OneProcessTaskScheduler -try: - from executorlib.standalone.validate import ( - validate_resource_dict, - validate_resource_dict_with_optional_keys, - ) -except ImportError: - from executorlib.task_scheduler.base import validate_resource_dict - from executorlib.task_scheduler.base import ( - validate_resource_dict as validate_resource_dict_with_optional_keys, - ) - class SingleNodeExecutor(BaseExecutor): """ diff --git a/src/executorlib/executor/slurm.py b/src/executorlib/executor/slurm.py index 256f49a3..1af3ab3a 100644 --- a/src/executorlib/executor/slurm.py +++ b/src/executorlib/executor/slurm.py @@ -10,6 +10,10 @@ check_wait_on_shutdown, validate_number_of_cores, ) +from executorlib.standalone.validate import ( + validate_resource_dict, + validate_resource_dict_with_optional_keys, +) from executorlib.task_scheduler.interactive.blockallocation import ( BlockAllocationTaskScheduler, ) @@ -20,17 +24,6 @@ validate_max_workers, ) -try: - from executorlib.standalone.validate import ( - validate_resource_dict, - validate_resource_dict_with_optional_keys, - ) -except ImportError: - from executorlib.task_scheduler.base import validate_resource_dict - from executorlib.task_scheduler.base import ( - validate_resource_dict as validate_resource_dict_with_optional_keys, - ) - class SlurmClusterExecutor(BaseExecutor): """ diff --git a/src/executorlib/standalone/validate.py b/src/executorlib/standalone/validate.py index e61e3a6e..c389fed6 100644 --- a/src/executorlib/standalone/validate.py +++ b/src/executorlib/standalone/validate.py @@ -1,7 +1,16 @@ import warnings from typing import Optional -from pydantic import BaseModel, Extra +try: + from pydantic import BaseModel, Extra + + HAS_PYDANTIC = True +except ImportError: + from dataclasses import dataclass + + BaseModel = object + Extra = None + HAS_PYDANTIC = False class ResourceDictValidation(BaseModel): @@ -17,8 +26,22 @@ class ResourceDictValidation(BaseModel): priority: Optional[int] = None slurm_cmd_args: Optional[list[str]] = None - class Config: - extra = Extra.forbid + if HAS_PYDANTIC: + + class Config: + extra = Extra.forbid + + +if not HAS_PYDANTIC: + ResourceDictValidation = dataclass(ResourceDictValidation) # type: ignore + + +def _get_accepted_keys(class_type) -> list[str]: + if hasattr(class_type, "model_fields"): + return list(class_type.model_fields.keys()) + elif hasattr(class_type, "__dataclass_fields__"): + return list(class_type.__dataclass_fields__.keys()) + raise TypeError("Unsupported class type for validation") def validate_resource_dict(resource_dict: dict) -> None: @@ -26,7 +49,7 @@ def validate_resource_dict(resource_dict: dict) -> None: def validate_resource_dict_with_optional_keys(resource_dict: dict) -> None: - accepted_keys = ResourceDictValidation.model_fields.keys() + accepted_keys = _get_accepted_keys(class_type=ResourceDictValidation) optional_lst = [key for key in resource_dict if key not in accepted_keys] validate_dict = { key: value for key, value in resource_dict.items() if key in accepted_keys diff --git a/tests/unit/standalone/test_validate.py b/tests/unit/standalone/test_validate.py index 03802faf..f66a7bd5 100644 --- a/tests/unit/standalone/test_validate.py +++ b/tests/unit/standalone/test_validate.py @@ -12,63 +12,55 @@ skip_pydantic_test = True -class TestValidateImport(unittest.TestCase): - def test_single_node_executor(self): +class TestValidateFallback(unittest.TestCase): + def test_validate_resource_dict_fallback(self): with patch.dict('sys.modules', {'pydantic': None}): if 'executorlib.standalone.validate' in sys.modules: del sys.modules['executorlib.standalone.validate'] - if 'executorlib.executor.single' in sys.modules: - del sys.modules['executorlib.executor.single'] - import executorlib.executor.single - importlib.reload(executorlib.executor.single) + from executorlib.standalone.validate import validate_resource_dict, ResourceDictValidation + from dataclasses import is_dataclass - from executorlib.executor.single import validate_resource_dict - - source_file = inspect.getfile(validate_resource_dict) - if os.name == 'nt': - self.assertTrue(source_file.endswith('task_scheduler\\base.py')) - else: - self.assertTrue(source_file.endswith('task_scheduler/base.py')) - self.assertIsNone(validate_resource_dict({"any": "thing"})) + self.assertTrue(is_dataclass(ResourceDictValidation)) - def test_flux_job_executor(self): - with patch.dict('sys.modules', {'pydantic': None}): - if 'executorlib.standalone.validate' in sys.modules: - del sys.modules['executorlib.standalone.validate'] - if 'executorlib.executor.flux' in sys.modules: - del sys.modules['executorlib.executor.flux'] + # Valid dict + self.assertIsNone(validate_resource_dict({"cores": 1})) - import executorlib.executor.flux - importlib.reload(executorlib.executor.flux) + # Invalid dict (extra key) + with self.assertRaises(TypeError): + validate_resource_dict({"invalid_key": 1}) - from executorlib.executor.flux import validate_resource_dict - - source_file = inspect.getfile(validate_resource_dict) - if os.name == 'nt': - self.assertTrue(source_file.endswith('task_scheduler\\base.py')) - else: - self.assertTrue(source_file.endswith('task_scheduler/base.py')) - self.assertIsNone(validate_resource_dict({"any": "thing"})) - - def test_slurm_job_executor(self): + def test_validate_resource_dict_with_optional_keys_fallback(self): with patch.dict('sys.modules', {'pydantic': None}): if 'executorlib.standalone.validate' in sys.modules: del sys.modules['executorlib.standalone.validate'] - if 'executorlib.executor.slurm' in sys.modules: - del sys.modules['executorlib.executor.slurm'] - import executorlib.executor.slurm - importlib.reload(executorlib.executor.slurm) + from executorlib.standalone.validate import validate_resource_dict_with_optional_keys + + # Valid dict with optional keys + with self.assertWarns(UserWarning): + validate_resource_dict_with_optional_keys({"cores": 1, "optional_key": 2}) + + def test_get_accepted_keys(self): + from executorlib.standalone.validate import _get_accepted_keys, ResourceDictValidation - from executorlib.executor.slurm import validate_resource_dict - - source_file = inspect.getfile(validate_resource_dict) - if os.name == 'nt': - self.assertTrue(source_file.endswith('task_scheduler\\base.py')) - else: - self.assertTrue(source_file.endswith('task_scheduler/base.py')) - self.assertIsNone(validate_resource_dict({"any": "thing"})) + accepted_keys = _get_accepted_keys(ResourceDictValidation) + expected_keys = [ + "cores", + "threads_per_core", + "gpus_per_core", + "cwd", + "cache_key", + "num_nodes", + "exclusive", + "error_log_file", + "run_time_limit", + "priority", + "slurm_cmd_args" + ] + self.assertEqual(set(accepted_keys), set(expected_keys)) + with self.assertRaises(TypeError): + _get_accepted_keys(int) @unittest.skipIf(skip_pydantic_test, "pydantic is not installed") @@ -81,4 +73,4 @@ def dummy_function(i): with SingleNodeExecutor() as exe: with self.assertRaises(ValidationError): - exe.submit(dummy_function, 5, resource_dict={"any": "thing"}) \ No newline at end of file + exe.submit(dummy_function, 5, resource_dict={"any": "thing"})