From 0b6643f06491fd105f084dc305a3225542a89183 Mon Sep 17 00:00:00 2001 From: jiarui-lu2 <136027585+jiarui-lu2@users.noreply.github.com> Date: Mon, 4 Nov 2024 15:43:58 -0800 Subject: [PATCH 01/27] Fix Regression from https://github.com/apple/axlearn/pull/737 (#810) --- .../common/quantized_dot_general/layers.py | 41 ++++--------------- .../quantized_dot_general/layers_test.py | 25 ++++++++--- 2 files changed, 28 insertions(+), 38 deletions(-) diff --git a/axlearn/common/quantized_dot_general/layers.py b/axlearn/common/quantized_dot_general/layers.py index 2935a5d15..c0983e7a6 100644 --- a/axlearn/common/quantized_dot_general/layers.py +++ b/axlearn/common/quantized_dot_general/layers.py @@ -26,12 +26,10 @@ import jax from absl import logging -from aqt.jax.v2 import aqt_dot_general -from aqt.jax.v2 import utils as aqt_utils +from aqt.jax.v2.config import DotGeneral, set_context from jax import numpy as jnp from jax.lax import DotDimensionNumbers, Precision from jax.typing import DTypeLike -from typing_extensions import Protocol from axlearn.common.base_layer import BaseLayer from axlearn.common.config import config_class @@ -65,26 +63,6 @@ class ClippingChoice(Enum): OUTPUT_ACTIVATION = 1 -class AQTDotGeneralType(Protocol): - """Typedef for AQT DotGeneral functions. - - Adds context kwarg containing prng key comparing to jax.lax.dot_general. - - """ - - def __call__( - self, - lhs: Tensor, - rhs: Tensor, - *, - dimension_numbers: DotDimensionNumbers, - precision: PrecisionLike = None, - preferred_element_type: Optional[DTypeLike] = None, - context: aqt_utils.Context = aqt_utils.Context(key=None, train_step=None), - ) -> Tensor: - ... - - class QuantizedDotGeneral(BaseLayer): """Hardware accelerated quantized dot general layer. @@ -132,13 +110,9 @@ def __init__(self, cfg: Config, *, parent: Module): # for anything, we just need to init an aqt_dot_general function # with recommended configs. # Dot general with default config. - self.lhs_act_dot_general: AQTDotGeneralType = aqt_dot_general.make_dot_general( - lhs_activation_aqt_config() - ) + self.lhs_act_dot_general: DotGeneral = lhs_activation_aqt_config() # Dot general with mirrored config where lhs and rhs are swapped. - self.rhs_act_dot_general: AQTDotGeneralType = aqt_dot_general.make_dot_general( - rhs_activation_aqt_config() - ) + self.rhs_act_dot_general: DotGeneral = rhs_activation_aqt_config() elif cfg.quantization_type == DotGeneralQuantizationType.FP_8: # TODO(jiarui): Is there a way to identify if we are running on H100? if jax.default_backend() != "gpu": @@ -203,18 +177,19 @@ def _dot_general_maybe_quantized( elif cfg.quantization_type == DotGeneralQuantizationType.INT_8: # Provide prng_key and call self.aqt_dot_general. if lhs_is_activation: - fn: AQTDotGeneralType = self.lhs_act_dot_general + fn: DotGeneral = self.lhs_act_dot_general else: fn = self.rhs_act_dot_general + # Pass in prng_key for stochastic rounding + set_context( + cfg=fn, key=prng_key if prng_key is not None else self.prng_key, train_step=None + ) return fn( lhs, rhs, dimension_numbers=dimension_numbers, precision=precision, preferred_element_type=preferred_element_type, - context=aqt_dot_general.Context( - key=prng_key if prng_key is not None else self.prng_key, train_step=None - ), ) elif cfg.quantization_type == DotGeneralQuantizationType.FP_8: raise NotImplementedError("Fp8 quantization on GPU is not yet implemented") diff --git a/axlearn/common/quantized_dot_general/layers_test.py b/axlearn/common/quantized_dot_general/layers_test.py index 6aea04532..134c01c32 100644 --- a/axlearn/common/quantized_dot_general/layers_test.py +++ b/axlearn/common/quantized_dot_general/layers_test.py @@ -22,12 +22,26 @@ class TestQuantizedDotGeneral(TestCase): """Tests QuantizedDotGeneral layer.""" - # TODO(jiarui): Add TPU / GPU tests once they are available in CI - @parameterized.product(b=[2, 16], d=[4, 32], h=[8, 64]) - def test_einsum_maybe_quantized(self, b, d, h): + # TODO(jiarui): Assert output for INT8 once TPU tests are available in CI + @parameterized.product( + b=[2, 16], + d=[4, 32], + h=[8, 64], + quantization_type_and_assert_output=[ + (None, True), # Test bf16, ensure parity on output + ( + DotGeneralQuantizationType.INT_8, + False, + ), # Test INT8, ignore output parity since this is executing on CPU instead of TPU + ], + ) + def test_einsum_maybe_quantized(self, b, d, h, quantization_type_and_assert_output): + quantization_type, assert_output = quantization_type_and_assert_output # When config is None, maybe_quantized_einsum should reduce to einsum with Mesh(mesh_utils.create_device_mesh((1, 1)), ("data", "fsdp")): - quantized_dot_general_cfg = QuantizedDotGeneral.default_config() + quantized_dot_general_cfg = QuantizedDotGeneral.default_config().set( + quantization_type=quantization_type + ) quantized_dot_general_layer = quantized_dot_general_cfg.set( name="quantized_dot_general_layer" ).instantiate(parent=None) @@ -56,7 +70,8 @@ def test_einsum_maybe_quantized(self, b, d, h): method="einsum_maybe_quantized", ) reference = jnp.einsum(*inputs) - self.assertNestedAllClose(output, reference) + if assert_output: + self.assertNestedAllClose(output, reference) def test_set_quantized_dot_general_recursively(self): cfg = Decoder.default_config() From c568e53046fd9e4085545985a1085320c60dee7d Mon Sep 17 00:00:00 2001 From: Floris Weers Date: Mon, 4 Nov 2024 20:09:28 -0800 Subject: [PATCH 02/27] Fix tfds autotune (#811) --- axlearn/common/input_tf_data.py | 23 ++++++++++++++--------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/axlearn/common/input_tf_data.py b/axlearn/common/input_tf_data.py index 1a9df4a97..1d44ad1bb 100644 --- a/axlearn/common/input_tf_data.py +++ b/axlearn/common/input_tf_data.py @@ -429,16 +429,21 @@ def fn() -> tf.data.Dataset: if autotune_ram_budget_gb is not None: autotuned_ds_list = [] - options = tf.data.Options() - options.autotune.enabled = True - options.autotune.ram_budget = int( - # Soft constrain to this many bytes of memory per component. - (autotune_ram_budget_gb / len(source_ds_list)) - * 1024**3 - ) - # Start fetching data on iterator creation. - options.experimental_warm_start = True for el in source_ds_list: + # We need a new Options object for each dataset, + # due to limitations on tfds side. + # It seems like only the first dataset gets the options, + # while others do not respect autotune. + options = tf.data.Options() + options.autotune.enabled = True + options.autotune.ram_budget = int( + # Soft constrain to this many bytes of memory per component. + (autotune_ram_budget_gb / len(source_ds_list)) + * 1024**3 + ) + # Start fetching data on iterator creation. + options.experimental_warm_start = True + autotuned_ds_list.append(el.with_options(options)) source_ds_list = autotuned_ds_list From c2d7518f8917c4313e65e1470c1118ec02d561fe Mon Sep 17 00:00:00 2001 From: kelvin-zou <166073445+kelvin-zou@users.noreply.github.com> Date: Tue, 5 Nov 2024 09:23:12 -0800 Subject: [PATCH 03/27] [Bug fix] Update gpu flash attention after syntax change and fixed unit tests for flash attention (#809) * snapshot * fix precision issue --- .../common/flash_attention/gpu_attention.py | 157 +++++++++++++----- .../flash_attention/gpu_attention_test.py | 51 ++---- axlearn/common/flash_attention/layer_test.py | 12 +- 3 files changed, 135 insertions(+), 85 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 5f80254ee..915914cc1 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -31,11 +31,10 @@ import jax import jax.numpy as jnp - -# pytype: disable=import-error # pylint: disable=import-error from jax import lax from jax._src.cudnn.fused_attention_stablehlo import MaskType, dot_product_attention from jax.experimental import pallas as pl +from jax.experimental.pallas import gpu as plgpu from axlearn.common.attention import NEG_INF @@ -249,12 +248,16 @@ def flash_attention( def bias_index_map(_, j, k): return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) # Segment Ids segment_ids_block_spec = None if segment_ids is not None: assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) + ) num_warps_ = num_warps if num_warps_ is None: @@ -276,14 +279,25 @@ def bias_index_map(_, j, k): kernel, grid=grid_, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids ], - out_specs=pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - compiler_params=dict(triton=dict(num_warps=num_warps_, num_stages=num_stages_)), + out_specs=pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim) + ), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), out_shape=out_shape, debug=debug, interpret=interpret, @@ -327,13 +341,17 @@ def _mha_forward( def bias_index_map(_, j, k): return (j if bias.shape[0] != 1 else 0, k if bias.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) # Segment Ids. segment_ids_block_spec = None if segment_ids is not None: assert segment_ids.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda _, j, k: (j, 0), (None, seq_len)) + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0)), block_shape=(None, seq_len) + ) num_warps_ = num_warps if num_warps_ is None: @@ -359,18 +377,30 @@ def bias_index_map(_, j, k): kernel, grid=grid_, in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], - compiler_params=dict(triton=dict(num_warps=num_warps_, num_stages=num_stages_)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps_, num_stages=num_stages_), out_shape=out_shape, debug=debug, interpret=interpret, @@ -426,15 +456,24 @@ def _preprocess_backward( functools.partial(_preprocess_backward_kernel, block_q=block_q), grid=(pl.cdiv(seq_len, block_q), batch_size, num_heads), in_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], out_specs=[ - pl.BlockSpec(lambda _, j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda _, j, k: (j, k, 0), (None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda _, j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda _, j, k: (j, k, 0)), block_shape=(None, None, seq_len)), ], - compiler_params=dict(triton=dict(num_warps=4, num_stages=3)), + compiler_params=plgpu.TritonCompilerParams(num_warps=4, num_stages=3), out_shape=out_shape, debug=debug, interpret=interpret, @@ -586,7 +625,7 @@ def _mha_backward( if backward_pass_impl == "triton": # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. if jnp.float32 in (q.dtype, k.dtype, v.dtype): - block_q = block_k = 64 + block_q = block_k = 32 batch_size, seq_len, num_heads, head_dim = q.shape # Backward heuristics, using the same block size for block q and block k. block_q = min(block_q, seq_len) @@ -610,19 +649,23 @@ def _mha_backward( b = jnp.moveaxis(b, -1, -2) # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. if b.dtype == jnp.float32: - block_q = block_k = 64 + block_q = block_k = 32 def bias_index_map(j, k): return (j if b.shape[0] != 1 else 0, k if b.shape[1] != 1 else 0, 0, 0) - bias_block_spec = pl.BlockSpec(bias_index_map, (None, None, seq_len, seq_len)) + bias_block_spec = pl.BlockSpec( + index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) + ) num_input += 1 # Segment Ids. segment_ids_block_spec = None if s is not None: assert s.ndim == 2 - segment_ids_block_spec = pl.BlockSpec(lambda j, k: (j, 0), (None, seq_len)) + segment_ids_block_spec = pl.BlockSpec( + index_map=(lambda j, k: (j, 0)), block_shape=(None, seq_len) + ) num_input += 1 input_output_aliases = {num_input: 0} @@ -643,27 +686,54 @@ def bias_index_map(j, k): grid=grid, out_shape=out_shapes, in_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # query - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # key - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), # value + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # query + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # key + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, k, 0), (None, None, seq_len)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), + pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), + pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), ], out_specs=[ - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), - pl.BlockSpec(lambda j, k: (j, 0, k, 0), (None, seq_len, None, head_dim)), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), + pl.BlockSpec( + index_map=(lambda j, k: (j, 0, k, 0)), + block_shape=(None, seq_len, None, head_dim), + ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=dict(triton=dict(num_warps=num_warps, num_stages=1)), + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=1), input_output_aliases=input_output_aliases, )(q, k, v, b, s, out, do_scaled, l, m, delta, dq) else: @@ -696,11 +766,6 @@ def cudnn_dot_product_attention( https://github.com/google/jax/blob/f4158ace933482844c145a6b919bf5dc86e084ba/jax/_src/cudnn/fused_attention_stablehlo.py#L927. https://github.com/openxla/xla/blob/536ba0b7d74f6637a7a772471a99ecf4f578aef2/xla/service/gpu/cublas_cudnn.cc#L77. - We override the Jax fused multihead attention(fMHA) interface in axlearn - due to following reasons: - 1. Original Jax implementation has a bug to support multi-node training (fixed in jax 0.4.32). - 2. We may want to leverage more lower level CuDNN capabilities from xla and expose to users. - Args: query: Query of shape [batch_size, target_length, num_heads, per_head_dim]. key: Key of shape [batch_size, source_length, num_heads, per_head_dim]. diff --git a/axlearn/common/flash_attention/gpu_attention_test.py b/axlearn/common/flash_attention/gpu_attention_test.py index 48f1bf172..085eb39dc 100644 --- a/axlearn/common/flash_attention/gpu_attention_test.py +++ b/axlearn/common/flash_attention/gpu_attention_test.py @@ -10,14 +10,9 @@ Currently tested on A100/H100. """ -# pylint: disable=wrong-import-position import functools -import os from typing import Literal -os.environ["XLA_PYTHON_CLIENT_PREALLOCATE"] = "false" -os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform" - import chex import jax import jax.numpy as jnp @@ -29,6 +24,9 @@ ) from axlearn.common.flash_attention.utils import mha_reference +if jax.default_backend() != "gpu": + pytest.skip(reason="Incompatible hardware", allow_module_level=True) + @pytest.mark.parametrize( "batch_size,seq_len,num_heads,per_head_dim", @@ -42,20 +40,17 @@ ], ) @pytest.mark.parametrize("block_size", [64, 128]) -@pytest.mark.parametrize("use_fwd", [True, False]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("sm_scale", [1.0, 0.123]) @pytest.mark.parametrize("attention_bias_type", [None, "2d", "4d"]) @pytest.mark.parametrize("use_segment_ids", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") -def test_fwd_against_ref( +def test_triton_fwd_only_against_ref( batch_size: int, seq_len: int, num_heads: int, per_head_dim: int, block_size: int, - use_fwd: bool, causal: bool, sm_scale: float, attention_bias_type: Literal["2d", "4d", None], @@ -80,37 +75,26 @@ def test_fwd_against_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" - - if use_fwd: - - @jax.jit - def impl(q, k, v, bias, segment_ids): - fn = functools.partial( - flash_attention, - block_q=block_size, - block_k=block_size, - causal=causal, - softmax_scale=sm_scale, - ) - out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) - return out - - else: - impl = functools.partial( + @jax.jit + def impl(q, k, v, bias, segment_ids): + fn = functools.partial( flash_attention, block_q=block_size, block_k=block_size, causal=causal, softmax_scale=sm_scale, ) + out, _ = jax.vjp(fn, q, k, v, bias, segment_ids) + return out o = impl(q, k, v, bias, segment_ids) o_ref = mha_reference(q, k, v, bias, segment_ids, causal=causal, softmax_scale=sm_scale) - chex.assert_trees_all_close(o, o_ref, atol=0.05) + chex.assert_trees_all_close(o, o_ref, atol=0.07) +# We test the flash_attention against the reference mha_reference. +# The outputs should be close in both fp16 and fp32, with a relaxed bound due +# to the numerical difference during operations. @pytest.mark.parametrize( "batch_size,num_heads,seq_len,per_head_dim", [ @@ -127,8 +111,7 @@ def impl(q, k, v, bias, segment_ids): @pytest.mark.parametrize("block_size", [64, 128]) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("input_dtype", [jnp.float16, jnp.float32]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") -def test_bwd_against_ref( +def test_triton_against_xla_ref( batch_size: int, num_heads: int, seq_len: int, @@ -164,9 +147,6 @@ def test_bwd_against_ref( jnp.concatenate([segment_left, segment_right], axis=-1) if use_segment_ids else None ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" - sm_scale = q.shape[-1] ** -0.5 # Compare outputs. @@ -226,7 +206,6 @@ def ref_fn(q, k, v, bias, segment_ids): ) @pytest.mark.parametrize("causal", [True, False]) @pytest.mark.parametrize("dtype", [jnp.bfloat16, jnp.float16]) -@pytest.mark.skipif(jax.devices()[0].platform != "gpu", reason="Test only runs on GPU.") def test_cudnn_against_triton_ref( batch_size: int, num_heads: int, @@ -244,8 +223,6 @@ def test_cudnn_against_triton_ref( v = jax.random.normal( jax.random.PRNGKey(2), (batch_size, seq_len, num_heads, per_head_dim), dtype=dtype ) - # Make sure that it is running on GPU. - assert str(q.devices()) == "{cuda(id=0)}" sm_scale = q.shape[-1] ** -0.5 diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 89f5d4821..5ac4db3bc 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -82,7 +82,13 @@ def _fake_inputs( def _prepare_layers( - *, num_heads, per_head_dim, mesh_axis_names, causal, sliding_window_size, inference=False + *, + num_heads, + per_head_dim, + mesh_axis_names, + causal, + sliding_window_size, + inference=False, ): hidden_dim = num_heads * per_head_dim kwargs = dict( @@ -406,6 +412,7 @@ def test_forward( ) # TODO(markblee): Test probs. self.assertNestedAllClose(ref_out.data, test_out.data, atol=0.05) + jax.clear_backends() @parameterized.product( _TEST_CONFIGS, @@ -433,7 +440,6 @@ def test_backward( pytest.skip(reason=f"Unsupported mesh {mesh}.") if use_segment_ids and query_len_multiplier != 1: pytest.skip("Segment IDs are not supported for Q and K with different lengths.") - if not causal and sliding_window_size is not None: pytest.skip(reason="Sliding window attention must be causal.") @@ -539,6 +545,7 @@ def loss(params, inputs, layer): atol = 1e-4 self.assertNestedAllClose(ref_value, test_value, atol=atol) self.assertNestedAllClose(ref_grads, test_grads, atol=atol) + jax.clear_backends() @parameterized.product(_TEST_CONFIGS, causal=[True], sliding_window_size=[None, 4]) def test_extend_step( @@ -714,3 +721,4 @@ def test_extend_step( test_out.data, atol=2e-2, ) + jax.clear_backends() From 04e5aac19809d642fe1a1fac56a30a16f37ae0d9 Mon Sep 17 00:00:00 2001 From: Floris Weers Date: Tue, 5 Nov 2024 11:10:11 -0800 Subject: [PATCH 04/27] Enable autotune ram for axlearn dataset (#813) --- axlearn/experiments/text/gpt/common.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/axlearn/experiments/text/gpt/common.py b/axlearn/experiments/text/gpt/common.py index 30ac0ba6a..d32120d25 100644 --- a/axlearn/experiments/text/gpt/common.py +++ b/axlearn/experiments/text/gpt/common.py @@ -466,6 +466,7 @@ def mixture_train_input_source( max_sequence_length: int, replace_newlines_with: str = REPLACE_NEWLINES_WITH, fake_input_source_cfg: Optional[InstantiableConfig] = None, + autotune_ram_budget_gb: Optional[int] = None, ) -> input_tf_data.BuildDatasetFn: """Build mixture training input source for decoder-only LM model. @@ -483,6 +484,9 @@ def mixture_train_input_source( replace_newlines_with: Value to replace newlines with in the text. fake_input_source_cfg: A config that instantiates to a BuildDatasetFn for the input source used during unittest. + autotune_ram_budget_gb: The memory budget (in GiB) the tensorflow datasets optimization + pipeline will target. Typically configure as 50%-75% of available memory. + If None, uses tensorflow defaults. Returns: A BuildDatasetFn that mixes the given list of DataMixtureComponent(s). @@ -535,6 +539,7 @@ def _set_config_for_preprocessor(p: InstantiableConfig) -> InstantiableConfig: sources=sources, weights=weights, is_training=is_training, + autotune_ram_budget_gb=autotune_ram_budget_gb, ) From 41c9a6db86e438957f4736590babbb38f2c76775 Mon Sep 17 00:00:00 2001 From: rhodes73 <64533795+rhodes73@users.noreply.github.com> Date: Tue, 5 Nov 2024 21:17:09 +0100 Subject: [PATCH 05/27] Fix flash attention layer test (#812) --- axlearn/common/flash_attention/layer_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 5ac4db3bc..0a557cb96 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -641,7 +641,7 @@ def test_extend_step( initial_state = test_layer.init_states( target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state ) - ref_initial_state = test_layer.init_states( + ref_initial_state = ref_layer.init_states( target_batch_size=batch, target_max_len=seq_len, kv_state=kv_state ) for k in ["key", "value"]: From 226d27ab7569668f2c38a35cf32d5dc5190ebdbb Mon Sep 17 00:00:00 2001 From: Matthew Hopkins Date: Tue, 5 Nov 2024 12:48:31 -0800 Subject: [PATCH 06/27] Strict job names (#814) * restrict bastion is_valid_job_name to ASCII char range * Update axlearn/cloud/common/bastion.py Co-authored-by: Mark Lee --------- Co-authored-by: Mark Lee --- axlearn/cloud/common/bastion.py | 15 ++++++++++++--- axlearn/cloud/common/bastion_test.py | 8 ++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/axlearn/cloud/common/bastion.py b/axlearn/cloud/common/bastion.py index 0901bb0a9..4a020ee71 100644 --- a/axlearn/cloud/common/bastion.py +++ b/axlearn/cloud/common/bastion.py @@ -53,6 +53,7 @@ import io import json import os +import re import shlex import shutil import signal @@ -100,6 +101,9 @@ FLAGS = flags.FLAGS +_VALID_NAME_CHARS = r"[!-~]+" # match all printing ASCII characters except space +valid_name_re = re.compile(_VALID_NAME_CHARS) + def bastion_job_flags(flag_values: flags.FlagValues = FLAGS): flags.DEFINE_string("name", None, "Name of bastion.", flag_values=flag_values, required=True) @@ -323,11 +327,16 @@ def deserialize_jobspec(f: Union[str, IO]) -> JobSpec: def is_valid_job_name(name: str) -> bool: - """Ensures that job name does not look like a path. + """Ensures job name is not path-like and only contains safe characters. - We use a permissive regex to avoid making assumptions about the underlying compute environment. + This check should avoid making assumptions about the underlying compute environment. """ - return bool(name) and ("/" not in name) and (name not in (".", "..")) and ("\n" not in name) + return ( + bool(name) + and ("/" not in name) + and (name not in (".", "..")) + and bool(valid_name_re.fullmatch(name)) + ) def _download_jobspec(job_name: str, *, remote_dir: str, local_dir: str = _JOB_DIR) -> JobSpec: diff --git a/axlearn/cloud/common/bastion_test.py b/axlearn/cloud/common/bastion_test.py index 30db616e3..adaf66117 100644 --- a/axlearn/cloud/common/bastion_test.py +++ b/axlearn/cloud/common/bastion_test.py @@ -153,6 +153,14 @@ def mock_download_job_state(job_name, *, remote_dir, **kwargs): dict(name="..test", valid=True), # This is a valid file name. dict(name="test.job..", valid=True), # This is a valid file name. dict(name="test\n", valid=False), # newline causes bastion to crash + dict(name="test", valid=True), + dict(name="test“job”test", valid=False), # pinyin quotes are invalid + dict(name="test‘job’test", valid=False), # pinyin quotes are invalid + dict(name="test\\job", valid=True), + dict(name="test,job", valid=True), + dict(name="test:job", valid=True), + dict(name="test_job", valid=True), + dict(name="test job", valid=False), ) def test_is_valid_job_name(self, name, valid): self.assertEqual(valid, is_valid_job_name(name)) From 4b559c5e98b5eefe5bf59392c924159c1c59d89a Mon Sep 17 00:00:00 2001 From: "Meng (Ethan) Li" Date: Wed, 6 Nov 2024 10:50:13 -0800 Subject: [PATCH 07/27] Support job in-place update (#816) --- axlearn/cloud/common/bastion.py | 55 ++++++- axlearn/cloud/common/bastion_test.py | 124 +++++++++++++++- axlearn/cloud/common/types.py | 2 + axlearn/cloud/gcp/bundler.py | 3 + axlearn/cloud/gcp/job.py | 13 +- axlearn/cloud/gcp/job_test.py | 13 +- axlearn/cloud/gcp/jobs/gke_runner.py | 46 +++++- axlearn/cloud/gcp/jobs/gke_runner_test.py | 155 +++++++++++++++++++- axlearn/cloud/gcp/jobs/launch.py | 75 +++++++++- axlearn/cloud/gcp/jobs/launch_test.py | 55 ++++++- axlearn/cloud/gcp/jobs/launch_utils.py | 53 +++++++ axlearn/cloud/gcp/jobs/launch_utils_test.py | 86 +++++++++++ 12 files changed, 653 insertions(+), 27 deletions(-) diff --git a/axlearn/cloud/common/bastion.py b/axlearn/cloud/common/bastion.py index 4a020ee71..6c8bdd28f 100644 --- a/axlearn/cloud/common/bastion.py +++ b/axlearn/cloud/common/bastion.py @@ -98,6 +98,7 @@ _LOG_DIR = "/var/tmp/logs" # Use /var/tmp/ since /tmp/ is cleared every 10 days. _JOB_DIR = "/var/tmp/jobs" _BASTION_SERIALIZED_JOBSPEC_ENV_VAR = "_BASTION_SERIALIZED_JOBSPEC" +BASTION_JOB_VERSION_ENV_VAR = "BASTION_JOB_VERSION" FLAGS = flags.FLAGS @@ -177,6 +178,8 @@ class JobLifecycleState(str, enum.Enum): PREEMPTING = "PREEMPTING" # Job is rescheduling. RESCHEDULING = "RESCHEDULING" + # Job is updating. + UPDATING = "UPDATING" # Job is cancelling. Command is terminating. CANCELLING = "CANCELLING" # Job has completed/terminated the command, is running cleanup command (if any). @@ -239,6 +242,8 @@ def _validate_job_metadata(metadata: JobMetadata): raise ValidationError(f"Expected {metadata.resources=} to have string keys and int values.") if not isinstance(metadata.priority, int): raise ValidationError(f"Expected {metadata.priority=} to be an int.") + if metadata.version is not None and not isinstance(metadata.version, int): + raise ValidationError(f"Expected {metadata.version=} to be None or an int.") def _validate_jobspec(jobspec: JobSpec): @@ -891,6 +896,12 @@ def _sync_jobs(self): else: curr_job = self._active_jobs[job_name] updated_job = active_jobs[job_name] + if updated_job.spec.metadata.version != curr_job.spec.metadata.version: + # When a new version is detected, add "updated" in the metadata to signal + # job state change and job relaunch. + # Note: "updated" is a transient state and should not be persisted. + updated_job.state.metadata["updated"] = True + logging.info("Detected a different version of job %s", job_name) curr_job.spec, curr_job.state = updated_job.spec, updated_job.state # pylint: disable-next=too-many-statements @@ -935,10 +946,15 @@ def _update_single_job(self, job: Job) -> Job: self._append_to_job_history( job, msg=f"ACTIVE: start process command: {job.spec.command} " - f"with metadata: {job.state.metadata}", + f"with metadata: {job.state.metadata} and version: {job.spec.metadata.version}", state=JobLifecycleState.STARTING, ) env_vars = {f"BASTION_{k.upper()}": v for k, v in job.state.metadata.items()} + + if job.spec.metadata.version: + # For backwards compatibility, only set the version in env when not None. + env_vars.update({BASTION_JOB_VERSION_ENV_VAR: job.spec.metadata.version}) + serialized_jobspec = io.StringIO() serialize_jobspec(job.spec, serialized_jobspec) env_vars |= {_BASTION_SERIALIZED_JOBSPEC_ENV_VAR: serialized_jobspec.getvalue()} @@ -1070,8 +1086,19 @@ def _update_jobs(self): new_tier = verdict.metadata.get("tier") changed_tiers = old_tier != new_tier - # Resume if not running, or keep running if scheduling tier did not change. - if job.state.status == JobStatus.PENDING or not changed_tiers: + jobspec_changed = job.state.metadata.get("updated") + + # Jobspec changed, trigger a restart of the runner. + if jobspec_changed: + self._append_to_job_history( + job, + msg="UPDATING: Detected updated jobspec. Will restart the runner " + "by sending to PENDING state", + state=JobLifecycleState.UPDATING, + ) + job.state.status = JobStatus.PENDING + elif job.state.status == JobStatus.PENDING or not changed_tiers: + # Resume if not running, or keep running if scheduling tier did not change. job.state.status = JobStatus.ACTIVE else: # Job changed scheduling tiers, and must be restarted on the new tier. @@ -1288,3 +1315,25 @@ def submit_job(self, job_name: str, *, job_spec_file: str): else: # Upload the job for bastion to pickup. tf_io.gfile.copy(job_spec_file, dst) + + def get_job(self, job_name: str) -> JobSpec: + job_path = os.path.join(self.active_job_dir, job_name) + if not tf_io.gfile.exists(job_path): + raise ValueError(f"Unable to locate jobspec {job_path}") + + with tempfile.TemporaryDirectory() as tmpdir: + job_spec = _download_jobspec(job_name, remote_dir=self.active_job_dir, local_dir=tmpdir) + return job_spec + + def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec: + dst = os.path.join(self.active_job_dir, job_name) + if not tf_io.gfile.exists(dst): + raise ValueError(f"Unable to locate jobspec {dst}") + + with tempfile.NamedTemporaryFile("w") as f: + serialize_jobspec(job_spec, f) + # Upload the job for bastion to pickup. + tf_io.gfile.copy(f.name, dst, overwrite=True) + logging.info("Job %s is updating.", job_name) + + return job_spec diff --git a/axlearn/cloud/common/bastion_test.py b/axlearn/cloud/common/bastion_test.py index adaf66117..bd1dc6736 100644 --- a/axlearn/cloud/common/bastion_test.py +++ b/axlearn/cloud/common/bastion_test.py @@ -781,6 +781,17 @@ def test_sync_jobs(self): resources={"test": 8}, ), ), + new_jobspec( + name="job3", + command="", + metadata=JobMetadata( + user_id="user1", + project_id="project1", + creation_time=datetime(1900, 1, 1, 0, 0, 0, 1), + resources={"test": 8}, + version=1, + ), + ), ] # Write them to the Bastion submission directory. for spec in specs: @@ -795,7 +806,30 @@ def test_sync_jobs(self): # Download the jobspecs. mock_bastion._sync_jobs() # Confirm expected jobs were downloaded. - self.assertSequenceEqual(list(mock_bastion._active_jobs), ["job1"]) + self.assertSequenceEqual( + sorted(list(mock_bastion._active_jobs)), sorted(["job1", "job3"]) + ) + + # Submit the job again to update the version. + updated_job_spec = new_jobspec( + name="job3", + command="", + metadata=JobMetadata( + user_id="user1", + project_id="project1", + creation_time=datetime(1900, 1, 1, 0, 0, 0, 1), + resources={"test": 8}, + version=2, + ), + ) + bastion_dir.update_job(updated_job_spec.name, job_spec=updated_job_spec) + + # Download the jobspecs. + mock_bastion._sync_jobs() + # Confirm the update is received. + self.assertEqual( + mock_bastion._active_jobs.get(updated_job_spec.name).state.metadata["updated"], True + ) @parameterized.product( [ @@ -1107,6 +1141,23 @@ def mock_proc(cmd, **kwargs): command_proc=mock_proc("command"), cleanup_proc=None, # No cleanup_proc for ACTIVE. ), + # This job will go from ACTIVE to PENDING, since it is being updated. + "updating": Job( + spec=new_jobspec( + name="updating", + command="command", + cleanup_command="cleanup", + metadata=JobMetadata( + user_id="e", + project_id="project1", + creation_time=yesterday + timedelta(seconds=2), + resources={"v4": 1}, # Fits within the v4 budget in project1. + ), + ), + state=JobState(status=JobStatus.ACTIVE, metadata={"tier": 0, "updated": True}), + command_proc=mock_proc("command"), + cleanup_proc=None, # No cleanup_proc for ACTIVE. + ), # This job will go from ACTIVE to CLEANING. "cleaning": Job( spec=new_jobspec( @@ -1196,6 +1247,7 @@ def mock_proc(cmd, **kwargs): "resume": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}), "active": JobState(status=JobStatus.ACTIVE, metadata={"tier": 0}), "preempt": JobState(status=JobStatus.PENDING), + "updating": JobState(status=JobStatus.PENDING, metadata={"tier": 0}), "cleaning": JobState(status=JobStatus.CLEANING, metadata={"tier": 0}), "cleaning_cancel": JobState(status=JobStatus.CLEANING), "completed": JobState(status=JobStatus.COMPLETED), @@ -1267,6 +1319,8 @@ def mock_proc(cmd, **kwargs): expected_msg = { "resume": "ACTIVE: start process command", "preempt": "PENDING: pre-empting", + "updating": "UPDATING: Detected updated jobspec. Will restart " + "the runner by sending to PENDING state", "cleaning": "CLEANING: process finished", "cleaning_cancel": "CLEANING: process terminated", "completed": "COMPLETED: cleanup finished", @@ -1574,6 +1628,74 @@ def test_delete(self, spec_exists): remote_dir=bastion_dir.user_states_dir, ) + @parameterized.parameters(True, False) + def test_get(self, spec_exists): + job_name = "test-job" + bastion_dir = ( + bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate() + ) + + patch_tfio = mock.patch.multiple( + f"{bastion.__name__}.tf_io.gfile", + exists=mock.MagicMock(return_value=spec_exists), + copy=mock.DEFAULT, + ) + + mock_deserialize_jobspec = mock.patch( + f"{bastion.__name__}.deserialize_jobspec", return_value=None + ) + + if spec_exists: + ctx = contextlib.nullcontext() + else: + ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec") + + with ctx, mock_deserialize_jobspec, patch_tfio as mock_tfio: + bastion_dir.get_job(job_name) + if spec_exists: + mock_tfio["copy"].assert_called() + self.assertEqual( + mock_tfio["copy"].call_args[0][0], + os.path.join(bastion_dir.active_job_dir, job_name), + ) + self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True) + else: + mock_tfio["copy"].assert_not_called() + + @parameterized.parameters(True, False) + def test_update(self, spec_exists): + job_name = "test-job" + bastion_dir = ( + bastion.BastionDirectory.default_config().set(root_dir="test-dir").instantiate() + ) + + patch_tfio = mock.patch.multiple( + f"{bastion.__name__}.tf_io.gfile", + exists=mock.MagicMock(return_value=spec_exists), + copy=mock.DEFAULT, + ) + + mock_serialize_jobspec = mock.patch( + f"{bastion.__name__}.serialize_jobspec", return_value=None + ) + + if spec_exists: + ctx = contextlib.nullcontext() + else: + ctx = self.assertRaisesRegex(ValueError, "Unable to locate jobspec") + + with ctx, mock_serialize_jobspec, patch_tfio as mock_tfio: + bastion_dir.update_job(job_name, job_spec=None) + if spec_exists: + mock_tfio["copy"].assert_called() + self.assertEqual( + mock_tfio["copy"].call_args[0][1], + os.path.join(bastion_dir.active_job_dir, job_name), + ) + self.assertEqual(mock_tfio["copy"].call_args.kwargs["overwrite"], True) + else: + mock_tfio["copy"].assert_not_called() + if __name__ == "__main__": absltest.main() diff --git a/axlearn/cloud/common/types.py b/axlearn/cloud/common/types.py index 98163e7c9..0c61b8088 100644 --- a/axlearn/cloud/common/types.py +++ b/axlearn/cloud/common/types.py @@ -29,6 +29,8 @@ class JobMetadata: # It is not used by the bastion directly. # TODO(haijing-fu): make it as a required field. job_id: Optional[str] = None + # Version of the job. + version: Optional[int] = None @dataclasses.dataclass diff --git a/axlearn/cloud/gcp/bundler.py b/axlearn/cloud/gcp/bundler.py index f8ae4e5a0..524f76e37 100644 --- a/axlearn/cloud/gcp/bundler.py +++ b/axlearn/cloud/gcp/bundler.py @@ -148,6 +148,9 @@ def from_spec( cfg.project = cfg.project or gcp_settings("project", required=False, fv=fv) cfg.repo = cfg.repo or gcp_settings("docker_repo", required=False, fv=fv) cfg.dockerfile = cfg.dockerfile or gcp_settings("default_dockerfile", required=False, fv=fv) + # The value from from_spec is a str and will result in wrong condition. + if isinstance(cfg.is_async, str): + cfg.is_async = cfg.is_async.lower() != "false" return cfg # pylint: disable-next=no-self-use,unused-argument diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 0d09f45c2..66dcff767 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -22,7 +22,11 @@ from absl import flags from google.auth.credentials import Credentials -from axlearn.cloud.common.bastion import _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, deserialize_jobspec +from axlearn.cloud.common.bastion import ( + _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, + BASTION_JOB_VERSION_ENV_VAR, + deserialize_jobspec, +) from axlearn.cloud.common.bundler import BaseDockerBundler from axlearn.cloud.common.job import Job from axlearn.cloud.common.utils import parse_kv_flags, subprocess_run @@ -52,6 +56,9 @@ # Set 80% of the max value as the requested memory. _MEMORY_REQUEST_PERCENTAGE = 0.8 +# A label added to the jobset to indicate job version. +BASTION_JOB_VERSION_LABEL = "bastion-job-version" + class GCPJob(Job): """Base GCP Job definition.""" @@ -552,6 +559,7 @@ def _build_container(self) -> Nested[Any]: # Env var values should always be strings. env=k8s_env_vars, volumeMounts=volume_mounts, + imagePullPolicy="Always", ) def _build_uploader_container(self) -> Nested[Any]: @@ -696,6 +704,9 @@ def _build_pod(self) -> Nested[Any]: } ) + if os.environ.get(BASTION_JOB_VERSION_ENV_VAR): + labels.update({BASTION_JOB_VERSION_LABEL: os.environ.get(BASTION_JOB_VERSION_ENV_VAR)}) + if os.environ.get(_BASTION_SERIALIZED_JOBSPEC_ENV_VAR): spec = deserialize_jobspec( io.StringIO(os.environ.get(_BASTION_SERIALIZED_JOBSPEC_ENV_VAR)) diff --git a/axlearn/cloud/gcp/job_test.py b/axlearn/cloud/gcp/job_test.py index 9719438c3..5b8b4174a 100644 --- a/axlearn/cloud/gcp/job_test.py +++ b/axlearn/cloud/gcp/job_test.py @@ -27,6 +27,7 @@ from axlearn.cloud.common.bastion import ( _BASTION_SERIALIZED_JOBSPEC_ENV_VAR, + BASTION_JOB_VERSION_ENV_VAR, deserialize_jobspec, new_jobspec, serialize_jobspec, @@ -39,6 +40,7 @@ from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.job import ( _MEMORY_REQUEST_PERCENTAGE, + BASTION_JOB_VERSION_LABEL, CPUJob, GCSFuseMount, HostMount, @@ -328,12 +330,13 @@ class Config(Bundler.Config): env={ "BASTION_TIER": "0", _BASTION_SERIALIZED_JOBSPEC_ENV_VAR: _create_serialized_job_spec(1, "user-1"), + BASTION_JOB_VERSION_ENV_VAR: "1", }, reservation="test-reservation", expect_reserved=True, ), dict( - env={"BASTION_TIER": "1"}, + env={"BASTION_TIER": "1", BASTION_JOB_VERSION_ENV_VAR: "2"}, reservation="test-reservation", expect_reserved=False, ), @@ -421,6 +424,8 @@ def test_build_pod( else: self.fail("host-mount not found!") + self.assertEqual(container["imagePullPolicy"], "Always") + self.assertIn("limits", resources) tpu_type = infer_tpu_type(cfg.accelerator.instance_type) tpu_characteristics = USER_FACING_NAME_TO_SYSTEM_CHARACTERISTICS[tpu_type] @@ -515,6 +520,12 @@ def test_build_pod( self.assertNotIn("job-priority", node_selector) self.assertNotIn("user-id", labels) + if BASTION_JOB_VERSION_ENV_VAR in env: + job_version = env.get(BASTION_JOB_VERSION_ENV_VAR) + self.assertEqual(job_version, labels.get(BASTION_JOB_VERSION_LABEL, None)) + else: + self.assertNotIn(BASTION_JOB_VERSION_LABEL, labels) + if enable_tpu_smart_repair: self.assertIn( "cloud.google.com/gke-tpu-auto-restart", diff --git a/axlearn/cloud/gcp/jobs/gke_runner.py b/axlearn/cloud/gcp/jobs/gke_runner.py index fc186925b..a1dc18392 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner.py +++ b/axlearn/cloud/gcp/jobs/gke_runner.py @@ -32,7 +32,11 @@ import kubernetes as k8s from absl import app, flags, logging -from axlearn.cloud.common.bastion import JobLifecycleEvent, JobLifecycleState +from axlearn.cloud.common.bastion import ( + BASTION_JOB_VERSION_ENV_VAR, + JobLifecycleEvent, + JobLifecycleState, +) from axlearn.cloud.common.bundler import get_bundler_config from axlearn.cloud.common.event_queue import BaseQueueClient from axlearn.cloud.common.utils import ( @@ -44,7 +48,7 @@ from axlearn.cloud.gcp.bundler import ArtifactRegistryBundler from axlearn.cloud.gcp.config import gcp_settings from axlearn.cloud.gcp.event_queue import event_queue_from_config -from axlearn.cloud.gcp.job import GCPJob, GKEJob, GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GCPJob, GKEJob, GPUGKEJob, TPUGKEJob from axlearn.cloud.gcp.jobs import runner_utils from axlearn.cloud.gcp.jobs.tpu_runner import with_tpu_training_defaults from axlearn.cloud.gcp.node_pool import ( @@ -82,6 +86,21 @@ def _infer_reservation(jobset_spec: dict) -> Optional[str]: return None +def _infer_job_version(jobset_spec: dict) -> Optional[int]: + """Infers job version given a jobset spec.""" + try: + for job in jobset_spec["replicatedJobs"]: + labels = job["template"]["spec"]["template"]["metadata"]["labels"] + # If any job has a job version label, return it. + job_version = labels.get(BASTION_JOB_VERSION_LABEL, None) + + if job_version is not None: + return int(job_version) + except (TypeError, KeyError) as e: + logging.warning("Failed to infer job version: %s.", e) + return None + + class GKERunnerJob(GCPJob): """Launches and monitors a GKE job via k8s JobSet API.""" @@ -231,6 +250,7 @@ class Status(enum.Enum): STARTUPPOLICYCOMPLETED: JobSet completed StartupPolicy. READY: JobSet is ready (all Jobs are ready). SUCCEEDED: JobSet succeeded (all Jobs succeeded). Typically also manifests as COMPLETED. + UPDATING: Job will be relaunched with new specs. RESCHEDULED: Job was rescheduled onto a different tier. """ @@ -243,6 +263,7 @@ class Status(enum.Enum): STARTUPPOLICYCOMPLETED = "STARTUPPOLICYCOMPLETED" READY = "READY" SUCCEEDED = "SUCCEEDED" + UPDATING = "UPDATING" RESCHEDULED = "RESCHEDULED" # TODO(markblee): Consider moving some of the logic here into the inner impl. @@ -261,6 +282,20 @@ def _get_status(self) -> Status: if runner_utils.should_recreate_job(tier, reservation): return GKERunnerJob.Status.RESCHEDULED + expected_job_version = os.environ.get(BASTION_JOB_VERSION_ENV_VAR, None) + current_job_version = _infer_job_version(resp["spec"]) + + # If the job is expected to run with a newer version, relaunch it. + if expected_job_version is not None and ( + current_job_version is None or int(expected_job_version) > current_job_version + ): + logging.info( + "Current job version is %s; expected job version is %s", + current_job_version, + expected_job_version, + ) + return GKERunnerJob.Status.UPDATING + # According to stogner@google.com, it's possible for "conditions" to be missing until # the overall jobset has completed. However, if the jobset does complete, "conditions" # should be a reliable indicator of overall completion status. @@ -428,6 +463,9 @@ def _execute(self): elif status == GKERunnerJob.Status.RESCHEDULED: logging.info("Jobset does not match scheduling tier. Rescheduling the jobset...") self._reschedule() + elif status == GKERunnerJob.Status.UPDATING: + logging.info("Newer job version is available. Relaunching the jobset...") + self._inner._delete() # pylint: disable=protected-access elif status == GKERunnerJob.Status.NOT_STARTED: logging.info("Task does not exist. Submitting it now...") # Only bundle on first start, not if we're resuming monitoring. @@ -546,7 +584,7 @@ def _delete_k8s_jobset_and_node_pools( @catch_auth def main(argv: Sequence[str], *, flag_values: flags.FlagValues = FLAGS): - action = parse_action(argv, options=["start", "list", "stop"]) + action = parse_action(argv, options=["start", "update", "list", "stop"]) project = gcp_settings("project", fv=flag_values) zone = gcp_settings("zone", fv=flag_values) @@ -554,7 +592,7 @@ def main(argv: Sequence[str], *, flag_values: flags.FlagValues = FLAGS): load_kube_config(project=project, zone=zone, cluster=cluster) - if action == "start": + if action in ("start", "update"): command = " ".join(argv[2:]) if not command: raise app.UsageError("Command is required.") diff --git a/axlearn/cloud/gcp/jobs/gke_runner_test.py b/axlearn/cloud/gcp/jobs/gke_runner_test.py index 5bee068a7..ea59ec28d 100644 --- a/axlearn/cloud/gcp/jobs/gke_runner_test.py +++ b/axlearn/cloud/gcp/jobs/gke_runner_test.py @@ -12,16 +12,27 @@ from absl import app, flags from absl.testing import parameterized +from axlearn.cloud.common.bastion import BASTION_JOB_VERSION_ENV_VAR from axlearn.cloud.gcp import bundler, node_pool_provisioner -from axlearn.cloud.gcp.job import GPUGKEJob, TPUGKEJob +from axlearn.cloud.gcp.job import BASTION_JOB_VERSION_LABEL, GPUGKEJob, TPUGKEJob from axlearn.cloud.gcp.jobs import gke_runner from axlearn.cloud.gcp.jobs.bastion_vm_test import _mock_job -from axlearn.cloud.gcp.jobs.gke_runner import _get_runner_or_exit, _infer_reservation +from axlearn.cloud.gcp.jobs.gke_runner import ( + _get_runner_or_exit, + _infer_job_version, + _infer_reservation, +) from axlearn.cloud.gcp.node_pool import PRE_PROVISIONER_LABEL from axlearn.cloud.gcp.test_utils import mock_gcp_settings -def _mock_replicated_jobs(reservations: Sequence[str]): +def _mock_replicated_jobs(reservations: Sequence[str], bastion_job_version: Optional[int] = None): + job_version_label = ( + {"metadata": {"labels": {BASTION_JOB_VERSION_LABEL: str(bastion_job_version)}}} + if bastion_job_version + else {} + ) + return [ { "template": { @@ -34,6 +45,7 @@ def _mock_replicated_jobs(reservations: Sequence[str]): else {"cloud.google.com/gke-spot": "true"} ) }, + **job_version_label, }, } } @@ -304,11 +316,33 @@ def test_exit(self, status, enable_pre_provisioner): def test_infer_reservation(self, status: dict, expected: Optional[str] = None): self.assertEqual(expected, _infer_reservation(status)) + @parameterized.parameters( + dict( + status=dict( + replicatedJobs=_mock_replicated_jobs(["test-reservation"], bastion_job_version=None) + ), + expected=None, + ), + dict( + status=dict( + replicatedJobs=_mock_replicated_jobs(["test-reservation"], bastion_job_version=1) + ), + expected=1, + ), + dict( + status=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"])), + expected=None, + ), + ) + def test_infer_job_version(self, status: dict, expected: Optional[str] = None): + self.assertEqual(expected, _infer_job_version(status)) + @parameterized.product( ( # Conditions is set, so we use it. dict( tier=None, + job_version=None, status=dict( conditions=[ dict(type="COMPLETED", status="TRUE"), @@ -321,6 +355,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore conditions with status.lower() != "true". dict( tier=None, + job_version=None, status=dict( conditions=[ dict(type="COMPLETED", status="FALSE"), @@ -334,6 +369,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing conditions entirely, fallback to child job statuses. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=1, succeeded=0), @@ -347,6 +383,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore conditions with status.lower() != "true". dict( tier=None, + job_version=None, status=dict( conditions=[dict(type="COMPLETED", status="FALSE")], replicatedJobsStatus=[ @@ -361,6 +398,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # or until replicated job statuses change. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=1, ready=1, succeeded=0), @@ -373,6 +411,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # At least one job failed without conditions, and tier does not match. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=1, ready=1, succeeded=0), @@ -385,6 +424,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Number of replicated job statuses do not match slices. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=1, succeeded=0), @@ -397,6 +437,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All replicated jobs succeeded. No need to wait for jobset conditions. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=2), @@ -409,6 +450,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Ignore active and missing statuses. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=1, ready=1), @@ -421,6 +463,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing jobset is reported as "not started". dict( tier=None, + job_version=None, status=k8s.client.exceptions.ApiException(status=404), spec=None, num_slices=1, @@ -429,6 +472,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All statuses are 0. dict( tier=None, + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=0), @@ -441,6 +485,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # All statuses are 0 and tiers do not match (thus will be recreated). dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(failed=0, ready=0, succeeded=0), @@ -453,6 +498,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Jobset reservation and bastion tier do not match. dict( tier="1", + job_version=None, status={}, spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"])), num_slices=2, @@ -461,6 +507,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Jobset reservation and bastion tier do not match. dict( tier="1", + job_version=None, status={}, spec=dict(replicatedJobs=_mock_replicated_jobs(["spot", "test-reservation"])), num_slices=2, @@ -470,6 +517,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # In this case, we allow the job to keep running. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=2, ready=2), @@ -482,6 +530,7 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): # Missing reservation / invalid spec will be treated as spot. dict( tier="0", + job_version=None, status=dict( replicatedJobsStatus=[ dict(active=2, ready=2), @@ -491,6 +540,58 @@ def test_infer_reservation(self, status: dict, expected: Optional[str] = None): num_slices=2, expected=gke_runner.GKERunnerJob.Status.READY, ), + # Job version has increased from None. + dict( + tier="0", + job_version=1, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], None)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.UPDATING, + ), + # Job version has increased from a non-None number. + dict( + tier="0", + job_version=4, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 3)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.UPDATING, + ), + # Job version has decreased, in which case, no update. + dict( + tier="0", + job_version=1, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 2)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.READY, + ), + # Job version is set to None, in which case, no update. + dict( + tier="0", + job_version=None, + status=dict( + replicatedJobsStatus=[ + dict(active=1, ready=1), + ], + ), + spec=dict(replicatedJobs=_mock_replicated_jobs(["test-reservation"], 2)), + num_slices=1, + expected=gke_runner.GKERunnerJob.Status.READY, + ), ), enable_pre_provisioner=(None, False, True), ) @@ -500,6 +601,7 @@ def test_get_status( num_slices: int, expected: gke_runner.GKERunnerJob.Status, tier: str, + job_version: Optional[int], spec: dict, enable_pre_provisioner: Optional[bool] = None, ): @@ -519,7 +621,9 @@ def test_get_status( mock_get_status = mock.Mock(return_value=dict(status=status, spec=spec)) with ( - mock.patch.dict("os.environ", {"BASTION_TIER": tier}), + mock.patch.dict( + "os.environ", {"BASTION_TIER": tier, BASTION_JOB_VERSION_ENV_VAR: job_version} + ), mock.patch( "kubernetes.client.CustomObjectsApi", return_value=mock.Mock(get_namespaced_custom_object_status=mock_get_status), @@ -835,6 +939,47 @@ def test_start(self, enable_pre_provisioner): job._inner.execute.assert_called() # pytype: disable=attribute-error + @parameterized.parameters(None, False, True) + def test_update(self, enable_pre_provisioner): + with self._job_config( + name="test-name", + cluster="test-cluster", + service_account="test-sa", + enable_pre_provisioner=enable_pre_provisioner, + ) as ( + cfg, + _, + ): + cfg.bundler.set(image="test") + + job: gke_runner.TPUGKERunnerJob = cfg.set( + command="", + status_interval_seconds=0, + enable_pre_provisioner=enable_pre_provisioner, + ).instantiate() + + mock_job = mock.patch.multiple( + job, + _get_status=mock.Mock( + side_effect=[ + gke_runner.GKERunnerJob.Status.UPDATING, + gke_runner.GKERunnerJob.Status.COMPLETED, + ] + ), + _get_job_credentials=mock.DEFAULT, + _delete=mock.DEFAULT, + _inner=mock.DEFAULT, + _pre_provisioner=mock.DEFAULT, + ) + + with mock_job: + job._execute() + + # pytype: disable=attribute-error + job._pre_provisioner.delete_for.assert_not_called() + job._inner._delete.assert_called() + # pytype: enable=attribute-error + class MainTest(parameterized.TestCase): """Tests CLI entrypoint.""" @@ -857,7 +1002,7 @@ def test_get_runner_or_exit(self, instance_type: str, expected: Union[Exception, dict(runner=gke_runner.TPUGKERunnerJob, instance_type="tpu-v4-8"), dict(runner=gke_runner.GPUGKERunnerJob, instance_type="gpu-a3-highgpu-8g-256"), ], - action=["start", "stop"], + action=["start", "stop", "update"], ) def test_load_kube_config(self, action, runner, instance_type): # load_kube_config should only be called if using gke action. diff --git a/axlearn/cloud/gcp/jobs/launch.py b/axlearn/cloud/gcp/jobs/launch.py index d3f70c0a6..48d7cb3ce 100644 --- a/axlearn/cloud/gcp/jobs/launch.py +++ b/axlearn/cloud/gcp/jobs/launch.py @@ -10,9 +10,10 @@ that decides, for a given CLI action (e.g. 'start') and instance type (e.g. 'tpu-v4-8'), whether the launcher can be used. See `_LAUNCHERS` for a full list, and `BastionManagedTPUJob` for an example. -Possible actions: [start|stop|list] +Possible actions: [start|update|stop|list] Start: submits a job to the queue. + Update: updates a job without resubmission. Stop: stops the job or removes a job from the queue. List: lists jobs and their statuses. @@ -41,9 +42,27 @@ --bundler_spec=dockerfile=Dockerfile \ --bundler_spec=build_arg1=my-build-arg ... + # Update an existing job without resubmission. + axlearn gcp launch update --instance_type=tpu-v4-32 ... -- python3 my_script2.py + # To stop a job. axlearn gcp launch stop --name=... --instance_type=tpu +More on the Update command: + + The update command allows updating bundles and job command of an existing job + without resubmission. It currently only works with axlearn.cloud.gcp.jobs.gke_runner. + + Resource related flags including instance_type, num_replicas and enable_pre_provisioner + are not allowed to change. + + When bundles are updated before the job update, job will run with new bundles. + If bundle update is not desired, use `--bundler_spec=skip_bundle=True` flag + to skip bundle update. + + To be able to update the job without re-provisioning the resources (e.g. TPU node pools), + use `--enable_pre_provisioner` to submit the job. + """ # pylint: disable=redefined-outer-name,protected-access @@ -83,6 +102,7 @@ project_usage_table, serialized_flags_for_job, user_usage_table, + validate_resource_flags, with_k8s_jobset_state, with_qrm_tpu_state, ) @@ -256,8 +276,8 @@ def from_flags(cls, fv: flags.FlagValues, *, command: str, action: str, **kwargs # We use the bundler defined by the runner impl, ensuring that bundling is consistent # between local and bastion. cfg.bundler = None - # Construct runner only for start. - if action == "start": + # Construct runner only for start and update. + if action in ("start", "update"): cfg.runner = cls.runner.from_flags(fv, command=command) runner_flags = " ".join(serialized_flags_for_job(fv, cls.runner)) cfg.command = f"python3 -m {cls.runner.__module__} {action} {runner_flags} -- {command}" @@ -353,6 +373,41 @@ def _execute(self) -> JobSpec: ) return jobspec + def _update(self) -> JobSpec: + """Update an existing job without resubmission. + + This will fetch the existing job from Bastion, change + the trainer command, increment the version in metadata, and then update the job on Bastion. + + The resource related flags including instance_type, num_replicas and enable_pre_provisioner + are not allowed to change. + """ + cfg: BaseBastionManagedJob.Config = self.config + + # Get current job spec. + job_spec = self._bastion_dir.get_job(job_name=cfg.name) + + if self._runner and self._runner.bundler: + self._runner.bundler.bundle(cfg.name) + + logging.info("Starting update for job name %s", cfg.name) + logging.info("Command: %s", cfg.command) + + # Update the job version. + job_version = job_spec.metadata.version or 0 + job_spec.metadata.version = job_version + 1 + + # The resource related flags are not allowed to change. + validate_resource_flags(job_spec.command, cfg.command) + + job_spec.command = cfg.command + + logging.info("Updated jobspec: %s", job_spec) + + jobspec = self._bastion_dir.update_job(cfg.name, job_spec=job_spec) + + return jobspec + # TODO(markblee): Add a BastionManagedCPUJob. class BastionManagedTPUJob(BaseBastionManagedJob): @@ -451,7 +506,7 @@ def define_flags(cls, fv: flags.FlagValues): @classmethod def from_flags(cls, fv: flags.FlagValues, *, command: str, action: str, **kwargs) -> Config: # Set default docker flags. These will automatically propagate to the runner on the bastion. - if action == "start": + if action in ("start", "update"): fv.set_default("bundler_type", CloudBuildBundler.TYPE) cfg: BastionManagedGKEJob.Config = super().from_flags( fv, command=command, action=action, **kwargs @@ -523,12 +578,12 @@ def _execute(self) -> JobSpec: Launcher( job_cls=BastionManagedGKEJob.with_runner(gke_runner.TPUGKERunnerJob), matcher=config_for_function(match_by_regex).set( - match_regex=dict(start=r"tpu-v.+-(\d)+", list=r"tpu.*", stop=r"tpu.*"), + match_regex=dict(start=r"tpu-v.+-(\d)+", update=r"tpu.*", list=r"tpu.*", stop=r"tpu.*"), gcp_api=GCPAPI.GKE.value, ), description=( "Supports launching TPU jobs via GKE. " - "For 'start', provide --gcp_api=gke, as well as the full instance type, " + "For 'start' or 'update', provide --gcp_api=gke, as well as the full instance type, " "e.g. --instance_type=tpu-v4-8. " "For 'list' or 'stop', provide --gcp_api=gke as well as the accelerator type, " "e.g. --instance_type=tpu." @@ -576,7 +631,7 @@ def main(_): if FLAGS.instance_type is None: raise app.UsageError("--instance_type is required.") - action = parse_action(sys.argv, options=["start", "stop", "list"], default="start") + action = parse_action(sys.argv, options=["start", "stop", "update", "list"], default="start") launcher = _get_launcher_or_exit( action=action, instance_type=FLAGS.instance_type, @@ -604,6 +659,8 @@ def main(_): job._list() elif action == "stop": job._delete() + elif action == "update": + job._update() else: raise app.UsageError(f"Unsupported action {action}") @@ -635,7 +692,9 @@ def _private_flags(): # Allow instance_type to be None when running --help without any flags. On the other hand, if # instance_type is provided when running --help, we show additional help info. if FLAGS.instance_type: - action = parse_action(sys.argv, options=["start", "stop", "list"], default="start") + action = parse_action( + sys.argv, options=["start", "update", "stop", "list"], default="start" + ) launcher = _get_launcher_or_exit( action=action, instance_type=FLAGS.instance_type, diff --git a/axlearn/cloud/gcp/jobs/launch_test.py b/axlearn/cloud/gcp/jobs/launch_test.py index 776896f19..c8991221d 100644 --- a/axlearn/cloud/gcp/jobs/launch_test.py +++ b/axlearn/cloud/gcp/jobs/launch_test.py @@ -4,6 +4,7 @@ # pylint: disable=protected-access import contextlib +import copy from datetime import datetime from typing import Optional from unittest import mock @@ -17,6 +18,7 @@ from axlearn.cloud.common.bundler import BUNDLE_EXCLUDE from axlearn.cloud.common.job import Job from axlearn.cloud.common.scheduler import JobMetadata +from axlearn.cloud.common.types import JobSpec from axlearn.cloud.gcp import bundler from axlearn.cloud.gcp import job as gcp_job from axlearn.cloud.gcp.jobs import bastion_vm, gke_runner, launch, tpu_runner @@ -461,7 +463,7 @@ class TestBastionManagedGKEJob(TestWithTemporaryCWD): cluster="test-cluster", ), ], - action=["start", "list"], + action=["start", "list", "update"], ) def test_tpu_flags( self, @@ -550,7 +552,7 @@ def test_tpu_flags( cfg = tpu_gke_job.from_flags(fv, **from_flags_kwargs) self.assertIsNone(cfg.bundler) - if action == "start": + if action in ("start", "update"): self.assertIsNotNone(cfg.runner) self.assertIsNotNone(cfg.runner.bundler) self.assertIn("tpu", cfg.runner.bundler.extras) @@ -581,7 +583,7 @@ def test_tpu_flags( # Test infer tpu resources. self.assertEqual({"v4": 16}, maybe_instantiate(cfg.resources)) - if action == "start": + if action in ("start", "update"): # Make sure command is expected. for flag in ["name", "bundler_type", "instance_type"]: if fv[flag].value is not None: @@ -601,7 +603,7 @@ def test_tpu_flags( ) # Bundler should be propagated to runner. - if action == "start": + if action in ("start", "update"): self.assertIsNotNone(job.runner.bundler) @parameterized.parameters( @@ -638,3 +640,48 @@ class FakeBastionDirectory(BastionDirectory): else: mock_execute.assert_called_once() self.assertIsNotNone(job_spec) + + @parameterized.parameters(None, 0, 1) + def test_update(self, job_version): + job_name = "test_job0" + + job_spec = new_jobspec( + name=job_name, + command="command", + metadata=JobMetadata( + user_id="test_user", + project_id="test_project", + creation_time=datetime.now(), + resources={"v4": 8}, + job_id="test-id0", + version=job_version, + ), + ) + + class FakeBastionDirectory(BastionDirectory): + def get_job(self, job_name: str) -> JobSpec: + return copy.deepcopy(job_spec) + + def update_job(self, job_name: str, *, job_spec: JobSpec) -> JobSpec: + return job_spec + + tpu_gke_job = BastionManagedGKEJob.with_runner(_DummyRunner) + cfg = tpu_gke_job.default_config().set( + **_common_bastion_managed_job_kwargs(), + namespace="default", + project="test-project", + cluster="test-cluster", + bastion_dir=FakeBastionDirectory.default_config().set(root_dir="temp_dir"), + ) + cfg.set(name=job_name) + patch_kube_config = mock.patch(f"{launch.__name__}.load_kube_config") + + with patch_kube_config: + job: BastionManagedGKEJob = cfg.instantiate() + + # Update the job. + updated_job_spec = job._update() + + updated_version = (job_spec.metadata.version or 0) + 1 + + self.assertEqual(updated_job_spec.metadata.version, updated_version) diff --git a/axlearn/cloud/gcp/jobs/launch_utils.py b/axlearn/cloud/gcp/jobs/launch_utils.py index ccfd4f003..756a45b3a 100644 --- a/axlearn/cloud/gcp/jobs/launch_utils.py +++ b/axlearn/cloud/gcp/jobs/launch_utils.py @@ -5,6 +5,7 @@ import collections import json import re +import shlex from typing import Any, Optional, Protocol from absl import flags @@ -253,3 +254,55 @@ def _k8s_jobset_state_from_jobs( else: states.append("PENDING") return states + + +def _parse_resource_flags_from_command(command: str) -> flags.FlagValues: + """Infer resources flags from launch command. + + It parses the resources flags from the command. + + Args: + command: The launch command of a job. + + Returns: + A flags.FlagValues containing the parsed resources flags. + """ + commands = shlex.split(command) + + fv = flags.FlagValues() + flags.DEFINE_string("instance_type", default=None, help="", flag_values=fv) + flags.DEFINE_integer("num_replicas", default=None, help="", flag_values=fv) + flags.DEFINE_boolean("enable_pre_provisioner", default=None, help="", flag_values=fv) + flags.DEFINE_alias("num_slices", "num_replicas", flag_values=fv) + flags.DEFINE_alias("tpu_type", "instance_type", flag_values=fv) + fv(commands, known_only=True) + + return fv + + +def validate_resource_flags(original_command: str, updated_command: str): + """Raise an exception if the resource flags are different + in the original and updated commands.""" + + original_parsed_flags = _parse_resource_flags_from_command(original_command) + updated_parsed_flags = _parse_resource_flags_from_command(updated_command) + + original_instance_type = original_parsed_flags.instance_type or original_parsed_flags.tpu_type + updated_instance_type = updated_parsed_flags.instance_type or updated_parsed_flags.tpu_type + + original_num_replicas = original_parsed_flags.num_replicas or original_parsed_flags.num_slices + updated_num_replicas = updated_parsed_flags.num_replicas or updated_parsed_flags.num_slices + + original_pre_provisioner = original_parsed_flags.enable_pre_provisioner + updated_pre_provisioner = updated_parsed_flags.enable_pre_provisioner + + if original_instance_type != updated_instance_type: + raise ValueError(f"Expected {original_instance_type=} to match {updated_instance_type=}.") + + if original_num_replicas != updated_num_replicas: + raise ValueError(f"Expected {original_num_replicas=} to match {updated_num_replicas=}.") + + if original_pre_provisioner != updated_pre_provisioner: + raise ValueError( + f"Expected {original_pre_provisioner=} to match {updated_pre_provisioner=}." + ) diff --git a/axlearn/cloud/gcp/jobs/launch_utils_test.py b/axlearn/cloud/gcp/jobs/launch_utils_test.py index dbe6a5733..07c280f3b 100644 --- a/axlearn/cloud/gcp/jobs/launch_utils_test.py +++ b/axlearn/cloud/gcp/jobs/launch_utils_test.py @@ -3,10 +3,12 @@ """Tests launch utilities.""" # pylint: disable=protected-access +import contextlib import dataclasses import json from datetime import datetime from types import SimpleNamespace +from typing import Union from unittest import mock from absl import flags @@ -18,11 +20,13 @@ from axlearn.cloud.common.utils import Table from axlearn.cloud.gcp.jobs import launch_utils from axlearn.cloud.gcp.jobs.launch_utils import ( + _parse_resource_flags_from_command, jobs_table, match_by_regex, project_usage_table, serialized_flags_for_job, user_usage_table, + validate_resource_flags, with_k8s_jobset_state, with_qrm_tpu_state, ) @@ -120,6 +124,88 @@ def test_match_by_regex(self, matcher, cases): ), ) + @parameterized.parameters( + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update -" + "-enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + enable_pre_provisioner=True, + instance_type="tpu-v5litepod-16", + num_replicas=1, + ), + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--noenable_pre_provisioner --tpu_type=tpu-v5litepod-32 --num_slices=2 " + "-- sleep infinity", + enable_pre_provisioner=False, + instance_type="tpu-v5litepod-32", + num_replicas=2, + ), + dict( + command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--tpu_type=tpu-v5litepod-32 --num_slices=2 " + "-- sleep infinity", + enable_pre_provisioner=None, + instance_type="tpu-v5litepod-32", + num_replicas=2, + ), + ) + def test_parse_resource_flags_from_command( + self, command, enable_pre_provisioner, instance_type, num_replicas + ): + parsed_flags = _parse_resource_flags_from_command(command) + + self.assertEqual(parsed_flags.enable_pre_provisioner, enable_pre_provisioner) + self.assertEqual(parsed_flags.instance_type, instance_type) + self.assertEqual(parsed_flags.num_replicas, num_replicas) + + @parameterized.parameters( + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep 30", + expected=None, + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-32 --num_replicas=1 " + "-- sleep infinity", + expected=ValueError("instance_type"), + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update " + "--enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_slices=2 " + "-- sleep infinity", + expected=ValueError("num_replicas"), + ), + dict( + original_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update" + " --instance_type=tpu-v5litepod-16 --num_replicas=1 -- sleep infinity", + updated_command="python3 -m axlearn.cloud.gcp.jobs.gke_runner update" + " --enable_pre_provisioner --instance_type=tpu-v5litepod-16 --num_replicas=1 " + "-- sleep infinity", + expected=ValueError("pre_provisioner"), + ), + ) + def test_validate_resource_flags( + self, original_command, updated_command, expected: Union[Exception, type] + ): + if isinstance(expected, Exception): + ctx = self.assertRaisesRegex(type(expected), str(expected)) + else: + ctx = contextlib.nullcontext() + with ctx: + validate_resource_flags(original_command, updated_command) + class TestListUtils(parameterized.TestCase): """Tests list utils.""" From 14054b4ebb1fbee3cf02b308b14a2926ab6d82b6 Mon Sep 17 00:00:00 2001 From: Matthew Hopkins Date: Wed, 6 Nov 2024 11:35:21 -0800 Subject: [PATCH 08/27] upgrade to jax 0.4.34 (#817) * upgrade jax to 0.4.34 * add workaround for change to jax cluster autodetection --- CHANGELOG.md | 5 +++++ axlearn/common/learner_test.py | 2 +- axlearn/common/optimizers.py | 14 +++++++++++--- axlearn/common/update_transformation_test.py | 12 +++++------- axlearn/common/utils_spmd.py | 4 ++++ pyproject.toml | 10 +++++----- 6 files changed, 31 insertions(+), 16 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 9a9c7eac0..5fc2dff7c 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,10 @@ # Change Log +## 0.1.4 + +* Changes + * Upgrade Jax from 0.4.33 to 0.4.34. + ## 0.1.3 * Changes diff --git a/axlearn/common/learner_test.py b/axlearn/common/learner_test.py index 6ca7e3b2d..1a11a4797 100644 --- a/axlearn/common/learner_test.py +++ b/axlearn/common/learner_test.py @@ -1219,7 +1219,7 @@ def test_learner_masking(test_self): pre-existing `CompositeLearner` implementation. """ - updates = axlearn.common.update_transformation_test.mock_updates() + updates = axlearn.common.update_transformation_test.mock_updates(state_param_none=False) param_keys = updates.opt_params.keys() state_keys = updates.inplace_updates.keys() diff --git a/axlearn/common/optimizers.py b/axlearn/common/optimizers.py index bd507c258..70517ad59 100644 --- a/axlearn/common/optimizers.py +++ b/axlearn/common/optimizers.py @@ -544,11 +544,13 @@ def update_fn(updates: NestedTensor, state: AddDecayedWeightsState, params: Nest lr_scale = lr**learning_rate_exponent param_scales = _weight_decay_scales(params, per_param_scale=per_param_scale) + f = lambda g, p, s: g + weight_decay * lr_scale * p.value * s updates = jax.tree.map( - lambda g, p, s: g + weight_decay * lr_scale * p.value * s, + lambda x, y, z: None if x is None else f(x, y, z), updates, params, param_scales, + is_leaf=lambda x: x is None, ) if learning_rate_exponent is None: updated_state = state @@ -1882,9 +1884,10 @@ def _smoothed_updates( # First compute raw updates. raw_updates, pps_tree = _split_update_results( jax.tree.map( - lambda g, s: _raw_updates(grad=g, pps=s), + lambda g, s: None if g is None else _raw_updates(grad=g, pps=s), grads, state.pps, + is_leaf=lambda x: x is None, ) ) # Clip raw updates if necessary. @@ -1966,7 +1969,12 @@ def _update2(u: Tensor, param: OptParam): context.add_summary("weight_decay_rate", weight_decay * schedule_scale) return -schedule_scale * updates_with_wd - updates2 = jax.tree.map(lambda u, p: _update2(u, param=p), updates, params) + updates2 = jax.tree.map( + lambda u, p: None if u is None else _update2(u, param=p), + updates, + params, + is_leaf=lambda x: x is None, + ) return updates2, optax.safe_int32_increment(step) # Stage 1. diff --git a/axlearn/common/update_transformation_test.py b/axlearn/common/update_transformation_test.py index bcfc49d1e..b29062a6a 100644 --- a/axlearn/common/update_transformation_test.py +++ b/axlearn/common/update_transformation_test.py @@ -166,9 +166,11 @@ def mock_params() -> Nested[Tensor]: ) -def mock_updates() -> axlearn.common.update_transformation.Updates: +def mock_updates(state_param_none: bool = True) -> axlearn.common.update_transformation.Updates: """Create an updates object with various semi-reasonable values.""" model_params = mock_params() + if state_param_none: + model_params["state"] = None opt_params = jax.tree.map( lambda p: OptParam( value=p, @@ -197,6 +199,7 @@ def test_param_values(self): updates = mock_updates() actual = updates.param_values() expected = mock_params() + expected["state"] = None chex.assert_trees_all_equal_structs(actual, expected) self.assertNestedAllClose(actual, expected) @@ -218,12 +221,7 @@ def test_param_specs(self): weight_decay_scale=0.1, ) ), - state=ParameterSpec( - shape=(2,), - dtype=jnp.int32, - factorization=FactorizationSpec([None]), - weight_decay_scale=0.1, - ), + state=None, more_state=ParameterSpec( shape=(3,), dtype=jnp.int32, diff --git a/axlearn/common/utils_spmd.py b/axlearn/common/utils_spmd.py index bc0007cef..cfbbb7456 100644 --- a/axlearn/common/utils_spmd.py +++ b/axlearn/common/utils_spmd.py @@ -87,6 +87,10 @@ def setup( num_processes=num_processes, process_id=process_id, ) + if jax_backend == "gpu": + # jax 0.4.34 introduced a change to cluster auto-detection behavior, supplying + # local_device_ids arg allows us to maintain expected behavior + init_kwargs["local_device_ids"] = list(range(8)) jax.distributed.initialize(**init_kwargs) _jax_distributed_initialized = True diff --git a/pyproject.toml b/pyproject.toml index 54329b6b1..076c4706a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "flit_core.buildapi" [project] name = "axlearn" -version = "0.1.3" +version = "0.1.4" description = "AXLearn" readme = "README.md" requires-python = ">=3.10" @@ -23,8 +23,8 @@ core = [ "absl-py==2.1.0", "chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25. "importlab==0.7", # breaks pytype on 0.8 - "jax==0.4.33", - "jaxlib==0.4.33", + "jax==0.4.34", + "jaxlib==0.4.34", "nltk==3.7", # for text preprocessing "optax==0.1.7", # optimizers (0.1.0 has known bugs). "portpicker", @@ -101,7 +101,7 @@ gcp = [ # Note: Specify -f https://storage.googleapis.com/jax-releases/libtpu_releases.html during install. tpu = [ "axlearn[gcp]", - "jax[tpu]==0.4.33", # must be >=0.4.19 for compat with v5p. + "jax[tpu]==0.4.34", # must be >=0.4.19 for compat with v5p. ] # Vertex AI tensorboard. TODO(markblee): Merge with `gcp`. vertexai_tensorboard = [ @@ -125,7 +125,7 @@ dataflow = [ # GPU custom kernel dependency. gpu = [ "triton==2.1.0", - "jax[cuda12_pip]==0.4.33", + "jax[cuda12]==0.4.34", ] # Open API inference. open_api = [ From 33ec1526b8e6d7482a1860f7a6638f694b3dce64 Mon Sep 17 00:00:00 2001 From: Chang Lan Date: Wed, 6 Nov 2024 16:34:59 -0800 Subject: [PATCH 09/27] fix: Instantiate logits_modifier config in sample_decode (#819) --- axlearn/common/decoder.py | 3 ++- axlearn/common/decoder_test.py | 7 ++++--- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/axlearn/common/decoder.py b/axlearn/common/decoder.py index 5c4b659da..941347e5c 100644 --- a/axlearn/common/decoder.py +++ b/axlearn/common/decoder.py @@ -280,11 +280,12 @@ def sample_decode( The sample decoding outputs. """ cfg: DecodingLayer.Config = self.config + logits_modifier = maybe_instantiate(logits_modifier) tokens_to_scores_fn = self._tokens_to_scores( num_decodes=num_decodes, cross_attention_data=cross_attention_data, cross_attention_logit_biases=cross_attention_logit_biases, - logits_modifier=maybe_instantiate(logits_modifier), + logits_modifier=logits_modifier, ) input_ids = self._pad( prefix, max_sequence_length=max_sequence_length, pad_id=cfg.pad_token_id diff --git a/axlearn/common/decoder_test.py b/axlearn/common/decoder_test.py index 279e6dc59..6eef860d1 100644 --- a/axlearn/common/decoder_test.py +++ b/axlearn/common/decoder_test.py @@ -574,9 +574,10 @@ def test_decode( if method == "sample_decode": # Modify logits so that we will always sample the last token ID. - inputs["logits_modifier"] = ( - lambda logits: jnp.full_like(logits, decoding.NEG_INF).at[:, -1].set(0) - ) + def logits_modifier_fn(): + return lambda logits: jnp.full_like(logits, decoding.NEG_INF).at[:, -1].set(0) + + inputs["logits_modifier"] = config_for_function(logits_modifier_fn) # pylint: disable=protected-access mock_ctx = contextlib.nullcontext() From 6d45610fb17db0cb20de0627e1b588e8e3a639d3 Mon Sep 17 00:00:00 2001 From: qdavid1 <168590940+qdavid1@users.noreply.github.com> Date: Fri, 8 Nov 2024 11:24:03 -0800 Subject: [PATCH 10/27] Stack transformer with skip connection (#821) --- axlearn/common/attention.py | 69 +++++++++++++++++++---------- axlearn/common/attention_test.py | 75 ++++++++++++++++++-------------- 2 files changed, 89 insertions(+), 55 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index fb49d507c..7236ef6e4 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -3601,6 +3601,45 @@ class Config(BaseTransformerLayer.Config): peak_stochastic_depth_rate: Optional[float] = None +class UpdateDataFn(Protocol): + """A function for updating the constituent layers' input in a StackTransformerLayer.""" + + def __call__( + self, data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output] + ) -> Tensor: + """Returns a new Tensor with the same shape as `data`, reflecting some desired updates. + + Args: + data: A Tensor denoting the input data to the upcoming layer. + all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with + the output of each constituent layer in the stack. + + Returns: + A new Tensor with the same shape as `data`. + """ + + +def update_data_with_skip_connection(skip_connections: dict[int, int]) -> UpdateDataFn: + """Creates a function that adds skip connection to the input data tensor. + + Args: + skip_connections: A dictionary where keys and values represent 0-indexed layer indices. + For a (k, v) pair, the output of the v-th layer will be added to the input + of the k-th layer. + + Returns: + A function that implements skip connections, following the UpdateDataFn protocol, . + """ + + def update_data(data: Tensor, all_layer_outputs: list[BaseTransformerLayer.Output]) -> Tensor: + layer_index = len(all_layer_outputs) + if layer_index in skip_connections: + data += all_layer_outputs[skip_connections[layer_index]].data + return data + + return update_data + + class StackedTransformerLayer(BaseStackedTransformerLayer): """A simple implementation of BaseStackedTransformerLayer.""" @@ -3613,10 +3652,15 @@ class Config(BaseStackedTransformerLayer.Config): layer: Union[ BaseTransformerLayer.Config, Sequence[BaseTransformerLayer.Config] ] = TransformerLayer.default_config() + # If set, implements the UpdateDataFn protocol to update individual layers' input + # data in some specified way. This operation is applied before calling every layer. + data_merger: Optional[InstantiableConfig[UpdateDataFn]] = None def __init__(self, cfg: Config, *, parent: Optional[Module]): super().__init__(cfg, parent=parent) cfg = self.config + self._update_data = maybe_instantiate(cfg.data_merger) + if isinstance(cfg.layer, Sequence): layer_cfgs = cfg.layer if len(layer_cfgs) != cfg.num_layers: @@ -3685,7 +3729,8 @@ def _forward_for_mode( all_layer_states = [] for i, layer in enumerate(self._layers): # Prepare inputs to the current layer. - data = self._update_data(data, all_layer_outputs=all_layer_outputs) + if self._update_data is not None: + data = self._update_data(data, all_layer_outputs) self._update_layer_kwargs(layer_kwargs, all_layer_outputs=all_layer_outputs) if mode == ForwardMode.FORWARD: @@ -3712,28 +3757,6 @@ def _forward_for_mode( return all_layer_states, self._aggregate_layer_outputs(all_layer_outputs) - def _update_data( - self, - data: Tensor, - *, - all_layer_outputs: list[BaseTransformerLayer.Output], - ): - """Updates `data` using other args. - - This method is called before we invoke each layer in `self._layers`. - The updated data will be passed to the layer invocation. - - Args: - data: A Tensor denoting the input data to the upcoming layer. - all_layer_outputs: A list of BaseTransformerLayer.Output that is appended with - the output of each constituent layer in the stack. - - Returns: - A new Tensor. - """ - del all_layer_outputs - return data - def _update_layer_kwargs( self, layer_kwargs: dict[str, Any], diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 12fbbf335..b8fda30a4 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -78,6 +78,7 @@ set_double_shard_weights_config, sinusoidal_positional_embeddings, sliding_window_causal_mask, + update_data_with_skip_connection, xl_attention_logits, ) from axlearn.common.base_layer import ( @@ -3761,22 +3762,6 @@ def _aggregate_layer_outputs( ) -class TestStackedTransformerLayerWithDataOverride(NonUniformStack): - """A class with a simple override of _update_data for unit testing.""" - - @property - def forced_input(self): - return jnp.ones((2, 3, 4)) - - def _update_data( - self, - data: Tensor, - *, - all_layer_outputs: list[BaseTransformerLayer.Output], - ): - return self.forced_input - - class TestStackedTransformerLayerWithKVState(NonUniformStack): """A class with a simple override of _update_layer_kwargs for unit testing.""" @@ -3793,6 +3778,16 @@ def _update_layer_kwargs( layer_kwargs["self_attention_kv_state"] = None +class TestStackedTransformerLayerWithSkipConnection(StackedTransformerLayer): + """A class that outputs all layers' output for unit testing.""" + + def _aggregate_layer_outputs( + self, + layer_outputs: Sequence[BaseTransformerLayer.Output], + ) -> Sequence[BaseTransformerLayer.Output]: + return layer_outputs + + class StackedTransformerTest(BaseTransformerTest): """Tests StackedTransformerLayer.""" @@ -4117,47 +4112,63 @@ def test_transformer_prefill_states(self, transformer_type, layer_type): assert_allclose(decoder_self_attention_probs, forward_outputs.self_attention_probs) assert_allclose(decoder_cross_attention_probs, forward_outputs.cross_attention_probs) - def test_update_data(self): + def test_skip_connection(self): batch_size = 2 seq_len = 6 num_heads = 2 input_dim = 4 hidden_dim = 8 + num_layers = 5 + layer_with_skip_input = 3 - # Create a StackedTransformerLayer by specifying a sequence of non-uniform layer configs. - cfg = TestStackedTransformerLayerWithDataOverride.default_config().set(name="test") - cfg.input_dim = input_dim - cfg.num_layers = 2 + cfg = TestStackedTransformerLayerWithSkipConnection.default_config().set( + name="test", input_dim=input_dim, num_layers=num_layers + ) transformer_cfg = TransformerLayer.default_config() transformer_cfg.self_attention.attention.num_heads = num_heads transformer_cfg.feed_forward.hidden_dim = hidden_dim cfg.layer = transformer_cfg - layer: StackedTransformerLayer = cfg.instantiate(parent=None) + test_cfg = cfg.clone().set( + data_merger=config_for_function(update_data_with_skip_connection).set( + skip_connections={layer_with_skip_input: 1} + ) + ) + + base_layer = cfg.instantiate(parent=None) + test_layer = test_cfg.instantiate(parent=None) + random_inputs = jax.random.uniform( jax.random.PRNGKey(1), shape=(batch_size, seq_len, input_dim) ) - state = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - outputs_with_random_input, _ = F( - layer, + state = base_layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + base_output, _ = F( + base_layer, is_training=True, prng_key=jax.random.PRNGKey(123), state=state, inputs=dict(data=random_inputs), ) - outputs_with_forced_input, _ = F( - layer, + test_output, _ = F( + test_layer, is_training=True, prng_key=jax.random.PRNGKey(123), state=state, - inputs=dict(data=layer.forced_input), - ) - self.assertNestedAllClose( - outputs_with_random_input.data, - outputs_with_forced_input.data, + inputs=dict(data=random_inputs), ) + for i in range(layer_with_skip_input): + self.assertNestedAllClose( + base_output[i].data, + test_output[i].data, + ) + for i in range(layer_with_skip_input, num_layers): + self.assertNotAlmostEqual( + jnp.min(jnp.abs(base_output[i].data - test_output[i].data)), + 0.0, + ) + def test_update_layer_kwargs(self): batch_size = 2 seq_len = 6 From bff39377c95f0582c914e1ec45c336a36d00f87e Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Fri, 8 Nov 2024 16:47:52 -0800 Subject: [PATCH 11/27] StackOverTime's partial frame is treated as a valid frame, similar to convolution padding. (#822) PR "Fix inconsistent paddings in conv layer. (#776)" changed the semantics so that partial frames are now considered valid. As a result, StackOverTime has been updated to follow this new semantic. Additionally, StackOverTime has been modified to support convolution padding literals such as CAUSAL, SAME, VALID. Currently, CAUSAL padding is preferred for all downstream tasks. StackOverTime with stride=4 functions as a special case of a convolution with window=4, stride=4, and CAUSAL padding. This corresponds to explicit padding of (0,3) = (window-stride, stride-1). Moreover, StackOverTime is only used in speech models to reduce the input sequence length, similar to the effect of a subsampler. For stride=4, the first reduced output frame should represent inputs[0:4], making this equivalent to a CAUSAL convolution with stride=4 and window=4. However, in downstream tasks, padding has often been incorrectly set to (0,0) or (0,4), which only works for specific sequence lengths. The padding (0,3) works universally. To simplify hyperparameter setup (as mistake prone), the default padding has been changed to CAUSAL, covering all current downstream cases. Figures for above explanation. * subsampler output len = ceil(input_length / 4) When input len is 6 and padding="SAME" or "CAUSAL" ``` 0 0 conv2 out 0 0 0 conv1 out 0 0 0 0 0 0 frontend out ``` * padding=(0,0) vs (0,3) for stride=4, ``` Inputs 1 2 3 4 5 6 7 8 padding=(0,0) 1 5 2 6 3 7 4 8 padding=(0,3) 1 5 P 2 6 P 3 7 P 4 8 ^ discarded as not complete. Inputs 1 2 3 4 5 6 7 padding=(0,0) 1 5 2 6 3 7 4 ^ discarded as not complete, which is WRONG!! padding=(0,3) 1 5 P 2 6 P 3 7 4 P ^ discarded as not complete. Inputs 1 2 3 4 5 padding=(0,0) 1 5 2 3 4 ^ discarded as not complete, which is WRONG!! padding=(0,3) 1 5 2 P 3 P 4 P ``` --- axlearn/common/layers.py | 44 +++++++++++++++++++++++------------ axlearn/common/layers_test.py | 20 +++++++--------- 2 files changed, 38 insertions(+), 26 deletions(-) diff --git a/axlearn/common/layers.py b/axlearn/common/layers.py index c8beacc24..a2f6d4190 100644 --- a/axlearn/common/layers.py +++ b/axlearn/common/layers.py @@ -19,7 +19,7 @@ import enum from collections.abc import Sequence -from typing import Any, Callable, Optional, Union +from typing import Any, Callable, Literal, Optional, Union import chex import jax @@ -2030,7 +2030,8 @@ class StackOverTime(BaseLayer): """Stack inputs along the time axis. StackOverTime behaves the same as Conv2DWith1DPadding w.r.t. paddings along the time axis. - We treat front paddings as valid frames and back paddings as invalid frames. + Please refer to the docstring of Conv2DWith1DPadding to understand how the padding work + including "SAME", "VALID", and "CAUSAL" literals. The padding anchor is set to `left padding`. """ @config_class @@ -2038,9 +2039,13 @@ class Config(BaseLayer.Config): """Configures StackOverTime.""" stride: Required[int] = REQUIRED # Number of frames to stack. - # Number of paddings to apply along the time axis. The two integers indicate - # leading and trailing padding to add respectively. - padding: tuple[int, int] = (0, 0) + + # Number of paddings to apply along the time axis. The two integers specify the amount + # of leading and trailing padding, respectively. Alternatively, this can be a + # convolution padding literals type such as 'SAME', 'VALID', or 'CAUSAL'. + # Note: For backward compatibility, the default is set to VALID, but in most cases, + # CAUSAL is more appropriate as it preserves the sequence length. + padding: Union[tuple[int, int], Literal["SAME", "VALID", "CAUSAL"]] = "VALID" def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: """Stacks stride number of frames into one frame along the time axis. @@ -2060,11 +2065,16 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: cfg = self.config if cfg.stride <= 1: raise ValueError(f"stride should be greater than 1, but got {cfg.stride}.") - inputs = jnp.pad(inputs, ((0, 0), cfg.padding, (0, 0)), constant_values=0) - # Front paddings are valid frames. - paddings = jnp.pad(paddings, ((0, 0), (cfg.padding[0], 0)), constant_values=0) - # Back paddings are invalid frames. - paddings = jnp.pad(paddings, ((0, 0), (0, cfg.padding[1])), constant_values=1) + + # For the last partial frame. + inputs = inputs * (1 - paddings)[:, :, None] + + padding = cfg.padding + if isinstance(padding, str): + padding = conv_explicit_padding( + window=(cfg.stride,), strides=(cfg.stride,), padding=padding + )[0] + inputs = jnp.pad(inputs, ((0, 0), padding, (0, 0)), constant_values=0) batch_size, seq_len, input_dim = inputs.shape output_length = seq_len // cfg.stride @@ -2072,9 +2082,8 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> tuple[Tensor, Tensor]: # Stack inputs over the time dimension. stacked_inputs = jnp.reshape(inputs[:, : output_length * cfg.stride, :], new_shape) # An output frame is padding if at least one of the stacked input frames is padding. - stacked_paddings = jnp.max( - jnp.reshape(paddings[:, : output_length * cfg.stride], [-1, output_length, cfg.stride]), - axis=-1, + stacked_paddings = compute_conv_paddings( + paddings, window=cfg.stride, stride=cfg.stride, conv_padding=(padding,) ) stacked_inputs = stacked_inputs * (1 - stacked_paddings)[:, :, None] return stacked_inputs, stacked_paddings @@ -2092,8 +2101,13 @@ def output_shape(self, *, input_shape: Sequence[Optional[int]]) -> Sequence[Opti """ cfg = self.config batch_size, seq_len, input_dim = input_shape - output_length = (seq_len + sum(cfg.padding)) // cfg.stride if seq_len is not None else None - return [batch_size, output_length, input_dim * cfg.stride] + padding = cfg.padding + if isinstance(padding, tuple): + padding = (padding,) + out_shape = conv_output_shape( + [seq_len], window=(cfg.stride,), strides=(cfg.stride,), padding=padding + ) + return [batch_size, *out_shape, input_dim * cfg.stride] class MultiLinear(BaseLayer): diff --git a/axlearn/common/layers_test.py b/axlearn/common/layers_test.py index eb0943b99..30dba4741 100644 --- a/axlearn/common/layers_test.py +++ b/axlearn/common/layers_test.py @@ -2000,8 +2000,8 @@ def test_drop_tokens(self, drop_rate, num_cls_tokens): ( 3, (0, 0), - [[[1, 1, 2, 2, 3, 3]], [[0, 0, 0, 0, 0, 0]]], - [[0], [1]], + [[[1, 1, 2, 2, 3, 3]], [[7, 7, 8, 8, 0, 0]]], + [[0], [0]], ), ( 3, @@ -2066,18 +2066,14 @@ def test_stack_over_time_data_change(self): ) output_shape = layer.output_shape(input_shape=inputs.shape) self.assertAllEqual(outputs.shape, output_shape) - self.assertAllEqual(np.array([4, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) - self.assertAllClose( - np.sum(inputs**2, (1, 2)), - np.sum(outputs**2, (1, 2)) + np.array([np.sum(inputs[0][8] ** 2), 0.0]), - ) + self.assertAllEqual(np.array([5, 7], dtype=np.float32), np.sum(1 - output_paddings, axis=1)) + self.assertAllClose(np.sum(inputs**2, (1, 2)), np.sum(outputs**2, (1, 2))) - @parameterized.product(stride=(2, 3, 4), pad=((0, 0), (1, 1), (2, 0))) + @parameterized.product(stride=(2, 3, 4), pad=("VALID", "SAME", "CAUSAL")) def test_stack_consistent_outputs(self, stride, pad): """Tests that StackOverTime has consistent outputs under different padding lengths.""" batch_size, input_dim = 2, 1 input_length = 7 - expected_output_length = (input_length + pad[0]) // stride layer: StackOverTime = ( StackOverTime.default_config() .set( @@ -2087,12 +2083,13 @@ def test_stack_consistent_outputs(self, stride, pad): ) .instantiate(parent=None) ) + expected_output_length = layer.output_shape(input_shape=[1, input_length, 1])[1] layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) for ll in range(4, 11): # Batch with another example of length ll. length = max(input_length, ll) inputs = jnp.ones([batch_size, length, input_dim]) - paddings = jnp.arange(length)[None, :] >= jnp.array([7, ll])[:, None] + paddings = jnp.arange(length)[None, :] >= jnp.array([input_length, ll])[:, None] (outputs, output_paddings), _ = F( layer, inputs=dict(inputs=inputs, paddings=paddings), @@ -2102,7 +2099,8 @@ def test_stack_consistent_outputs(self, stride, pad): ) output_shape = layer.output_shape(input_shape=inputs.shape) self.assertAllEqual(outputs.shape, output_shape) - self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) + if pad != "VALID": # VALID doesn't preserve length. + self.assertEqual(expected_output_length, np.sum(1 - output_paddings, axis=1)[0]) @parameterized.parameters(((0, 1), (0, 0)), ((1, 1), (3, 0)), ((1, 1), (0, 3))) def test_stack_vs_conv2d_output_len_match(self, conv_padding, stack_padding): From 8bbee7f850e57af2105224c2f868553ccde0905f Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Fri, 8 Nov 2024 16:48:03 -0800 Subject: [PATCH 12/27] Quantizer does not return one-hot vectors. (#823) onehot can be extremely memory hungry in scenarios with long context and a large vocab size. For instance, if the vocab size is 256k and the context length is 32k, it would unnecessarily consume 32GB of memory per batch (256k * 32k * 4 bytes). Note: ids is 128kB (32k * 4 bytes). For similar reasons, we avoid using one-hot vectors during codebook lookup and instead use advanced indexing for memory efficiency. Note: The change in the unittest is unrelated to any logic changes. The difference is likely due to not creating temporary tensors in forward(), which altered the random values of the parameters. --- axlearn/common/quantizer.py | 64 ++++++++++++++------------ axlearn/common/quantizer_test.py | 36 ++++----------- axlearn/vision/beit_image_tokenizer.py | 9 +++- 3 files changed, 51 insertions(+), 58 deletions(-) diff --git a/axlearn/common/quantizer.py b/axlearn/common/quantizer.py index 199b2b9f4..373a2b059 100644 --- a/axlearn/common/quantizer.py +++ b/axlearn/common/quantizer.py @@ -105,8 +105,6 @@ class Config(BaseLayer.Config): class Output(NamedTuple): # [..., num_codebooks]. ids: Tensor - # [..., num_codebooks, codebook_size]. - onehots: Tensor # [..., num_codebooks, codebook_dim]. quantized_vectors: Tensor # Scalar of quantizer loss. @@ -147,25 +145,28 @@ def _lookup(*, ids: Tensor, codebook: Tensor) -> BaseQuantizer.Output: """Codebook look up with ids. Args: - ids: integer tensor of shape [batch_size, seq_len, num_codebooks] with values + ids: integer tensor of shape [..., num_codebooks] with values in range [0, codebook_size). codebook: Tensor of shape [codebook_size, num_codebooks, codebook_dim]. Returns: - BaseQuantizer.Output. + BaseQuantizer.Output + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. Raises: NotImplementedError: if ids.ndim > 11. """ if ids.ndim - 1 > len(_einsum_dims): raise NotImplementedError(ids.shape) - # [..., num_codebooks, vocab_size]. - onehots = jax.nn.one_hot(ids, num_classes=codebook.shape[0], axis=-1, dtype=codebook.dtype) - batch_dims = _einsum_dims[: onehots.ndim - 2] - quantized_vectors = jnp.einsum(f"{batch_dims}gv,vgh->{batch_dims}gh", onehots, codebook) + + # [..., num_codebooks] + g_index = jnp.expand_dims(jnp.arange(ids.shape[-1]), axis=tuple(range(ids.ndim - 1))) + # codebook: [codebook_size, num_codebooks, codebook_dim], ids: [..., num_codebooks] + # -> [..., num_codebooks, codebook_dim] + quantized_vectors = codebook[ids, g_index] return BaseQuantizer.Output( ids=ids, - onehots=onehots, quantized_vectors=quantized_vectors, ) @@ -236,19 +237,20 @@ def _apply_paddings(*, outputs: BaseQuantizer.Output, paddings: Tensor) -> BaseQ """ # ids are padded with -1. ids = outputs.ids * (1 - paddings)[:, :, None] + (-1) * paddings[:, :, None] - onehots = outputs.onehots * (1 - paddings)[:, :, None, None] quantized_vectors = outputs.quantized_vectors * (1 - paddings)[:, :, None, None] return BaseQuantizer.Output( ids=ids, - onehots=onehots, quantized_vectors=quantized_vectors, loss=outputs.loss, ) -def _add_codebook_summaries( - *, context: InvocationContext, outputs: BaseQuantizer.Output, paddings: Tensor -): +def _ids_to_onehots(ids: Tensor, *, codebook_size: int, dtype: jnp.dtype) -> Tensor: + # [..., num_codebooks, codebook_size]. + return jax.nn.one_hot(ids, num_classes=codebook_size, axis=-1, dtype=dtype) + + +def _add_codebook_summaries(*, context: InvocationContext, onehots: Tensor, paddings: Tensor): """Helper function to compute codebook distribution statistics and add to summaries. The statistics are from all frames, not only on those masked frames in self-supervised training. @@ -256,11 +258,11 @@ def _add_codebook_summaries( Args: context: Module invocation context to add summaries to. - outputs: BaseQuantizer.Output. + onehots: onehot of BaseQuantizer.Output.ids. paddings: 0/1 tensor of shape [batch_size, seq_len], where 0 is valid position. """ - coverage = compute_code_coverage(onehots=outputs.onehots, paddings=paddings) - pplx, entropy = compute_code_pplx(onehots=outputs.onehots, paddings=paddings) + coverage = compute_code_coverage(onehots=onehots, paddings=paddings) + pplx, entropy = compute_code_pplx(onehots=onehots, paddings=paddings) batch_size = paddings.shape[0] num_frames = jnp.sum(1 - paddings) @@ -368,18 +370,19 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> BaseQuantizer.Output: q_outputs = _apply_paddings(outputs=q_outputs, paddings=paddings) # Best-rq freezes the codebook. ids = jax.lax.stop_gradient(q_outputs.ids) - onehots = jax.lax.stop_gradient(q_outputs.onehots) quantized_vectors = jax.lax.stop_gradient(q_outputs.quantized_vectors) outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=ids, - # [batch_size, seq_len, num_codebooks, codebook_size]. - onehots=onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=quantized_vectors, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) return outputs @@ -519,15 +522,16 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> BaseQuantizer.Output: outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=quantized_inputs.ids, - # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots=quantized_inputs.onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=jnp.reshape( quantized_vectors, [batch_size, seq_len, cfg.num_codebooks, cfg.codebook_dim] ), loss=total_loss, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) return outputs @@ -620,10 +624,9 @@ def forward( # [batch_size, seq_len, 1]. mask = (1 - paddings)[:, :, None] ids = ids * mask + (-1) * (1 - mask) + # TODO(dhwang2): optimize memory by scan for long context training. # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots = jax.nn.one_hot( - ids, num_classes=cfg.codebook_size, axis=-1, dtype=inputs.dtype - ) + onehots = _ids_to_onehots(ids, codebook_size=cfg.codebook_size, dtype=inputs.dtype) # We need this to stop gradients on the padded frames. onehots = onehots * mask[:, :, :, None] # [batch_size, seq_len, num_codebooks, vocab_size]. @@ -640,13 +643,14 @@ def forward( outputs = self.Output( # [batch_size, seq_len, num_codebooks]. ids=ids, - # [batch_size, seq_len, num_codebooks, vocab_size]. - onehots=onehots, # [batch_size, seq_len, num_codebooks, codebook_dim]. quantized_vectors=quantized_vectors, ) - _add_codebook_summaries(context=current_context(), outputs=outputs, paddings=paddings) + onehots = _ids_to_onehots( + outputs.ids, codebook_size=cfg.codebook_size, dtype=paddings.dtype + ) + _add_codebook_summaries(context=current_context(), onehots=onehots, paddings=paddings) if self.is_training: self.add_module_output("probs", y_soft) self.add_summary("codebook/temperature_schedule_step", self.parameters["step"]) diff --git a/axlearn/common/quantizer_test.py b/axlearn/common/quantizer_test.py index 893dc9dad..465eb4b29 100644 --- a/axlearn/common/quantizer_test.py +++ b/axlearn/common/quantizer_test.py @@ -29,6 +29,7 @@ KmeansVectorQuantizer, RandomVectorQuantizer, SimilarityMetric, + _ids_to_onehots, compute_code_coverage, compute_code_pplx, quantize_by_nearest_neighbor, @@ -86,12 +87,13 @@ def test_quantize(self, num_groups, input_mean, metric): inputs=inputs, codebook=codebook, metric=metric ) # Compute codebook metrics. - coverage = compute_code_coverage(onehots=q_outputs.onehots, paddings=paddings) - pplx, entropy = compute_code_pplx(onehots=q_outputs.onehots, paddings=paddings) + onehots = _ids_to_onehots(q_outputs.ids, codebook_size=vocab_size, dtype=paddings.dtype) + coverage = compute_code_coverage(onehots=onehots, paddings=paddings) + pplx, entropy = compute_code_pplx(onehots=onehots, paddings=paddings) # Check shapes. self.assertEqual(q_outputs.ids.shape, (batch_size, seq_len, num_groups)) - self.assertEqual(q_outputs.onehots.shape, (batch_size, seq_len, num_groups, vocab_size)) + self.assertEqual(onehots.shape, (batch_size, seq_len, num_groups, vocab_size)) self.assertEqual( q_outputs.quantized_vectors.shape, (batch_size, seq_len, num_groups, codebook_dim) ) @@ -314,7 +316,7 @@ def test_forward( np.sum(layer_params["codebook"] ** 2), expected_values[batch_size][normalize_codebook]["codebook"], atol=1e-6, - rtol=1e-6, + rtol=2e-6, ) np.random.seed(2022) @@ -332,7 +334,6 @@ def test_forward( q_outputs.quantized_vectors.shape, ) self.assertEqual((batch_size, seq_len, num_groups), q_outputs.ids.shape) - self.assertEqual((batch_size, seq_len, num_groups, vocab_size), q_outputs.onehots.shape) assert_allclose( np.sum( jnp.reshape( @@ -349,12 +350,6 @@ def test_forward( atol=1e-6, rtol=1e-6, ) - assert_allclose( - np.sum(q_outputs.onehots), - expected_values[batch_size][normalize_codebook]["onehots"], - atol=1e-6, - rtol=1e-6, - ) self.assertEqual( output_collections.summaries["codebook/num_frames"].mean, jnp.sum(1 - paddings) / batch_size, @@ -472,7 +467,6 @@ def test_forward(self, num_groups, input_mean): outputs.quantized_vectors.shape, ) self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) - self.assertEqual((batch_size, seq_len, num_groups, vocab_size), outputs.onehots.shape) assert_allclose( expected_outputs[num_groups][input_mean][0], @@ -619,12 +613,6 @@ def _loss(params, inputs, paddings, layer=layer): atol=1e-6, rtol=1e-6, ) - assert_allclose( - outputs.onehots * paddings[:, :, None, None], - jnp.zeros_like(outputs.onehots), - atol=1e-6, - rtol=1e-6, - ) assert_allclose( outputs.quantized_vectors * paddings[:, :, None, None], jnp.zeros_like(outputs.quantized_vectors), @@ -653,7 +641,8 @@ def _loss(params, inputs, paddings, layer=layer): # [batch_size, seq_len, num_groups, dim]. # Gradient w.r.t codebook comes from kmeans_loss. grad_kmeans = -jnp.reshape(grad_l2_loss, [batch_size, seq_len, num_groups, codebook_dim]) - expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_kmeans, outputs.onehots) + onehots = _ids_to_onehots(outputs.ids, codebook_size=vocab_size, dtype=grad_kmeans.dtype) + expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_kmeans, onehots) self.assertNestedAllClose(grad_params, dict(codebook=expected_grad_codebook)) @@ -706,12 +695,6 @@ def test_forward(self, is_training): atol=1e-6, rtol=1e-6, ) - assert_allclose( - outputs.onehots * paddings[:, :, None, None], - jnp.zeros_like(outputs.onehots), - atol=1e-6, - rtol=1e-6, - ) assert_allclose( outputs.quantized_vectors * paddings[:, :, None, None], jnp.zeros_like(outputs.quantized_vectors), @@ -849,7 +832,8 @@ def _loss(params, inputs, paddings, layer=layer): # [batch_size, seq_len, num_groups, dim]. # Gradient w.r.t codebook. - expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_q_vecs, outputs.onehots) + onehots = _ids_to_onehots(outputs.ids, codebook_size=vocab_size, dtype=grad_q_vecs.dtype) + expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_q_vecs, onehots) assert_allclose(grad_params["codebook"], expected_grad_codebook, atol=1e-6, rtol=1e-6) diff --git a/axlearn/vision/beit_image_tokenizer.py b/axlearn/vision/beit_image_tokenizer.py index 19a7a7f1a..d655c5e3f 100644 --- a/axlearn/vision/beit_image_tokenizer.py +++ b/axlearn/vision/beit_image_tokenizer.py @@ -185,11 +185,16 @@ def forward(self, inputs: Tensor) -> tuple[Tensor, dict[str, Tensor]]: paddings = jnp.zeros(encoded_outputs.shape[:2]) quantized_output = self.quantizer(inputs=encoded_outputs, paddings=paddings) # quantized_output.quantized_vectors shape [batch_size, seq_len, 1, codebook_dim] - # quantized_output.onehots in shape [batch_size, seq_len, 1, codebook_size] # quantized_output.ids in shape [batch_size, seq_len, 1] + onehots = jax.nn.one_hot( + quantized_output.ids, + num_classes=self.config.quantizer.codebook_size, + axis=-1, + dtype=paddings.dtype, + ) return jnp.squeeze(quantized_output.ids, axis=-1), { "quantized_vectors": jnp.squeeze(quantized_output.quantized_vectors, axis=-2), - "quantized_codebook_onehots": jnp.squeeze(quantized_output.onehots, axis=-2), + "quantized_codebook_onehots": jnp.squeeze(onehots, axis=-2), } From 7145ac0465b83d432bd7b7e6d37a8bf6a5d22616 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Fri, 8 Nov 2024 17:07:58 -0800 Subject: [PATCH 13/27] Introduce `model_analysis.txt` in trainer. (#824) trainer saves `model_analysis.txt` to show model parameters details. e.g. ``` 16 [16] fc/bias 48 (3, 16) fc/weight Total number of model params: 64 State: prng_key=uint32((4,)) mesh_axes=ParameterSpec(shape=[4], dtype=, mesh_axes=PartitionSpec(None,), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None) State: model/fc/bias=float32((16,)) mesh_axes=ParameterSpec(shape=[16], dtype=, mesh_axes=PartitionSpec('model',), initializer=None, factorization=None, fan_axes=None, weight_decay_scale=None) State: model/fc/weight=float32((3, 16)) mesh_axes=ParameterSpec(shape=(3, 16), dtype=, mesh_axes=PartitionSpec(None, 'model'), initializer=None, factorization=FactorizationSpec(axes=('row', 'col')), fan_axes=FanAxes(in_axis=-2, out_axis=-1, batch_axis=()), weight_decay_scale=None) State: learner/optimizer/0/trace/fc/bias=float32((16,)) mesh_axes=TensorSpec(shape=[16], dtype=, mesh_axes=PartitionSpec('model',)) State: learner/optimizer/0/trace/fc/weight=float32((3, 16)) mesh_axes=TensorSpec(shape=(3, 16), dtype=, mesh_axes=PartitionSpec(None, 'model')) State: learner/optimizer/2/count=int32(()) mesh_axes=TensorSpec(shape=[], dtype=, mesh_axes=PartitionSpec()) Training state size: 0.00 GiB Training state size (partitioned): 0.00 GiB Max training state size (partitioned): 0.00 GiB ``` Note: the functionality refers to print_model_analysis.py --- axlearn/common/trainer.py | 40 ++++++++++++++++++++++++++-------- axlearn/common/trainer_test.py | 8 +++++++ 2 files changed, 39 insertions(+), 9 deletions(-) diff --git a/axlearn/common/trainer.py b/axlearn/common/trainer.py index 6ad7720f0..a60560769 100644 --- a/axlearn/common/trainer.py +++ b/axlearn/common/trainer.py @@ -753,17 +753,32 @@ def _init_state(prng_key: Tensor) -> TrainerState: learner=initialized_trainer_state.learner, ) - def _log_trainer_state_stats(self): + def _log_trainer_state_stats(self) -> str: total_num_params = count_model_params(self._trainer_state.model) - self._step_log("Total number of model params: %s", f"{total_num_params:,}") + analysis_logs = [] + + def _step_log(msg, *args, **kwargs): + self._step_log(msg, *args, **kwargs) + analysis_logs.append(msg % args) + + _step_log("##################### Model analysis #####################\n") + _step_log("## Parameters:") + fmt = "%10d %-20s %s" + flatten_name_and_spec = flatten_items(self._model_param_specs) + for name, spec in flatten_name_and_spec: + spec_size = np.prod(spec.shape) + _step_log(fmt, spec_size, spec.shape, name) + + _step_log("Total number of model params: %s", f"{total_num_params:,}") self.summary_writer(0, {"num_model_params": total_num_params}) + _step_log("\n## Trainer States:") # Training state size. total_state_bytes = 0 total_sharded_state_bytes = 0 state_spec_map = dict(utils.flatten_items(self.trainer_state_specs)) for path, value in utils.flatten_items(self._trainer_state): - self._step_log( + _step_log( "State: %s=%s(%s) mesh_axes=%s", path, value.dtype, @@ -780,7 +795,7 @@ def _log_trainer_state_stats(self): else: max_sharded_state_gb = total_sharded_state_gb - self._step_log( + _step_log( "Training state size: %.2f GiB\n" "Training state size (partitioned): %.2f GiB\n" "Max training state size (partitioned): %.2f GiB", @@ -789,6 +804,9 @@ def _log_trainer_state_stats(self): max_sharded_state_gb, ) + _step_log("\n##########################################################") + return "\n".join(analysis_logs) + def _prepare_training(self, prng_key: Tensor) -> bool: """Prepares training. @@ -819,12 +837,16 @@ def _prepare_training(self, prng_key: Tensor) -> bool: # Note the default checkpointer and evaler do nothing at step 0 with min_step=1. self.save_checkpoint(self._run_eval()) - # Log trainer state tree. - if jax.process_index() == 0: - with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f: - f.write(str(jax.tree_util.tree_structure(self._trainer_state))) + model_analysis = self._log_trainer_state_stats() + + # Log trainer state tree. + if not self.step and jax.process_index() == 0: + with fs.open(os.path.join(cfg.dir, "trainer_state_tree.txt"), "w") as f: + f.write(str(jax.tree_util.tree_structure(self._trainer_state))) + + with fs.open(os.path.join(cfg.dir, "model_analysis.txt"), "w") as f: + f.write(model_analysis) - self._log_trainer_state_stats() # Log config. self.summary_writer.log_config(cfg, step=self.step) diff --git a/axlearn/common/trainer_test.py b/axlearn/common/trainer_test.py index 149864b82..39177429e 100644 --- a/axlearn/common/trainer_test.py +++ b/axlearn/common/trainer_test.py @@ -417,6 +417,11 @@ def test_trainer( with open(os.path.join(cfg.dir, "trainer_state_tree.txt"), encoding="utf-8") as f: self.assertStartsWith(f.read(), "PyTreeDef(CustomNode(namedtuple[TrainerState], [*, ") + with open(os.path.join(cfg.dir, "model_analysis.txt"), encoding="utf-8") as f: + self.assertStartsWith( + f.read(), "##################### Model analysis #####################" + ) + if start_trace_steps: trace_dir = os.path.join(cfg.dir, "summaries", "train_train", "plugins", "profile") profile_files = [] @@ -856,6 +861,7 @@ def test_run_builder(self, restore_from_builder: bool): first_output = trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) # Make sure checkpoint exists. trainer2: SpmdTrainer = cfg.instantiate(parent=None) with trainer2.mesh(): @@ -931,6 +937,7 @@ def fn(*, step: int, evaler_summaries: dict[str, Any]): trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) trainer2: SpmdTrainer = cfg.clone(save_input_iterator=restore_input_iterator).instantiate( parent=None ) @@ -972,6 +979,7 @@ def test_last_step_checkpoint_policy(self): trainer.run(prng_key=jax.random.PRNGKey(123)) assert os.path.exists(os.path.join(cfg.dir, "trainer_state_tree.txt")) + assert os.path.exists(os.path.join(cfg.dir, "model_analysis.txt")) trainer2: SpmdTrainer = cfg.instantiate(parent=None) with trainer2.mesh(): # We should have checkpointed at the last step. From e6528a1a432fbddd9aa4b995de62980e444fbf61 Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Mon, 11 Nov 2024 16:52:22 -0800 Subject: [PATCH 14/27] Remove cleanup on save. (#825) --- axlearn/common/checkpointer.py | 1 - axlearn/common/checkpointer_test.py | 41 +++++++++++++++++++++++++++++ 2 files changed, 41 insertions(+), 1 deletion(-) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 0eb48ae3a..4f1f06e92 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -954,7 +954,6 @@ def save( if step < 0 or step >= 10**8: raise ValueError(f"Out-of-range: {step}") ckpt_dir = self.ckpt_dir(step) - self.cleanup_checkpoint(ckpt_dir) self._storage.save_to_dir( step=step, state=state, ckpt_dir=ckpt_dir, on_commit_callback=write_index_file ) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 7b485cd18..9c117e1fb 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -243,6 +243,47 @@ def create_corrupt_ckpt(step): ckpt.wait_until_finished() self.assertNestedEqual((3, state0), ckpt.restore(step=None, state=state0)) + @parameterized.parameters(Checkpointer, OrbaxCheckpointer) + def test_save_can_override_on_gcs(self, checkpointer_cls: Type[BaseCheckpointer]): + mesh_shape = (1, 1) + if not test_utils.is_supported_mesh_shape(mesh_shape): + return + # Patch is_gcs_path for orbax, since it commits differently on gcs vs local. + with _mesh(mesh_shape), mock.patch(f"{ocp.step.__name__}.is_gcs_path", return_value=True): + cfg = _checkpointer_config(checkpointer_cls) + ckpt: BaseCheckpointer = cfg.instantiate(parent=None) + state0 = dict(x=jnp.zeros([], dtype=jnp.int32), y=jnp.ones([2], dtype=jnp.float32)) + + # Save a checkpoint. + ckpt.save(step=1, state=state0) + ckpt.wait_until_finished() + self.assertNestedEqual((1, state0), ckpt.restore(step=None, state=state0)) + + if isinstance(ckpt, (Checkpointer, OrbaxCheckpointer)): + ckpt_dir = ckpt.ckpt_dir(step=1) + else: + raise NotImplementedError(type(ckpt)) + + # Corrupt the checkpoint by removing some files, while ensuring it is non-empty. + commit_file = ( + "index" if isinstance(ckpt, Checkpointer) else ocp.step._COMMIT_SUCCESS_FILE + ) + fs.rmtree(os.path.join(ckpt_dir, commit_file)) + self.assertGreater(len(fs.listdir(ckpt_dir)), 0) + + if isinstance(ckpt, OrbaxCheckpointer): + ckpt._manager.reload() # Orbax caches complete checkpoints. + + self.assertEqual(0, len(ckpt.checkpoint_paths(ckpt.config.dir))) + + # Test that save() should be able to override non-empty ckpt dir. + state1 = dict(x=jnp.ones([], dtype=jnp.int32), y=jnp.zeros([2], dtype=jnp.float32)) + ckpt.save(step=1, state=state1) + ckpt.wait_until_finished() + + # Should match the new state. + self.assertNestedEqual((1, state1), ckpt.restore(step=None, state=state1)) + @parameterized.product( checkpointer_cls=[Checkpointer, OrbaxCheckpointer], mesh_shape=[(1, 1), (2, 2), (4, 2)], From 40cd1967815566a39cef167676c6c9819e830581 Mon Sep 17 00:00:00 2001 From: Zhaoyi Zhang Date: Mon, 11 Nov 2024 17:51:32 -0800 Subject: [PATCH 15/27] [GKE]: support priority class (#828) * [GKE]: support priority class * more comments --------- Co-authored-by: Zhaoyi Zhang --- axlearn/cloud/gcp/job.py | 45 ++++++++++++++++++++++------------- axlearn/cloud/gcp/job_test.py | 20 ++++++++++++++-- 2 files changed, 47 insertions(+), 18 deletions(-) diff --git a/axlearn/cloud/gcp/job.py b/axlearn/cloud/gcp/job.py index 66dcff767..08073acd4 100644 --- a/axlearn/cloud/gcp/job.py +++ b/axlearn/cloud/gcp/job.py @@ -352,6 +352,13 @@ class Config(GCPJob.Config): Each host's output will be placed in `"{output_dir}/output/$HOSTNAME/"`. This directory is used by the sidecar container to sync outputs to GCS using gsutil. Ensure that `output_dir` is a valid GCS path (e.g., `gs://your-bucket/path`). + priority_class: Optional; The GKE PriorityClass for the job. + https://kubernetes.io/docs/concepts/scheduling-eviction/pod-priority-preemption + Note: 1. Values need to be pre-defined in each cluster. + 2. Job level priority is enforced by pod level priority of the leader pod. + This is managed by jobset controller. + 3. For TPU slice, this requires alpha.jobset.sigs.k8s.io/exclusive-topology + 4. [2024-11-11] Does not work on multi-slice TPU training yet. host_mounts: List of volumes from host to mount into the container. See `HostMount` for details. """ @@ -363,6 +370,7 @@ class Config(GCPJob.Config): enable_pre_provisioner: Optional[bool] = None queue: Optional[str] = None output_dir: Optional[str] = None + priority_class: Optional[str] = None host_mounts: Optional[list[HostMount]] = None @classmethod @@ -739,24 +747,29 @@ def _build_pod(self) -> Nested[Any]: } ) + spec = dict( + # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. + terminationGracePeriodSeconds=60, + # Fail if any pod fails, and allow retries to happen at JobSet level. + restartPolicy="Never", + nodeSelector={ + "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, + "cloud.google.com/gke-tpu-topology": system.topology, + **selector, + }, + tolerations=tolerations, + containers=[self._build_container()], + initContainers=[self._build_uploader_container()], + serviceAccountName=cfg.service_account, + volumes=volumes, + ) + + if cfg.priority_class: + spec["priorityClassName"] = cfg.priority_class + return dict( metadata=dict(annotations=annotations, labels=labels), - spec=dict( - # NOTE: Don't set hostNetwork or dnsPolicy for compat with Workload Identity. - terminationGracePeriodSeconds=60, - # Fail if any pod fails, and allow retries to happen at JobSet level. - restartPolicy="Never", - nodeSelector={ - "cloud.google.com/gke-tpu-accelerator": system.gke_accelerator, - "cloud.google.com/gke-tpu-topology": system.topology, - **selector, - }, - tolerations=tolerations, - containers=[self._build_container()], - initContainers=[self._build_uploader_container()], - serviceAccountName=cfg.service_account, - volumes=volumes, - ), + spec=spec, ) def _build_job(self) -> Nested[Any]: diff --git a/axlearn/cloud/gcp/job_test.py b/axlearn/cloud/gcp/job_test.py index 5b8b4174a..51b114c6b 100644 --- a/axlearn/cloud/gcp/job_test.py +++ b/axlearn/cloud/gcp/job_test.py @@ -246,6 +246,7 @@ def _job_config( service_account: Optional[str] = None, enable_pre_provisioner: Optional[bool] = None, host_mount_spec: Optional[list[str]] = None, + priority_class: Optional[str] = None, ): with mock_gcp_settings([job.__name__, bundler.__name__], self._mock_settings): fv = flags.FlagValues() @@ -261,6 +262,7 @@ def _job_config( cfg.bundler = bundler_cls.from_spec([], fv=fv).set(image="test-image") cfg.accelerator.instance_type = "tpu-v4-8" cfg.enable_pre_provisioner = enable_pre_provisioner + cfg.priority_class = priority_class yield cfg def test_mount_dataclass(self): @@ -286,7 +288,12 @@ def test_mount_dataclass(self): enable_pre_provisioner=[None, False, True], ) def test_instantiate( - self, reservation, service_account, enable_pre_provisioner, bundler_cls, wrap_bundler + self, + reservation, + service_account, + enable_pre_provisioner, + bundler_cls, + wrap_bundler, ): class WrappedBundler(Bundler): @config_class @@ -352,6 +359,7 @@ class Config(Bundler.Config): location_hint=["test-location-hint", None], enable_tpu_smart_repair=[True, False], host_mount_spec=[["name=host-mount,host_path=/tmp,mount_path=/host-tmp"], None], + priority_class=[None, "such-high-priority"], ) def test_build_pod( self, @@ -364,9 +372,12 @@ def test_build_pod( location_hint: Optional[str] = None, enable_tpu_smart_repair: bool = False, host_mount_spec: Optional[list[str]] = None, + priority_class: Optional[str] = None, ): with mock.patch.dict("os.environ", env), self._job_config( - bundler_cls, host_mount_spec=host_mount_spec + bundler_cls, + host_mount_spec=host_mount_spec, + priority_class=priority_class, ) as cfg: gke_job: job.TPUGKEJob = cfg.set( reservation=reservation, @@ -539,6 +550,11 @@ def test_build_pod( ) self.assertNotIn("cloud.google.com/gke-tpu-auto-restart", labels) + if priority_class is None: + self.assertNotIn("priorityClassName", pod_spec) + else: + self.assertEqual(pod_spec.get("priorityClassName", None), priority_class) + class GPUGKEJobTest(TestCase): @property From 1aa987717f96ad7d688a7fbeafc515e7118e468e Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Mon, 11 Nov 2024 18:36:25 -0800 Subject: [PATCH 16/27] Quantizer returns ids as int32, not float32. (#826) https://github.com/apple/axlearn/pull/823 reveals this hidden bug. In addition, add `lookup` API as downstream started abusing `_lookup` function. Actually this API is needed in inference. --- axlearn/common/quantizer.py | 28 +++++- axlearn/common/quantizer_test.py | 143 ++++++++++++++++++++++++++++++- 2 files changed, 165 insertions(+), 6 deletions(-) diff --git a/axlearn/common/quantizer.py b/axlearn/common/quantizer.py index 373a2b059..fdc9fe948 100644 --- a/axlearn/common/quantizer.py +++ b/axlearn/common/quantizer.py @@ -137,9 +137,28 @@ def forward(self, inputs: Tensor, *, paddings: Tensor) -> Output: Returns: BaseQuantizer.Output. + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. """ raise NotImplementedError(type(self)) + def lookup(self, ids: Tensor) -> Output: + """Codebook look up with ids. + + Args: + ids: integer tensor of shape [..., num_codebooks] with values + in range [0, codebook_size). + + Returns: + BaseQuantizer.Output + * ids: Tensor [..., num_codebooks]. + * quantized_vectors: Tensor [..., num_codebooks, codebook_dim]. + + Raises: + NotImplementedError: if ids.ndim > 11. + """ + return _lookup(ids=ids, codebook=self.parameters["codebook"]) + def _lookup(*, ids: Tensor, codebook: Tensor) -> BaseQuantizer.Output: """Codebook look up with ids. @@ -235,8 +254,10 @@ def _apply_paddings(*, outputs: BaseQuantizer.Output, paddings: Tensor) -> BaseQ Returns: padded_outputs: BaseQuantizer.Output. """ + # ids are padded with -1. - ids = outputs.ids * (1 - paddings)[:, :, None] + (-1) * paddings[:, :, None] + ids_paddings = paddings[:, :, None].astype(outputs.ids.dtype) + ids = outputs.ids * (1 - ids_paddings) + (-1) * ids_paddings quantized_vectors = outputs.quantized_vectors * (1 - paddings)[:, :, None, None] return BaseQuantizer.Output( ids=ids, @@ -618,16 +639,17 @@ def forward( ids = jnp.argmax(logits, axis=-1) if not self.is_training: - outputs = _lookup(ids=ids, codebook=self.parameters["codebook"]) + outputs = self.lookup(ids=ids) outputs = _apply_paddings(outputs=outputs, paddings=paddings) else: # [batch_size, seq_len, 1]. - mask = (1 - paddings)[:, :, None] + mask = (1 - paddings)[:, :, None].astype(ids.dtype) ids = ids * mask + (-1) * (1 - mask) # TODO(dhwang2): optimize memory by scan for long context training. # [batch_size, seq_len, num_codebooks, vocab_size]. onehots = _ids_to_onehots(ids, codebook_size=cfg.codebook_size, dtype=inputs.dtype) # We need this to stop gradients on the padded frames. + mask = mask.astype(inputs.dtype) onehots = onehots * mask[:, :, :, None] # [batch_size, seq_len, num_codebooks, vocab_size]. y_soft = jax.nn.softmax(logits, axis=-1) diff --git a/axlearn/common/quantizer_test.py b/axlearn/common/quantizer_test.py index 465eb4b29..af2ed4a26 100644 --- a/axlearn/common/quantizer_test.py +++ b/axlearn/common/quantizer_test.py @@ -404,9 +404,8 @@ def _loss(params, inputs, paddings, layer=layer): + o_col.summaries["codebook/entropy"].mean ) - np.random.seed(2000) - inputs = np.random.rand(batch_size, seq_len, input_dim).astype(np.float32) - paddings = np.zeros((batch_size, seq_len)).astype(np.float32) + inputs = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, seq_len, input_dim)) + paddings = jnp.zeros((batch_size, seq_len)) _, (grad_params, grad_inputs) = jax.value_and_grad(_loss, argnums=(0, 1), has_aux=False)( layer_params, jnp.asarray(inputs), jnp.asarray(paddings) @@ -414,6 +413,41 @@ def _loss(params, inputs, paddings, layer=layer): self.assertNestedAllClose(grad_params, jax.tree.map(jnp.zeros_like, layer_params)) assert_allclose(grad_inputs, jnp.zeros_like(inputs), atol=1e-6, rtol=1e-6) + def test_lookup(self): + batch_size, seq_len, input_dim = 2, 4, 20 + dim_from_all_codebooks, vocab_size, num_groups = 32, 4, 2 + cfg = RandomVectorQuantizer.default_config().set( + name="test", + input_dim=input_dim, + codebook_dim=dim_from_all_codebooks // num_groups, + codebook_size=vocab_size, + num_codebooks=num_groups, + ) + layer: RandomVectorQuantizer = cfg.instantiate(parent=None) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(1)) + inputs = jax.random.uniform(jax.random.PRNGKey(1), (batch_size, seq_len, input_dim)) + paddings = jnp.zeros((batch_size, seq_len)) + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + class KmeansVectorQuantizerTest(TestCase): @parameterized.product(num_groups=(1, 2), input_mean=(0.0, -0.5)) @@ -467,6 +501,7 @@ def test_forward(self, num_groups, input_mean): outputs.quantized_vectors.shape, ) self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) assert_allclose( expected_outputs[num_groups][input_mean][0], @@ -645,6 +680,54 @@ def _loss(params, inputs, paddings, layer=layer): expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_kmeans, onehots) self.assertNestedAllClose(grad_params, dict(codebook=expected_grad_codebook)) + def test_lookup(self): + num_groups, input_mean = 2, -0.5 + vocab_size, dim_from_all_codebooks = 4, 4 + codebook_dim = dim_from_all_codebooks // num_groups + layer: KmeansVectorQuantizer = ( + KmeansVectorQuantizer.default_config() + .set( + name="test", + codebook_dim=codebook_dim, + codebook_size=vocab_size, + num_codebooks=num_groups, + beta=0.1, + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(0)) + # [vocab_size, num_codebooks, codebook_dim]. + layer_params["codebook"] = jnp.reshape(_CODE_BOOK, [vocab_size, num_groups, codebook_dim]) + batch_size, seq_len = 2, 4 + np.random.seed(2021) + inputs = ( + np.random.rand(batch_size, seq_len, dim_from_all_codebooks).astype(np.float32) + + input_mean + ) + paddings = jnp.arange(seq_len)[None, :] >= jnp.array([2, 3])[:, None] + inputs = inputs * (1 - paddings)[:, :, None] + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(1), + state=layer_params, + drop_output_collections=[], + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + class GumbelSoftmaxVectorQuantizerTest(TestCase): @parameterized.parameters(True, False) @@ -836,6 +919,60 @@ def _loss(params, inputs, paddings, layer=layer): expected_grad_codebook = jnp.einsum("btgh,btgv->vgh", grad_q_vecs, onehots) assert_allclose(grad_params["codebook"], expected_grad_codebook, atol=1e-6, rtol=1e-6) + def test_lookup(self): + dim_from_all_codebooks, vocab_size, num_groups = 15, 5, 3 + input_dim = 10 + step = 5 + begin_step, begin_value, end_step, end_value = 0, 21, 10, 1 + codebook_dim = dim_from_all_codebooks // num_groups + layer: GumbelSoftmaxVectorQuantizer = ( + GumbelSoftmaxVectorQuantizer.default_config() + .set( + name="test", + input_dim=input_dim, + codebook_dim=codebook_dim, + codebook_size=vocab_size, + num_codebooks=num_groups, + temperature_schedule=schedule.polynomial( + begin_step=begin_step, + begin_value=begin_value, + end_step=end_step, + end_value=end_value, + ), + ) + .instantiate(parent=None) + ) + layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) + layer_params["step"] = step + batch_size, seq_len = 2, 4 + np.random.seed(2021) + inputs = np.random.rand(batch_size, seq_len, input_dim).astype(np.float32) + paddings = np.array( + np.arange(seq_len)[None, :] >= np.array([2, 3])[:, None], dtype=np.float32 + ) + inputs = inputs * (1 - paddings)[:, :, None] + outputs, _ = F( + layer, + inputs=dict(inputs=inputs, paddings=paddings), + is_training=True, + prng_key=jax.random.PRNGKey(1), + state=layer_params, + drop_output_collections=[], + ) + self.assertEqual((batch_size, seq_len, num_groups), outputs.ids.shape) + self.assertEqual(jnp.int32, outputs.ids.dtype) + + lookup_outputs, _ = F( + layer, + inputs=dict(ids=outputs.ids), + is_training=True, + prng_key=jax.random.PRNGKey(10), + state=layer_params, + method="lookup", + ) + quantized_vectors = lookup_outputs.quantized_vectors * (1 - paddings)[:, :, None, None] + self.assertNestedAllClose(quantized_vectors, outputs.quantized_vectors) + if __name__ == "__main__": absltest.main() From 342a3a511e5d384298ef7080fc8d4c4a4c696581 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Mon, 11 Nov 2024 21:01:12 -0800 Subject: [PATCH 17/27] Add bf16 test to subsampler. (#827) There was a report about a bf16-related bug in the downstream, so I created a unit test to demonstrate that axlearn doesn't have the bug. --- axlearn/audio/subsamplers_test.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/axlearn/audio/subsamplers_test.py b/axlearn/audio/subsamplers_test.py index d90b51fb5..5d462ff5d 100644 --- a/axlearn/audio/subsamplers_test.py +++ b/axlearn/audio/subsamplers_test.py @@ -7,7 +7,7 @@ from typing import Optional, Union import jax -from absl.testing import parameterized +from absl.testing import absltest, parameterized from jax import numpy as jnp from axlearn.audio.subsamplers import ConvSubSampler @@ -187,7 +187,8 @@ def test_paddings( self.assertEqual(tuple(subsampled_shape), outputs["outputs"].shape) self.assertEqual(tuple(subsampled_shape)[:2], outputs["paddings"].shape) - def test_activation_summaries(self): + @parameterized.parameters(jnp.float32, jnp.bfloat16) + def test_activation_summaries(self, dtype): """Tests that activation summaries behave as expected.""" input_dim, num_filters, hidden_dim, output_dim = 1, 80, 12, 8 prng_key = jax.random.PRNGKey(567) @@ -195,10 +196,12 @@ def test_activation_summaries(self): # Initialize layer parameters. cfg = ConvSubSampler.default_config().set( - input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim + input_dim=input_dim, hidden_dim=hidden_dim, output_dim=output_dim, dtype=dtype ) layer = cfg.set(name="test").instantiate(parent=None) layer_params = layer.initialize_parameters_recursively(init_key) + dtypes, _ = jax.tree.flatten(jax.tree.map(jnp.dtype, layer_params)) + self.assertTrue(all(dt == dtype for dt in dtypes)) # Build inputs. batch_size, num_frames = 4, 10 @@ -206,6 +209,8 @@ def test_activation_summaries(self): inputs = jax.random.normal(key=data_key, shape=inputs_shape) * 10.0 lengths = jnp.array([5, 10, 9, 0]) paddings = jnp.arange(num_frames)[None, :] >= lengths[:, None] + inputs = inputs.astype(dtype) + paddings = paddings.astype(dtype) outputs, output_collections = F( layer, inputs=dict(inputs=inputs, paddings=paddings), @@ -247,9 +252,14 @@ def test_activation_summaries(self): expected_outputs_norm, ) self.assertNestedAllClose( - output_collections.summaries["activations/subsampler_inputs_mean"].weight, input_weights + output_collections.summaries["activations/subsampler_inputs_mean"].weight, + input_weights.astype(dtype), ) self.assertNestedAllClose( output_collections.summaries["activations/subsampler_outputs_norm"].weight, output_weights, ) + + +if __name__ == "__main__": + absltest.main() From 25025d1d6646f4f7df9177a1ac5a4600995af890 Mon Sep 17 00:00:00 2001 From: kelvin-zou <166073445+kelvin-zou@users.noreply.github.com> Date: Mon, 11 Nov 2024 21:52:16 -0800 Subject: [PATCH 18/27] Speed up FA Backward pass in GPU via parallelizing sequence dimension (#818) * snapshot * address comments --- .../common/flash_attention/gpu_attention.py | 243 +++++++++++------- 1 file changed, 147 insertions(+), 96 deletions(-) diff --git a/axlearn/common/flash_attention/gpu_attention.py b/axlearn/common/flash_attention/gpu_attention.py index 915914cc1..c1b19106e 100644 --- a/axlearn/common/flash_attention/gpu_attention.py +++ b/axlearn/common/flash_attention/gpu_attention.py @@ -264,7 +264,9 @@ def bias_index_map(_, j, k): num_warps_ = 4 if head_dim <= 64 else 8 num_stages_ = num_stages if num_stages_ is None: - num_stages_ = 2 if head_dim <= 64 else 1 + num_stages_ = ( + 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 + ) kernel = functools.partial( _mha_forward_kernel, softmax_scale=softmax_scale, @@ -358,7 +360,9 @@ def bias_index_map(_, j, k): num_warps_ = 4 if head_dim <= 64 else 8 num_stages_ = num_stages if num_stages_ is None: - num_stages_ = 2 if head_dim <= 64 else 1 + num_stages_ = ( + 2 if bias is None and jnp.float32 not in (query.dtype, key.dtype, value.dtype) else 1 + ) kernel = functools.partial( _mha_forward_kernel, softmax_scale=softmax_scale, @@ -494,7 +498,6 @@ def _mha_backward_kernel( l_ref, m_ref, delta_ref, - _, # Outputs. dq_ref, dk_ref, @@ -509,9 +512,14 @@ def _mha_backward_kernel( """Computes the backward pass. This algorithm is described in https://arxiv.org/abs/2205.14135 Appendix B.4 Algorithm 4. + Jax reference implementation: + https://github.com/jax-ml/jax/blob/0995bc231c51e2ee66995be8ee2b31adf9236509/jax/experimental/pallas/ops/gpu/attention.py#L343 See also `_mha_forward_kernel` for the forward pass. + The main difference between ours and jax reference implementation is that it supports 4-d bias, + and it supports float32 in the input dtype. + Args: q_ref: Input query ref. k_ref: Input key ref. @@ -536,71 +544,123 @@ def _mha_backward_kernel( del out_ref, l_ref # Not needed seq_len = q_ref.shape[0] - def outer_loop(start_k, _): - dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) - dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + # Parallelize over k/v's seq dimension. + # Load a block of K and V of size (block_k, block_d). + # Iterate through Q in chunks of (block_q, block_d) to accumulate dK and dV. + start_k = pl.program_id(2) + slice_k = pl.ds(start_k * block_k, block_k) + dv = jnp.zeros([block_k, block_d], dtype=jnp.float32) + dk = jnp.zeros([block_k, block_d], dtype=jnp.float32) + k = pl.load(k_ref, (slice_k, slice(None))) + v = pl.load(v_ref, (slice_k, slice(None))) + span_k = start_k * block_k + jnp.arange(block_k) + kv_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_k,)) + + def inner_loop_dk_dv(start_q, carry): + dv, dk = carry + slice_q = pl.ds(start_q * block_q, block_q) + q = pl.load(q_ref, (slice_q, slice(None))) + qk = pl.dot(q, k.T) + # These casts are needed to avoid precision issues. + qk = qk.astype(jnp.float32) + + if softmax_scale != 1.0: + qk *= softmax_scale + + if b_ref is not None: + # Load bias in transposed order, for hopefully better cache efficiency. + b = pl.load( + b_ref, + (slice_k, slice_q), + ) + b = b.astype(jnp.float32) + qk += b.T # Transpose back. + if s_ref is not None: + q_segment_ids = pl.load(s_ref, (slice_q,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk = jnp.where(mask, qk, NEG_INF) + if causal: + span_q = start_q * block_q + jnp.arange(block_q) + mask = span_q[:, None] >= span_k[None, :] + qk = jnp.where(mask, qk, NEG_INF) + m = pl.load(m_ref, (slice_q,)) + p = jnp.exp(qk - m[:, None]) + do = pl.load(do_scaled_ref, (slice_q, slice(None))) + dv = dv + pl.dot(p.astype(do.dtype).T, do) + di = pl.load(delta_ref, (slice_q,)) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if softmax_scale != 1.0: + ds = ds * softmax_scale + dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) + + return dv, dk + + lower_bound = lax.div(start_k * block_k, block_q) if causal else 0 + dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop_dk_dv, (dv, dk)) + pl.store(dv_ref, (slice_k, slice(None)), dv.astype(dv_ref.dtype)) + pl.store(dk_ref, (slice_k, slice(None)), dk.astype(dk_ref.dtype)) + # Free up memory. + del dv, dk + + # Parallelize over q's seq dimension. + # 1. Load a block of Q of size (block_q, block_d). + # 2. Iterate through K and V in chunks of (block_k, block_d) to accumulate dQ. + start_q = pl.program_id(2) + slice_q = pl.ds(start_q * block_q, block_q) + q = pl.load(q_ref, (slice_q, slice(None))) + dq = jnp.zeros([block_q, block_d], dtype=jnp.float32) + q_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_q,)) + span_q = start_q * block_q + jnp.arange(block_q) + m = pl.load(m_ref, (slice_q,)) + di = pl.load(delta_ref, (slice_q,)) + do = pl.load(do_scaled_ref, (slice_q, slice(None))) + + def inner_loop_dq(start_k, carry): + dq = carry slice_k = pl.ds(start_k * block_k, block_k) k = pl.load(k_ref, (slice_k, slice(None))) v = pl.load(v_ref, (slice_k, slice(None))) - span_k = start_k * block_k + jnp.arange(block_k) - kv_segment_ids = None if s_ref is None else pl.load(s_ref, (slice_k)) - - def inner_loop(start_q, carry): - dv, dk = carry - slice_q = pl.ds(start_q * block_q, block_q) - q = pl.load(q_ref, (slice_q, slice(None))) - qk = pl.dot(q, k.T) - - # These casts are needed to avoid precision issues. - qk = qk.astype(jnp.float32) - - if softmax_scale != 1.0: - qk *= softmax_scale - if b_ref is not None: - # Load bias in transposed order, for hopefully better cache efficiency. - b = pl.load( - b_ref, - (slice_k, slice_q), - ) - b = b.astype(jnp.float32) - qk += b.T # Transpose back. - if s_ref is not None: - q_segment_ids = pl.load(s_ref, (slice_q)) - mask = _segment_mask(q_segment_ids, kv_segment_ids) - qk = jnp.where(mask, qk, NEG_INF) - if causal: - span_q = start_q * block_q + jnp.arange(block_q) - mask = span_q[:, None] >= span_k[None, :] - qk = jnp.where(mask, qk, NEG_INF) - m = pl.load(m_ref, (slice_q,)) - p = jnp.exp(qk - m[:, None]) - do = pl.load(do_scaled_ref, (slice_q, slice(None))) - dv = dv + pl.dot(p.astype(do.dtype).T, do) - di = pl.load(delta_ref, (slice_q,)) - dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] - dp = dp + pl.dot(do, v.T) - ds = p * dp - if softmax_scale != 1.0: - ds = ds * softmax_scale - dk = dk + pl.dot(ds.astype(q_ref.dtype).T, q) - dq = pl.load( - dq_ref, - (slice_q, slice(None)), - eviction_policy="evict_last", - ) - dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) - pl.store(dq_ref, (slice_q, slice(None)), dq, eviction_policy="evict_last") - return dv, dk + qk = pl.dot(q, k.T) + + # These casts are needed to avoid precision issues. + qk = qk.astype(jnp.float32) + if softmax_scale != 1.0: + qk *= softmax_scale + if b_ref is not None: + # Load bias in transposed order, for hopefully better cache efficiency. + b = pl.load( + b_ref, + (slice_k, slice_q), + ) + b = b.astype(jnp.float32) + qk += b.T # Transpose back. + if s_ref is not None: + kv_segment_ids = pl.load(s_ref, (slice_k,)) + mask = _segment_mask(q_segment_ids, kv_segment_ids) + qk = jnp.where(mask, qk, NEG_INF) if causal: - lower_bound = lax.div(start_k * block_k, block_q) - else: - lower_bound = 0 - dv, dk = lax.fori_loop(lower_bound, pl.cdiv(seq_len, block_q), inner_loop, (dv, dk)) - pl.store(dv_ref, (slice_k, slice(None)), dv.astype(dv_ref.dtype)) - pl.store(dk_ref, (slice_k, slice(None)), dk.astype(dk_ref.dtype)) + span_k = start_k * block_k + jnp.arange(block_k) + mask = span_q[:, None] >= span_k[None, :] + qk = jnp.where(mask, qk, NEG_INF) + p = jnp.exp(qk - m[:, None]) + dp = jnp.zeros((block_q, block_k), dtype=jnp.float32) - di[:, None] + dp = dp + pl.dot(do, v.T) + ds = p * dp + if softmax_scale != 1.0: + ds = ds * softmax_scale + dq = dq + pl.dot(ds.astype(k.dtype), k).astype(dq.dtype) + return dq + + if causal: + upper_bound = lax.div((start_q + 1) * block_q, block_k) + else: + upper_bound = pl.cdiv(seq_len, block_k) - lax.fori_loop(0, pl.cdiv(seq_len, block_k), outer_loop, None) + dq = lax.fori_loop(0, upper_bound, inner_loop_dq, (dq)) + pl.store(dq_ref, (slice_q, slice(None)), dq.astype(dq_ref.dtype)) def _mha_backward( @@ -624,8 +684,9 @@ def _mha_backward( # NOTE: temporarily removed the "xla" branch, which seems unused. if backward_pass_impl == "triton": # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. - if jnp.float32 in (q.dtype, k.dtype, v.dtype): + if jnp.float32 in (q.dtype, k.dtype, v.dtype, jnp.bfloat16 if b is None else b.dtype): block_q = block_k = 32 + batch_size, seq_len, num_heads, head_dim = q.shape # Backward heuristics, using the same block size for block q and block k. block_q = min(block_q, seq_len) @@ -633,47 +694,36 @@ def _mha_backward( # Very tiny amount of time, not worth using pallas_call. do_scaled, delta = _preprocess_backward(out, do, l, block_q, debug, interpret) # We accumulate into dq so we need to initialize it to zeros. - dq = jnp.zeros(q.shape, jnp.float32) out_shapes = [ - jax.ShapeDtypeStruct(dq.shape, dq.dtype), + jax.ShapeDtypeStruct(q.shape, q.dtype), jax.ShapeDtypeStruct(k.shape, k.dtype), jax.ShapeDtypeStruct(v.shape, v.dtype), ] - num_input = 8 - # Bias. bias_block_spec = None if b is not None: assert b.ndim == 4 b = jnp.moveaxis(b, -1, -2) - # We must shrink the block size for float32 inputs to avoid OOM during bwd pass. - if b.dtype == jnp.float32: - block_q = block_k = 32 - def bias_index_map(j, k): + def bias_index_map(j, k, _): return (j if b.shape[0] != 1 else 0, k if b.shape[1] != 1 else 0, 0, 0) bias_block_spec = pl.BlockSpec( index_map=bias_index_map, block_shape=(None, None, seq_len, seq_len) ) - num_input += 1 # Segment Ids. segment_ids_block_spec = None if s is not None: assert s.ndim == 2 segment_ids_block_spec = pl.BlockSpec( - index_map=(lambda j, k: (j, 0)), block_shape=(None, seq_len) + index_map=(lambda j, k, _: (j, 0)), block_shape=(None, seq_len) ) - num_input += 1 - - input_output_aliases = {num_input: 0} - - grid = (batch_size, num_heads) - # TODO(markblee): num_warps=8 seems to work from basic testing, confirm the below comment. - # TODO(sharadmv): figure out why num_warps=8 doesn't work! + grid = (batch_size, num_heads, pl.cdiv(seq_len, block_q)) + # Add some proof check against SRAM for float32 inputs or huge bias input. num_warps = 8 + num_stages = 2 if b is None and jnp.float32 not in (q.dtype, k.dtype, v.dtype) else 1 dq, dk, dv = pl.pallas_call( functools.partial( _mha_backward_kernel, @@ -687,55 +737,56 @@ def bias_index_map(j, k): out_shape=out_shapes, in_specs=[ pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), # query pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), # key pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), # value bias_block_spec, # bias segment_ids_block_spec, # segment_ids pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), - pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), - pl.BlockSpec(index_map=(lambda j, k: (j, k, 0)), block_shape=(None, None, seq_len)), pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), - block_shape=(None, seq_len, None, head_dim), + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) + ), + pl.BlockSpec( + index_map=(lambda j, k, _: (j, k, 0)), block_shape=(None, None, seq_len) ), ], out_specs=[ pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), pl.BlockSpec( - index_map=(lambda j, k: (j, 0, k, 0)), + index_map=(lambda j, k, _: (j, 0, k, 0)), block_shape=(None, seq_len, None, head_dim), ), ], name="mha_backward", debug=debug, interpret=interpret, - compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=1), - input_output_aliases=input_output_aliases, - )(q, k, v, b, s, out, do_scaled, l, m, delta, dq) + compiler_params=plgpu.TritonCompilerParams(num_warps=num_warps, num_stages=num_stages), + )(q, k, v, b, s, out, do_scaled, l, m, delta) else: raise ValueError(f"Invalid backward pass implementation: {backward_pass_impl}") return dq.astype(q.dtype), dk, dv, None, None From 6433a18ff899a8e6db9ed9502d2a392459e310a3 Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Tue, 12 Nov 2024 07:59:21 -0800 Subject: [PATCH 19/27] Removes legacy bias check for flash attention. (#829) --- axlearn/common/flash_attention/layer.py | 19 ++---------- axlearn/common/flash_attention/layer_test.py | 32 ++++++++++++++------ 2 files changed, 24 insertions(+), 27 deletions(-) diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index b5b57dfe2..5588ae1d6 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """FlashAttention layers.""" + from collections.abc import Sequence from typing import Optional @@ -18,8 +19,7 @@ causal_mask, make_segment_mask, ) -from axlearn.common.base_layer import BaseLayer -from axlearn.common.config import ConfigBase, config_class +from axlearn.common.config import config_class from axlearn.common.flash_attention.utils import ( MultiHeadAttentionImpl, flash_attention_implementation, @@ -28,20 +28,6 @@ from axlearn.common.utils import Tensor, with_sharding_constraint -def _check_bias_recursively(cfg: ConfigBase): - """Ensures that `cfg.bias` is set to False for all descendants.""" - - def visit_fn(_, value): - if isinstance(value, BaseLayer.Config) and getattr(value, "bias", False): - raise NotImplementedError("cfg.bias is not yet supported.") - - def enter_fn(_, value, default_kv): - return None if isinstance(value, BaseLayer.Config) and "bias" in value else default_kv - - cfg.visit(visit_fn=visit_fn, enter_fn=enter_fn) - return cfg - - class FlashAttention(GroupedQueryAttention): """FlashAttention layer. @@ -87,7 +73,6 @@ class Config(GroupedQueryAttention.Config): def __init__(self, cfg: Config, *, parent: Module): super().__init__(cfg, parent=parent) cfg = self.config - _check_bias_recursively(cfg) # Bias not supported. if getattr(cfg, "atten_logit_cap", None) is not None: raise NotImplementedError("cfg.atten_logit_cap is not supported.") # TODO(kelvinzou): enable dropout for flash attention. diff --git a/axlearn/common/flash_attention/layer_test.py b/axlearn/common/flash_attention/layer_test.py index 0a557cb96..52d1bc4ec 100644 --- a/axlearn/common/flash_attention/layer_test.py +++ b/axlearn/common/flash_attention/layer_test.py @@ -1,6 +1,7 @@ # Copyright © 2023 Apple Inc. """Tests FlashAttention layers.""" + import math import os from unittest import mock @@ -89,6 +90,7 @@ def _prepare_layers( causal, sliding_window_size, inference=False, + set_layer_bias_recursively=False, ): hidden_dim = num_heads * per_head_dim kwargs = dict( @@ -124,8 +126,8 @@ def _prepare_layers( ref_cfg.set(causal=causal) test_cfg.set(causal=causal) - set_bias_recursively(ref_cfg, False) - set_bias_recursively(test_cfg, False) + set_bias_recursively(ref_cfg, set_layer_bias_recursively) + set_bias_recursively(test_cfg, set_layer_bias_recursively) ref_layer = ref_cfg.set(name="ref").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) @@ -421,6 +423,7 @@ def test_forward( sliding_window_size=[None, 4], use_bias=[False, True], use_segment_ids=[False, True], + set_layer_bias_recursively=[False, True], ) def test_backward( self, @@ -435,6 +438,7 @@ def test_backward( sliding_window_size, use_bias, use_segment_ids, + set_layer_bias_recursively, ): if not is_supported_mesh_shape(mesh): pytest.skip(reason=f"Unsupported mesh {mesh}.") @@ -496,17 +500,17 @@ def forward(self, *, query, key, value, attention_logit_biases, segment_ids): layer=GroupedQueryAttention.default_config().set(**kwargs), ) test_cfg = DummyModel.default_config().set( - layer=FlashAttention.default_config() - .set(**kwargs, tpu_block_size=128) - .set( + layer=FlashAttention.default_config().set( + tpu_block_size=128, mha_dim_to_partition_spec=default_mha_dim_to_partition_spec(mesh_axis_names), output_dim_to_partition_spec=default_output_dim_to_partition_spec( mesh_axis_names ), + **kwargs, ) ) - set_bias_recursively(ref_cfg, False) - set_bias_recursively(test_cfg, False) + set_bias_recursively(ref_cfg, set_layer_bias_recursively) + set_bias_recursively(test_cfg, set_layer_bias_recursively) ref_layer = ref_cfg.set(name="ref").instantiate(parent=None) test_layer = test_cfg.set(name="test").instantiate(parent=None) # pylint: disable-next=protected-access @@ -541,10 +545,18 @@ def loss(params, inputs, layer): ref_value, ref_grads = jax.value_and_grad(loss)(params, ref_inputs, ref_layer) test_value, test_grads = jax.value_and_grad(loss)(params, inputs, test_layer) + + # Have slightly higher diffs with layer bias on GPU. We don't see this on TPU or CPU. + # pylint: disable-next=protected-access + if set_layer_bias_recursively and test_layer.layer._backend() == "gpu": + atol, rtol = 5e-4, 5e-2 + # Can be 1e-5 on x86_64/GPU/TPU, needed to be slightly higher on ARM. - atol = 1e-4 - self.assertNestedAllClose(ref_value, test_value, atol=atol) - self.assertNestedAllClose(ref_grads, test_grads, atol=atol) + else: + atol, rtol = 1e-4, 1e-3 + + self.assertNestedAllClose(ref_value, test_value, atol=atol, rtol=rtol) + self.assertNestedAllClose(ref_grads, test_grads, atol=atol, rtol=rtol) jax.clear_backends() @parameterized.product(_TEST_CONFIGS, causal=[True], sliding_window_size=[None, 4]) From 58fb213d875a8414ac2ac95ad9bc15ba351ffc02 Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Tue, 12 Nov 2024 18:56:45 -0800 Subject: [PATCH 20/27] Skip dst dir creation if no tf savables. (#830) --- axlearn/common/checkpointer.py | 8 ++++++-- axlearn/common/checkpointer_test.py | 13 +++++++++++++ 2 files changed, 19 insertions(+), 2 deletions(-) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index 4f1f06e92..bd1eec98d 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -154,8 +154,12 @@ def _upload_dir(src_dir_handle: tempfile.TemporaryDirectory, *, dst_dir: str): Temporary dir will be deleted after the upload is complete. """ src_dir = src_dir_handle.name - fs.makedirs(dst_dir) - for item in fs.listdir(src_dir): + src_files = fs.listdir(src_dir) + # src_files will be empty if there are no tf savables (i.e., don't have any tf state to save). + # In this case, do not create empty dst_dirs. + if len(src_files): + fs.makedirs(dst_dir) + for item in src_files: src_file = os.path.join(src_dir, item) dst_file = os.path.join(dst_dir, item) assert not fs.isdir(src_file) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 9c117e1fb..b164cbb2b 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -1088,6 +1088,19 @@ def test_restored_iterator_resumes(self): # should continue from the interruption. self.assertSetEqual(set(seen), set(range(num_examples))) + def test_no_save_input_iterator(self): + executor = ThreadPoolExecutor(1) + tmpdir = tempfile.mkdtemp() + ckpt_dir = os.path.join(tmpdir, "tf_ckpt") + self.assertEqual(0, len(fs.listdir(tmpdir))) + # Test that when we don't save input iterator, tf dirs are not created. + async_save_tf_savables({}, executor=executor, dir=ckpt_dir) + self.assertEqual([], fs.listdir(tmpdir)) + # Test that dirs are created if we save. + ds = tf.data.Dataset.from_tensor_slices([]) + async_save_tf_savables({"it": iter(ds)}, executor=executor, dir=ckpt_dir) + self.assertEqual(["tf_ckpt"], fs.listdir(tmpdir)) + SWITCHABLE_VDICT_IMPL: Optional[type[VDict]] = None From 76afcc5fef520434572a55ae9518ab63917faca2 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 13 Nov 2024 11:58:01 -0800 Subject: [PATCH 21/27] Transformer extend_step supports multi steps generation. (#831) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Streaming encoder and streaming synthesizer require multi-step extend_step: * First, because the input and output modalities are different (e.g., audio in, token out), there’s no need to process one step at a time in an autoregressive manner. * Moreover, depending on the latency requirements, processing multiple steps at once is beneficial for both throughput and latency (Real-Time Factor, RTF). This is likely a common requirement for multimodal streaming encoders and synthesizers. Additionally, it serves as a prerequisite for speculative decoding or a funnel transformer decoder. * Performance benchmark I benchmarked it on TPUv4 in the Notebook TPU. This change doesn't change performance much (some little faster, other little slower). ASIS --------------------------------------------------------------------------------------- Benchmark Time CPU Iterations --------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.22 ms 0.444 ms 1497 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.29 ms 0.494 ms 927 QkvLinearExtendStepBenchmark/2048/16/32768/1 23.6 ms 1.07 ms 158 QkvLinearExtendStepBenchmark/2048/16/4096/8 N/A Note: multi step benchmark QkvLinearExtendStepBenchmark/2048/16/4096/64 N/A QkvLinearExtendStepBenchmark/2048/16/4096/512 N/A This PR ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.70 ms 0.513 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.40 ms 0.519 ms 1174 QkvLinearExtendStepBenchmark/2048/16/32768/1 20.1 ms 0.930 ms 404 QkvLinearExtendStepBenchmark/2048/16/4096/8 3.68 ms 0.524 ms 1139 QkvLinearExtendStepBenchmark/2048/16/4096/64 3.74 ms 0.532 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/512 2530 ms 80.4 ms 1 If remove the weird moveaxis hack, there is further speed up, especially when step size is big (512). This PR w/o moveaxis ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.52 ms 0.542 ms 1082 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.18 ms 0.547 ms 1096 QkvLinearExtendStepBenchmark/2048/16/32768/1 19.6 ms 0.824 ms 430 QkvLinearExtendStepBenchmark/2048/16/4096/8 3.34 ms 0.542 ms 1139 QkvLinearExtendStepBenchmark/2048/16/4096/64 3.48 ms 0.553 ms 1091 QkvLinearExtendStepBenchmark/2048/16/4096/512 36.5 ms 1.71 ms 71 --- axlearn/common/attention.py | 54 ++++++++++++++++++-------------- axlearn/common/attention_test.py | 51 ++++++++++++++++-------------- 2 files changed, 57 insertions(+), 48 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 7236ef6e4..7ac168697 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -724,6 +724,7 @@ def init_states( dtype = cfg.cache_dtype or self.dtype() assert dtype is not None + # TODO(dhwang2): Use [BTNH], because our benchmark shows the current is slower. # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take # advantage of TPU optimizations (see `extend_step`). # Reference: @@ -831,6 +832,8 @@ def prefill_states( k_proj = k_proj * time_step_mask v_proj = v_proj * time_step_mask + # TODO(dhwang2): remove this unnecessary transpose, because our benchmark shows it + # slow down. # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take # advantage of TPU optimizations (see `extend_step`). # Reference: @@ -861,8 +864,8 @@ def extend_step( previous attentions, and index used for fast decoding. Contains "key" and "value" of shape [batch, num_heads, per_head_dim, target_length], and a Tensor "time_step" of shape [batch]. - query: Tensor of shape [batch, 1, target_dim] corresponding to query vector at - "time_step" indices. + query: Tensor of shape [batch, steps, target_dim] corresponding to query vector starting + at "time_step" indices. key: An optional Tensor of shape [batch, source_length, source_dim]. If None, will use `query`. value: An optional Tensor of shape [batch, source_length, source_dim]. If None, will @@ -884,40 +887,42 @@ def extend_step( kv_kwargs = dict(kv_state=kv_state) else: kv_kwargs = dict(key=key, value=value) - # Project inputs to key, value and query. Each has shape [B, 1, N, H]. + num_query_steps = query.shape[1] + # Project inputs to key, value and query. Each has shape [B, steps, N, H]. q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, time_step=time_step) - - updated_state = dict(time_step=time_step + 1) + updated_state = dict(time_step=time_step + num_query_steps) if kv_state is None: - # Move the length axis to the back. This allows us to update the cache key, value with - # the "scatter via one-hot broadcast" trick, rather than a scatter/gather operation. - # Profiling suggests moveaxis is competitive with tweaking einsum in `i_proj` -- it's - # also a bit simpler, so we keep it for now. - # [B, 1, N, H] --> [B, N, H, 1]. + # TODO(dhwang2): remove this unnecessary transpose. + # [B, S, N, H] --> [B, N, H, S]. k_proj = jnp.moveaxis(k_proj, -3, -1) v_proj = jnp.moveaxis(v_proj, -3, -1) # Update the cache via one-hot broadcast and addition. cached_key = cached_states["key"] cached_value = cached_states["value"] - target_len = cached_key.shape[-1] - oh_indices = jax.nn.one_hot(time_step, target_len, dtype=k_proj.dtype) - # [B, 1, 1, T] to broadcast. - oh_indices = oh_indices[:, None, None, :] - negated_oh_indices = (1 - oh_indices).astype(cached_key.dtype) # Ensure that we accumulate using the original dtype. - new_k_proj = (cached_key * negated_oh_indices) + (k_proj * oh_indices).astype( - cached_key.dtype + k_proj = k_proj.astype(cached_key.dtype) + v_proj = v_proj.astype(cached_value.dtype) + + # Function to update the cached_key for a single batch element. + def update_single(cached_key_slice, k_proj_slice, time_idx): + start_indices = (0, 0, time_idx) + return jax.lax.dynamic_update_slice(cached_key_slice, k_proj_slice, start_indices) + + # Use jax.vmap to vectorize over the batch dimension. + new_cached_key = jax.vmap(update_single, in_axes=(0, 0, 0))( + cached_key, k_proj, time_step ) - new_v_proj = (cached_value * negated_oh_indices) + (v_proj * oh_indices).astype( - cached_value.dtype + new_cached_value = jax.vmap(update_single, in_axes=(0, 0, 0))( + cached_value, v_proj, time_step ) + # TODO(dhwang2): remove this unnecessary transpose. # Move back to original [B, T, N, H] layout. - k_proj = jnp.moveaxis(new_k_proj, -1, -3) - v_proj = jnp.moveaxis(new_v_proj, -1, -3) + k_proj = jnp.moveaxis(new_cached_key, -1, -3) + v_proj = jnp.moveaxis(new_cached_value, -1, -3) - updated_state.update(key=new_k_proj, value=new_v_proj) + updated_state.update(key=new_cached_key, value=new_cached_value) return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) @@ -1391,8 +1396,9 @@ def forward( else: # Time step shape is [batch_size] # The expected input shape for rope_pos_emb_layer is [batch_size, seq_len] - # Therefore, expanding the shape of time_step to [batch_size, 1] - time_step = jnp.expand_dims(time_step, 1) + # Therefore, expanding the shape of time_step to [batch_size, step]. + step = query.shape[1] + time_step = jnp.arange(step)[None] + time_step[:, None] sinusoidal_pos_emb = self.rope_pos_emb_layer.forward(time_step).astype(query.dtype) # sinusoidal_pos_emb shape should be [batch_size, seq_len, 1, dim] sinusoidal_pos_emb = jnp.expand_dims(sinusoidal_pos_emb, 2) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index b8fda30a4..a7a38bd3f 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -1429,16 +1429,20 @@ def test_qlinear(self): self.assertNestedAllClose(outputs[layer_a], outputs[layer_b]) @parameterized.parameters( - attention.QKVLinear, - attention.FusedQKVLinear, - attention.GroupedQKVLinear, - attention.FusedGroupedQKVLinear, - attention.RoFormerQKVLinear, + (attention.QKVLinear, 1), + (attention.FusedQKVLinear, 1), + (attention.GroupedQKVLinear, 1), + (attention.FusedGroupedQKVLinear, 1), + (attention.RoFormerQKVLinear, 1), + (attention.QKVLinear, 2), + (attention.FusedQKVLinear, 3), + (attention.GroupedQKVLinear, 4), + (attention.FusedGroupedQKVLinear, 3), + (attention.RoFormerQKVLinear, 2), ) - def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear]): + def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], stride): """Tests that calling QKVLinear.extend_step() multiple times with the same time_step results in the same output.""" - model_dim = 8 num_heads = 2 per_head_dim = model_dim // num_heads @@ -1460,34 +1464,33 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear]): batch_size, tgt_len = 2, 4 query = jax.random.uniform(jax.random.PRNGKey(0), [batch_size, tgt_len, model_dim]) - extend_step_state, _ = F( + fwd_output, _ = F( layer, state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(456), - inputs=dict(target_batch_size=batch_size, target_max_len=tgt_len), - method="init_states", + inputs=dict(query=query), ) - for t in range(tgt_len): - (first_call_state, first_call_output), _ = F( - layer, - state=layer_state, - is_training=False, - prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=extend_step_state, query=query[:, t : t + 1]), - method="extend_step", - ) - # Rewind the time_step. - first_call_state["time_step"] -= 1 - (extend_step_state, second_call_output), _ = F( + + cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) + step_querys = [] + step_keys = step_values = None + for t in range(0, tgt_len, stride): + (cache_state, step_output), _ = F( layer, state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=first_call_state, query=query[:, t : t + 1]), + inputs=dict(cached_states=cache_state, query=query[:, t : t + stride]), method="extend_step", ) - self.assertNestedAllClose(first_call_output, second_call_output) + step_querys.append(step_output.query) + step_keys = step_output.key + step_values = step_output.value + + self.assertNestedAllClose(fwd_output.query, jnp.concat(step_querys, axis=1)) + self.assertNestedAllClose(fwd_output.key, step_keys) + self.assertNestedAllClose(fwd_output.value, step_values) @parameterized.parameters(jnp.float32, jnp.float16, jnp.bfloat16) def test_dtypes_inherited_from_parent(self, dtype: jnp.dtype): From ef1dc596b3377a791ab7cf44a65c5388d2625691 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 13 Nov 2024 15:40:44 -0800 Subject: [PATCH 22/27] Implement sequence_mask(). (#832) axlearn version of tf.sequence_mask. https://www.tensorflow.org/api_docs/python/tf/sequence_mask --- axlearn/common/utils.py | 22 ++++++++++++++++++++++ axlearn/common/utils_test.py | 15 ++++++++++++++- 2 files changed, 36 insertions(+), 1 deletion(-) diff --git a/axlearn/common/utils.py b/axlearn/common/utils.py index 37b88d56c..2c755a325 100644 --- a/axlearn/common/utils.py +++ b/axlearn/common/utils.py @@ -1419,3 +1419,25 @@ class DeviceUsage: hbm_memory_usage_bytes: Optional[int] = None hbm_memory_total_bytes: Optional[int] = None hbm_memory_bandwidth_utilization: Optional[float] = None + + +def sequence_mask(*, lengths: Tensor, max_len: int, dtype: Optional[jnp.dtype] = None) -> Tensor: + """Computes a mask over sequence positions for each given length. + + Args: + lengths: [...]. int32 + max_len: T, int + dtype: outputs dtype. + + Returns: + Tensor [..., T]. 1 is valid and 0 is padding. + """ + if dtype is None: + dtype = lengths.dtype + + prefix_axis = tuple(range(lengths.ndim)) + # [..., T] + sequence = jnp.expand_dims(jnp.arange(max_len), axis=prefix_axis) + # [..., 1] + lengths = lengths[..., jnp.newaxis] + return (sequence < lengths).astype(dtype) diff --git a/axlearn/common/utils_test.py b/axlearn/common/utils_test.py index 85d557f94..3b6311557 100644 --- a/axlearn/common/utils_test.py +++ b/axlearn/common/utils_test.py @@ -21,7 +21,7 @@ from jax.experimental import checkify, mesh_utils from jax.sharding import PartitionSpec -from axlearn.common import learner, optimizers, serialization, struct +from axlearn.common import learner, optimizers, serialization, struct, utils from axlearn.common.base_layer import BaseLayer, FactorizationSpec, ParameterSpec from axlearn.common.config import config_class, config_for_function, similar_names from axlearn.common.layers import BatchNorm, LayerNorm, Linear @@ -761,6 +761,19 @@ def test_check_jax_type(self): with self.assertRaisesRegex(ValueError, "^Argument key has leaf with non-JAX type"): check_jax_type(pretty_named_args={"key": "1"}) + @parameterized.parameters( + dict(lengths=[3, 4], dtype=None, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[3, 4], dtype=jnp.int32, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[3, 4], dtype=jnp.float32, expected=[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]), + dict(lengths=[[3], [4]], dtype=jnp.int32, expected=[[[1, 1, 1, 0, 0]], [[1, 1, 1, 1, 0]]]), + dict(lengths=[[3, 4]], dtype=jnp.int32, expected=[[[1, 1, 1, 0, 0], [1, 1, 1, 1, 0]]]), + ) + def test_sequence_mask(self, lengths, dtype, expected): + max_len = 5 + mask = utils.sequence_mask(lengths=jnp.array(lengths), max_len=max_len, dtype=dtype) + expected = jnp.array(expected).astype(dtype if dtype else jnp.int32) + self.assertNestedAllClose(mask, expected) + class SimilarNamesTest(TestCase): @parameterized.parameters( From fd0fc5b611863ef291121e255e53303fc7ddb0d2 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 13 Nov 2024 17:12:49 -0800 Subject: [PATCH 23/27] Remove stale moveaxis optimization in attention. (#835) With the hack, ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.70 ms 0.513 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.40 ms 0.519 ms 1174 QkvLinearExtendStepBenchmark/2048/16/32768/1 20.1 ms 0.930 ms 404 QkvLinearExtendStepBenchmark/2048/16/4096/8 3.68 ms 0.524 ms 1139 QkvLinearExtendStepBenchmark/2048/16/4096/64 3.74 ms 0.532 ms 1125 QkvLinearExtendStepBenchmark/2048/16/4096/512 2530 ms 80.4 ms 1 If remove the weird moveaxis hack, there is speed up, especially when step size is big (512). This PR w/o moveaxis ---------------------------------------------------------------------------------------- Benchmark Time CPU Iterations ---------------------------------------------------------------------------------------- QkvLinearExtendStepBenchmark/2048/16/1024/1 1.52 ms 0.542 ms 1082 QkvLinearExtendStepBenchmark/2048/16/4096/1 3.18 ms 0.547 ms 1096 QkvLinearExtendStepBenchmark/2048/16/32768/1 19.6 ms 0.824 ms 430 QkvLinearExtendStepBenchmark/2048/16/4096/8 3.34 ms 0.542 ms 1139 QkvLinearExtendStepBenchmark/2048/16/4096/64 3.48 ms 0.553 ms 1091 QkvLinearExtendStepBenchmark/2048/16/4096/512 36.5 ms 1.71 ms 71 --- axlearn/common/attention.py | 46 ++++++-------------------------- axlearn/common/attention_test.py | 2 +- 2 files changed, 9 insertions(+), 39 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index 7ac168697..b6dbdb9f1 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -724,21 +724,16 @@ def init_states( dtype = cfg.cache_dtype or self.dtype() assert dtype is not None - # TODO(dhwang2): Use [BTNH], because our benchmark shows the current is slower. - # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take - # advantage of TPU optimizations (see `extend_step`). - # Reference: - # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215 cache = dict(time_step=jnp.zeros(target_batch_size, dtype=jnp.int32)) # If `kv_state` is provided externally, we do not have to maintain key/value in cache. if kv_state is None: cache.update( key=jnp.zeros( - shape=(target_batch_size, self.num_kv_heads, cfg.per_head_dim, target_max_len), + shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), dtype=dtype, ), value=jnp.zeros( - shape=(target_batch_size, self.num_kv_heads, cfg.per_head_dim, target_max_len), + shape=(target_batch_size, target_max_len, self.num_kv_heads, cfg.per_head_dim), dtype=dtype, ), ) @@ -831,17 +826,7 @@ def prefill_states( time_step_mask = (jnp.arange(k_proj.shape[1]) < time_step[:, None])[..., None, None] k_proj = k_proj * time_step_mask v_proj = v_proj * time_step_mask - - # TODO(dhwang2): remove this unnecessary transpose, because our benchmark shows it - # slow down. - # Following T5X, we cache key, value as [batch, num_heads, head_dim, seq_len] to take - # advantage of TPU optimizations (see `extend_step`). - # Reference: - # https://github.com/google-research/t5x/blob/4d94d8bf41230d492e15e255c9888b5bfd9a5ee8/t5x/examples/t5/layers.py#L215 - init_state.update( - key=jnp.moveaxis(k_proj, -3, -1).astype(dtype), - value=jnp.moveaxis(v_proj, -3, -1).astype(dtype), - ) + init_state.update(key=k_proj.astype(dtype), value=v_proj.astype(dtype)) return init_state, self.Output(query=q_proj, key=k_proj, value=v_proj) def extend_step( @@ -892,12 +877,7 @@ def extend_step( q_proj, k_proj, v_proj = self.forward(query, **kv_kwargs, time_step=time_step) updated_state = dict(time_step=time_step + num_query_steps) if kv_state is None: - # TODO(dhwang2): remove this unnecessary transpose. - # [B, S, N, H] --> [B, N, H, S]. - k_proj = jnp.moveaxis(k_proj, -3, -1) - v_proj = jnp.moveaxis(v_proj, -3, -1) - - # Update the cache via one-hot broadcast and addition. + # Update the cache via one-hot broadcast and addition. [B, S, N, H]. cached_key = cached_states["key"] cached_value = cached_states["value"] # Ensure that we accumulate using the original dtype. @@ -906,23 +886,13 @@ def extend_step( # Function to update the cached_key for a single batch element. def update_single(cached_key_slice, k_proj_slice, time_idx): - start_indices = (0, 0, time_idx) + start_indices = (time_idx, 0, 0) return jax.lax.dynamic_update_slice(cached_key_slice, k_proj_slice, start_indices) # Use jax.vmap to vectorize over the batch dimension. - new_cached_key = jax.vmap(update_single, in_axes=(0, 0, 0))( - cached_key, k_proj, time_step - ) - new_cached_value = jax.vmap(update_single, in_axes=(0, 0, 0))( - cached_value, v_proj, time_step - ) - - # TODO(dhwang2): remove this unnecessary transpose. - # Move back to original [B, T, N, H] layout. - k_proj = jnp.moveaxis(new_cached_key, -1, -3) - v_proj = jnp.moveaxis(new_cached_value, -1, -3) - - updated_state.update(key=new_cached_key, value=new_cached_value) + k_proj = jax.vmap(update_single)(cached_key, k_proj, time_step) + v_proj = jax.vmap(update_single)(cached_value, v_proj, time_step) + updated_state.update(key=k_proj, value=v_proj) return updated_state, self.Output(query=q_proj, key=k_proj, value=v_proj) diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index a7a38bd3f..87dae373e 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -2670,7 +2670,7 @@ def _test_prefill_states( self.assertTrue(jnp.all(time_step == initial_states["i_proj"]["time_step"])) for proj in ["key", "value"]: self.assertEqual( - (batch_size, num_kv_heads or num_heads, model_dim // num_heads, tgt_len), + (batch_size, tgt_len, num_kv_heads or num_heads, model_dim // num_heads), initial_states["i_proj"][proj].shape, ) self.assertEqual( From a54654106113506228c0d66cc9aea71357448b87 Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Wed, 13 Nov 2024 18:31:36 -0800 Subject: [PATCH 24/27] Transformer extend_step supports multi steps generation (2/2). (#836) In MultiheadAttention.extend_step, logit_bias was hardcoded to have a length of 1. This PR modified it to support multi-step inputs. This change also makes extend_step more aligned with forward, reducing the overall code complexity. --- axlearn/common/attention.py | 61 ++++---------- axlearn/common/attention_test.py | 106 ++++++++++-------------- axlearn/common/flash_attention/layer.py | 13 +-- 3 files changed, 62 insertions(+), 118 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index b6dbdb9f1..bcd9c65a6 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -1855,17 +1855,13 @@ def _forward_for_mode( f"Invalid attention_logit_biases shape: {attention_logit_biases.shape}." ) if self._mask_fn is not None: - kv_len = k_proj.shape[1] + kv_pos = jnp.arange(k_proj.shape[1])[None, :] # [1, source_len] + query_pos = jnp.arange(q_proj.shape[1])[None] # [1, target_length] if mode == ForwardMode.EXTEND_STEP: - # query_len is unused because extend_step assumes query to be length 1. - query_len = None - time_step = cached_states["i_proj"]["time_step"] - else: - query_len = q_proj.shape[1] - time_step = None - mask = self._logit_biases_for_mask( - mode=mode, kv_len=kv_len, query_len=query_len, time_step=time_step - ) + time_step = cached_states["i_proj"]["time_step"] # [B] + # [B, target_length], target_length is often 1 for decoding, but not always. + query_pos = query_pos + time_step[:, None] + mask = self._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) if mask is not None: attention_logit_biases = apply_attention_logit_biases( mask.astype(q_proj.dtype), @@ -1894,12 +1890,7 @@ def _forward_for_mode( return dict(i_proj=i_proj_state), output def _logit_biases_for_mask( - self, - *, - mode: ForwardMode, - kv_len: int, - query_len: Optional[int] = None, - time_step: Optional[Tensor] = None, + self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor ) -> Optional[Tensor]: """Returns the configured attention mask in the form of logit biases. @@ -1908,39 +1899,17 @@ def _logit_biases_for_mask( Args: mode: The forward propagation mode, chosen from (ForwardMode.FORWARD, ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP). - kv_len: The sequence length. For (ForwardMode.INIT_STATES, ForwardMode.EXTEND_STEP), - this is equal to the KV cache size. - query_len: Only used for (ForwardMode.FORWARD, ForwardMode.INIT_STATES). - If set, this is the query length. Otherwise, it uses kv_len as the query length. - Must be None for ForwardMode.EXTEND_STEP. - time_step: Only used for (ForwardMode.EXTEND_STEP). A tensor of size [batch] denoting - the 0-indexed position of the current input token. + query_pos: The index in the sequence of query vectors, [1|batch, target_length]. + kv_pos: The index in the sequence of kv vectors, [1|batch, source_length]. Returns: - For (ForwardMode.FORWARD, ForwardMode.INIT_STATES), a logit bias tensor that can be - broadcast to [batch, num_heads, query_len, kv_len]. - - For ForwardMode.EXTEND_STEP, a logit bias tensor that can be broadcast to - [batch, num_heads, 1, kv_len]. + A logit bias tensor [1|batch, 1, target_length, source_length]. """ - kv_pos = jnp.arange(kv_len) - - if mode in (ForwardMode.FORWARD, ForwardMode.INIT_STATES): - if time_step is not None: - raise ValueError( - "FORWARD or INIT_STATES modes do not expect `time_step` as an argument." - ) - query_pos = jnp.arange(kv_len if query_len is None else query_len) - mask = self._mask_fn(query_pos[:, None], kv_pos[None, :])[None, None] - elif mode == ForwardMode.EXTEND_STEP: - if query_len is not None: - raise ValueError("EXTEND_STEP mode does not expect `query_len` as an argument.") - # [batch, 1, 1, kv_len]. - # Ex: for a causal mask, mask[b, :, :, kv_pos] = 0 if time_step[b] > kv_pos else 1. - mask = self._mask_fn(time_step[:, None], kv_pos[None, :]) - mask = mask[:, None, None, :] - else: - raise ValueError(f"Unrecognized mode {mode}.") + del mode + kv_pos = kv_pos[:, None] # [1|B, 1, source_len] + query_pos = query_pos[..., None] # [1|B, target_len, 1] + # [1|B, 1, target_len, source_len] + mask = self._mask_fn(query_pos, kv_pos)[:, None] mask = bool_to_bias(mask) return mask diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index 87dae373e..c19992b0f 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -1440,7 +1440,7 @@ def test_qlinear(self): (attention.FusedGroupedQKVLinear, 3), (attention.RoFormerQKVLinear, 2), ) - def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], stride): + def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], extend_step_len): """Tests that calling QKVLinear.extend_step() multiple times with the same time_step results in the same output.""" model_dim = 8 @@ -1475,13 +1475,13 @@ def test_repeated_extend_step(self, layer_cls: type[attention.BaseQKVLinear], st cache_state = layer.init_states(target_batch_size=batch_size, target_max_len=tgt_len) step_querys = [] step_keys = step_values = None - for t in range(0, tgt_len, stride): + for t in range(0, tgt_len, extend_step_len): (cache_state, step_output), _ = F( layer, state=layer_state, is_training=False, prng_key=jax.random.PRNGKey(456), - inputs=dict(cached_states=cache_state, query=query[:, t : t + stride]), + inputs=dict(cached_states=cache_state, query=query[:, t : t + extend_step_len]), method="extend_step", ) step_querys.append(step_output.query) @@ -2187,7 +2187,10 @@ def test_logit_biases_for_mask(self): layer = cfg.instantiate(parent=None) layer_params = layer.initialize_parameters_recursively(prng_key=jax.random.PRNGKey(123)) - inputs = dict(mode=ForwardMode.FORWARD, kv_len=3, query_len=2) + query_len, kv_len = 2, 3 + query_pos = jnp.arange(query_len)[None] + kv_pos = jnp.arange(kv_len)[None] + inputs = dict(mode=ForwardMode.FORWARD, query_pos=query_pos, kv_pos=kv_pos) layer_outputs, _ = F( layer, state=layer_params, @@ -2201,33 +2204,11 @@ def test_logit_biases_for_mask(self): bool_to_bias(jnp.array([[1, 0, 0], [1, 1, 0]], dtype=jnp.bool))[None, None], ) - inputs = dict(mode=ForwardMode.FORWARD, kv_len=3, query_len=2, time_step=jnp.array([3, 4])) - with self.assertRaises(ValueError) as cm: - layer_outputs, _ = F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertTrue(isinstance(cm.exception, ValueError)) - - inputs = dict( - mode=ForwardMode.EXTEND_STEP, kv_len=3, query_len=2, time_step=jnp.array([3, 4]) - ) - with self.assertRaises(ValueError) as cm: - F( - layer, - state=layer_params, - is_training=True, - prng_key=jax.random.PRNGKey(456), - inputs=inputs, - method="_logit_biases_for_mask", - ) - self.assertTrue(isinstance(cm.exception, ValueError)) - - inputs = dict(mode=ForwardMode.EXTEND_STEP, kv_len=4, time_step=jnp.array([1, 2])) + time_step = jnp.array([1, 2]) + query_pos = time_step[:, None] + kv_len = 4 + kv_pos = jnp.arange(kv_len)[None] + inputs = dict(mode=ForwardMode.EXTEND_STEP, query_pos=query_pos, kv_pos=kv_pos) layer_outputs, _ = F( layer, state=layer_params, @@ -2439,6 +2420,7 @@ def _test_extend_step( num_heads: int, dtype: jnp.dtype, bias: bool, + extend_step_len: int, ): cfg = attention_cfg.set( query_dim=model_dim, @@ -2502,18 +2484,16 @@ def _test_extend_step( self.assertNotIn("key", initial_state["i_proj"]) self.assertNotIn("value", initial_state["i_proj"]) inputs = dict(cached_states=initial_state, kv_state=kv_state, return_aux=return_aux) - decoder_output = jnp.zeros(shape=[tgt_len, batch_size, model_dim]) - decoder_probs = jnp.zeros(shape=[tgt_len, batch_size, num_heads, tgt_len]) - for t in range(tgt_len): - inputs["query"] = jnp.expand_dims(query[:, t, :], axis=1) + decoder_output = [] + decoder_probs = [] + for t in range(0, tgt_len, extend_step_len): + inputs["query"] = query[:, t : t + extend_step_len, :] if key is not None: - inputs["key"] = jnp.expand_dims(key[:, t, :], axis=1) + inputs["key"] = key[:, t : t + extend_step_len, :] if value is not None: - inputs["value"] = jnp.expand_dims(value[:, t, :], axis=1) - inputs["attention_logit_biases"] = attention_logit_biases[ - jnp.newaxis, jnp.newaxis, t, : - ] - extend_step_outputs, _ = F( + inputs["value"] = value[:, t : t + extend_step_len, :] + inputs["attention_logit_biases"] = attention_logit_biases[t : t + extend_step_len, :] + (cached_states, extend_step_outputs), _ = F( layer, state=layer_params, is_training=False, @@ -2521,25 +2501,13 @@ def _test_extend_step( inputs=inputs, method="extend_step", ) - inputs["cached_states"] = extend_step_outputs[0] - decoder_output = decoder_output.at[t].set( - jnp.squeeze(extend_step_outputs[1].data, axis=1) - ) - decoder_probs = decoder_probs.at[t].set( - jnp.squeeze(extend_step_outputs[1].probs, axis=2) - ) - decoder_out_transposed = jnp.transpose(decoder_output, [1, 0, 2]) - decoder_probs_transposed = jnp.transpose(decoder_probs, [1, 2, 0, 3]) - assert_allclose( - decoder_out_transposed, - forward_outputs.data, - atol=1e-6, - ) - assert_allclose( - decoder_probs_transposed, - forward_outputs.probs, - atol=1e-6, - ) + inputs["cached_states"] = cached_states + decoder_output.append(extend_step_outputs.data) + decoder_probs.append(extend_step_outputs.probs) + decoder_output = jnp.concatenate(decoder_output, axis=1) + decoder_probs = jnp.concatenate(decoder_probs, axis=2) + assert_allclose(decoder_output, forward_outputs.data, atol=1e-6) + assert_allclose(decoder_probs, forward_outputs.probs, atol=1e-6) @parameterized.product( dtype=(jnp.float32, jnp.float16, jnp.bfloat16), @@ -2547,6 +2515,7 @@ def _test_extend_step( atten_logit_cap=(0.0, 20.0), bias=(True, False), input_linear=(QKVLinear, RoFormerQKVLinear, QLinear), + extend_step_len=(1, 4), ) def test_extend_step( self, @@ -2555,6 +2524,7 @@ def test_extend_step( atten_logit_cap: float, input_linear: attention.BaseQKVLinear, bias: bool, + extend_step_len: int, ): model_dim = 16 num_heads = 4 @@ -2568,7 +2538,12 @@ def test_extend_step( input_linear=input_linear, ) self._test_extend_step( - cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, ) @parameterized.product( @@ -2578,6 +2553,7 @@ def test_extend_step( num_kv_heads=(1, 2, 4), input_linear=(attention.GroupedQKVLinear, attention.FusedGroupedQKVLinear), bias=(True, False), + extend_step_len=(1, 4), ) def test_gqa_extend_step( self, @@ -2587,6 +2563,7 @@ def test_gqa_extend_step( num_kv_heads: int, input_linear: type[attention.BaseQKVLinear], bias: bool, + extend_step_len: int, ): model_dim = 16 num_heads = 4 @@ -2596,7 +2573,12 @@ def test_gqa_extend_step( input_linear=input_linear.default_config().set(num_kv_heads=num_kv_heads), ) self._test_extend_step( - cfg, model_dim=model_dim, num_heads=num_heads, dtype=dtype, bias=bias + cfg, + model_dim=model_dim, + num_heads=num_heads, + dtype=dtype, + bias=bias, + extend_step_len=extend_step_len, ) def _test_prefill_states( diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 5588ae1d6..4aff9488b 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -109,18 +109,13 @@ def _is_mask_fn_used(self): ) def _logit_biases_for_mask( - self, - *, - mode: ForwardMode, - kv_len: int, - query_len: Optional[int] = None, - time_step: Optional[Tensor] = None, + self, *, mode: ForwardMode, query_pos: Tensor, kv_pos: Tensor ) -> Optional[Tensor]: if self._mask_fn is None: return None elif mode == ForwardMode.EXTEND_STEP: # Use biases for decoding. - return super()._logit_biases_for_mask(mode=mode, kv_len=kv_len, time_step=time_step) + return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) elif self._is_mask_fn_used(): # Biases are not needed in favor of mask_fn, which is supported in Splash Attention. return None @@ -130,9 +125,7 @@ def _logit_biases_for_mask( else: # Fall back to biases. In the subsequent _compute_attention calls, _mask_fn should not # be used. - return super()._logit_biases_for_mask( - mode=mode, kv_len=kv_len, query_len=query_len, time_step=time_step - ) + return super()._logit_biases_for_mask(mode=mode, query_pos=query_pos, kv_pos=kv_pos) def _backend(self): # For compatibility with AOT compilation, we obtain the backend type from physical_mesh. From 3213e048bdc742b614dd9be1413f9a1c2f830c3d Mon Sep 17 00:00:00 2001 From: Mark Lee Date: Thu, 14 Nov 2024 00:02:06 -0800 Subject: [PATCH 25/27] Updates orbax and adds support for max save/restore concurrent gb. (#834) --- axlearn/common/checkpointer_orbax.py | 12 ++++++++---- axlearn/common/checkpointer_test.py | 6 ++++-- pyproject.toml | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/axlearn/common/checkpointer_orbax.py b/axlearn/common/checkpointer_orbax.py index 2f714605c..befde370f 100644 --- a/axlearn/common/checkpointer_orbax.py +++ b/axlearn/common/checkpointer_orbax.py @@ -45,7 +45,7 @@ _GRAIN_INSTALLED = False -class _TfIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): +class _TfIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes tf.data.Iterator. Reference: @@ -105,7 +105,7 @@ async def metadata( if _GRAIN_INSTALLED: - class _GrainDatasetIteratorHandler(ocp.pytree_checkpoint_handler.TypeHandler): + class _GrainDatasetIteratorHandler(ocp.type_handlers.TypeHandler): """Serializes grain dataset iterators.""" @dataclasses.dataclass @@ -178,6 +178,8 @@ class Config(BaseCheckpointer.Config): keep_last_n: int = 1 validation_type: CheckpointValidationType = CheckpointValidationType.EXACT async_timeout_secs: int = 300 + max_concurrent_save_gb: Optional[int] = None + max_concurrent_restore_gb: Optional[int] = None @classmethod def checkpoint_paths(cls, base_dir: str) -> List[str]: @@ -225,10 +227,12 @@ def save_fn_with_summaries(step: int, last_saved_step: Optional[int]) -> bool: # for simplicity. The test cases ensure that this is compatible with # `read_index_file`. "index": ocp.JsonCheckpointHandler(filename="index"), - # TODO(markblee): Add save/restore_concurrent_gb when available. # Note that this defaults to use_ocdb=True. Note also that custom `TypeHandler`s are # ignored by `StandardCheckpointHandler`, so we use `PyTreeCheckpointHandler`. - "state": ocp.PyTreeCheckpointHandler(), + "state": ocp.PyTreeCheckpointHandler( + save_concurrent_gb=cfg.max_concurrent_save_gb, + restore_concurrent_gb=cfg.max_concurrent_restore_gb, + ), }, ) diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index b164cbb2b..673a9c84f 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -112,7 +112,9 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): ) # When the given state has a different array shape: [3] instead of [2] for y. - with self.assertRaisesRegex(ValueError, "checkpoint tree dtypes or shapes"): + with self.assertRaisesRegex( + ValueError, "(checkpoint tree dtypes or shapes|not compatible)" + ): ckpt.restore( step=None, state=dict( @@ -124,7 +126,7 @@ def test_save_and_restore(self, checkpointer_cls: Type[BaseCheckpointer]): # Orbax throws AssertionError in this case. with self.assertRaisesRegex( (AssertionError, ValueError), - "(checkpoint tree dtypes or shapes|do not match)", + "(checkpoint tree dtypes or shapes|not compatible)", ): ckpt.restore( step=None, diff --git a/pyproject.toml b/pyproject.toml index 076c4706a..f3f28b8e5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -145,7 +145,7 @@ mmau = [ # Orbax checkpointing. orbax = [ "humanize==4.10.0", - "orbax-checkpoint==0.5.23", + "orbax-checkpoint==0.9.1", ] # Grain input processing. Currently does not support macos. grain = [ From c4bbaa18706afbf41a30c7ae00675cba9dccc871 Mon Sep 17 00:00:00 2001 From: Hanzhi Zhou Date: Thu, 14 Nov 2024 04:37:09 -0800 Subject: [PATCH 26/27] Implement custom `max_data_shard_degree` and `shard_threshold_bytes` (#838) --- axlearn/common/array_serialization.py | 52 ++++++++---- axlearn/common/array_serialization_test.py | 85 +++++++++++++++---- axlearn/common/checkpointer.py | 12 ++- axlearn/common/checkpointer_test.py | 96 +++++++++++++++++++++- 4 files changed, 211 insertions(+), 34 deletions(-) diff --git a/axlearn/common/array_serialization.py b/axlearn/common/array_serialization.py index 34d3b87c9..1b0095452 100644 --- a/axlearn/common/array_serialization.py +++ b/axlearn/common/array_serialization.py @@ -72,7 +72,9 @@ def _num_replicas_per_shard(arr: Tensor) -> dict[tuple[_SliceTuple, ...], int]: return dict(replica_count) -def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_ShardInfo]: +def _get_shard_infos( + arr_inp: Tensor, *, max_data_shard_degree: int, shard_threshold_bytes: int +) -> list[_ShardInfo]: """Returns a list of _ShardInfo for addressable shards that need to be saved. If replica count for the shards are greater than 0, all replicas will save slices of the @@ -84,11 +86,21 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh for shard in arr_inp.addressable_shards: replica_count = replica_count_map[_slices_to_tuple(shard.index)] assert replica_count > 0 + shard_degree = ( + min(replica_count, max_data_shard_degree) + if max_data_shard_degree > 0 + else replica_count + ) + should_skip = ( + shard_degree == 1 + or shard.data.nbytes < shard_threshold_bytes + or shard.replica_id >= shard_degree + ) for axis, size in enumerate(shard.data.shape): # Find the first dim divisible by partial replication size. - if max_data_shard_degree == 1 or replica_count == 1 or size % replica_count != 0: + if should_skip or size % shard_degree != 0: continue - part_size = size // replica_count + part_size = size // shard_degree slice_obj = shard.index[axis] assert slice_obj.step is None start_offset = shard.replica_id * part_size @@ -103,7 +115,7 @@ def _get_shard_infos(arr_inp: Tensor, *, max_data_shard_degree: int) -> list[_Sh + (slice(slice_start + start_offset, slice_start + end_offset),) + shard.index[axis + 1 :], (start_offset, end_offset, axis), - replica_count, + shard_degree, ) ) break @@ -181,7 +193,8 @@ async def _async_serialize( d2h_future: futures.Future, *, limiter: Optional[serialization._LimitInFlightBytes] = None, - max_data_shard_degree: Optional[int] = None, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): """Similar to `serialization.async_serialize`, but limiting peak host memory usage and sharding along data-parallel axis. @@ -195,7 +208,11 @@ async def _async_serialize( Reference: https://github.com/google/jax/blob/595a620804e810335a870e93975a78504b2e95e5/jax/experimental/array_serialization/serialization.py#L188 """ - shard_infos = _get_shard_infos(arr_inp, max_data_shard_degree=max_data_shard_degree) + shard_infos = _get_shard_infos( + arr_inp, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, + ) if not shard_infos: d2h_future.set_result(shard_infos) return @@ -261,7 +278,8 @@ async def _run_serializer( d2h_futures: list[futures.Future], *, max_concurrent_bytes: Optional[int] = None, - max_data_shard_degree: Optional[int] = None, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): """Asynchronously serializes a list of tensors with _async_serialize.""" # We add 1 because LimitInFlightBytes expects a limit strictly greater than any request. @@ -274,7 +292,10 @@ async def _run_serializer( # pylint: enable=protected-access future_writer = jax.tree.map( functools.partial( - _async_serialize, limiter=limiter, max_data_shard_degree=max_data_shard_degree + _async_serialize, + limiter=limiter, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, ), arrays, tensorstore_specs, @@ -385,7 +406,9 @@ class BoundedDataShardedAsyncCheckpointManager(serialization.GlobalAsyncCheckpoi max_concurrent_gb: Max concurrent shards (in GB) to write. max_data_shard_degree: Max sharding degree of model weights along data-parallel axis. `None` and `1` means no sharding. `-1` means fully shard along data-parallel - replicas. `>1` means custom sharding degree (currently not implemented). + replicas. `>1` means custom sharding degree and should almost always be a power of 2. + shard_threshold_bytes: Threshold for a array shard to be data-sharded. A value of None + or <= 0 means always data-shard according to max_data_shard_degree. timeout_secs: Barrier timeout in seconds. """ @@ -395,6 +418,7 @@ def __init__( max_concurrent_gb: Optional[int] = None, timeout_secs: int = 300, max_data_shard_degree: Optional[int] = None, + shard_threshold_bytes: Optional[int] = None, ): super().__init__(timeout_secs) self._logged_spec = False @@ -406,11 +430,10 @@ def __init__( raise ValueError("max_concurrent_gb must be strictly positive.") self._max_concurrent_bytes = int(max_concurrent_gb * 10**9) - self._max_data_shard_degree = max_data_shard_degree or 1 - if self._max_data_shard_degree not in (1, -1): - raise NotImplementedError( - "max_data_shard_degree is not implemented for values other than 1 and -1" - ) + self._max_data_shard_degree = 1 if max_data_shard_degree is None else max_data_shard_degree + if self._max_data_shard_degree == 0: + raise NotImplementedError("max_data_shard_degree cannot be 0.") + self._shard_threshold_bytes = shard_threshold_bytes or 0 def serialize( self, @@ -457,6 +480,7 @@ def serialize( d2h_futures, max_concurrent_bytes=max_concurrent_bytes, max_data_shard_degree=self._max_data_shard_degree, + shard_threshold_bytes=self._shard_threshold_bytes, ) ) ] diff --git a/axlearn/common/array_serialization_test.py b/axlearn/common/array_serialization_test.py index 8ab56d1aa..3b59168ae 100644 --- a/axlearn/common/array_serialization_test.py +++ b/axlearn/common/array_serialization_test.py @@ -70,7 +70,12 @@ def test_fully_addressable(self): with mock.patch("jax.process_count", return_value=2), self.assertRaises(Exception): asyncio.run( _async_serialize( - jnp.array(1), {}, futures.Future(), limiter=serialization._LimitInFlightBytes(1) + jnp.array(1), + {}, + futures.Future(), + limiter=serialization._LimitInFlightBytes(1), + max_data_shard_degree=-1, + shard_threshold_bytes=0, ), debug=True, ) @@ -122,7 +127,14 @@ def transfer_to_host_patch(*args, **kwargs): # ValueError(...Buffer has been deleted or donated...) may occur. with pytest.raises((RuntimeError, ValueError), match=re.escape("delete")): f = _CommitFuture( - _run_serializer([arr], [spec], [d2h_future], max_concurrent_bytes=arr.nbytes) + _run_serializer( + [arr], + [spec], + [d2h_future], + max_concurrent_bytes=arr.nbytes, + max_data_shard_degree=-1, + shard_threshold_bytes=-1, + ) ) # Throws Array deleted exception if not waiting for d2h_future. jit_fn(arr) @@ -138,7 +150,14 @@ def transfer_to_host_patch(*args, **kwargs): f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch ): f = _CommitFuture( - _run_serializer([arr], [spec], [d2h_future], max_concurrent_bytes=arr.nbytes) + _run_serializer( + [arr], + [spec], + [d2h_future], + max_concurrent_bytes=arr.nbytes, + max_data_shard_degree=-1, + shard_threshold_bytes=-1, + ) ) d2h_future.result() # If D2H is finished, arr can be safely donated. @@ -162,7 +181,11 @@ async def ts_open_patch(*_, **__): f"{array_serialization.__name__}.serialization.ts.open", ts_open_patch, ), get_tensorstore_spec(arr) as spec: - f = _CommitFuture(_run_serializer([arr], [spec], [d2h_future])) + f = _CommitFuture( + _run_serializer( + [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 + ) + ) d2h_future.result() with pytest.raises(RuntimeError, match=re.escape("Test")): f.result() @@ -175,7 +198,11 @@ def transfer_to_host_patch(*_): f"{array_serialization.__name__}._transfer_to_host", transfer_to_host_patch, ), get_tensorstore_spec(arr) as spec: - f = _CommitFuture(_run_serializer([arr], [spec], [d2h_future])) + f = _CommitFuture( + _run_serializer( + [arr], [spec], [d2h_future], max_data_shard_degree=-1, shard_threshold_bytes=-1 + ) + ) # Exceptions will be raised in both the d2h future and the commit future. with pytest.raises(RuntimeError, match=re.escape("Test")): d2h_future.result() @@ -285,9 +312,17 @@ def _donate_argnum_fn(x): self.assertTrue(np.all(x_zero_copy == x_np)) def _verify_shard_info( - self, single_device_arr: jax.Array, arr: jax.Array, max_data_shard_degree: int + self, + single_device_arr: jax.Array, + arr: jax.Array, + max_data_shard_degree: int, + shard_threshold_bytes: int, ): - shard_infos = _get_shard_infos(arr, max_data_shard_degree=max_data_shard_degree) + shard_infos = _get_shard_infos( + arr, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, + ) # Write each shard to output and check if it's the same as the original # single device array. If same, that means all shards should cover all @@ -299,12 +334,16 @@ def _verify_shard_info( out_array[info.index] = info.data self.assertTrue(np.all(out_array == np.array(single_device_arr))) - @parameterized.parameters(1, -1) + @parameterized.product( + max_data_shard_degree=[1, -1, 2, 4, 8], shard_threshold_bytes=[1000 * 1000 * 1000, 1] + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_partially_replicated(self, max_data_shard_degree): + def test_shard_info_partially_replicated( + self, max_data_shard_degree: int, shard_threshold_bytes: int + ): single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -315,14 +354,18 @@ def test_shard_info_partially_replicated(self, max_data_shard_degree): self.assertEqual(replica_count[((None, None, None), (0, 512, None))], 4) self.assertEqual(replica_count[((None, None, None), (512, 1024, None))], 4) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) - @parameterized.parameters(1, -1) + @parameterized.product( + max_data_shard_degree=[1, -1, 2, 4, 8], shard_threshold_bytes=[1000 * 1000 * 1000, 1] + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_fully_sharded(self, max_data_shard_degree): + def test_shard_info_fully_sharded(self, max_data_shard_degree: int, shard_threshold_bytes: int): single_device_arr = jnp.arange(0, 1024 * 1024).reshape(1024, 1024) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -332,14 +375,22 @@ def test_shard_info_fully_sharded(self, max_data_shard_degree): replica_count = _num_replicas_per_shard(arr) self.assertEqual(replica_count[((0, 256, None), (0, 512, None))], 1) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) - @parameterized.product(sz=[1, 11, 16, 21], max_data_shard_degree=[1, -1]) + @parameterized.product( + sz=[1, 11, 16, 21], + max_data_shard_degree=[1, -1, 2, 4, 8], + shard_threshold_bytes=[1000 * 1000 * 1000, 1], + ) @pytest.mark.skipif( jax.device_count() != 8 or jax.process_count() != 1, reason="Incorrect device count for mesh.", ) - def test_shard_info_fully_replicated(self, sz: int, max_data_shard_degree: int): + def test_shard_info_fully_replicated( + self, sz: int, max_data_shard_degree: int, shard_threshold_bytes: int + ): single_device_arr = jnp.arange(0, sz) devices = mesh_utils.create_device_mesh((8,)) sharding = PositionalSharding(devices) @@ -350,4 +401,6 @@ def test_shard_info_fully_replicated(self, sz: int, max_data_shard_degree: int): # Fully replicated on 8 devices. self.assertEqual(replica_count[((None, None, None),)], 8) - self._verify_shard_info(single_device_arr, arr, max_data_shard_degree) + self._verify_shard_info( + single_device_arr, arr, max_data_shard_degree, shard_threshold_bytes + ) diff --git a/axlearn/common/checkpointer.py b/axlearn/common/checkpointer.py index bd1eec98d..027a51157 100644 --- a/axlearn/common/checkpointer.py +++ b/axlearn/common/checkpointer.py @@ -368,10 +368,13 @@ class Config(StateStorage.Config): timeout_secs: Barrier timeout in seconds. max_data_shard_degree: Max sharding degree of model weights along data-parallel axis. `None` and `1` means no sharding. `-1` means fully shard along data-parallel - replicas. `>1` means custom sharding degree (currently not implemented). + replicas. `>1` means custom sharding degree and should almost always be a power + of 2. max_concurrent_gb: Max concurrent shards (in GB) to write. max_concurrent_restore_gb: Max concurrent shards (in GB) to read during checkpoint restore. `None` or `0` means using a default value of 32GB. + shard_threshold_bytes: Threshold for a array shard to be data-sharded. A value of None + or <= 0 means always data-shard according to max_data_shard_degree. """ timeout_secs: float = 3600 @@ -379,6 +382,7 @@ class Config(StateStorage.Config): # TODO(hanzhi-zhou): rename this to max_concurrent_save_gb. max_concurrent_gb: Optional[int] = None max_concurrent_restore_gb: Optional[int] = None + shard_threshold_bytes: Optional[int] = None def __init__(self, cfg: Config): super().__init__(cfg) @@ -390,8 +394,14 @@ def __init__(self, cfg: Config): max_concurrent_gb=cfg.max_concurrent_gb, timeout_secs=cfg.timeout_secs, max_data_shard_degree=cfg.max_data_shard_degree, + shard_threshold_bytes=cfg.shard_threshold_bytes, ) else: + if cfg.shard_threshold_bytes is not None: + raise ValueError( + f"shard_threshold_bytes is set to {cfg.shard_threshold_bytes}, but " + "max_data_shard_degree is not set. It will not take any effect." + ) self._manager = GlobalAsyncCheckpointManager(timeout_secs=cfg.timeout_secs) if cfg.max_concurrent_restore_gb is not None and cfg.max_concurrent_restore_gb <= 0: raise ValueError( diff --git a/axlearn/common/checkpointer_test.py b/axlearn/common/checkpointer_test.py index 673a9c84f..8678692a4 100644 --- a/axlearn/common/checkpointer_test.py +++ b/axlearn/common/checkpointer_test.py @@ -199,6 +199,77 @@ def state_specs(state, partition_spec): self.assertEqual(step, restored_step) self.assertNestedEqual(state, restored_state) + @parameterized.parameters( + # Number of files minus index and .zarray metadata. + dict( + max_data_shard_degree=None, + shard_threshold_bytes=None, + num_files=4, # 2 ararys * 2 shards (2 model) per array. + ), + dict( + max_data_shard_degree=-1, + shard_threshold_bytes=None, + num_files=16, # 2 ararys * 8 shards (2 model, 4 data) per array. + ), + dict( + max_data_shard_degree=2, + shard_threshold_bytes=None, + num_files=8, # 2 ararys * 4 shards (2 model, 2 data) per array. + ), + dict( + max_data_shard_degree=2, + shard_threshold_bytes=1024, + num_files=6, # 1 array 4 shards (2 model, 2 data) + 1 array 2 shards (small array). + ), + ) + def test_save_restore_files_count( + self, max_data_shard_degree: int, shard_threshold_bytes: int, num_files: int + ): + # Tests the effect of max_data_shard_degree and shard_threshold_bytes on number of files. + mesh_shape = (4, 2) + if not test_utils.is_supported_mesh_shape(mesh_shape): + return + + cfg: Checkpointer.Config = _checkpointer_config(Checkpointer) + cfg.storage.max_data_shard_degree = max_data_shard_degree + cfg.storage.shard_threshold_bytes = shard_threshold_bytes + ckpt: Checkpointer = cfg.instantiate(parent=None) + state = dict( + x=jnp.zeros((1024, 1024), dtype=jnp.float32), + small_x=jnp.zeros((16, 16), dtype=jnp.float32), + ) + step = 1 + + def count_files(directory): + file_count = 0 + for _, _, files in os.walk(directory): + file_count += len(files) + return file_count + + def state_specs(state): + return jax.tree.map( + lambda x: utils.TensorSpec( + shape=x.shape, + dtype=x.dtype, + mesh_axes=jax.sharding.PartitionSpec(None, "model"), + ), + state, + ) + + with _mesh(mesh_shape) as mesh: + sharding = jax.sharding.NamedSharding( + mesh, spec=jax.sharding.PartitionSpec(None, "model") + ) + state = jax.tree.map(lambda x: jax.device_put(x, device=sharding), state) + ckpt.save(step=step, state=state) + ckpt.wait_until_finished() + + restored_step, restored_state = ckpt.restore(step=step, state=state_specs(state)) + self.assertEqual(step, restored_step) + self.assertNestedEqual(state, restored_state) + + self.assertEqual(count_files(ckpt.ckpt_dir(step)), num_files + 3) + @parameterized.parameters(Checkpointer, OrbaxCheckpointer) def test_save_and_restore_latest_valid(self, checkpointer_cls: Type[BaseCheckpointer]): mesh_shape = (1, 1) @@ -936,11 +1007,30 @@ def tree_unflatten(cls, keys, values): class TensorStoreStateStorageTest(test_utils.TestCase): - @parameterized.product(max_concurrent_gb=[None, 1], max_data_shard_degree=[None, 1, -1]) - def test_max_concurrent_gb(self, max_concurrent_gb: Optional[int], max_data_shard_degree: int): + @parameterized.product( + max_concurrent_gb=[None, 1], + max_data_shard_degree=[None, 1, -1], + shard_threshold_bytes=[None, 0, int(1024**3)], + ) + def test_checkpointer_configs( + self, + max_concurrent_gb: Optional[int], + max_data_shard_degree: int, + shard_threshold_bytes: int, + ): cfg = TensorStoreStateStorage.default_config().set( - max_concurrent_gb=max_concurrent_gb, max_data_shard_degree=max_data_shard_degree + max_concurrent_gb=max_concurrent_gb, + max_data_shard_degree=max_data_shard_degree, + shard_threshold_bytes=shard_threshold_bytes, ) + if ( + not max_concurrent_gb + and not max_data_shard_degree + and shard_threshold_bytes is not None + ): + with self.assertRaises(ValueError): + storage = cfg.instantiate() + return storage = cfg.instantiate() if max_concurrent_gb is not None or max_data_shard_degree: self.assertIsInstance(storage._manager, BoundedDataShardedAsyncCheckpointManager) From e080157b7e3ed695ddce67998a6bbf2b52dfe54f Mon Sep 17 00:00:00 2001 From: Dongseong Hwang Date: Thu, 14 Nov 2024 07:13:13 -0800 Subject: [PATCH 27/27] Optimize MQA computation. (#837) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit The advantage of multi-query attention (MQA) lies in both reducing the size of the KV cache and making self-attention computation more efficient. The current implementation only saves on KV cache size. This PR improves it further by not only reducing the computation cost, but also saving the per-layer KV cache memory. This becomes especially critical when dealing with very long contexts. For instance, if an LLM is processing a context length of 1 million tokens using the Character.ai architecture [1], there might be around 4 unique KV cache layers. Let’s assume there are 4 KV heads and 32 total attention heads, with a dim_per_head of 128. In the current implementation, each layer consumes significant memory for self-attention KV caching (using bfloat16): * Current (ASIS): 8GB (128 * 32 * 2 * 1M) * Optimized (TODO): 1GB (128 * 4 * 2 * 1M) [1] https://research.character.ai/optimizing-inference/ * Benchmark results: it saves memory and computation. tools/attention_benchmark.py on TPUv5p ASIS ----------------------------------------------------------------------------------------- Benchmark Time CPU Iterations HBM (over 95.74G) ----------------------------------------------------------------------------------------- MQABenchmark/2048/16/2/1024 1.42 ms 0.247 ms 2347 291.16M MQABenchmark/4096/16/2/1024 3.60 ms 0.277 ms 1257 322.95M MQABenchmark/4096/16/2/4096 47.3 ms 0.818 ms 139 4.25G MQABenchmark/4096/16/2/8192 869 ms 0.932 ms 140 48.00G This PR ----------------------------------------------------------------------------------------- Benchmark Time CPU Iterations HBM (over 95.74G) ----------------------------------------------------------------------------------------- MQABenchmark/2048/16/2/1024 1.16 ms 0.256 ms 2535 262.35M MQABenchmark/4096/16/2/1024 3.46 ms 0.294 ms 1114 266.88M MQABenchmark/4096/16/2/4096 24.8 ms 0.769 ms 137 4.04G MQABenchmark/4096/16/2/8192 860 ms 1.19 ms 136 48.00G --- axlearn/common/attention.py | 104 +++++++++++++++--------- axlearn/common/attention_test.py | 25 ------ axlearn/common/flash_attention/layer.py | 11 +++ axlearn/vision/attention.py | 2 +- pyproject.toml | 3 +- 5 files changed, 79 insertions(+), 66 deletions(-) diff --git a/axlearn/common/attention.py b/axlearn/common/attention.py index bcd9c65a6..00bc604fa 100644 --- a/axlearn/common/attention.py +++ b/axlearn/common/attention.py @@ -57,6 +57,7 @@ from enum import Enum, unique from typing import Any, Callable, Literal, NamedTuple, Optional, Protocol, Union +import einops import jax from jax import numpy as jnp from jax.ad_checkpoint import checkpoint_policies as jax_remat_policies @@ -710,7 +711,7 @@ class Output(NamedTuple): @property def num_kv_heads(self): - raise NotImplementedError(type(self)) + return self.config.num_heads def init_states( self, @@ -920,10 +921,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child(f"{name}_proj", proj_cfg) - @property - def num_kv_heads(self): - return self.config.num_heads - def forward( self, query: Tensor, @@ -994,10 +991,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child("q_proj", proj_cfg) - @property - def num_kv_heads(self): - raise NotImplementedError(type(self)) - def forward( self, query: Tensor, @@ -1046,10 +1039,6 @@ def __init__(self, cfg: Config, *, parent: Module): proj_cfg.per_head_dim = cfg.per_head_dim self._add_child("qkv_proj", proj_cfg) - @property - def num_kv_heads(self): - return self.config.num_heads - def create_parameter_specs_recursively(self) -> NestedParameterSpec: specs = VDict(**super().create_parameter_specs_recursively()) @@ -1951,7 +1940,7 @@ def _compute_attention( self.vlog(3, "atten.logits=%s", logits[0, 0, 0, :]) probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") return context, probs @@ -2007,10 +1996,31 @@ def _cap_logits(self, logits: Tensor) -> Tensor: return cap * jnp.tanh(logits / cap) def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. + + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ q_proj = self.scale_query(q_proj) k_proj = self.scale_key(k_proj) return jnp.einsum("btnh,bsnh->bnts", q_proj, k_proj) + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + return jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + def init_states( self, *, @@ -2173,31 +2183,47 @@ class GroupedQueryAttention(MultiheadAttention): def num_kv_heads(self): return self.i_proj.num_kv_heads - def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: - """Repeats key or value heads dim to match the query.""" - num_head_repeats = self.config.num_heads // key_or_value.shape[-2] - if num_head_repeats == 1: - return key_or_value - # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. - return jnp.repeat(key_or_value, num_head_repeats, axis=-2) + def _compute_logits(self, q_proj: Tensor, k_proj: Tensor) -> Tensor: + """Compute attention logits. - def _compute_attention( - self, - *, - q_proj: Tensor, - k_proj: Tensor, - v_proj: Tensor, - **kwargs, - ) -> tuple[Tensor, Tensor]: - """See `MultiheadAttention._compute_attention` for details.""" - k_proj = self._repeat_kv_heads(k_proj) - v_proj = self._repeat_kv_heads(v_proj) - return super()._compute_attention( - q_proj=q_proj, - k_proj=k_proj, - v_proj=v_proj, - **kwargs, - ) + Args: + q_proj: query tensor, [batch, target_length, num_heads, per_head_dim]. + k_proj: key tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + logits: [batch, num_heads, target_length, source_length]. + """ + kv_heads = k_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_logits(q_proj=q_proj, k_proj=k_proj) + + q_proj = self.scale_query(q_proj) + k_proj = self.scale_key(k_proj) + q_proj = einops.rearrange(q_proj, "b t (g k) h -> b t g k h", g=num_head_group, k=kv_heads) + k_proj = einops.rearrange(k_proj, "b s k h -> b s 1 k h") + logits = jnp.einsum("btgkh,bs1kh->bgkts", q_proj, k_proj) + return einops.rearrange(logits, "b g k t s -> b (g k) t s") + + def _compute_context(self, probs: Tensor, v_proj: Tensor) -> Tensor: + """Compute attention context. + + Args: + probs: probs tensor, [batch, num_heads, target_length, source_length]. + v_proj: value tensor, [batch, source_length, num_kv_heads, per_head_dim]. + + Returns: + context: [batch, target_length, num_heads, per_head_dim]. + """ + kv_heads = v_proj.shape[-2] + num_head_group = self.config.num_heads // kv_heads + if num_head_group == 1: + return super()._compute_context(probs=probs, v_proj=v_proj) + + probs = einops.rearrange(probs, "b (g k) t s -> b g k t s", g=num_head_group, k=kv_heads) + v_proj = einops.rearrange(v_proj, "b s k h -> b s 1 k h") + context = jnp.einsum("bgkts,bs1kh->btgkh", probs, v_proj) + return einops.rearrange(context, "b t g k h -> b t (g k) h") class SigmoidAttention(MultiheadAttention): @@ -2248,7 +2274,7 @@ def _compute_attention( ) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") return context, probs diff --git a/axlearn/common/attention_test.py b/axlearn/common/attention_test.py index c19992b0f..502b16088 100644 --- a/axlearn/common/attention_test.py +++ b/axlearn/common/attention_test.py @@ -2303,31 +2303,6 @@ def test_sliding_window( # The outputs are equivalent. self.assertNestedAllClose(outputs[0], outputs[1]) - def test_gqa_kv_heads(self): - """Checks that only the heads dim is repeated.""" - batch = source_length = num_heads = 8 - per_head_dim = 2 - num_kv_heads = 4 - dtype = jnp.float32 - key_or_value = jnp.zeros((batch, source_length, num_kv_heads, per_head_dim), dtype=dtype) - model_dim = per_head_dim * num_heads - cfg = attention.GroupedQueryAttention.default_config().set( - query_dim=model_dim, - key_dim=model_dim, - value_dim=model_dim, - num_heads=num_heads, - input_linear=attention.FusedGroupedQKVLinear.default_config().set( - num_kv_heads=num_kv_heads - ), - dtype=dtype, - ) - test_layer = cfg.set(name="test").instantiate(parent=None) - # pylint: disable-next=protected-access - repeated_key_or_value = test_layer._repeat_kv_heads(key_or_value) - self.assertEqual( - repeated_key_or_value.shape, (batch, source_length, num_heads, per_head_dim) - ) - @parameterized.product( dtype=(jnp.float32, jnp.float16, jnp.bfloat16), per_dim_scale=(None, PerDimScale.default_config()), diff --git a/axlearn/common/flash_attention/layer.py b/axlearn/common/flash_attention/layer.py index 4aff9488b..36da8d595 100644 --- a/axlearn/common/flash_attention/layer.py +++ b/axlearn/common/flash_attention/layer.py @@ -146,6 +146,17 @@ def _logit_biases_spec(self, attention_logit_biases: Tensor) -> Tensor: spec = PartitionSpec(spec[0], None, *spec[2:]) return spec + def _repeat_kv_heads(self, key_or_value: Tensor) -> Tensor: + """Repeats key or value heads dim to match the query. + + TODO(dhwang2): optimize computation like GroupedQueryAttention. + """ + num_head_repeats = self.config.num_heads // key_or_value.shape[-2] + if num_head_repeats == 1: + return key_or_value + # Repeat along the num_heads dim: [batch, source_length, num_heads, per_head_dim]. + return jnp.repeat(key_or_value, num_head_repeats, axis=-2) + def _compute_attention( self, *, diff --git a/axlearn/vision/attention.py b/axlearn/vision/attention.py index fd943d9a1..4b96db95b 100644 --- a/axlearn/vision/attention.py +++ b/axlearn/vision/attention.py @@ -229,7 +229,7 @@ def forward( attention_logit_biases = attention_logit_biases[:, None, :, :] probs = softmax_with_biases(logits, attention_logit_biases=attention_logit_biases) probs = self.dropout(probs) - context = jnp.einsum("bnts,bsnh->btnh", probs, v_proj).astype(v_proj.dtype) + context = self._compute_context(probs, v_proj) context = self._remat_name(context, "context") self.vlog(3, "atten.prob=%s", probs[0, 0, 0, :]) self.vlog(3, "atten.context=%s", context.sum()) diff --git a/pyproject.toml b/pyproject.toml index f3f28b8e5..d0bc333ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,6 +22,7 @@ dependencies = [ core = [ "absl-py==2.1.0", "chex==0.1.86", # chex 0.1.86 is required for jax 0.4.25. + "einops==0.8.0", "importlab==0.7", # breaks pytype on 0.8 "jax==0.4.34", "jaxlib==0.4.34", @@ -53,10 +54,10 @@ apple-silicon = [ ] # Requirements for testing and development. dev = [ + "axlearn[core]", # core "axlearn[audio]", # audio tests "axlearn[orbax]", # checkpointer tests "black==23.1a1", # formatting - "einops==0.8.0", "evaluate", "isort", # formatting "pika==1.3.2", # used by event queue