From d9d5b7846df7bf9ba5553848f93681ea71e18a7d Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Fri, 13 Sep 2024 17:38:05 +0000 Subject: [PATCH 01/17] Toy hide a2a --- MaxText/example_hide.py | 77 +++++++++++++++++++++++++++++++++++++++++ 1 file changed, 77 insertions(+) create mode 100644 MaxText/example_hide.py diff --git a/MaxText/example_hide.py b/MaxText/example_hide.py new file mode 100644 index 000000000..20c4afe6a --- /dev/null +++ b/MaxText/example_hide.py @@ -0,0 +1,77 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +def blocking_a2a(input_activations, weights): + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P(None, 'expert', None))) + return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P(None, 'expert', None))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P(None, 'expert', None))) + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + +def create_inputs(): + input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,None))) + + weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P(None, 'expert', None))) + return input_activations, weights + +BATCH_PER_EXP = 2048 +EMBED = 4096 +MLP = 32768 +EXP = 4 + +global mesh +mesh = Mesh(jax.devices(), ('expert',)) + +input_activations, weights = jax.jit(create_inputs)() + +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") From 49049e8149fa5ae3f35c322456f4533aeb00664a Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 02:22:38 +0000 Subject: [PATCH 02/17] working on initial hide a2a --- MaxText/example_hide.py | 30 +++++++++---- MaxText/layers/linears.py | 41 +++++++++++++++++ MaxText/save_example_hide.py | 85 ++++++++++++++++++++++++++++++++++++ 3 files changed, 148 insertions(+), 8 deletions(-) create mode 100644 MaxText/save_example_hide.py diff --git a/MaxText/example_hide.py b/MaxText/example_hide.py index 20c4afe6a..4791ba873 100644 --- a/MaxText/example_hide.py +++ b/MaxText/example_hide.py @@ -7,6 +7,7 @@ import random import string import os +from jax.experimental import shard_map os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" #!!!! Internally in google3 set trace_dir to CNS path or other profiling solution @@ -34,22 +35,29 @@ def simple_timeit(f, *args, tries=10, task=None): # Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code def blocking_a2a(input_activations, weights): - input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P(None, 'expert', None))) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) +# Necessary explicit communication (use shard map) +def a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 1, 0) + # Desired overlapped implementaion def overlap_a2a(input_activations, weights): num_chunks = 4 chunk_size = EMBED // num_chunks partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) - partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P(None, 'expert', None))) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) for i in range(num_chunks): chunk_start = chunk_size * i input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) - input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P(None, 'expert', None))) + #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) @@ -57,21 +65,27 @@ def overlap_a2a(input_activations, weights): def create_inputs(): input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16) - input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,None))) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) - weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P(None, 'expert', None))) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) return input_activations, weights -BATCH_PER_EXP = 2048 +BATCH_PER_EXP = 16384 EMBED = 4096 -MLP = 32768 +MLP = 8192 EXP = 4 global mesh -mesh = Mesh(jax.devices(), ('expert',)) +data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 +ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["data", "model", "expert"]) input_activations, weights = jax.jit(create_inputs)() jit_overlap_a2a = jax.jit(overlap_a2a) simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index b1f26c1c7..b388f276e 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -491,8 +491,49 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) loss = self.load_balance_loss(top_k_indices, softmax_probs) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + + def dispatch_a2a_overlapped(dispatch_mask,inputs w0, w1): + # We overlap the a2a by chunking up the comms and compute along the embed axis. + # We rely on XLA with `--xla_tpu_enable_async_all_to_all` to schedule the a2a + # so only the first chunk is exposed, the rest can be overlapped. + + # We found explicit communication via shard map is necessary to achieve overlap, details in b/366501973 + def a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) + + # Desired overlapped implementaion + def chunking_overlap_a2a(inputs, w0, w1): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + partial_sum = jnp.zeros_like(x) + inputs_shape, weight_shape = jnp.shape(inputs), jnp.shape(w0) + # Inputs are [exp, batch, capacity, model=embed] + exp, batch, capacity = inputs_shape[0], inputs_shape[1], inputs_shape[2] + # weights are [exp, model=embed, hidden=mlp] + mlp = weight_shape[2] + + # We chunk along the contracting dimension (embed), thus each step produces a partial sum + running_partial_sum = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) + running_partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + + + + with jax.named_scope("dispatch"): dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + # We expect an A2A from E, B/X -> E/X, B with the below sharding constaint. dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) with jax.named_scope("wi_0"): w0_kernel_axes = ("exp", None, None) diff --git a/MaxText/save_example_hide.py b/MaxText/save_example_hide.py new file mode 100644 index 000000000..5c024c1f4 --- /dev/null +++ b/MaxText/save_example_hide.py @@ -0,0 +1,85 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code +def blocking_a2a(input_activations, weights): + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + +def create_inputs(): + input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) + + weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) + return input_activations, weights + +BATCH_PER_EXP = 2048 +EMBED = 4096 +MLP = 8192 +EXP = 4 + +global mesh +data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 +ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["data", "model", "expert"]) + +input_activations, weights = jax.jit(create_inputs)() + +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") From c12f550cfaf92cd2e721bf0c59caa3c8bc41873e Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 02:55:42 +0000 Subject: [PATCH 03/17] Will test actually assingin input chunk to shard map result --- MaxText/layers/linears.py | 37 ++++++++++++++++++++++--------------- 1 file changed, 22 insertions(+), 15 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index b388f276e..fad4a3ce3 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -492,40 +492,47 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): loss = self.load_balance_loss(top_k_indices, softmax_probs) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) + # TODO(b/363005676) : Currently this hardcodes two activation functions (e.g. swigLU), we should support any number + def dispatch_a2a_overlapped(dispatch_mask,inputs w0, w1): # We overlap the a2a by chunking up the comms and compute along the embed axis. # We rely on XLA with `--xla_tpu_enable_async_all_to_all` to schedule the a2a # so only the first chunk is exposed, the rest can be overlapped. # We found explicit communication via shard map is necessary to achieve overlap, details in b/366501973 - def a2a(input_chunk): + def input_a2a(input_chunk): return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) # Desired overlapped implementaion def chunking_overlap_a2a(inputs, w0, w1): - num_chunks = 4 - chunk_size = EMBED // num_chunks - partial_sum = jnp.zeros_like(x) - inputs_shape, weight_shape = jnp.shape(inputs), jnp.shape(w0) - # Inputs are [exp, batch, capacity, model=embed] - exp, batch, capacity = inputs_shape[0], inputs_shape[1], inputs_shape[2] + exp, batch, capacity, embed = ijnp.shape(inputs) # weights are [exp, model=embed, hidden=mlp] - mlp = weight_shape[2] + mlp = jnp.shape(w0)[2] + chunk_size = EMBED // config.num_moe_a2a_chunks # We chunk along the contracting dimension (embed), thus each step produces a partial sum - running_partial_sum = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) - running_partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) - for i in range(num_chunks): + running_partial_sum_0 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) + running_partial_sum_1 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) + running_partial_sum_0 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + running_partial_sum_1 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(config.num_moe_a2a_chunks): chunk_start = chunk_size * i - input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 3) #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X - shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) - weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + # Inputs are exp, bach, capacity, embed + inputs_before_a2a_spec = nn.get_partition_spec((None, "activation_batch", None, "activation_embed")) + inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + # Perform a2a on input_chunk Exp, B/X -> Exp/X, B + shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + + w0 = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) + w1 = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) - partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, weight_chunk) + running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, weight_chunk) return partial_sum From 7a478a1235d9b12c1a91c2192c06d65d2eb1d631 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 03:27:17 +0000 Subject: [PATCH 04/17] Initial chunking behavior done --- MaxText/configs/base.yml | 3 ++ MaxText/example_hide.py | 7 ++-- MaxText/layers/linears.py | 69 ++++++++++++++++++++++----------------- 3 files changed, 47 insertions(+), 32 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index e6cf7246a..b3b06b892 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -115,6 +115,9 @@ num_experts_per_tok: 1 megablox: True capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss +num_moe_a2a_chunks: 1 # Number of chunks used for MoE FF layeres to pipeline and add the A2A. +# We can potentially hide (chunks - 1) / chunk fraction of the a2a, at the cost of +# each matmul being a factor of chunk smaller - which may make the matmuls less efficient. # pipeline parallelism # The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. diff --git a/MaxText/example_hide.py b/MaxText/example_hide.py index 4791ba873..ccb6153ea 100644 --- a/MaxText/example_hide.py +++ b/MaxText/example_hide.py @@ -8,8 +8,11 @@ import string import os from jax.experimental import shard_map +from jax.experimental.compilation_cache import compilation_cache os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" +compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) + #!!!! Internally in google3 set trace_dir to CNS path or other profiling solution def simple_timeit(f, *args, tries=10, task=None): """Simple utility to time a function for multiple runs""" @@ -56,7 +59,7 @@ def overlap_a2a(input_activations, weights): input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X - shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) + input_chunk = shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) @@ -71,7 +74,7 @@ def create_inputs(): weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) return input_activations, weights -BATCH_PER_EXP = 16384 +BATCH_PER_EXP = 2048 EMBED = 4096 MLP = 8192 EXP = 4 diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index fad4a3ce3..77ff591f9 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -494,64 +494,73 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): # TODO(b/363005676) : Currently this hardcodes two activation functions (e.g. swigLU), we should support any number - def dispatch_a2a_overlapped(dispatch_mask,inputs w0, w1): + def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs w0, w1): # We overlap the a2a by chunking up the comms and compute along the embed axis. # We rely on XLA with `--xla_tpu_enable_async_all_to_all` to schedule the a2a # so only the first chunk is exposed, the rest can be overlapped. # We found explicit communication via shard map is necessary to achieve overlap, details in b/366501973 - def input_a2a(input_chunk): - return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) + # def input_a2a(input_chunk): + # return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) # Desired overlapped implementaion - def chunking_overlap_a2a(inputs, w0, w1): + def chunked_a2a(inputs, w0, w1): + # Returns: inputs @ w0 and inputs @ w1 exp, batch, capacity, embed = ijnp.shape(inputs) # weights are [exp, model=embed, hidden=mlp] mlp = jnp.shape(w0)[2] - chunk_size = EMBED // config.num_moe_a2a_chunks + chunk_size = embed // self.config.num_moe_a2a_chunks # We chunk along the contracting dimension (embed), thus each step produces a partial sum running_partial_sum_0 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) running_partial_sum_1 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) running_partial_sum_0 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) running_partial_sum_1 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) - for i in range(config.num_moe_a2a_chunks): + for i in range(self.config.num_moe_a2a_chunks): chunk_start = chunk_size * i input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 3) - #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X - # Inputs are exp, bach, capacity, embed - inputs_before_a2a_spec = nn.get_partition_spec((None, "activation_batch", None, "activation_embed")) - inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + # inputs_before_a2a_spec = nn.get_partition_spec((None, "activation_batch", None, "activation_embed")) + # inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) # Perform a2a on input_chunk Exp, B/X -> Exp/X, B - shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + # shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + + w0_chunk = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) + w1_chunk = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) + + running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) + running_partial_sum_1 = running_partial_sum_1 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w1_chunk) + return running_partial_sum_0, running_partial_sum_1 - w0 = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) - w1 = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) + with jax.named_scope("dispatch"): + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + with jax.named_scope("wi_both"): + return chunked_a2a(dispatch, w0, w1) - running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, weight_chunk) - running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, weight_chunk) - return partial_sum + - with jax.named_scope("dispatch"): - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) - # We expect an A2A from E, B/X -> E/X, B with the below sharding constaint. - dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) - with jax.named_scope("wi_0"): - w0_kernel_axes = ("exp", None, None) - w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes) - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel) - layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp")) - with jax.named_scope("wi_1"): - w1_kernel_axes = ("exp", None, None) - w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes) - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel) - layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp")) + if self.config.num_moe_a2a_chunks > 1: + layer_w0, layer_w0 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs w0_kernel, w1_kernel) + else: + with jax.named_scope("dispatch"): + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + # When using expert parallelism we expect an A2A from E, B/X -> E/X, B with the below sharding constaint. + dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + with jax.named_scope("wi_0"): + w0_kernel_axes = ("exp", None, None) + w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes) + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel) + layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp")) + with jax.named_scope("wi_1"): + w1_kernel_axes = ("exp", None, None) + w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes) + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel) + layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp")) layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1) with jax.named_scope("wo"): From 817afa5f24358dabe887c15e3ada2408a18a80a9 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 03:52:49 +0000 Subject: [PATCH 05/17] The debuggining starts! --- MaxText/layers/linears.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 77ff591f9..71178ff80 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -493,8 +493,7 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) # TODO(b/363005676) : Currently this hardcodes two activation functions (e.g. swigLU), we should support any number - - def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs w0, w1): + def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0, w1): # We overlap the a2a by chunking up the comms and compute along the embed axis. # We rely on XLA with `--xla_tpu_enable_async_all_to_all` to schedule the a2a # so only the first chunk is exposed, the rest can be overlapped. @@ -507,7 +506,7 @@ def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs w0, w1): def chunked_a2a(inputs, w0, w1): # Returns: inputs @ w0 and inputs @ w1 - exp, batch, capacity, embed = ijnp.shape(inputs) + exp, batch, capacity, embed = jnp.shape(inputs) # weights are [exp, model=embed, hidden=mlp] mlp = jnp.shape(w0)[2] @@ -515,8 +514,8 @@ def chunked_a2a(inputs, w0, w1): # We chunk along the contracting dimension (embed), thus each step produces a partial sum running_partial_sum_0 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) running_partial_sum_1 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) - running_partial_sum_0 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) - running_partial_sum_1 = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + running_partial_sum_0 = nn.with_logical_constraint(running_partial_sum_0, ('activation_exp', 'activation_batch', None, "activation_mlp")) + running_partial_sum_1 = nn.with_logical_constraint(running_partial_sum_1, ('activation_exp', 'activation_batch', None, "activation_mlp")) for i in range(self.config.num_moe_a2a_chunks): chunk_start = chunk_size * i @@ -545,7 +544,7 @@ def chunked_a2a(inputs, w0, w1): if self.config.num_moe_a2a_chunks > 1: - layer_w0, layer_w0 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs w0_kernel, w1_kernel) + layer_w0, layer_w0 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0_kernel, w1_kernel) else: with jax.named_scope("dispatch"): dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) From a50673ceaebfab3fe22fa9b57fc8efd922ef0c65 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 04:55:22 +0000 Subject: [PATCH 06/17] Success! --- MaxText/layers/linears.py | 28 ++++++++++++++++++++++++---- 1 file changed, 24 insertions(+), 4 deletions(-) diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 71178ff80..6c332b590 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -503,6 +503,10 @@ def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0, w1): # return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) # Desired overlapped implementaion + w0_kernel_axes = ("exp", None, None) + w0 = nn.with_logical_constraint(w0, w0_kernel_axes) + w1 = nn.with_logical_constraint(w1, w0_kernel_axes) + def chunked_a2a(inputs, w0, w1): # Returns: inputs @ w0 and inputs @ w1 @@ -514,12 +518,19 @@ def chunked_a2a(inputs, w0, w1): # We chunk along the contracting dimension (embed), thus each step produces a partial sum running_partial_sum_0 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) running_partial_sum_1 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) - running_partial_sum_0 = nn.with_logical_constraint(running_partial_sum_0, ('activation_exp', 'activation_batch', None, "activation_mlp")) - running_partial_sum_1 = nn.with_logical_constraint(running_partial_sum_1, ('activation_exp', 'activation_batch', None, "activation_mlp")) + running_partial_sum_0 = nn.with_logical_constraint(running_partial_sum_0, ('activation_exp', 'activation_batch_no_exp', None, "activation_mlp")) + running_partial_sum_1 = nn.with_logical_constraint(running_partial_sum_1, ('activation_exp', 'activation_batch_no_exp', None, "activation_mlp")) + + + for i in range(self.config.num_moe_a2a_chunks): chunk_start = chunk_size * i - input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 3) + input_chunk = jax.lax.dynamic_slice_in_dim(inputs, chunk_start, chunk_size, 3) + + # inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + # input_chunk = nn.with_logical_constraint(input_chunk, inputs_after_a2a_spec) + # Inputs are exp, bach, capacity, embed # inputs_before_a2a_spec = nn.get_partition_spec((None, "activation_batch", None, "activation_embed")) # inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) @@ -529,12 +540,21 @@ def chunked_a2a(inputs, w0, w1): w0_chunk = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) w1_chunk = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) + w0_chunk = nn.with_logical_constraint(w0_chunk, w0_kernel_axes) + w1_chunk = nn.with_logical_constraint(w1_chunk, w0_kernel_axes) + + #partial_result_0 = jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) + #partial_result_0 = nn.with_logical_constraint(partial_result_0, ) + + running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) running_partial_sum_1 = running_partial_sum_1 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w1_chunk) return running_partial_sum_0, running_partial_sum_1 with jax.named_scope("dispatch"): dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + # Keep dispatch sharded like data parallel - we will A2A in chunks + dispatch = nn.with_logical_constraint(dispatch, (None, "activation_batch", None, "activation_embed")) with jax.named_scope("wi_both"): return chunked_a2a(dispatch, w0, w1) @@ -544,7 +564,7 @@ def chunked_a2a(inputs, w0, w1): if self.config.num_moe_a2a_chunks > 1: - layer_w0, layer_w0 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0_kernel, w1_kernel) + layer_w0, layer_w1 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0_kernel, w1_kernel) else: with jax.named_scope("dispatch"): dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) From efd0d8885f9c0e4e98b8a34f16834ed139dd6166 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 23:31:34 +0000 Subject: [PATCH 07/17] Working overlap EP + FSDP with shard mapgit add MaxText/my_a2a_playground.py ! --- MaxText/configs/base.yml | 3 ++- MaxText/layers/linears.py | 32 ++++++++++++++------------------ MaxText/my_a2a_playground.py | 32 ++++++++++++++++++++++++++++++++ 3 files changed, 48 insertions(+), 19 deletions(-) create mode 100644 MaxText/my_a2a_playground.py diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index b3b06b892..b3f390fa1 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -117,7 +117,8 @@ capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, a load_balance_loss_weight: 0.01 # weight for the load balance loss num_moe_a2a_chunks: 1 # Number of chunks used for MoE FF layeres to pipeline and add the A2A. # We can potentially hide (chunks - 1) / chunk fraction of the a2a, at the cost of -# each matmul being a factor of chunk smaller - which may make the matmuls less efficient. +# each matmul being a factor of chunk smaller - which may make the matmuls less efficient. +# You should use --xla_tpu_enable_async_all_to_all in conjunction with num_moe_a2a_chunks > 1 # pipeline parallelism # The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index 6c332b590..9ea4ff523 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -499,10 +499,13 @@ def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0, w1): # so only the first chunk is exposed, the rest can be overlapped. # We found explicit communication via shard map is necessary to achieve overlap, details in b/366501973 - # def input_a2a(input_chunk): - # return jax.lax.all_to_all(input_chunk, 'expert', 0, 1) + def input_a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 0, 1, tiled=True) + # Desired overlapped implementaion + # AG weigts over FSDP, keep sharded by exp - this might not be necessary, or perhaps + # we could include the sharding by weight and rely on XLA to figure out the FSDP AG w0_kernel_axes = ("exp", None, None) w0 = nn.with_logical_constraint(w0, w0_kernel_axes) w1 = nn.with_logical_constraint(w1, w0_kernel_axes) @@ -526,16 +529,19 @@ def chunked_a2a(inputs, w0, w1): for i in range(self.config.num_moe_a2a_chunks): chunk_start = chunk_size * i + print(f"{inputs.shape=}") input_chunk = jax.lax.dynamic_slice_in_dim(inputs, chunk_start, chunk_size, 3) - - # inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) - # input_chunk = nn.with_logical_constraint(input_chunk, inputs_after_a2a_spec) + print(f"{input_chunk.shape=}") + #input_chunk = nn.with_logical_constraint(input_chunk, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) # Inputs are exp, bach, capacity, embed - # inputs_before_a2a_spec = nn.get_partition_spec((None, "activation_batch", None, "activation_embed")) - # inputs_after_a2a_spec = nn.get_partition_spec(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + inputs_before_a2a_spec = nn.logical_to_mesh_axes((None, "activation_batch", None, "activation_embed")) + print(f"{inputs_before_a2a_spec=}") + inputs_after_a2a_spec = nn.logical_to_mesh_axes(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + print(f"{inputs_after_a2a_spec=}") # Perform a2a on input_chunk Exp, B/X -> Exp/X, B - # shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + input_chunk = shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + print(f"{input_chunk.shape=}") w0_chunk = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) w1_chunk = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) @@ -543,10 +549,6 @@ def chunked_a2a(inputs, w0, w1): w0_chunk = nn.with_logical_constraint(w0_chunk, w0_kernel_axes) w1_chunk = nn.with_logical_constraint(w1_chunk, w0_kernel_axes) - #partial_result_0 = jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) - #partial_result_0 = nn.with_logical_constraint(partial_result_0, ) - - running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) running_partial_sum_1 = running_partial_sum_1 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w1_chunk) return running_partial_sum_0, running_partial_sum_1 @@ -557,12 +559,6 @@ def chunked_a2a(inputs, w0, w1): dispatch = nn.with_logical_constraint(dispatch, (None, "activation_batch", None, "activation_embed")) with jax.named_scope("wi_both"): return chunked_a2a(dispatch, w0, w1) - - - - - - if self.config.num_moe_a2a_chunks > 1: layer_w0, layer_w1 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0_kernel, w1_kernel) else: diff --git a/MaxText/my_a2a_playground.py b/MaxText/my_a2a_playground.py new file mode 100644 index 000000000..d0c77358e --- /dev/null +++ b/MaxText/my_a2a_playground.py @@ -0,0 +1,32 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import jax +import os +from jax.debug import visualize_array_sharding +from jax.experimental import shard_map +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + +def input_a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True) + +BATCH_PER_EXP = 2048 +EXP = 4 + +global mesh +mesh = Mesh(jax.devices(), ('expert',)) + +input_activations = jnp.ones((BATCH_PER_EXP, EXP),dtype=jnp.bfloat16) +input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None))) + +print(f"{input_activations.shape=}") +visualize_array_sharding(input_activations) + +inputs_before_a2a_spec = P("expert", None) +inputs_after_a2a_spec = P(None, "expert") +# Perform a2a on input_chunk B/X, Exp -> B, Exp/X +input_after_a2a = shard_map.shard_map(input_a2a, mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_activations) + +print(f"{input_after_a2a.shape=}") +visualize_array_sharding(input_after_a2a) \ No newline at end of file From bfb0c4ecc46d67eb2e546e30af74bc82db55f7f7 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Sun, 15 Sep 2024 23:50:37 +0000 Subject: [PATCH 08/17] Hide ff2 a2a as well! --- MaxText/example_hide.py | 3 +- MaxText/hide_ff2_a2a.py | 94 +++++++++++++++++++++++++++++++++++++++++ 2 files changed, 95 insertions(+), 2 deletions(-) create mode 100644 MaxText/hide_ff2_a2a.py diff --git a/MaxText/example_hide.py b/MaxText/example_hide.py index ccb6153ea..e664c1e15 100644 --- a/MaxText/example_hide.py +++ b/MaxText/example_hide.py @@ -11,7 +11,6 @@ from jax.experimental.compilation_cache import compilation_cache os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" -compilation_cache.set_cache_dir(os.path.expanduser(raw_keys["jax_cache_dir"])) #!!!! Internally in google3 set trace_dir to CNS path or other profiling solution def simple_timeit(f, *args, tries=10, task=None): @@ -45,7 +44,7 @@ def blocking_a2a(input_activations, weights): # Necessary explicit communication (use shard map) def a2a(input_chunk): - return jax.lax.all_to_all(input_chunk, 'expert', 1, 0) + return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True) # Desired overlapped implementaion def overlap_a2a(input_activations, weights): diff --git a/MaxText/hide_ff2_a2a.py b/MaxText/hide_ff2_a2a.py new file mode 100644 index 000000000..24013a558 --- /dev/null +++ b/MaxText/hide_ff2_a2a.py @@ -0,0 +1,94 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +from jax.experimental import shard_map +from jax.experimental.compilation_cache import compilation_cache +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code +def blocking_a2a(input_activations, weights): + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + +# Necessary explicit communication (use shard map) +def a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 0, 1, tiled=True) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + input_chunk = shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + +def create_inputs(): + input_activations = jnp.ones((BATCH_PER_EXP, EXP, MLP),dtype=jnp.bfloat16) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) + + weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) + return input_activations, weights + +BATCH_PER_EXP = 2048 +EMBED = 4096 +MLP = 8192 +EXP = 4 + +global mesh +data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 +ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["data", "model", "expert"]) + +input_activations, weights = jax.jit(create_inputs)() + +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") From ec51ab88156f252b842b54ebed3d5a9f49a6c577 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Mon, 16 Sep 2024 01:14:52 +0000 Subject: [PATCH 09/17] Ensure same inputs before and after a2a --- MaxText/hide_ff2_a2a.py | 2 +- MaxText/my_a2a_playground.py | 13 ++++++++++--- 2 files changed, 11 insertions(+), 4 deletions(-) diff --git a/MaxText/hide_ff2_a2a.py b/MaxText/hide_ff2_a2a.py index 24013a558..980549237 100644 --- a/MaxText/hide_ff2_a2a.py +++ b/MaxText/hide_ff2_a2a.py @@ -51,7 +51,7 @@ def overlap_a2a(input_activations, weights): num_chunks = 4 chunk_size = EMBED // num_chunks - + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) for i in range(num_chunks): diff --git a/MaxText/my_a2a_playground.py b/MaxText/my_a2a_playground.py index d0c77358e..efc865e2b 100644 --- a/MaxText/my_a2a_playground.py +++ b/MaxText/my_a2a_playground.py @@ -11,16 +11,22 @@ def input_a2a(input_chunk): return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True) -BATCH_PER_EXP = 2048 +BATCH_PER_EXP = 8 EXP = 4 global mesh mesh = Mesh(jax.devices(), ('expert',)) -input_activations = jnp.ones((BATCH_PER_EXP, EXP),dtype=jnp.bfloat16) +# create inputs that are BATCH_PER_EXP, by EXP with entries 1,2,... using jnp.arrange +input_activations = jnp.arange(BATCH_PER_EXP * EXP).reshape(BATCH_PER_EXP, EXP) +input_activations = input_activations.astype(jnp.bfloat16) input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None))) + + + print(f"{input_activations.shape=}") +print(input_activations) visualize_array_sharding(input_activations) inputs_before_a2a_spec = P("expert", None) @@ -29,4 +35,5 @@ def input_a2a(input_chunk): input_after_a2a = shard_map.shard_map(input_a2a, mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_activations) print(f"{input_after_a2a.shape=}") -visualize_array_sharding(input_after_a2a) \ No newline at end of file +visualize_array_sharding(input_after_a2a) +print(input_after_a2a) \ No newline at end of file From 45a82afcf9b0f7741e0b410c6b1214ed3344a72e Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Mon, 16 Sep 2024 19:08:41 +0000 Subject: [PATCH 10/17] Hide please --- MaxText/hide_ff2_a2a.py | 55 +++++++++++++++++++++++++++--------- MaxText/my_a2a_playground.py | 6 ++-- 2 files changed, 44 insertions(+), 17 deletions(-) diff --git a/MaxText/hide_ff2_a2a.py b/MaxText/hide_ff2_a2a.py index 980549237..fa2ba7cf2 100644 --- a/MaxText/hide_ff2_a2a.py +++ b/MaxText/hide_ff2_a2a.py @@ -39,8 +39,10 @@ def simple_timeit(f, *args, tries=10, task=None): # Baseline non-overlapped implementation to compare against # In some ideal world compiler comes up with an overlapped solution even with naive code def blocking_a2a(input_activations, weights): - input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X - return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + + outputs = jnp.einsum("BXM,XEM -> BXE", input_activations, weights) + outputs = jax.lax.with_sharding_constraint(outputs, NamedSharding(mesh, P('expert', None, 'model'))) #A2A B,EXP/X -> B/X,EXP + return outputs # Necessary explicit communication (use shard map) def a2a(input_chunk): @@ -52,29 +54,46 @@ def overlap_a2a(input_activations, weights): chunk_size = EMBED // num_chunks - partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) - partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + ff_output_post_a2a = jnp.zeros((BATCH_PER_EXP, EXP, EMBED), dtype=input_activations.dtype) + ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model'))) + + output_list=[None for _ in range(num_chunks)] for i in range(num_chunks): chunk_start = chunk_size * i - input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) - #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X - input_chunk = shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) - + #input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2)s weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + result_chunk_before_a2a = jnp.einsum("BXM,XEM -> BXE", input_activations, weight_chunk) + + # a2a result from B/X,EXP -> B, EXP/X + result_chunk = shard_map.shard_map(a2a, mesh, in_specs=P(None, 'expert', 'model'), out_specs=P('expert', None, 'model'))(result_chunk_before_a2a) + #result_chunk = jax.lax.with_sharding_constraint(result_chunk, NamedSharding(mesh, P('expert', None, 'model'))) + #print(f"{result_chunk.shape=}", flush=True) + #output_list[i] = result_chunk_before_a2a + ff_output_post_a2a = jax.lax.dynamic_update_slice(ff_output_post_a2a, result_chunk, (0,0,chunk_start)) - partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) - return partial_sum + # Alterantive at API + #ff_output_post_a2a = ff_output_post_a2a.at[:,:,chunk_start:chunk_start+chunk_size].set(result_chunk) + + # to_ret = jnp.concatenate(output_list, axis=-1) + # print(f"{to_ret.shape=}", flush=True) + # to_ret = jax.lax.with_sharding_constraint(to_ret, NamedSharding(mesh, P('expert', None, 'model'))) + + ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model'))) + return ff_output_post_a2a + # outputs = jnp.concatenate + # return ff_output_post_a2a def create_inputs(): - input_activations = jnp.ones((BATCH_PER_EXP, EXP, MLP),dtype=jnp.bfloat16) - input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) + input_activations = jax.random.normal(jax.random.PRNGKey(0), (BATCH_PER_EXP, EXP, MLP), dtype=jnp.bfloat16) + # Inputs start out expert sharded + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P(None, 'expert','model'))) - weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.random.normal(jax.random.PRNGKey(1), (EXP, EMBED, MLP), dtype=jnp.bfloat16) weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) return input_activations, weights -BATCH_PER_EXP = 2048 +BATCH_PER_EXP = 16384 EMBED = 4096 MLP = 8192 EXP = 4 @@ -87,8 +106,16 @@ def create_inputs(): input_activations, weights = jax.jit(create_inputs)() +# correctness test +# overlapped_results = jax.jit(overlap_a2a)(input_activations, weights) +# blocking_results = jax.jit(blocking_a2a)(input_activations, weights) +# # assert overlapped_results and blocking_results are close +# assert jnp.allclose(overlapped_results, blocking_results, rtol=1e-3, atol=1e-2) + +# Profile overlap solution jit_overlap_a2a = jax.jit(overlap_a2a) simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") +# Profile blocking solution # jit_blocking_a2a = jax.jit(blocking_a2a) # simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") diff --git a/MaxText/my_a2a_playground.py b/MaxText/my_a2a_playground.py index efc865e2b..2fda3d42c 100644 --- a/MaxText/my_a2a_playground.py +++ b/MaxText/my_a2a_playground.py @@ -9,10 +9,10 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" def input_a2a(input_chunk): - return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True) + return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=False) -BATCH_PER_EXP = 8 -EXP = 4 +BATCH_PER_EXP = 12 +EXP = 16 global mesh mesh = Mesh(jax.devices(), ('expert',)) From 498be64d042b9fc011f290d6bfdb2e181cd9d6a5 Mon Sep 17 00:00:00 2001 From: gobbleturk Date: Tue, 17 Sep 2024 22:54:33 +0000 Subject: [PATCH 11/17] pain among pain --- MaxText/hide_ff2_a2a.py | 27 +++++---------------------- MaxText/my_a2a_playground.py | 13 +++++++++++-- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/MaxText/hide_ff2_a2a.py b/MaxText/hide_ff2_a2a.py index fa2ba7cf2..73f513001 100644 --- a/MaxText/hide_ff2_a2a.py +++ b/MaxText/hide_ff2_a2a.py @@ -53,36 +53,19 @@ def overlap_a2a(input_activations, weights): num_chunks = 4 chunk_size = EMBED // num_chunks - ff_output_post_a2a = jnp.zeros((BATCH_PER_EXP, EXP, EMBED), dtype=input_activations.dtype) + # After a2a batch is sharded by expert, expert dim is unsharded ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model'))) - - output_list=[None for _ in range(num_chunks)] for i in range(num_chunks): chunk_start = chunk_size * i - #input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2)s weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) result_chunk_before_a2a = jnp.einsum("BXM,XEM -> BXE", input_activations, weight_chunk) - # a2a result from B/X,EXP -> B, EXP/X result_chunk = shard_map.shard_map(a2a, mesh, in_specs=P(None, 'expert', 'model'), out_specs=P('expert', None, 'model'))(result_chunk_before_a2a) - #result_chunk = jax.lax.with_sharding_constraint(result_chunk, NamedSharding(mesh, P('expert', None, 'model'))) - #print(f"{result_chunk.shape=}", flush=True) - #output_list[i] = result_chunk_before_a2a ff_output_post_a2a = jax.lax.dynamic_update_slice(ff_output_post_a2a, result_chunk, (0,0,chunk_start)) + return result_chunk - # Alterantive at API - #ff_output_post_a2a = ff_output_post_a2a.at[:,:,chunk_start:chunk_start+chunk_size].set(result_chunk) - - # to_ret = jnp.concatenate(output_list, axis=-1) - # print(f"{to_ret.shape=}", flush=True) - # to_ret = jax.lax.with_sharding_constraint(to_ret, NamedSharding(mesh, P('expert', None, 'model'))) - - ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model'))) - return ff_output_post_a2a - # outputs = jnp.concatenate - # return ff_output_post_a2a def create_inputs(): input_activations = jax.random.normal(jax.random.PRNGKey(0), (BATCH_PER_EXP, EXP, MLP), dtype=jnp.bfloat16) @@ -99,10 +82,10 @@ def create_inputs(): EXP = 4 global mesh -data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 -ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +expert_parallelism, data_parallelism, model_parallelism, = 4, 1, 1 +ici_parallelism = [expert_parallelism, data_parallelism, model_parallelism] devices_array = mesh_utils.create_device_mesh(ici_parallelism) -mesh = Mesh(devices_array, ["data", "model", "expert"]) +mesh = Mesh(devices_array, ["expert", "data", "model"]) input_activations, weights = jax.jit(create_inputs)() diff --git a/MaxText/my_a2a_playground.py b/MaxText/my_a2a_playground.py index 2fda3d42c..1d05ff14c 100644 --- a/MaxText/my_a2a_playground.py +++ b/MaxText/my_a2a_playground.py @@ -9,7 +9,7 @@ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" def input_a2a(input_chunk): - return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=False) + return jax.lax.all_to_all(input_chunk, 'expert', 1, 1, tiled=True) BATCH_PER_EXP = 12 EXP = 16 @@ -36,4 +36,13 @@ def input_a2a(input_chunk): print(f"{input_after_a2a.shape=}") visualize_array_sharding(input_after_a2a) -print(input_after_a2a) \ No newline at end of file +print(input_after_a2a) + + + +# Try without shard map +# with mesh: +# after_a2a_no_shmap = jax.lax.all_to_all(input_activations, 'expert', 1, 0, tiled=True) +# print(f"{after_a2a_no_shmap.shape=}") +# visualize_array_sharding(after_a2a_no_shmap) +# print(after_a2a_no_shmap) From 8aa74d92628664aa54624b961ebe29f57c9fefed Mon Sep 17 00:00:00 2001 From: RissyRan Date: Thu, 7 Nov 2024 00:02:30 +0000 Subject: [PATCH 12/17] Add custom config --- MaxText/configs/models/custom-moe-multi.yml | 31 ++++++++++++++++++++ MaxText/configs/models/custom-moe-single.yml | 31 ++++++++++++++++++++ MaxText/pyconfig.py | 2 ++ 3 files changed, 64 insertions(+) create mode 100644 MaxText/configs/models/custom-moe-multi.yml create mode 100644 MaxText/configs/models/custom-moe-single.yml diff --git a/MaxText/configs/models/custom-moe-multi.yml b/MaxText/configs/models/custom-moe-multi.yml new file mode 100644 index 000000000..573d004ae --- /dev/null +++ b/MaxText/configs/models/custom-moe-multi.yml @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for custom_moe + +base_emb_dim: 8192 +base_num_query_heads: 112 +base_num_kv_heads: 8 +base_mlp_dim: 32768 +base_num_decoder_layers: 4 +head_dim: 256 +mlp_activations: ["silu","linear"] +vocab_size: 32000 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-5 +num_experts: 64 +num_experts_per_tok: 2 +rope_max_timescale: 1_000_000 +decoder_block: "mistral" diff --git a/MaxText/configs/models/custom-moe-single.yml b/MaxText/configs/models/custom-moe-single.yml new file mode 100644 index 000000000..7ab67670b --- /dev/null +++ b/MaxText/configs/models/custom-moe-single.yml @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for custom_moe + +base_emb_dim: 8192 +base_num_query_heads: 112 +base_num_kv_heads: 8 +base_mlp_dim: 32768 +base_num_decoder_layers: 2 +head_dim: 256 +mlp_activations: ["silu","linear"] +vocab_size: 32000 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-5 +num_experts: 64 +num_experts_per_tok: 2 +rope_max_timescale: 1_000_000 +decoder_block: "mistral" diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 8ea48e7e9..532c6c6f5 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -148,6 +148,8 @@ def validate_model_name(s: str) -> bool: "mistral-7b", "mixtral-8x7b", "mixtral-8x22b", + "custom-moe-single", + "custom-moe-multi", "gemma-7b", "gemma-2b", "gemma2-2b", From 622f1c84fc214e5331dabb70437d2ab3efe79844 Mon Sep 17 00:00:00 2001 From: RissyRan Date: Thu, 7 Nov 2024 18:38:20 +0000 Subject: [PATCH 13/17] Add flexibility of configs --- MaxText/configs/base.yml | 13 ++++++++++--- MaxText/configs/models/custom-moe-single.yml | 2 -- 2 files changed, 10 insertions(+), 5 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index b3f390fa1..184b134ed 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -101,7 +101,7 @@ base_emb_dim: 2048 base_num_query_heads: 16 base_num_kv_heads: 16 base_mlp_dim: 7168 -base_num_decoder_layers: 16 +base_num_decoder_layers: 2 head_dim: 128 mlp_activations: ["silu", "linear"] dropout_rate: 0 @@ -110,7 +110,7 @@ normalize_embedding_logits: True # whether to normlize pre-softmax logits if lo logits_dot_in_fp32: True # whether to use fp32 in logits_dense or shared_embedding dot product for stability # mixture of experts (moe) -num_experts: 1 +num_experts: 64 num_experts_per_tok: 1 megablox: True capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default @@ -432,4 +432,11 @@ enable_single_controller: False allow_split_physical_axes: False use_ragged_attention: False -ragged_block_size: 256 \ No newline at end of file +ragged_block_size: 256 + +### Splash attention block sizes +# These can be tuned for specific hardware generations, and can be set up to +# the model's sequence length. +sa_block_q: 512 +sa_block_q_dkv: 512 +sa_block_q_dq: 512 diff --git a/MaxText/configs/models/custom-moe-single.yml b/MaxText/configs/models/custom-moe-single.yml index 7ab67670b..94178df76 100644 --- a/MaxText/configs/models/custom-moe-single.yml +++ b/MaxText/configs/models/custom-moe-single.yml @@ -18,14 +18,12 @@ base_emb_dim: 8192 base_num_query_heads: 112 base_num_kv_heads: 8 base_mlp_dim: 32768 -base_num_decoder_layers: 2 head_dim: 256 mlp_activations: ["silu","linear"] vocab_size: 32000 enable_dropout: False logits_via_embedding: False normalization_layer_epsilon: 1.0e-5 -num_experts: 64 num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mistral" From 94e675fe899dfe4e595b66424aca4a420e49ca6a Mon Sep 17 00:00:00 2001 From: RissyRan Date: Thu, 7 Nov 2024 23:51:32 +0000 Subject: [PATCH 14/17] Add more flexibility to configs --- MaxText/configs/base.yml | 17 ++++++++++-- MaxText/configs/models/custom-moe-single.yml | 7 +---- MaxText/pyconfig.py | 29 ++++++++++++++++++++ 3 files changed, 45 insertions(+), 8 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 184b134ed..16615cb78 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -97,10 +97,10 @@ decoder_block: "llama2" # which style of DecoderBlock to use. # base_mlp_dim, base_num_decoder_layers and/or head_dim. weight_dtype: float32 global_parameter_scale: 1 -base_emb_dim: 2048 +base_emb_dim: 6144 base_num_query_heads: 16 base_num_kv_heads: 16 -base_mlp_dim: 7168 +base_mlp_dim: 36864 base_num_decoder_layers: 2 head_dim: 128 mlp_activations: ["silu", "linear"] @@ -142,6 +142,19 @@ pipeline_delay_activation_forwarding: False # This delays the activation forward # Choose 'remat_policy' between 'minimal', 'save_dot_except_mlpwi', 'save_dot_except_mlp', 'save_qkv_proj', 'qkv_proj_offloaded', 'minimal_offloaded', 'save_out_proj' and 'full'. # These options offer a trade-off between speed (fastest to slowest) and HBM usage (highest to lowest) remat_policy: 'full' +# If custom_save_offload remat_policy is chosen, you can select tensors from the following list to offload on host memory, rematerialize or save on device memory. +# Pick one of these options for following tensors: ['remat','device','offload'] +decoder_layer_input: 'device' # this tensor cannot be rematerialized - it serves as periodic checkpoints that act as the remat start points +mlpwi: 'remat' +mlpwi_0: 'remat' +mlpwi_1: 'remat' +mlpwo: 'remat' +query_proj: 'remat' +key_proj: 'remat' +value_proj: 'remat' +out_proj: 'remat' +qkv_proj: 'remat' + scan_layers: True param_scan_axis: 1 diff --git a/MaxText/configs/models/custom-moe-single.yml b/MaxText/configs/models/custom-moe-single.yml index 94178df76..f6f6c5dfa 100644 --- a/MaxText/configs/models/custom-moe-single.yml +++ b/MaxText/configs/models/custom-moe-single.yml @@ -14,16 +14,11 @@ # model config for custom_moe -base_emb_dim: 8192 -base_num_query_heads: 112 -base_num_kv_heads: 8 -base_mlp_dim: 32768 -head_dim: 256 + mlp_activations: ["silu","linear"] vocab_size: 32000 enable_dropout: False logits_via_embedding: False normalization_layer_epsilon: 1.0e-5 -num_experts_per_tok: 2 rope_max_timescale: 1_000_000 decoder_block: "mistral" diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 532c6c6f5..e0e34eb2b 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -173,6 +173,35 @@ def validate_no_keys_overwritten_twice(keys1: list[str], keys2: list[str]): ) +def validate_and_assign_remat_tensors(keys): + # list of allowed tensors for custom remat policy + tensors = [ + "decoder_layer_input", + "mlpwi", + "mlpwi_0", + "mlpwi_1", + "mlpwo", + "query_proj", + "key_proj", + "value_proj", + "out_proj", + "qkv_proj", + ] + assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True" + tensors_on_device = [] + tensors_to_offload = [] + for t in tensors: + if keys[t] == "device": + tensors_on_device.append(t) + elif keys[t] == "offload": + tensors_to_offload.append(t) + elif keys[t] != "remat": + raise ValueError(f"Invalid value chosen for tensor {t}") + keys["tensors_on_device"] = tensors_on_device + keys["tensors_to_offload"] = tensors_to_offload + return keys + + _config = None config = None From d4b86b951225a3c288f860a474d63ff3d99ee4b5 Mon Sep 17 00:00:00 2001 From: RissyRan Date: Mon, 11 Nov 2024 01:19:15 +0000 Subject: [PATCH 15/17] Move EP ahead --- MaxText/configs/base.yml | 4 ++-- MaxText/max_utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 16615cb78..4a7b7fa9d 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -198,7 +198,7 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +mesh_axes: ['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], @@ -238,7 +238,7 @@ logical_axis_rules: [ ['exp', 'expert'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] +data_sharding: [['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 93641fdb4..75243cb27 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -386,22 +386,22 @@ def create_device_mesh(config, devices=None): dcn_parallelism = [ config.dcn_data_parallelism, + config.dcn_expert_parallelism, config.dcn_pipeline_parallelism, config.dcn_fsdp_parallelism, config.dcn_fsdp_transpose_parallelism, config.dcn_sequence_parallelism, config.dcn_tensor_parallelism, - config.dcn_expert_parallelism, config.dcn_autoregressive_parallelism, ] ici_parallelism = [ config.ici_data_parallelism, + config.ici_expert_parallelism, config.ici_pipeline_parallelism, config.ici_fsdp_parallelism, config.ici_fsdp_transpose_parallelism, config.ici_sequence_parallelism, config.ici_tensor_parallelism, - config.ici_expert_parallelism, config.ici_autoregressive_parallelism, ] From bc5e7e2d0c82c1d0050abe20afd8021e8f9b1cae Mon Sep 17 00:00:00 2001 From: RissyRan Date: Tue, 12 Nov 2024 22:22:13 +0000 Subject: [PATCH 16/17] fix conflict --- MaxText/optimizers.py | 2 ++ MaxText/train.py | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index 04c432958..111c8e55c 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -36,6 +36,8 @@ def get_optimizer(config, learning_rate_schedule): eps_root=config.adam_eps_root, weight_decay=config.adam_weight_decay, ) + elif config.opt_type == "sgd": + return optax.sgd(learning_rate_schedule) elif config.opt_type == "adam_pax": return adam_pax( learning_rate_schedule, diff --git a/MaxText/train.py b/MaxText/train.py index 35b48105b..8031549f1 100644 --- a/MaxText/train.py +++ b/MaxText/train.py @@ -459,7 +459,7 @@ def setup_mesh_and_model(config): devices_array = max_utils.create_device_mesh(config) mesh = Mesh(devices_array, config.mesh_axes) - if emergency_checkpoint_manager.should_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir)): + if emergency_checkpoint_manager._should_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir)): mesh = emergency_checkpoint_manager.consistent_restore_mesh_from_metadata(epath.Path(config.checkpoint_dir), mesh) # Model and Optimizer definition From 511c0bc894126f388eb0d6188faac91f50342e70 Mon Sep 17 00:00:00 2001 From: RissyRan Date: Tue, 12 Nov 2024 22:40:16 +0000 Subject: [PATCH 17/17] fix conflict --- MaxText/configs/base.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index e54d39386..13774a4f7 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -466,6 +466,7 @@ ragged_block_size: 256 ### Splash attention block sizes # These can be tuned for specific hardware generations, and can be set up to # the model's sequence length. +sa_block_q: 512 sa_block_kv: 512 sa_block_kv_compute: 512 sa_block_q_dkv: 512