Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
  • Loading branch information
StevenSong committed Jun 15, 2020
1 parent e3540e1 commit 9d3a3b6
Showing 1 changed file with 11 additions and 6 deletions.
17 changes: 11 additions & 6 deletions ml4cvd/tensor_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,12 +90,12 @@ def __init__(
self._started = False
self.workers = []
self.worker_instances = []
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
self.batch_size, self.input_maps, self.output_maps, self.num_workers, self.cache_size, self.weights, self.name, self.keep_paths = \
batch_size, input_maps, output_maps, num_workers, cache_size, weights, name, keep_paths
self.true_epochs = 0
self.stats_string = ""
if num_workers == 0:
num_workers = 1 # The one worker is the main thread
if weights is None:
worker_paths = np.array_split(paths, num_workers)
self.true_epoch_lens = list(map(len, worker_paths))
Expand Down Expand Up @@ -148,7 +148,7 @@ def _init_workers(self):
)
process.start()
self.workers.append(process)
logging.info(f"Started {i} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")
logging.info(f"Started {i + 1} {self.name.replace('_', ' ')}s with cache size {self.cache_size/1e9}GB.")

def set_worker_paths(self, paths: List[Path]):
"""In the single worker case, set the worker's paths."""
Expand Down Expand Up @@ -227,7 +227,7 @@ def aggregate_and_print_stats(self):
f"{stats['Tensors presented']:0.0f} tensors were presented.",
f"{stats['skipped_paths']} paths were skipped because they previously failed.",
f"{error_info}",
f"{self.stats_string}"
f"{self.stats_string}",
])
logging.info(f"\n!!!!>~~~~~~~~~~~~ {self.name} completed true epoch {self.true_epochs} ~~~~~~~~~~~~<!!!!\nAggregated information string:\n\t{info_string}")

Expand Down Expand Up @@ -725,6 +725,8 @@ def test_train_valid_tensor_generators(
tensors: str,
batch_size: int,
num_workers: int,
training_steps: int,
validation_steps: int,
cache_size: float,
balance_csvs: List[str],
keep_paths: bool = False,
Expand Down Expand Up @@ -786,8 +788,11 @@ def test_train_valid_tensor_generators(
test_csv=test_csv,
)
weights = None
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, num_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, num_workers // 2, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)

train_workers = int(training_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
valid_workers = int(validation_steps / (training_steps + validation_steps) * num_workers) or (1 if num_workers else 0)
generate_train = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, train_paths, train_workers, cache_size, weights, keep_paths, mixup_alpha, name='train_worker', siamese=siamese, augment=True, sample_weight=sample_weight)
generate_valid = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, valid_paths, valid_workers, cache_size, weights, keep_paths, name='validation_worker', siamese=siamese, augment=False)
generate_test = TensorGenerator(batch_size, tensor_maps_in, tensor_maps_out, test_paths, num_workers, 0, weights, keep_paths or keep_paths_test, name='test_worker', siamese=siamese, augment=False)
return generate_train, generate_valid, generate_test

Expand Down

0 comments on commit 9d3a3b6

Please sign in to comment.