Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 82 additions & 19 deletions modelarrayio/cifti.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
write_rows_in_column_stripes as tdb_write_stripes,
write_column_names as tdb_write_column_names,
)
from .parser import add_relative_root_arg, add_output_hdf5_arg, add_cohort_arg, add_storage_args, add_backend_arg, add_output_tiledb_arg, add_tiledb_storage_args, add_scalar_columns_arg
from .parser import add_relative_root_arg, add_output_hdf5_arg, add_cohort_arg, add_storage_args, add_backend_arg, add_output_tiledb_arg, add_tiledb_storage_args, add_scalar_columns_arg, add_s3_workers_arg
from .s3_utils import is_s3_path, load_nibabel


def _cohort_to_long_dataframe(cohort_df, scalar_columns=None):
Expand Down Expand Up @@ -83,7 +84,7 @@ def extract_cifti_scalar_data(cifti_file, reference_brain_names=None):

"""

cifti = nb.load(cifti_file)
cifti = cifti_file if hasattr(cifti_file, 'get_fdata') else nb.load(cifti_file)
cifti_hdr = cifti.header
axes = [cifti_hdr.get_axis(i) for i in range(cifti.ndim)]
if len(axes) > 2:
Expand Down Expand Up @@ -134,6 +135,75 @@ def brain_names_to_dataframe(brain_names):
return greyordinate_df, structure_name_strings


def _load_cohort_cifti(cohort_long, relative_root, s3_workers):
"""Load all CIFTI scalar rows from the cohort, optionally in parallel.

The first file is always loaded serially to obtain the reference brain
structure axis used for validation. When s3_workers > 1, remaining rows
are submitted to a ThreadPoolExecutor and collected via as_completed.
Threads share memory so reference_brain_names is accessed directly with
no copying overhead.

Returns
-------
scalars : dict[str, list[np.ndarray]]
Per-scalar ordered list of 1-D subject arrays, ready for stripe-write.
reference_brain_names : np.ndarray
Brain structure names from the first file, for building greyordinate table.
"""
# Assign stable per-scalar subject indices in cohort order
scalar_subj_counter = defaultdict(int)
rows_with_idx = []
for row in cohort_long.itertuples(index=False):
subj_idx = scalar_subj_counter[row.scalar_name]
scalar_subj_counter[row.scalar_name] += 1
rows_with_idx.append((row.scalar_name, subj_idx, row.source_file))

# Load the first file serially to get the reference brain axis
first_sn, _, first_src = rows_with_idx[0]
first_path = first_src if is_s3_path(first_src) else op.join(relative_root, first_src)
first_data, reference_brain_names = extract_cifti_scalar_data(load_nibabel(first_path, cifti=True))

def _worker(job):
sn, subj_idx, src = job
arr, _ = extract_cifti_scalar_data(
load_nibabel(src, cifti=True), reference_brain_names=reference_brain_names
)
return sn, subj_idx, arr

if s3_workers > 1 and len(rows_with_idx) > 1:
results = {first_sn: {0: first_data}}
jobs = [
(sn, subj_idx, src if is_s3_path(src) else op.join(relative_root, src))
for sn, subj_idx, src in rows_with_idx[1:]
]
with ThreadPoolExecutor(max_workers=s3_workers) as pool:
futures = {pool.submit(_worker, job): job for job in jobs}
for future in tqdm(
as_completed(futures),
total=len(futures),
desc="Loading CIFTI data",
):
sn, subj_idx, arr = future.result()
results.setdefault(sn, {})[subj_idx] = arr
scalars = {
sn: [results[sn][i] for i in range(cnt)]
for sn, cnt in scalar_subj_counter.items()
}
else:
scalars = defaultdict(list)
scalars[first_sn].append(first_data)
remaining = [
(sn, subj_idx, src if is_s3_path(src) else op.join(relative_root, src))
for sn, subj_idx, src in rows_with_idx[1:]
]
for job in tqdm(remaining, desc="Loading CIFTI data"):
sn, subj_idx, arr = _worker(job)
scalars[sn].append(arr)

return scalars, reference_brain_names


def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_tdb='arraydb.tdb', relative_root='/',
storage_dtype='float32',
compression='gzip',
Expand All @@ -147,7 +217,8 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
tdb_tile_voxels=0,
tdb_target_tile_mb=2.0,
tdb_workers=None,
scalar_columns=None):
scalar_columns=None,
s3_workers=1):
"""
Load all fixeldb data.
Parameters
Expand All @@ -174,19 +245,9 @@ def write_storage(cohort_file, backend='hdf5', output_h5='fixeldb.h5', output_td
raise ValueError("Unable to derive scalar sources from cohort file.")

if backend == 'hdf5':
scalars = defaultdict(list)
last_brain_names = None
for row in tqdm(
cohort_long.itertuples(index=False),
total=cohort_long.shape[0],
desc="Loading CIFTI scalars",
):
scalar_file = op.join(relative_root, row.source_file)
cifti_data, brain_names = extract_cifti_scalar_data(
scalar_file, reference_brain_names=last_brain_names
)
last_brain_names = brain_names.copy()
scalars[row.scalar_name].append(cifti_data)
scalars, last_brain_names = _load_cohort_cifti(
cohort_long, relative_root, s3_workers
)

output_file = op.join(relative_root, output_h5)
f = h5py.File(output_file, "w")
Expand Down Expand Up @@ -293,6 +354,7 @@ def get_parser():
add_backend_arg(parser)
add_storage_args(parser)
add_tiledb_storage_args(parser)
add_s3_workers_arg(parser)
return parser


Expand All @@ -319,7 +381,8 @@ def main():
tdb_tile_voxels=args.tdb_tile_voxels,
tdb_target_tile_mb=args.tdb_target_tile_mb,
tdb_workers=args.tdb_workers,
scalar_columns=args.scalar_columns)
scalar_columns=args.scalar_columns,
s3_workers=args.s3_workers)
return status


Expand Down Expand Up @@ -395,7 +458,7 @@ def h5_to_ciftis():
parser = get_h5_to_ciftis_parser()
args = parser.parse_args()

out_cifti_dir = op.join(args.relative_root, args.output_dir) # absolute path for output dir
out_cifti_dir = op.abspath(args.output_dir) # absolute path for output dir

if op.exists(out_cifti_dir):
print("WARNING: Output directory exists")
Expand All @@ -422,7 +485,7 @@ def get_h5_to_ciftis_parser():
parser.add_argument(
"--cohort-file", "--cohort_file",
help="Path to a csv with demographic info and paths to data.",
required=True)
)
parser.add_argument(
"--relative-root", "--relative_root",
help="Root to which all paths are relative, i.e. defining the (absolute) path to root directory of index_file, directions_file, cohort_file, input_hdf5, and output_dir.",
Expand Down
4 changes: 3 additions & 1 deletion modelarrayio/h5_storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,9 @@ def create_empty_scalar_matrix_dataset(
shuffle=use_shuffle,
)
if sources_list is not None:
write_column_names(h5file, dataset_path, sources_list)
# dataset_path is e.g. 'scalars/FA/values'; extract the scalar name segment
scalar_name = dataset_path.split('/')[1] if dataset_path.count('/') >= 2 else dataset_path
write_column_names(h5file, scalar_name, sources_list)
return dset


Expand Down
14 changes: 14 additions & 0 deletions modelarrayio/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,20 @@ def add_tiledb_storage_args(parser):
return parser


def add_s3_workers_arg(parser):
parser.add_argument(
"--s3-workers", "--s3_workers",
type=int,
default=1,
help=(
"Number of parallel worker processes for loading image files. "
"Set > 1 to enable parallel downloads when cohort paths begin with s3://. "
"Default 1 (serial)."
),
)
return parser


def add_scalar_columns_arg(parser):
parser.add_argument(
"--scalar-columns", "--scalar_columns",
Expand Down
82 changes: 82 additions & 0 deletions modelarrayio/s3_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
import gzip
import logging
import os
from io import BytesIO
from urllib.parse import urlparse

logger = logging.getLogger(__name__)


def is_s3_path(path: str) -> bool:
"""Return True if path is an S3 URI (s3://)."""
return str(path).startswith("s3://")


def _make_s3_client():
"""Create a boto3 S3 client.

Uses anonymous (unsigned) access when the environment variable
MODELARRAYIO_S3_ANON=1 is set (useful for public buckets such as
fcp-indi). Otherwise the standard boto3 credential chain is used
(env vars, ~/.aws/credentials, IAM instance profile, etc.).

Raises ImportError if boto3 is not installed.
"""
try:
import boto3
except ImportError:
raise ImportError(
"boto3 is required for s3:// paths. "
"Install with: pip install modelarrayio[s3]"
)
anon = os.environ.get("MODELARRAYIO_S3_ANON", "").lower() in ("1", "true", "yes")
if anon:
from botocore import UNSIGNED
from botocore.config import Config
return boto3.client("s3", config=Config(signature_version=UNSIGNED))
return boto3.client("s3")


def load_nibabel(path: str, *, cifti: bool = False):
"""Load a nibabel image from a local path or an s3:// URI.

For s3:// paths the object is downloaded directly into memory via
``get_object``; no temporary file is written to disk. The bytes are
decompressed in-memory if the key ends with ``.gz``, then handed to
nibabel through its ``FileHolder`` / ``from_file_map`` API.

Parameters
----------
path : str
Local file path or s3:// URI.
cifti : bool
Pass ``True`` for CIFTI-2 files (``.dscalar.nii`` etc.) so that
nibabel returns a ``Cifti2Image`` with proper axes. ``False``
(default) returns a ``Nifti1Image``.

Returns
-------
nibabel.FileBasedImage
"""
import nibabel as nb

if not is_s3_path(path):
return nb.load(path)

parsed = urlparse(path)
bucket = parsed.netloc
key = parsed.path.lstrip("/")

logger.debug("Loading s3://%s/%s into memory", bucket, key)
data = _make_s3_client().get_object(Bucket=bucket, Key=key)["Body"].read()

if os.path.basename(key).endswith(".gz"):
data = gzip.decompress(data)

from nibabel.filebasedimages import FileHolder
fh = FileHolder(fileobj=BytesIO(data))
file_map = {"header": fh, "image": fh}

if cifti:
return nb.Cifti2Image.from_file_map(file_map)
return nb.Nifti1Image.from_file_map(file_map)
Loading
Loading