Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions server/mergin/sync/public_api_v2_controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@
from .storages.disk import move_to_tmp, save_to_file
from .utils import get_device_id, get_ip, get_user_agent, get_chunk_location
from .workspace import WorkspaceRole
from ..utils import parse_order_params
from ..utils import parse_order_params, get_schema_fields_map


@auth_required
Expand Down Expand Up @@ -437,11 +437,15 @@ def list_workspace_projects(workspace_id, page, per_page, order_params=None, q=N
projects = projects.filter(Project.name.ilike(f"%{q}%"))

if order_params:
order_by_params = parse_order_params(Project, order_params)
schema_map = get_schema_fields_map(ProjectSchemaV2)
order_by_params = parse_order_params(
Project, order_params, field_map=schema_map
)
projects = projects.order_by(*order_by_params)

result = projects.paginate(page, per_page).items
total = projects.paginate(page, per_page).total
pagination = projects.paginate(page=page, per_page=per_page)
result = pagination.items
total = pagination.total

data = ProjectSchemaV2(many=True).dump(result)
return jsonify(projects=data, count=total, page=page, per_page=per_page), 200
11 changes: 11 additions & 0 deletions server/mergin/tests/test_public_api_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,17 @@ def test_list_workspace_projects(client):
url + f"?page={page}&per_page={per_page}&q=1&order_params=created DESC"
)
assert response.json["projects"][0]["name"] == "project_10"
# using field name instead column names for sorting
p4 = Project.query.filter(Project.name == project_name).first()
p4.disk_usage = 1234567
db.session.commit()
response = client.get(url + f"?page=1&per_page=10&order_params=size DESC")
resp_data = json.loads(response.data)
assert resp_data["projects"][0]["name"] == project_name

# invalid order param
response = client.get(url + f"?page=1&per_page=10&order_params=invalid DESC")
assert response.status_code == 200

# no permissions to workspace
user2 = add_user("user", "password")
Expand Down
27 changes: 26 additions & 1 deletion server/mergin/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,15 @@
import json
import pytest
from flask import url_for, current_app
from marshmallow import Schema, fields
from sqlalchemy import desc
import os
from unittest.mock import patch
from pathvalidate import sanitize_filename
from pygeodiff import GeoDiff
from pathlib import PureWindowsPath

from ..utils import save_diagnostic_log_file
from ..utils import save_diagnostic_log_file, get_schema_fields_map

from ..sync.utils import (
is_reserved_word,
Expand Down Expand Up @@ -297,3 +298,27 @@ def test_save_diagnostic_log_file(client, app):
with open(saved_file_path, "r") as f:
content = f.read()
assert content == body.decode("utf-8")


def test_get_schema_fields_map():
"""Test that schema map correctly resolves DB attributes, keeps all fields, and ignores virtual fields."""

# dummy schema for testing
class TestSchema(Schema):
# standard field -> map 'name': 'name'
name = fields.String()
# aliased field -> map 'size': 'disk_usage
size = fields.Integer(attribute="disk_usage")
# virtual fields -> skip
version = fields.Function(lambda obj: "v1")
role = fields.Method("get_role")
# excluded field - set to None in schema inheritance -> skip
hidden_field = None

schema_map = get_schema_fields_map(TestSchema)

expected_map = {
"name": "name",
"size": "disk_usage",
}
assert schema_map == expected_map
53 changes: 45 additions & 8 deletions server/mergin/utils.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
# Copyright (C) Lutra Consulting Limited
#
# SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-MerginMaps-Commercial
import logging

import math
from collections import namedtuple
from datetime import datetime, timedelta, timezone
from enum import Enum
import os
from flask import current_app
from flask_sqlalchemy import Model
from marshmallow import Schema, fields
from pathvalidate import sanitize_filename
from sqlalchemy import Column, JSON
from sqlalchemy.sql.elements import UnaryExpression
from typing import Optional

from typing import Optional, Type

OrderParam = namedtuple("OrderParam", "name direction")

Expand All @@ -33,7 +35,7 @@ def split_order_param(order_param: str) -> Optional[OrderParam]:


def get_order_param(
cls: Model, order_param: OrderParam, json_sort: dict = None
cls: Model, order_param: OrderParam, json_sort: dict = None, field_map: dict = None
) -> Optional[UnaryExpression]:
"""Return order by clause parameter for SQL query

Expand All @@ -43,15 +45,22 @@ def get_order_param(
:type order_param: OrderParam
:param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None
:type json_sort: dict
:param field_map: mapping for translating public field names to internal DB columns, e.g. '{"size": "disk_usage"}'
:type field_map: dict
"""
# translate field name to column name
db_column_name = order_param.name
if field_map and order_param.name in field_map:
db_column_name = field_map[order_param.name]
# find candidate for nested json sort
if "." in order_param.name:
col, attr = order_param.name.split(".")
if "." in db_column_name:
col, attr = db_column_name.split(".")
else:
col = order_param.name
col = db_column_name
attr = None
order_attr = cls.__table__.c.get(col, None)
if not isinstance(order_attr, Column):
logging.warning("Ignoring invalid order parameter.")
return
# sort by key in JSON field
if attr:
Expand Down Expand Up @@ -80,7 +89,9 @@ def get_order_param(
return order_attr.desc()


def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
def parse_order_params(
cls: Model, order_params: str, json_sort: dict = None, field_map: dict = None
) -> list[UnaryExpression]:
"""Convert order parameters in query string to list of order by clauses.

:param cls: Db model class
Expand All @@ -89,6 +100,8 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
:type order_params: str
:param json_sort: type mapping for sort by json field, e.g. '{"storage": "int"}', defaults to None
:type json_sort: dict
:param field_map: mapping response fields to database column names, e.g. '{"size": "disk_usage"}'
:type field_map: dict

:rtype: List[Column]
"""
Expand All @@ -97,7 +110,7 @@ def parse_order_params(cls: Model, order_params: str, json_sort: dict = None):
order_param = split_order_param(p)
if not order_param:
continue
order_attr = get_order_param(cls, order_param, json_sort)
order_attr = get_order_param(cls, order_param, json_sort, field_map)
if order_attr is not None:
order_by_params.append(order_attr)
return order_by_params
Expand Down Expand Up @@ -135,3 +148,27 @@ def save_diagnostic_log_file(app: str, username: str, body: bytes) -> str:
f.write(content)

return file_name


def get_schema_fields_map(schema: Type[Schema]) -> dict:
"""
Creates a mapping of schema field names to corresponding DB columns.
This allows sorting by the API field name (e.g. 'size') while
actually sorting by the database column (e.g. 'disk_usage').
"""
mapping = {}
for name, field in schema._declared_fields.items():
# some fields could have been overridden with None to be excluded
if not field:
continue
# skip virtual fields as DB cannot sort by them
if isinstance(
field, (fields.Function, fields.Method, fields.Nested, fields.List)
):
continue
if field.attribute:
mapping[name] = field.attribute
# keep the map complete
else:
mapping[name] = name
return mapping
Loading