Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 3 additions & 7 deletions smdebug/core/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class CollectionKeys:
# Use this collection to log scalars other than losses/metrics to Minerva.
# Mainly for Tensorflow. For all other frameworks, call save_scalar() API
# with details of the scalar to be saved.
SEARCHABLE_SCALARS = "searchable_scalars"
SM_METRICS = "sm_metrics"

OPTIMIZER_VARIABLES = "optimizer_variables"
TENSORFLOW_SUMMARIES = "tensorflow_summaries"
Expand All @@ -57,14 +57,10 @@ class CollectionKeys:
CollectionKeys.SCALARS,
CollectionKeys.FEATURE_IMPORTANCE,
CollectionKeys.AVERAGE_SHAP,
CollectionKeys.SEARCHABLE_SCALARS,
CollectionKeys.SM_METRICS,
}

SEARCHABLE_SCALAR_COLLECTIONS = {
CollectionKeys.LOSSES,
CollectionKeys.METRICS,
CollectionKeys.SEARCHABLE_SCALARS,
}
SM_METRIC_COLLECTIONS = {CollectionKeys.LOSSES, CollectionKeys.METRICS, CollectionKeys.SM_METRICS}

# used by pt, mx, keras
NON_REDUCTION_COLLECTIONS = SCALAR_COLLECTIONS.union(SUMMARIES_COLLECTIONS)
Expand Down
26 changes: 13 additions & 13 deletions smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
NON_HISTOGRAM_COLLECTIONS,
NON_REDUCTION_COLLECTIONS,
SCALAR_COLLECTIONS,
SEARCHABLE_SCALAR_COLLECTIONS,
SM_METRIC_COLLECTIONS,
CollectionKeys,
)
from smdebug.core.collection_manager import CollectionManager
Expand Down Expand Up @@ -46,10 +46,10 @@


class ScalarCache(object):
def __init__(self, scalar_name, scalar_val, searchable, write_tb, write_event):
def __init__(self, scalar_name, scalar_val, sm_metric, write_tb, write_event):
self.name = scalar_name
self.value = scalar_val
self.searchable = searchable
self.sm_metric = sm_metric
self.write_tb = write_tb
self.write_event = write_event

Expand Down Expand Up @@ -357,7 +357,7 @@ def _close_writers(self) -> None:
if self.dry_run:
return

# flush out searchable scalars to metrics file
# flush out sm_metric scalars to metrics file
if self.metrics_writer is not None:
self._write_scalars()

Expand Down Expand Up @@ -564,16 +564,16 @@ def _write_histogram_summary(self, tensor_name, tensor_value, save_collections):
def _write_scalars(self):
"""
This function writes all the scalar values saved in the scalar_cache to file.
If searchable is set to True for certain scalars, then that scalar is written to
Minerva as well. By default, loss values are searchable.
If sm_metric is set to True for certain scalars, then that scalar is written to
Minerva as well. By default, loss values are sm_metric.
"""
for scalar_obj in self.scalar_cache:
scalar_name = scalar_obj.name
scalar_val = scalar_obj.value
searchable = scalar_obj.searchable
sm_metric = scalar_obj.sm_metric
write_tb = scalar_obj.write_tb
write_event = scalar_obj.write_event
if self.metrics_writer and searchable:
if self.metrics_writer and sm_metric:
self.metrics_writer.log_metric(scalar_name, scalar_val, self.mode_steps[self.mode])
if write_tb:
tb_writer = self._maybe_get_tb_writer()
Expand All @@ -587,20 +587,20 @@ def _write_scalars(self):
self.scalar_cache = []

# Fix step number for saving scalar and tensor
def save_scalar(self, name, value, searchable=False):
def save_scalar(self, name, value, sm_metric=False):
"""
Call save_scalar at any point in the training script to log a scalar value,
such as a metric or any other value.
:param name: Name of the scalar. A prefix 'scalar/' will be added to it
:param value: Scalar value
:param searchable: True/False. If set to True, the scalar value will be written to
:param sm_metric: True/False. If set to True, the scalar value will be written to
SageMaker Minerva
"""
name = CallbackHook.SCALAR_PREFIX + name
val = self._make_numpy_array(value)
if val.size != 1:
raise TypeError(f"{name} has non scalar value of type: {type(value)}")
scalar_obj = ScalarCache(name, val, searchable=True, write_tb=True, write_event=True)
scalar_obj = ScalarCache(name, val, sm_metric=True, write_tb=True, write_event=True)
self.scalar_cache.append(scalar_obj)

def _write_raw_tensor(self, tensor_name, tensor_value, save_collections, tensor_ref=None):
Expand Down Expand Up @@ -656,12 +656,12 @@ def _save_for_tensor(self, tensor_name, tensor_value, check_before_write=True):

self._write_for_tensor(tensor_name, tensor_value, save_collections_for_tensor)
for s_col in save_collections_for_tensor:
if s_col.name in SEARCHABLE_SCALAR_COLLECTIONS:
if s_col.name in SM_METRIC_COLLECTIONS:
np_val = self._make_numpy_array(tensor_value)
# Always log loss to Minerva
tensor_val = np.mean(np_val)
scalar_obj = ScalarCache(
tensor_name, tensor_val, searchable=True, write_tb=False, write_event=False
tensor_name, tensor_val, sm_metric=True, write_tb=False, write_event=False
)
self.scalar_cache.append(scalar_obj)

Expand Down
8 changes: 4 additions & 4 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
DEFAULT_INCLUDE_COLLECTIONS = [
CollectionKeys.METRICS,
CollectionKeys.LOSSES,
CollectionKeys.SEARCHABLE_SCALARS,
CollectionKeys.SM_METRICS,
]


Expand Down Expand Up @@ -269,7 +269,7 @@ def _close_writers(self) -> None:
if self.dry_run:
return

# flush out searchable scalars to metrics file
# flush out sm_metric scalars to metrics file
if self.metrics_writer is not None:
self._write_scalars()

Expand Down Expand Up @@ -368,13 +368,13 @@ def set_optimizer_variables(self, optimizer_variables):
optimizer_variables, ModeKeys.TRAIN
)

def save_scalar(self, name, value, searchable=False):
def save_scalar(self, name, value, sm_metric=False):
"""
save_scalar() not supported on Tensorflow
"""
self.logger.warning(
"save_scalar not supported on Tensorflow. "
"Add the scalar to scalars or searchable_scalars collection instead. "
"Add the scalar to scalars or sm_metrics collection instead. "
)
return

Expand Down
2 changes: 1 addition & 1 deletion smdebug/tensorflow/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def __init__(self, collections=None, create_default=True):
CollectionKeys.INPUTS,
CollectionKeys.OUTPUTS,
CollectionKeys.ALL,
CollectionKeys.SEARCHABLE_SCALARS,
CollectionKeys.SM_METRICS,
]:
self.create_collection(n)
self.get(CollectionKeys.BIASES).include("bias")
Expand Down
22 changes: 11 additions & 11 deletions tests/core/test_hook_save_scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
def simple_pt_model(hook, steps=10, register_loss=False):
"""
Create a PT model. save_scalar() calls are inserted before, during and after training.
Only the scalars with searchable=True will be written to a metrics file.
Only the scalars with sm_metric=True will be written to a metrics file.
"""

class Net(nn.Module):
Expand Down Expand Up @@ -70,7 +70,7 @@ def forward(self, x):
model.train()
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)

hook.save_scalar("pt_before_train", 1, searchable=False)
hook.save_scalar("pt_before_train", 1, sm_metric=False)
hook.set_mode(ModeKeys.TRAIN)
for i in range(steps):
batch_size = 32
Expand All @@ -82,16 +82,16 @@ def forward(self, x):
loss = criterion(output, target)
else:
loss = F.nll_loss(output, target)
hook.save_scalar("pt_train_loss", loss.item(), searchable=True)
hook.save_scalar("pt_train_loss", loss.item(), sm_metric=True)
loss.backward()
optimizer.step()
hook.save_scalar("pt_after_train", 1, searchable=False)
hook.save_scalar("pt_after_train", 1, sm_metric=False)


def simple_mx_model(hook, steps=10, register_loss=False):
"""
Create a MX model. save_scalar() calls are inserted before, during and after training.
Only the scalars with searchable=True will be written to a metrics file.
Only the scalars with sm_metric=True will be written to a metrics file.
"""
net = mxnn.HybridSequential()
net.add(
Expand All @@ -113,7 +113,7 @@ def simple_mx_model(hook, steps=10, register_loss=False):
hook.register_block(softmax_cross_entropy)
trainer = gluon.Trainer(net.collect_params(), "sgd", {"learning_rate": 0.1})

hook.save_scalar("mx_before_train", 1, searchable=False)
hook.save_scalar("mx_before_train", 1, sm_metric=False)
hook.set_mode(ModeKeys.TRAIN)
for i in range(steps):
batch_size = 32
Expand All @@ -127,13 +127,13 @@ def simple_mx_model(hook, steps=10, register_loss=False):
trainer.step(batch_size)
# calculate training metrics
train_loss += loss.mean().asscalar()
hook.save_scalar("mx_train_loss", loss.mean().asscalar(), searchable=True)
hook.save_scalar("mx_after_train", 1, searchable=False)
hook.save_scalar("mx_train_loss", loss.mean().asscalar(), sm_metric=True)
hook.save_scalar("mx_after_train", 1, sm_metric=False)


def simple_tf_model(hook, steps=10, lr=0.4):
"""
Create a TF model. Tensors registered with the SEARCHABLE_SCALARS collection will be logged
Create a TF model. Tensors registered with the SM_METRICS collection will be logged
to the metrics file.
"""
mnist = keras.datasets.mnist
Expand Down Expand Up @@ -191,7 +191,7 @@ def check_trials(out_dir, save_steps, coll_name, saved_scalars=None):
def check_metrics_file(saved_scalars):
"""
Check the SageMaker metrics file to ensure that all the scalars saved using
save_scalar(searchable=True) or mentioned through SEARCHABLE_SCALARS collections, have been saved.
save_scalar(sm_metrics=True) or mentioned through SM_METRICS collections, have been saved.
"""
if is_sagemaker_job():
METRICS_DIR = os.environ.get(DEFAULT_SAGEMAKER_METRICS_PATH)
Expand Down Expand Up @@ -316,6 +316,6 @@ def helper_tensorflow_tests(collection, save_config):
@pytest.mark.slow # 1:30
def test_tf_save_scalar():
save_config = SaveConfig(save_steps=[0, 2, 4, 6, 8])
collection = ("searchable_scalars", "loss")
collection = ("sm_metrics", "loss")
helper_tensorflow_tests(collection, save_config)
delete_local_trials([SMDEBUG_TF_HOOK_TESTS_DIR])