Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 6 additions & 5 deletions api_server_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,12 @@

import fastapi

from cloud_pipelines_backend import api_router
from cloud_pipelines_backend import database_ops
from cloud_pipelines_backend.instrumentation import api_tracing
from cloud_pipelines_backend.instrumentation import contextual_logging
from cloud_pipelines_backend.instrumentation import otel_tracing
from cloud_pipelines_backend import api_router, database_ops
from cloud_pipelines_backend.instrumentation import (
api_tracing,
contextual_logging,
otel_tracing,
)

app = fastapi.FastAPI(
title="Cloud Pipelines API",
Expand Down
16 changes: 6 additions & 10 deletions cloud_pipelines_backend/api_router.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,16 @@
from collections import abc
import contextlib
import dataclasses
import typing
import typing_extensions
from collections import abc

import fastapi
import sqlalchemy
from sqlalchemy import orm
import starlette.types
import typing_extensions
from sqlalchemy import orm


from . import api_server_sql
from . import backend_types_sql
from . import api_server_sql, backend_types_sql, database_ops, errors
from . import component_library_api_server as components_api
from . import database_ops
from . import errors
from .instrumentation import contextual_logging

if typing.TYPE_CHECKING:
Expand Down Expand Up @@ -137,7 +133,7 @@ def get_user_name(
def user_has_admin_permission(
user_details: typing.Annotated[UserDetails, get_user_details_dependency],
):
return user_details.permissions.get("admin") == True
return user_details.permissions.get("admin")

user_has_admin_permission_dependency = fastapi.Depends(user_has_admin_permission)

Expand Down Expand Up @@ -401,7 +397,7 @@ def get_current_user(
permissions = list(
permission
for permission, is_granted in (user_details.permissions or {}).items()
if is_granted == True
if is_granted
)
return GetUserResponse(
id=user_details.name,
Expand Down
58 changes: 24 additions & 34 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,13 @@
import json
import logging
import typing
from typing import Any, Optional
from typing import Any

if typing.TYPE_CHECKING:
from cloud_pipelines.orchestration.storage_providers import (
interfaces as storage_provider_interfaces,
)

from .launchers import interfaces as launcher_interfaces


Expand All @@ -26,8 +27,8 @@ def _get_current_time() -> datetime.datetime:
return datetime.datetime.now(tz=datetime.timezone.utc)


from . import component_structures as structures
from . import backend_types_sql as bts
from . import component_structures as structures
from . import errors
from .errors import ItemNotFoundError

Expand Down Expand Up @@ -77,9 +78,9 @@ def create(
session: orm.Session,
root_task: structures.TaskSpec,
# Component library to avoid repeating component specs inside task specs
components: Optional[list[structures.ComponentReference]] = None,
components: list[structures.ComponentReference] | None = None,
# Arbitrary metadata. Can be used to specify user.
annotations: Optional[dict[str, Any]] = None,
annotations: dict[str, Any] | None = None,
created_by: str | None = None,
) -> PipelineRunResponse:
# TODO: Validate the pipeline spec
Expand All @@ -89,7 +90,6 @@ def create(
pipeline_name = root_task.component_ref.spec.name

with session.begin():

root_execution_node = _recursively_create_all_executions_and_artifacts_root(
session=session,
root_task_spec=root_task,
Expand Down Expand Up @@ -201,7 +201,7 @@ def list(
if value:
where_clauses.append(bts.PipelineRun.created_by == value)
else:
where_clauses.append(bts.PipelineRun.created_by == None)
where_clauses.append(bts.PipelineRun.created_by.is_(None))
else:
raise NotImplementedError(f"Unsupported filter {filter}.")
pipeline_runs = list(
Expand Down Expand Up @@ -275,7 +275,7 @@ def _calculate_execution_status_stats(
bts.ExecutionToAncestorExecutionLink.ancestor_execution_id
== root_execution_id
)
.where(bts.ExecutionNode.container_execution_status != None)
.where(bts.ExecutionNode.container_execution_status.isnot(None))
.group_by(
bts.ExecutionNode.container_execution_status,
)
Expand Down Expand Up @@ -386,7 +386,7 @@ def _calculate_hash(s: str) -> str:

def _split_type_spec(
type_spec: structures.TypeSpecType | None,
) -> typing.Tuple[str | None, dict[str, Any] | None]:
) -> tuple[str | None, dict[str, Any] | None]:
if type_spec is None:
return None, None
if isinstance(type_spec, str):
Expand Down Expand Up @@ -520,7 +520,6 @@ class GetContainerExecutionLogResponse:


class ExecutionNodesApiService_Sql:

def get(self, session: orm.Session, id: bts.IdType) -> GetExecutionInfoResponse:
execution_node = session.get(bts.ExecutionNode, id)
if execution_node is None:
Expand Down Expand Up @@ -631,7 +630,7 @@ def get_graph_execution_state(
ExecutionNode_Descendant.id
== bts.ExecutionToAncestorExecutionLink.execution_id,
)
.where(ExecutionNode_Descendant.container_execution_status != None)
.where(ExecutionNode_Descendant.container_execution_status.isnot(None))
.group_by(
ExecutionNode_Child.id,
ExecutionNode_Descendant.container_execution_status,
Expand All @@ -644,7 +643,7 @@ def get_graph_execution_state(
sql.func.count().label("count"),
)
.where(ExecutionNode_Child.parent_execution_id == id)
.where(ExecutionNode_Child.container_execution_status != None)
.where(ExecutionNode_Child.container_execution_status.isnot(None))
.group_by(
ExecutionNode_Child.id,
ExecutionNode_Child.container_execution_status,
Expand Down Expand Up @@ -790,7 +789,7 @@ def get_container_execution_log(
log_uri=container_execution.log_uri,
storage_provider=storage_provider,
)
except:
except Exception:
# Do not raise exception if the execution is in SYSTEM_ERROR state
# We want to return the system error exception.
if (
Expand All @@ -801,11 +800,11 @@ def get_container_execution_log(
elif container_execution.status == bts.ContainerExecutionStatus.RUNNING:
if not container_launcher:
raise ApiServiceError(
f"Reading log of an unfinished container requires `container_launcher`."
"Reading log of an unfinished container requires `container_launcher`."
)
if not container_execution.launcher_data:
raise ApiServiceError(
f"Execution does not have container launcher data."
"Execution does not have container launcher data."
)

launched_container = (
Expand Down Expand Up @@ -833,11 +832,11 @@ def stream_container_execution_log(
container_execution = execution.container_execution
if not container_execution:
raise ApiServiceError(
f"Execution does not have container execution information."
"Execution does not have container execution information."
)
if not container_execution.launcher_data:
raise ApiServiceError(
f"Execution does not have container launcher information."
"Execution does not have container launcher information."
)
if container_execution.status == bts.ContainerExecutionStatus.RUNNING:
launched_container = (
Expand Down Expand Up @@ -882,7 +881,7 @@ def _read_container_execution_log_from_uri(

if "://" not in log_uri:
# Consider the URL to be an absolute local path (`/path` or `C:\path` or `C:/path`)
with open(log_uri, "r") as reader:
with open(log_uri) as reader:
return reader.read()
elif log_uri.startswith("gs://"):
# TODO: Switch to using storage providers.
Expand Down Expand Up @@ -966,7 +965,6 @@ class GetArtifactSignedUrlResponse:


class ArtifactNodesApiService_Sql:

def get(self, session: orm.Session, id: bts.IdType) -> GetArtifactInfoResponse:
artifact_node = session.get(bts.ArtifactNode, id)
if artifact_node is None:
Expand All @@ -990,14 +988,14 @@ def get_signed_artifact_url(
if not artifact_data.uri:
raise ValueError(f"Artifact node with {id=} does not have artifact URI.")
if artifact_data.is_dir:
raise ValueError(f"Cannot generate signer URL for a directory artifact.")
raise ValueError("Cannot generate signer URL for a directory artifact.")
if not artifact_data.uri.startswith("gs://"):
raise ValueError(
f"The get_signed_artifact_url method only supports Google Cloud Storage URIs, but got {artifact_data.uri=}."
)

from google.cloud import storage
from google import auth
from google.cloud import storage

# Avoiding error: "you need a private key to sign credentials."
# "the credentials you are currently using <class 'google.auth.compute_engine.credentials.Credentials'> just contains a token.
Expand Down Expand Up @@ -1040,7 +1038,6 @@ class ListSecretsResponse:


class SecretsApiService:

def create_secret(
self,
*,
Expand All @@ -1053,7 +1050,7 @@ def create_secret(
) -> SecretInfoResponse:
secret_name = secret_name.strip()
if not secret_name:
raise ApiServiceError(f"Secret name must not be empty.")
raise ApiServiceError("Secret name must not be empty.")
return self._create_or_update_secret(
session=session,
user_id=user_id,
Expand Down Expand Up @@ -1166,9 +1163,7 @@ def list_secrets(
# No. Decided to first do topological sort and then 1-stage generation.


_ArtifactNodeOrDynamicDataType = typing.Union[
bts.ArtifactNode, structures.DynamicDataArgument
]
_ArtifactNodeOrDynamicDataType = bts.ArtifactNode | structures.DynamicDataArgument


def _recursively_create_all_executions_and_artifacts_root(
Expand Down Expand Up @@ -1262,7 +1257,6 @@ def _recursively_create_all_executions_and_artifacts(

# FIX: Handle ExecutionNode.constant_arguments
# We do not touch root_task_spec.arguments. We use graph_input_artifact_nodes instead
constant_input_artifacts: dict[str, bts.ArtifactData] = {}
input_artifact_nodes = dict(input_artifact_nodes)
for input_spec in root_component_spec.inputs or []:
input_artifact_node = input_artifact_nodes.get(input_spec.name)
Expand Down Expand Up @@ -1490,13 +1484,11 @@ def _toposort_tasks(
dependencies[argument.task_output.task_id] = True
if argument.task_output.task_id not in tasks:
raise TypeError(
'Argument "{}" references non-existing task.'.format(
argument
)
f'Argument "{argument}" references non-existing task.'
)

# Topologically sorting tasks to detect cycles
task_dependents = {k: {} for k in task_dependencies.keys()}
task_dependents = {k: {} for k in task_dependencies}
for task_id, dependencies in task_dependencies.items():
for dependency in dependencies:
task_dependents[dependency][task_id] = True
Expand All @@ -1515,7 +1507,7 @@ def process_task(task_id):
task_number_of_remaining_dependencies[dependent_task] -= 1
process_task(dependent_task)

for task_id in task_dependencies.keys():
for task_id in task_dependencies:
process_task(task_id)
if len(sorted_tasks) != len(task_dependencies):
tasks_with_unsatisfied_dependencies = {
Expand All @@ -1526,9 +1518,7 @@ def process_task(task_id):
key=lambda task_id: tasks_with_unsatisfied_dependencies[task_id],
)
raise ValueError(
'Task "{}" has cyclical dependency.'.format(
task_with_minimal_number_of_unsatisfied_dependencies
)
f'Task "{task_with_minimal_number_of_unsatisfied_dependencies}" has cyclical dependency.'
)

return sorted_tasks
Expand Down
2 changes: 1 addition & 1 deletion cloud_pipelines_backend/backend_types_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def generate_unique_id() -> str:
nanoseconds = time.time_ns()
milliseconds = nanoseconds // 1_000_000

return ("%012x" % milliseconds) + random_bytes.hex()
return f"{milliseconds:012x}" + random_bytes.hex()


id_column = orm.mapped_column(
Expand Down
22 changes: 9 additions & 13 deletions cloud_pipelines_backend/component_library_api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,12 @@
import hashlib
from typing import Any

import yaml
import sqlalchemy as sql
import yaml
from sqlalchemy import orm

from . import backend_types_sql as bts
from . import errors
from . import component_structures
from . import component_structures, errors


def calculate_digest_for_component_text(text: str) -> str:
Expand Down Expand Up @@ -192,7 +191,7 @@ def list(
if digest:
query = query.filter(PublishedComponentRow.digest == digest)
if not include_deprecated:
query = query.filter(PublishedComponentRow.deprecated == False)
query = query.filter(PublishedComponentRow.deprecated.is_(False))
if name_substring:
query = query.filter(
PublishedComponentRow.name.icontains(name_substring, autoescape=True)
Expand Down Expand Up @@ -243,7 +242,7 @@ def publish(
digest = component_ref.digest
if not (digest and component_text):
raise ValueError(
f"Component text is missing, cannot get component by digest (or digest is missing). Currently we cannot get component by URL for security reasons (you can get text from url yourself before publishing)."
"Component text is missing, cannot get component by digest (or digest is missing). Currently we cannot get component by URL for security reasons (you can get text from url yourself before publishing)."
)

component_spec = load_component_spec_from_text_and_validate(component_text)
Expand Down Expand Up @@ -372,7 +371,7 @@ def make_empty_user_library(user_name: str) -> "ComponentLibraryRow":
user_name = id.partition(":")[2]
else:
raise ValueError(
f"make_empty_user_library only supports user component libraries."
"make_empty_user_library only supports user component libraries."
)

name = f"{user_name} components"
Expand Down Expand Up @@ -443,7 +442,7 @@ def list(
# TODO: Implement filtering by user, URL
# TODO: Implement visibility/access control
query = sql.select(ComponentLibraryRow).filter(
ComponentLibraryRow.hide_from_search == False
ComponentLibraryRow.hide_from_search.is_(False)
)
if name_substring:
query = query.filter(
Expand Down Expand Up @@ -520,14 +519,12 @@ def _prepare_new_library_and_publish_components(
service = PublishedComponentService()
session.rollback()
component_count = 0
for (
component_ref
) in ComponentLibraryService._recursively_iterate_over_all_component_refs_in_library_folder(
for component_ref in ComponentLibraryService._recursively_iterate_over_all_component_refs_in_library_folder(
library.root_folder
):
if not component_ref.text:
# TODO: Support publishing component from URL
raise ValueError(f"Currently every library component must have text.")
raise ValueError("Currently every library component must have text.")
digest = calculate_digest_for_component_text(component_ref.text)
if publish_components:
try:
Expand Down Expand Up @@ -669,8 +666,7 @@ def _recursively_iterate_over_all_component_refs_in_library_folder(
yield from ComponentLibraryService._recursively_iterate_over_all_component_refs_in_library_folder(
child_folder
)
for component_ref in library_folder.components or []:
yield component_ref
yield from library_folder.components or []


### UserService
Expand Down
Loading