diff --git a/tests/unit_tests/instances/test_instances.py b/tests/unit_tests/instances/test_instances.py index 53654d3..b9dbc0a 100644 --- a/tests/unit_tests/instances/test_instances.py +++ b/tests/unit_tests/instances/test_instances.py @@ -1,9 +1,10 @@ +import copy import json import pytest import responses -from verda.constants import Actions, ErrorCodes, Locations +from verda.constants import Actions, ErrorCodes, InstanceStatus, Locations from verda.exceptions import APIException from verda.instances import Instance, InstancesService, OSVolume @@ -333,6 +334,61 @@ def test_create_instance_attached_os_volume_successful(self, instances_service, assert responses.assert_call_count(endpoint, 1) is True assert responses.assert_call_count(url, 1) is True + @pytest.mark.parametrize( + ('wait_for_status', 'expected_status', 'expected_get_instance_call_count'), + [ + (None, InstanceStatus.ORDERED, 1), + (InstanceStatus.ORDERED, InstanceStatus.ORDERED, 1), + (InstanceStatus.PROVISIONING, InstanceStatus.PROVISIONING, 2), + (lambda status: status != InstanceStatus.ORDERED, InstanceStatus.PROVISIONING, 2), + (InstanceStatus.RUNNING, InstanceStatus.RUNNING, 3), + ], + ) + def test_create_wait_for_status( + self, + instances_service, + endpoint, + wait_for_status, + expected_status, + expected_get_instance_call_count, + ): + # arrange - add response mock + # create instance + responses.add(responses.POST, endpoint, body=INSTANCE_ID, status=200) + # First get instance by id - ordered + get_instance_url = endpoint + '/' + INSTANCE_ID + payload = copy.deepcopy(PAYLOAD[0]) + payload['status'] = InstanceStatus.ORDERED + responses.add(responses.GET, get_instance_url, json=payload, status=200) + # Second get instance by id - provisioning + payload = copy.deepcopy(PAYLOAD[0]) + payload['status'] = InstanceStatus.PROVISIONING + responses.add(responses.GET, get_instance_url, json=payload, status=200) + # Third get instance by id - running + payload = copy.deepcopy(PAYLOAD[0]) + payload['status'] = InstanceStatus.RUNNING + responses.add(responses.GET, get_instance_url, json=payload, status=200) + + # act + instance = instances_service.create( + instance_type=INSTANCE_TYPE, + image=OS_VOLUME_ID, + hostname=INSTANCE_HOSTNAME, + description=INSTANCE_DESCRIPTION, + wait_for_status=wait_for_status, + max_interval=0, + max_wait_time=1, + ) + + # assert + assert isinstance(instance, Instance) + assert instance.id == INSTANCE_ID + assert instance.status == expected_status + assert responses.assert_call_count(endpoint, 1) is True + assert ( + responses.assert_call_count(get_instance_url, expected_get_instance_call_count) is True + ) + def test_create_instance_failed(self, instances_service, endpoint): # arrange - add response mock responses.add( diff --git a/verda/instances/_instances.py b/verda/instances/_instances.py index 5fd3a48..a32add8 100644 --- a/verda/instances/_instances.py +++ b/verda/instances/_instances.py @@ -1,5 +1,6 @@ import itertools import time +from collections.abc import Callable from dataclasses import dataclass from typing import Literal @@ -150,6 +151,7 @@ def create( pricing: Pricing | None = None, coupon: str | None = None, *, + wait_for_status: str | Callable[[str], bool] | None = lambda s: s != InstanceStatus.ORDERED, max_wait_time: float = 180, initial_interval: float = 0.5, max_interval: float = 5, @@ -172,6 +174,7 @@ def create( contract: Optional contract type for the instance. pricing: Optional pricing model for the instance. coupon: Optional coupon code for discounts. + wait_for_status: Status to wait for the instance to reach, or callable that returns True when the desired status is reached. Default to any status other than ORDERED. If None, no wait is performed. max_wait_time: Maximum total wait for the instance to start provisioning, in seconds (default: 180) initial_interval: Initial interval, in seconds (default: 0.5) max_interval: The longest single delay allowed between retries, in seconds (default: 5) @@ -203,12 +206,18 @@ def create( payload['pricing'] = pricing id = self._http_client.post(INSTANCES_ENDPOINT, json=payload).text + if wait_for_status is None: + return self.get_by_id(id) + # Wait for instance to enter provisioning state with timeout # TODO(shamrin) extract backoff logic, _clusters module has the same code deadline = time.monotonic() + max_wait_time for i in itertools.count(): instance = self.get_by_id(id) - if instance.status != InstanceStatus.ORDERED: + if callable(wait_for_status): + if wait_for_status(instance.status): + return instance + elif instance.status == wait_for_status: return instance now = time.monotonic()