Skip to content

Commit

Permalink
Address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
dipannita08 committed Nov 15, 2024
2 parents 1417133 + e080157 commit f5d6a37
Show file tree
Hide file tree
Showing 49 changed files with 2,020 additions and 776 deletions.
5 changes: 5 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
# Change Log

## 0.1.4

* Changes
* Upgrade Jax from 0.4.33 to 0.4.34.

## 0.1.3

* Changes
Expand Down
18 changes: 14 additions & 4 deletions axlearn/audio/subsamplers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
from typing import Optional, Union

import jax
from absl.testing import parameterized
from absl.testing import absltest, parameterized
from jax import numpy as jnp

from axlearn.audio.subsamplers import ConvSubSampler
Expand Down Expand Up @@ -187,25 +187,30 @@ def test_paddings(
self.assertEqual(tuple(subsampled_shape), outputs["outputs"].shape)
self.assertEqual(tuple(subsampled_shape)[:2], outputs["paddings"].shape)

def test_activation_summaries(self):
@parameterized.parameters(jnp.float32, jnp.bfloat16)
def test_activation_summaries(self, dtype):
"""Tests that activation summaries behave as expected."""
input_dim, num_filters, hidden_dim, output_dim = 1, 80, 12, 8
prng_key = jax.random.PRNGKey(567)
prng_key, init_key, data_key = jax.random.split(prng_key, num=3)

# Initialize layer parameters.
cfg = ConvSubSampler.default_config().set(
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim
input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, dtype=dtype
)
layer = cfg.set(name="test").instantiate(parent=None)
layer_params = layer.initialize_parameters_recursively(init_key)
dtypes, _ = jax.tree.flatten(jax.tree.map(jnp.dtype, layer_params))
self.assertTrue(all(dt == dtype for dt in dtypes))

# Build inputs.
batch_size, num_frames = 4, 10
inputs_shape = [batch_size, num_frames, num_filters, input_dim]
inputs = jax.random.normal(key=data_key, shape=inputs_shape) * 10.0
lengths = jnp.array([5, 10, 9, 0])
paddings = jnp.arange(num_frames)[None, :] >= lengths[:, None]
inputs = inputs.astype(dtype)
paddings = paddings.astype(dtype)
outputs, output_collections = F(
layer,
inputs=dict(inputs=inputs, paddings=paddings),
Expand Down Expand Up @@ -247,9 +252,14 @@ def test_activation_summaries(self):
expected_outputs_norm,
)
self.assertNestedAllClose(
output_collections.summaries["activations/subsampler_inputs_mean"].weight, input_weights
output_collections.summaries["activations/subsampler_inputs_mean"].weight,
input_weights.astype(dtype),
)
self.assertNestedAllClose(
output_collections.summaries["activations/subsampler_outputs_norm"].weight,
output_weights,
)


if __name__ == "__main__":
absltest.main()
70 changes: 64 additions & 6 deletions axlearn/cloud/common/bastion.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
import io
import json
import os
import re
import shlex
import shutil
import signal
Expand Down Expand Up @@ -97,9 +98,13 @@
_LOG_DIR = "/var/tmp/logs" # Use /var/tmp/ since /tmp/ is cleared every 10 days.
_JOB_DIR = "/var/tmp/jobs"
_BASTION_SERIALIZED_JOBSPEC_ENV_VAR = "_BASTION_SERIALIZED_JOBSPEC"
BASTION_JOB_VERSION_ENV_VAR = "BASTION_JOB_VERSION"

FLAGS = flags.FLAGS

_VALID_NAME_CHARS = r"[!-~]+" # match all printing ASCII characters except space
valid_name_re = re.compile(_VALID_NAME_CHARS)


def bastion_job_flags(flag_values: flags.FlagValues = FLAGS):
flags.DEFINE_string("name", None, "Name of bastion.", flag_values=flag_values, required=True)
Expand Down Expand Up @@ -173,6 +178,8 @@ class JobLifecycleState(str, enum.Enum):
PREEMPTING = "PREEMPTING"
# Job is rescheduling.
RESCHEDULING = "RESCHEDULING"
# Job is updating.
UPDATING = "UPDATING"
# Job is cancelling. Command is terminating.
CANCELLING = "CANCELLING"
# Job has completed/terminated the command, is running cleanup command (if any).
Expand Down Expand Up @@ -235,6 +242,8 @@ def _validate_job_metadata(metadata: JobMetadata):
raise ValidationError(f"Expected {metadata.resources=} to have string keys and int values.")
if not isinstance(metadata.priority, int):
raise ValidationError(f"Expected {metadata.priority=} to be an int.")
if metadata.version is not None and not isinstance(metadata.version, int):
raise ValidationError(f"Expected {metadata.version=} to be None or an int.")


def _validate_jobspec(jobspec: JobSpec):
Expand Down Expand Up @@ -323,11 +332,16 @@ def deserialize_jobspec(f: Union[str, IO]) -> JobSpec:


def is_valid_job_name(name: str) -> bool:
"""Ensures that job name does not look like a path.
"""Ensures job name is not path-like and only contains safe characters.
We use a permissive regex to avoid making assumptions about the underlying compute environment.
This check should avoid making assumptions about the underlying compute environment.
"""
return bool(name) and ("/" not in name) and (name not in (".", "..")) and ("\n" not in name)
return (
bool(name)
and ("/" not in name)
and (name not in (".", ".."))
and bool(valid_name_re.fullmatch(name))
)


def _download_jobspec(job_name: str, *, remote_dir: str, local_dir: str = _JOB_DIR) -> JobSpec:
Expand Down Expand Up @@ -882,6 +896,12 @@ def _sync_jobs(self):
else:
curr_job = self._active_jobs[job_name]
updated_job = active_jobs[job_name]
if updated_job.spec.metadata.version != curr_job.spec.metadata.version:
# When a new version is detected, add "updated" in the metadata to signal
# job state change and job relaunch.
# Note: "updated" is a transient state and should not be persisted.
updated_job.state.metadata["updated"] = True
logging.info("Detected a different version of job %s", job_name)
curr_job.spec, curr_job.state = updated_job.spec, updated_job.state

# pylint: disable-next=too-many-statements
Expand Down Expand Up @@ -926,10 +946,15 @@ def _update_single_job(self, job: Job) -> Job:
self._append_to_job_history(
job,
msg=f"ACTIVE: start process command: {job.spec.command} "
f"with metadata: {job.state.metadata}",
f"with metadata: {job.state.metadata} and version: {job.spec.metadata.version}",
state=JobLifecycleState.STARTING,
)
env_vars = {f"BASTION_{k.upper()}": v for k, v in job.state.metadata.items()}

if job.spec.metadata.version:
# For backwards compatibility, only set the version in env when not None.
env_vars.update({BASTION_JOB_VERSION_ENV_VAR: job.spec.metadata.version})

serialized_jobspec = io.StringIO()
serialize_jobspec(job.spec, serialized_jobspec)
env_vars |= {_BASTION_SERIALIZED_JOBSPEC_ENV_VAR: serialized_jobspec.getvalue()}
Expand Down Expand Up @@ -1061,8 +1086,19 @@ def _update_jobs(self):
new_tier = verdict.metadata.get("tier")
changed_tiers = old_tier != new_tier

# Resume if not running, or keep running if scheduling tier did not change.
if job.state.status == JobStatus.PENDING or not changed_tiers:
jobspec_changed = job.state.metadata.get("updated")

# Jobspec changed, trigger a restart of the runner.
if jobspec_changed:
self._append_to_job_history(
job,
msg="UPDATING: Detected updated jobspec. Will restart the runner "
"by sending to PENDING state",
state=JobLifecycleState.UPDATING,
)
job.state.status = JobStatus.PENDING
elif job.state.status == JobStatus.PENDING or not changed_tiers:
# Resume if not running, or keep running if scheduling tier did not change.
job.state.status = JobStatus.ACTIVE
else:
# Job changed scheduling tiers, and must be restarted on the new tier.
Expand Down Expand Up @@ -1279,3 +1315,25 @@ def submit_job(self, job_name: str, *, job_spec_file: str):
else:
# Upload the job for bastion to pickup.
tf_io.gfile.copy(job_spec_file, dst)

def get_job(self, job_name: str) -> JobSpec:
job_path = os.path.join(self.active_job_dir, job_name)
if not tf_io.gfile.exists(job_path):
raise ValueError(f"Unable to locate jobspec {job_path}")

with tempfile.TemporaryDirectory() as tmpdir:
job_spec = _download_jobspec(job_name, remote_dir=self.active_job_dir, local_dir=tmpdir)
return job_spec

def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec:
dst = os.path.join(self.active_job_dir, job_name)
if not tf_io.gfile.exists(dst):
raise ValueError(f"Unable to locate jobspec {dst}")

with tempfile.NamedTemporaryFile("w") as f:
serialize_jobspec(job_spec, f)
# Upload the job for bastion to pickup.
tf_io.gfile.copy(f.name, dst, overwrite=True)
logging.info("Job %s is updating.", job_name)

return job_spec
132 changes: 131 additions & 1 deletion axlearn/cloud/common/bastion_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,14 @@ def mock_download_job_state(job_name, *, remote_dir, **kwargs):
dict(name="..test", valid=True), # This is a valid file name.
dict(name="test.job..", valid=True), # This is a valid file name.
dict(name="test\n", valid=False), # newline causes bastion to crash
dict(name="test", valid=True),
dict(name="test“job”test", valid=False), # pinyin quotes are invalid
dict(name="test‘job’test", valid=False), # pinyin quotes are invalid
dict(name="test\\job", valid=True),
dict(name="test,job", valid=True),
dict(name="test:job", valid=True),
dict(name="test_job", valid=True),
dict(name="test job", valid=False),
)
def test_is_valid_job_name(self, name, valid):
self.assertEqual(valid, is_valid_job_name(name))
Expand Down Expand Up @@ -773,6 +781,17 @@ def test_sync_jobs(self):
resources={"test": 8},
),
),
new_jobspec(
name="job3",
command="",
metadata=JobMetadata(
user_id="user1",
project_id="project1",
creation_time=datetime(1900, 1, 1, 0, 0, 0, 1),
resources={"test": 8},
version=1,
),
),
]
# Write them to the Bastion submission directory.
for spec in specs:
Expand All @@ -787,7 +806,30 @@ def test_sync_jobs(self):
# Download the jobspecs.
mock_bastion._sync_jobs()
# Confirm expected jobs were downloaded.
self.assertSequenceEqual(list(mock_bastion._active_jobs), ["job1"])
self.assertSequenceEqual(
sorted(list(mock_bastion._active_jobs)), sorted(["job1", "job3"])
)

# Submit the job again to update the version.
updated_job_spec = new_jobspec(
name="job3",
command="",
metadata=JobMetadata(
user_id="user1",
project_id="project1",
creation_time=datetime(1900, 1, 1, 0, 0, 0, 1),
resources={"test": 8},
version=2,
),
)
bastion_dir.update_job(updated_job_spec.name, job_spec=updated_job_spec)

# Download the jobspecs.
mock_bastion._sync_jobs()
# Confirm the update is received.
self.assertEqual(
mock_bastion._active_jobs.get(updated_job_spec.name).state.metadata["updated"], True
)

@parameterized.product(
[
Expand Down Expand Up @@ -1099,6 +1141,23 @@ def mock_proc(cmd, **kwargs):
command_proc=mock_proc("command"),
cleanup_proc=None, # No cleanup_proc for ACTIVE.
),
# This job will go from ACTIVE to PENDING, since it is being updated.
"updating": Job(
spec=new_jobspec(
name="updating",
command="command",
cleanup_command="cleanup",
metadata=JobMetadata(
user_id="e",
project_id="project1",
creation_time=yesterday + timedelta(seconds=2),
resources={"v4": 1}, # Fits within the v4 budget in project1.
),
),
state=JobState(status=JobStatus.ACTIVE, metadata={"tier": 0, "updated": True}),
command_proc=mock_proc("command"),
cleanup_proc=None, # No cleanup_proc for ACTIVE.
),
# This job will go from ACTIVE to CLEANING.
"cleaning": Job(
spec=new_jobspec(
Expand Down Expand Up @@ -1188,6 +1247,7 @@ def mock_proc(cmd, **kwargs):
"resume": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}),
"active": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}),
"preempt": JobState(status=JobStatus.PENDING),
"updating": JobState(status=JobStatus.PENDING, metadata={"tier": 0}),
"cleaning": JobState(status=JobStatus.CLEANING, metadata={"tier": 0}),
"cleaning_cancel": JobState(status=JobStatus.CLEANING),
"completed": JobState(status=JobStatus.COMPLETED),
Expand Down Expand Up @@ -1259,6 +1319,8 @@ def mock_proc(cmd, **kwargs):
expected_msg = {
"resume": "ACTIVE: start process command",
"preempt": "PENDING: pre-empting",
"updating": "UPDATING: Detected updated jobspec. Will restart "
"the runner by sending to PENDING state",
"cleaning": "CLEANING: process finished",
"cleaning_cancel": "CLEANING: process terminated",
"completed": "COMPLETED: cleanup finished",
Expand Down Expand Up @@ -1566,6 +1628,74 @@ def test_delete(self, spec_exists):
remote_dir=bastion_dir.user_states_dir,
)

@parameterized.parameters(True, False)
def test_get(self, spec_exists):
job_name = "test-job"
bastion_dir = (
bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate()
)

patch_tfio = mock.patch.multiple(
f"{bastion.__name__}.tf_io.gfile",
exists=mock.MagicMock(return_value=spec_exists),
copy=mock.DEFAULT,
)

mock_deserialize_jobspec = mock.patch(
f"{bastion.__name__}.deserialize_jobspec", return_value=None
)

if spec_exists:
ctx = contextlib.nullcontext()
else:
ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec")

with ctx, mock_deserialize_jobspec, patch_tfio as mock_tfio:
bastion_dir.get_job(job_name)
if spec_exists:
mock_tfio["copy"].assert_called()
self.assertEqual(
mock_tfio["copy"].call_args[0][0],
os.path.join(bastion_dir.active_job_dir, job_name),
)
self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True)
else:
mock_tfio["copy"].assert_not_called()

@parameterized.parameters(True, False)
def test_update(self, spec_exists):
job_name = "test-job"
bastion_dir = (
bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate()
)

patch_tfio = mock.patch.multiple(
f"{bastion.__name__}.tf_io.gfile",
exists=mock.MagicMock(return_value=spec_exists),
copy=mock.DEFAULT,
)

mock_serialize_jobspec = mock.patch(
f"{bastion.__name__}.serialize_jobspec", return_value=None
)

if spec_exists:
ctx = contextlib.nullcontext()
else:
ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec")

with ctx, mock_serialize_jobspec, patch_tfio as mock_tfio:
bastion_dir.update_job(job_name, job_spec=None)
if spec_exists:
mock_tfio["copy"].assert_called()
self.assertEqual(
mock_tfio["copy"].call_args[0][1],
os.path.join(bastion_dir.active_job_dir, job_name),
)
self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True)
else:
mock_tfio["copy"].assert_not_called()


if __name__ == "__main__":
absltest.main()
2 changes: 2 additions & 0 deletions axlearn/cloud/common/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@ class JobMetadata:
# It is not used by the bastion directly.
# TODO(haijing-fu): make it as a required field.
job_id: Optional[str] = None
# Version of the job.
version: Optional[int] = None


@dataclasses.dataclass
Expand Down
Loading

0 comments on commit f5d6a37

Please sign in to comment.