Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 8 additions & 7 deletions cloud_pipelines_backend/api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,26 +182,27 @@ def list(
include_pipeline_names: bool = False,
include_execution_stats: bool = False,
) -> ListPipelineJobsResponse:
where_clauses, offset, next_token = filter_query_sql.build_list_filters(
where_clauses = filter_query_sql.build_list_filters(
filter_value=filter,
filter_query_value=filter_query,
page_token_value=page_token,
cursor_value=page_token,
current_user=current_user,
page_size=_DEFAULT_PAGE_SIZE,
)

pipeline_runs = list(
session.scalars(
sql.select(bts.PipelineRun)
.where(*where_clauses)
.order_by(bts.PipelineRun.created_at.desc())
.offset(offset)
.order_by(
bts.PipelineRun.created_at.desc(),
bts.PipelineRun.id.desc(),
)
.limit(_DEFAULT_PAGE_SIZE)
).all()
)

next_page_token = (
next_token.encode() if len(pipeline_runs) >= _DEFAULT_PAGE_SIZE else None
next_page_token = filter_query_sql.maybe_next_page_token(
rows=pipeline_runs, page_size=_DEFAULT_PAGE_SIZE
)

return ListPipelineJobsResponse(
Expand Down
6 changes: 6 additions & 0 deletions cloud_pipelines_backend/backend_types_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
IX_EXECUTION_NODE_CACHE_KEY: Final[str] = (
"ix_execution_node_container_execution_cache_key"
)
IX_PR_CREATED_AT_DESC_ID_DESC: Final[str] = "ix_pr_created_at_desc_id_desc"
IX_ANNOTATION_RUN_ID_KEY_VALUE: Final[str] = (
"ix_pipeline_run_annotation_run_id_key_value"
)
Expand Down Expand Up @@ -167,6 +168,11 @@ class PipelineRun(_TableBase):
created_by,
created_at.desc(),
),
sql.Index(
IX_PR_CREATED_AT_DESC_ID_DESC,
created_at.desc(),
id.desc(),
),
)


Expand Down
5 changes: 5 additions & 0 deletions cloud_pipelines_backend/database_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,11 @@ def migrate_db(db_engine: sqlalchemy.Engine):
index.create(db_engine, checkfirst=True)
break

for index in bts.PipelineRun.__table__.indexes:
if index.name == bts.IX_PR_CREATED_AT_DESC_ID_DESC:
index.create(db_engine, checkfirst=True)
break

backfill_created_by_annotations(db_engine=db_engine)
backfill_pipeline_name_annotations(db_engine=db_engine)

Expand Down
4 changes: 4 additions & 0 deletions cloud_pipelines_backend/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,7 @@ class MutuallyExclusiveFilterError(ApiValidationError):

class InvalidAnnotationKeyError(ApiValidationError):
pass


class InvalidPageTokenError(ApiValidationError):
pass
98 changes: 61 additions & 37 deletions cloud_pipelines_backend/filter_query_sql.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import base64
import dataclasses
import datetime
import json
import enum
Expand Down Expand Up @@ -39,26 +37,58 @@ class SystemKey(enum.StrEnum):
}

# ---------------------------------------------------------------------------
# PageToken
# Cursor encode / decode
# ---------------------------------------------------------------------------

CURSOR_SEPARATOR: Final[str] = "~"

@dataclasses.dataclass(kw_only=True)
class PageToken:
offset: int = 0
filter: str | None = None
filter_query: str | None = None

def encode(self) -> str:
return base64.b64encode(
json.dumps(dataclasses.asdict(self)).encode("utf-8")
).decode("utf-8")
def encode_cursor(created_at: datetime.datetime, run_id: str) -> str:
"""Encode the last row's position as a tilde-separated cursor string.

@classmethod
def decode(cls, token: str | None) -> "PageToken":
if not token:
return cls()
return cls(**json.loads(base64.b64decode(token)))
The created_at from PipelineRun is naive UTC (no UtcDateTime decorator on
this column). We stamp it as UTC here so the cursor string is
timezone-explicit for readability and correctness.
decode_cursor() normalizes back to naive UTC for DB comparison.
"""
if created_at.tzinfo is None:
created_at = created_at.replace(tzinfo=datetime.timezone.utc)
return f"{created_at.isoformat()}{CURSOR_SEPARATOR}{run_id}"


def decode_cursor(cursor: str | None) -> tuple[datetime.datetime, str] | None:
"""Parse a tilde-separated cursor string into (created_at, run_id).

Returns None for empty/missing cursors. Raises InvalidPageTokenError
for unrecognized formats (e.g. legacy base64 tokens).
"""
if not cursor:
return None
if CURSOR_SEPARATOR not in cursor:
raise errors.InvalidPageTokenError(
f"Unrecognized page_token format. "
f"Expected 'created_at~id' cursor. token={cursor[:20]}... (truncated)"
)
# maxsplit=1: split on first ~ only, so run_id can safely contain ~
created_at_str, run_id = cursor.split(CURSOR_SEPARATOR, 1)
created_at = datetime.datetime.fromisoformat(created_at_str)
# Normalize to naive UTC to match DB storage format (PipelineRun.created_at
# is plain DateTime, not UtcDateTime -- stores/returns naive datetimes).
if created_at.tzinfo is not None:
created_at = created_at.astimezone(datetime.timezone.utc).replace(tzinfo=None)
return created_at, run_id


def maybe_next_page_token(
*,
rows: list[bts.PipelineRun],
page_size: int,
) -> str | None:
"""Return a cursor token for the next page, or None if this is the last page."""
if len(rows) < page_size:
return None
last = rows[page_size - 1]
return encode_cursor(last.created_at, last.id)


# ---------------------------------------------------------------------------
Expand Down Expand Up @@ -159,26 +189,15 @@ def build_list_filters(
*,
filter_value: str | None,
filter_query_value: str | None,
page_token_value: str | None,
cursor_value: str | None,
current_user: str | None,
page_size: int,
) -> tuple[list[sql.ColumnElement], int, PageToken]:
"""Resolve pagination token, legacy filter, and filter_query into WHERE clauses.

Returns (where_clauses, offset, next_page_token).
"""
) -> list[sql.ColumnElement]:
"""Build WHERE clauses from filters and cursor."""
if filter_value and filter_query_value:
raise errors.MutuallyExclusiveFilterError(
"Cannot use both 'filter' and 'filter_query'. Use one or the other."
)

page_token = PageToken.decode(page_token_value)
offset = page_token.offset
filter_value = page_token.filter if page_token_value else filter_value
filter_query_value = (
page_token.filter_query if page_token_value else filter_query_value
)

if filter_value:
filter_query_value = _convert_legacy_filter_to_filter_query(
filter_value=filter_value,
Expand All @@ -194,13 +213,18 @@ def build_list_filters(
)
)

next_page_token = PageToken(
offset=offset + page_size,
filter=None,
filter_query=filter_query_value,
)
cursor = decode_cursor(cursor_value)
if cursor:
cursor_created_at, cursor_id = cursor
where_clauses.append(
sql.tuple_(bts.PipelineRun.created_at, bts.PipelineRun.id)
< sql.tuple_(
sql.literal(cursor_created_at),
sql.literal(cursor_id),
)
)

return where_clauses, offset, next_page_token
return where_clauses


def filter_query_to_where_clause(
Expand Down
72 changes: 68 additions & 4 deletions tests/test_api_server_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,7 @@ def test_list_pagination(self, session_factory, service):
)
assert len(page1.pipeline_runs) == 10
assert page1.next_page_token is not None
assert "~" in page1.next_page_token

with session_factory() as session:
page2 = service.list(
Expand All @@ -191,6 +192,70 @@ def test_list_pagination(self, session_factory, service):
assert len(page2.pipeline_runs) == 2
assert page2.next_page_token is None

def test_list_cursor_pagination_order(self, session_factory, service):
for i in range(5):
_create_run(
session_factory,
service,
root_task=_make_task_spec(f"pipeline-{i}"),
)

with session_factory() as session:
result = service.list(session=session)

dates = [r.created_at for r in result.pipeline_runs]
assert dates == sorted(dates, reverse=True)

def test_list_cursor_pagination_no_overlap(self, session_factory, service):
for i in range(12):
_create_run(
session_factory,
service,
root_task=_make_task_spec(f"pipeline-{i}"),
)

with session_factory() as session:
page1 = service.list(session=session)
with session_factory() as session:
page2 = service.list(session=session, page_token=page1.next_page_token)
page1_ids = {r.id for r in page1.pipeline_runs}
page2_ids = {r.id for r in page2.pipeline_runs}
assert page1_ids.isdisjoint(page2_ids)

def test_list_cursor_pagination_stable_under_inserts(
self, session_factory, service
):
for i in range(12):
_create_run(
session_factory,
service,
root_task=_make_task_spec(f"pipeline-{i}"),
)

with session_factory() as session:
page1 = service.list(session=session)
page1_ids = {r.id for r in page1.pipeline_runs}

_create_run(
session_factory,
service,
root_task=_make_task_spec("pipeline-new"),
)

with session_factory() as session:
page2 = service.list(session=session, page_token=page1.next_page_token)
page2_ids = {r.id for r in page2.pipeline_runs}
assert page1_ids.isdisjoint(page2_ids)
assert len(page2.pipeline_runs) == 2

def test_list_invalid_page_token_raises(self, session_factory, service):
"""page_token without ~ raises InvalidPageTokenError (422)."""
with session_factory() as session:
with pytest.raises(
errors.InvalidPageTokenError, match="Unrecognized page_token"
):
service.list(session=session, page_token="not-a-cursor")

def test_list_filter_unsupported(self, session_factory, service):
with session_factory() as session:
with pytest.raises(NotImplementedError, match="Unsupported filter"):
Expand Down Expand Up @@ -1254,7 +1319,7 @@ def test_list_filter_query_time_range_offset_timezone(
returned_ids = {r.id for r in result.pipeline_runs}
assert returned_ids == {run_b.id, run_c.id}

def test_pagination_preserves_filter_query(self, session_factory, service):
def test_pagination_with_filter_query(self, session_factory, service):
for _ in range(12):
run = _create_run(
session_factory,
Expand All @@ -1278,14 +1343,13 @@ def test_pagination_preserves_filter_query(self, session_factory, service):
)
assert len(page1.pipeline_runs) == 10
assert page1.next_page_token is not None

decoded = filter_query_sql.PageToken.decode(page1.next_page_token)
assert decoded.filter_query == fq
assert "~" in page1.next_page_token

with session_factory() as session:
page2 = service.list(
session=session,
page_token=page1.next_page_token,
filter_query=fq,
)
assert len(page2.pipeline_runs) == 2
assert page2.next_page_token is None
Expand Down
Loading