Skip to content
Open
66 changes: 62 additions & 4 deletions src/dstack/_internal/cli/utils/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,16 +281,38 @@ def _format_job_name(
show_deployment_num: bool,
show_replica: bool,
show_job: bool,
group_index: Optional[int] = None,
last_shown_group_index: Optional[int] = None,
) -> str:
name_parts = []
prefix = ""
if show_replica:
name_parts.append(f"replica={job.job_spec.replica_num}")
# Show group information if replica groups are used
if group_index is not None:
# Show group=X replica=Y when group changes, or just replica=Y when same group
if group_index != last_shown_group_index:
# First job in group: use 3 spaces indent
prefix = " "
name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}")
else:
# Subsequent job in same group: align "replica=" with first job's "replica="
# Calculate padding: width of " group={last_shown_group_index} "
padding_width = 3 + len(f"group={last_shown_group_index}") + 1
prefix = " " * padding_width
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
# Legacy behavior: no replica groups
prefix = " "
name_parts.append(f"replica={job.job_spec.replica_num}")
else:
prefix = " "

if show_job:
name_parts.append(f"job={job.job_spec.job_num}")
name_suffix = (
f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else ""
)
name_value = " " + (" ".join(name_parts) if name_parts else "")
name_value = prefix + (" ".join(name_parts) if name_parts else "")
name_value += name_suffix
return name_value

Expand Down Expand Up @@ -359,6 +381,17 @@ def get_runs_table(
)
merge_job_rows = len(run.jobs) == 1 and not show_deployment_num

group_name_to_index: Dict[str, int] = {}
if run.run_spec.configuration.type == "service" and hasattr(
run.run_spec.configuration, "replica_groups"
):
replica_groups = run.run_spec.configuration.replica_groups
if replica_groups:
for idx, group in enumerate(replica_groups):
# Use group name or default to "replica{idx}" if name is None
group_name = group.name or f"replica{idx}"
group_name_to_index[group_name] = idx

run_row: Dict[Union[str, int], Any] = {
"NAME": _format_run_name(run, show_deployment_num),
"SUBMITTED": format_date(run.submitted_at),
Expand All @@ -372,13 +405,35 @@ def get_runs_table(
if not merge_job_rows:
add_row_from_dict(table, run_row)

for job in run.jobs:
# Sort jobs by group index first, then by replica_num within each group
def get_job_sort_key(job: Job) -> tuple:
group_index = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)
# Use a large number for jobs without groups to put them at the end
return (group_index if group_index is not None else 999999, job.job_spec.replica_num)

sorted_jobs = sorted(run.jobs, key=get_job_sort_key)

last_shown_group_index: Optional[int] = None
for job in sorted_jobs:
latest_job_submission = job.job_submissions[-1]
status_formatted = _format_job_submission_status(latest_job_submission, verbose)

# Get group index for this job
group_index: Optional[int] = None
if group_name_to_index and job.job_spec.replica_group:
group_index = group_name_to_index.get(job.job_spec.replica_group)

job_row: Dict[Union[str, int], Any] = {
"NAME": _format_job_name(
job, latest_job_submission, show_deployment_num, show_replica, show_job
job,
latest_job_submission,
show_deployment_num,
show_replica,
show_job,
group_index=group_index,
last_shown_group_index=last_shown_group_index,
),
"STATUS": status_formatted,
"PROBES": _format_job_probes(
Expand All @@ -390,6 +445,9 @@ def get_runs_table(
"GPU": "-",
"PRICE": "-",
}
# Update last shown group index for next iteration
if group_index is not None:
last_shown_group_index = group_index
jpd = latest_job_submission.job_provisioning_data
if jpd is not None:
shared_offer: Optional[InstanceOfferWithAvailability] = None
Expand Down
229 changes: 201 additions & 28 deletions src/dstack/_internal/core/models/configurations.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
from dstack._internal.core.models.services import AnyModel, OpenAIChatModel
from dstack._internal.core.models.unix import UnixUser
from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point
from dstack._internal.core.services import validate_dstack_resource_name
from dstack._internal.utils.common import has_duplicates, list_enum_values_for_annotation
from dstack._internal.utils.json_schema import add_extra_schema_types
from dstack._internal.utils.json_utils import (
Expand Down Expand Up @@ -610,8 +611,13 @@ def convert_ports(cls, v) -> PortMapping:
class ConfigurationWithCommandsParams(CoreModel):
commands: Annotated[CommandsList, Field(description="The shell commands to run")] = []

@root_validator
@root_validator(pre=True)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Why pre? I see a couple of negative side effects of making it a pre validator.

  • It now prevents some of the more obvious validators from running, so error messages can be misleading.

    > cat test.dstack.yml
    type: service
    port: 8000
    
    replicas:  # (!) wrong syntax, not a list
      count: 1
      commands:
      - sleep infinity
    
    > dstack apply -f test.dstack.yml
    1 validation error for ApplyConfigurationRequest
    __root__ -> ServiceConfigurationRequest -> __root__
      Either `commands` or `image` must be set (type=value_error)
  • It now runs before converting replicas to list, so, for example, passing them as a tuple doesn't work anymore:

    >>> ServiceConfiguration(type="service", port=80, replicas=(ReplicaGroup(commands=["sleep"], count=1),))
    Traceback (most recent call last):
      File "<stdin>", line 1, in <module>
      File "/home/jvstme/git/dstack/dstack/venv/lib64/python3.10/site-packages/pydantic_duality/__init__.py", line 246, in __new__
        return cls.__request__(*args, **kwargs)
      File "pydantic/main.py", line 347, in pydantic.main.BaseModel.__init__
    pydantic.error_wrappers.ValidationError: 1 validation error for ServiceConfigurationRequest
    __root__
      Either `commands` or `image` must be set (type=value_error)

def check_image_or_commands_present(cls, values):
# If replicas is list, skip validation - commands come from replica groups
replicas = values.get("replicas")
if isinstance(replicas, list):
return values

if not values.get("commands") and not values.get("image"):
raise ValueError("Either `commands` or `image` must be set")
return values
Expand Down Expand Up @@ -714,6 +720,68 @@ def schema_extra(schema: Dict[str, Any]):
)


class ReplicaGroup(CoreModel):
name: Annotated[
Optional[str],
Field(
description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position."
),
]
count: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
]
scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `count` is set to a range"),
] = None

resources: Annotated[
ResourcesSpec,
Field(description="The resources requirements for replicas in this group"),
] = ResourcesSpec()

commands: Annotated[
CommandsList,
Field(description="The shell commands to run for replicas in this group"),
]

@validator("name")
def validate_name(cls, v: Optional[str]) -> Optional[str]:
if v is not None:
validate_dstack_resource_name(v)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This function raises ServerClientError. Pydantic validators are supposed to raise ValueError. Pydantic detects ValueError and automatically adds useful context, such as the name and location of the field that failed validation

return v

@validator("count")
def convert_count(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@validator("commands")
def validate_commands(cls, v: CommandsList) -> CommandsList:
if not v:
raise ValueError("`commands` must be set for replica groups")
return v
Comment on lines +768 to +772
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Can replace with Field(min_items=1) in the field annotation?


@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
count = values.get("count")
if count and count.min != count.max and not scaling:
raise ValueError("When you set `count` to a range, ensure to specify `scaling`.")
if count and count.min == count.max and scaling:
raise ValueError("To use `scaling`, `count` must be set to a range.")
return values


class ServiceConfigurationParams(CoreModel):
port: Annotated[
# NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used.
Expand Down Expand Up @@ -755,13 +823,7 @@ class ServiceConfigurationParams(CoreModel):
SERVICE_HTTPS_DEFAULT
)
auth: Annotated[bool, Field(description="Enable the authorization")] = True
replicas: Annotated[
Range[int],
Field(
description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). "
"If it's a range, the `scaling` property is required"
),
] = Range[int](min=1, max=1)

scaling: Annotated[
Optional[ScalingSpec],
Field(description="The auto-scaling rules. Required if `replicas` is set to a range"),
Expand All @@ -772,6 +834,19 @@ class ServiceConfigurationParams(CoreModel):
Field(description="List of probes used to determine job health"),
] = []

replicas: Annotated[
Optional[Union[List[ReplicaGroup], Range[int]]],
Field(
description=(
"List of replica groups. Each group defines replicas with shared configuration "
"(commands, port, resources, scaling, probes, rate_limits). "
"When specified, the top-level `replicas`, `commands`, `port`, `resources`, "
"`scaling`, `probes`, and `rate_limits` are ignored. "
"Each replica group must have a unique name."
Comment on lines +841 to +845
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. This mentions properties that are not related to replica groups in the current implementation.
  2. The int and Range[int] syntaxes are worth mentioning too, I think they will remain popular and relevant.

)
),
] = None

@validator("port")
def convert_port(cls, v) -> PortMapping:
if isinstance(v, int):
Expand All @@ -786,26 +861,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]:
return OpenAIChatModel(type="chat", name=v, format="openai")
return v

@validator("replicas")
def convert_replicas(cls, v: Range[int]) -> Range[int]:
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError("The minimum number of replicas must be greater than or equal to 0")
return v

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.")
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@validator("rate_limits")
def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]:
counts = Counter(limit.prefix for limit in v)
Expand All @@ -827,6 +882,92 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]:
raise ValueError("Probes must be unique")
return v

@validator("replicas")
def validate_replicas(
cls, v: Optional[Union[Range[int], List[ReplicaGroup]]]
) -> Optional[Union[Range[int], List[ReplicaGroup]]]:
if v is None:
return v
if isinstance(v, Range):
if v.max is None:
raise ValueError("The maximum number of replicas is required")
if v.min is None:
v.min = 0
if v.min < 0:
raise ValueError(
"The minimum number of replicas must be greater than or equal to 0"
)
return v
Comment on lines +892 to +900
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) Duplicated here and in ReplicaGroup.convert_count. Move to a utility function?


if isinstance(v, list):
if not v:
raise ValueError("`replicas` cannot be an empty list")

# Assign default names to groups without names
for index, group in enumerate(v):
if group.name is None:
group.name = f"replica{index}"

# Check for duplicate names
names = [group.name for group in v]
if len(names) != len(set(names)):
duplicates = [name for name in set(names) if names.count(name) > 1]
raise ValueError(
f"Duplicate replica group names found: {duplicates}. "
"Each replica group must have a unique name."
)
return v

@root_validator()
def validate_scaling(cls, values):
scaling = values.get("scaling")
replicas = values.get("replicas")

if isinstance(replicas, Range):
if replicas and replicas.min != replicas.max and not scaling:
raise ValueError(
"When you set `replicas` to a range, ensure to specify `scaling`."
)
if replicas and replicas.min == replicas.max and scaling:
raise ValueError("To use `scaling`, `replicas` must be set to a range.")
return values

@root_validator()
def validate_top_level_properties_with_replica_groups(cls, values):
"""
When replicas is a list of ReplicaGroup, forbid top-level scaling, commands, and resources
"""
replicas = values.get("replicas")

if not isinstance(replicas, list):
return values

scaling = values.get("scaling")
if scaling is not None:
raise ValueError(
"Top-level `scaling` is not allowed when `replicas` is a list. "
"Specify `scaling` in each replica group instead."
)

commands = values.get("commands", [])
if commands:
raise ValueError(
"Top-level `commands` is not allowed when `replicas` is a list. "
"Specify `commands` in each replica group instead."
)

resources = values.get("resources")
from dstack._internal.core.models.resources import ResourcesSpec
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) ResourcesSpec is already imported at the top of the module


default_resources = ResourcesSpec()
if resources and resources.dict() != default_resources.dict():
raise ValueError(
"Top-level `resources` is not allowed when `replicas` is a list. "
"Specify `resources` in each replica group instead."
)

return values


class ServiceConfigurationConfig(
ProfileParamsConfig,
Expand All @@ -849,6 +990,38 @@ class ServiceConfiguration(
):
type: Literal["service"] = "service"

@property
def replica_groups(self) -> List[ReplicaGroup]:
"""
Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None.
Use this property for type-safe access in code.
"""
Comment on lines +995 to +998
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) This docstring seems to be slightly outdated. We no longer have the validator that converts replicas to Optional[List[ReplicaGroup]]

if self.replicas is None:
return [
ReplicaGroup(
name="default",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) The string "default" is hardcoded in several places in this PR. Consider moving it to a constant (except for tests — hardcoding is fine there)

count=Range[int](min=1, max=1),
commands=self.commands or [],
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(nit) or [] seems to be redundant. Same thing a few lines below

resources=self.resources,
scaling=self.scaling,
)
]
if isinstance(self.replicas, list):
return self.replicas
if isinstance(self.replicas, Range):
return [
ReplicaGroup(
name="default",
count=self.replicas,
commands=self.commands or [],
resources=self.resources,
scaling=self.scaling,
)
]
raise ValueError(
f"Invalid replicas type: {type(self.replicas)}. Expected None, Range[int], or List[ReplicaGroup]"
)


AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration]

Expand Down
Loading