diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index 9defb0652..13774a4f7 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -110,11 +110,11 @@ decoder_block: "llama2" # which style of DecoderBlock to use. # base_mlp_dim, base_num_decoder_layers and/or head_dim. weight_dtype: float32 global_parameter_scale: 1 -base_emb_dim: 2048 +base_emb_dim: 6144 base_num_query_heads: 16 base_num_kv_heads: 16 -base_mlp_dim: 7168 -base_num_decoder_layers: 16 +base_mlp_dim: 36864 +base_num_decoder_layers: 2 head_dim: 128 mlp_activations: ["silu", "linear"] dropout_rate: 0.0 @@ -124,11 +124,15 @@ logits_dot_in_fp32: False # whether to use fp32 in logits_dense or shared_embed cast_logits_to_fp32: True # whether to cast the logits to fp32. The higher precision is generally beneficial, but it can vary slightly. # mixture of experts (moe) -num_experts: 1 +num_experts: 64 num_experts_per_tok: 1 megablox: True capacity_factor: -1.0 # a factor to decide expert capacity for token dropping, and no dropping by default load_balance_loss_weight: 0.01 # weight for the load balance loss +num_moe_a2a_chunks: 1 # Number of chunks used for MoE FF layeres to pipeline and add the A2A. +# We can potentially hide (chunks - 1) / chunk fraction of the a2a, at the cost of +# each matmul being a factor of chunk smaller - which may make the matmuls less efficient. +# You should use --xla_tpu_enable_async_all_to_all in conjunction with num_moe_a2a_chunks > 1 # pipeline parallelism # The number of decoder layers is equal to the product of num_stages, num_layers_per_pipeline_stage and num_pipeline_repeats. @@ -208,7 +212,7 @@ jax_cache_dir: "~/jax_cache" hardware: 'tpu' # Supported hardware types are 'tpu', 'gpu', 'gpu_multiprocess' and 'cpu' # Parallelism -mesh_axes: ['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive'] +mesh_axes: ['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive'] logical_axis_rules: [ ['activation_batch', ['data', 'fsdp', 'fsdp_transpose', 'expert']], ['activation_batch_no_exp', ['data', 'fsdp', 'fsdp_transpose']], @@ -248,7 +252,7 @@ logical_axis_rules: [ ['exp', 'expert'], ] # Axes used for DCN must be earlier in this list than ICI, see (b/339009148) for details -data_sharding: [['data', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'expert', 'autoregressive']] +data_sharding: [['data', 'expert', 'stage', 'fsdp', 'fsdp_transpose', 'sequence', 'tensor', 'autoregressive']] # One axis for each parallelism type may hold a placeholder (-1) # value to auto-shard based on available slices and devices. diff --git a/MaxText/configs/models/custom-moe-multi.yml b/MaxText/configs/models/custom-moe-multi.yml new file mode 100644 index 000000000..573d004ae --- /dev/null +++ b/MaxText/configs/models/custom-moe-multi.yml @@ -0,0 +1,31 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for custom_moe + +base_emb_dim: 8192 +base_num_query_heads: 112 +base_num_kv_heads: 8 +base_mlp_dim: 32768 +base_num_decoder_layers: 4 +head_dim: 256 +mlp_activations: ["silu","linear"] +vocab_size: 32000 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-5 +num_experts: 64 +num_experts_per_tok: 2 +rope_max_timescale: 1_000_000 +decoder_block: "mistral" diff --git a/MaxText/configs/models/custom-moe-single.yml b/MaxText/configs/models/custom-moe-single.yml new file mode 100644 index 000000000..f6f6c5dfa --- /dev/null +++ b/MaxText/configs/models/custom-moe-single.yml @@ -0,0 +1,24 @@ +# Copyright 2024 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# model config for custom_moe + + +mlp_activations: ["silu","linear"] +vocab_size: 32000 +enable_dropout: False +logits_via_embedding: False +normalization_layer_epsilon: 1.0e-5 +rope_max_timescale: 1_000_000 +decoder_block: "mistral" diff --git a/MaxText/example_hide.py b/MaxText/example_hide.py new file mode 100644 index 000000000..e664c1e15 --- /dev/null +++ b/MaxText/example_hide.py @@ -0,0 +1,93 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +from jax.experimental import shard_map +from jax.experimental.compilation_cache import compilation_cache +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code +def blocking_a2a(input_activations, weights): + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + +# Necessary explicit communication (use shard map) +def a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 1, 0, tiled=True) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + #input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + input_chunk = shard_map.shard_map(a2a, mesh, in_specs=P('expert', None, None), out_specs=P(None, 'expert', None))(input_chunk) + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + +def create_inputs(): + input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) + + weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) + return input_activations, weights + +BATCH_PER_EXP = 2048 +EMBED = 4096 +MLP = 8192 +EXP = 4 + +global mesh +data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 +ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["data", "model", "expert"]) + +input_activations, weights = jax.jit(create_inputs)() + +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") diff --git a/MaxText/hide_ff2_a2a.py b/MaxText/hide_ff2_a2a.py new file mode 100644 index 000000000..73f513001 --- /dev/null +++ b/MaxText/hide_ff2_a2a.py @@ -0,0 +1,104 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +from jax.experimental import shard_map +from jax.experimental.compilation_cache import compilation_cache +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code +def blocking_a2a(input_activations, weights): + + outputs = jnp.einsum("BXM,XEM -> BXE", input_activations, weights) + outputs = jax.lax.with_sharding_constraint(outputs, NamedSharding(mesh, P('expert', None, 'model'))) #A2A B,EXP/X -> B/X,EXP + return outputs + +# Necessary explicit communication (use shard map) +def a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 0, 1, tiled=True) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + ff_output_post_a2a = jnp.zeros((BATCH_PER_EXP, EXP, EMBED), dtype=input_activations.dtype) + # After a2a batch is sharded by expert, expert dim is unsharded + ff_output_post_a2a = jax.lax.with_sharding_constraint(ff_output_post_a2a, NamedSharding(mesh, P('expert', None, 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + result_chunk_before_a2a = jnp.einsum("BXM,XEM -> BXE", input_activations, weight_chunk) + + result_chunk = shard_map.shard_map(a2a, mesh, in_specs=P(None, 'expert', 'model'), out_specs=P('expert', None, 'model'))(result_chunk_before_a2a) + ff_output_post_a2a = jax.lax.dynamic_update_slice(ff_output_post_a2a, result_chunk, (0,0,chunk_start)) + return result_chunk + + +def create_inputs(): + input_activations = jax.random.normal(jax.random.PRNGKey(0), (BATCH_PER_EXP, EXP, MLP), dtype=jnp.bfloat16) + # Inputs start out expert sharded + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P(None, 'expert','model'))) + + weights = jax.random.normal(jax.random.PRNGKey(1), (EXP, EMBED, MLP), dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) + return input_activations, weights + +BATCH_PER_EXP = 16384 +EMBED = 4096 +MLP = 8192 +EXP = 4 + +global mesh +expert_parallelism, data_parallelism, model_parallelism, = 4, 1, 1 +ici_parallelism = [expert_parallelism, data_parallelism, model_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["expert", "data", "model"]) + +input_activations, weights = jax.jit(create_inputs)() + +# correctness test +# overlapped_results = jax.jit(overlap_a2a)(input_activations, weights) +# blocking_results = jax.jit(blocking_a2a)(input_activations, weights) +# # assert overlapped_results and blocking_results are close +# assert jnp.allclose(overlapped_results, blocking_results, rtol=1e-3, atol=1e-2) + +# Profile overlap solution +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# Profile blocking solution +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a") diff --git a/MaxText/layers/linears.py b/MaxText/layers/linears.py index faef1bd67..283943d91 100644 --- a/MaxText/layers/linears.py +++ b/MaxText/layers/linears.py @@ -554,37 +554,125 @@ def dense_matmul(self, inputs, gate_logits, w0_kernel, w1_kernel, wo_kernel): combine_mask = nn.with_logical_constraint(combine_mask, mask_axes) loss = self.load_balance_loss(top_k_indices, softmax_probs) inputs = nn.with_logical_constraint(inputs, ("activation_batch", "activation_length", "activation_embed")) - with jax.named_scope("dispatch"): - dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)( - "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision - ) - dispatch = nn.with_logical_constraint( - dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed") - ) - with jax.named_scope("wi_0"): +# <<<<<<< HEAD +# with jax.named_scope("dispatch"): +# dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)( +# "BSM,BSEC -> EBCM", inputs, dispatch_mask, precision=matmul_precision +# ) +# dispatch = nn.with_logical_constraint( +# dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed") +# ) +# with jax.named_scope("wi_0"): +# w0_kernel_axes = ("exp", None, None) +# w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) +# layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( +# "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision +# ) +# if self.config.activations_in_float32: +# layer_w0 = layer_w0.astype(jnp.float32) +# layer_w0 = nn.with_logical_constraint( +# layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") +# ) +# layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") +# with jax.named_scope("wi_1"): +# w1_kernel_axes = ("exp", None, None) +# w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) +# layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( +# "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision +# ) +# if self.config.activations_in_float32: +# layer_w1 = layer_w1.astype(jnp.float32) +# layer_w1 = nn.with_logical_constraint( +# layer_w1, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") +# ) +# layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") +# ======= + + # TODO(b/363005676) : Currently this hardcodes two activation functions (e.g. swigLU), we should support any number + def dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0, w1): + # We overlap the a2a by chunking up the comms and compute along the embed axis. + # We rely on XLA with `--xla_tpu_enable_async_all_to_all` to schedule the a2a + # so only the first chunk is exposed, the rest can be overlapped. + + # We found explicit communication via shard map is necessary to achieve overlap, details in b/366501973 + def input_a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 0, 1, tiled=True) + + + # Desired overlapped implementaion + # AG weigts over FSDP, keep sharded by exp - this might not be necessary, or perhaps + # we could include the sharding by weight and rely on XLA to figure out the FSDP AG w0_kernel_axes = ("exp", None, None) - w0_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w0_kernel, w0_kernel_axes) - layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w0_kernel, precision=matmul_precision - ) - if self.config.activations_in_float32: - layer_w0 = layer_w0.astype(jnp.float32) - layer_w0 = nn.with_logical_constraint( - layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") - ) - layer_w0 = checkpoint_name(layer_w0, "mlpwi_0") - with jax.named_scope("wi_1"): - w1_kernel_axes = ("exp", None, None) - w1_kernel = self.maybe_all_gather_kernel_weight_in_expert_parallelism(w1_kernel, w1_kernel_axes) - layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)( - "EBCM,EMH -> EBCH", dispatch, w1_kernel, precision=matmul_precision - ) - if self.config.activations_in_float32: - layer_w1 = layer_w1.astype(jnp.float32) - layer_w1 = nn.with_logical_constraint( - layer_w1, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp") - ) - layer_w1 = checkpoint_name(layer_w1, "mlpwi_1") + w0 = nn.with_logical_constraint(w0, w0_kernel_axes) + w1 = nn.with_logical_constraint(w1, w0_kernel_axes) + + def chunked_a2a(inputs, w0, w1): + # Returns: inputs @ w0 and inputs @ w1 + + exp, batch, capacity, embed = jnp.shape(inputs) + # weights are [exp, model=embed, hidden=mlp] + mlp = jnp.shape(w0)[2] + + chunk_size = embed // self.config.num_moe_a2a_chunks + # We chunk along the contracting dimension (embed), thus each step produces a partial sum + running_partial_sum_0 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) + running_partial_sum_1 = jnp.zeros((exp, batch, capacity, mlp), dtype=inputs.dtype) + running_partial_sum_0 = nn.with_logical_constraint(running_partial_sum_0, ('activation_exp', 'activation_batch_no_exp', None, "activation_mlp")) + running_partial_sum_1 = nn.with_logical_constraint(running_partial_sum_1, ('activation_exp', 'activation_batch_no_exp', None, "activation_mlp")) + + + + for i in range(self.config.num_moe_a2a_chunks): + chunk_start = chunk_size * i + + print(f"{inputs.shape=}") + input_chunk = jax.lax.dynamic_slice_in_dim(inputs, chunk_start, chunk_size, 3) + print(f"{input_chunk.shape=}") + #input_chunk = nn.with_logical_constraint(input_chunk, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + + # Inputs are exp, bach, capacity, embed + inputs_before_a2a_spec = nn.logical_to_mesh_axes((None, "activation_batch", None, "activation_embed")) + print(f"{inputs_before_a2a_spec=}") + inputs_after_a2a_spec = nn.logical_to_mesh_axes(("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + print(f"{inputs_after_a2a_spec=}") + # Perform a2a on input_chunk Exp, B/X -> Exp/X, B + input_chunk = shard_map.shard_map(input_a2a, self.mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_chunk) + print(f"{input_chunk.shape=}") + + w0_chunk = jax.lax.dynamic_slice_in_dim(w0, chunk_start, chunk_size, 1) + w1_chunk = jax.lax.dynamic_slice_in_dim(w1, chunk_start, chunk_size, 1) + + w0_chunk = nn.with_logical_constraint(w0_chunk, w0_kernel_axes) + w1_chunk = nn.with_logical_constraint(w1_chunk, w0_kernel_axes) + + running_partial_sum_0 = running_partial_sum_0 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w0_chunk) + running_partial_sum_1 = running_partial_sum_1 + jnp.einsum("EBCM,EMH -> EBCH", input_chunk, w1_chunk) + return running_partial_sum_0, running_partial_sum_1 + + with jax.named_scope("dispatch"): + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + # Keep dispatch sharded like data parallel - we will A2A in chunks + dispatch = nn.with_logical_constraint(dispatch, (None, "activation_batch", None, "activation_embed")) + with jax.named_scope("wi_both"): + return chunked_a2a(dispatch, w0, w1) + if self.config.num_moe_a2a_chunks > 1: + layer_w0, layer_w1 = dispatch_a2a_overlapped_with_ff1(dispatch_mask,inputs, w0_kernel, w1_kernel) + else: + with jax.named_scope("dispatch"): + dispatch = self.get_einsum(rhs_mesh_axes=mask_axes)("BSM,BSEC -> EBCM", inputs, dispatch_mask) + # When using expert parallelism we expect an A2A from E, B/X -> E/X, B with the below sharding constaint. + dispatch = nn.with_logical_constraint(dispatch, ("activation_exp", "activation_batch_no_exp", None, "activation_embed")) + with jax.named_scope("wi_0"): + w0_kernel_axes = ("exp", None, None) + w0_kernel = nn.with_logical_constraint(w0_kernel, w0_kernel_axes) + layer_w0 = self.get_einsum(rhs_mesh_axes=w0_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w0_kernel) + layer_w0 = nn.with_logical_constraint(layer_w0, ("activation_exp", "activation_batch_no_exp", None, "activation_mlp")) + with jax.named_scope("wi_1"): + w1_kernel_axes = ("exp", None, None) + w1_kernel = nn.with_logical_constraint(w1_kernel, w1_kernel_axes) + layer_w1 = self.get_einsum(rhs_mesh_axes=w1_kernel_axes)("EBCM,EMH -> EBCH", dispatch, w1_kernel) + layer_w1 = nn.with_logical_constraint(layer_w1, ("activation_exp", "activation_batch_no_exp",None, "activation_mlp")) + layer_w0_act = _convert_to_activation_function(self.config.mlp_activations[0])(layer_w0) layer_multiply = jnp.multiply(layer_w0_act, layer_w1).astype(self.dtype) with jax.named_scope("wo"): diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 0a9a26128..ac534ed6c 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -424,22 +424,22 @@ def create_device_mesh(config, devices=None): dcn_parallelism = [ config.dcn_data_parallelism, + config.dcn_expert_parallelism, config.dcn_pipeline_parallelism, config.dcn_fsdp_parallelism, config.dcn_fsdp_transpose_parallelism, config.dcn_sequence_parallelism, config.dcn_tensor_parallelism, - config.dcn_expert_parallelism, config.dcn_autoregressive_parallelism, ] ici_parallelism = [ config.ici_data_parallelism, + config.ici_expert_parallelism, config.ici_pipeline_parallelism, config.ici_fsdp_parallelism, config.ici_fsdp_transpose_parallelism, config.ici_sequence_parallelism, config.ici_tensor_parallelism, - config.ici_expert_parallelism, config.ici_autoregressive_parallelism, ] diff --git a/MaxText/my_a2a_playground.py b/MaxText/my_a2a_playground.py new file mode 100644 index 000000000..1d05ff14c --- /dev/null +++ b/MaxText/my_a2a_playground.py @@ -0,0 +1,48 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import jax +import os +from jax.debug import visualize_array_sharding +from jax.experimental import shard_map +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + +def input_a2a(input_chunk): + return jax.lax.all_to_all(input_chunk, 'expert', 1, 1, tiled=True) + +BATCH_PER_EXP = 12 +EXP = 16 + +global mesh +mesh = Mesh(jax.devices(), ('expert',)) + +# create inputs that are BATCH_PER_EXP, by EXP with entries 1,2,... using jnp.arrange +input_activations = jnp.arange(BATCH_PER_EXP * EXP).reshape(BATCH_PER_EXP, EXP) +input_activations = input_activations.astype(jnp.bfloat16) +input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None))) + + + + +print(f"{input_activations.shape=}") +print(input_activations) +visualize_array_sharding(input_activations) + +inputs_before_a2a_spec = P("expert", None) +inputs_after_a2a_spec = P(None, "expert") +# Perform a2a on input_chunk B/X, Exp -> B, Exp/X +input_after_a2a = shard_map.shard_map(input_a2a, mesh, in_specs=inputs_before_a2a_spec, out_specs=inputs_after_a2a_spec)(input_activations) + +print(f"{input_after_a2a.shape=}") +visualize_array_sharding(input_after_a2a) +print(input_after_a2a) + + + +# Try without shard map +# with mesh: +# after_a2a_no_shmap = jax.lax.all_to_all(input_activations, 'expert', 1, 0, tiled=True) +# print(f"{after_a2a_no_shmap.shape=}") +# visualize_array_sharding(after_a2a_no_shmap) +# print(after_a2a_no_shmap) diff --git a/MaxText/optimizers.py b/MaxText/optimizers.py index 332a95904..b99e8ea77 100644 --- a/MaxText/optimizers.py +++ b/MaxText/optimizers.py @@ -36,6 +36,8 @@ def get_optimizer(config, learning_rate_schedule): eps_root=config.adam_eps_root, weight_decay=config.adam_weight_decay, ) + elif config.opt_type == "sgd": + return optax.sgd(learning_rate_schedule) elif config.opt_type == "adam_pax": return adam_pax( learning_rate_schedule, diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5adc88199..451c4f76b 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -157,6 +157,8 @@ def validate_model_name(s: str) -> bool: "mistral-7b", "mixtral-8x7b", "mixtral-8x22b", + "custom-moe-single", + "custom-moe-multi", "gemma-7b", "gemma-2b", "gemma2-2b", @@ -192,6 +194,7 @@ def validate_and_assign_remat_tensors(keys): "key_proj", "value_proj", "out_proj", + "qkv_proj", ] assert keys["decoder_layer_input"] != "remat", "Cannot remeterialize this tensor with scan_layers=True" tensors_on_device = [] diff --git a/MaxText/save_example_hide.py b/MaxText/save_example_hide.py new file mode 100644 index 000000000..5c024c1f4 --- /dev/null +++ b/MaxText/save_example_hide.py @@ -0,0 +1,85 @@ +import jax +from jax import numpy as jnp +from jax.sharding import NamedSharding, Mesh, PartitionSpec as P +from jax.experimental import mesh_utils +import datetime +import jax +import random +import string +import os +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "0" + +#!!!! Internally in google3 set trace_dir to CNS path or other profiling solution +def simple_timeit(f, *args, tries=10, task=None): + """Simple utility to time a function for multiple runs""" + assert task is not None + + trace_name = f"t_{task}_" + "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(10)) + trace_dir = f"gs://mattdavidow-br/{trace_name}" + + outcomes_ms = [] + jax.block_until_ready(f(*args)) # warm it up! + jax.profiler.start_trace(trace_dir) + + for _ in range(tries): + s = datetime.datetime.now() + jax.block_until_ready(f(*args)) + e = datetime.datetime.now() + outcomes_ms.append(1000 * (e - s).total_seconds()) + jax.profiler.stop_trace() + + average_time_ms = sum(outcomes_ms) / len(outcomes_ms) + print(f"{task}: average time milliseconds: {average_time_ms:.2f}, trace {trace_dir}") + return average_time_ms + + +# Baseline non-overlapped implementation to compare against +# In some ideal world compiler comes up with an overlapped solution even with naive code +def blocking_a2a(input_activations, weights): + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + return jnp.einsum("BXE,XEM -> BXM", input_activations, weights) + +# Desired overlapped implementaion +def overlap_a2a(input_activations, weights): + num_chunks = 4 + chunk_size = EMBED // num_chunks + + partial_sum = jnp.zeros((BATCH_PER_EXP, EXP, MLP)) + partial_sum = jax.lax.with_sharding_constraint(partial_sum, NamedSharding(mesh, P('data', 'expert', 'model'))) + for i in range(num_chunks): + chunk_start = chunk_size * i + + input_chunk = jax.lax.dynamic_slice_in_dim(input_activations, chunk_start, chunk_size, 2) + input_chunk = jax.lax.with_sharding_constraint(input_chunk, NamedSharding(mesh, P('data', 'expert', 'model'))) #A2A B/X,EXP -> B,EXP/X + + weight_chunk = jax.lax.dynamic_slice_in_dim(weights, chunk_start, chunk_size, 1) + + partial_sum = partial_sum + jnp.einsum("BXE,XEM -> BXM", input_chunk, weight_chunk) + return partial_sum + +def create_inputs(): + input_activations = jnp.ones((BATCH_PER_EXP, EXP, EMBED),dtype=jnp.bfloat16) + input_activations = jax.lax.with_sharding_constraint(input_activations, NamedSharding(mesh, P('expert', None,'model'))) + + weights = jnp.ones((EXP, EMBED, MLP),dtype=jnp.bfloat16) + weights = jax.lax.with_sharding_constraint(weights, NamedSharding(mesh, P('expert', None, 'model'))) + return input_activations, weights + +BATCH_PER_EXP = 2048 +EMBED = 4096 +MLP = 8192 +EXP = 4 + +global mesh +data_parallelism, model_parallelism, expert_parallelism = 1, 1, 4 +ici_parallelism = [data_parallelism, model_parallelism, expert_parallelism] +devices_array = mesh_utils.create_device_mesh(ici_parallelism) +mesh = Mesh(devices_array, ["data", "model", "expert"]) + +input_activations, weights = jax.jit(create_inputs)() + +jit_overlap_a2a = jax.jit(overlap_a2a) +simple_timeit(jit_overlap_a2a, input_activations, weights, task="hide_a2a") + +# jit_blocking_a2a = jax.jit(blocking_a2a) +# simple_timeit(jit_blocking_a2a, input_activations, weights, task="blocking_a2a")