Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,14 @@ class ListPipelineJobsResponse:
class PipelineRunsApiService_Sql:
_PIPELINE_NAME_EXTRA_DATA_KEY = "pipeline_name"
_DEFAULT_PAGE_SIZE: Final[int] = 10
_SYSTEM_KEY_RESERVED_MSG = (
"Annotation keys starting with "
f"{filter_query_sql.SYSTEM_KEY_PREFIX!r} are reserved for system use."
)

def _fail_if_changing_system_annotation(self, *, key: str) -> None:
if key.startswith(filter_query_sql.SYSTEM_KEY_PREFIX):
raise errors.ApiValidationError(self._SYSTEM_KEY_RESERVED_MSG)

def create(
self,
Expand Down Expand Up @@ -105,6 +113,19 @@ def create(
},
)
session.add(pipeline_run)
# Mirror created_by into the annotations table so it's searchable
# via filter_query like any other annotation.
if created_by is not None:
# Flush to populate pipeline_run.id (server-generated) before inserting the annotation FK.
# TODO: Use ORM relationship instead of explicit flush + manual FK assignment.
session.flush()
session.add(
bts.PipelineRunAnnotation(
pipeline_run_id=pipeline_run.id,
key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY,
value=created_by,
)
)
session.commit()

session.refresh(pipeline_run)
Expand Down Expand Up @@ -295,6 +316,7 @@ def set_annotation(
user_name: str | None = None,
skip_user_check: bool = False,
):
self._fail_if_changing_system_annotation(key=key)
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
Expand All @@ -317,6 +339,7 @@ def delete_annotation(
user_name: str | None = None,
skip_user_check: bool = False,
):
self._fail_if_changing_system_annotation(key=key)
pipeline_run = session.get(bts.PipelineRun, id)
if not pipeline_run:
raise errors.ItemNotFoundError(f"Pipeline run {id} not found.")
Expand Down
59 changes: 59 additions & 0 deletions cloud_pipelines_backend/database_ops.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import sqlalchemy
from sqlalchemy import orm

from . import backend_types_sql as bts
from . import filter_query_sql


def create_db_engine_and_migrate_db(
Expand Down Expand Up @@ -83,3 +85,60 @@ def migrate_db(db_engine: sqlalchemy.Engine):
if index.name == bts.PipelineRunAnnotation._IX_ANNOTATION_RUN_ID_KEY_VALUE:
index.create(db_engine, checkfirst=True)
break

_backfill_pipeline_run_created_by_annotations(db_engine=db_engine)


def _is_pipeline_run_annotation_key_already_backfilled(
*,
session: orm.Session,
key: str,
) -> bool:
"""Return True if at least one annotation with the given key exists."""
return session.query(
sqlalchemy.exists(
sqlalchemy.select(sqlalchemy.literal(1))
.select_from(bts.PipelineRunAnnotation)
.where(
bts.PipelineRunAnnotation.key == key,
)
)
).scalar()


def _backfill_pipeline_run_created_by_annotations(
*,
db_engine: sqlalchemy.Engine,
) -> None:
"""Copy pipeline_run.created_by into pipeline_run_annotation so
annotation-based search works for created_by.

The check and insert run in a single session/transaction to avoid
TOCTOU races between concurrent startup processes.

Skips entirely if any created_by annotation key already exists (i.e. the
write-path is populating them, so the backfill has already run or is
no longer needed).
"""
with orm.Session(db_engine) as session:
if _is_pipeline_run_annotation_key_already_backfilled(
session=session,
key=filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY,
):
return

stmt = sqlalchemy.insert(bts.PipelineRunAnnotation).from_select(
["pipeline_run_id", "key", "value"],
sqlalchemy.select(
bts.PipelineRun.id,
sqlalchemy.literal(
filter_query_sql.PipelineRunAnnotationSystemKey.CREATED_BY
),
bts.PipelineRun.created_by,
).where(
bts.PipelineRun.created_by.isnot(None),
bts.PipelineRun.created_by != "",
),
)
session.execute(stmt)
session.commit()
33 changes: 29 additions & 4 deletions cloud_pipelines_backend/filter_query_models.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

from abc import abstractmethod
from typing import Annotated

import pydantic
Expand Down Expand Up @@ -58,21 +59,45 @@ def _at_least_one_time_bound(self) -> TimeRange:
# --- Predicate wrapper models (one field each) ---


class KeyExistsPredicate(_BaseModel):
class KeyPredicateBase(_BaseModel):
"""Base for predicates that target an annotation key."""

@property
@abstractmethod
def key(self) -> str: ...
Comment on lines +64 to +67
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's try to just use

key: str?

Then the derived classes should have it automatically (no need to override with a property.)

AFAIK, key cannot be a property, since that might not be recognized by Pydantic.



class KeyExistsPredicate(KeyPredicateBase):
key_exists: KeyExists

@property
def key(self) -> str:
return self.key_exists.key


class ValueContainsPredicate(_BaseModel):
class ValueContainsPredicate(KeyPredicateBase):
value_contains: ValueContains

@property
def key(self) -> str:
return self.value_contains.key

class ValueInPredicate(_BaseModel):

class ValueInPredicate(KeyPredicateBase):
value_in: ValueIn

@property
def key(self) -> str:
return self.value_in.key


class ValueEqualsPredicate(_BaseModel):
class ValueEqualsPredicate(KeyPredicateBase):
value_equals: ValueEquals

@property
def key(self) -> str:
return self.value_equals.key


class TimeRangePredicate(_BaseModel):
time_range: TimeRange
Expand Down
126 changes: 120 additions & 6 deletions cloud_pipelines_backend/filter_query_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import base64
import json
import enum
from typing import Any, Final

import sqlalchemy as sql
Expand All @@ -8,6 +9,21 @@
from . import errors
from . import filter_query_models

SYSTEM_KEY_PREFIX: Final[str] = "system/"


class PipelineRunAnnotationSystemKey(enum.StrEnum):
CREATED_BY = f"{SYSTEM_KEY_PREFIX}pipeline_run.created_by"


SYSTEM_KEY_SUPPORTED_PREDICATES: dict[PipelineRunAnnotationSystemKey, set[type]] = {
PipelineRunAnnotationSystemKey.CREATED_BY: {
filter_query_models.KeyExistsPredicate,
filter_query_models.ValueEqualsPredicate,
filter_query_models.ValueInPredicate,
},
}

# ---------------------------------------------------------------------------
# Page-token helpers
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -44,6 +60,78 @@ def _resolve_filter_value(
return filter, filter_query, offset


# ---------------------------------------------------------------------------
# PipelineRunAnnotationSystemKey validation and resolution
# ---------------------------------------------------------------------------


def _check_predicate_allowed(*, predicate: filter_query_models.Predicate) -> None:
"""Raise if a system key is used with an unsupported predicate type."""
if not isinstance(predicate, filter_query_models.KeyPredicateBase):
return
key = predicate.key

try:
system_key = PipelineRunAnnotationSystemKey(key)
except ValueError:
return

supported = SYSTEM_KEY_SUPPORTED_PREDICATES.get(system_key, set())
if type(predicate) not in supported:
raise errors.ApiValidationError(
f"Predicate {type(predicate).__name__} is not supported "
f"for system key {system_key!r}. "
f"Supported: {[t.__name__ for t in supported]}"
)


def _resolve_system_key_value(
*,
key: str,
value: str,
current_user: str | None,
) -> str:
"""Resolve special placeholder values for system keys."""
if key == PipelineRunAnnotationSystemKey.CREATED_BY and value == "me":
return current_user if current_user is not None else ""
return value


def _maybe_resolve_system_values(
*,
predicate: filter_query_models.ValueEqualsPredicate,
current_user: str | None,
) -> filter_query_models.ValueEqualsPredicate:
"""Resolve special values in a ValueEqualsPredicate."""
key = predicate.value_equals.key
value = predicate.value_equals.value
resolved = _resolve_system_key_value(
key=key,
value=value,
current_user=current_user,
)
if resolved != value:
return filter_query_models.ValueEqualsPredicate(
value_equals=filter_query_models.ValueEquals(key=key, value=resolved)
)
return predicate


def _validate_and_resolve_predicate(
*,
predicate: filter_query_models.Predicate,
current_user: str | None,
) -> filter_query_models.Predicate:
"""Validate system key support, then resolve special values."""
_check_predicate_allowed(predicate=predicate)
if isinstance(predicate, filter_query_models.ValueEqualsPredicate):
return _maybe_resolve_system_values(
predicate=predicate,
current_user=current_user,
)
return predicate


# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -79,7 +167,12 @@ def build_list_filters(

if filter_query_value:
parsed = filter_query_models.FilterQuery.model_validate_json(filter_query_value)
where_clauses.append(filter_query_to_where_clause(filter_query=parsed))
where_clauses.append(
filter_query_to_where_clause(
filter_query=parsed,
current_user=current_user,
)
)

next_page_token = _encode_page_token(
page_token_dict={
Expand All @@ -95,10 +188,13 @@ def build_list_filters(
def filter_query_to_where_clause(
*,
filter_query: filter_query_models.FilterQuery,
current_user: str | None = None,
) -> sql.ColumnElement:
predicates = filter_query.and_ or filter_query.or_
is_and = filter_query.and_ is not None
clauses = [_predicate_to_clause(predicate=p) for p in predicates]
clauses = [
_predicate_to_clause(predicate=p, current_user=current_user) for p in predicates
]
return sql.and_(*clauses) if is_and else sql.or_(*clauses)


Expand Down Expand Up @@ -163,17 +259,35 @@ def _build_filter_where_clauses(

def _predicate_to_clause(
*,
predicate,
predicate: filter_query_models.Predicate,
current_user: str | None = None,
) -> sql.ColumnElement:
predicate = _validate_and_resolve_predicate(
predicate=predicate,
current_user=current_user,
)

match predicate:
case filter_query_models.AndPredicate():
return sql.and_(
*[_predicate_to_clause(predicate=p) for p in predicate.and_]
*[
_predicate_to_clause(predicate=p, current_user=current_user)
for p in predicate.and_
]
)
case filter_query_models.OrPredicate():
return sql.or_(*[_predicate_to_clause(predicate=p) for p in predicate.or_])
return sql.or_(
*[
_predicate_to_clause(predicate=p, current_user=current_user)
for p in predicate.or_
]
)
case filter_query_models.NotPredicate():
return sql.not_(_predicate_to_clause(predicate=predicate.not_))
return sql.not_(
_predicate_to_clause(
predicate=predicate.not_, current_user=current_user
)
)
case filter_query_models.KeyExistsPredicate():
return _key_exists_to_clause(predicate=predicate)
case filter_query_models.ValueEqualsPredicate():
Expand Down
Loading