diff --git a/cloud_pipelines_backend/api_server_sql.py b/cloud_pipelines_backend/api_server_sql.py index 0e4cb1d..101eec6 100644 --- a/cloud_pipelines_backend/api_server_sql.py +++ b/cloud_pipelines_backend/api_server_sql.py @@ -182,26 +182,27 @@ def list( include_pipeline_names: bool = False, include_execution_stats: bool = False, ) -> ListPipelineJobsResponse: - where_clauses, offset, next_token = filter_query_sql.build_list_filters( + where_clauses = filter_query_sql.build_list_filters( filter_value=filter, filter_query_value=filter_query, - page_token_value=page_token, + cursor_value=page_token, current_user=current_user, - page_size=_DEFAULT_PAGE_SIZE, ) pipeline_runs = list( session.scalars( sql.select(bts.PipelineRun) .where(*where_clauses) - .order_by(bts.PipelineRun.created_at.desc()) - .offset(offset) + .order_by( + bts.PipelineRun.created_at.desc(), + bts.PipelineRun.id.desc(), + ) .limit(_DEFAULT_PAGE_SIZE) ).all() ) - next_page_token = ( - next_token.encode() if len(pipeline_runs) >= _DEFAULT_PAGE_SIZE else None + next_page_token = filter_query_sql.maybe_next_page_token( + rows=pipeline_runs, page_size=_DEFAULT_PAGE_SIZE ) return ListPipelineJobsResponse( diff --git a/cloud_pipelines_backend/backend_types_sql.py b/cloud_pipelines_backend/backend_types_sql.py index feb54e1..99a020e 100644 --- a/cloud_pipelines_backend/backend_types_sql.py +++ b/cloud_pipelines_backend/backend_types_sql.py @@ -13,6 +13,7 @@ IX_EXECUTION_NODE_CACHE_KEY: Final[str] = ( "ix_execution_node_container_execution_cache_key" ) +IX_PR_CREATED_AT_DESC_ID_DESC: Final[str] = "ix_pr_created_at_desc_id_desc" IX_ANNOTATION_RUN_ID_KEY_VALUE: Final[str] = ( "ix_pipeline_run_annotation_run_id_key_value" ) @@ -167,6 +168,11 @@ class PipelineRun(_TableBase): created_by, created_at.desc(), ), + sql.Index( + IX_PR_CREATED_AT_DESC_ID_DESC, + created_at.desc(), + id.desc(), + ), ) diff --git a/cloud_pipelines_backend/database_ops.py b/cloud_pipelines_backend/database_ops.py index 348256d..9504584 100644 --- a/cloud_pipelines_backend/database_ops.py +++ b/cloud_pipelines_backend/database_ops.py @@ -92,6 +92,11 @@ def migrate_db(db_engine: sqlalchemy.Engine): index.create(db_engine, checkfirst=True) break + for index in bts.PipelineRun.__table__.indexes: + if index.name == bts.IX_PR_CREATED_AT_DESC_ID_DESC: + index.create(db_engine, checkfirst=True) + break + backfill_created_by_annotations(db_engine=db_engine) backfill_pipeline_name_annotations(db_engine=db_engine) diff --git a/cloud_pipelines_backend/errors.py b/cloud_pipelines_backend/errors.py index 23c7dd9..379b617 100644 --- a/cloud_pipelines_backend/errors.py +++ b/cloud_pipelines_backend/errors.py @@ -22,3 +22,7 @@ class MutuallyExclusiveFilterError(ApiValidationError): class InvalidAnnotationKeyError(ApiValidationError): pass + + +class InvalidPageTokenError(ApiValidationError): + pass diff --git a/cloud_pipelines_backend/filter_query_sql.py b/cloud_pipelines_backend/filter_query_sql.py index f5c6138..bc4de5a 100644 --- a/cloud_pipelines_backend/filter_query_sql.py +++ b/cloud_pipelines_backend/filter_query_sql.py @@ -1,5 +1,3 @@ -import base64 -import dataclasses import datetime import json import enum @@ -39,26 +37,58 @@ class SystemKey(enum.StrEnum): } # --------------------------------------------------------------------------- -# PageToken +# Cursor encode / decode # --------------------------------------------------------------------------- +CURSOR_SEPARATOR: Final[str] = "~" -@dataclasses.dataclass(kw_only=True) -class PageToken: - offset: int = 0 - filter: str | None = None - filter_query: str | None = None - def encode(self) -> str: - return base64.b64encode( - json.dumps(dataclasses.asdict(self)).encode("utf-8") - ).decode("utf-8") +def encode_cursor(created_at: datetime.datetime, run_id: str) -> str: + """Encode the last row's position as a tilde-separated cursor string. - @classmethod - def decode(cls, token: str | None) -> "PageToken": - if not token: - return cls() - return cls(**json.loads(base64.b64decode(token))) + The created_at from PipelineRun is naive UTC (no UtcDateTime decorator on + this column). We stamp it as UTC here so the cursor string is + timezone-explicit for readability and correctness. + decode_cursor() normalizes back to naive UTC for DB comparison. + """ + if created_at.tzinfo is None: + created_at = created_at.replace(tzinfo=datetime.timezone.utc) + return f"{created_at.isoformat()}{CURSOR_SEPARATOR}{run_id}" + + +def decode_cursor(cursor: str | None) -> tuple[datetime.datetime, str] | None: + """Parse a tilde-separated cursor string into (created_at, run_id). + + Returns None for empty/missing cursors. Raises InvalidPageTokenError + for unrecognized formats (e.g. legacy base64 tokens). + """ + if not cursor: + return None + if CURSOR_SEPARATOR not in cursor: + raise errors.InvalidPageTokenError( + f"Unrecognized page_token format. " + f"Expected 'created_at~id' cursor. token={cursor[:20]}... (truncated)" + ) + # maxsplit=1: split on first ~ only, so run_id can safely contain ~ + created_at_str, run_id = cursor.split(CURSOR_SEPARATOR, 1) + created_at = datetime.datetime.fromisoformat(created_at_str) + # Normalize to naive UTC to match DB storage format (PipelineRun.created_at + # is plain DateTime, not UtcDateTime -- stores/returns naive datetimes). + if created_at.tzinfo is not None: + created_at = created_at.astimezone(datetime.timezone.utc).replace(tzinfo=None) + return created_at, run_id + + +def maybe_next_page_token( + *, + rows: list[bts.PipelineRun], + page_size: int, +) -> str | None: + """Return a cursor token for the next page, or None if this is the last page.""" + if len(rows) < page_size: + return None + last = rows[page_size - 1] + return encode_cursor(last.created_at, last.id) # --------------------------------------------------------------------------- @@ -159,26 +189,15 @@ def build_list_filters( *, filter_value: str | None, filter_query_value: str | None, - page_token_value: str | None, + cursor_value: str | None, current_user: str | None, - page_size: int, -) -> tuple[list[sql.ColumnElement], int, PageToken]: - """Resolve pagination token, legacy filter, and filter_query into WHERE clauses. - - Returns (where_clauses, offset, next_page_token). - """ +) -> list[sql.ColumnElement]: + """Build WHERE clauses from filters and cursor.""" if filter_value and filter_query_value: raise errors.MutuallyExclusiveFilterError( "Cannot use both 'filter' and 'filter_query'. Use one or the other." ) - page_token = PageToken.decode(page_token_value) - offset = page_token.offset - filter_value = page_token.filter if page_token_value else filter_value - filter_query_value = ( - page_token.filter_query if page_token_value else filter_query_value - ) - if filter_value: filter_query_value = _convert_legacy_filter_to_filter_query( filter_value=filter_value, @@ -194,13 +213,18 @@ def build_list_filters( ) ) - next_page_token = PageToken( - offset=offset + page_size, - filter=None, - filter_query=filter_query_value, - ) + cursor = decode_cursor(cursor_value) + if cursor: + cursor_created_at, cursor_id = cursor + where_clauses.append( + sql.tuple_(bts.PipelineRun.created_at, bts.PipelineRun.id) + < sql.tuple_( + sql.literal(cursor_created_at), + sql.literal(cursor_id), + ) + ) - return where_clauses, offset, next_page_token + return where_clauses def filter_query_to_where_clause( diff --git a/tests/test_api_server_sql.py b/tests/test_api_server_sql.py index 16b6210..d7bb898 100644 --- a/tests/test_api_server_sql.py +++ b/tests/test_api_server_sql.py @@ -182,6 +182,7 @@ def test_list_pagination(self, session_factory, service): ) assert len(page1.pipeline_runs) == 10 assert page1.next_page_token is not None + assert "~" in page1.next_page_token with session_factory() as session: page2 = service.list( @@ -191,6 +192,70 @@ def test_list_pagination(self, session_factory, service): assert len(page2.pipeline_runs) == 2 assert page2.next_page_token is None + def test_list_cursor_pagination_order(self, session_factory, service): + for i in range(5): + _create_run( + session_factory, + service, + root_task=_make_task_spec(f"pipeline-{i}"), + ) + + with session_factory() as session: + result = service.list(session=session) + + dates = [r.created_at for r in result.pipeline_runs] + assert dates == sorted(dates, reverse=True) + + def test_list_cursor_pagination_no_overlap(self, session_factory, service): + for i in range(12): + _create_run( + session_factory, + service, + root_task=_make_task_spec(f"pipeline-{i}"), + ) + + with session_factory() as session: + page1 = service.list(session=session) + with session_factory() as session: + page2 = service.list(session=session, page_token=page1.next_page_token) + page1_ids = {r.id for r in page1.pipeline_runs} + page2_ids = {r.id for r in page2.pipeline_runs} + assert page1_ids.isdisjoint(page2_ids) + + def test_list_cursor_pagination_stable_under_inserts( + self, session_factory, service + ): + for i in range(12): + _create_run( + session_factory, + service, + root_task=_make_task_spec(f"pipeline-{i}"), + ) + + with session_factory() as session: + page1 = service.list(session=session) + page1_ids = {r.id for r in page1.pipeline_runs} + + _create_run( + session_factory, + service, + root_task=_make_task_spec("pipeline-new"), + ) + + with session_factory() as session: + page2 = service.list(session=session, page_token=page1.next_page_token) + page2_ids = {r.id for r in page2.pipeline_runs} + assert page1_ids.isdisjoint(page2_ids) + assert len(page2.pipeline_runs) == 2 + + def test_list_invalid_page_token_raises(self, session_factory, service): + """page_token without ~ raises InvalidPageTokenError (422).""" + with session_factory() as session: + with pytest.raises( + errors.InvalidPageTokenError, match="Unrecognized page_token" + ): + service.list(session=session, page_token="not-a-cursor") + def test_list_filter_unsupported(self, session_factory, service): with session_factory() as session: with pytest.raises(NotImplementedError, match="Unsupported filter"): @@ -1254,7 +1319,7 @@ def test_list_filter_query_time_range_offset_timezone( returned_ids = {r.id for r in result.pipeline_runs} assert returned_ids == {run_b.id, run_c.id} - def test_pagination_preserves_filter_query(self, session_factory, service): + def test_pagination_with_filter_query(self, session_factory, service): for _ in range(12): run = _create_run( session_factory, @@ -1278,14 +1343,13 @@ def test_pagination_preserves_filter_query(self, session_factory, service): ) assert len(page1.pipeline_runs) == 10 assert page1.next_page_token is not None - - decoded = filter_query_sql.PageToken.decode(page1.next_page_token) - assert decoded.filter_query == fq + assert "~" in page1.next_page_token with session_factory() as session: page2 = service.list( session=session, page_token=page1.next_page_token, + filter_query=fq, ) assert len(page2.pipeline_runs) == 2 assert page2.next_page_token is None diff --git a/tests/test_filter_query_sql.py b/tests/test_filter_query_sql.py index 5809612..ac4bad4 100644 --- a/tests/test_filter_query_sql.py +++ b/tests/test_filter_query_sql.py @@ -1,3 +1,4 @@ +import datetime import json import pydantic @@ -479,38 +480,108 @@ def test_time_range_naive_datetime_rejected(self): ) -class TestPageToken: - def test_decode_none(self): - token = filter_query_sql.PageToken.decode(None) - assert token.offset == 0 - assert token.filter is None - assert token.filter_query is None - - def test_encode_decode_roundtrip(self): - original = filter_query_sql.PageToken( - offset=20, - filter="created_by:alice", - filter_query='{"and": [{"key_exists": {"key": "team"}}]}', - ) - encoded = original.encode() - decoded = filter_query_sql.PageToken.decode(encoded) - assert decoded.offset == 20 - assert decoded.filter == "created_by:alice" - assert decoded.filter_query == '{"and": [{"key_exists": {"key": "team"}}]}' - - def test_decode_with_filter_query(self): - fq_json = '{"or": [{"value_equals": {"key": "env", "value": "prod"}}]}' - original = filter_query_sql.PageToken(offset=10, filter_query=fq_json) - decoded = filter_query_sql.PageToken.decode(original.encode()) - assert decoded.filter_query == fq_json - assert decoded.filter is None - assert decoded.offset == 10 - - def test_decode_empty_string(self): - token = filter_query_sql.PageToken.decode("") - assert token.offset == 0 - assert token.filter is None - assert token.filter_query is None +class TestCursorEncodeDecode: + def test_encode_cursor_roundtrip(self): + naive_dt = datetime.datetime(2024, 2, 1, 9, 0, 0) + run_id = "018d8fff0000aaaabbbb" + cursor = filter_query_sql.encode_cursor(naive_dt, run_id) + decoded = filter_query_sql.decode_cursor(cursor) + assert decoded is not None + assert decoded == (naive_dt, run_id) + + def test_encode_cursor_stamps_utc(self): + naive_dt = datetime.datetime(2024, 2, 1, 9, 0, 0) + cursor = filter_query_sql.encode_cursor(naive_dt, "abc123") + assert "+00:00" in cursor + + def test_encode_cursor_already_aware(self): + aware_dt = datetime.datetime(2024, 2, 1, 9, 0, 0, tzinfo=datetime.timezone.utc) + cursor = filter_query_sql.encode_cursor(aware_dt, "abc123") + assert "+00:00" in cursor + decoded = filter_query_sql.decode_cursor(cursor) + assert decoded is not None + assert decoded == (datetime.datetime(2024, 2, 1, 9, 0, 0), "abc123") + + def test_decode_cursor_none(self): + assert filter_query_sql.decode_cursor(None) is None + + def test_decode_cursor_empty_string(self): + assert filter_query_sql.decode_cursor("") is None + + def test_decode_cursor_no_tilde_raises(self): + """Token without ~ raises InvalidPageTokenError.""" + with pytest.raises( + errors.InvalidPageTokenError, match="Unrecognized page_token" + ): + filter_query_sql.decode_cursor("not-a-cursor") + + def test_decode_cursor_tilde_separated(self): + cursor = "2024-02-01T09:00:00+00:00~018d8fff0000" + result = filter_query_sql.decode_cursor(cursor) + assert result is not None + created_at, run_id = result + assert created_at == datetime.datetime(2024, 2, 1, 9, 0, 0) + assert created_at.tzinfo is None + assert run_id == "018d8fff0000" + + def test_decode_cursor_naive_fallback(self): + """Cursor without timezone parses correctly as naive.""" + cursor = "2024-02-01T09:00:00~abc123" + result = filter_query_sql.decode_cursor(cursor) + assert result is not None + created_at, run_id = result + assert created_at == datetime.datetime(2024, 2, 1, 9, 0, 0) + assert created_at.tzinfo is None + assert run_id == "abc123" + + +class _FakeRow: + """Minimal stand-in for bts.PipelineRun with only the fields cursor logic needs.""" + + def __init__(self, *, created_at: datetime.datetime, id: str): + self.created_at = created_at + self.id = id + + +class TestMaybeNextPageToken: + def test_returns_none_when_fewer_than_page_size(self): + rows = [_FakeRow(created_at=datetime.datetime(2024, 1, 1), id="a")] + assert filter_query_sql.maybe_next_page_token(rows=rows, page_size=10) is None + + def test_returns_none_when_empty(self): + assert filter_query_sql.maybe_next_page_token(rows=[], page_size=10) is None + + def test_returns_cursor_when_full_page(self): + rows = [ + _FakeRow( + created_at=datetime.datetime(2024, 1, 1, 12 - i, 0, 0), id=f"id-{i}" + ) + for i in range(10) + ] + token = filter_query_sql.maybe_next_page_token(rows=rows, page_size=10) + assert token is not None + assert "~" in token + decoded = filter_query_sql.decode_cursor(token) + assert decoded is not None + assert decoded == (rows[-1].created_at, rows[-1].id) + + def test_returns_cursor_at_page_boundary(self): + """Even with more rows than page_size, cursor points to the page_size-th row.""" + rows = [ + _FakeRow(created_at=datetime.datetime(2024, 1, 1), id=f"id-{i}") + for i in range(15) + ] + token = filter_query_sql.maybe_next_page_token(rows=rows, page_size=10) + assert token is not None + decoded = filter_query_sql.decode_cursor(token) + assert decoded == (rows[9].created_at, rows[9].id) + + def test_returns_none_at_exact_boundary_minus_one(self): + rows = [ + _FakeRow(created_at=datetime.datetime(2024, 1, 1), id=f"id-{i}") + for i in range(9) + ] + assert filter_query_sql.maybe_next_page_token(rows=rows, page_size=10) is None class TestConvertLegacyFilterToFilterQuery: @@ -558,18 +629,13 @@ def test_text_search_raises(self): class TestBuildListFilters: def test_no_filters(self): - clauses, offset, next_token = filter_query_sql.build_list_filters( + clauses = filter_query_sql.build_list_filters( filter_value=None, filter_query_value=None, - page_token_value=None, + cursor_value=None, current_user=None, - page_size=10, ) assert clauses == [] - assert offset == 0 - assert next_token.offset == 10 - assert next_token.filter is None - assert next_token.filter_query is None def test_mutual_exclusivity_raises(self): with pytest.raises( @@ -578,99 +644,117 @@ def test_mutual_exclusivity_raises(self): filter_query_sql.build_list_filters( filter_value="created_by:alice", filter_query_value='{"and": [{"key_exists": {"key": "team"}}]}', - page_token_value=None, + cursor_value=None, current_user=None, - page_size=10, ) def test_legacy_filter_produces_annotation_clause(self): - clauses, offset, next_token = filter_query_sql.build_list_filters( + clauses = filter_query_sql.build_list_filters( filter_value="created_by:alice", filter_query_value=None, - page_token_value=None, + cursor_value=None, current_user=None, - page_size=10, ) assert len(clauses) == 1 - compiled = _compile(clauses[0]) - assert "EXISTS" in compiled.upper() - assert "pipeline_run_annotation" in compiled - assert offset == 0 - assert next_token.filter is None - assert next_token.filter_query is not None + assert _compile(clauses[0]) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id " + "FROM pipeline_run_annotation, pipeline_run " + "WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id " + "AND pipeline_run_annotation.\"key\" = 'system/pipeline_run.created_by' " + "AND pipeline_run_annotation.value = 'alice')" + ) def test_filter_query_produces_clauses(self): fq = '{"and": [{"key_exists": {"key": "team"}}]}' - clauses, offset, next_token = filter_query_sql.build_list_filters( + clauses = filter_query_sql.build_list_filters( filter_value=None, filter_query_value=fq, - page_token_value=None, + cursor_value=None, current_user=None, - page_size=10, ) assert len(clauses) == 1 - compiled = _compile(clauses[0]) - assert "EXISTS" in compiled.upper() - assert next_token.filter_query == fq + assert _compile(clauses[0]) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id " + "FROM pipeline_run_annotation, pipeline_run " + "WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id " + "AND pipeline_run_annotation.\"key\" = 'team')" + ) - def test_page_token_with_legacy_filter_converts(self): - token = filter_query_sql.PageToken( - offset=20, - filter="created_by:alice", + def test_cursor_where_clause_generated(self): + cursor = filter_query_sql.encode_cursor( + datetime.datetime(2024, 2, 1, 9, 0, 0), "018d8fff" ) - clauses, offset, next_token = filter_query_sql.build_list_filters( + clauses = filter_query_sql.build_list_filters( filter_value=None, filter_query_value=None, - page_token_value=token.encode(), + cursor_value=cursor, current_user=None, - page_size=10, ) - assert offset == 20 assert len(clauses) == 1 - compiled = _compile(clauses[0]) - assert "EXISTS" in compiled.upper() - assert next_token.offset == 30 - assert next_token.filter is None - assert next_token.filter_query is not None - - def test_page_token_restores_filter_query(self): - fq = '{"and": [{"key_exists": {"key": "env"}}]}' - token = filter_query_sql.PageToken(offset=10, filter_query=fq) - clauses, offset, next_token = filter_query_sql.build_list_filters( + assert _compile(clauses[0]) == ( + "(pipeline_run.created_at, pipeline_run.id) " + "< ('2024-02-01 09:00:00.000000', '018d8fff')" + ) + + def test_cursor_where_clause_not_generated_page1(self): + clauses = filter_query_sql.build_list_filters( filter_value=None, filter_query_value=None, - page_token_value=token.encode(), + cursor_value=None, current_user=None, - page_size=5, ) - assert offset == 10 - assert len(clauses) == 1 - assert next_token.offset == 15 - assert next_token.filter_query == fq + assert clauses == [] - def test_page_size_reflected_in_next_token(self): - _, _, next_token = filter_query_sql.build_list_filters( + def test_cursor_with_filter_query(self): + fq = '{"and": [{"key_exists": {"key": "team"}}]}' + cursor = filter_query_sql.encode_cursor( + datetime.datetime(2024, 2, 1, 9, 0, 0), "018d8fff" + ) + clauses = filter_query_sql.build_list_filters( filter_value=None, - filter_query_value=None, - page_token_value=None, + filter_query_value=fq, + cursor_value=cursor, current_user=None, - page_size=25, ) - assert next_token.offset == 25 + assert len(clauses) == 2 + assert _compile(clauses[0]) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id " + "FROM pipeline_run_annotation, pipeline_run " + "WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id " + "AND pipeline_run_annotation.\"key\" = 'team')" + ) + assert _compile(clauses[1]) == ( + "(pipeline_run.created_at, pipeline_run.id) " + "< ('2024-02-01 09:00:00.000000', '018d8fff')" + ) - def test_created_by_me_resolved_in_next_token(self): - clauses, offset, next_token = filter_query_sql.build_list_filters( + def test_invalid_cursor_raises(self): + """cursor_value without ~ raises InvalidPageTokenError.""" + with pytest.raises( + errors.InvalidPageTokenError, match="Unrecognized page_token" + ): + filter_query_sql.build_list_filters( + filter_value=None, + filter_query_value=None, + cursor_value="not-a-cursor", + current_user=None, + ) + + def test_created_by_me_resolves(self): + clauses = filter_query_sql.build_list_filters( filter_value="created_by:me", filter_query_value=None, - page_token_value=None, + cursor_value=None, current_user="bob@example.com", - page_size=10, ) assert len(clauses) == 1 - assert next_token.filter is None - assert next_token.filter_query is not None - parsed_fq = json.loads(next_token.filter_query) - assert parsed_fq["and"][0]["value_equals"]["value"] == "me" + assert _compile(clauses[0]) == ( + "EXISTS (SELECT pipeline_run_annotation.pipeline_run_id " + "FROM pipeline_run_annotation, pipeline_run " + "WHERE pipeline_run_annotation.pipeline_run_id = pipeline_run.id " + "AND pipeline_run_annotation.\"key\" = 'system/pipeline_run.created_by' " + "AND pipeline_run_annotation.value = 'bob@example.com')" + ) class TestSystemKeyValidation: