diff --git a/databusclient/api/download.py b/databusclient/api/download.py index ac55faa..4af27c4 100644 --- a/databusclient/api/download.py +++ b/databusclient/api/download.py @@ -12,6 +12,54 @@ get_databus_id_parts_from_file_url, ) +from databusclient.extensions.webdav import compute_sha256_and_length + +def _extract_checksum_from_node(node) -> str | None: + """ + Try to extract a 64-char hex checksum from a JSON-LD file node. + Handles these common shapes: + - checksum or sha256sum fields as plain string + - checksum fields as dict with '@value' + - nested values (recursively search strings for a 64-char hex) + """ + def find_in_value(v): + if isinstance(v, str): + s = v.strip() + if len(s) == 64 and all(c in "0123456789abcdefABCDEF" for c in s): + return s + if isinstance(v, dict): + # common JSON-LD value object + if "@value" in v and isinstance(v["@value"], str): + res = find_in_value(v["@value"]) + if res: + return res + # try all nested dict values + for vv in v.values(): + res = find_in_value(vv) + if res: + return res + if isinstance(v, list): + for item in v: + res = find_in_value(item) + if res: + return res + return None + + # direct keys to try first + for key in ("checksum", "sha256sum", "sha256", "databus:checksum"): + if key in node: + res = find_in_value(node[key]) + if res: + return res + + # fallback: search all values recursively for a 64-char hex string + for v in node.values(): + res = find_in_value(v) + if res: + return res + return None + + # Hosts that require Vault token based authentication. Central source of truth. VAULT_REQUIRED_HOSTS = { @@ -32,6 +80,8 @@ def _download_file( databus_key=None, auth_url=None, client_id=None, + validate_checksum: bool = False, + expected_checksum: str | None = None, ) -> None: """ Download a file from the internet with a progress bar using tqdm. @@ -183,6 +233,27 @@ def _download_file( if total_size_in_bytes != 0 and progress_bar.n != total_size_in_bytes: raise IOError("Downloaded size does not match Content-Length header") + # --- 6. Optional checksum validation --- + if validate_checksum: + # reuse compute_sha256_and_length from webdav extension + try: + actual, _ = compute_sha256_and_length(filename) + except (OSError, IOError) as e: + print(f"WARNING: error computing checksum for {filename}: {e}") + actual = None + + if expected_checksum is None: + print(f"WARNING: no expected checksum available for {filename}; skipping validation") + elif actual is None: + print(f"WARNING: could not compute checksum for {filename}; skipping validation") + else: + if actual.lower() != expected_checksum.lower(): + try: os.remove(filename) # delete corrupted file + except OSError: pass + raise IOError( + f"Checksum mismatch for {filename}: expected {expected_checksum}, got {actual}" + ) + def _download_files( urls: List[str], @@ -191,6 +262,8 @@ def _download_files( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, + checksums: dict | None = None, ) -> None: """ Download multiple files from the databus. @@ -204,6 +277,9 @@ def _download_files( - client_id: Client ID for token exchange """ for url in urls: + expected = None + if checksums and isinstance(checksums, dict): + expected = checksums.get(url) _download_file( url=url, localDir=localDir, @@ -211,6 +287,8 @@ def _download_files( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + expected_checksum=expected, ) @@ -358,6 +436,7 @@ def _download_collection( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False ) -> None: """ Download all files in a databus collection. @@ -375,6 +454,44 @@ def _download_collection( file_urls = _get_file_download_urls_from_sparql_query( endpoint, query, databus_key=databus_key ) + + # If checksum validation requested, attempt to build url->checksum mapping + # by fetching the Version JSON-LD for each file's version. We group files + # by their version URI to avoid fetching the same metadata repeatedly. + checksums: dict = {} + if validate_checksum: + # Map version_uri -> list of file urls + versions_map: dict = {} + for fu in file_urls: + try: + h, acc, grp, art, ver, f = get_databus_id_parts_from_file_url(fu) + except Exception: + continue + if ver is None: + continue + if h is None or acc is None or grp is None or art is None: + continue + version_uri = f"https://{h}/{acc}/{grp}/{art}/{ver}" + versions_map.setdefault(version_uri, []).append(fu) + + # Fetch each version's JSON-LD once and extract checksums for its files + for version_uri, urls_in_version in versions_map.items(): + try: + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + jd = json.loads(json_str) + graph = jd.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + file_uri = node.get("file") + if not isinstance(file_uri, str): + continue + expected = _extract_checksum_from_node(node) + if expected and file_uri in urls_in_version: + checksums[file_uri] = expected + except Exception: + # Best-effort: if fetching a version fails, skip it + continue + _download_files( list(file_urls), localDir, @@ -382,6 +499,8 @@ def _download_collection( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums if checksums else None, ) @@ -392,6 +511,7 @@ def _download_version( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download all files in a databus artifact version. @@ -406,6 +526,22 @@ def _download_version( """ json_str = fetch_databus_jsonld(uri, databus_key=databus_key) file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # build url -> checksum mapping from JSON-LD when available + checksums: dict = {} + try: + json_dict = json.loads(json_str) + graph = json_dict.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + file_uri = node.get("file") + if not isinstance(file_uri, str): + continue + expected = _extract_checksum_from_node(node) + if expected: + checksums[file_uri] = expected + except Exception: + checksums = {} + _download_files( file_urls, localDir, @@ -413,6 +549,8 @@ def _download_version( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums, ) @@ -424,6 +562,7 @@ def _download_artifact( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download files in a databus artifact. @@ -445,6 +584,22 @@ def _download_artifact( print(f"Downloading version: {version_uri}") json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) file_urls = _get_file_download_urls_from_artifact_jsonld(json_str) + # extract checksums for this version + checksums: dict = {} + try: + jd = json.loads(json_str) + graph = jd.get("@graph", []) + for node in graph: + if node.get("@type") == "Part": + file_uri = node.get("file") + if not isinstance(file_uri, str): + continue + expected = _extract_checksum_from_node(node) + if expected: + checksums[file_uri] = expected + except Exception: + checksums = {} + _download_files( file_urls, localDir, @@ -452,6 +607,8 @@ def _download_artifact( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + checksums=checksums, ) @@ -527,6 +684,7 @@ def _download_group( databus_key: str = None, auth_url: str = None, client_id: str = None, + validate_checksum: bool = False, ) -> None: """ Download files in a databus group. @@ -552,6 +710,7 @@ def _download_group( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) @@ -598,6 +757,7 @@ def download( all_versions=None, auth_url="https://auth.dbpedia.org/realms/dbpedia/protocol/openid-connect/token", client_id="vault-token-exchange", + validate_checksum: bool = False ) -> None: """ Download datasets from databus. @@ -638,9 +798,27 @@ def download( databus_key, auth_url, client_id, + validate_checksum=validate_checksum, ) elif file is not None: print(f"Downloading file: {databusURI}") + # Try to fetch expected checksum from the parent Version metadata + expected = None + if validate_checksum: + try: + version_uri = f"https://{host}/{account}/{group}/{artifact}/{version}" + json_str = fetch_databus_jsonld(version_uri, databus_key=databus_key) + json_dict = json.loads(json_str) + graph = json_dict.get("@graph", []) + for node in graph: + if node.get("file") == databusURI or node.get("@id") == databusURI: + expected = _extract_checksum_from_node(node) + if expected: + break + except Exception as e: + print(f"WARNING: Could not fetch checksum for single file: {e}") + + # Call the worker to download the single file (passes expected checksum) _download_file( databusURI, localDir, @@ -648,6 +826,8 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, + expected_checksum=expected, ) elif version is not None: print(f"Downloading version: {databusURI}") @@ -658,6 +838,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif artifact is not None: print( @@ -671,6 +852,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif group is not None and group != "collections": print( @@ -684,6 +866,7 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) elif account is not None: print("accountId not supported yet") # TODO @@ -697,6 +880,8 @@ def download( # query as argument else: print("QUERY {}", databusURI.replace("\n", " ")) + if validate_checksum: + print("WARNING: Checksum validation is not supported for user-defined SPARQL queries.") if uri_endpoint is None: # endpoint is required for queries (--databus) raise ValueError("No endpoint given for query") res = _get_file_download_urls_from_sparql_query( @@ -709,4 +894,5 @@ def download( databus_key=databus_key, auth_url=auth_url, client_id=client_id, + validate_checksum=validate_checksum, ) diff --git a/databusclient/cli.py b/databusclient/cli.py index 069408e..420530d 100644 --- a/databusclient/cli.py +++ b/databusclient/cli.py @@ -158,6 +158,11 @@ def deploy( show_default=True, help="Client ID for token exchange", ) +@click.option( + "--validate-checksum", + is_flag=True, + help="Validate checksums of downloaded files" +) def download( databusuris: List[str], localdir, @@ -167,7 +172,8 @@ def download( all_versions, authurl, clientid, -): + validate_checksum, +): """ Download datasets from databus, optionally using vault access if vault options are provided. """ @@ -181,7 +187,8 @@ def download( all_versions=all_versions, auth_url=authurl, client_id=clientid, - ) + validate_checksum=validate_checksum + ) except DownloadAuthError as e: raise click.ClickException(str(e))