Source code for ivadomed.scripts.extract_small_dataset

#!/usr/bin/env python

import os
import shutil
import argparse
import numpy as np
import pandas as pd
from ivadomed import utils as imed_utils

EXCLUDED_SUBJECT = ["sub-mniPilot1"]


def get_parser():
    parser = argparse.ArgumentParser()
    parser.add_argument("-i", "--input", required=True,
                        help="Input BIDS folder.", metavar=imed_utils.Metavar.file)
    parser.add_argument("-n", "--number", required=False, default=1,
                        help="Number of subjects.", metavar=imed_utils.Metavar.int)
    parser.add_argument("-c", "--contrasts", required=False,
                        help="Contrast list.", metavar=imed_utils.Metavar.list)
    parser.add_argument("-o", "--output", required=True,
                        help="Output BIDS Folder.", metavar=imed_utils.Metavar.file)
    parser.add_argument("-s", "--seed", required=False, default=-1,
                        help="""Set np.random.RandomState to ensure reproducibility: the same
                                subjects will be selected if the script is run several times on the
                                same dataset. Set to -1 (default) otherwise.""",
                        metavar=imed_utils.Metavar.int)
    parser.add_argument("-d", "--derivatives",
                        dest="derivatives",
                        default=1,
                        help="""If true, include derivatives/labels content.
                                1 = true, 0 = false""",
                        metavar=imed_utils.Metavar.int)
    return parser


def is_good_contrast(fname, good_contrast_list):
    for good_contrast in good_contrast_list:
        if "_" + good_contrast in fname:
            return True
    return False


def remove_some_contrasts(folder, subject_list, good_contrast_list):
    file_list = [os.path.join(folder, s, "anat", f) for s in subject_list
                 for f in os.listdir(os.path.join(folder, s, "anat"))]
    rm_file_list = [f for f in file_list if not is_good_contrast(f, good_contrast_list)]
    for ff in rm_file_list:
        os.remove(ff)


[docs]def extract_small_dataset(input, output, n=10, contrast_list=None, include_derivatives=True, seed=-1): """Extract small BIDS dataset from a larger BIDS dataset. Example:: ivadomed_extract_small_dataset -i path/to/BIDS/dataset -o path/of/small/BIDS/dataset \ -n 10 -c T1w,T2w -d 0 -s 1234 Args: input (str): Input BIDS folder. Flag: ``--input``, ``-i`` output (str): Output folder. Flag: ``--output``, ``-o`` n (int): Number of subjects in the output folder. Flag: ``--number``, ``-n`` contrast_list (list): List of image contrasts to include. If set to None, then all available contrasts are included. Flag: ``--contrasts``, ``-c`` include_derivatives (bool): If True, derivatives/labels/ content is also copied, only the raw images otherwise. Flag: ``--derivatives``, ``-d`` seed (int): Set np.random.RandomState to ensure reproducibility: the same subjects will be selected if the function is run several times on the same dataset. If set to -1, each function run is independent. Flag: ``--seed``, ``-s``. """ # Create output folders if not os.path.isdir(output): os.makedirs(output) if include_derivatives: oderivatives = os.path.join(output, "derivatives") if not os.path.isdir(oderivatives): os.makedirs(oderivatives) oderivatives = os.path.join(oderivatives, "labels") if not os.path.isdir(oderivatives): os.makedirs(oderivatives) iderivatives = os.path.join(input, "derivatives", "labels") # Get subject list subject_list = [s for s in os.listdir(input) if s.startswith("sub-") and os.path.isdir(os.path.join(input, s)) and s not in EXCLUDED_SUBJECT] # Randomly select subjects if seed != -1: # Reproducibility r = np.random.RandomState(seed) subject_random_list = list(r.choice(subject_list, n)) else: subject_random_list = list(np.random.choice(subject_list, n, replace=False)) # Loop across subjects for subject in subject_random_list: print("\nSubject: {}".format(subject)) # Copy images isubjfolder = os.path.join(input, subject) osubjfolder = os.path.join(output, subject) assert os.path.isdir(isubjfolder) print("\tCopying {} to {}.".format(isubjfolder, osubjfolder)) shutil.copytree(isubjfolder, osubjfolder) # Remove dwi data if os.path.isdir(os.path.join(output, subject, "dwi")): shutil.rmtree(os.path.join(output, subject, "dwi")) # Copy labels if include_derivatives: isubjderivatives = os.path.join(iderivatives, subject) osubjderivatives = os.path.join(oderivatives, subject) assert os.path.isdir(isubjderivatives) print("\tCopying {} to {}.".format(isubjderivatives, osubjderivatives)) shutil.copytree(isubjderivatives, osubjderivatives) # Remove dwi data if os.path.isdir(os.path.join(osubjderivatives, subject, "dwi")): shutil.rmtree(os.path.join(osubjderivatives, subject, "dwi")) if contrast_list: remove_some_contrasts(output, subject_random_list, contrast_list) if include_derivatives: remove_some_contrasts(os.path.join(output, "derivatives", "labels"), subject_random_list, contrast_list) # Copy dataset_description.json idatasetjson = os.path.join(input, "dataset_description.json") odatasetjson = os.path.join(output, "dataset_description.json") shutil.copyfile(idatasetjson, odatasetjson) # Copy participants.json if it exist if os.path.isfile(os.path.join(input, "participants.json")): iparticipantsjson = os.path.join(input, "participants.json") oparticipantsjson = os.path.join(output, "participants.json") shutil.copyfile(iparticipantsjson, oparticipantsjson) # Copy participants.tsv iparticipantstsv = os.path.join(input, "participants.tsv") oparticipantstsv = os.path.join(output, "participants.tsv") df = pd.read_csv(iparticipantstsv, sep='\t') # Drop subjects df = df[df.participant_id.isin(subject_random_list)] df.to_csv(oparticipantstsv, sep='\t', index=False)
def main(args=None): imed_utils.init_ivadomed() parser = get_parser() args = imed_utils.get_arguments(parser, args) if args.contrasts is not None: contrast_list = args.contrasts.split(",") else: contrast_list = None extract_small_dataset(args.input, args.output, int(args.number), contrast_list, bool(int(args.derivatives)), int(args.seed)) if __name__ == '__main__': main()