Skip to content

Commit

Permalink
Fix: Set OMP_NUM_THREADS by default in Elastic (#2569)
Browse files Browse the repository at this point in the history
Signed-off-by: Dennis Keck <26092524+fellhorn@users.noreply.github.com>
  • Loading branch information
fellhorn authored Jul 8, 2024
1 parent caf3139 commit 89e5461
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
20 changes: 20 additions & 0 deletions plugins/flytekit-kf-pytorch/flytekitplugins/kfpytorch/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,10 @@ class Elastic(object):
Single-node elastic training is executed in a k8s pod when `nnodes` is set to 1.
Multi-node training is executed otherwise using a `Pytorch Job <https://github.com/kubeflow/training-operator>`_.
Like `torchrun`, this plugin sets the environment variable `OMP_NUM_THREADS` to 1 if it is not set.
Please see https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html for potential performance improvements.
To change `OMP_NUM_THREADS`, specify it in the environment dict of the flytekit task decorator or via `pyflyte run --env`.
Args:
nnodes (Union[int, str]): Number of nodes, or the range of nodes in form <minimum_nodes>:<maximum_nodes>.
nproc_per_node (str): Number of workers per node.
Expand Down Expand Up @@ -332,6 +336,22 @@ def _execute(self, **kwargs) -> Any:
)
)

# If OMP_NUM_THREADS is not set, set it to 1 to avoid overloading the system.
# Doing so to copy the default behavior of torchrun.
# See https://github.com/pytorch/pytorch/blob/eea4ece256d74c6f25c1f4eab37b3f2f4aeefd4d/torch/distributed/run.py#L791
if "OMP_NUM_THREADS" not in os.environ and self.task_config.nproc_per_node > 1:
omp_num_threads = 1
logger.warning(
"\n*****************************************\n"
"Setting OMP_NUM_THREADS environment variable for each process to be "
"%s in default, to avoid your system being overloaded, "
"please further tune the variable for optimal performance in "
"your application as needed. \n"
"*****************************************",
omp_num_threads,
)
os.environ["OMP_NUM_THREADS"] = str(omp_num_threads)

config = LaunchConfig(
run_id=flytekit.current_context().execution_id.name,
min_nodes=self.min_nodes,
Expand Down
24 changes: 24 additions & 0 deletions plugins/flytekit-kf-pytorch/tests/test_elastic_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@
from flytekit.configuration import SerializationSettings
from flytekit.exceptions.user import FlyteRecoverableException

@pytest.fixture(autouse=True, scope="function")
def restore_env():
original_env = os.environ.copy()
yield
os.environ.clear()
os.environ.update(original_env)

@dataclass
class Config(DataClassJsonMixin):
Expand Down Expand Up @@ -212,3 +218,21 @@ def test_task():
"ttlSecondsAfterFinished": 600,
"activeDeadlineSeconds": 36000,
}

@pytest.mark.parametrize("start_method", ["spawn", "fork"])
def test_omp_num_threads(start_method: str) -> None:
"""Test that the env var OMP_NUM_THREADS is set by default and not overwritten if set."""

@task(task_config=Elastic(nnodes=1, nproc_per_node=2, start_method=start_method))
def test_task_omp_default():
assert os.environ["OMP_NUM_THREADS"] == "1"

test_task_omp_default()

os.environ["OMP_NUM_THREADS"] = "42"

@task(task_config=Elastic(nnodes=1, nproc_per_node=2, start_method=start_method))
def test_task_omp_set():
assert os.environ["OMP_NUM_THREADS"] == "42"

test_task_omp_set()

0 comments on commit 89e5461

Please sign in to comment.