diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a9c7eac..5fc2dff7 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Change Log +## 0.1.4 + +* Changes + * Upgrade Jax from 0.4.33 to 0.4.34. + ## 0.1.3 * Changes diff --git a/axlearn/audio/subsamplers_test.py b/axlearn/audio/subsamplers_test.py index d90b51fb..5d462ff5 100644 --- a/axlearn/audio/subsamplers_test.py +++ b/axlearn/audio/subsamplers_test.py @@ -7,7 +7,7 @@ from typing import Optional, Union import jax -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax import numpy as jnp from axlearn.audio.subsamplers import ConvSubSampler @@ -187,7 +187,8 @@ def test_paddings( self.assertEqual(tuple(subsampled_shape), outputs["outputs"].shape) self.assertEqual(tuple(subsampled_shape)[:2], outputs["paddings"].shape) - def test_activation_summaries(self): + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_activation_summaries(self, dtype): """Tests that activation summaries behave as expected.""" input_dim, num_filters, hidden_dim, output_dim = 1, 80, 12, 8 prng_key = jax.random.PRNGKey(567) @@ -195,10 +196,12 @@ def test_activation_summaries(self): # Initialize layer parameters. cfg = ConvSubSampler.default_config().set( - input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim + input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, dtype=dtype ) layer = cfg.set(name="test").instantiate(parent=None) layer_params = layer.initialize_parameters_recursively(init_key) + dtypes, _ = jax.tree.flatten(jax.tree.map(jnp.dtype, layer_params)) + self.assertTrue(all(dt == dtype for dt in dtypes)) # Build inputs. batch_size, num_frames = 4, 10 @@ -206,6 +209,8 @@ def test_activation_summaries(self): inputs = jax.random.normal(key=data_key, shape=inputs_shape) * 10.0 lengths = jnp.array([5, 10, 9, 0]) paddings = jnp.arange(num_frames)[None, :] >= lengths[:, None] + inputs = inputs.astype(dtype) + paddings = paddings.astype(dtype) outputs, output_collections = F( layer, inputs=dict(inputs=inputs, paddings=paddings), @@ -247,9 +252,14 @@ def test_activation_summaries(self): expected_outputs_norm, ) self.assertNestedAllClose( - output_collections.summaries["activations/subsampler_inputs_mean"].weight, input_weights + output_collections.summaries["activations/subsampler_inputs_mean"].weight, + input_weights.astype(dtype), ) self.assertNestedAllClose( output_collections.summaries["activations/subsampler_outputs_norm"].weight, output_weights, ) + + +if __name__ == "__main__": + absltest.main() diff --git a/axlearn/cloud/common/bastion.py b/axlearn/cloud/common/bastion.py index 0901bb0a..6c8bdd28 100644 --- a/axlearn/cloud/common/bastion.py +++ b/axlearn/cloud/common/bastion.py @@ -53,6 +53,7 @@ import io import json import os +import re import shlex import shutil import signal @@ -97,9 +98,13 @@ _LOG_DIR = "/var/tmp/logs" # Use /var/tmp/ since /tmp/ is cleared every 10 days. _JOB_DIR = "/var/tmp/jobs" _BASTION_SERIALIZED_JOBSPEC_ENV_VAR = "_BASTION_SERIALIZED_JOBSPEC" +BASTION_JOB_VERSION_ENV_VAR = "BASTION_JOB_VERSION" FLAGS = flags.FLAGS +_VALID_NAME_CHARS = r"[!-~]+" # match all printing ASCII characters except space +valid_name_re = re.compile(_VALID_NAME_CHARS) + def bastion_job_flags(flag_values: flags.FlagValues = FLAGS): flags.DEFINE_string("name", None, "Name of bastion.", flag_values=flag_values, required=True) @@ -173,6 +178,8 @@ class JobLifecycleState(str, enum.Enum): PREEMPTING = "PREEMPTING" # Job is rescheduling. RESCHEDULING = "RESCHEDULING" + # Job is updating. + UPDATING = "UPDATING" # Job is cancelling. Command is terminating. CANCELLING = "CANCELLING" # Job has completed/terminated the command, is running cleanup command (if any). @@ -235,6 +242,8 @@ def _validate_job_metadata(metadata: JobMetadata): raise ValidationError(f"Expected {metadata.resources=} to have string keys and int values.") if not isinstance(metadata.priority, int): raise ValidationError(f"Expected {metadata.priority=} to be an int.") + if metadata.version is not None and not isinstance(metadata.version, int): + raise ValidationError(f"Expected {metadata.version=} to be None or an int.") def _validate_jobspec(jobspec: JobSpec): @@ -323,11 +332,16 @@ def deserialize_jobspec(f: Union[str, IO]) -> JobSpec: def is_valid_job_name(name: str) -> bool: - """Ensures that job name does not look like a path. + """Ensures job name is not path-like and only contains safe characters. - We use a permissive regex to avoid making assumptions about the underlying compute environment. + This check should avoid making assumptions about the underlying compute environment. """ - return bool(name) and ("/" not in name) and (name not in (".", "..")) and ("\n" not in name) + return ( + bool(name) + and ("/" not in name) + and (name not in (".", "..")) + and bool(valid_name_re.fullmatch(name)) + ) def _download_jobspec(job_name: str, *, remote_dir: str, local_dir: str = _JOB_DIR) -> JobSpec: @@ -882,6 +896,12 @@ def _sync_jobs(self): else: curr_job = self._active_jobs[job_name] updated_job = active_jobs[job_name] + if updated_job.spec.metadata.version != curr_job.spec.metadata.version: + # When a new version is detected, add "updated" in the metadata to signal + # job state change and job relaunch. + # Note: "updated" is a transient state and should not be persisted. + updated_job.state.metadata["updated"] = True + logging.info("Detected a different version of job %s", job_name) curr_job.spec, curr_job.state = updated_job.spec, updated_job.state # pylint: disable-next=too-many-statements @@ -926,10 +946,15 @@ def _update_single_job(self, job: Job) -> Job: self._append_to_job_history( job, msg=f"ACTIVE: start process command: {job.spec.command} " - f"with metadata: {job.state.metadata}", + f"with metadata: {job.state.metadata} and version: {job.spec.metadata.version}", state=JobLifecycleState.STARTING, ) env_vars = {f"BASTION_{k.upper()}": v for k, v in job.state.metadata.items()} + + if job.spec.metadata.version: + # For backwards compatibility, only set the version in env when not None. + env_vars.update({BASTION_JOB_VERSION_ENV_VAR: job.spec.metadata.version}) + serialized_jobspec = io.StringIO() serialize_jobspec(job.spec, serialized_jobspec) env_vars |= {_BASTION_SERIALIZED_JOBSPEC_ENV_VAR: serialized_jobspec.getvalue()} @@ -1061,8 +1086,19 @@ def _update_jobs(self): new_tier = verdict.metadata.get("tier") changed_tiers = old_tier != new_tier - # Resume if not running, or keep running if scheduling tier did not change. - if job.state.status == JobStatus.PENDING or not changed_tiers: + jobspec_changed = job.state.metadata.get("updated") + + # Jobspec changed, trigger a restart of the runner. + if jobspec_changed: + self._append_to_job_history( + job, + msg="UPDATING: Detected updated jobspec. Will restart the runner " + "by sending to PENDING state", + state=JobLifecycleState.UPDATING, + ) + job.state.status = JobStatus.PENDING + elif job.state.status == JobStatus.PENDING or not changed_tiers: + # Resume if not running, or keep running if scheduling tier did not change. job.state.status = JobStatus.ACTIVE else: # Job changed scheduling tiers, and must be restarted on the new tier. @@ -1279,3 +1315,25 @@ def submit_job(self, job_name: str, *, job_spec_file: str): else: # Upload the job for bastion to pickup. tf_io.gfile.copy(job_spec_file, dst) + + def get_job(self, job_name: str) -> JobSpec: + job_path = os.path.join(self.active_job_dir, job_name) + if not tf_io.gfile.exists(job_path): + raise ValueError(f"Unable to locate jobspec {job_path}") + + with tempfile.TemporaryDirectory() as tmpdir: + job_spec = _download_jobspec(job_name, remote_dir=self.active_job_dir, local_dir=tmpdir) + return job_spec + + def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec: + dst = os.path.join(self.active_job_dir, job_name) + if not tf_io.gfile.exists(dst): + raise ValueError(f"Unable to locate jobspec {dst}") + + with tempfile.NamedTemporaryFile("w") as f: + serialize_jobspec(job_spec, f) + # Upload the job for bastion to pickup. + tf_io.gfile.copy(f.name, dst, overwrite=True) + logging.info("Job %s is updating.", job_name) + + return job_spec diff --git a/axlearn/cloud/common/bastion_test.py b/axlearn/cloud/common/bastion_test.py index 30db616e..bd1dc673 100644 --- a/axlearn/cloud/common/bastion_test.py +++ b/axlearn/cloud/common/bastion_test.py @@ -153,6 +153,14 @@ def mock_download_job_state(job_name, *, remote_dir, **kwargs): dict(name="..test", valid=True), # This is a valid file name. dict(name="test.job..", valid=True), # This is a valid file name. dict(name="test\n", valid=False), # newline causes bastion to crash + dict(name="test", valid=True), + dict(name="test“job”test", valid=False), # pinyin quotes are invalid + dict(name="test‘job’test", valid=False), # pinyin quotes are invalid + dict(name="test\\job", valid=True), + dict(name="test,job", valid=True), + dict(name="test:job", valid=True), + dict(name="test_job", valid=True), + dict(name="test job", valid=False), ) def test_is_valid_job_name(self, name, valid): self.assertEqual(valid, is_valid_job_name(name)) @@ -773,6 +781,17 @@ def test_sync_jobs(self): resources={"test": 8}, ), ), + new_jobspec( + name="job3", + command="", + metadata=JobMetadata( + user_id="user1", + project_id="project1", + creation_time=datetime(1900, 1, 1, 0, 0, 0, 1), + resources={"test": 8}, + version=1, + ), + ), ] # Write them to the Bastion submission directory. for spec in specs: @@ -787,7 +806,30 @@ def test_sync_jobs(self): # Download the jobspecs. mock_bastion._sync_jobs() # Confirm expected jobs were downloaded. - self.assertSequenceEqual(list(mock_bastion._active_jobs), ["job1"]) + self.assertSequenceEqual( + sorted(list(mock_bastion._active_jobs)), sorted(["job1", "job3"]) + ) + + # Submit the job again to update the version. + updated_job_spec = new_jobspec( + name="job3", + command="", + metadata=JobMetadata( + user_id="user1", + project_id="project1", + creation_time=datetime(1900, 1, 1, 0, 0, 0, 1), + resources={"test": 8}, + version=2, + ), + ) + bastion_dir.update_job(updated_job_spec.name, job_spec=updated_job_spec) + + # Download the jobspecs. + mock_bastion._sync_jobs() + # Confirm the update is received. + self.assertEqual( + mock_bastion._active_jobs.get(updated_job_spec.name).state.metadata["updated"], True + ) @parameterized.product( [ @@ -1099,6 +1141,23 @@ def mock_proc(cmd, **kwargs): command_proc=mock_proc("command"), cleanup_proc=None, # No cleanup_proc for ACTIVE. ), + # This job will go from ACTIVE to PENDING, since it is being updated. + "updating": Job( + spec=new_jobspec( + name="updating", + command="command", + cleanup_command="cleanup", + metadata=JobMetadata( + user_id="e", + project_id="project1", + creation_time=yesterday + timedelta(seconds=2), + resources={"v4": 1}, # Fits within the v4 budget in project1. + ), + ), + state=JobState(status=JobStatus.ACTIVE, metadata={"tier": 0, "updated": True}), + command_proc=mock_proc("command"), + cleanup_proc=None, # No cleanup_proc for ACTIVE. + ), # This job will go from ACTIVE to CLEANING. "cleaning": Job( spec=new_jobspec( @@ -1188,6 +1247,7 @@ def mock_proc(cmd, **kwargs): "resume": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}), "active": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}), "preempt": JobState(status=JobStatus.PENDING), + "updating": JobState(status=JobStatus.PENDING, metadata={"tier": 0}), "cleaning": JobState(status=JobStatus.CLEANING, metadata={"tier": 0}), "cleaning_cancel": JobState(status=JobStatus.CLEANING), "completed": JobState(status=JobStatus.COMPLETED), @@ -1259,6 +1319,8 @@ def mock_proc(cmd, **kwargs): expected_msg = { "resume": "ACTIVE: start process command", "preempt": "PENDING: pre-empting", + "updating": "UPDATING: Detected updated jobspec. Will restart " + "the runner by sending to PENDING state", "cleaning": "CLEANING: process finished", "cleaning_cancel": "CLEANING: process terminated", "completed": "COMPLETED: cleanup finished", @@ -1566,6 +1628,74 @@ def test_delete(self, spec_exists): remote_dir=bastion_dir.user_states_dir, ) + @parameterized.parameters(True, False) + def test_get(self, spec_exists): + job_name = "test-job" + bastion_dir = ( + bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate() + ) + + patch_tfio = mock.patch.multiple( + f"{bastion.__name__}.tf_io.gfile", + exists=mock.MagicMock(return_value=spec_exists), + copy=mock.DEFAULT, + ) + + mock_deserialize_jobspec = mock.patch( + f"{bastion.__name__}.deserialize_jobspec", return_value=None + ) + + if spec_exists: + ctx = contextlib.nullcontext() + else: + ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec") + + with ctx, mock_deserialize_jobspec, patch_tfio as mock_tfio: + bastion_dir.get_job(job_name) + if spec_exists: + mock_tfio["copy"].assert_called() + self.assertEqual( + mock_tfio["copy"].call_args[0][0], + os.path.join(bastion_dir.active_job_dir, job_name), + ) + self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True) + else: + mock_tfio["copy"].assert_not_called() + + @parameterized.parameters(True, False) + def test_update(self, spec_exists): + job_name = "test-job" + bastion_dir = ( + bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate() + ) + + patch_tfio = mock.patch.multiple( + f"{bastion.__name__}.tf_io.gfile", + exists=mock.MagicMock(return_value=spec_exists), + copy=mock.DEFAULT, + ) + + mock_serialize_jobspec = mock.patch( + f"{bastion.__name__}.serialize_jobspec", return_value=None + ) + + if spec_exists: + ctx = contextlib.nullcontext() + else: + ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec") + + with ctx, mock_serialize_jobspec, patch_tfio as mock_tfio: + bastion_dir.update_job(job_name, job_spec=None) + if spec_exists: + mock_tfio["copy"].assert_called() + self.assertEqual( + mock_tfio["copy"].call_args[0][1], + os.path.join(bastion_dir.active_job_dir, job_name), + ) + self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True) + else: + mock_tfio["copy"].assert_not_called() + if __name__ == "__main__": absltest.main() diff --git a/axlearn/cloud/common/types.py b/axlearn/cloud/common/types.py index 98163e7c..0c61b808 100644 --- a/axlearn/cloud/common/types.py +++ b/axlearn/cloud/common/types.py @@ -29,6 +29,8 @@ class JobMetadata: # It is not used by the bastion directly. # TODO(haijing-fu): make it as a required field. job_id: Optional[str] = None + # Version of the job. + version: Optional[int] = None @dataclasses.dataclass diff --git a/axlearn/cloud/gcp/bundler.py b/axlearn/cloud/gcp/bundler.py index f8ae4e5a..524f76e3 100644 --- a/axlearn/cloud/gcp/bundler.py +++ b/axlearn/cloud/gcp/bundler.py @@ -148,6 +148,9 @@ def from_spec( cfg.project = cfg.project or gcp_settings("project", required=False, fv=fv) cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) + # The value from from_spec is a str and will result in wrong condition. + if isinstance(cfg.is_async, str): + cfg.is_async = cfg.is_async.lower() != "false" return cfg # pylint: disable-next=no-self-use,unused-argument diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 0d09f45c..08073acd 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -22,7 +22,11 @@ from absl import flags from google.auth.credentials import Credentials -from axlearn.cloud.common.bastion import _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, deserialize_jobspec +from axlearn.cloud.common.bastion import ( + _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, + BASTION_JOB_VERSION_ENV_VAR, + deserialize_jobspec, +) from axlearn.cloud.common.bundler import BaseDockerBundler from axlearn.cloud.common.job import Job from axlearn.cloud.common.utils import parse_kv_flags, subprocess_run @@ -52,6 +56,9 @@ # Set 80% of the max value as the requested memory. _MEMORY_REQUEST_PERCENTAGE = 0.8 +# A label added to the jobset to indicate job version. +BASTION_JOB_VERSION_LABEL = "bastion-job-version" + class GCPJob(Job): """Base GCP Job definition.""" @@ -345,6 +352,13 @@ class Config(GCPJob.Config): Each host's output will be placed in `"{output_dir}/output/$HOSTNAME/"`. This directory is used by the sidecar container to sync outputs to GCS using gsutil. Ensure that `output_dir` is a valid GCS path (e.g., `gs://your-bucket/path`). + priority_class: Optional; The GKE PriorityClass for the job. + https://kubernetes.io/docs/concepts/scheduling-eviction/pod-priority-preemption + Note: 1. Values need to be pre-defined in each cluster. + 2. Job level priority is enforced by pod level priority of the leader pod. + This is managed by jobset controller. + 3. For TPU slice, this requires alpha.jobset.sigs.k8s.io/exclusive-topology + 4. [2024-11-11] Does not work on multi-slice TPU training yet. host_mounts: List of volumes from host to mount into the container. See `HostMount` for details. """ @@ -356,6 +370,7 @@ class Config(GCPJob.Config): enable_pre_provisioner: Optional[bool] = None queue: Optional[str] = None output_dir: Optional[str] = None + priority_class: Optional[str] = None host_mounts: Optional[list[HostMount]] = None @classmethod @@ -552,6 +567,7 @@ def _build_container(self) -> Nested[Any]: # Env var values should always be strings. env=k8s_env_vars, volumeMounts=volume_mounts, + imagePullPolicy="Always", ) def _build_uploader_container(self) -> Nested[Any]: @@ -696,6 +712,9 @@ def _build_pod(self) -> Nested[Any]: } ) + if os.environ.get(BASTION_JOB_VERSION_ENV_VAR): + labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)}) + if os.environ.get(_BASTION_SERIALIZED_JOBSPEC_ENV_VAR): spec = deserialize_jobspec( io.StringIO(os.environ.get(_BASTION_SERIALIZED_JOBSPEC_ENV_VAR)) @@ -728,24 +747,29 @@ def _build_pod(self) -> Nested[Any]: } ) + spec = dict( + # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. + terminationGracePeriodSeconds=60, + # Fail if any pod fails, and allow retries to happen at JobSet level. + restartPolicy="Never", + nodeSelector={ + "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, + "cloud.google.com/gke-tpu-topology": system.topology, + **selector, + }, + tolerations=tolerations, + containers=[self._build_container()], + initContainers=[self._build_uploader_container()], + serviceAccountName=cfg.service_account, + volumes=volumes, + ) + + if cfg.priority_class: + spec["priorityClassName"] = cfg.priority_class + return dict( metadata=dict(annotations=annotations, labels=labels), - spec=dict( - # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=60, - # Fail if any pod fails, and allow retries to happen at JobSet level. - restartPolicy="Never", - nodeSelector={ - "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, - "cloud.google.com/gke-tpu-topology": system.topology, - **selector, - }, - tolerations=tolerations, - containers=[self._build_container()], - initContainers=[self._build_uploader_container()], - serviceAccountName=cfg.service_account, - volumes=volumes, - ), + spec=spec, ) def _build_job(self) -> Nested[Any]: diff --git a/axlearn/cloud/gcp/job_test.py b/axlearn/cloud/gcp/job_test.py index 9719438c..51b114c6 100644 --- a/axlearn/cloud/gcp/job_test.py +++ b/axlearn/cloud/gcp/job_test.py @@ -27,6 +27,7 @@ from axlearn.cloud.common.bastion import ( _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, + BASTION_JOB_VERSION_ENV_VAR, deserialize_jobspec, new_jobspec, serialize_jobspec, @@ -39,6 +40,7 @@ from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.job import ( _MEMORY_REQUEST_PERCENTAGE, + BASTION_JOB_VERSION_LABEL, CPUJob, GCSFuseMount, HostMount, @@ -244,6 +246,7 @@ def _job_config( service_account: Optional[str] = None, enable_pre_provisioner: Optional[bool] = None, host_mount_spec: Optional[list[str]] = None, + priority_class: Optional[str] = None, ): with mock_gcp_settings([job.__name__, bundler.__name__], self._mock_settings): fv = flags.FlagValues() @@ -259,6 +262,7 @@ def _job_config( cfg.bundler = bundler_cls.from_spec([], fv=fv).set(image="test-image") cfg.accelerator.instance_type = "tpu-v4-8" cfg.enable_pre_provisioner = enable_pre_provisioner + cfg.priority_class = priority_class yield cfg def test_mount_dataclass(self): @@ -284,7 +288,12 @@ def test_mount_dataclass(self): enable_pre_provisioner=[None, False, True], ) def test_instantiate( - self, reservation, service_account, enable_pre_provisioner, bundler_cls, wrap_bundler + self, + reservation, + service_account, + enable_pre_provisioner, + bundler_cls, + wrap_bundler, ): class WrappedBundler(Bundler): @config_class @@ -328,12 +337,13 @@ class Config(Bundler.Config): env={ "BASTION_TIER": "0", _BASTION_SERIALIZED_JOBSPEC_ENV_VAR: _create_serialized_job_spec(1, "user-1"), + BASTION_JOB_VERSION_ENV_VAR: "1", }, reservation="test-reservation", expect_reserved=True, ), dict( - env={"BASTION_TIER": "1"}, + env={"BASTION_TIER": "1", BASTION_JOB_VERSION_ENV_VAR: "2"}, reservation="test-reservation", expect_reserved=False, ), @@ -349,6 +359,7 @@ class Config(Bundler.Config): location_hint=["test-location-hint", None], enable_tpu_smart_repair=[True, False], host_mount_spec=[["name=host-mount,host_path=/tmp,mount_path=/host-tmp"], None], + priority_class=[None, "such-high-priority"], ) def test_build_pod( self, @@ -361,9 +372,12 @@ def test_build_pod( location_hint: Optional[str] = None, enable_tpu_smart_repair: bool = False, host_mount_spec: Optional[list[str]] = None, + priority_class: Optional[str] = None, ): with mock.patch.dict("os.environ", env), self._job_config( - bundler_cls, host_mount_spec=host_mount_spec + bundler_cls, + host_mount_spec=host_mount_spec, + priority_class=priority_class, ) as cfg: gke_job: job.TPUGKEJob = cfg.set( reservation=reservation, @@ -421,6 +435,8 @@ def test_build_pod( else: self.fail("host-mount not found!") + self.assertEqual(container["imagePullPolicy"], "Always") + self.assertIn("limits", resources) tpu_type = infer_tpu_type(cfg.accelerator.instance_type) tpu_characteristics = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[tpu_type] @@ -515,6 +531,12 @@ def test_build_pod( self.assertNotIn("job-priority", node_selector) self.assertNotIn("user-id", labels) + if BASTION_JOB_VERSION_ENV_VAR in env: + job_version = env.get(BASTION_JOB_VERSION_ENV_VAR) + self.assertEqual(job_version, labels.get(BASTION_JOB_VERSION_LABEL, None)) + else: + self.assertNotIn(BASTION_JOB_VERSION_LABEL, labels) + if enable_tpu_smart_repair: self.assertIn( "cloud.google.com/gke-tpu-auto-restart", @@ -528,6 +550,11 @@ def test_build_pod( ) self.assertNotIn("cloud.google.com/gke-tpu-auto-restart", labels) + if priority_class is None: + self.assertNotIn("priorityClassName", pod_spec) + else: + self.assertEqual(pod_spec.get("priorityClassName", None), priority_class) + class GPUGKEJobTest(TestCase): @property diff --git a/axlearn/cloud/gcp/jobs/gke_runner.py b/axlearn/cloud/gcp/jobs/gke_runner.py index fc186925..a1dc1839 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner.py +++ b/axlearn/cloud/gcp/jobs/gke_runner.py @@ -32,7 +32,11 @@ import kubernetes as k8s from absl import app, flags, logging -from axlearn.cloud.common.bastion import JobLifecycleEvent, JobLifecycleState +from axlearn.cloud.common.bastion import ( + BASTION_JOB_VERSION_ENV_VAR, + JobLifecycleEvent, + JobLifecycleState, +) from axlearn.cloud.common.bundler import get_bundler_config from axlearn.cloud.common.event_queue import BaseQueueClient from axlearn.cloud.common.utils import ( @@ -44,7 +48,7 @@ from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.event_queue import event_queue_from_config -from axlearn.cloud.gcp.job import GCPJob, GKEJob, GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GCPJob, GKEJob, GPUGKEJob, TPUGKEJob from axlearn.cloud.gcp.jobs import runner_utils from axlearn.cloud.gcp.jobs.tpu_runner import with_tpu_training_defaults from axlearn.cloud.gcp.node_pool import ( @@ -82,6 +86,21 @@ def _infer_reservation(jobset_spec: dict) -> Optional[str]: return None +def _infer_job_version(jobset_spec: dict) -> Optional[int]: + """Infers job version given a jobset spec.""" + try: + for job in jobset_spec["replicatedJobs"]: + labels = job["template"]["spec"]["template"]["metadata"]["labels"] + # If any job has a job version label, return it. + job_version = labels.get(BASTION_JOB_VERSION_LABEL, None) + + if job_version is not None: + return int(job_version) + except (TypeError, KeyError) as e: + logging.warning("Failed to infer job version: %s.", e) + return None + + class GKERunnerJob(GCPJob): """Launches and monitors a GKE job via k8s JobSet API.""" @@ -231,6 +250,7 @@ class Status(enum.Enum): STARTUPPOLICYCOMPLETED: JobSet completed StartupPolicy. READY: JobSet is ready (all Jobs are ready). SUCCEEDED: JobSet succeeded (all Jobs succeeded). Typically also manifests as COMPLETED. + UPDATING: Job will be relaunched with new specs. RESCHEDULED: Job was rescheduled onto a different tier. """ @@ -243,6 +263,7 @@ class Status(enum.Enum): STARTUPPOLICYCOMPLETED = "STARTUPPOLICYCOMPLETED" READY = "READY" SUCCEEDED = "SUCCEEDED" + UPDATING = "UPDATING" RESCHEDULED = "RESCHEDULED" # TODO(markblee): Consider moving some of the logic here into the inner impl. @@ -261,6 +282,20 @@ def _get_status(self) -> Status: if runner_utils.should_recreate_job(tier, reservation): return GKERunnerJob.Status.RESCHEDULED + expected_job_version = os.environ.get(BASTION_JOB_VERSION_ENV_VAR, None) + current_job_version = _infer_job_version(resp["spec"]) + + # If the job is expected to run with a newer version, relaunch it. + if expected_job_version is not None and ( + current_job_version is None or int(expected_job_version) > current_job_version + ): + logging.info( + "Current job version is %s; expected job version is %s", + current_job_version, + expected_job_version, + ) + return GKERunnerJob.Status.UPDATING + # According to stogner@google.com, it's possible for "conditions" to be missing until # the overall jobset has completed. However, if the jobset does complete, "conditions" # should be a reliable indicator of overall completion status. @@ -428,6 +463,9 @@ def _execute(self): elif status == GKERunnerJob.Status.RESCHEDULED: logging.info("Jobset does not match scheduling tier. Rescheduling the jobset...") self._reschedule() + elif status == GKERunnerJob.Status.UPDATING: + logging.info("Newer job version is available. Relaunching the jobset...") + self._inner._delete() # pylint: disable=protected-access elif status == GKERunnerJob.Status.NOT_STARTED: logging.info("Task does not exist. Submitting it now...") # Only bundle on first start, not if we're resuming monitoring. @@ -546,7 +584,7 @@ def _delete_k8s_jobset_and_node_pools( @catch_auth def main(argv: Sequence[str], *, flag_values: flags.FlagValues = FLAGS): - action = parse_action(argv, options=["start", "list", "stop"]) + action = parse_action(argv, options=["start", "update", "list", "stop"]) project = gcp_settings("project", fv=flag_values) zone = gcp_settings("zone", fv=flag_values) @@ -554,7 +592,7 @@ def main(argv: Sequence[str], *, flag_values: flags.FlagValues = FLAGS): load_kube_config(project=project, zone=zone, cluster=cluster) - if action == "start": + if action in ("start", "update"): command = " ".join(argv[2:]) if not command: raise app.UsageError("Command is required.") diff --git a/axlearn/cloud/gcp/jobs/gke_runner_test.py b/axlearn/cloud/gcp/jobs/gke_runner_test.py index 5bee068a..ea59ec28 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner_test.py +++ b/axlearn/cloud/gcp/jobs/gke_runner_test.py @@ -12,16 +12,27 @@ from absl import app, flags from absl.testing import parameterized +from axlearn.cloud.common.bastion import BASTION_JOB_VERSION_ENV_VAR from axlearn.cloud.gcp import bundler, node_pool_provisioner -from axlearn.cloud.gcp.job import GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GPUGKEJob, TPUGKEJob from axlearn.cloud.gcp.jobs import gke_runner from axlearn.cloud.gcp.jobs.bastion_vm_test import _mock_job -from axlearn.cloud.gcp.jobs.gke_runner import _get_runner_or_exit, _infer_reservation +from axlearn.cloud.gcp.jobs.gke_runner import ( + _get_runner_or_exit, + _infer_job_version, + _infer_reservation, +) from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL from axlearn.cloud.gcp.test_utils import mock_gcp_settings -def _mock_replicated_jobs(reservations: Sequence[str]): +def _mock_replicated_jobs(reservations: Sequence[str], bastion_job_version: Optional[int] = None): + job_version_label = ( + {"metadata": {"labels": {BASTION_JOB_VERSION_LABEL: str(bastion_job_version)}}} + if bastion_job_version + else {} + ) + return [ { "template": { @@ -34,6 +45,7 @@ def _mock_replicated_jobs(reservations: Sequence[str]): else {"cloud.google.com/gke-spot": "true"} ) }, + **job_version_label, }, } } @@ -304,11 +316,33 @@ def test_exit(self, status, enable_pre_provisioner): def test_infer_reservation(self, status: dict, expected: Optional[str] = None): self.assertEqual(expected, _infer_reservation(status)) + @parameterized.parameters( + dict( + status=dict( + replicatedJobs=_mock_replicated_jobs(["test-reservation"], bastion_job_version=None) + ), + expected=None, + ), + dict( + status=dict( + replicatedJobs=_mock_replicated_jobs(["test-reservation"], bastion_job_version=1) + ), + expected=1, + ), + dict( + status=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"])), + expected=None, + ), + ) + def test_infer_job_version(self, status: dict, expected: Optional[str] = None): + self.assertEqual(expected, _infer_job_version(status)) + @parameterized.product( ( # Conditions is set, so we use it. dict( tier=None, + job_version=None, status=dict( conditions=[ dict(type="COMPLETED", status="TRUE"), @@ -321,6 +355,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore conditions with status.lower() != "true". dict( tier=None, + job_version=None, status=dict( conditions=[ dict(type="COMPLETED", status="FALSE"), @@ -334,6 +369,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing conditions entirely, fallback to child job statuses. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=1, succeeded=0), @@ -347,6 +383,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore conditions with status.lower() != "true". dict( tier=None, + job_version=None, status=dict( conditions=[dict(type="COMPLETED", status="FALSE")], replicatedJobsStatus=[ @@ -361,6 +398,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # or until replicated job statuses change. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=1, ready=1, succeeded=0), @@ -373,6 +411,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # At least one job failed without conditions, and tier does not match. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=1, ready=1, succeeded=0), @@ -385,6 +424,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Number of replicated job statuses do not match slices. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=1, succeeded=0), @@ -397,6 +437,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All replicated jobs succeeded. No need to wait for jobset conditions. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=2), @@ -409,6 +450,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore active and missing statuses. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=1, ready=1), @@ -421,6 +463,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing jobset is reported as "not started". dict( tier=None, + job_version=None, status=k8s.client.exceptions.ApiException(status=404), spec=None, num_slices=1, @@ -429,6 +472,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All statuses are 0. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=0), @@ -441,6 +485,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All statuses are 0 and tiers do not match (thus will be recreated). dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=0), @@ -453,6 +498,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Jobset reservation and bastion tier do not match. dict( tier="1", + job_version=None, status={}, spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"])), num_slices=2, @@ -461,6 +507,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Jobset reservation and bastion tier do not match. dict( tier="1", + job_version=None, status={}, spec=dict(replicatedJobs=_mock_replicated_jobs(["spot", "test-reservation"])), num_slices=2, @@ -470,6 +517,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # In this case, we allow the job to keep running. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=2, ready=2), @@ -482,6 +530,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing reservation / invalid spec will be treated as spot. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=2, ready=2), @@ -491,6 +540,58 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): num_slices=2, expected=gke_runner.GKERunnerJob.Status.READY, ), + # Job version has increased from None. + dict( + tier="0", + job_version=1, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], None)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.UPDATING, + ), + # Job version has increased from a non-None number. + dict( + tier="0", + job_version=4, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 3)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.UPDATING, + ), + # Job version has decreased, in which case, no update. + dict( + tier="0", + job_version=1, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 2)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.READY, + ), + # Job version is set to None, in which case, no update. + dict( + tier="0", + job_version=None, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 2)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.READY, + ), ), enable_pre_provisioner=(None, False, True), ) @@ -500,6 +601,7 @@ def test_get_status( num_slices: int, expected: gke_runner.GKERunnerJob.Status, tier: str, + job_version: Optional[int], spec: dict, enable_pre_provisioner: Optional[bool] = None, ): @@ -519,7 +621,9 @@ def test_get_status( mock_get_status = mock.Mock(return_value=dict(status=status, spec=spec)) with ( - mock.patch.dict("os.environ", {"BASTION_TIER": tier}), + mock.patch.dict( + "os.environ", {"BASTION_TIER": tier, BASTION_JOB_VERSION_ENV_VAR: job_version} + ), mock.patch( "kubernetes.client.CustomObjectsApi", return_value=mock.Mock(get_namespaced_custom_object_status=mock_get_status), @@ -835,6 +939,47 @@ def test_start(self, enable_pre_provisioner): job._inner.execute.assert_called() # pytype: disable=attribute-error + @parameterized.parameters(None, False, True) + def test_update(self, enable_pre_provisioner): + with self._job_config( + name="test-name", + cluster="test-cluster", + service_account="test-sa", + enable_pre_provisioner=enable_pre_provisioner, + ) as ( + cfg, + _, + ): + cfg.bundler.set(image="test") + + job: gke_runner.TPUGKERunnerJob = cfg.set( + command="", + status_interval_seconds=0, + enable_pre_provisioner=enable_pre_provisioner, + ).instantiate() + + mock_job = mock.patch.multiple( + job, + _get_status=mock.Mock( + side_effect=[ + gke_runner.GKERunnerJob.Status.UPDATING, + gke_runner.GKERunnerJob.Status.COMPLETED, + ] + ), + _get_job_credentials=mock.DEFAULT, + _delete=mock.DEFAULT, + _inner=mock.DEFAULT, + _pre_provisioner=mock.DEFAULT, + ) + + with mock_job: + job._execute() + + # pytype: disable=attribute-error + job._pre_provisioner.delete_for.assert_not_called() + job._inner._delete.assert_called() + # pytype: enable=attribute-error + class MainTest(parameterized.TestCase): """Tests CLI entrypoint.""" @@ -857,7 +1002,7 @@ def test_get_runner_or_exit(self, instance_type: str, expected: Union[Exception, dict(runner=gke_runner.TPUGKERunnerJob, instance_type="tpu-v4-8"), dict(runner=gke_runner.GPUGKERunnerJob, instance_type="gpu-a3-highgpu-8g-256"), ], - action=["start", "stop"], + action=["start", "stop", "update"], ) def test_load_kube_config(self, action, runner, instance_type): # load_kube_config should only be called if using gke action. diff --git a/axlearn/cloud/gcp/jobs/launch.py b/axlearn/cloud/gcp/jobs/launch.py index d3f70c0a..48d7cb3c 100644 --- a/axlearn/cloud/gcp/jobs/launch.py +++ b/axlearn/cloud/gcp/jobs/launch.py @@ -10,9 +10,10 @@ that decides, for a given CLI action (e.g. 'start') and instance type (e.g. 'tpu-v4-8'), whether the launcher can be used. See `_LAUNCHERS` for a full list, and `BastionManagedTPUJob` for an example. -Possible actions: [start|stop|list] +Possible actions: [start|update|stop|list] Start: submits a job to the queue. + Update: updates a job without resubmission. Stop: stops the job or removes a job from the queue. List: lists jobs and their statuses. @@ -41,9 +42,27 @@ --bundler_spec=dockerfile=Dockerfile \ --bundler_spec=build_arg1=my-build-arg ... + # Update an existing job without resubmission. + axlearn gcp launch update --instance_type=tpu-v4-32 ... -- python3 my_script2.py + # To stop a job. axlearn gcp launch stop --name=... --instance_type=tpu +More on the Update command: + + The update command allows updating bundles and job command of an existing job + without resubmission. It currently only works with axlearn.cloud.gcp.jobs.gke_runner. + + Resource related flags including instance_type, num_replicas and enable_pre_provisioner + are not allowed to change. + + When bundles are updated before the job update, job will run with new bundles. + If bundle update is not desired, use `--bundler_spec=skip_bundle=True` flag + to skip bundle update. + + To be able to update the job without re-provisioning the resources (e.g. TPU node pools), + use `--enable_pre_provisioner` to submit the job. + """ # pylint: disable=redefined-outer-name,protected-access @@ -83,6 +102,7 @@ project_usage_table, serialized_flags_for_job, user_usage_table, + validate_resource_flags, with_k8s_jobset_state, with_qrm_tpu_state, ) @@ -256,8 +276,8 @@ def from_flags(cls, fv: flags.FlagValues, *, command: str, action: str, **kwargs # We use the bundler defined by the runner impl, ensuring that bundling is consistent # between local and bastion. cfg.bundler = None - # Construct runner only for start. - if action == "start": + # Construct runner only for start and update. + if action in ("start", "update"): cfg.runner = cls.runner.from_flags(fv, command=command) runner_flags = " ".join(serialized_flags_for_job(fv, cls.runner)) cfg.command = f"python3 -m {cls.runner.__module__} {action} {runner_flags} -- {command}" @@ -353,6 +373,41 @@ def _execute(self) -> JobSpec: ) return jobspec + def _update(self) -> JobSpec: + """Update an existing job without resubmission. + + This will fetch the existing job from Bastion, change + the trainer command, increment the version in metadata, and then update the job on Bastion. + + The resource related flags including instance_type, num_replicas and enable_pre_provisioner + are not allowed to change. + """ + cfg: BaseBastionManagedJob.Config = self.config + + # Get current job spec. + job_spec = self._bastion_dir.get_job(job_name=cfg.name) + + if self._runner and self._runner.bundler: + self._runner.bundler.bundle(cfg.name) + + logging.info("Starting update for job name %s", cfg.name) + logging.info("Command: %s", cfg.command) + + # Update the job version. + job_version = job_spec.metadata.version or 0 + job_spec.metadata.version = job_version + 1 + + # The resource related flags are not allowed to change. + validate_resource_flags(job_spec.command, cfg.command) + + job_spec.command = cfg.command + + logging.info("Updated jobspec: %s", job_spec) + + jobspec = self._bastion_dir.update_job(cfg.name, job_spec=job_spec) + + return jobspec + # TODO(markblee): Add a BastionManagedCPUJob. class BastionManagedTPUJob(BaseBastionManagedJob): @@ -451,7 +506,7 @@ def define_flags(cls, fv: flags.FlagValues): @classmethod def from_flags(cls, fv: flags.FlagValues, *, command: str, action: str, **kwargs) -> Config: # Set default docker flags. These will automatically propagate to the runner on the bastion. - if action == "start": + if action in ("start", "update"): fv.set_default("bundler_type", CloudBuildBundler.TYPE) cfg: BastionManagedGKEJob.Config = super().from_flags( fv, command=command, action=action, **kwargs @@ -523,12 +578,12 @@ def _execute(self) -> JobSpec: Launcher( job_cls=BastionManagedGKEJob.with_runner(gke_runner.TPUGKERunnerJob), matcher=config_for_function(match_by_regex).set( - match_regex=dict(start=r"tpu-v.+-(\d)+", list=r"tpu.*", stop=r"tpu.*"), + match_regex=dict(start=r"tpu-v.+-(\d)+", update=r"tpu.*", list=r"tpu.*", stop=r"tpu.*"), gcp_api=GCPAPI.GKE.value, ), description=( "Supports launching TPU jobs via GKE. " - "For 'start', provide --gcp_api=gke, as well as the full instance type, " + "For 'start' or 'update', provide --gcp_api=gke, as well as the full instance type, " "e.g. --instance_type=tpu-v4-8. " "For 'list' or 'stop', provide --gcp_api=gke as well as the accelerator type, " "e.g. --instance_type=tpu." @@ -576,7 +631,7 @@ def main(_): if FLAGS.instance_type is None: raise app.UsageError("--instance_type is required.") - action = parse_action(sys.argv, options=["start", "stop", "list"], default="start") + action = parse_action(sys.argv, options=["start", "stop", "update", "list"], default="start") launcher = _get_launcher_or_exit( action=action, instance_type=FLAGS.instance_type, @@ -604,6 +659,8 @@ def main(_): job._list() elif action == "stop": job._delete() + elif action == "update": + job._update() else: raise app.UsageError(f"Unsupported action {action}") @@ -635,7 +692,9 @@ def _private_flags(): # Allow instance_type to be None when running --help without any flags. On the other hand, if # instance_type is provided when running --help, we show additional help info. if FLAGS.instance_type: - action = parse_action(sys.argv, options=["start", "stop", "list"], default="start") + action = parse_action( + sys.argv, options=["start", "update", "stop", "list"], default="start" + ) launcher = _get_launcher_or_exit( action=action, instance_type=FLAGS.instance_type, diff --git a/axlearn/cloud/gcp/jobs/launch_test.py b/axlearn/cloud/gcp/jobs/launch_test.py index 776896f1..c8991221 100644 --- a/axlearn/cloud/gcp/jobs/launch_test.py +++ b/axlearn/cloud/gcp/jobs/launch_test.py @@ -4,6 +4,7 @@ # pylint: disable=protected-access import contextlib +import copy from datetime import datetime from typing import Optional from unittest import mock @@ -17,6 +18,7 @@ from axlearn.cloud.common.bundler import BUNDLE_EXCLUDE from axlearn.cloud.common.job import Job from axlearn.cloud.common.scheduler import JobMetadata +from axlearn.cloud.common.types import JobSpec from axlearn.cloud.gcp import bundler from axlearn.cloud.gcp import job as gcp_job from axlearn.cloud.gcp.jobs import bastion_vm, gke_runner, launch, tpu_runner @@ -461,7 +463,7 @@ class TestBastionManagedGKEJob(TestWithTemporaryCWD): cluster="test-cluster", ), ], - action=["start", "list"], + action=["start", "list", "update"], ) def test_tpu_flags( self, @@ -550,7 +552,7 @@ def test_tpu_flags( cfg = tpu_gke_job.from_flags(fv, **from_flags_kwargs) self.assertIsNone(cfg.bundler) - if action == "start": + if action in ("start", "update"): self.assertIsNotNone(cfg.runner) self.assertIsNotNone(cfg.runner.bundler) self.assertIn("tpu", cfg.runner.bundler.extras) @@ -581,7 +583,7 @@ def test_tpu_flags( # Test infer tpu resources. self.assertEqual({"v4": 16}, maybe_instantiate(cfg.resources)) - if action == "start": + if action in ("start", "update"): # Make sure command is expected. for flag in ["name", "bundler_type", "instance_type"]: if fv[flag].value is not None: @@ -601,7 +603,7 @@ def test_tpu_flags( ) # Bundler should be propagated to runner. - if action == "start": + if action in ("start", "update"): self.assertIsNotNone(job.runner.bundler) @parameterized.parameters( @@ -638,3 +640,48 @@ class FakeBastionDirectory(BastionDirectory): else: mock_execute.assert_called_once() self.assertIsNotNone(job_spec) + + @parameterized.parameters(None, 0, 1) + def test_update(self, job_version): + job_name = "test_job0" + + job_spec = new_jobspec( + name=job_name, + command="command", + metadata=JobMetadata( + user_id="test_user", + project_id="test_project", + creation_time=datetime.now(), + resources={"v4": 8}, + job_id="test-id0", + version=job_version, + ), + ) + + class FakeBastionDirectory(BastionDirectory): + def get_job(self, job_name: str) -> JobSpec: + return copy.deepcopy(job_spec) + + def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec: + return job_spec + + tpu_gke_job = BastionManagedGKEJob.with_runner(_DummyRunner) + cfg = tpu_gke_job.default_config().set( + **_common_bastion_managed_job_kwargs(), + namespace="default", + project="test-project", + cluster="test-cluster", + bastion_dir=FakeBastionDirectory.default_config().set(root_dir="temp_dir"), + ) + cfg.set(name=job_name) + patch_kube_config = mock.patch(f"{launch.__name__}.load_kube_config") + + with patch_kube_config: + job: BastionManagedGKEJob = cfg.instantiate() + + # Update the job. + updated_job_spec = job._update() + + updated_version = (job_spec.metadata.version or 0) + 1 + + self.assertEqual(updated_job_spec.metadata.version, updated_version) diff --git a/axlearn/cloud/gcp/jobs/launch_utils.py b/axlearn/cloud/gcp/jobs/launch_utils.py index ccfd4f00..756a45b3 100644 --- a/axlearn/cloud/gcp/jobs/launch_utils.py +++ b/axlearn/cloud/gcp/jobs/launch_utils.py @@ -5,6 +5,7 @@ import collections import json import re +import shlex from typing import Any, Optional, Protocol from absl import flags @@ -253,3 +254,55 @@ def _k8s_jobset_state_from_jobs( else: states.append("PENDING") return states + + +def _parse_resource_flags_from_command(command: str) -> flags.FlagValues: + """Infer resources flags from launch command. + + It parses the resources flags from the command. + + Args: + command: The launch command of a job. + + Returns: + A flags.FlagValues containing the parsed resources flags. + """ + commands = shlex.split(command) + + fv = flags.FlagValues() + flags.DEFINE_string("instance_type", default=None, help="", flag_values=fv) + flags.DEFINE_integer("num_replicas", default=None, help="", flag_values=fv) + flags.DEFINE_boolean("enable_pre_provisioner", default=None, help="", flag_values=fv) + flags.DEFINE_alias("num_slices", "num_replicas", flag_values=fv) + flags.DEFINE_alias("tpu_type", "instance_type", flag_values=fv) + fv(commands, known_only=True) + + return fv + + +def validate_resource_flags(original_command: str, updated_command: str): + """Raise an exception if the resource flags are different + in the original and updated commands.""" + + original_parsed_flags = _parse_resource_flags_from_command(original_command) + updated_parsed_flags = _parse_resource_flags_from_command(updated_command) + + original_instance_type = original_parsed_flags.instance_type or original_parsed_flags.tpu_type + updated_instance_type = updated_parsed_flags.instance_type or updated_parsed_flags.tpu_type + + original_num_replicas = original_parsed_flags.num_replicas or original_parsed_flags.num_slices + updated_num_replicas = updated_parsed_flags.num_replicas or updated_parsed_flags.num_slices + + original_pre_provisioner = original_parsed_flags.enable_pre_provisioner + updated_pre_provisioner = updated_parsed_flags.enable_pre_provisioner + + if original_instance_type != updated_instance_type: + raise ValueError(f"Expected {original_instance_type=} to match {updated_instance_type=}.") + + if original_num_replicas != updated_num_replicas: + raise ValueError(f"Expected {original_num_replicas=} to match {updated_num_replicas=}.") + + if original_pre_provisioner != updated_pre_provisioner: + raise ValueError( + f"Expected {original_pre_provisioner=} to match {updated_pre_provisioner=}." + ) diff --git a/axlearn/cloud/gcp/jobs/launch_utils_test.py b/axlearn/cloud/gcp/jobs/launch_utils_test.py index dbe6a573..07c280f3 100644 --- a/axlearn/cloud/gcp/jobs/launch_utils_test.py +++ b/axlearn/cloud/gcp/jobs/launch_utils_test.py @@ -3,10 +3,12 @@ """Tests launch utilities.""" # pylint: disable=protected-access +import contextlib import dataclasses import json from datetime import datetime from types import SimpleNamespace +from typing import Union from unittest import mock from absl import flags @@ -18,11 +20,13 @@ from axlearn.cloud.common.utils import Table from axlearn.cloud.gcp.jobs import launch_utils from axlearn.cloud.gcp.jobs.launch_utils import ( + _parse_resource_flags_from_command, jobs_table, match_by_regex, project_usage_table, serialized_flags_for_job, user_usage_table, + validate_resource_flags, with_k8s_jobset_state, with_qrm_tpu_state, ) @@ -120,6 +124,88 @@ def test_match_by_regex(self, matcher, cases): ), ) + @parameterized.parameters( + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update -" + "-enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + enable_pre_provisioner=True, + instance_type="tpu-v5litepod-16", + num_replicas=1, + ), + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--noenable_pre_provisioner --tpu_type=tpu-v5litepod-32 --num_slices=2 " + "-- sleep infinity", + enable_pre_provisioner=False, + instance_type="tpu-v5litepod-32", + num_replicas=2, + ), + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--tpu_type=tpu-v5litepod-32 --num_slices=2 " + "-- sleep infinity", + enable_pre_provisioner=None, + instance_type="tpu-v5litepod-32", + num_replicas=2, + ), + ) + def test_parse_resource_flags_from_command( + self, command, enable_pre_provisioner, instance_type, num_replicas + ): + parsed_flags = _parse_resource_flags_from_command(command) + + self.assertEqual(parsed_flags.enable_pre_provisioner, enable_pre_provisioner) + self.assertEqual(parsed_flags.instance_type, instance_type) + self.assertEqual(parsed_flags.num_replicas, num_replicas) + + @parameterized.parameters( + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep 30", + expected=None, + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-32 --num_replicas=1 " + "-- sleep infinity", + expected=ValueError("instance_type"), + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_slices=2 " + "-- sleep infinity", + expected=ValueError("num_replicas"), + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update" + " --instance_type=tpu-v5litepod-16 --num_replicas=1 -- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update" + " --enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + expected=ValueError("pre_provisioner"), + ), + ) + def test_validate_resource_flags( + self, original_command, updated_command, expected: Union[Exception, type] + ): + if isinstance(expected, Exception): + ctx = self.assertRaisesRegex(type(expected), str(expected)) + else: + ctx = contextlib.nullcontext() + with ctx: + validate_resource_flags(original_command, updated_command) + class TestListUtils(parameterized.TestCase): """Tests list utils.""" diff --git a/axlearn/cloud/gcp/measurement.py b/axlearn/cloud/gcp/measurement.py index 89094d85..7a40b755 100644 --- a/axlearn/cloud/gcp/measurement.py +++ b/axlearn/cloud/gcp/measurement.py @@ -9,14 +9,24 @@ from axlearn.cloud.common.utils import parse_kv_flags from axlearn.common import measurement -from axlearn.common.config import maybe_set_config +from axlearn.common.config import REQUIRED, Required, config_class, maybe_set_config @measurement.register_recorder("goodput") class GoodputRecorder(measurement.Recorder): """Records overall training goodput.""" - Config = measurement.Recorder.Config + @config_class + class Config(measurement.Recorder.Config): + """Configures GoodputRecorder. + + Attributes: + upload_dir: Directory to store metrics for the monitor. + upload_interval: Time interval (seconds) for monitoring uploads. + """ + + upload_dir: Required[str] = REQUIRED + upload_interval: Required[int] = REQUIRED @classmethod def from_flags(cls, fv: flags.FlagValues) -> "GoodputRecorder": @@ -76,10 +86,12 @@ def record(self, event: measurement.Event, *args, **kwargs): ) def start_monitoring(self, *args, **kwargs): - # Instantiate ml-goodput-measurement's GoodputMonitor - # to asynchronously calculate goodput and badput at - # the upload_interval and upload to the specified - # tensorboard directory. + """ + Instantiate ml-goodput-measurement's GoodputMonitor to asynchronously calculate + Goodput and Badput at the upload_interval and upload to the specified TensorBoard + directory. + Note: This function requires initialization of distributed JAX before it is called. + """ if self._monitor is None: cfg: GoodputRecorder.Config = self.config self._monitor = goodput_monitoring.GoodputMonitor( diff --git a/axlearn/cloud/gcp/measurement_test.py b/axlearn/cloud/gcp/measurement_test.py index 30214262..ff667d6a 100644 --- a/axlearn/cloud/gcp/measurement_test.py +++ b/axlearn/cloud/gcp/measurement_test.py @@ -60,7 +60,7 @@ def test_start_monitoring(self): fv.mark_as_parsed() recorder = GoodputRecorder.from_flags(fv) - recorder._monitor = None # Ensure _monitor is initially None + self.assertIsNone(recorder._monitor) # Ensure _monitor is initially None with mock.patch("ml_goodput_measurement.monitoring.GoodputMonitor") as mock_goodput_monitor: mock_monitor_instance = mock_goodput_monitor.return_value diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 34d3b87c..1b009545 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -72,7 +72,9 @@ def _num_replicas_per_shard(arr: Tensor) -> dict[tuple[_SliceTuple, ...], int]: return dict(replica_count) -def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_ShardInfo]: +def _get_shard_infos( + arr_inp: Tensor, *, max_data_shard_degree: int, shard_threshold_bytes: int +) -> list[_ShardInfo]: """Returns a list of _ShardInfo for addressable shards that need to be saved. If replica count for the shards are greater than 0, all replicas will save slices of the @@ -84,11 +86,21 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh for shard in arr_inp.addressable_shards: replica_count = replica_count_map[_slices_to_tuple(shard.index)] assert replica_count > 0 + shard_degree = ( + min(replica_count, max_data_shard_degree) + if max_data_shard_degree > 0 + else replica_count + ) + should_skip = ( + shard_degree == 1 + or shard.data.nbytes < shard_threshold_bytes + or shard.replica_id >= shard_degree + ) for axis, size in enumerate(shard.data.shape): # Find the first dim divisible by partial replication size. - if max_data_shard_degree == 1 or replica_count == 1 or size % replica_count != 0: + if should_skip or size % shard_degree != 0: continue - part_size = size // replica_count + part_size = size // shard_degree slice_obj = shard.index[axis] assert slice_obj.step is None start_offset = shard.replica_id * part_size @@ -103,7 +115,7 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh + (slice(slice_start + start_offset, slice_start + end_offset),) + shard.index[axis + 1 :], (start_offset, end_offset, axis), - replica_count, + shard_degree, ) ) break @@ -181,7 +193,8 @@ async def _async_serialize( d2h_future: futures.Future, *, limiter: Optional[serialization._LimitInFlightBytes] = None, - max_data_shard_degree: Optional[int] = None, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): """Similar to `serialization.async_serialize`, but limiting peak host memory usage and sharding along data-parallel axis. @@ -195,7 +208,11 @@ async def _async_serialize( Reference: https://github.com/google/jax/blob/595a620804e810335a870e93975a78504b2e95e5/jax/experimental/array_serialization/serialization.py#L188 """ - shard_infos = _get_shard_infos(arr_inp, max_data_shard_degree=max_data_shard_degree) + shard_infos = _get_shard_infos( + arr_inp, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, + ) if not shard_infos: d2h_future.set_result(shard_infos) return @@ -261,7 +278,8 @@ async def _run_serializer( d2h_futures: list[futures.Future], *, max_concurrent_bytes: Optional[int] = None, - max_data_shard_degree: Optional[int] = None, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): """Asynchronously serializes a list of tensors with _async_serialize.""" # We add 1 because LimitInFlightBytes expects a limit strictly greater than any request. @@ -274,7 +292,10 @@ async def _run_serializer( # pylint: enable=protected-access future_writer = jax.tree.map( functools.partial( - _async_serialize, limiter=limiter, max_data_shard_degree=max_data_shard_degree + _async_serialize, + limiter=limiter, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, ), arrays, tensorstore_specs, @@ -385,7 +406,9 @@ class BoundedDataShardedAsyncCheckpointManager(serialization.GlobalAsyncCheckpoi max_concurrent_gb: Max concurrent shards (in GB) to write. max_data_shard_degree: Max sharding degree of model weights along data-parallel axis. `None` and `1` means no sharding. `-1` means fully shard along data-parallel - replicas. `>1` means custom sharding degree (currently not implemented). + replicas. `>1` means custom sharding degree and should almost always be a power of 2. + shard_threshold_bytes: Threshold for a array shard to be data-sharded. A value of None + or <= 0 means always data-shard according to max_data_shard_degree. timeout_secs: Barrier timeout in seconds. """ @@ -395,6 +418,7 @@ def __init__( max_concurrent_gb: Optional[int] = None, timeout_secs: int = 300, max_data_shard_degree: Optional[int] = None, + shard_threshold_bytes: Optional[int] = None, ): super().__init__(timeout_secs) self._logged_spec = False @@ -406,11 +430,10 @@ def __init__( raise ValueError("max_concurrent_gb must be strictly positive.") self._max_concurrent_bytes = int(max_concurrent_gb * 10**9) - self._max_data_shard_degree = max_data_shard_degree or 1 - if self._max_data_shard_degree not in (1, -1): - raise NotImplementedError( - "max_data_shard_degree is not implemented for values other than 1 and -1" - ) + self._max_data_shard_degree = 1 if max_data_shard_degree is None else max_data_shard_degree + if self._max_data_shard_degree == 0: + raise NotImplementedError("max_data_shard_degree cannot be 0.") + self._shard_threshold_bytes = shard_threshold_bytes or 0 def serialize( self, @@ -457,6 +480,7 @@ def serialize( d2h_futures, max_concurrent_bytes=max_concurrent_bytes, max_data_shard_degree=self._max_data_shard_degree, + shard_threshold_bytes=self._shard_threshold_bytes, ) ) ] diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index 8ab56d1a..3b59168a 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -70,7 +70,12 @@ def test_fully_addressable(self): with mock.patch("jax.process_count", return_value=2), self.assertRaises(Exception): asyncio.run( _async_serialize( - jnp.array(1), {}, futures.Future(), limiter=serialization._LimitInFlightBytes(1) + jnp.array(1), + {}, + futures.Future(), + limiter=serialization._LimitInFlightBytes(1), + max_data_shard_degree=-1, + shard_threshold_bytes=0, ), debug=True, ) @@ -122,7 +127,14 @@ def transfer_to_host_patch(*args, **kwargs): # ValueError(...Buffer has been deleted or donated...) may occur. with pytest.raises((RuntimeError, ValueError), match=re.escape("delete")): f = _CommitFuture( - _run_serializer([arr], [spec], [d2h_future], max_concurrent_bytes=arr.nbytes) + _run_serializer( + [arr], + [spec], + [d2h_future], + max_concurrent_bytes=arr.nbytes, + max_data_shard_degree=-1, + shard_threshold_bytes=-1, + ) ) # Throws Array deleted exception if not waiting for d2h_future. jit_fn(arr) @@ -138,7 +150,14 @@ def transfer_to_host_patch(*args, **kwargs): f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch ): f = _CommitFuture( - _run_serializer([arr], [spec], [d2h_future], max_concurrent_bytes=arr.nbytes) + _run_serializer( + [arr], + [spec], + [d2h_future], + max_concurrent_bytes=arr.nbytes, + max_data_shard_degree=-1, + shard_threshold_bytes=-1, + ) ) d2h_future.result() # If D2H is finished, arr can be safely donated. @@ -162,7 +181,11 @@ async def ts_open_patch(*_, **__): f"{array_serialization.__name__}.serialization.ts.open", ts_open_patch, ), get_tensorstore_spec(arr) as spec: - f = _CommitFuture(_run_serializer([arr], [spec], [d2h_future])) + f = _CommitFuture( + _run_serializer( + [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 + ) + ) d2h_future.result() with pytest.raises(RuntimeError, match=re.escape("Test")): f.result() @@ -175,7 +198,11 @@ def transfer_to_host_patch(*_): f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch, ), get_tensorstore_spec(arr) as spec: - f = _CommitFuture(_run_serializer([arr], [spec], [d2h_future])) + f = _CommitFuture( + _run_serializer( + [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 + ) + ) # Exceptions will be raised in both the d2h future and the commit future. with pytest.raises(RuntimeError, match=re.escape("Test")): d2h_future.result() @@ -285,9 +312,17 @@ def _donate_argnum_fn(x): self.assertTrue(np.all(x_zero_copy == x_np)) def _verify_shard_info( - self, single_device_arr: jax.Array, arr: jax.Array, max_data_shard_degree: int + self, + single_device_arr: jax.Array, + arr: jax.Array, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): - shard_infos = _get_shard_infos(arr, max_data_shard_degree=max_data_shard_degree) + shard_infos = _get_shard_infos( + arr, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, + ) # Write each shard to output and check if it's the same as the original # single device array. If same, that means all shards should cover all @@ -299,12 +334,16 @@ def _verify_shard_info( out_array[info.index] = info.data self.assertTrue(np.all(out_array == np.array(single_device_arr))) - @parameterized.parameters(1, -1) + @parameterized.product( + max_data_shard_degree=[1, -1, 2, 4, 8], shard_threshold_bytes=[1000 * 1000 * 1000, 1] + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_partially_replicated(self, max_data_shard_degree): + def test_shard_info_partially_replicated( + self, max_data_shard_degree: int, shard_threshold_bytes: int + ): single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -315,14 +354,18 @@ def test_shard_info_partially_replicated(self, max_data_shard_degree): self.assertEqual(replica_count[((None, None, None), (0, 512, None))], 4) self.assertEqual(replica_count[((None, None, None), (512, 1024, None))], 4) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) - @parameterized.parameters(1, -1) + @parameterized.product( + max_data_shard_degree=[1, -1, 2, 4, 8], shard_threshold_bytes=[1000 * 1000 * 1000, 1] + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_fully_sharded(self, max_data_shard_degree): + def test_shard_info_fully_sharded(self, max_data_shard_degree: int, shard_threshold_bytes: int): single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -332,14 +375,22 @@ def test_shard_info_fully_sharded(self, max_data_shard_degree): replica_count = _num_replicas_per_shard(arr) self.assertEqual(replica_count[((0, 256, None), (0, 512, None))], 1) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) - @parameterized.product(sz=[1, 11, 16, 21], max_data_shard_degree=[1, -1]) + @parameterized.product( + sz=[1, 11, 16, 21], + max_data_shard_degree=[1, -1, 2, 4, 8], + shard_threshold_bytes=[1000 * 1000 * 1000, 1], + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_fully_replicated(self, sz: int, max_data_shard_degree: int): + def test_shard_info_fully_replicated( + self, sz: int, max_data_shard_degree: int, shard_threshold_bytes: int + ): single_device_arr = jnp.arange(0, sz) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -350,4 +401,6 @@ def test_shard_info_fully_replicated(self, sz: int, max_data_shard_degree: int): # Fully replicated on 8 devices. self.assertEqual(replica_count[((None, None, None),)], 8) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index fb49d507..00bc604f 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -57,6 +57,7 @@ from enum import Enum, unique from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union +import einops import jax from jax import numpy as jnp from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies @@ -710,7 +711,7 @@ class Output(NamedTuple): @property def num_kv_heads(self): - raise NotImplementedError(type(self)) + return self.config.num_heads def init_states( self, @@ -724,20 +725,16 @@ def init_states( dtype = cfg.cache_dtype or self.dtype() assert dtype is not None - # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take - # advantage of TPU optimizations (see `extend_step`). - # Reference: - # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215 cache = dict(time_step=jnp.zeros(target_batch_size, dtype=jnp.int32)) # If `kv_state` is provided externally, we do not have to maintain key/value in cache. if kv_state is None: cache.update( key=jnp.zeros( - shape=(target_batch_size, self.num_kv_heads, cfg.per_head_dim, target_max_len), + shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), dtype=dtype, ), value=jnp.zeros( - shape=(target_batch_size, self.num_kv_heads, cfg.per_head_dim, target_max_len), + shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), dtype=dtype, ), ) @@ -830,15 +827,7 @@ def prefill_states( time_step_mask = (jnp.arange(k_proj.shape[1]) < time_step[:, None])[..., None, None] k_proj = k_proj * time_step_mask v_proj = v_proj * time_step_mask - - # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take - # advantage of TPU optimizations (see `extend_step`). - # Reference: - # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215 - init_state.update( - key=jnp.moveaxis(k_proj, -3, -1).astype(dtype), - value=jnp.moveaxis(v_proj, -3, -1).astype(dtype), - ) + init_state.update(key=k_proj.astype(dtype), value=v_proj.astype(dtype)) return init_state, self.Output(query=q_proj, key=k_proj, value=v_proj) def extend_step( @@ -861,8 +850,8 @@ def extend_step( previous attentions, and index used for fast decoding. Contains "key" and "value" of shape [batch, num_heads, per_head_dim, target_length], and a Tensor "time_step" of shape [batch]. - query: Tensor of shape [batch, 1, target_dim] corresponding to query vector at - "time_step" indices. + query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting + at "time_step" indices. key: An optional Tensor of shape [batch, source_length, source_dim]. If None, will use `query`. value: An optional Tensor of shape [batch, source_length, source_dim]. If None, will @@ -884,40 +873,27 @@ def extend_step( kv_kwargs = dict(kv_state=kv_state) else: kv_kwargs = dict(key=key, value=value) - # Project inputs to key, value and query. Each has shape [B, 1, N, H]. + num_query_steps = query.shape[1] + # Project inputs to key, value and query. Each has shape [B, steps, N, H]. q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, time_step=time_step) - - updated_state = dict(time_step=time_step + 1) + updated_state = dict(time_step=time_step + num_query_steps) if kv_state is None: - # Move the length axis to the back. This allows us to update the cache key, value with - # the "scatter via one-hot broadcast" trick, rather than a scatter/gather operation. - # Profiling suggests moveaxis is competitive with tweaking einsum in `i_proj` -- it's - # also a bit simpler, so we keep it for now. - # [B, 1, N, H] --> [B, N, H, 1]. - k_proj = jnp.moveaxis(k_proj, -3, -1) - v_proj = jnp.moveaxis(v_proj, -3, -1) - - # Update the cache via one-hot broadcast and addition. + # Update the cache via one-hot broadcast and addition. [B, S, N, H]. cached_key = cached_states["key"] cached_value = cached_states["value"] - target_len = cached_key.shape[-1] - oh_indices = jax.nn.one_hot(time_step, target_len, dtype=k_proj.dtype) - # [B, 1, 1, T] to broadcast. - oh_indices = oh_indices[:, None, None, :] - negated_oh_indices = (1 - oh_indices).astype(cached_key.dtype) # Ensure that we accumulate using the original dtype. - new_k_proj = (cached_key * negated_oh_indices) + (k_proj * oh_indices).astype( - cached_key.dtype - ) - new_v_proj = (cached_value * negated_oh_indices) + (v_proj * oh_indices).astype( - cached_value.dtype - ) - - # Move back to original [B, T, N, H] layout. - k_proj = jnp.moveaxis(new_k_proj, -1, -3) - v_proj = jnp.moveaxis(new_v_proj, -1, -3) - - updated_state.update(key=new_k_proj, value=new_v_proj) + k_proj = k_proj.astype(cached_key.dtype) + v_proj = v_proj.astype(cached_value.dtype) + + # Function to update the cached_key for a single batch element. + def update_single(cached_key_slice, k_proj_slice, time_idx): + start_indices = (time_idx, 0, 0) + return jax.lax.dynamic_update_slice(cached_key_slice, k_proj_slice, start_indices) + + # Use jax.vmap to vectorize over the batch dimension. + k_proj = jax.vmap(update_single)(cached_key, k_proj, time_step) + v_proj = jax.vmap(update_single)(cached_value, v_proj, time_step) + updated_state.update(key=k_proj, value=v_proj) return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) @@ -945,10 +921,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child(f"{name}_proj", proj_cfg) - @property - def num_kv_heads(self): - return self.config.num_heads - def forward( self, query: Tensor, @@ -1019,10 +991,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child("q_proj", proj_cfg) - @property - def num_kv_heads(self): - raise NotImplementedError(type(self)) - def forward( self, query: Tensor, @@ -1071,10 +1039,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child("qkv_proj", proj_cfg) - @property - def num_kv_heads(self): - return self.config.num_heads - def create_parameter_specs_recursively(self) -> NestedParameterSpec: specs = VDict(**super().create_parameter_specs_recursively()) @@ -1391,8 +1355,9 @@ def forward( else: # Time step shape is [batch_size] # The expected input shape for rope_pos_emb_layer is [batch_size, seq_len] - # Therefore, expanding the shape of time_step to [batch_size, 1] - time_step = jnp.expand_dims(time_step, 1) + # Therefore, expanding the shape of time_step to [batch_size, step]. + step = query.shape[1] + time_step = jnp.arange(step)[None] + time_step[:, None] sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(time_step).astype(query.dtype) # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) @@ -1879,17 +1844,13 @@ def _forward_for_mode( f"Invalid attention_logit_biases shape: {attention_logit_biases.shape}." ) if self._mask_fn is not None: - kv_len = k_proj.shape[1] + kv_pos = jnp.arange(k_proj.shape[1])[None, :] # [1, source_len] + query_pos = jnp.arange(q_proj.shape[1])[None] # [1, target_length] if mode == ForwardMode.EXTEND_STEP: - # query_len is unused because extend_step assumes query to be length 1. - query_len = None - time_step = cached_states["i_proj"]["time_step"] - else: - query_len = q_proj.shape[1] - time_step = None - mask = self._logit_biases_for_mask( - mode=mode, kv_len=kv_len, query_len=query_len, time_step=time_step - ) + time_step = cached_states["i_proj"]["time_step"] # [B] + # [B, target_length], target_length is often 1 for decoding, but not always. + query_pos = query_pos + time_step[:, None] + mask = self._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) if mask is not None: attention_logit_biases = apply_attention_logit_biases( mask.astype(q_proj.dtype), @@ -1918,12 +1879,7 @@ def _forward_for_mode( return dict(i_proj=i_proj_state), output def _logit_biases_for_mask( - self, - *, - mode: ForwardMode, - kv_len: int, - query_len: Optional[int] = None, - time_step: Optional[Tensor] = None, + self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor ) -> Optional[Tensor]: """Returns the configured attention mask in the form of logit biases. @@ -1932,39 +1888,17 @@ def _logit_biases_for_mask( Args: mode: The forward propagation mode, chosen from (ForwardMode.FORWARD, ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP). - kv_len: The sequence length. For (ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP), - this is equal to the KV cache size. - query_len: Only used for (ForwardMode.FORWARD, ForwardMode.INIT_STATES). - If set, this is the query length. Otherwise, it uses kv_len as the query length. - Must be None for ForwardMode.EXTEND_STEP. - time_step: Only used for (ForwardMode.EXTEND_STEP). A tensor of size [batch] denoting - the 0-indexed position of the current input token. + query_pos: The index in the sequence of query vectors, [1|batch, target_length]. + kv_pos: The index in the sequence of kv vectors, [1|batch, source_length]. Returns: - For (ForwardMode.FORWARD, ForwardMode.INIT_STATES), a logit bias tensor that can be - broadcast to [batch, num_heads, query_len, kv_len]. - - For ForwardMode.EXTEND_STEP, a logit bias tensor that can be broadcast to - [batch, num_heads, 1, kv_len]. + A logit bias tensor [1|batch, 1, target_length, source_length]. """ - kv_pos = jnp.arange(kv_len) - - if mode in (ForwardMode.FORWARD, ForwardMode.INIT_STATES): - if time_step is not None: - raise ValueError( - "FORWARD or INIT_STATES modes do not expect `time_step` as an argument." - ) - query_pos = jnp.arange(kv_len if query_len is None else query_len) - mask = self._mask_fn(query_pos[:, None], kv_pos[None, :])[None, None] - elif mode == ForwardMode.EXTEND_STEP: - if query_len is not None: - raise ValueError("EXTEND_STEP mode does not expect `query_len` as an argument.") - # [batch, 1, 1, kv_len]. - # Ex: for a causal mask, mask[b, :, :, kv_pos] = 0 if time_step[b] > kv_pos else 1. - mask = self._mask_fn(time_step[:, None], kv_pos[None, :]) - mask = mask[:, None, None, :] - else: - raise ValueError(f"Unrecognized mode {mode}.") + del mode + kv_pos = kv_pos[:, None] # [1|B, 1, source_len] + query_pos = query_pos[..., None] # [1|B, target_len, 1] + # [1|B, 1, target_len, source_len] + mask = self._mask_fn(query_pos, kv_pos)[:, None] mask = bool_to_bias(mask) return mask @@ -2006,7 +1940,7 @@ def _compute_attention( self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") return context, probs @@ -2062,10 +1996,31 @@ def _cap_logits(self, logits: Tensor) -> Tensor: return cap * jnp.tanh(logits / cap) def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. + + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ q_proj = self.scale_query(q_proj) k_proj = self.scale_key(k_proj) return jnp.einsum("btnh,bsnh->bnts", q_proj, k_proj) + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + return jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + def init_states( self, *, @@ -2228,31 +2183,47 @@ class GroupedQueryAttention(MultiheadAttention): def num_kv_heads(self): return self.i_proj.num_kv_heads - def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: - """Repeats key or value heads dim to match the query.""" - num_head_repeats = self.config.num_heads // key_or_value.shape[-2] - if num_head_repeats == 1: - return key_or_value - # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. - return jnp.repeat(key_or_value, num_head_repeats, axis=-2) + def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. - def _compute_attention( - self, - *, - q_proj: Tensor, - k_proj: Tensor, - v_proj: Tensor, - **kwargs, - ) -> tuple[Tensor, Tensor]: - """See `MultiheadAttention._compute_attention` for details.""" - k_proj = self._repeat_kv_heads(k_proj) - v_proj = self._repeat_kv_heads(v_proj) - return super()._compute_attention( - q_proj=q_proj, - k_proj=k_proj, - v_proj=v_proj, - **kwargs, - ) + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ + kv_heads = k_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_logits(q_proj=q_proj, k_proj=k_proj) + + q_proj = self.scale_query(q_proj) + k_proj = self.scale_key(k_proj) + q_proj = einops.rearrange(q_proj, "b t (g k) h -> b t g k h", g=num_head_group, k=kv_heads) + k_proj = einops.rearrange(k_proj, "b s k h -> b s 1 k h") + logits = jnp.einsum("btgkh,bs1kh->bgkts", q_proj, k_proj) + return einops.rearrange(logits, "b g k t s -> b (g k) t s") + + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + kv_heads = v_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_context(probs=probs, v_proj=v_proj) + + probs = einops.rearrange(probs, "b (g k) t s -> b g k t s", g=num_head_group, k=kv_heads) + v_proj = einops.rearrange(v_proj, "b s k h -> b s 1 k h") + context = jnp.einsum("bgkts,bs1kh->btgkh", probs, v_proj) + return einops.rearrange(context, "b t g k h -> b t (g k) h") class SigmoidAttention(MultiheadAttention): @@ -2303,7 +2274,7 @@ def _compute_attention( ) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") return context, probs @@ -3601,6 +3572,45 @@ class Config(BaseTransformerLayer.Config): peak_stochastic_depth_rate: Optional[float] = None +class UpdateDataFn(Protocol): + """A function for updating the constituent layers' input in a StackTransformerLayer.""" + + def __call__( + self, data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output] + ) -> Tensor: + """Returns a new Tensor with the same shape as `data`, reflecting some desired updates. + + Args: + data: A Tensor denoting the input data to the upcoming layer. + all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with + the output of each constituent layer in the stack. + + Returns: + A new Tensor with the same shape as `data`. + """ + + +def update_data_with_skip_connection(skip_connections: dict[int, int]) -> UpdateDataFn: + """Creates a function that adds skip connection to the input data tensor. + + Args: + skip_connections: A dictionary where keys and values represent 0-indexed layer indices. + For a (k, v) pair, the output of the v-th layer will be added to the input + of the k-th layer. + + Returns: + A function that implements skip connections, following the UpdateDataFn protocol, . + """ + + def update_data(data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output]) -> Tensor: + layer_index = len(all_layer_outputs) + if layer_index in skip_connections: + data += all_layer_outputs[skip_connections[layer_index]].data + return data + + return update_data + + class StackedTransformerLayer(BaseStackedTransformerLayer): """A simple implementation of BaseStackedTransformerLayer.""" @@ -3613,10 +3623,15 @@ class Config(BaseStackedTransformerLayer.Config): layer: Union[ BaseTransformerLayer.Config, Sequence[BaseTransformerLayer.Config] ] = TransformerLayer.default_config() + # If set, implements the UpdateDataFn protocol to update individual layers' input + # data in some specified way. This operation is applied before calling every layer. + data_merger: Optional[InstantiableConfig[UpdateDataFn]] = None def __init__(self, cfg: Config, *, parent: Optional[Module]): super().__init__(cfg, parent=parent) cfg = self.config + self._update_data = maybe_instantiate(cfg.data_merger) + if isinstance(cfg.layer, Sequence): layer_cfgs = cfg.layer if len(layer_cfgs) != cfg.num_layers: @@ -3685,7 +3700,8 @@ def _forward_for_mode( all_layer_states = [] for i, layer in enumerate(self._layers): # Prepare inputs to the current layer. - data = self._update_data(data, all_layer_outputs=all_layer_outputs) + if self._update_data is not None: + data = self._update_data(data, all_layer_outputs) self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs) if mode == ForwardMode.FORWARD: @@ -3712,28 +3728,6 @@ def _forward_for_mode( return all_layer_states, self._aggregate_layer_outputs(all_layer_outputs) - def _update_data( - self, - data: Tensor, - *, - all_layer_outputs: list[BaseTransformerLayer.Output], - ): - """Updates `data` using other args. - - This method is called before we invoke each layer in `self._layers`. - The updated data will be passed to the layer invocation. - - Args: - data: A Tensor denoting the input data to the upcoming layer. - all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with - the output of each constituent layer in the stack. - - Returns: - A new Tensor. - """ - del all_layer_outputs - return data - def _update_layer_kwargs( self, layer_kwargs: dict[str, Any], diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 12fbbf33..502b1608 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -78,6 +78,7 @@ set_double_shard_weights_config, sinusoidal_positional_embeddings, sliding_window_causal_mask, + update_data_with_skip_connection, xl_attention_logits, ) from axlearn.common.base_layer import ( @@ -1428,16 +1429,20 @@ def test_qlinear(self): self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) @parameterized.parameters( - attention.QKVLinear, - attention.FusedQKVLinear, - attention.GroupedQKVLinear, - attention.FusedGroupedQKVLinear, - attention.RoFormerQKVLinear, + (attention.QKVLinear, 1), + (attention.FusedQKVLinear, 1), + (attention.GroupedQKVLinear, 1), + (attention.FusedGroupedQKVLinear, 1), + (attention.RoFormerQKVLinear, 1), + (attention.QKVLinear, 2), + (attention.FusedQKVLinear, 3), + (attention.GroupedQKVLinear, 4), + (attention.FusedGroupedQKVLinear, 3), + (attention.RoFormerQKVLinear, 2), ) - def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear]): + def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], extend_step_len): """Tests that calling QKVLinear.extend_step() multiple times with the same time_step results in the same output.""" - model_dim = 8 num_heads = 2 per_head_dim = model_dim // num_heads @@ -1459,34 +1464,33 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear]): batch_size, tgt_len = 2, 4 query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) - extend_step_state, _ = F( + fwd_output, _ = F( layer, state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(456), - inputs=dict(target_batch_size=batch_size, target_max_len=tgt_len), - method="init_states", + inputs=dict(query=query), ) - for t in range(tgt_len): - (first_call_state, first_call_output), _ = F( - layer, - state=layer_state, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=extend_step_state, query=query[:, t : t + 1]), - method="extend_step", - ) - # Rewind the time_step. - first_call_state["time_step"] -= 1 - (extend_step_state, second_call_output), _ = F( + + cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + step_querys = [] + step_keys = step_values = None + for t in range(0, tgt_len, extend_step_len): + (cache_state, step_output), _ = F( layer, state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=first_call_state, query=query[:, t : t + 1]), + inputs=dict(cached_states=cache_state, query=query[:, t : t + extend_step_len]), method="extend_step", ) - self.assertNestedAllClose(first_call_output, second_call_output) + step_querys.append(step_output.query) + step_keys = step_output.key + step_values = step_output.value + + self.assertNestedAllClose(fwd_output.query, jnp.concat(step_querys, axis=1)) + self.assertNestedAllClose(fwd_output.key, step_keys) + self.assertNestedAllClose(fwd_output.value, step_values) @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) def test_dtypes_inherited_from_parent(self, dtype: jnp.dtype): @@ -2183,7 +2187,10 @@ def test_logit_biases_for_mask(self): layer = cfg.instantiate(parent=None) layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - inputs = dict(mode=ForwardMode.FORWARD, kv_len=3, query_len=2) + query_len, kv_len = 2, 3 + query_pos = jnp.arange(query_len)[None] + kv_pos = jnp.arange(kv_len)[None] + inputs = dict(mode=ForwardMode.FORWARD, query_pos=query_pos, kv_pos=kv_pos) layer_outputs, _ = F( layer, state=layer_params, @@ -2197,33 +2204,11 @@ def test_logit_biases_for_mask(self): bool_to_bias(jnp.array([[1, 0, 0], [1, 1, 0]], dtype=jnp.bool))[None, None], ) - inputs = dict(mode=ForwardMode.FORWARD, kv_len=3, query_len=2, time_step=jnp.array([3, 4])) - with self.assertRaises(ValueError) as cm: - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertTrue(isinstance(cm.exception, ValueError)) - - inputs = dict( - mode=ForwardMode.EXTEND_STEP, kv_len=3, query_len=2, time_step=jnp.array([3, 4]) - ) - with self.assertRaises(ValueError) as cm: - F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertTrue(isinstance(cm.exception, ValueError)) - - inputs = dict(mode=ForwardMode.EXTEND_STEP, kv_len=4, time_step=jnp.array([1, 2])) + time_step = jnp.array([1, 2]) + query_pos = time_step[:, None] + kv_len = 4 + kv_pos = jnp.arange(kv_len)[None] + inputs = dict(mode=ForwardMode.EXTEND_STEP, query_pos=query_pos, kv_pos=kv_pos) layer_outputs, _ = F( layer, state=layer_params, @@ -2318,31 +2303,6 @@ def test_sliding_window( # The outputs are equivalent. self.assertNestedAllClose(outputs[0], outputs[1]) - def test_gqa_kv_heads(self): - """Checks that only the heads dim is repeated.""" - batch = source_length = num_heads = 8 - per_head_dim = 2 - num_kv_heads = 4 - dtype = jnp.float32 - key_or_value = jnp.zeros((batch, source_length, num_kv_heads, per_head_dim), dtype=dtype) - model_dim = per_head_dim * num_heads - cfg = attention.GroupedQueryAttention.default_config().set( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - input_linear=attention.FusedGroupedQKVLinear.default_config().set( - num_kv_heads=num_kv_heads - ), - dtype=dtype, - ) - test_layer = cfg.set(name="test").instantiate(parent=None) - # pylint: disable-next=protected-access - repeated_key_or_value = test_layer._repeat_kv_heads(key_or_value) - self.assertEqual( - repeated_key_or_value.shape, (batch, source_length, num_heads, per_head_dim) - ) - @parameterized.product( dtype=(jnp.float32, jnp.float16, jnp.bfloat16), per_dim_scale=(None, PerDimScale.default_config()), @@ -2435,6 +2395,7 @@ def _test_extend_step( num_heads: int, dtype: jnp.dtype, bias: bool, + extend_step_len: int, ): cfg = attention_cfg.set( query_dim=model_dim, @@ -2498,18 +2459,16 @@ def _test_extend_step( self.assertNotIn("key", initial_state["i_proj"]) self.assertNotIn("value", initial_state["i_proj"]) inputs = dict(cached_states=initial_state, kv_state=kv_state, return_aux=return_aux) - decoder_output = jnp.zeros(shape=[tgt_len, batch_size, model_dim]) - decoder_probs = jnp.zeros(shape=[tgt_len, batch_size, num_heads, tgt_len]) - for t in range(tgt_len): - inputs["query"] = jnp.expand_dims(query[:, t, :], axis=1) + decoder_output = [] + decoder_probs = [] + for t in range(0, tgt_len, extend_step_len): + inputs["query"] = query[:, t : t + extend_step_len, :] if key is not None: - inputs["key"] = jnp.expand_dims(key[:, t, :], axis=1) + inputs["key"] = key[:, t : t + extend_step_len, :] if value is not None: - inputs["value"] = jnp.expand_dims(value[:, t, :], axis=1) - inputs["attention_logit_biases"] = attention_logit_biases[ - jnp.newaxis, jnp.newaxis, t, : - ] - extend_step_outputs, _ = F( + inputs["value"] = value[:, t : t + extend_step_len, :] + inputs["attention_logit_biases"] = attention_logit_biases[t : t + extend_step_len, :] + (cached_states, extend_step_outputs), _ = F( layer, state=layer_params, is_training=False, @@ -2517,25 +2476,13 @@ def _test_extend_step( inputs=inputs, method="extend_step", ) - inputs["cached_states"] = extend_step_outputs[0] - decoder_output = decoder_output.at[t].set( - jnp.squeeze(extend_step_outputs[1].data, axis=1) - ) - decoder_probs = decoder_probs.at[t].set( - jnp.squeeze(extend_step_outputs[1].probs, axis=2) - ) - decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2]) - decoder_probs_transposed = jnp.transpose(decoder_probs, [1, 2, 0, 3]) - assert_allclose( - decoder_out_transposed, - forward_outputs.data, - atol=1e-6, - ) - assert_allclose( - decoder_probs_transposed, - forward_outputs.probs, - atol=1e-6, - ) + inputs["cached_states"] = cached_states + decoder_output.append(extend_step_outputs.data) + decoder_probs.append(extend_step_outputs.probs) + decoder_output = jnp.concatenate(decoder_output, axis=1) + decoder_probs = jnp.concatenate(decoder_probs, axis=2) + assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) + assert_allclose(decoder_probs, forward_outputs.probs, atol=1e-6) @parameterized.product( dtype=(jnp.float32, jnp.float16, jnp.bfloat16), @@ -2543,6 +2490,7 @@ def _test_extend_step( atten_logit_cap=(0.0, 20.0), bias=(True, False), input_linear=(QKVLinear, RoFormerQKVLinear, QLinear), + extend_step_len=(1, 4), ) def test_extend_step( self, @@ -2551,6 +2499,7 @@ def test_extend_step( atten_logit_cap: float, input_linear: attention.BaseQKVLinear, bias: bool, + extend_step_len: int, ): model_dim = 16 num_heads = 4 @@ -2564,7 +2513,12 @@ def test_extend_step( input_linear=input_linear, ) self._test_extend_step( - cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, ) @parameterized.product( @@ -2574,6 +2528,7 @@ def test_extend_step( num_kv_heads=(1, 2, 4), input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), bias=(True, False), + extend_step_len=(1, 4), ) def test_gqa_extend_step( self, @@ -2583,6 +2538,7 @@ def test_gqa_extend_step( num_kv_heads: int, input_linear: type[attention.BaseQKVLinear], bias: bool, + extend_step_len: int, ): model_dim = 16 num_heads = 4 @@ -2592,7 +2548,12 @@ def test_gqa_extend_step( input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), ) self._test_extend_step( - cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, ) def _test_prefill_states( @@ -2666,7 +2627,7 @@ def _test_prefill_states( self.assertTrue(jnp.all(time_step == initial_states["i_proj"]["time_step"])) for proj in ["key", "value"]: self.assertEqual( - (batch_size, num_kv_heads or num_heads, model_dim // num_heads, tgt_len), + (batch_size, tgt_len, num_kv_heads or num_heads, model_dim // num_heads), initial_states["i_proj"][proj].shape, ) self.assertEqual( @@ -3761,22 +3722,6 @@ def _aggregate_layer_outputs( ) -class TestStackedTransformerLayerWithDataOverride(NonUniformStack): - """A class with a simple override of _update_data for unit testing.""" - - @property - def forced_input(self): - return jnp.ones((2, 3, 4)) - - def _update_data( - self, - data: Tensor, - *, - all_layer_outputs: list[BaseTransformerLayer.Output], - ): - return self.forced_input - - class TestStackedTransformerLayerWithKVState(NonUniformStack): """A class with a simple override of _update_layer_kwargs for unit testing.""" @@ -3793,6 +3738,16 @@ def _update_layer_kwargs( layer_kwargs["self_attention_kv_state"] = None +class TestStackedTransformerLayerWithSkipConnection(StackedTransformerLayer): + """A class that outputs all layers' output for unit testing.""" + + def _aggregate_layer_outputs( + self, + layer_outputs: Sequence[BaseTransformerLayer.Output], + ) -> Sequence[BaseTransformerLayer.Output]: + return layer_outputs + + class StackedTransformerTest(BaseTransformerTest): """Tests StackedTransformerLayer.""" @@ -4117,47 +4072,63 @@ def test_transformer_prefill_states(self, transformer_type, layer_type): assert_allclose(decoder_self_attention_probs, forward_outputs.self_attention_probs) assert_allclose(decoder_cross_attention_probs, forward_outputs.cross_attention_probs) - def test_update_data(self): + def test_skip_connection(self): batch_size = 2 seq_len = 6 num_heads = 2 input_dim = 4 hidden_dim = 8 + num_layers = 5 + layer_with_skip_input = 3 - # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. - cfg = TestStackedTransformerLayerWithDataOverride.default_config().set(name="test") - cfg.input_dim = input_dim - cfg.num_layers = 2 + cfg = TestStackedTransformerLayerWithSkipConnection.default_config().set( + name="test", input_dim=input_dim, num_layers=num_layers + ) transformer_cfg = TransformerLayer.default_config() transformer_cfg.self_attention.attention.num_heads = num_heads transformer_cfg.feed_forward.hidden_dim = hidden_dim cfg.layer = transformer_cfg - layer: StackedTransformerLayer = cfg.instantiate(parent=None) + test_cfg = cfg.clone().set( + data_merger=config_for_function(update_data_with_skip_connection).set( + skip_connections={layer_with_skip_input: 1} + ) + ) + + base_layer = cfg.instantiate(parent=None) + test_layer = test_cfg.instantiate(parent=None) + random_inputs = jax.random.uniform( jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim) ) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - outputs_with_random_input, _ = F( - layer, + state = base_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + base_output, _ = F( + base_layer, is_training=True, prng_key=jax.random.PRNGKey(123), state=state, inputs=dict(data=random_inputs), ) - outputs_with_forced_input, _ = F( - layer, + test_output, _ = F( + test_layer, is_training=True, prng_key=jax.random.PRNGKey(123), state=state, - inputs=dict(data=layer.forced_input), - ) - self.assertNestedAllClose( - outputs_with_random_input.data, - outputs_with_forced_input.data, + inputs=dict(data=random_inputs), ) + for i in range(layer_with_skip_input): + self.assertNestedAllClose( + base_output[i].data, + test_output[i].data, + ) + for i in range(layer_with_skip_input, num_layers): + self.assertNotAlmostEqual( + jnp.min(jnp.abs(base_output[i].data - test_output[i].data)), + 0.0, + ) + def test_update_layer_kwargs(self): batch_size = 2 seq_len = 6 diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 0eb48ae3..027a5115 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -154,8 +154,12 @@ def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str): Temporary dir will be deleted after the upload is complete. """ src_dir = src_dir_handle.name - fs.makedirs(dst_dir) - for item in fs.listdir(src_dir): + src_files = fs.listdir(src_dir) + # src_files will be empty if there are no tf savables (i.e., don't have any tf state to save). + # In this case, do not create empty dst_dirs. + if len(src_files): + fs.makedirs(dst_dir) + for item in src_files: src_file = os.path.join(src_dir, item) dst_file = os.path.join(dst_dir, item) assert not fs.isdir(src_file) @@ -364,10 +368,13 @@ class Config(StateStorage.Config): timeout_secs: Barrier timeout in seconds. max_data_shard_degree: Max sharding degree of model weights along data-parallel axis. `None` and `1` means no sharding. `-1` means fully shard along data-parallel - replicas. `>1` means custom sharding degree (currently not implemented). + replicas. `>1` means custom sharding degree and should almost always be a power + of 2. max_concurrent_gb: Max concurrent shards (in GB) to write. max_concurrent_restore_gb: Max concurrent shards (in GB) to read during checkpoint restore. `None` or `0` means using a default value of 32GB. + shard_threshold_bytes: Threshold for a array shard to be data-sharded. A value of None + or <= 0 means always data-shard according to max_data_shard_degree. """ timeout_secs: float = 3600 @@ -375,6 +382,7 @@ class Config(StateStorage.Config): # TODO(hanzhi-zhou): rename this to max_concurrent_save_gb. max_concurrent_gb: Optional[int] = None max_concurrent_restore_gb: Optional[int] = None + shard_threshold_bytes: Optional[int] = None def __init__(self, cfg: Config): super().__init__(cfg) @@ -386,8 +394,14 @@ def __init__(self, cfg: Config): max_concurrent_gb=cfg.max_concurrent_gb, timeout_secs=cfg.timeout_secs, max_data_shard_degree=cfg.max_data_shard_degree, + shard_threshold_bytes=cfg.shard_threshold_bytes, ) else: + if cfg.shard_threshold_bytes is not None: + raise ValueError( + f"shard_threshold_bytes is set to {cfg.shard_threshold_bytes}, but " + "max_data_shard_degree is not set. It will not take any effect." + ) self._manager = GlobalAsyncCheckpointManager(timeout_secs=cfg.timeout_secs) if cfg.max_concurrent_restore_gb is not None and cfg.max_concurrent_restore_gb <= 0: raise ValueError( @@ -954,7 +968,6 @@ def save( if step < 0 or step >= 10**8: raise ValueError(f"Out-of-range: {step}") ckpt_dir = self.ckpt_dir(step) - self.cleanup_checkpoint(ckpt_dir) self._storage.save_to_dir( step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=write_index_file ) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2f714605..befde370 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -45,7 +45,7 @@ _GRAIN_INSTALLED = False -class _TfIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): +class _TfIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes tf.data.Iterator. Reference: @@ -105,7 +105,7 @@ async def metadata( if _GRAIN_INSTALLED: - class _GrainDatasetIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): + class _GrainDatasetIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes grain dataset iterators.""" @dataclasses.dataclass @@ -178,6 +178,8 @@ class Config(BaseCheckpointer.Config): keep_last_n: int = 1 validation_type: CheckpointValidationType = CheckpointValidationType.EXACT async_timeout_secs: int = 300 + max_concurrent_save_gb: Optional[int] = None + max_concurrent_restore_gb: Optional[int] = None @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: @@ -225,10 +227,12 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: # for simplicity. The test cases ensure that this is compatible with # `read_index_file`. "index": ocp.JsonCheckpointHandler(filename="index"), - # TODO(markblee): Add save/restore_concurrent_gb when available. # Note that this defaults to use_ocdb=True. Note also that custom `TypeHandler`s are # ignored by `StandardCheckpointHandler`, so we use `PyTreeCheckpointHandler`. - "state": ocp.PyTreeCheckpointHandler(), + "state": ocp.PyTreeCheckpointHandler( + save_concurrent_gb=cfg.max_concurrent_save_gb, + restore_concurrent_gb=cfg.max_concurrent_restore_gb, + ), }, ) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 7b485cd1..8678692a 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -112,7 +112,9 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): ) # When the given state has a different array shape: [3] instead of [2] for y. - with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"): + with self.assertRaisesRegex( + ValueError, "(checkpoint tree dtypes or shapes|not compatible)" + ): ckpt.restore( step=None, state=dict( @@ -124,7 +126,7 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): # Orbax throws AssertionError in this case. with self.assertRaisesRegex( (AssertionError, ValueError), - "(checkpoint tree dtypes or shapes|do not match)", + "(checkpoint tree dtypes or shapes|not compatible)", ): ckpt.restore( step=None, @@ -197,6 +199,77 @@ def state_specs(state, partition_spec): self.assertEqual(step, restored_step) self.assertNestedEqual(state, restored_state) + @parameterized.parameters( + # Number of files minus index and .zarray metadata. + dict( + max_data_shard_degree=None, + shard_threshold_bytes=None, + num_files=4, # 2 ararys * 2 shards (2 model) per array. + ), + dict( + max_data_shard_degree=-1, + shard_threshold_bytes=None, + num_files=16, # 2 ararys * 8 shards (2 model, 4 data) per array. + ), + dict( + max_data_shard_degree=2, + shard_threshold_bytes=None, + num_files=8, # 2 ararys * 4 shards (2 model, 2 data) per array. + ), + dict( + max_data_shard_degree=2, + shard_threshold_bytes=1024, + num_files=6, # 1 array 4 shards (2 model, 2 data) + 1 array 2 shards (small array). + ), + ) + def test_save_restore_files_count( + self, max_data_shard_degree: int, shard_threshold_bytes: int, num_files: int + ): + # Tests the effect of max_data_shard_degree and shard_threshold_bytes on number of files. + mesh_shape = (4, 2) + if not test_utils.is_supported_mesh_shape(mesh_shape): + return + + cfg: Checkpointer.Config = _checkpointer_config(Checkpointer) + cfg.storage.max_data_shard_degree = max_data_shard_degree + cfg.storage.shard_threshold_bytes = shard_threshold_bytes + ckpt: Checkpointer = cfg.instantiate(parent=None) + state = dict( + x=jnp.zeros((1024, 1024), dtype=jnp.float32), + small_x=jnp.zeros((16, 16), dtype=jnp.float32), + ) + step = 1 + + def count_files(directory): + file_count = 0 + for _, _, files in os.walk(directory): + file_count += len(files) + return file_count + + def state_specs(state): + return jax.tree.map( + lambda x: utils.TensorSpec( + shape=x.shape, + dtype=x.dtype, + mesh_axes=jax.sharding.PartitionSpec(None, "model"), + ), + state, + ) + + with _mesh(mesh_shape) as mesh: + sharding = jax.sharding.NamedSharding( + mesh, spec=jax.sharding.PartitionSpec(None, "model") + ) + state = jax.tree.map(lambda x: jax.device_put(x, device=sharding), state) + ckpt.save(step=step, state=state) + ckpt.wait_until_finished() + + restored_step, restored_state = ckpt.restore(step=step, state=state_specs(state)) + self.assertEqual(step, restored_step) + self.assertNestedEqual(state, restored_state) + + self.assertEqual(count_files(ckpt.ckpt_dir(step)), num_files + 3) + @parameterized.parameters(Checkpointer, OrbaxCheckpointer) def test_save_and_restore_latest_valid(self, checkpointer_cls: Type[BaseCheckpointer]): mesh_shape = (1, 1) @@ -243,6 +316,47 @@ def create_corrupt_ckpt(step): ckpt.wait_until_finished() self.assertNestedEqual((3, state0), ckpt.restore(step=None, state=state0)) + @parameterized.parameters(Checkpointer, OrbaxCheckpointer) + def test_save_can_override_on_gcs(self, checkpointer_cls: Type[BaseCheckpointer]): + mesh_shape = (1, 1) + if not test_utils.is_supported_mesh_shape(mesh_shape): + return + # Patch is_gcs_path for orbax, since it commits differently on gcs vs local. + with _mesh(mesh_shape), mock.patch(f"{ocp.step.__name__}.is_gcs_path", return_value=True): + cfg = _checkpointer_config(checkpointer_cls) + ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32)) + + # Save a checkpoint. + ckpt.save(step=1, state=state0) + ckpt.wait_until_finished() + self.assertNestedEqual((1, state0), ckpt.restore(step=None, state=state0)) + + if isinstance(ckpt, (Checkpointer, OrbaxCheckpointer)): + ckpt_dir = ckpt.ckpt_dir(step=1) + else: + raise NotImplementedError(type(ckpt)) + + # Corrupt the checkpoint by removing some files, while ensuring it is non-empty. + commit_file = ( + "index" if isinstance(ckpt, Checkpointer) else ocp.step._COMMIT_SUCCESS_FILE + ) + fs.rmtree(os.path.join(ckpt_dir, commit_file)) + self.assertGreater(len(fs.listdir(ckpt_dir)), 0) + + if isinstance(ckpt, OrbaxCheckpointer): + ckpt._manager.reload() # Orbax caches complete checkpoints. + + self.assertEqual(0, len(ckpt.checkpoint_paths(ckpt.config.dir))) + + # Test that save() should be able to override non-empty ckpt dir. + state1 = dict(x=jnp.ones([], dtype=jnp.int32), y=jnp.zeros([2], dtype=jnp.float32)) + ckpt.save(step=1, state=state1) + ckpt.wait_until_finished() + + # Should match the new state. + self.assertNestedEqual((1, state1), ckpt.restore(step=None, state=state1)) + @parameterized.product( checkpointer_cls=[Checkpointer, OrbaxCheckpointer], mesh_shape=[(1, 1), (2, 2), (4, 2)], @@ -893,11 +1007,30 @@ def tree_unflatten(cls, keys, values): class TensorStoreStateStorageTest(test_utils.TestCase): - @parameterized.product(max_concurrent_gb=[None, 1], max_data_shard_degree=[None, 1, -1]) - def test_max_concurrent_gb(self, max_concurrent_gb: Optional[int], max_data_shard_degree: int): + @parameterized.product( + max_concurrent_gb=[None, 1], + max_data_shard_degree=[None, 1, -1], + shard_threshold_bytes=[None, 0, int(1024**3)], + ) + def test_checkpointer_configs( + self, + max_concurrent_gb: Optional[int], + max_data_shard_degree: int, + shard_threshold_bytes: int, + ): cfg = TensorStoreStateStorage.default_config().set( - max_concurrent_gb=max_concurrent_gb, max_data_shard_degree=max_data_shard_degree + max_concurrent_gb=max_concurrent_gb, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, ) + if ( + not max_concurrent_gb + and not max_data_shard_degree + and shard_threshold_bytes is not None + ): + with self.assertRaises(ValueError): + storage = cfg.instantiate() + return storage = cfg.instantiate() if max_concurrent_gb is not None or max_data_shard_degree: self.assertIsInstance(storage._manager, BoundedDataShardedAsyncCheckpointManager) @@ -1047,6 +1180,19 @@ def test_restored_iterator_resumes(self): # should continue from the interruption. self.assertSetEqual(set(seen), set(range(num_examples))) + def test_no_save_input_iterator(self): + executor = ThreadPoolExecutor(1) + tmpdir = tempfile.mkdtemp() + ckpt_dir = os.path.join(tmpdir, "tf_ckpt") + self.assertEqual(0, len(fs.listdir(tmpdir))) + # Test that when we don't save input iterator, tf dirs are not created. + async_save_tf_savables({}, executor=executor, dir=ckpt_dir) + self.assertEqual([], fs.listdir(tmpdir)) + # Test that dirs are created if we save. + ds = tf.data.Dataset.from_tensor_slices([]) + async_save_tf_savables({"it": iter(ds)}, executor=executor, dir=ckpt_dir) + self.assertEqual(["tf_ckpt"], fs.listdir(tmpdir)) + SWITCHABLE_VDICT_IMPL: Optional[type[VDict]] = None diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index 5c4b659d..941347e5 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -280,11 +280,12 @@ def sample_decode( The sample decoding outputs. """ cfg: DecodingLayer.Config = self.config + logits_modifier = maybe_instantiate(logits_modifier) tokens_to_scores_fn = self._tokens_to_scores( num_decodes=num_decodes, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, - logits_modifier=maybe_instantiate(logits_modifier), + logits_modifier=logits_modifier, ) input_ids = self._pad( prefix, max_sequence_length=max_sequence_length, pad_id=cfg.pad_token_id diff --git a/axlearn/common/decoder_test.py b/axlearn/common/decoder_test.py index 279e6dc5..6eef860d 100644 --- a/axlearn/common/decoder_test.py +++ b/axlearn/common/decoder_test.py @@ -574,9 +574,10 @@ def test_decode( if method == "sample_decode": # Modify logits so that we will always sample the last token ID. - inputs["logits_modifier"] = ( - lambda logits: jnp.full_like(logits, decoding.NEG_INF).at[:, -1].set(0) - ) + def logits_modifier_fn(): + return lambda logits: jnp.full_like(logits, decoding.NEG_INF).at[:, -1].set(0) + + inputs["logits_modifier"] = config_for_function(logits_modifier_fn) # pylint: disable=protected-access mock_ctx = contextlib.nullcontext() diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 5f80254e..c1b19106 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -31,11 +31,10 @@ import jax import jax.numpy as jnp - -# pytype: disable=import-error # pylint: disable=import-error from jax import lax from jax._src.cudnn.fused_attention_stablehlo import MaskType, dot_product_attention from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from axlearn.common.attention import NEG_INF @@ -249,19 +248,25 @@ def flash_attention( def bias_index_map(_, j, k): return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) # Segment Ids segment_ids_block_spec = None if segment_ids is not None: assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) + ) num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 num_stages_ = num_stages if num_stages_ is None: - num_stages_ = 2 if head_dim <= 64 else 1 + num_stages_ = ( + 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 + ) kernel = functools.partial( _mha_forward_kernel, softmax_scale=softmax_scale, @@ -276,14 +281,25 @@ def bias_index_map(_, j, k): kernel, grid=grid_, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids ], - out_specs=pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - compiler_params=dict(triton=dict(num_warps=num_warps_, num_stages=num_stages_)), + out_specs=pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim) + ), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), out_shape=out_shape, debug=debug, interpret=interpret, @@ -327,20 +343,26 @@ def _mha_forward( def bias_index_map(_, j, k): return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) # Segment Ids. segment_ids_block_spec = None if segment_ids is not None: assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) + ) num_warps_ = num_warps if num_warps_ is None: num_warps_ = 4 if head_dim <= 64 else 8 num_stages_ = num_stages if num_stages_ is None: - num_stages_ = 2 if head_dim <= 64 else 1 + num_stages_ = ( + 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 + ) kernel = functools.partial( _mha_forward_kernel, softmax_scale=softmax_scale, @@ -359,18 +381,30 @@ def bias_index_map(_, j, k): kernel, grid=grid_, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], - compiler_params=dict(triton=dict(num_warps=num_warps_, num_stages=num_stages_)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), out_shape=out_shape, debug=debug, interpret=interpret, @@ -426,15 +460,24 @@ def _preprocess_backward( functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], - compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), + compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), out_shape=out_shape, debug=debug, interpret=interpret, @@ -455,7 +498,6 @@ def _mha_backward_kernel( l_ref, m_ref, delta_ref, - _, # Outputs. dq_ref, dk_ref, @@ -470,9 +512,14 @@ def _mha_backward_kernel( """Computes the backward pass. This algorithm is described in https://arxiv.org/abs/2205.14135 Appendix B.4 Algorithm 4. + Jax reference implementation: + https://github.com/jax-ml/jax/blob/0995bc231c51e2ee66995be8ee2b31adf9236509/jax/experimental/pallas/ops/gpu/attention.py#L343 See also `_mha_forward_kernel` for the forward pass. + The main difference between ours and jax reference implementation is that it supports 4-d bias, + and it supports float32 in the input dtype. + Args: q_ref: Input query ref. k_ref: Input key ref. @@ -497,71 +544,123 @@ def _mha_backward_kernel( del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - def outer_loop(start_k, _): - dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + # Parallelize over k/v's seq dimension. + # Load a block of K and V of size (block_k, block_d). + # Iterate through Q in chunks of (block_q, block_d) to accumulate dK and dV. + start_k = pl.program_id(2) + slice_k = pl.ds(start_k * block_k, block_k) + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (slice_k, slice(None))) + v = pl.load(v_ref, (slice_k, slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) + kv_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_k,)) + + def inner_loop_dk_dv(start_q, carry): + dv, dk = carry + slice_q = pl.ds(start_q * block_q, block_q) + q = pl.load(q_ref, (slice_q, slice(None))) + qk = pl.dot(q, k.T) + # These casts are needed to avoid precision issues. + qk = qk.astype(jnp.float32) + + if softmax_scale != 1.0: + qk *= softmax_scale + + if b_ref is not None: + # Load bias in transposed order, for hopefully better cache efficiency. + b = pl.load( + b_ref, + (slice_k, slice_q), + ) + b = b.astype(jnp.float32) + qk += b.T # Transpose back. + if s_ref is not None: + q_segment_ids = pl.load(s_ref, (slice_q,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk = jnp.where(mask, qk, NEG_INF) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + mask = span_q[:, None] >= span_k[None, :] + qk = jnp.where(mask, qk, NEG_INF) + m = pl.load(m_ref, (slice_q,)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (slice_q, slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (slice_q,)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if softmax_scale != 1.0: + ds = ds * softmax_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + + return dv, dk + + lower_bound = lax.div(start_k * block_k, block_q) if causal else 0 + dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop_dk_dv, (dv, dk)) + pl.store(dv_ref, (slice_k, slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (slice_k, slice(None)), dk.astype(dk_ref.dtype)) + # Free up memory. + del dv, dk + + # Parallelize over q's seq dimension. + # 1. Load a block of Q of size (block_q, block_d). + # 2. Iterate through K and V in chunks of (block_k, block_d) to accumulate dQ. + start_q = pl.program_id(2) + slice_q = pl.ds(start_q * block_q, block_q) + q = pl.load(q_ref, (slice_q, slice(None))) + dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) + q_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_q,)) + span_q = start_q * block_q + jnp.arange(block_q) + m = pl.load(m_ref, (slice_q,)) + di = pl.load(delta_ref, (slice_q,)) + do = pl.load(do_scaled_ref, (slice_q, slice(None))) + + def inner_loop_dq(start_k, carry): + dq = carry slice_k = pl.ds(start_k * block_k, block_k) k = pl.load(k_ref, (slice_k, slice(None))) v = pl.load(v_ref, (slice_k, slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) - kv_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_k)) - - def inner_loop(start_q, carry): - dv, dk = carry - slice_q = pl.ds(start_q * block_q, block_q) - q = pl.load(q_ref, (slice_q, slice(None))) - qk = pl.dot(q, k.T) - - # These casts are needed to avoid precision issues. - qk = qk.astype(jnp.float32) - - if softmax_scale != 1.0: - qk *= softmax_scale - if b_ref is not None: - # Load bias in transposed order, for hopefully better cache efficiency. - b = pl.load( - b_ref, - (slice_k, slice_q), - ) - b = b.astype(jnp.float32) - qk += b.T # Transpose back. - if s_ref is not None: - q_segment_ids = pl.load(s_ref, (slice_q)) - mask = _segment_mask(q_segment_ids, kv_segment_ids) - qk = jnp.where(mask, qk, NEG_INF) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - mask = span_q[:, None] >= span_k[None, :] - qk = jnp.where(mask, qk, NEG_INF) - m = pl.load(m_ref, (slice_q,)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (slice_q, slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (slice_q,)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if softmax_scale != 1.0: - ds = ds * softmax_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - dq = pl.load( - dq_ref, - (slice_q, slice(None)), - eviction_policy="evict_last", - ) - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - pl.store(dq_ref, (slice_q, slice(None)), dq, eviction_policy="evict_last") - return dv, dk + qk = pl.dot(q, k.T) + + # These casts are needed to avoid precision issues. + qk = qk.astype(jnp.float32) + if softmax_scale != 1.0: + qk *= softmax_scale + if b_ref is not None: + # Load bias in transposed order, for hopefully better cache efficiency. + b = pl.load( + b_ref, + (slice_k, slice_q), + ) + b = b.astype(jnp.float32) + qk += b.T # Transpose back. + if s_ref is not None: + kv_segment_ids = pl.load(s_ref, (slice_k,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk = jnp.where(mask, qk, NEG_INF) if causal: - lower_bound = lax.div(start_k * block_k, block_q) - else: - lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, (dv, dk)) - pl.store(dv_ref, (slice_k, slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (slice_k, slice(None)), dk.astype(dk_ref.dtype)) + span_k = start_k * block_k + jnp.arange(block_k) + mask = span_q[:, None] >= span_k[None, :] + qk = jnp.where(mask, qk, NEG_INF) + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if softmax_scale != 1.0: + ds = ds * softmax_scale + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + return dq + + if causal: + upper_bound = lax.div((start_q + 1) * block_q, block_k) + else: + upper_bound = pl.cdiv(seq_len, block_k) - lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) + pl.store(dq_ref, (slice_q, slice(None)), dq.astype(dq_ref.dtype)) def _mha_backward( @@ -585,8 +684,9 @@ def _mha_backward( # NOTE: temporarily removed the "xla" branch, which seems unused. if backward_pass_impl == "triton": # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. - if jnp.float32 in (q.dtype, k.dtype, v.dtype): - block_q = block_k = 64 + if jnp.float32 in (q.dtype, k.dtype, v.dtype, jnp.bfloat16 if b is None else b.dtype): + block_q = block_k = 32 + batch_size, seq_len, num_heads, head_dim = q.shape # Backward heuristics, using the same block size for block q and block k. block_q = min(block_q, seq_len) @@ -594,43 +694,36 @@ def _mha_backward( # Very tiny amount of time, not worth using pallas_call. do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), + jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct(k.shape, k.dtype), jax.ShapeDtypeStruct(v.shape, v.dtype), ] - num_input = 8 - # Bias. bias_block_spec = None if b is not None: assert b.ndim == 4 b = jnp.moveaxis(b, -1, -2) - # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. - if b.dtype == jnp.float32: - block_q = block_k = 64 - def bias_index_map(j, k): + def bias_index_map(j, k, _): return (j if b.shape[0] != 1 else 0, k if b.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) - num_input += 1 + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) # Segment Ids. segment_ids_block_spec = None if s is not None: assert s.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len)) - num_input += 1 - - input_output_aliases = {num_input: 0} - - grid = (batch_size, num_heads) - # TODO(markblee): num_warps=8 seems to work from basic testing, confirm the below comment. - # TODO(sharadmv): figure out why num_warps=8 doesn't work! + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0)), block_shape=(None, seq_len) + ) + grid = (batch_size, num_heads, pl.cdiv(seq_len, block_q)) + # Add some proof check against SRAM for float32 inputs or huge bias input. num_warps = 8 + num_stages = 2 if b is None and jnp.float32 not in (q.dtype, k.dtype, v.dtype) else 1 dq, dk, dv = pl.pallas_call( functools.partial( _mha_backward_kernel, @@ -643,29 +736,57 @@ def bias_index_map(j, k): grid=grid, out_shape=out_shapes, in_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) + ), ], out_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), - input_output_aliases=input_output_aliases, - )(q, k, v, b, s, out, do_scaled, l, m, delta, dq) + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages), + )(q, k, v, b, s, out, do_scaled, l, m, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None, None @@ -696,11 +817,6 @@ def cudnn_dot_product_attention( https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/cudnn/fused_attention_stablehlo.py#L927. https://github.com/openxla/xla/blob/536ba0b7d74f6637a7a772471a99ecf4f578aef2/xla/service/gpu/cublas_cudnn.cc#L77. - We override the Jax fused multihead attention(fMHA) interface in axlearn - due to following reasons: - 1. Original Jax implementation has a bug to support multi-node training (fixed in jax 0.4.32). - 2. We may want to leverage more lower level CuDNN capabilities from xla and expose to users. - Args: query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 48f1bf17..085eb39d 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -10,14 +10,9 @@ Currently tested on A100/H100. """ -# pylint: disable=wrong-import-position import functools -import os from typing import Literal -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - import chex import jax import jax.numpy as jnp @@ -29,6 +24,9 @@ ) from axlearn.common.flash_attention.utils import mha_reference +if jax.default_backend() != "gpu": + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + @pytest.mark.parametrize( "batch_size,seq_len,num_heads,per_head_dim", @@ -42,20 +40,17 @@ ], ) @pytest.mark.parametrize("block_size", [64, 128]) -@pytest.mark.parametrize("use_fwd", [True, False]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("sm_scale", [1.0, 0.123]) @pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]) @pytest.mark.parametrize("use_segment_ids", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") -def test_fwd_against_ref( +def test_triton_fwd_only_against_ref( batch_size: int, seq_len: int, num_heads: int, per_head_dim: int, block_size: int, - use_fwd: bool, causal: bool, sm_scale: float, attention_bias_type: Literal["2d", "4d", None], @@ -80,37 +75,26 @@ def test_fwd_against_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" - - if use_fwd: - - @jax.jit - def impl(q, k, v, bias, segment_ids): - fn = functools.partial( - flash_attention, - block_q=block_size, - block_k=block_size, - causal=causal, - softmax_scale=sm_scale, - ) - out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) - return out - - else: - impl = functools.partial( + @jax.jit + def impl(q, k, v, bias, segment_ids): + fn = functools.partial( flash_attention, block_q=block_size, block_k=block_size, causal=causal, softmax_scale=sm_scale, ) + out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) + return out o = impl(q, k, v, bias, segment_ids) o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) - chex.assert_trees_all_close(o, o_ref, atol=0.05) + chex.assert_trees_all_close(o, o_ref, atol=0.07) +# We test the flash_attention against the reference mha_reference. +# The outputs should be close in both fp16 and fp32, with a relaxed bound due +# to the numerical difference during operations. @pytest.mark.parametrize( "batch_size,num_heads,seq_len,per_head_dim", [ @@ -127,8 +111,7 @@ def impl(q, k, v, bias, segment_ids): @pytest.mark.parametrize("block_size", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") -def test_bwd_against_ref( +def test_triton_against_xla_ref( batch_size: int, num_heads: int, seq_len: int, @@ -164,9 +147,6 @@ def test_bwd_against_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" - sm_scale = q.shape[-1] ** -0.5 # Compare outputs. @@ -226,7 +206,6 @@ def ref_fn(q, k, v, bias, segment_ids): ) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") def test_cudnn_against_triton_ref( batch_size: int, num_heads: int, @@ -244,8 +223,6 @@ def test_cudnn_against_triton_ref( v = jax.random.normal( jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" sm_scale = q.shape[-1] ** -0.5 diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index b5b57dfe..36da8d59 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """FlashAttention layers.""" + from collections.abc import Sequence from typing import Optional @@ -18,8 +19,7 @@ causal_mask, make_segment_mask, ) -from axlearn.common.base_layer import BaseLayer -from axlearn.common.config import ConfigBase, config_class +from axlearn.common.config import config_class from axlearn.common.flash_attention.utils import ( MultiHeadAttentionImpl, flash_attention_implementation, @@ -28,20 +28,6 @@ from axlearn.common.utils import Tensor, with_sharding_constraint -def _check_bias_recursively(cfg: ConfigBase): - """Ensures that `cfg.bias` is set to False for all descendants.""" - - def visit_fn(_, value): - if isinstance(value, BaseLayer.Config) and getattr(value, "bias", False): - raise NotImplementedError("cfg.bias is not yet supported.") - - def enter_fn(_, value, default_kv): - return None if isinstance(value, BaseLayer.Config) and "bias" in value else default_kv - - cfg.visit(visit_fn=visit_fn, enter_fn=enter_fn) - return cfg - - class FlashAttention(GroupedQueryAttention): """FlashAttention layer. @@ -87,7 +73,6 @@ class Config(GroupedQueryAttention.Config): def __init__(self, cfg: Config, *, parent: Module): super().__init__(cfg, parent=parent) cfg = self.config - _check_bias_recursively(cfg) # Bias not supported. if getattr(cfg, "atten_logit_cap", None) is not None: raise NotImplementedError("cfg.atten_logit_cap is not supported.") # TODO(kelvinzou): enable dropout for flash attention. @@ -124,18 +109,13 @@ def _is_mask_fn_used(self): ) def _logit_biases_for_mask( - self, - *, - mode: ForwardMode, - kv_len: int, - query_len: Optional[int] = None, - time_step: Optional[Tensor] = None, + self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor ) -> Optional[Tensor]: if self._mask_fn is None: return None elif mode == ForwardMode.EXTEND_STEP: # Use biases for decoding. - return super()._logit_biases_for_mask(mode=mode, kv_len=kv_len, time_step=time_step) + return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) elif self._is_mask_fn_used(): # Biases are not needed in favor of mask_fn, which is supported in Splash Attention. return None @@ -145,9 +125,7 @@ def _logit_biases_for_mask( else: # Fall back to biases. In the subsequent _compute_attention calls, _mask_fn should not # be used. - return super()._logit_biases_for_mask( - mode=mode, kv_len=kv_len, query_len=query_len, time_step=time_step - ) + return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) def _backend(self): # For compatibility with AOT compilation, we obtain the backend type from physical_mesh. @@ -168,6 +146,17 @@ def _logit_biases_spec(self, attention_logit_biases: Tensor) -> Tensor: spec = PartitionSpec(spec[0], None, *spec[2:]) return spec + def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: + """Repeats key or value heads dim to match the query. + + TODO(dhwang2): optimize computation like GroupedQueryAttention. + """ + num_head_repeats = self.config.num_heads // key_or_value.shape[-2] + if num_head_repeats == 1: + return key_or_value + # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. + return jnp.repeat(key_or_value, num_head_repeats, axis=-2) + def _compute_attention( self, *, diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 89f5d482..52d1bc4e 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests FlashAttention layers.""" + import math import os from unittest import mock @@ -82,7 +83,14 @@ def _fake_inputs( def _prepare_layers( - *, num_heads, per_head_dim, mesh_axis_names, causal, sliding_window_size, inference=False + *, + num_heads, + per_head_dim, + mesh_axis_names, + causal, + sliding_window_size, + inference=False, + set_layer_bias_recursively=False, ): hidden_dim = num_heads * per_head_dim kwargs = dict( @@ -118,8 +126,8 @@ def _prepare_layers( ref_cfg.set(causal=causal) test_cfg.set(causal=causal) - set_bias_recursively(ref_cfg, False) - set_bias_recursively(test_cfg, False) + set_bias_recursively(ref_cfg, set_layer_bias_recursively) + set_bias_recursively(test_cfg, set_layer_bias_recursively) ref_layer = ref_cfg.set(name="ref").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) @@ -406,6 +414,7 @@ def test_forward( ) # TODO(markblee): Test probs. self.assertNestedAllClose(ref_out.data, test_out.data, atol=0.05) + jax.clear_backends() @parameterized.product( _TEST_CONFIGS, @@ -414,6 +423,7 @@ def test_forward( sliding_window_size=[None, 4], use_bias=[False, True], use_segment_ids=[False, True], + set_layer_bias_recursively=[False, True], ) def test_backward( self, @@ -428,12 +438,12 @@ def test_backward( sliding_window_size, use_bias, use_segment_ids, + set_layer_bias_recursively, ): if not is_supported_mesh_shape(mesh): pytest.skip(reason=f"Unsupported mesh {mesh}.") if use_segment_ids and query_len_multiplier != 1: pytest.skip("Segment IDs are not supported for Q and K with different lengths.") - if not causal and sliding_window_size is not None: pytest.skip(reason="Sliding window attention must be causal.") @@ -490,17 +500,17 @@ def forward(self, *, query, key, value, attention_logit_biases, segment_ids): layer=GroupedQueryAttention.default_config().set(**kwargs), ) test_cfg = DummyModel.default_config().set( - layer=FlashAttention.default_config() - .set(**kwargs, tpu_block_size=128) - .set( + layer=FlashAttention.default_config().set( + tpu_block_size=128, mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names), output_dim_to_partition_spec=default_output_dim_to_partition_spec( mesh_axis_names ), + **kwargs, ) ) - set_bias_recursively(ref_cfg, False) - set_bias_recursively(test_cfg, False) + set_bias_recursively(ref_cfg, set_layer_bias_recursively) + set_bias_recursively(test_cfg, set_layer_bias_recursively) ref_layer = ref_cfg.set(name="ref").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) # pylint: disable-next=protected-access @@ -535,10 +545,19 @@ def loss(params, inputs, layer): ref_value, ref_grads = jax.value_and_grad(loss)(params, ref_inputs, ref_layer) test_value, test_grads = jax.value_and_grad(loss)(params, inputs, test_layer) + + # Have slightly higher diffs with layer bias on GPU. We don't see this on TPU or CPU. + # pylint: disable-next=protected-access + if set_layer_bias_recursively and test_layer.layer._backend() == "gpu": + atol, rtol = 5e-4, 5e-2 + # Can be 1e-5 on x86_64/GPU/TPU, needed to be slightly higher on ARM. - atol = 1e-4 - self.assertNestedAllClose(ref_value, test_value, atol=atol) - self.assertNestedAllClose(ref_grads, test_grads, atol=atol) + else: + atol, rtol = 1e-4, 1e-3 + + self.assertNestedAllClose(ref_value, test_value, atol=atol, rtol=rtol) + self.assertNestedAllClose(ref_grads, test_grads, atol=atol, rtol=rtol) + jax.clear_backends() @parameterized.product(_TEST_CONFIGS, causal=[True], sliding_window_size=[None, 4]) def test_extend_step( @@ -634,7 +653,7 @@ def test_extend_step( initial_state = test_layer.init_states( target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state ) - ref_initial_state = test_layer.init_states( + ref_initial_state = ref_layer.init_states( target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state ) for k in ["key", "value"]: @@ -714,3 +733,4 @@ def test_extend_step( test_out.data, atol=2e-2, ) + jax.clear_backends() diff --git a/axlearn/common/input_tf_data.py b/axlearn/common/input_tf_data.py index 1a9df4a9..1d44ad1b 100644 --- a/axlearn/common/input_tf_data.py +++ b/axlearn/common/input_tf_data.py @@ -429,16 +429,21 @@ def fn() -> tf.data.Dataset: if autotune_ram_budget_gb is not None: autotuned_ds_list = [] - options = tf.data.Options() - options.autotune.enabled = True - options.autotune.ram_budget = int( - # Soft constrain to this many bytes of memory per component. - (autotune_ram_budget_gb / len(source_ds_list)) - * 1024**3 - ) - # Start fetching data on iterator creation. - options.experimental_warm_start = True for el in source_ds_list: + # We need a new Options object for each dataset, + # due to limitations on tfds side. + # It seems like only the first dataset gets the options, + # while others do not respect autotune. + options = tf.data.Options() + options.autotune.enabled = True + options.autotune.ram_budget = int( + # Soft constrain to this many bytes of memory per component. + (autotune_ram_budget_gb / len(source_ds_list)) + * 1024**3 + ) + # Start fetching data on iterator creation. + options.experimental_warm_start = True + autotuned_ds_list.append(el.with_options(options)) source_ds_list = autotuned_ds_list diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index c8beacc2..a2f6d419 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -19,7 +19,7 @@ import enum from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import chex import jax @@ -2030,7 +2030,8 @@ class StackOverTime(BaseLayer): """Stack inputs along the time axis. StackOverTime behaves the same as Conv2DWith1DPadding w.r.t. paddings along the time axis. - We treat front paddings as valid frames and back paddings as invalid frames. + Please refer to the docstring of Conv2DWith1DPadding to understand how the padding work + including "SAME", "VALID", and "CAUSAL" literals. The padding anchor is set to `left padding`. """ @config_class @@ -2038,9 +2039,13 @@ class Config(BaseLayer.Config): """Configures StackOverTime.""" stride: Required[int] = REQUIRED # Number of frames to stack. - # Number of paddings to apply along the time axis. The two integers indicate - # leading and trailing padding to add respectively. - padding: tuple[int, int] = (0, 0) + + # Number of paddings to apply along the time axis. The two integers specify the amount + # of leading and trailing padding, respectively. Alternatively, this can be a + # convolution padding literals type such as 'SAME', 'VALID', or 'CAUSAL'. + # Note: For backward compatibility, the default is set to VALID, but in most cases, + # CAUSAL is more appropriate as it preserves the sequence length. + padding: Union[tuple[int, int], Literal["SAME", "VALID", "CAUSAL"]] = "VALID" def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: """Stacks stride number of frames into one frame along the time axis. @@ -2060,11 +2065,16 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: cfg = self.config if cfg.stride <= 1: raise ValueError(f"stride should be greater than 1, but got {cfg.stride}.") - inputs = jnp.pad(inputs, ((0, 0), cfg.padding, (0, 0)), constant_values=0) - # Front paddings are valid frames. - paddings = jnp.pad(paddings, ((0, 0), (cfg.padding[0], 0)), constant_values=0) - # Back paddings are invalid frames. - paddings = jnp.pad(paddings, ((0, 0), (0, cfg.padding[1])), constant_values=1) + + # For the last partial frame. + inputs = inputs * (1 - paddings)[:, :, None] + + padding = cfg.padding + if isinstance(padding, str): + padding = conv_explicit_padding( + window=(cfg.stride,), strides=(cfg.stride,), padding=padding + )[0] + inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) batch_size, seq_len, input_dim = inputs.shape output_length = seq_len // cfg.stride @@ -2072,9 +2082,8 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: # Stack inputs over the time dimension. stacked_inputs = jnp.reshape(inputs[:, : output_length * cfg.stride, :], new_shape) # An output frame is padding if at least one of the stacked input frames is padding. - stacked_paddings = jnp.max( - jnp.reshape(paddings[:, : output_length * cfg.stride], [-1, output_length, cfg.stride]), - axis=-1, + stacked_paddings = compute_conv_paddings( + paddings, window=cfg.stride, stride=cfg.stride, conv_padding=(padding,) ) stacked_inputs = stacked_inputs * (1 - stacked_paddings)[:, :, None] return stacked_inputs, stacked_paddings @@ -2092,8 +2101,13 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti """ cfg = self.config batch_size, seq_len, input_dim = input_shape - output_length = (seq_len + sum(cfg.padding)) // cfg.stride if seq_len is not None else None - return [batch_size, output_length, input_dim * cfg.stride] + padding = cfg.padding + if isinstance(padding, tuple): + padding = (padding,) + out_shape = conv_output_shape( + [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding + ) + return [batch_size, *out_shape, input_dim * cfg.stride] class MultiLinear(BaseLayer): diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index eb0943b9..30dba474 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -2000,8 +2000,8 @@ def test_drop_tokens(self, drop_rate, num_cls_tokens): ( 3, (0, 0), - [[[1, 1, 2, 2, 3, 3]], [[0, 0, 0, 0, 0, 0]]], - [[0], [1]], + [[[1, 1, 2, 2, 3, 3]], [[7, 7, 8, 8, 0, 0]]], + [[0], [0]], ), ( 3, @@ -2066,18 +2066,14 @@ def test_stack_over_time_data_change(self): ) output_shape = layer.output_shape(input_shape=inputs.shape) self.assertAllEqual(outputs.shape, output_shape) - self.assertAllEqual(np.array([4, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) - self.assertAllClose( - np.sum(inputs**2, (1, 2)), - np.sum(outputs**2, (1, 2)) + np.array([np.sum(inputs[0][8] ** 2), 0.0]), - ) + self.assertAllEqual(np.array([5, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) + self.assertAllClose(np.sum(inputs**2, (1, 2)), np.sum(outputs**2, (1, 2))) - @parameterized.product(stride=(2, 3, 4), pad=((0, 0), (1, 1), (2, 0))) + @parameterized.product(stride=(2, 3, 4), pad=("VALID", "SAME", "CAUSAL")) def test_stack_consistent_outputs(self, stride, pad): """Tests that StackOverTime has consistent outputs under different padding lengths.""" batch_size, input_dim = 2, 1 input_length = 7 - expected_output_length = (input_length + pad[0]) // stride layer: StackOverTime = ( StackOverTime.default_config() .set( @@ -2087,12 +2083,13 @@ def test_stack_consistent_outputs(self, stride, pad): ) .instantiate(parent=None) ) + expected_output_length = layer.output_shape(input_shape=[1, input_length, 1])[1] layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) for ll in range(4, 11): # Batch with another example of length ll. length = max(input_length, ll) inputs = jnp.ones([batch_size, length, input_dim]) - paddings = jnp.arange(length)[None, :] >= jnp.array([7, ll])[:, None] + paddings = jnp.arange(length)[None, :] >= jnp.array([input_length, ll])[:, None] (outputs, output_paddings), _ = F( layer, inputs=dict(inputs=inputs, paddings=paddings), @@ -2102,7 +2099,8 @@ def test_stack_consistent_outputs(self, stride, pad): ) output_shape = layer.output_shape(input_shape=inputs.shape) self.assertAllEqual(outputs.shape, output_shape) - self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) + if pad != "VALID": # VALID doesn't preserve length. + self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) @parameterized.parameters(((0, 1), (0, 0)), ((1, 1), (3, 0)), ((1, 1), (0, 3))) def test_stack_vs_conv2d_output_len_match(self, conv_padding, stack_padding): diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index 6ca7e3b2..1a11a479 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -1219,7 +1219,7 @@ def test_learner_masking(test_self): pre-existing `CompositeLearner` implementation. """ - updates = axlearn.common.update_transformation_test.mock_updates() + updates = axlearn.common.update_transformation_test.mock_updates(state_param_none=False) param_keys = updates.opt_params.keys() state_keys = updates.inplace_updates.keys() diff --git a/axlearn/common/measurement.py b/axlearn/common/measurement.py index ee0e83da..d17072b1 100644 --- a/axlearn/common/measurement.py +++ b/axlearn/common/measurement.py @@ -46,13 +46,9 @@ class Config(Configurable.Config): Attributes: name: Name of the recorder. - upload_dir: Directory to store metrics for the monitor. - upload_interval: Time interval (seconds) for monitoring uploads. """ name: Required[str] = REQUIRED - upload_dir: Required[str] = REQUIRED - upload_interval: Required[int] = REQUIRED @classmethod def from_flags(cls, fv: Optional[flags.FlagValues]) -> "Recorder": diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index bd507c25..70517ad5 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -544,11 +544,13 @@ def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: Nest lr_scale = lr**learning_rate_exponent param_scales = _weight_decay_scales(params, per_param_scale=per_param_scale) + f = lambda g, p, s: g + weight_decay * lr_scale * p.value * s updates = jax.tree.map( - lambda g, p, s: g + weight_decay * lr_scale * p.value * s, + lambda x, y, z: None if x is None else f(x, y, z), updates, params, param_scales, + is_leaf=lambda x: x is None, ) if learning_rate_exponent is None: updated_state = state @@ -1882,9 +1884,10 @@ def _smoothed_updates( # First compute raw updates. raw_updates, pps_tree = _split_update_results( jax.tree.map( - lambda g, s: _raw_updates(grad=g, pps=s), + lambda g, s: None if g is None else _raw_updates(grad=g, pps=s), grads, state.pps, + is_leaf=lambda x: x is None, ) ) # Clip raw updates if necessary. @@ -1966,7 +1969,12 @@ def _update2(u: Tensor, param: OptParam): context.add_summary("weight_decay_rate", weight_decay * schedule_scale) return -schedule_scale * updates_with_wd - updates2 = jax.tree.map(lambda u, p: _update2(u, param=p), updates, params) + updates2 = jax.tree.map( + lambda u, p: None if u is None else _update2(u, param=p), + updates, + params, + is_leaf=lambda x: x is None, + ) return updates2, optax.safe_int32_increment(step) # Stage 1. diff --git a/axlearn/common/quantized_dot_general/layers.py b/axlearn/common/quantized_dot_general/layers.py index 2935a5d1..c0983e7a 100644 --- a/axlearn/common/quantized_dot_general/layers.py +++ b/axlearn/common/quantized_dot_general/layers.py @@ -26,12 +26,10 @@ import jax from absl import logging -from aqt.jax.v2 import aqt_dot_general -from aqt.jax.v2 import utils as aqt_utils +from aqt.jax.v2.config import DotGeneral, set_context from jax import numpy as jnp from jax.lax import DotDimensionNumbers, Precision from jax.typing import DTypeLike -from typing_extensions import Protocol from axlearn.common.base_layer import BaseLayer from axlearn.common.config import config_class @@ -65,26 +63,6 @@ class ClippingChoice(Enum): OUTPUT_ACTIVATION = 1 -class AQTDotGeneralType(Protocol): - """Typedef for AQT DotGeneral functions. - - Adds context kwarg containing prng key comparing to jax.lax.dot_general. - - """ - - def __call__( - self, - lhs: Tensor, - rhs: Tensor, - *, - dimension_numbers: DotDimensionNumbers, - precision: PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None, - context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None), - ) -> Tensor: - ... - - class QuantizedDotGeneral(BaseLayer): """Hardware accelerated quantized dot general layer. @@ -132,13 +110,9 @@ def __init__(self, cfg: Config, *, parent: Module): # for anything, we just need to init an aqt_dot_general function # with recommended configs. # Dot general with default config. - self.lhs_act_dot_general: AQTDotGeneralType = aqt_dot_general.make_dot_general( - lhs_activation_aqt_config() - ) + self.lhs_act_dot_general: DotGeneral = lhs_activation_aqt_config() # Dot general with mirrored config where lhs and rhs are swapped. - self.rhs_act_dot_general: AQTDotGeneralType = aqt_dot_general.make_dot_general( - rhs_activation_aqt_config() - ) + self.rhs_act_dot_general: DotGeneral = rhs_activation_aqt_config() elif cfg.quantization_type == DotGeneralQuantizationType.FP_8: # TODO(jiarui): Is there a way to identify if we are running on H100? if jax.default_backend() != "gpu": @@ -203,18 +177,19 @@ def _dot_general_maybe_quantized( elif cfg.quantization_type == DotGeneralQuantizationType.INT_8: # Provide prng_key and call self.aqt_dot_general. if lhs_is_activation: - fn: AQTDotGeneralType = self.lhs_act_dot_general + fn: DotGeneral = self.lhs_act_dot_general else: fn = self.rhs_act_dot_general + # Pass in prng_key for stochastic rounding + set_context( + cfg=fn, key=prng_key if prng_key is not None else self.prng_key, train_step=None + ) return fn( lhs, rhs, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, - context=aqt_dot_general.Context( - key=prng_key if prng_key is not None else self.prng_key, train_step=None - ), ) elif cfg.quantization_type == DotGeneralQuantizationType.FP_8: raise NotImplementedError("Fp8 quantization on GPU is not yet implemented") diff --git a/axlearn/common/quantized_dot_general/layers_test.py b/axlearn/common/quantized_dot_general/layers_test.py index 6aea0453..134c01c3 100644 --- a/axlearn/common/quantized_dot_general/layers_test.py +++ b/axlearn/common/quantized_dot_general/layers_test.py @@ -22,12 +22,26 @@ class TestQuantizedDotGeneral(TestCase): """Tests QuantizedDotGeneral layer.""" - # TODO(jiarui): Add TPU / GPU tests once they are available in CI - @parameterized.product(b=[2, 16], d=[4, 32], h=[8, 64]) - def test_einsum_maybe_quantized(self, b, d, h): + # TODO(jiarui): Assert output for INT8 once TPU tests are available in CI + @parameterized.product( + b=[2, 16], + d=[4, 32], + h=[8, 64], + quantization_type_and_assert_output=[ + (None, True), # Test bf16, ensure parity on output + ( + DotGeneralQuantizationType.INT_8, + False, + ), # Test INT8, ignore output parity since this is executing on CPU instead of TPU + ], + ) + def test_einsum_maybe_quantized(self, b, d, h, quantization_type_and_assert_output): + quantization_type, assert_output = quantization_type_and_assert_output # When config is None, maybe_quantized_einsum should reduce to einsum with Mesh(mesh_utils.create_device_mesh((1, 1)), ("data", "fsdp")): - quantized_dot_general_cfg = QuantizedDotGeneral.default_config() + quantized_dot_general_cfg = QuantizedDotGeneral.default_config().set( + quantization_type=quantization_type + ) quantized_dot_general_layer = quantized_dot_general_cfg.set( name="quantized_dot_general_layer" ).instantiate(parent=None) @@ -56,7 +70,8 @@ def test_einsum_maybe_quantized(self, b, d, h): method="einsum_maybe_quantized", ) reference = jnp.einsum(*inputs) - self.assertNestedAllClose(output, reference) + if assert_output: + self.assertNestedAllClose(output, reference) def test_set_quantized_dot_general_recursively(self): cfg = Decoder.default_config() diff --git a/axlearn/common/quantizer.py b/axlearn/common/quantizer.py index 199b2b9f..fdc9fe94 100644 --- a/axlearn/common/quantizer.py +++ b/axlearn/common/quantizer.py @@ -105,8 +105,6 @@ class Config(BaseLayer.Config): class Output(NamedTuple): # [..., num_codebooks]. ids: Tensor - # [..., num_codebooks, codebook_size]. - onehots: Tensor # [..., num_codebooks, codebook_dim]. quantized_vectors: Tensor # Scalar of quantizer loss. @@ -139,33 +137,55 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> Output: Returns: BaseQuantizer.Output. + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. """ raise NotImplementedError(type(self)) + def lookup(self, ids: Tensor) -> Output: + """Codebook look up with ids. + + Args: + ids: integer tensor of shape [..., num_codebooks] with values + in range [0, codebook_size). + + Returns: + BaseQuantizer.Output + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. + + Raises: + NotImplementedError: if ids.ndim > 11. + """ + return _lookup(ids=ids, codebook=self.parameters["codebook"]) + def _lookup(*, ids: Tensor, codebook: Tensor) -> BaseQuantizer.Output: """Codebook look up with ids. Args: - ids: integer tensor of shape [batch_size, seq_len, num_codebooks] with values + ids: integer tensor of shape [..., num_codebooks] with values in range [0, codebook_size). codebook: Tensor of shape [codebook_size, num_codebooks, codebook_dim]. Returns: - BaseQuantizer.Output. + BaseQuantizer.Output + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. Raises: NotImplementedError: if ids.ndim > 11. """ if ids.ndim - 1 > len(_einsum_dims): raise NotImplementedError(ids.shape) - # [..., num_codebooks, vocab_size]. - onehots = jax.nn.one_hot(ids, num_classes=codebook.shape[0], axis=-1, dtype=codebook.dtype) - batch_dims = _einsum_dims[: onehots.ndim - 2] - quantized_vectors = jnp.einsum(f"{batch_dims}gv,vgh->{batch_dims}gh", onehots, codebook) + + # [..., num_codebooks] + g_index = jnp.expand_dims(jnp.arange(ids.shape[-1]), axis=tuple(range(ids.ndim - 1))) + # codebook: [codebook_size, num_codebooks, codebook_dim], ids: [..., num_codebooks] + # -> [..., num_codebooks, codebook_dim] + quantized_vectors = codebook[ids, g_index] return BaseQuantizer.Output( ids=ids, - onehots=onehots, quantized_vectors=quantized_vectors, ) @@ -234,21 +254,24 @@ def _apply_paddings(*, outputs: BaseQuantizer.Output, paddings: Tensor) -> BaseQ Returns: padded_outputs: BaseQuantizer.Output. """ + # ids are padded with -1. - ids = outputs.ids * (1 - paddings)[:, :, None] + (-1) * paddings[:, :, None] - onehots = outputs.onehots * (1 - paddings)[:, :, None, None] + ids_paddings = paddings[:, :, None].astype(outputs.ids.dtype) + ids = outputs.ids * (1 - ids_paddings) + (-1) * ids_paddings quantized_vectors = outputs.quantized_vectors * (1 - paddings)[:, :, None, None] return BaseQuantizer.Output( ids=ids, - onehots=onehots, quantized_vectors=quantized_vectors, loss=outputs.loss, ) -def _add_codebook_summaries( - *, context: InvocationContext, outputs: BaseQuantizer.Output, paddings: Tensor -): +def _ids_to_onehots(ids: Tensor, *, codebook_size: int, dtype: jnp.dtype) -> Tensor: + # [..., num_codebooks, codebook_size]. + return jax.nn.one_hot(ids, num_classes=codebook_size, axis=-1, dtype=dtype) + + +def _add_codebook_summaries(*, context: InvocationContext, onehots: Tensor, paddings: Tensor): """Helper function to compute codebook distribution statistics and add to summaries. The statistics are from all frames, not only on those masked frames in self-supervised training. @@ -256,11 +279,11 @@ def _add_codebook_summaries( Args: context: Module invocation context to add summaries to. - outputs: BaseQuantizer.Output. + onehots: onehot of BaseQuantizer.Output.ids. paddings: 0/1 tensor of shape [batch_size, seq_len], where 0 is valid position. """ - coverage = compute_code_coverage(onehots=outputs.onehots, paddings=paddings) - pplx, entropy = compute_code_pplx(onehots=outputs.onehots, paddings=paddings) + coverage = compute_code_coverage(onehots=onehots, paddings=paddings) + pplx, entropy = compute_code_pplx(onehots=onehots, paddings=paddings) batch_size = paddings.shape[0] num_frames = jnp.sum(1 - paddings) @@ -368,18 +391,19 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> BaseQuantizer.Output: q_outputs = _apply_paddings(outputs=q_outputs, paddings=paddings) # Best-rq freezes the codebook. ids = jax.lax.stop_gradient(q_outputs.ids) - onehots = jax.lax.stop_gradient(q_outputs.onehots) quantized_vectors = jax.lax.stop_gradient(q_outputs.quantized_vectors) outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=ids, - # [batch_size, seq_len, num_codebooks, codebook_size]. - onehots=onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=quantized_vectors, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) return outputs @@ -519,15 +543,16 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> BaseQuantizer.Output: outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=quantized_inputs.ids, - # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots=quantized_inputs.onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=jnp.reshape( quantized_vectors, [batch_size, seq_len, cfg.num_codebooks, cfg.codebook_dim] ), loss=total_loss, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) return outputs @@ -614,17 +639,17 @@ def forward( ids = jnp.argmax(logits, axis=-1) if not self.is_training: - outputs = _lookup(ids=ids, codebook=self.parameters["codebook"]) + outputs = self.lookup(ids=ids) outputs = _apply_paddings(outputs=outputs, paddings=paddings) else: # [batch_size, seq_len, 1]. - mask = (1 - paddings)[:, :, None] + mask = (1 - paddings)[:, :, None].astype(ids.dtype) ids = ids * mask + (-1) * (1 - mask) + # TODO(dhwang2): optimize memory by scan for long context training. # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots = jax.nn.one_hot( - ids, num_classes=cfg.codebook_size, axis=-1, dtype=inputs.dtype - ) + onehots = _ids_to_onehots(ids, codebook_size=cfg.codebook_size, dtype=inputs.dtype) # We need this to stop gradients on the padded frames. + mask = mask.astype(inputs.dtype) onehots = onehots * mask[:, :, :, None] # [batch_size, seq_len, num_codebooks, vocab_size]. y_soft = jax.nn.softmax(logits, axis=-1) @@ -640,13 +665,14 @@ def forward( outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=ids, - # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots=onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=quantized_vectors, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) if self.is_training: self.add_module_output("probs", y_soft) self.add_summary("codebook/temperature_schedule_step", self.parameters["step"]) diff --git a/axlearn/common/quantizer_test.py b/axlearn/common/quantizer_test.py index 893dc9da..af2ed4a2 100644 --- a/axlearn/common/quantizer_test.py +++ b/axlearn/common/quantizer_test.py @@ -29,6 +29,7 @@ KmeansVectorQuantizer, RandomVectorQuantizer, SimilarityMetric, + _ids_to_onehots, compute_code_coverage, compute_code_pplx, quantize_by_nearest_neighbor, @@ -86,12 +87,13 @@ def test_quantize(self, num_groups, input_mean, metric): inputs=inputs, codebook=codebook, metric=metric ) # Compute codebook metrics. - coverage = compute_code_coverage(onehots=q_outputs.onehots, paddings=paddings) - pplx, entropy = compute_code_pplx(onehots=q_outputs.onehots, paddings=paddings) + onehots = _ids_to_onehots(q_outputs.ids, codebook_size=vocab_size, dtype=paddings.dtype) + coverage = compute_code_coverage(onehots=onehots, paddings=paddings) + pplx, entropy = compute_code_pplx(onehots=onehots, paddings=paddings) # Check shapes. self.assertEqual(q_outputs.ids.shape, (batch_size, seq_len, num_groups)) - self.assertEqual(q_outputs.onehots.shape, (batch_size, seq_len, num_groups, vocab_size)) + self.assertEqual(onehots.shape, (batch_size, seq_len, num_groups, vocab_size)) self.assertEqual( q_outputs.quantized_vectors.shape, (batch_size, seq_len, num_groups, codebook_dim) ) @@ -314,7 +316,7 @@ def test_forward( np.sum(layer_params["codebook"] ** 2), expected_values[batch_size][normalize_codebook]["codebook"], atol=1e-6, - rtol=1e-6, + rtol=2e-6, ) np.random.seed(2022) @@ -332,7 +334,6 @@ def test_forward( q_outputs.quantized_vectors.shape, ) self.assertEqual((batch_size, seq_len, num_groups), q_outputs.ids.shape) - self.assertEqual((batch_size, seq_len, num_groups, vocab_size), q_outputs.onehots.shape) assert_allclose( np.sum( jnp.reshape( @@ -349,12 +350,6 @@ def test_forward( atol=1e-6, rtol=1e-6, ) - assert_allclose( - np.sum(q_outputs.onehots), - expected_values[batch_size][normalize_codebook]["onehots"], - atol=1e-6, - rtol=1e-6, - ) self.assertEqual( output_collections.summaries["codebook/num_frames"].mean, jnp.sum(1 - paddings) / batch_size, @@ -409,9 +404,8 @@ def _loss(params, inputs, paddings, layer=layer): + o_col.summaries["codebook/entropy"].mean ) - np.random.seed(2000) - inputs = np.random.rand(batch_size, seq_len, input_dim).astype(np.float32) - paddings = np.zeros((batch_size, seq_len)).astype(np.float32) + inputs = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, seq_len, input_dim)) + paddings = jnp.zeros((batch_size, seq_len)) _, (grad_params, grad_inputs) = jax.value_and_grad(_loss, argnums=(0, 1), has_aux=False)( layer_params, jnp.asarray(inputs), jnp.asarray(paddings) @@ -419,6 +413,41 @@ def _loss(params, inputs, paddings, layer=layer): self.assertNestedAllClose(grad_params, jax.tree.map(jnp.zeros_like, layer_params)) assert_allclose(grad_inputs, jnp.zeros_like(inputs), atol=1e-6, rtol=1e-6) + def test_lookup(self): + batch_size, seq_len, input_dim = 2, 4, 20 + dim_from_all_codebooks, vocab_size, num_groups = 32, 4, 2 + cfg = RandomVectorQuantizer.default_config().set( + name="test", + input_dim=input_dim, + codebook_dim=dim_from_all_codebooks // num_groups, + codebook_size=vocab_size, + num_codebooks=num_groups, + ) + layer: RandomVectorQuantizer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(1)) + inputs = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, seq_len, input_dim)) + paddings = jnp.zeros((batch_size, seq_len)) + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + class KmeansVectorQuantizerTest(TestCase): @parameterized.product(num_groups=(1, 2), input_mean=(0.0, -0.5)) @@ -472,7 +501,7 @@ def test_forward(self, num_groups, input_mean): outputs.quantized_vectors.shape, ) self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) - self.assertEqual((batch_size, seq_len, num_groups, vocab_size), outputs.onehots.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) assert_allclose( expected_outputs[num_groups][input_mean][0], @@ -619,12 +648,6 @@ def _loss(params, inputs, paddings, layer=layer): atol=1e-6, rtol=1e-6, ) - assert_allclose( - outputs.onehots * paddings[:, :, None, None], - jnp.zeros_like(outputs.onehots), - atol=1e-6, - rtol=1e-6, - ) assert_allclose( outputs.quantized_vectors * paddings[:, :, None, None], jnp.zeros_like(outputs.quantized_vectors), @@ -653,9 +676,58 @@ def _loss(params, inputs, paddings, layer=layer): # [batch_size, seq_len, num_groups, dim]. # Gradient w.r.t codebook comes from kmeans_loss. grad_kmeans = -jnp.reshape(grad_l2_loss, [batch_size, seq_len, num_groups, codebook_dim]) - expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_kmeans, outputs.onehots) + onehots = _ids_to_onehots(outputs.ids, codebook_size=vocab_size, dtype=grad_kmeans.dtype) + expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_kmeans, onehots) self.assertNestedAllClose(grad_params, dict(codebook=expected_grad_codebook)) + def test_lookup(self): + num_groups, input_mean = 2, -0.5 + vocab_size, dim_from_all_codebooks = 4, 4 + codebook_dim = dim_from_all_codebooks // num_groups + layer: KmeansVectorQuantizer = ( + KmeansVectorQuantizer.default_config() + .set( + name="test", + codebook_dim=codebook_dim, + codebook_size=vocab_size, + num_codebooks=num_groups, + beta=0.1, + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + # [vocab_size, num_codebooks, codebook_dim]. + layer_params["codebook"] = jnp.reshape(_CODE_BOOK, [vocab_size, num_groups, codebook_dim]) + batch_size, seq_len = 2, 4 + np.random.seed(2021) + inputs = ( + np.random.rand(batch_size, seq_len, dim_from_all_codebooks).astype(np.float32) + + input_mean + ) + paddings = jnp.arange(seq_len)[None, :] >= jnp.array([2, 3])[:, None] + inputs = inputs * (1 - paddings)[:, :, None] + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(1), + state=layer_params, + drop_output_collections=[], + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + class GumbelSoftmaxVectorQuantizerTest(TestCase): @parameterized.parameters(True, False) @@ -706,12 +778,6 @@ def test_forward(self, is_training): atol=1e-6, rtol=1e-6, ) - assert_allclose( - outputs.onehots * paddings[:, :, None, None], - jnp.zeros_like(outputs.onehots), - atol=1e-6, - rtol=1e-6, - ) assert_allclose( outputs.quantized_vectors * paddings[:, :, None, None], jnp.zeros_like(outputs.quantized_vectors), @@ -849,9 +915,64 @@ def _loss(params, inputs, paddings, layer=layer): # [batch_size, seq_len, num_groups, dim]. # Gradient w.r.t codebook. - expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_q_vecs, outputs.onehots) + onehots = _ids_to_onehots(outputs.ids, codebook_size=vocab_size, dtype=grad_q_vecs.dtype) + expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_q_vecs, onehots) assert_allclose(grad_params["codebook"], expected_grad_codebook, atol=1e-6, rtol=1e-6) + def test_lookup(self): + dim_from_all_codebooks, vocab_size, num_groups = 15, 5, 3 + input_dim = 10 + step = 5 + begin_step, begin_value, end_step, end_value = 0, 21, 10, 1 + codebook_dim = dim_from_all_codebooks // num_groups + layer: GumbelSoftmaxVectorQuantizer = ( + GumbelSoftmaxVectorQuantizer.default_config() + .set( + name="test", + input_dim=input_dim, + codebook_dim=codebook_dim, + codebook_size=vocab_size, + num_codebooks=num_groups, + temperature_schedule=schedule.polynomial( + begin_step=begin_step, + begin_value=begin_value, + end_step=end_step, + end_value=end_value, + ), + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + layer_params["step"] = step + batch_size, seq_len = 2, 4 + np.random.seed(2021) + inputs = np.random.rand(batch_size, seq_len, input_dim).astype(np.float32) + paddings = np.array( + np.arange(seq_len)[None, :] >= np.array([2, 3])[:, None], dtype=np.float32 + ) + inputs = inputs * (1 - paddings)[:, :, None] + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(1), + state=layer_params, + drop_output_collections=[], + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + if __name__ == "__main__": absltest.main() diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index e75d04ed..4ffcb810 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -755,17 +755,32 @@ def _init_state(prng_key: Tensor) -> TrainerState: learner=initialized_trainer_state.learner, ) - def _log_trainer_state_stats(self): + def _log_trainer_state_stats(self) -> str: total_num_params = count_model_params(self._trainer_state.model) - self._step_log("Total number of model params: %s", f"{total_num_params:,}") + analysis_logs = [] + + def _step_log(msg, *args, **kwargs): + self._step_log(msg, *args, **kwargs) + analysis_logs.append(msg % args) + + _step_log("##################### Model analysis #####################\n") + _step_log("## Parameters:") + fmt = "%10d %-20s %s" + flatten_name_and_spec = flatten_items(self._model_param_specs) + for name, spec in flatten_name_and_spec: + spec_size = np.prod(spec.shape) + _step_log(fmt, spec_size, spec.shape, name) + + _step_log("Total number of model params: %s", f"{total_num_params:,}") self.summary_writer(0, {"num_model_params": total_num_params}) + _step_log("\n## Trainer States:") # Training state size. total_state_bytes = 0 total_sharded_state_bytes = 0 state_spec_map = dict(utils.flatten_items(self.trainer_state_specs)) for path, value in utils.flatten_items(self._trainer_state): - self._step_log( + _step_log( "State: %s=%s(%s) mesh_axes=%s", path, value.dtype, @@ -782,7 +797,7 @@ def _log_trainer_state_stats(self): else: max_sharded_state_gb = total_sharded_state_gb - self._step_log( + _step_log( "Training state size: %.2f GiB\n" "Training state size (partitioned): %.2f GiB\n" "Max training state size (partitioned): %.2f GiB", @@ -791,6 +806,9 @@ def _log_trainer_state_stats(self): max_sharded_state_gb, ) + _step_log("\n##########################################################") + return "\n".join(analysis_logs) + def _prepare_training(self, prng_key: Tensor) -> bool: """Prepares training. @@ -822,12 +840,16 @@ def _prepare_training(self, prng_key: Tensor) -> bool: # Note the default checkpointer and evaler do nothing at step 0 with min_step=1. self.save_checkpoint(self._run_eval()) - # Log trainer state tree. - if jax.process_index() == 0: - with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f: - f.write(str(jax.tree_util.tree_structure(self._trainer_state))) + model_analysis = self._log_trainer_state_stats() + + # Log trainer state tree. + if not self.step and jax.process_index() == 0: + with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f: + f.write(str(jax.tree_util.tree_structure(self._trainer_state))) + + with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f: + f.write(model_analysis) - self._log_trainer_state_stats() self._maybe_record_event(measurement.Event.END_TRAINING_PREPARATION) # Log config. self.summary_writer.log_config(cfg, step=self.step) diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index 149864b8..39177429 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -417,6 +417,11 @@ def test_trainer( with open(os.path.join(cfg.dir, "trainer_state_tree.txt"), encoding="utf-8") as f: self.assertStartsWith(f.read(), "PyTreeDef(CustomNode(namedtuple[TrainerState], [*, ") + with open(os.path.join(cfg.dir, "model_analysis.txt"), encoding="utf-8") as f: + self.assertStartsWith( + f.read(), "##################### Model analysis #####################" + ) + if start_trace_steps: trace_dir = os.path.join(cfg.dir, "summaries", "train_train", "plugins", "profile") profile_files = [] @@ -856,6 +861,7 @@ def test_run_builder(self, restore_from_builder: bool): first_output = trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) # Make sure checkpoint exists. trainer2: SpmdTrainer = cfg.instantiate(parent=None) with trainer2.mesh(): @@ -931,6 +937,7 @@ def fn(*, step: int, evaler_summaries: dict[str, Any]): trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) trainer2: SpmdTrainer = cfg.clone(save_input_iterator=restore_input_iterator).instantiate( parent=None ) @@ -972,6 +979,7 @@ def test_last_step_checkpoint_policy(self): trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) trainer2: SpmdTrainer = cfg.instantiate(parent=None) with trainer2.mesh(): # We should have checkpointed at the last step. diff --git a/axlearn/common/update_transformation_test.py b/axlearn/common/update_transformation_test.py index bcfc49d1..b29062a6 100644 --- a/axlearn/common/update_transformation_test.py +++ b/axlearn/common/update_transformation_test.py @@ -166,9 +166,11 @@ def mock_params() -> Nested[Tensor]: ) -def mock_updates() -> axlearn.common.update_transformation.Updates: +def mock_updates(state_param_none: bool = True) -> axlearn.common.update_transformation.Updates: """Create an updates object with various semi-reasonable values.""" model_params = mock_params() + if state_param_none: + model_params["state"] = None opt_params = jax.tree.map( lambda p: OptParam( value=p, @@ -197,6 +199,7 @@ def test_param_values(self): updates = mock_updates() actual = updates.param_values() expected = mock_params() + expected["state"] = None chex.assert_trees_all_equal_structs(actual, expected) self.assertNestedAllClose(actual, expected) @@ -218,12 +221,7 @@ def test_param_specs(self): weight_decay_scale=0.1, ) ), - state=ParameterSpec( - shape=(2,), - dtype=jnp.int32, - factorization=FactorizationSpec([None]), - weight_decay_scale=0.1, - ), + state=None, more_state=ParameterSpec( shape=(3,), dtype=jnp.int32, diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 37b88d56..2c755a32 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1419,3 +1419,25 @@ class DeviceUsage: hbm_memory_usage_bytes: Optional[int] = None hbm_memory_total_bytes: Optional[int] = None hbm_memory_bandwidth_utilization: Optional[float] = None + + +def sequence_mask(*, lengths: Tensor, max_len: int, dtype: Optional[jnp.dtype] = None) -> Tensor: + """Computes a mask over sequence positions for each given length. + + Args: + lengths: [...]. int32 + max_len: T, int + dtype: outputs dtype. + + Returns: + Tensor [..., T]. 1 is valid and 0 is padding. + """ + if dtype is None: + dtype = lengths.dtype + + prefix_axis = tuple(range(lengths.ndim)) + # [..., T] + sequence = jnp.expand_dims(jnp.arange(max_len), axis=prefix_axis) + # [..., 1] + lengths = lengths[..., jnp.newaxis] + return (sequence < lengths).astype(dtype) diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index bc0007ce..cfbbb745 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -87,6 +87,10 @@ def setup( num_processes=num_processes, process_id=process_id, ) + if jax_backend == "gpu": + # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying + # local_device_ids arg allows us to maintain expected behavior + init_kwargs["local_device_ids"] = list(range(8)) jax.distributed.initialize(**init_kwargs) _jax_distributed_initialized = True diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 85d557f9..3b631155 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -21,7 +21,7 @@ from jax.experimental import checkify, mesh_utils from jax.sharding import PartitionSpec -from axlearn.common import learner, optimizers, serialization, struct +from axlearn.common import learner, optimizers, serialization, struct, utils from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterSpec from axlearn.common.config import config_class, config_for_function, similar_names from axlearn.common.layers import BatchNorm, LayerNorm, Linear @@ -761,6 +761,19 @@ def test_check_jax_type(self): with self.assertRaisesRegex(ValueError, "^Argument key has leaf with non-JAX type"): check_jax_type(pretty_named_args={"key": "1"}) + @parameterized.parameters( + dict(lengths=[3, 4], dtype=None, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[3, 4], dtype=jnp.int32, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[3, 4], dtype=jnp.float32, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[[3], [4]], dtype=jnp.int32, expected=[[[1, 1, 1, 0, 0]], [[1, 1, 1, 1, 0]]]), + dict(lengths=[[3, 4]], dtype=jnp.int32, expected=[[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]]), + ) + def test_sequence_mask(self, lengths, dtype, expected): + max_len = 5 + mask = utils.sequence_mask(lengths=jnp.array(lengths), max_len=max_len, dtype=dtype) + expected = jnp.array(expected).astype(dtype if dtype else jnp.int32) + self.assertNestedAllClose(mask, expected) + class SimilarNamesTest(TestCase): @parameterized.parameters( diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 30ac0ba6..d32120d2 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -466,6 +466,7 @@ def mixture_train_input_source( max_sequence_length: int, replace_newlines_with: str = REPLACE_NEWLINES_WITH, fake_input_source_cfg: Optional[InstantiableConfig] = None, + autotune_ram_budget_gb: Optional[int] = None, ) -> input_tf_data.BuildDatasetFn: """Build mixture training input source for decoder-only LM model. @@ -483,6 +484,9 @@ def mixture_train_input_source( replace_newlines_with: Value to replace newlines with in the text. fake_input_source_cfg: A config that instantiates to a BuildDatasetFn for the input source used during unittest. + autotune_ram_budget_gb: The memory budget (in GiB) the tensorflow datasets optimization + pipeline will target. Typically configure as 50%-75% of available memory. + If None, uses tensorflow defaults. Returns: A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). @@ -535,6 +539,7 @@ def _set_config_for_preprocessor(p: InstantiableConfig) -> InstantiableConfig: sources=sources, weights=weights, is_training=is_training, + autotune_ram_budget_gb=autotune_ram_budget_gb, ) diff --git a/axlearn/vision/attention.py b/axlearn/vision/attention.py index fd943d9a..4b96db95 100644 --- a/axlearn/vision/attention.py +++ b/axlearn/vision/attention.py @@ -229,7 +229,7 @@ def forward( attention_logit_biases = attention_logit_biases[:, None, :, :] probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :]) self.vlog(3, "atten.context=%s", context.sum()) diff --git a/axlearn/vision/beit_image_tokenizer.py b/axlearn/vision/beit_image_tokenizer.py index 19a7a7f1..d655c5e3 100644 --- a/axlearn/vision/beit_image_tokenizer.py +++ b/axlearn/vision/beit_image_tokenizer.py @@ -185,11 +185,16 @@ def forward(self, inputs: Tensor) -> tuple[Tensor, dict[str, Tensor]]: paddings = jnp.zeros(encoded_outputs.shape[:2]) quantized_output = self.quantizer(inputs=encoded_outputs, paddings=paddings) # quantized_output.quantized_vectors shape [batch_size, seq_len, 1, codebook_dim] - # quantized_output.onehots in shape [batch_size, seq_len, 1, codebook_size] # quantized_output.ids in shape [batch_size, seq_len, 1] + onehots = jax.nn.one_hot( + quantized_output.ids, + num_classes=self.config.quantizer.codebook_size, + axis=-1, + dtype=paddings.dtype, + ) return jnp.squeeze(quantized_output.ids, axis=-1), { "quantized_vectors": jnp.squeeze(quantized_output.quantized_vectors, axis=-2), - "quantized_codebook_onehots": jnp.squeeze(quantized_output.onehots, axis=-2), + "quantized_codebook_onehots": jnp.squeeze(onehots, axis=-2), } diff --git a/pyproject.toml b/pyproject.toml index 1cd1ea12..2189eea9 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "axlearn" -version = "0.1.3" +version = "0.1.4" description = "AXLearn" readme = "README.md" requires-python = ">=3.10" @@ -22,9 +22,10 @@ dependencies = [ core = [ "absl-py==2.1.0", "chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25. + "einops==0.8.0", "importlab==0.7", # breaks pytype on 0.8 - "jax==0.4.33", - "jaxlib==0.4.33", + "jax==0.4.34", + "jaxlib==0.4.34", "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", @@ -53,10 +54,10 @@ apple-silicon = [ ] # Requirements for testing and development. dev = [ + "axlearn[core]", # core "axlearn[audio]", # audio tests "axlearn[orbax]", # checkpointer tests "black==23.1a1", # formatting - "einops==0.8.0", "evaluate", "isort", # formatting "pika==1.3.2", # used by event queue @@ -101,7 +102,7 @@ gcp = [ # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. tpu = [ "axlearn[gcp]", - "jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p. + "jax[tpu]==0.4.34", # must be >=0.4.19 for compat with v5p. ] # Vertex AI tensorboard. TODO(markblee): Merge with `gcp`. vertexai_tensorboard = [ @@ -125,7 +126,7 @@ dataflow = [ # GPU custom kernel dependency. gpu = [ "triton==2.1.0", - "jax[cuda12_pip]==0.4.33", + "jax[cuda12]==0.4.34", ] # Open API inference. open_api = [ @@ -145,7 +146,7 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.5.23", + "orbax-checkpoint==0.9.1", ] # Grain input processing. Currently does not support macos. grain = [