Skip to content

Commit

Permalink
Update training tokens
Browse files Browse the repository at this point in the history
  • Loading branch information
rudolphpienaar committed Apr 1, 2024
1 parent ab2cce1 commit b52f665
Showing 1 changed file with 47 additions and 2 deletions.
49 changes: 47 additions & 2 deletions spleenseg_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,13 +60,57 @@
formatter_class=ArgumentDefaultsHelpFormatter,
)
parser.add_argument(
"-p", "--pattern", default="**/*nii.gz", type=str, help="input file filter glob"
"--trainImageDir",
type=str,
default="imagesTr",
help="name of directory containing training images",
)
parser.add_argument(
"--trainLabelsDir",
type=str,
default="labelsTr",
help="name of directory containing training labels",
)
parser.add_argument(
"--validateSize",
type=int,
default=9,
help="size of the validation set in the input raw/label space",
)
parser.add_argument(
"--pattern", type=str, default="**/[!._]*nii.gz", help="filter glob for input files"
)
parser.add_argument(
"-V", "--version", action="version", version=f"%(prog)s {__version__}"
)


def trainingData_prep(options: Namespace, inputDir: Path) -> list[dict[str, str]]:
trainRaw: list[Path] = []
trainLbl: list[Path] = []
for group in [options.trainImageDir, options.trainLabelsDir]:
for path in inputDir.rglob(group):
if group == path.name and path.name == options.trainImageDir:
trainRaw.extend(path.glob(options.pattern))
elif group == path.name and path.name == options.trainLabelsDir:
trainLbl.extend(path.glob(options.pattern))
trainRaw.sort()
trainLbl.sort()
return [
{"image": str(image_name), "label": str(label_name)}
for image_name, label_name in zip(trainRaw, trainLbl)
]


def inputFiles_splitInto_train_validate(
options: Namespace, inputDir: Path
) -> tuple[list[dict[str, str]], list[dict[str, str]]]:
trainingSpace: list[dict[str, str]] = trainingData_prep(options, inputDir)
trainingSet: list[dict[str, str]] = trainingSpace[: -options.validateSize]
validateSet: list[dict[str, str]] = trainingSpace[-options.validateSize :]
return trainingSet, validateSet


# The main function of this *ChRIS* plugin is denoted by this ``@chris_plugin`` "decorator."
# Some metadata about the plugin is specified here. There is more metadata specified in setup.py.
#
Expand All @@ -89,9 +133,10 @@ def main(options: Namespace, inputdir: Path, outputdir: Path):
:param inputdir: directory containing (read-only) input files
:param outputdir: directory where to write output files
"""

pudb.set_trace()
print(DISPLAY_TITLE)
print_config()
trainingSet, validateSet = inputFiles_splitInto_train_validate(options, inputdir)

# Typically it's easier to think of programs as operating on individual files
# rather than directories. The helper functions provided by a ``PathMapper``
Expand Down

0 comments on commit b52f665

Please sign in to comment.