Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion smdebug/core/hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ def __init__(
self.reduction_config = reduction_config
self.include_regex = include_regex
self.collection_manager = collection_manager
self.collection_manager.set_num_workers(self._get_num_workers())
self.init_step = init_step

self.logger = logger
Expand Down
57 changes: 31 additions & 26 deletions smdebug/tensorflow/base_hook.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
get_worker_id_from_tf_config,
is_mirrored_strategy,
is_parameter_server_strategy,
load_tf_config_json,
)

try:
Expand Down Expand Up @@ -86,9 +87,8 @@ def __init__(
self.device_map = {}
self.writer_map = {}
self.distribution_strategy = None
self.tf_config = os.getenv(
"TF_CONFIG"
) # caches the TF_CONFIG for the parameter server strategy
# caches the TF_CONFIG for the parameter server strategy
self.tf_config_json = load_tf_config_json(os.getenv("TF_CONFIG"))
self._hook_supported = None
set_hook(self)

Expand All @@ -101,9 +101,6 @@ def _get_distribution_strategy(self) -> TFDistributionStrategy:
except (ModuleNotFoundError, ValueError, ImportError):
pass

if self.tf_config and is_parameter_server_strategy(self.tf_config):
return TFDistributionStrategy.PARAMETER_SERVER_STRATEGY

strat = tf.distribute.get_strategy()
if is_mirrored_strategy(strat):
return TFDistributionStrategy.MIRRORED_STRATEGY
Expand All @@ -112,6 +109,10 @@ def _get_distribution_strategy(self) -> TFDistributionStrategy:
# single device
return TFDistributionStrategy.NONE

# Disable PS till we verify proper support of PS on SM
# if self.tf_config_json and is_parameter_server_strategy(self.tf_config):
# return TFDistributionStrategy.PARAMETER_SERVER_STRATEGY

return TFDistributionStrategy.UNSUPPORTED

def _get_worker_name(self) -> str:
Expand All @@ -126,18 +127,20 @@ def _get_worker_name(self) -> str:
It is safe to return the CONFIG_DEFAULT_WORKER_NAME in this case.
:return: str
"""
try:
assert self.distribution_strategy is not None
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
import horovod.tensorflow as hvd

if hvd.size():
return f"worker_{hvd.rank()}"
except (ModuleNotFoundError, ValueError, ImportError):
pass

tf_config = os.getenv("TF_CONFIG")
if tf_config and is_parameter_server_strategy(tf_config):
return get_worker_id_from_tf_config(tf_config)
return CONFIG_DEFAULT_WORKER_NAME
return f"worker_{hvd.rank()}"
elif self.distribution_strategy == TFDistributionStrategy.MIRRORED_STRATEGY:
# unused for this strategy
raise NotImplementedError
elif self.distribution_strategy == TFDistributionStrategy.NONE:
return CONFIG_DEFAULT_WORKER_NAME
elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
raise NotImplementedError
elif self.tf_config_json and is_parameter_server_strategy(self.tf_config_json):
return get_worker_id_from_tf_config(self.tf_config_json)

def export_collections(self):
num_workers = self._get_num_workers()
Expand All @@ -163,18 +166,20 @@ def export_collections(self):
self.collection_manager.export(self.out_dir, collection_file_name)

def _get_num_workers(self):
try:
assert self.distribution_strategy is not None
if self.distribution_strategy == TFDistributionStrategy.HOROVOD:
import horovod.tensorflow as hvd

if hvd.size():
return hvd.size()
except (ModuleNotFoundError, ValueError, ImportError):
pass
tf_config = os.getenv("TF_CONFIG")
if tf_config and is_parameter_server_strategy(tf_config):
return get_num_workers_from_tf_config(tf_config)
strategy = tf.distribute.get_strategy()
return strategy.num_replicas_in_sync
return hvd.size()
elif self.distribution_strategy == TFDistributionStrategy.MIRRORED_STRATEGY:
strategy = tf.distribute.get_strategy()
return strategy.num_replicas_in_sync
elif self.distribution_strategy == TFDistributionStrategy.NONE:
return 1
elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
raise NotImplementedError
elif self.tf_config_json and is_parameter_server_strategy(self.tf_config_json):
return get_num_workers_from_tf_config(self.tf_config_json)

def _export_model(self):
tb_writer = self._maybe_get_tb_writer()
Expand Down
5 changes: 3 additions & 2 deletions smdebug/tensorflow/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,15 +284,16 @@ def begin(self):

if self.save_all_workers is False:
if self.distribution_strategy == TFDistributionStrategy.PARAMETER_SERVER_STRATEGY:
self.chief_worker = get_chief_worker_parameter_server(self.tf_config)
self.chief_worker = get_chief_worker_parameter_server(self.tf_config_json)
elif self.distribution_strategy == TFDistributionStrategy.HOROVOD:
self.chief_worker = CONFIG_DEFAULT_WORKER_NAME
elif (
len(self.device_map)
and self.distribution_strategy == TFDistributionStrategy.MIRRORED_STRATEGY
):
self.chief_worker = sorted(self.device_map.keys())[0]

elif self.distribution_strategy == TFDistributionStrategy.UNSUPPORTED:
raise NotImplementedError
self._export_model()
self.export_collections()

Expand Down
42 changes: 21 additions & 21 deletions smdebug/tensorflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from tensorflow.python.distribute import values

# First Party
from smdebug.core.config_constants import CONFIG_DEFAULT_WORKER_NAME
from smdebug.core.modes import ModeKeys

try:
Expand Down Expand Up @@ -195,35 +194,42 @@ def get_original_fetch_ops(fetches):
"""


def is_parameter_server_strategy(tf_config: str) -> bool:
def load_tf_config_json(tf_config: str):
try:
tf_config = json.loads(tf_config)
return json.loads(tf_config)
except (json.JSONDecodeError, TypeError):
return False # Do not break for incorrectly set tf_config
return "cluster" in tf_config and "ps" in tf_config["cluster"]
# if tf_config is None throws TypeError, so return None from next line
return None


def is_mirrored_strategy(strat):
return isinstance(strat, (tf.distribute.MirroredStrategy, ContribMirroredStrategy))
def is_parameter_server_strategy(tf_config_json) -> bool:
return "cluster" in tf_config_json and "ps" in tf_config_json["cluster"]


def get_worker_id_from_tf_config(tf_config: str) -> str:
def get_worker_id_from_tf_config(tf_config_json) -> str:
"""Valid roles in a cluster is "chief", "worker", "ps" and "evaluator"."""
tf_config = json.loads(tf_config)
task = tf_config["task"]
task = tf_config_json["task"]
worker_type = task["type"]
worker_index = task["index"]
return f"{worker_type}_{worker_index}"


def get_num_workers_from_tf_config(tf_config: str) -> int:
tf_config = json.loads(tf_config)
workers = tf_config["cluster"]["worker"]
if "chief" in tf_config["cluster"]:
workers.extend(tf_config["cluster"]["chief"])
def get_num_workers_from_tf_config(tf_config_json) -> int:
workers = tf_config_json["cluster"]["worker"]
if "chief" in tf_config_json["cluster"]:
workers.extend(tf_config_json["cluster"]["chief"])
return len(workers)


def get_chief_worker_parameter_server(tf_config_json):
if "chief" in tf_config_json["cluster"]:
return "chief_0"


def is_mirrored_strategy(strat):
return isinstance(strat, (tf.distribute.MirroredStrategy, ContribMirroredStrategy))


def is_keras_optimizer(obj):
for cls in obj.__class__.__mro__:
if ".".join([cls.__module__, cls.__name__]) == "keras.optimizers.Optimizer":
Expand Down Expand Up @@ -282,9 +288,3 @@ def get_keras_mode(mode):
return KerasModeKeys.TEST
elif mode == ModeKeys.PREDICT:
return KerasModeKeys.PREDICT


def get_chief_worker_parameter_server(tf_config):
if "chief" in tf_config["cluster"]:
return "chief_0"
return CONFIG_DEFAULT_WORKER_NAME