-
Notifications
You must be signed in to change notification settings - Fork 208
Add replica groups in dstack-service #3408
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
22c1410
5abbcad
abba7da
d974292
caa4283
1ec1d6d
7b4bc52
8c5589d
0a54e07
f4c9fdf
a0e13f6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ | |
| from dstack._internal.core.models.services import AnyModel, OpenAIChatModel | ||
| from dstack._internal.core.models.unix import UnixUser | ||
| from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point | ||
| from dstack._internal.core.services import validate_dstack_resource_name | ||
| from dstack._internal.utils.common import has_duplicates, list_enum_values_for_annotation | ||
| from dstack._internal.utils.json_schema import add_extra_schema_types | ||
| from dstack._internal.utils.json_utils import ( | ||
|
|
@@ -610,8 +611,13 @@ def convert_ports(cls, v) -> PortMapping: | |
| class ConfigurationWithCommandsParams(CoreModel): | ||
| commands: Annotated[CommandsList, Field(description="The shell commands to run")] = [] | ||
|
|
||
| @root_validator | ||
| @root_validator(pre=True) | ||
| def check_image_or_commands_present(cls, values): | ||
| # If replicas is list, skip validation - commands come from replica groups | ||
| replicas = values.get("replicas") | ||
| if isinstance(replicas, list): | ||
| return values | ||
|
|
||
| if not values.get("commands") and not values.get("image"): | ||
| raise ValueError("Either `commands` or `image` must be set") | ||
| return values | ||
|
|
@@ -714,6 +720,68 @@ def schema_extra(schema: Dict[str, Any]): | |
| ) | ||
|
|
||
|
|
||
| class ReplicaGroup(CoreModel): | ||
| name: Annotated[ | ||
| Optional[str], | ||
| Field( | ||
| description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position." | ||
| ), | ||
| ] | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| count: Annotated[ | ||
| Range[int], | ||
| Field( | ||
| description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " | ||
| "If it's a range, the `scaling` property is required" | ||
| ), | ||
| ] | ||
| scaling: Annotated[ | ||
| Optional[ScalingSpec], | ||
| Field(description="The auto-scaling rules. Required if `count` is set to a range"), | ||
| ] = None | ||
|
|
||
| resources: Annotated[ | ||
| ResourcesSpec, | ||
| Field(description="The resources requirements for replicas in this group"), | ||
| ] = ResourcesSpec() | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| commands: Annotated[ | ||
| CommandsList, | ||
| Field(description="The shell commands to run for replicas in this group"), | ||
| ] | ||
|
|
||
| @validator("name") | ||
| def validate_name(cls, v: Optional[str]) -> Optional[str]: | ||
| if v is not None: | ||
| validate_dstack_resource_name(v) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) This function raises |
||
| return v | ||
|
|
||
| @validator("count") | ||
| def convert_count(cls, v: Range[int]) -> Range[int]: | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError("The minimum number of replicas must be greater than or equal to 0") | ||
| return v | ||
|
|
||
| @validator("commands") | ||
| def validate_commands(cls, v: CommandsList) -> CommandsList: | ||
| if not v: | ||
| raise ValueError("`commands` must be set for replica groups") | ||
| return v | ||
|
Comment on lines
+768
to
+772
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Can replace with |
||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| count = values.get("count") | ||
| if count and count.min != count.max and not scaling: | ||
| raise ValueError("When you set `count` to a range, ensure to specify `scaling`.") | ||
| if count and count.min == count.max and scaling: | ||
| raise ValueError("To use `scaling`, `count` must be set to a range.") | ||
| return values | ||
|
|
||
|
|
||
| class ServiceConfigurationParams(CoreModel): | ||
| port: Annotated[ | ||
| # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. | ||
|
|
@@ -755,13 +823,7 @@ class ServiceConfigurationParams(CoreModel): | |
| SERVICE_HTTPS_DEFAULT | ||
| ) | ||
| auth: Annotated[bool, Field(description="Enable the authorization")] = True | ||
| replicas: Annotated[ | ||
| Range[int], | ||
| Field( | ||
| description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " | ||
| "If it's a range, the `scaling` property is required" | ||
| ), | ||
| ] = Range[int](min=1, max=1) | ||
|
|
||
| scaling: Annotated[ | ||
| Optional[ScalingSpec], | ||
| Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), | ||
|
|
@@ -772,6 +834,19 @@ class ServiceConfigurationParams(CoreModel): | |
| Field(description="List of probes used to determine job health"), | ||
| ] = [] | ||
|
|
||
| replicas: Annotated[ | ||
| Optional[Union[List[ReplicaGroup], Range[int]]], | ||
| Field( | ||
| description=( | ||
| "List of replica groups. Each group defines replicas with shared configuration " | ||
| "(commands, port, resources, scaling, probes, rate_limits). " | ||
| "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " | ||
| "`scaling`, `probes`, and `rate_limits` are ignored. " | ||
| "Each replica group must have a unique name." | ||
|
Comment on lines
+841
to
+845
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
| ) | ||
| ), | ||
| ] = None | ||
|
|
||
| @validator("port") | ||
| def convert_port(cls, v) -> PortMapping: | ||
| if isinstance(v, int): | ||
|
|
@@ -786,26 +861,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: | |
| return OpenAIChatModel(type="chat", name=v, format="openai") | ||
| return v | ||
|
|
||
| @validator("replicas") | ||
| def convert_replicas(cls, v: Range[int]) -> Range[int]: | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError("The minimum number of replicas must be greater than or equal to 0") | ||
| return v | ||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| replicas = values.get("replicas") | ||
| if replicas and replicas.min != replicas.max and not scaling: | ||
| raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") | ||
| if replicas and replicas.min == replicas.max and scaling: | ||
| raise ValueError("To use `scaling`, `replicas` must be set to a range.") | ||
| return values | ||
|
|
||
| @validator("rate_limits") | ||
| def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: | ||
| counts = Counter(limit.prefix for limit in v) | ||
|
|
@@ -827,6 +882,92 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: | |
| raise ValueError("Probes must be unique") | ||
| return v | ||
|
|
||
| @validator("replicas") | ||
| def validate_replicas( | ||
| cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] | ||
| ) -> Optional[Union[Range[int], List[ReplicaGroup]]]: | ||
| if v is None: | ||
| return v | ||
| if isinstance(v, Range): | ||
| if v.max is None: | ||
| raise ValueError("The maximum number of replicas is required") | ||
| if v.min is None: | ||
| v.min = 0 | ||
| if v.min < 0: | ||
| raise ValueError( | ||
| "The minimum number of replicas must be greater than or equal to 0" | ||
| ) | ||
| return v | ||
|
Comment on lines
+892
to
+900
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) Duplicated here and in |
||
|
|
||
| if isinstance(v, list): | ||
| if not v: | ||
| raise ValueError("`replicas` cannot be an empty list") | ||
|
|
||
| # Assign default names to groups without names | ||
| for index, group in enumerate(v): | ||
| if group.name is None: | ||
| group.name = f"replica{index}" | ||
|
|
||
| # Check for duplicate names | ||
| names = [group.name for group in v] | ||
| if len(names) != len(set(names)): | ||
| duplicates = [name for name in set(names) if names.count(name) > 1] | ||
| raise ValueError( | ||
| f"Duplicate replica group names found: {duplicates}. " | ||
| "Each replica group must have a unique name." | ||
| ) | ||
| return v | ||
|
|
||
| @root_validator() | ||
| def validate_scaling(cls, values): | ||
| scaling = values.get("scaling") | ||
| replicas = values.get("replicas") | ||
|
|
||
| if isinstance(replicas, Range): | ||
| if replicas and replicas.min != replicas.max and not scaling: | ||
| raise ValueError( | ||
| "When you set `replicas` to a range, ensure to specify `scaling`." | ||
| ) | ||
| if replicas and replicas.min == replicas.max and scaling: | ||
| raise ValueError("To use `scaling`, `replicas` must be set to a range.") | ||
| return values | ||
jvstme marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| @root_validator() | ||
| def validate_top_level_properties_with_replica_groups(cls, values): | ||
| """ | ||
| When replicas is a list of ReplicaGroup, forbid top-level scaling, commands, and resources | ||
| """ | ||
| replicas = values.get("replicas") | ||
|
|
||
| if not isinstance(replicas, list): | ||
| return values | ||
|
|
||
| scaling = values.get("scaling") | ||
| if scaling is not None: | ||
| raise ValueError( | ||
| "Top-level `scaling` is not allowed when `replicas` is a list. " | ||
| "Specify `scaling` in each replica group instead." | ||
| ) | ||
|
|
||
| commands = values.get("commands", []) | ||
| if commands: | ||
| raise ValueError( | ||
| "Top-level `commands` is not allowed when `replicas` is a list. " | ||
| "Specify `commands` in each replica group instead." | ||
| ) | ||
|
|
||
| resources = values.get("resources") | ||
| from dstack._internal.core.models.resources import ResourcesSpec | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) |
||
|
|
||
| default_resources = ResourcesSpec() | ||
| if resources and resources.dict() != default_resources.dict(): | ||
| raise ValueError( | ||
| "Top-level `resources` is not allowed when `replicas` is a list. " | ||
| "Specify `resources` in each replica group instead." | ||
| ) | ||
|
|
||
| return values | ||
|
|
||
|
|
||
| class ServiceConfigurationConfig( | ||
| ProfileParamsConfig, | ||
|
|
@@ -849,6 +990,38 @@ class ServiceConfiguration( | |
| ): | ||
| type: Literal["service"] = "service" | ||
|
|
||
| @property | ||
| def replica_groups(self) -> List[ReplicaGroup]: | ||
| """ | ||
| Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None. | ||
| Use this property for type-safe access in code. | ||
| """ | ||
|
Comment on lines
+995
to
+998
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) This docstring seems to be slightly outdated. We no longer have the validator that converts |
||
| if self.replicas is None: | ||
| return [ | ||
| ReplicaGroup( | ||
| name="default", | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) The string |
||
| count=Range[int](min=1, max=1), | ||
| commands=self.commands or [], | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (nit) |
||
| resources=self.resources, | ||
| scaling=self.scaling, | ||
| ) | ||
| ] | ||
| if isinstance(self.replicas, list): | ||
| return self.replicas | ||
| if isinstance(self.replicas, Range): | ||
| return [ | ||
| ReplicaGroup( | ||
| name="default", | ||
| count=self.replicas, | ||
| commands=self.commands or [], | ||
| resources=self.resources, | ||
| scaling=self.scaling, | ||
| ) | ||
| ] | ||
| raise ValueError( | ||
| f"Invalid replicas type: {type(self.replicas)}. Expected None, Range[int], or List[ReplicaGroup]" | ||
| ) | ||
|
|
||
|
|
||
| AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] | ||
|
|
||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(nit) Why
pre? I see a couple of negative side effects of making it aprevalidator.It now prevents some of the more obvious validators from running, so error messages can be misleading.
It now runs before converting
replicastolist, so, for example, passing them as a tuple doesn't work anymore: