diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 68dc828f7..267ed8206 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -281,16 +281,38 @@ def _format_job_name( show_deployment_num: bool, show_replica: bool, show_job: bool, + group_index: Optional[int] = None, + last_shown_group_index: Optional[int] = None, ) -> str: name_parts = [] + prefix = "" if show_replica: - name_parts.append(f"replica={job.job_spec.replica_num}") + # Show group information if replica groups are used + if group_index is not None: + # Show group=X replica=Y when group changes, or just replica=Y when same group + if group_index != last_shown_group_index: + # First job in group: use 3 spaces indent + prefix = " " + name_parts.append(f"group={group_index} replica={job.job_spec.replica_num}") + else: + # Subsequent job in same group: align "replica=" with first job's "replica=" + # Calculate padding: width of " group={last_shown_group_index} " + padding_width = 3 + len(f"group={last_shown_group_index}") + 1 + prefix = " " * padding_width + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + # Legacy behavior: no replica groups + prefix = " " + name_parts.append(f"replica={job.job_spec.replica_num}") + else: + prefix = " " + if show_job: name_parts.append(f"job={job.job_spec.job_num}") name_suffix = ( f" deployment={latest_job_submission.deployment_num}" if show_deployment_num else "" ) - name_value = " " + (" ".join(name_parts) if name_parts else "") + name_value = prefix + (" ".join(name_parts) if name_parts else "") name_value += name_suffix return name_value @@ -359,6 +381,17 @@ def get_runs_table( ) merge_job_rows = len(run.jobs) == 1 and not show_deployment_num + group_name_to_index: Dict[str, int] = {} + if run.run_spec.configuration.type == "service" and hasattr( + run.run_spec.configuration, "replica_groups" + ): + replica_groups = run.run_spec.configuration.replica_groups + if replica_groups: + for idx, group in enumerate(replica_groups): + # Use group name or default to "replica{idx}" if name is None + group_name = group.name or f"replica{idx}" + group_name_to_index[group_name] = idx + run_row: Dict[Union[str, int], Any] = { "NAME": _format_run_name(run, show_deployment_num), "SUBMITTED": format_date(run.submitted_at), @@ -372,13 +405,35 @@ def get_runs_table( if not merge_job_rows: add_row_from_dict(table, run_row) - for job in run.jobs: + # Sort jobs by group index first, then by replica_num within each group + def get_job_sort_key(job: Job) -> tuple: + group_index = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + # Use a large number for jobs without groups to put them at the end + return (group_index if group_index is not None else 999999, job.job_spec.replica_num) + + sorted_jobs = sorted(run.jobs, key=get_job_sort_key) + + last_shown_group_index: Optional[int] = None + for job in sorted_jobs: latest_job_submission = job.job_submissions[-1] status_formatted = _format_job_submission_status(latest_job_submission, verbose) + # Get group index for this job + group_index: Optional[int] = None + if group_name_to_index and job.job_spec.replica_group: + group_index = group_name_to_index.get(job.job_spec.replica_group) + job_row: Dict[Union[str, int], Any] = { "NAME": _format_job_name( - job, latest_job_submission, show_deployment_num, show_replica, show_job + job, + latest_job_submission, + show_deployment_num, + show_replica, + show_job, + group_index=group_index, + last_shown_group_index=last_shown_group_index, ), "STATUS": status_formatted, "PROBES": _format_job_probes( @@ -390,6 +445,9 @@ def get_runs_table( "GPU": "-", "PRICE": "-", } + # Update last shown group index for next iteration + if group_index is not None: + last_shown_group_index = group_index jpd = latest_job_submission.job_provisioning_data if jpd is not None: shared_offer: Optional[InstanceOfferWithAvailability] = None diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 9c4415556..fcc01a9a4 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -31,6 +31,7 @@ from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point +from dstack._internal.core.services import validate_dstack_resource_name from dstack._internal.utils.common import has_duplicates, list_enum_values_for_annotation from dstack._internal.utils.json_schema import add_extra_schema_types from dstack._internal.utils.json_utils import ( @@ -612,6 +613,11 @@ class ConfigurationWithCommandsParams(CoreModel): @root_validator def check_image_or_commands_present(cls, values): + # If replicas is list, skip validation - commands come from replica groups + replicas = values.get("replicas") + if isinstance(replicas, list): + return values + if not values.get("commands") and not values.get("image"): raise ValueError("Either `commands` or `image` must be set") return values @@ -714,6 +720,68 @@ def schema_extra(schema: Dict[str, Any]): ) +class ReplicaGroup(CoreModel): + name: Annotated[ + Optional[str], + Field( + description="The name of the replica group. If not provided, defaults to 'replica0', 'replica1', etc. based on position." + ), + ] + count: Annotated[ + Range[int], + Field( + description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " + "If it's a range, the `scaling` property is required" + ), + ] + scaling: Annotated[ + Optional[ScalingSpec], + Field(description="The auto-scaling rules. Required if `count` is set to a range"), + ] = None + + resources: Annotated[ + ResourcesSpec, + Field(description="The resources requirements for replicas in this group"), + ] = ResourcesSpec() + + commands: Annotated[ + CommandsList, + Field(description="The shell commands to run for replicas in this group"), + ] + + @validator("name") + def validate_name(cls, v: Optional[str]) -> Optional[str]: + if v is not None: + validate_dstack_resource_name(v) + return v + + @validator("count") + def convert_count(cls, v: Range[int]) -> Range[int]: + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError("The minimum number of replicas must be greater than or equal to 0") + return v + + @validator("commands") + def validate_commands(cls, v: CommandsList) -> CommandsList: + if not v: + raise ValueError("`commands` must be set for replica groups") + return v + + @root_validator() + def validate_scaling(cls, values): + scaling = values.get("scaling") + count = values.get("count") + if count and count.min != count.max and not scaling: + raise ValueError("When you set `count` to a range, ensure to specify `scaling`.") + if count and count.min == count.max and scaling: + raise ValueError("To use `scaling`, `count` must be set to a range.") + return values + + class ServiceConfigurationParams(CoreModel): port: Annotated[ # NOTE: it's a PortMapping for historical reasons. Only `port.container_port` is used. @@ -755,13 +823,7 @@ class ServiceConfigurationParams(CoreModel): SERVICE_HTTPS_DEFAULT ) auth: Annotated[bool, Field(description="Enable the authorization")] = True - replicas: Annotated[ - Range[int], - Field( - description="The number of replicas. Can be a number (e.g. `2`) or a range (`0..4` or `1..8`). " - "If it's a range, the `scaling` property is required" - ), - ] = Range[int](min=1, max=1) + scaling: Annotated[ Optional[ScalingSpec], Field(description="The auto-scaling rules. Required if `replicas` is set to a range"), @@ -772,6 +834,19 @@ class ServiceConfigurationParams(CoreModel): Field(description="List of probes used to determine job health"), ] = [] + replicas: Annotated[ + Optional[Union[List[ReplicaGroup], Range[int]]], + Field( + description=( + "List of replica groups. Each group defines replicas with shared configuration " + "(commands, port, resources, scaling, probes, rate_limits). " + "When specified, the top-level `replicas`, `commands`, `port`, `resources`, " + "`scaling`, `probes`, and `rate_limits` are ignored. " + "Each replica group must have a unique name." + ) + ), + ] = None + @validator("port") def convert_port(cls, v) -> PortMapping: if isinstance(v, int): @@ -786,26 +861,6 @@ def convert_model(cls, v: Optional[Union[AnyModel, str]]) -> Optional[AnyModel]: return OpenAIChatModel(type="chat", name=v, format="openai") return v - @validator("replicas") - def convert_replicas(cls, v: Range[int]) -> Range[int]: - if v.max is None: - raise ValueError("The maximum number of replicas is required") - if v.min is None: - v.min = 0 - if v.min < 0: - raise ValueError("The minimum number of replicas must be greater than or equal to 0") - return v - - @root_validator() - def validate_scaling(cls, values): - scaling = values.get("scaling") - replicas = values.get("replicas") - if replicas and replicas.min != replicas.max and not scaling: - raise ValueError("When you set `replicas` to a range, ensure to specify `scaling`.") - if replicas and replicas.min == replicas.max and scaling: - raise ValueError("To use `scaling`, `replicas` must be set to a range.") - return values - @validator("rate_limits") def validate_rate_limits(cls, v: list[RateLimit]) -> list[RateLimit]: counts = Counter(limit.prefix for limit in v) @@ -827,6 +882,91 @@ def validate_probes(cls, v: list[ProbeConfig]) -> list[ProbeConfig]: raise ValueError("Probes must be unique") return v + @validator("replicas") + def validate_replicas( + cls, v: Optional[Union[Range[int], List[ReplicaGroup]]] + ) -> Optional[Union[Range[int], List[ReplicaGroup]]]: + if v is None: + return v + if isinstance(v, Range): + if v.max is None: + raise ValueError("The maximum number of replicas is required") + if v.min is None: + v.min = 0 + if v.min < 0: + raise ValueError( + "The minimum number of replicas must be greater than or equal to 0" + ) + return v + + if isinstance(v, list): + if not v: + raise ValueError("`replicas` cannot be an empty list") + + # Assign default names to groups without names + for index, group in enumerate(v): + if group.name is None: + group.name = f"replica{index}" + + # Check for duplicate names + names = [group.name for group in v] + if len(names) != len(set(names)): + duplicates = [name for name in set(names) if names.count(name) > 1] + raise ValueError( + f"Duplicate replica group names found: {duplicates}. " + "Each replica group must have a unique name." + ) + return v + + @root_validator() + def validate_scaling(cls, values): + scaling = values.get("scaling") + replicas = values.get("replicas") + + if isinstance(replicas, Range): + if replicas and replicas.min != replicas.max and not scaling: + raise ValueError( + "When you set `replicas` to a range, ensure to specify `scaling`." + ) + if replicas and replicas.min == replicas.max and scaling: + raise ValueError("To use `scaling`, `replicas` must be set to a range.") + return values + + @root_validator() + def validate_top_level_properties_with_replica_groups(cls, values): + """ + When replicas is a list of ReplicaGroup, forbid top-level scaling, commands, and resources + """ + replicas = values.get("replicas") + + if not isinstance(replicas, list): + return values + + scaling = values.get("scaling") + if scaling is not None: + raise ValueError( + "Top-level `scaling` is not allowed when `replicas` is a list. " + "Specify `scaling` in each replica group instead." + ) + + commands = values.get("commands", []) + if commands: + raise ValueError( + "Top-level `commands` is not allowed when `replicas` is a list. " + "Specify `commands` in each replica group instead." + ) + + resources = values.get("resources") + + default_resources = ResourcesSpec() + if resources and resources.dict() != default_resources.dict(): + raise ValueError( + "Top-level `resources` is not allowed when `replicas` is a list. " + "Specify `resources` in each replica group instead." + ) + + return values + class ServiceConfigurationConfig( ProfileParamsConfig, @@ -849,6 +989,38 @@ class ServiceConfiguration( ): type: Literal["service"] = "service" + @property + def replica_groups(self) -> List[ReplicaGroup]: + """ + Get normalized replica groups. After validation, replicas is always List[ReplicaGroup] or None. + Use this property for type-safe access in code. + """ + if self.replicas is None: + return [ + ReplicaGroup( + name="default", + count=Range[int](min=1, max=1), + commands=self.commands or [], + resources=self.resources, + scaling=self.scaling, + ) + ] + if isinstance(self.replicas, list): + return self.replicas + if isinstance(self.replicas, Range): + return [ + ReplicaGroup( + name="default", + count=self.replicas, + commands=self.commands or [], + resources=self.resources, + scaling=self.scaling, + ) + ] + raise ValueError( + f"Invalid replicas type: {type(self.replicas)}. Expected None, Range[int], or List[ReplicaGroup]" + ) + AnyRunConfiguration = Union[DevEnvironmentConfiguration, TaskConfiguration, ServiceConfiguration] diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index a966bc34a..7e8d89f0e 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -253,6 +253,7 @@ class JobSpec(CoreModel): job_num: int job_name: str jobs_per_replica: int = 1 # default value for backward compatibility + replica_group: str = "default" app_specs: Optional[List[AppSpec]] user: Optional[UnixUser] = None # default value for backward compatibility commands: List[str] diff --git a/src/dstack/_internal/server/background/tasks/process_runs.py b/src/dstack/_internal/server/background/tasks/process_runs.py index af2dcee8d..62363a6f2 100644 --- a/src/dstack/_internal/server/background/tasks/process_runs.py +++ b/src/dstack/_internal/server/background/tasks/process_runs.py @@ -1,5 +1,6 @@ import asyncio import datetime +import json from typing import List, Optional, Set, Tuple from sqlalchemy import and_, or_, select @@ -8,6 +9,7 @@ import dstack._internal.server.services.services.autoscalers as autoscalers from dstack._internal.core.errors import ServerError +from dstack._internal.core.models.configurations import ReplicaGroup from dstack._internal.core.models.profiles import RetryEvent, StopCriteria from dstack._internal.core.models.runs import ( Job, @@ -47,6 +49,7 @@ is_replica_registered, retry_run_replica_jobs, scale_run_replicas, + scale_run_replicas_per_group, ) from dstack._internal.server.services.secrets import get_project_secrets_mapping from dstack._internal.server.services.services import update_service_desired_replica_count @@ -192,9 +195,10 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): logger.debug("%s: retrying run is not yet ready for resubmission", fmt(run_model)) return - run_model.desired_replica_count = 1 if run.run_spec.configuration.type == "service": - run_model.desired_replica_count = run.run_spec.configuration.replicas.min or 0 + run_model.desired_replica_count = sum( + group.count.min or 0 for group in (run.run_spec.configuration.replica_groups or []) + ) await update_service_desired_replica_count( session, run_model, @@ -203,12 +207,23 @@ async def _process_pending_run(session: AsyncSession, run_model: RunModel): last_scaled_at=None, ) - if run_model.desired_replica_count == 0: - # stay zero scaled - return + if run_model.desired_replica_count == 0: + # stay zero scaled + return - await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) - switch_run_status(session, run_model, RunStatus.SUBMITTED) + # Per group scaling because single replica is also normalized to replica groups. + replicas: List[ReplicaGroup] = run.run_spec.configuration.replica_groups or [] + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replicas, counts) + else: + run_model.desired_replica_count = 1 + await scale_run_replicas(session, run_model, replicas_diff=run_model.desired_replica_count) + + switch_run_status(session=session, run_model=run_model, new_status=RunStatus.SUBMITTED) def _retrying_run_ready_for_resubmission(run_model: RunModel, run: Run) -> bool: @@ -444,6 +459,32 @@ async def _handle_run_replicas( # FIXME: should only include scaling events, not retries and deployments last_scaled_at=max((r.timestamp for r in replicas_info), default=None), ) + replicas: List[ReplicaGroup] = run_spec.configuration.replica_groups or [] + if replicas: + counts = ( + json.loads(run_model.desired_replica_counts) + if run_model.desired_replica_counts + else {} + ) + await scale_run_replicas_per_group(session, run_model, replicas, counts) + + # Handle per-group rolling deployment + await _update_jobs_to_new_deployment_in_place( + session=session, + run_model=run_model, + run_spec=run_spec, + replicas=replicas, + ) + # Process per-group rolling deployment + for group in replicas: + await _handle_rolling_deployment_for_group( + session=session, + run_model=run_model, + group=group, + run_spec=run_spec, + desired_replica_counts=counts, + ) + return max_replica_count = run_model.desired_replica_count if _has_out_of_date_replicas(run_model): @@ -509,7 +550,10 @@ async def _handle_run_replicas( async def _update_jobs_to_new_deployment_in_place( - session: AsyncSession, run_model: RunModel, run_spec: RunSpec + session: AsyncSession, + run_model: RunModel, + run_spec: RunSpec, + replicas: Optional[List] = None, ) -> None: """ Bump deployment_num for jobs that do not require redeployment. @@ -518,16 +562,26 @@ async def _update_jobs_to_new_deployment_in_place( session=session, project=run_model.project, ) + for replica_num, job_models in group_jobs_by_replica_latest(run_model.jobs): if all(j.status.is_finished() for j in job_models): continue if all(j.deployment_num == run_model.deployment_num for j in job_models): continue + + # Determine which group this replica belongs to + replica_group_name = None + + if replicas: + job_spec = JobSpec.__response__.parse_raw(job_models[0].job_spec_data) + replica_group_name = job_spec.replica_group + # FIXME: Handle getting image configuration errors or skip it. new_job_specs = await get_job_specs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group_name=replica_group_name, ) assert len(new_job_specs) == len(job_models), ( "Changing the number of jobs within a replica is not yet supported" @@ -543,8 +597,15 @@ async def _update_jobs_to_new_deployment_in_place( job_model.deployment_num = run_model.deployment_num -def _has_out_of_date_replicas(run: RunModel) -> bool: +def _has_out_of_date_replicas(run: RunModel, group_filter: Optional[str] = None) -> bool: for job in run.jobs: + # Filter jobs by group if specified + if group_filter is not None: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + # Handle None case: treat None as "default" for backward compatibility + job_replica_group = job_spec.replica_group or "default" + if job_replica_group != group_filter: + continue if job.deployment_num < run.deployment_num and not ( job.status.is_finished() or job.termination_reason == JobTerminationReason.SCALED_DOWN ): @@ -607,3 +668,111 @@ def _should_stop_on_master_done(run: Run) -> bool: if is_master_job(job) and job.job_submissions[-1].status == JobStatus.DONE: return True return False + + +async def _handle_rolling_deployment_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + run_spec: RunSpec, + desired_replica_counts: dict, +) -> None: + """ + Handle rolling deployment for a single replica group. + """ + from dstack._internal.server.services.runs.replicas import ( + _build_replica_lists, + scale_run_replicas_for_group, + ) + + group_desired = desired_replica_counts.get(group.name, group.count.min or 0) + + # Check if group has out-of-date replicas + if not _has_out_of_date_replicas(run_model, group_filter=group.name): + return # Group is up-to-date + + # Calculate max replicas (allow surge during deployment) + group_max_replica_count = group_desired + ROLLING_DEPLOYMENT_MAX_SURGE + + # Count non-terminated replicas for this group only + + non_terminated_replica_count = len( + { + j.replica_num + for j in run_model.jobs + if not j.status.is_finished() + and group.name is not None + and _job_belongs_to_group(job=j, group_name=group.name) + } + ) + + # Start new up-to-date replicas if needed + if non_terminated_replica_count < group_max_replica_count: + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_max_replica_count - non_terminated_replica_count, + run_spec=run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + # Stop out-of-date replicas that are not registered + replicas_to_stop_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + # Check if replica is out-of-date and not registered + if ( + any(j.deployment_num < run_model.deployment_num for j in jobs) + and any( + j.status not in [JobStatus.TERMINATING] + JobStatus.finished_statuses() + for j in jobs + ) + and not is_replica_registered(jobs) + ): + replicas_to_stop_count += 1 + + # Stop excessive registered out-of-date replicas + non_terminating_registered_replicas_count = 0 + for _, jobs in group_jobs_by_replica_latest(run_model.jobs): + # Filter by group + job_spec = JobSpec.__response__.parse_raw(jobs[0].job_spec_data) + if job_spec.replica_group != group.name: + continue + + if is_replica_registered(jobs) and all(j.status != JobStatus.TERMINATING for j in jobs): + non_terminating_registered_replicas_count += 1 + + replicas_to_stop_count += max(0, non_terminating_registered_replicas_count - group_desired) + + if replicas_to_stop_count > 0: + # Build lists again to get current state + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, + jobs=run_model.jobs, + group_filter=group.name, + ) + + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=-replicas_to_stop_count, + run_spec=run_spec, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +def _job_belongs_to_group(job: JobModel, group_name: str) -> bool: + job_spec = JobSpec.__response__.parse_raw(job.job_spec_data) + return job_spec.replica_group == group_name diff --git a/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py new file mode 100644 index 000000000..af4611b3c --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/706e0acc3a7d_add_runmodel_desired_replica_counts.py @@ -0,0 +1,26 @@ +"""add runmodel desired_replica_counts + +Revision ID: 706e0acc3a7d +Revises: 22d74df9897e +Create Date: 2025-12-18 10:54:13.508297 + +""" + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "706e0acc3a7d" +down_revision = "903c91e24634" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.add_column(sa.Column("desired_replica_counts", sa.Text(), nullable=True)) + + +def downgrade() -> None: + with op.batch_alter_table("runs", schema=None) as batch_op: + batch_op.drop_column("desired_replica_counts") diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index 5274d9ebf..72170f3c9 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -405,7 +405,7 @@ class RunModel(BaseModel): priority: Mapped[int] = mapped_column(Integer, default=0) deployment_num: Mapped[int] = mapped_column(Integer) desired_replica_count: Mapped[int] = mapped_column(Integer) - + desired_replica_counts: Mapped[Optional[str]] = mapped_column(Text, nullable=True) jobs: Mapped[List["JobModel"]] = relationship( back_populates="run", lazy="selectin", order_by="[JobModel.replica_num, JobModel.job_num]" ) diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 1ed3c5f99..75ab91696 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -98,7 +98,10 @@ def switch_job_status( async def get_jobs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group_name: Optional[str] = None, ) -> List[Job]: return [ Job(job_spec=s, job_submissions=[]) @@ -106,14 +109,20 @@ async def get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, replica_num=replica_num, + replica_group_name=replica_group_name, ) ] async def get_job_specs_from_run_spec( - run_spec: RunSpec, secrets: Dict[str, str], replica_num: int + run_spec: RunSpec, + secrets: Dict[str, str], + replica_num: int, + replica_group_name: Optional[str] = None, ) -> List[JobSpec]: - job_configurator = _get_job_configurator(run_spec=run_spec, secrets=secrets) + job_configurator = _get_job_configurator( + run_spec=run_spec, secrets=secrets, replica_group_name=replica_group_name + ) job_specs = await job_configurator.get_job_specs(replica_num=replica_num) return job_specs @@ -241,10 +250,14 @@ def is_master_job(job: Job) -> bool: return job.job_spec.job_num == 0 -def _get_job_configurator(run_spec: RunSpec, secrets: Dict[str, str]) -> JobConfigurator: +def _get_job_configurator( + run_spec: RunSpec, secrets: Dict[str, str], replica_group_name: Optional[str] = None +) -> JobConfigurator: configuration_type = RunConfigurationType(run_spec.configuration.type) configurator_class = _configuration_type_to_configurator_class_map[configuration_type] - return configurator_class(run_spec=run_spec, secrets=secrets) + return configurator_class( + run_spec=run_spec, secrets=secrets, replica_group_name=replica_group_name + ) _job_configurator_classes = [ diff --git a/src/dstack/_internal/server/services/jobs/configurators/base.py b/src/dstack/_internal/server/services/jobs/configurators/base.py index 0e770b2e9..d934fd8cf 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/base.py +++ b/src/dstack/_internal/server/services/jobs/configurators/base.py @@ -90,9 +90,11 @@ def __init__( self, run_spec: RunSpec, secrets: Optional[Dict[str, str]] = None, + replica_group_name: Optional[str] = None, ): self.run_spec = run_spec self.secrets = secrets or {} + self.replica_group_name = replica_group_name async def get_job_specs(self, replica_num: int) -> List[JobSpec]: job_spec = await self._get_job_spec(replica_num=replica_num, job_num=0, jobs_per_replica=1) @@ -150,6 +152,7 @@ async def _get_job_spec( job_num=job_num, job_name=f"{self.run_spec.run_name}-{job_num}-{replica_num}", jobs_per_replica=jobs_per_replica, + replica_group=self.replica_group_name or "default", app_specs=self._app_specs(), commands=await self._commands(), env=self._env(), @@ -298,9 +301,15 @@ def _registry_auth(self) -> Optional[RegistryAuth]: return self.run_spec.configuration.registry_auth def _requirements(self, jobs_per_replica: int) -> Requirements: + resources = self.run_spec.configuration.resources + if self.run_spec.configuration.type == "service": + for group in self.run_spec.configuration.replica_groups: + if group.name == self.replica_group_name: + resources = group.resources + break spot_policy = self._spot_policy() return Requirements( - resources=self.run_spec.configuration.resources, + resources=resources, max_price=self.run_spec.merged_profile.max_price, spot=None if spot_policy == SpotPolicy.AUTO else (spot_policy == SpotPolicy.SPOT), reservation=self.run_spec.merged_profile.reservation, diff --git a/src/dstack/_internal/server/services/jobs/configurators/service.py b/src/dstack/_internal/server/services/jobs/configurators/service.py index be15c4b23..6b5aa8c2d 100644 --- a/src/dstack/_internal/server/services/jobs/configurators/service.py +++ b/src/dstack/_internal/server/services/jobs/configurators/service.py @@ -10,6 +10,9 @@ class ServiceJobConfigurator(JobConfigurator): def _shell_commands(self) -> List[str]: assert self.run_spec.configuration.type == "service" + for group in self.run_spec.configuration.replica_groups: + if group.name == self.replica_group_name: + return group.commands return self.run_spec.configuration.commands def _default_single_branch(self) -> bool: diff --git a/src/dstack/_internal/server/services/runs/__init__.py b/src/dstack/_internal/server/services/runs/__init__.py index 5773403cf..4b705fdc6 100644 --- a/src/dstack/_internal/server/services/runs/__init__.py +++ b/src/dstack/_internal/server/services/runs/__init__.py @@ -486,12 +486,8 @@ async def submit_run( submitted_at = common_utils.get_current_datetime() initial_status = RunStatus.SUBMITTED - initial_replicas = 1 if run_spec.merged_profile.schedule is not None: initial_status = RunStatus.PENDING - initial_replicas = 0 - elif run_spec.configuration.type == "service": - initial_replicas = run_spec.configuration.replicas.min or 0 run_model = RunModel( id=uuid.uuid4(), @@ -519,12 +515,46 @@ async def submit_run( if run_spec.configuration.type == "service": await services.register_service(session, run_model, run_spec) + service_config = run_spec.configuration - for replica_num in range(initial_replicas): + global_replica_num = 0 # Global counter across all groups for unique replica_num + + for replica_group in service_config.replica_groups: + if run_spec.merged_profile.schedule is not None: + group_initial_replicas = 0 + else: + group_initial_replicas = replica_group.count.min or 0 + + # Each replica in this group gets the same group-specific configuration + for group_replica_num in range(group_initial_replicas): + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=global_replica_num, + replica_group_name=replica_group.name, + ) + + for job in jobs: + job_model = create_job_model_for_new_submission( + run_model=run_model, + job=job, + status=JobStatus.SUBMITTED, + ) + session.add(job_model) + events.emit( + session, + f"Job created on run submission. Status: {job_model.status.upper()}", + actor=events.SystemActor(), + targets=[ + events.Target.from_model(job_model), + ], + ) + global_replica_num += 1 + else: jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=0, ) for job in jobs: job_model = create_job_model_for_new_submission( diff --git a/src/dstack/_internal/server/services/runs/plan.py b/src/dstack/_internal/server/services/runs/plan.py index a5b20b15b..10aba7d25 100644 --- a/src/dstack/_internal/server/services/runs/plan.py +++ b/src/dstack/_internal/server/services/runs/plan.py @@ -78,57 +78,117 @@ async def get_job_plans( run_spec.run_name = "dry-run" secrets = await get_project_secrets_mapping(session=session, project=project) - jobs = await get_jobs_from_run_spec( - run_spec=run_spec, - secrets=secrets, - replica_num=0, - ) - volumes = await get_job_configured_volumes( - session=session, - project=project, - run_spec=run_spec, - job_num=0, - ) - candidate_fleet_models = await _select_candidate_fleet_models( - session=session, - project=project, - run_model=None, - run_spec=run_spec, - ) - fleet_model, instance_offers, backend_offers = await find_optimal_fleet_with_offers( - project=project, - fleet_models=candidate_fleet_models, - run_model=None, - run_spec=run_spec, - job=jobs[0], - master_job_provisioning_data=None, - volumes=volumes, - exclude_not_available=False, - ) - if _should_force_non_fleet_offers(run_spec) or ( - FeatureFlags.AUTOCREATED_FLEETS_ENABLED and profile.fleets is None and fleet_model is None - ): - # Keep the old behavior returning all offers irrespective of fleets. - # Needed for supporting offers with autocreated fleets flow (and for `dstack offer`). - instance_offers, backend_offers = await _get_non_fleet_offers( + + job_plans = [] + + if run_spec.configuration.type == "service": + for replica_group in run_spec.configuration.replica_groups: + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=0, + replica_group_name=replica_group.name, + ) + volumes = await get_job_configured_volumes( + session=session, + project=project, + run_spec=run_spec, + job_num=0, + ) + candidate_fleet_models = await _select_candidate_fleet_models( + session=session, + project=project, + run_model=None, + run_spec=run_spec, + ) + fleet_model, instance_offers, backend_offers = await find_optimal_fleet_with_offers( + project=project, + fleet_models=candidate_fleet_models, + run_model=None, + run_spec=run_spec, + job=jobs[0], + master_job_provisioning_data=None, + volumes=volumes, + exclude_not_available=False, + ) + if _should_force_non_fleet_offers(run_spec) or ( + FeatureFlags.AUTOCREATED_FLEETS_ENABLED + and profile.fleets is None + and fleet_model is None + ): + # Keep the old behavior returning all offers irrespective of fleets. + # Needed for supporting offers with autocreated fleets flow (and for `dstack offer`). + instance_offers, backend_offers = await _get_non_fleet_offers( + session=session, + project=project, + profile=profile, + run_spec=run_spec, + job=jobs[0], + volumes=volumes, + ) + + for job in jobs: + job_plan = _get_job_plan( + instance_offers=instance_offers, + backend_offers=backend_offers, + profile=profile, + job=job, + max_offers=max_offers, + ) + job_plans.append(job_plan) + else: + jobs = await get_jobs_from_run_spec( + run_spec=run_spec, + secrets=secrets, + replica_num=0, + ) + volumes = await get_job_configured_volumes( session=session, project=project, - profile=profile, + run_spec=run_spec, + job_num=0, + ) + candidate_fleet_models = await _select_candidate_fleet_models( + session=session, + project=project, + run_model=None, + run_spec=run_spec, + ) + fleet_model, instance_offers, backend_offers = await find_optimal_fleet_with_offers( + project=project, + fleet_models=candidate_fleet_models, + run_model=None, run_spec=run_spec, job=jobs[0], + master_job_provisioning_data=None, volumes=volumes, + exclude_not_available=False, ) + if _should_force_non_fleet_offers(run_spec) or ( + FeatureFlags.AUTOCREATED_FLEETS_ENABLED + and profile.fleets is None + and fleet_model is None + ): + # Keep the old behavior returning all offers irrespective of fleets. + # Needed for supporting offers with autocreated fleets flow (and for `dstack offer`). + instance_offers, backend_offers = await _get_non_fleet_offers( + session=session, + project=project, + profile=profile, + run_spec=run_spec, + job=jobs[0], + volumes=volumes, + ) - job_plans = [] - for job in jobs: - job_plan = _get_job_plan( - instance_offers=instance_offers, - backend_offers=backend_offers, - profile=profile, - job=job, - max_offers=max_offers, - ) - job_plans.append(job_plan) + for job in jobs: + job_plan = _get_job_plan( + instance_offers=instance_offers, + backend_offers=backend_offers, + profile=profile, + job=job, + max_offers=max_offers, + ) + job_plans.append(job_plan) run_spec.run_name = run_name return job_plans diff --git a/src/dstack/_internal/server/services/runs/replicas.py b/src/dstack/_internal/server/services/runs/replicas.py index 43065d96d..35c421c63 100644 --- a/src/dstack/_internal/server/services/runs/replicas.py +++ b/src/dstack/_internal/server/services/runs/replicas.py @@ -1,8 +1,9 @@ -from typing import List +from typing import Dict, List, Optional, Tuple from sqlalchemy.ext.asyncio import AsyncSession -from dstack._internal.core.models.runs import JobStatus, JobTerminationReason, RunSpec +from dstack._internal.core.models.configurations import ReplicaGroup +from dstack._internal.core.models.runs import JobSpec, JobStatus, JobTerminationReason, RunSpec from dstack._internal.server.models import JobModel, RunModel from dstack._internal.server.services import events from dstack._internal.server.services.jobs import ( @@ -11,7 +12,10 @@ switch_job_status, ) from dstack._internal.server.services.logging import fmt -from dstack._internal.server.services.runs import create_job_model_for_new_submission, logger +from dstack._internal.server.services.runs import ( + create_job_model_for_new_submission, + logger, +) from dstack._internal.server.services.secrets import get_project_secrets_mapping @@ -23,10 +27,17 @@ async def retry_run_replica_jobs( session=session, project=run_model.project, ) + + # Determine replica group from existing job + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + job_spec = JobSpec.parse_raw(latest_jobs[0].job_spec_data) + replica_group_name = job_spec.replica_group + new_jobs = await get_jobs_from_run_spec( - run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + run_spec=run_spec, secrets=secrets, replica_num=latest_jobs[0].replica_num, + replica_group_name=replica_group_name, ) assert len(new_jobs) == len(latest_jobs), ( "Changing the number of jobs within a replica is not yet supported" @@ -64,7 +75,6 @@ def is_replica_registered(jobs: list[JobModel]) -> bool: async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replicas_diff: int): if replicas_diff == 0: - # nothing to do return logger.info( @@ -74,14 +84,48 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica abs(replicas_diff), ) + active_replicas, inactive_replicas = _build_replica_lists(run_model, run_model.jobs) + run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session, + run_model, + active_replicas, + inactive_replicas, + replicas_diff, + run_spec, + group_name=None, + ) + + +def _build_replica_lists( + run_model: RunModel, + jobs: List[JobModel], + group_filter: Optional[str] = None, +) -> Tuple[ + List[Tuple[int, bool, int, List[JobModel]]], List[Tuple[int, bool, int, List[JobModel]]] +]: # lists of (importance, is_out_of_date, replica_num, jobs) active_replicas = [] inactive_replicas = [] - for replica_num, replica_jobs in group_jobs_by_replica_latest(run_model.jobs): + for replica_num, replica_jobs in group_jobs_by_replica_latest(jobs): + # Filter by group if specified + if group_filter is not None: + try: + job_spec = JobSpec.parse_raw(replica_jobs[0].job_spec_data) + if job_spec.replica_group != group_filter: + continue + except Exception: + continue + statuses = set(job.status for job in replica_jobs) deployment_num = replica_jobs[0].deployment_num # same for all jobs is_out_of_date = deployment_num < run_model.deployment_num + if {JobStatus.TERMINATING, *JobStatus.finished_statuses()} & statuses: # if there are any terminating or finished jobs, the replica is inactive inactive_replicas.append((0, is_out_of_date, replica_num, replica_jobs)) @@ -98,42 +142,67 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica # all jobs are running and ready, the replica is active and has the importance of 3 active_replicas.append((3, is_out_of_date, replica_num, replica_jobs)) - # sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) + # Sort by is_out_of_date (up-to-date first), importance (desc), and replica_num (asc) active_replicas.sort(key=lambda r: (r[1], -r[0], r[2])) - run_spec = RunSpec.__response__.parse_raw(run_model.run_spec) - if replicas_diff < 0: - for _, _, _, replica_jobs in reversed(active_replicas[-abs(replicas_diff) :]): - # scale down the less important replicas first - for job in replica_jobs: - if job.status.is_finished() or job.status == JobStatus.TERMINATING: - continue - job.status = JobStatus.TERMINATING - job.termination_reason = JobTerminationReason.SCALED_DOWN - # background task will process the job later - else: - scheduled_replicas = 0 + return active_replicas, inactive_replicas + - # rerun inactive replicas - for _, _, _, replica_jobs in inactive_replicas: - if scheduled_replicas == replicas_diff: - break - await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) - scheduled_replicas += 1 +def _scale_down_replicas( + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + count: int, +) -> None: + """Scale down by terminating the least important replicas""" + if count <= 0: + return + for _, _, _, replica_jobs in reversed(active_replicas[-count:]): + for job in replica_jobs: + if job.status.is_finished() or job.status == JobStatus.TERMINATING: + continue + job.status = JobStatus.TERMINATING + job.termination_reason = JobTerminationReason.SCALED_DOWN + + +async def _scale_up_replicas( + session: AsyncSession, + run_model: RunModel, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], + replicas_diff: int, + run_spec: RunSpec, + group_name: Optional[str] = None, +) -> None: + """Scale up by retrying inactive replicas and creating new ones""" + if replicas_diff <= 0: + return + + scheduled_replicas = 0 + + # Retry inactive replicas first + for _, _, _, replica_jobs in inactive_replicas: + if scheduled_replicas == replicas_diff: + break + await retry_run_replica_jobs(session, run_model, replica_jobs, only_failed=False) + scheduled_replicas += 1 + + # Create new replicas + if scheduled_replicas < replicas_diff: secrets = await get_project_secrets_mapping( session=session, project=run_model.project, ) - for replica_num in range( - len(active_replicas) + scheduled_replicas, len(active_replicas) + replicas_diff - ): - # FIXME: Handle getting image configuration errors or skip it. + max_replica_num = max((job.replica_num for job in run_model.jobs), default=-1) + + new_replicas_needed = replicas_diff - scheduled_replicas + for i in range(new_replicas_needed): + new_replica_num = max_replica_num + 1 + i jobs = await get_jobs_from_run_spec( run_spec=run_spec, secrets=secrets, - replica_num=replica_num, + replica_num=new_replica_num, + replica_group_name=group_name, ) for job in jobs: job_model = create_job_model_for_new_submission( @@ -148,3 +217,89 @@ async def scale_run_replicas(session: AsyncSession, run_model: RunModel, replica actor=events.SystemActor(), targets=[events.Target.from_model(job_model)], ) + # Append to run_model.jobs so that when processing later replica groups in the same + # transaction, run_model.jobs includes jobs from previously processed groups. + run_model.jobs.append(job_model) + + +async def scale_run_replicas_per_group( + session: AsyncSession, + run_model: RunModel, + replicas: List[ReplicaGroup], + desired_replica_counts: Dict[str, int], +) -> None: + """Scale each replica group independently""" + if not replicas: + return + + for group in replicas: + if group.name is None: + continue + group_desired = desired_replica_counts.get(group.name, group.count.min or 0) + + # Build replica lists filtered by this group + active_replicas, inactive_replicas = _build_replica_lists( + run_model=run_model, jobs=run_model.jobs, group_filter=group.name + ) + + # Count active replicas + active_group_count = len(active_replicas) + group_diff = group_desired - active_group_count + + if group_diff != 0: + # Check if rolling deployment is in progress for THIS GROUP + from dstack._internal.server.background.tasks.process_runs import ( + _has_out_of_date_replicas, + ) + + group_has_out_of_date = _has_out_of_date_replicas(run_model, group_filter=group.name) + + # During rolling deployment, don't scale down old replicas + # Let rolling deployment handle stopping old replicas + if group_diff < 0 and group_has_out_of_date: + # Skip scaling down during rolling deployment + continue + await scale_run_replicas_for_group( + session=session, + run_model=run_model, + group=group, + replicas_diff=group_diff, + run_spec=RunSpec.__response__.parse_raw(run_model.run_spec), + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + ) + + +async def scale_run_replicas_for_group( + session: AsyncSession, + run_model: RunModel, + group: ReplicaGroup, + replicas_diff: int, + run_spec: RunSpec, + active_replicas: List[Tuple[int, bool, int, List[JobModel]]], + inactive_replicas: List[Tuple[int, bool, int, List[JobModel]]], +) -> None: + """Scale a specific replica group up or down""" + if replicas_diff == 0: + return + + logger.info( + "%s: scaling %s %s replica(s) for group '%s'", + fmt(run_model), + "UP" if replicas_diff > 0 else "DOWN", + abs(replicas_diff), + group.name, + ) + + if replicas_diff < 0: + _scale_down_replicas(active_replicas, abs(replicas_diff)) + else: + await _scale_up_replicas( + session=session, + run_model=run_model, + active_replicas=active_replicas, + inactive_replicas=inactive_replicas, + replicas_diff=replicas_diff, + run_spec=run_spec, + group_name=group.name, + ) diff --git a/src/dstack/_internal/server/services/runs/spec.py b/src/dstack/_internal/server/services/runs/spec.py index 73b6d9fc7..f6fba1f59 100644 --- a/src/dstack/_internal/server/services/runs/spec.py +++ b/src/dstack/_internal/server/services/runs/spec.py @@ -88,7 +88,10 @@ def validate_run_spec_and_set_defaults( f"Maximum utilization_policy.time_window is {settings.SERVER_METRICS_RUNNING_TTL_SECONDS}s" ) if isinstance(run_spec.configuration, ServiceConfiguration): - if run_spec.merged_profile.schedule and run_spec.configuration.replicas.min == 0: + # Check if any group has min=0 + if run_spec.merged_profile.schedule and any( + group.count.min == 0 for group in (run_spec.configuration.replica_groups or []) + ): raise ServerClientError( "Scheduled services with autoscaling to zero are not supported" ) @@ -149,11 +152,10 @@ def get_nodes_required_num(run_spec: RunSpec) -> int: nodes_required_num = 1 if run_spec.configuration.type == "task": nodes_required_num = run_spec.configuration.nodes - elif ( - run_spec.configuration.type == "service" - and run_spec.configuration.replicas.min is not None - ): - nodes_required_num = run_spec.configuration.replicas.min + elif run_spec.configuration.type == "service": + nodes_required_num = sum( + group.count.min or 0 for group in (run_spec.configuration.replica_groups or []) + ) return nodes_required_num diff --git a/src/dstack/_internal/server/services/services/__init__.py b/src/dstack/_internal/server/services/services/__init__.py index 39e8e98c6..d73aa3906 100644 --- a/src/dstack/_internal/server/services/services/__init__.py +++ b/src/dstack/_internal/server/services/services/__init__.py @@ -2,6 +2,7 @@ Application logic related to `type: service` runs. """ +import json import uuid from datetime import datetime from typing import Optional @@ -145,7 +146,12 @@ def _register_service_in_server(run_model: RunModel, run_spec: RunSpec) -> Servi "The `https` configuration property is not applicable when running services without a gateway." " Please configure a gateway or remove the `https` property from the service configuration" ) - if run_spec.configuration.replicas.min != run_spec.configuration.replicas.max: + # Check if any group has autoscaling (min != max) + has_autoscaling = any( + group.count.min != group.count.max + for group in (run_spec.configuration.replica_groups or []) + ) + if has_autoscaling: raise ServerClientError( "Auto-scaling is not supported when running services without a gateway." " Please configure a gateway or set `replicas` to a fixed value in the service configuration" @@ -303,13 +309,24 @@ async def update_service_desired_replica_count( configuration: ServiceConfiguration, last_scaled_at: Optional[datetime], ) -> None: - scaler = get_service_scaler(configuration) stats = None if run_model.gateway_id is not None: conn = await get_or_add_gateway_connection(session, run_model.gateway_id) stats = await conn.get_stats(run_model.project.name, run_model.run_name) - run_model.desired_replica_count = scaler.get_desired_count( - current_desired_count=run_model.desired_replica_count, - stats=stats, - last_scaled_at=last_scaled_at, + replica_groups = configuration.replica_groups + desired_replica_counts = {} + total = 0 + prev_counts = ( + json.loads(run_model.desired_replica_counts) if run_model.desired_replica_counts else {} ) + for group in replica_groups: + scaler = get_service_scaler(group.count, group.scaling) + group_desired = scaler.get_desired_count( + current_desired_count=prev_counts.get(group.name, group.count.min or 0), + stats=stats, + last_scaled_at=last_scaled_at, + ) + desired_replica_counts[group.name] = group_desired + total += group_desired + run_model.desired_replica_counts = json.dumps(desired_replica_counts) + run_model.desired_replica_count = total diff --git a/src/dstack/_internal/server/services/services/autoscalers.py b/src/dstack/_internal/server/services/services/autoscalers.py index cd6d06e58..641d2cee4 100644 --- a/src/dstack/_internal/server/services/services/autoscalers.py +++ b/src/dstack/_internal/server/services/services/autoscalers.py @@ -6,7 +6,8 @@ from pydantic import BaseModel import dstack._internal.utils.common as common_utils -from dstack._internal.core.models.configurations import ServiceConfiguration +from dstack._internal.core.models.configurations import ScalingSpec +from dstack._internal.core.models.resources import Range from dstack._internal.proxy.gateway.schemas.stats import PerWindowStats @@ -119,21 +120,21 @@ def get_desired_count( return new_desired_count -def get_service_scaler(conf: ServiceConfiguration) -> BaseServiceScaler: - assert conf.replicas.min is not None - assert conf.replicas.max is not None - if conf.scaling is None: +def get_service_scaler(count: Range[int], scaling: Optional[ScalingSpec]) -> BaseServiceScaler: + assert count.min is not None + assert count.max is not None + if scaling is None: return ManualScaler( - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, + min_replicas=count.min, + max_replicas=count.max, ) - if conf.scaling.metric == "rps": + if scaling.metric == "rps": return RPSAutoscaler( # replicas count validated by configuration model - min_replicas=conf.replicas.min, - max_replicas=conf.replicas.max, - target=conf.scaling.target, - scale_up_delay=conf.scaling.scale_up_delay, - scale_down_delay=conf.scaling.scale_down_delay, + min_replicas=count.min, + max_replicas=count.max, + target=scaling.target, + scale_up_delay=scaling.scale_up_delay, + scale_down_delay=scaling.scale_down_delay, ) - raise ValueError(f"No scaler found for scaling parameters {conf.scaling}") + raise ValueError(f"No scaler found for scaling parameters {scaling}") diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 5f5037c79..96adbd996 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -254,6 +254,7 @@ def get_dev_env_run_plan_dict( "replica_num": 0, "job_num": 0, "jobs_per_replica": 1, + "replica_group": "default", "single_branch": False, "max_duration": None, "stop_duration": 300, @@ -487,6 +488,7 @@ def get_dev_env_run_dict( "replica_num": 0, "job_num": 0, "jobs_per_replica": 1, + "replica_group": "default", "single_branch": False, "max_duration": None, "stop_duration": 300,