Skip to content

Commit

Permalink
Nd inference update (#115)
Browse files Browse the repository at this point in the history
* inference infers even when output tmaps fail
* bug fix in tf.sh makes -t work
  • Loading branch information
ndiamant authored Feb 3, 2020
1 parent 3da6a7b commit 431e458
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 9 deletions.
37 changes: 29 additions & 8 deletions ml4cvd/recipes.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Imports
import os
import csv
import copy
import logging
import numpy as np
from functools import reduce
Expand Down Expand Up @@ -150,19 +151,37 @@ def compare_multimodal_scalar_task_models(args):
_calculate_and_plot_prediction_stats(args, predictions, labels, paths)


def _make_tmap_nan_on_fail(tmap):
"""
Builds a copy TensorMap with a tensor_from_file that returns nans on errors instead of raising an error
"""
new_tmap = copy.deepcopy(tmap)

def _tff(tm, hd5, dependents=None):
try:
return tmap.tensor_from_file(tm, hd5, dependents)
except (IndexError, KeyError, ValueError, OSError, RuntimeError):
return np.full(shape=tm.shape, fill_value=np.nan)

new_tmap.tensor_from_file = _tff
return new_tmap


def infer_multimodal_multitask(args):
stats = Counter()
tensor_paths_inferred = {}
inference_tsv = os.path.join(args.output_folder, args.id, 'inference_' + args.id + '.tsv')
tensor_paths = [args.tensors + tp for tp in sorted(os.listdir(args.tensors)) if os.path.splitext(tp)[-1].lower() == TENSOR_EXT]
# hard code batch size to 1 so we can iterate over file names and generated tensors together in the tensor_paths for loop
if args.variational:
model, encoder, decoder = make_variational_multimodal_multitask_model(**args.__dict__)
else:
model = make_multimodal_multitask_model(**args.__dict__)
generate_test = TensorGenerator(1, args.tensor_maps_in, args.tensor_maps_out, tensor_paths, num_workers=0,
cache_size=args.cache_size, keep_paths=True, mixup=args.mixup_alpha)
no_fail_tmaps_out = [_make_tmap_nan_on_fail(tmap) for tmap in args.tensor_maps_out]
# hard code batch size to 1 so we can iterate over file names and generated tensors together in the tensor_paths for loop
generate_test = TensorGenerator(1, args.tensor_maps_in, no_fail_tmaps_out, tensor_paths, num_workers=0,
cache_size=0, keep_paths=True, mixup=args.mixup_alpha)
with open(inference_tsv, mode='w') as inference_file:
# TODO: csv.DictWriter is much nicer for this
inference_writer = csv.writer(inference_file, delimiter='\t', quotechar='"', quoting=csv.QUOTE_MINIMAL)
header = ['sample_id']
for ot, otm in zip(args.output_tensors, args.tensor_maps_out):
Expand All @@ -183,21 +202,23 @@ def infer_multimodal_multitask(args):
break

prediction = model.predict(input_data)
if len(args.tensor_maps_out) == 1:
if len(no_fail_tmaps_out) == 1:
prediction = [prediction]

csv_row = [os.path.basename(tensor_path[0]).replace(TENSOR_EXT, '')] # extract sample id
for y, tm in zip(prediction, args.tensor_maps_out):
for y, tm in zip(prediction, no_fail_tmaps_out):
if len(tm.shape) == 1 and tm.is_continuous():
csv_row.append(str(tm.rescale(y)[0][0])) # first index into batch then index into the 1x1 structure
if tm.sentinel is not None and tm.sentinel == true_label[tm.output_name()][0][0]:
if ((tm.sentinel is not None and tm.sentinel == true_label[tm.output_name()][0][0])
or np.isnan(true_label[tm.output_name()][0][0])):
csv_row.append("NA")
else:
csv_row.append(str(tm.rescale(true_label[tm.output_name()])[0][0]))
elif len(tm.shape) == 1 and tm.is_categorical_any():
for k in tm.channel_map:
for k, i in tm.channel_map.items():
csv_row.append(str(y[0][tm.channel_map[k]]))
csv_row.append(str(true_label[tm.output_name()][0][tm.channel_map[k]]))
actual = true_label[tm.output_name()][0][i]
csv_row.append("NA" if np.isnan(actual) else str(actual))

inference_writer.writerow(csv_row)
tensor_paths_inferred[tensor_path[0]] = True
Expand Down
2 changes: 1 addition & 1 deletion scripts/tf.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ while getopts ":i:cth" opt ; do
DOCKER_COMMAND=docker
;;
t)
INTERACTIVE_RUN="-it"
INTERACTIVE="-it"
;;
:)
echo "ERROR: Option -${OPTARG} requires an argument." 1>&2
Expand Down

0 comments on commit 431e458

Please sign in to comment.