Source code for ivadomed.scripts.extract_small_dataset

#!/usr/bin/env python

import shutil
import argparse
import numpy as np
import pandas as pd
from ivadomed import utils as imed_utils
from pathlib import Path
from typing import List
from loguru import logger

EXCLUDED_SUBJECT = ["sub-mniPilot1"]


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", required=True,
                        help="Input BIDS folder.", metavar=imed_utils.Metavar.file)
    parser.add_argument("-n", "--number", required=False, default=1,
                        help="Number of subjects.", metavar=imed_utils.Metavar.int)
    parser.add_argument("-c", "--contrasts", required=False,
                        help="Contrast list.", metavar=imed_utils.Metavar.list)
    parser.add_argument("-o", "--output", required=True,
                        help="Output BIDS Folder.", metavar=imed_utils.Metavar.file)
    parser.add_argument("-s", "--seed", required=False, default=-1,
                        help="""Set np.random.RandomState to ensure reproducibility: the same
                                subjects will be selected if the script is run several times on the
                                same dataset. Set to -1 (default) otherwise.""",
                        metavar=imed_utils.Metavar.int)
    parser.add_argument("-d", "--derivatives",
                        dest="derivatives",
                        default=1,
                        help="""If true, include derivatives/labels content.
                                1 = true, 0 = false""",
                        metavar=imed_utils.Metavar.int)
    return parser


def is_good_contrast(fname, good_contrast_list):
    for good_contrast in good_contrast_list:
        if "_" + good_contrast in fname:
            return True
    return False


def remove_some_contrasts(folder, subject_list, good_contrast_list):
    file_list: List[Path] = []
    for s in subject_list:
        for f in Path(folder, s, "anat").iterdir():
            file_list.append(f)
    rm_file_list: List[Path] = []
    for file in file_list:
        if not is_good_contrast(str(file), good_contrast_list):
            rm_file_list.append(file)
    for file in rm_file_list:
        file.unlink()


[docs] def extract_small_dataset(input, output, n=10, contrast_list=None, include_derivatives=True, seed=-1): """Extract small BIDS dataset from a larger BIDS dataset. Example:: ivadomed_extract_small_dataset -i path/to/BIDS/dataset -o path/of/small/BIDS/dataset \ -n 10 -c T1w,T2w -d 0 -s 1234 Args: input (str): Input BIDS folder. Flag: ``--input``, ``-i`` output (str): Output folder. Flag: ``--output``, ``-o`` n (int): Number of subjects in the output folder. Flag: ``--number``, ``-n`` contrast_list (list): List of image contrasts to include. If set to None, then all available contrasts are included. Flag: ``--contrasts``, ``-c`` include_derivatives (bool): If True, derivatives/labels/ content is also copied, only the raw images otherwise. Flag: ``--derivatives``, ``-d`` seed (int): Set np.random.RandomState to ensure reproducibility: the same subjects will be selected if the function is run several times on the same dataset. If set to -1, each function run is independent. Flag: ``--seed``, ``-s``. """ # Create output folders if not Path(output).is_dir(): Path(output).mkdir(parents=True) if include_derivatives: out_derivatives = Path(output, "derivatives") if not out_derivatives.is_dir(): out_derivatives.mkdir(parents=True) out_derivatives = Path(out_derivatives, "labels") if not out_derivatives.is_dir(): out_derivatives.mkdir(parents=True) in_derivatives = Path(input, "derivatives", "labels") # Get subject list subject_list = [s.name for s in Path(input).iterdir() if s.name.startswith("sub-") and s.is_dir() and s.name not in EXCLUDED_SUBJECT] # Randomly select subjects if seed != -1: # Reproducibility r = np.random.RandomState(seed) subject_random_list = list(r.choice(subject_list, n)) else: subject_random_list = list(np.random.choice(subject_list, n, replace=False)) # Loop across subjects for subject in subject_random_list: logger.debug(f"\nSubject: {subject}") # Copy images in_subj_folder = Path(input, subject) out_subj_folder = Path(output, subject) assert in_subj_folder.is_dir() logger.debug(f"\tCopying {in_subj_folder} to {out_subj_folder}.") shutil.copytree(str(in_subj_folder), str(out_subj_folder)) # Remove dwi data if Path(output, subject, "dwi").is_dir(): shutil.rmtree(str(Path(output, subject, "dwi"))) # Copy labels if include_derivatives: in_subj_derivatives = Path(in_derivatives, subject) out_subj_derivatives = Path(out_derivatives, subject) assert in_subj_derivatives.is_dir() logger.debug(f"\tCopying {in_subj_derivatives} to {out_subj_derivatives}.") shutil.copytree(str(in_subj_derivatives), str(out_subj_derivatives)) # Remove dwi data if Path(out_subj_derivatives, subject, "dwi").is_dir(): shutil.rmtree(str(Path(out_subj_derivatives, subject, "dwi"))) if contrast_list: remove_some_contrasts(output, subject_random_list, contrast_list) if include_derivatives: remove_some_contrasts(str(Path(output, "derivatives", "labels")), subject_random_list, contrast_list) # Copy dataset_description.json in_dataset_json = Path(input, "dataset_description.json") out_dataset_json = Path(output, "dataset_description.json") shutil.copyfile(str(in_dataset_json), str(out_dataset_json)) # Copy participants.json if it exist if Path(input).joinpath("participants.json").is_file(): in_participants_json = Path(input, "participants.json") out_participants_json = Path(output, "participants.json") shutil.copyfile(str(in_participants_json), str(out_participants_json)) # Copy participants.tsv in_participants_tsv = Path(input, "participants.tsv") out_participants_tsv = Path(output, "participants.tsv") df = pd.read_csv(str(in_participants_tsv), sep='\t') # Drop subjects df = df[df.participant_id.isin(subject_random_list)] df.to_csv(str(out_participants_tsv), sep='\t', index=False)
def main(args=None): imed_utils.init_ivadomed() parser = get_parser() args = imed_utils.get_arguments(parser, args) if args.contrasts is not None: contrast_list = args.contrasts.split(",") else: contrast_list = None extract_small_dataset(args.input, args.output, int(args.number), contrast_list, bool(int(args.derivatives)), int(args.seed)) if __name__ == '__main__': main()