diff --git a/MaxText/configs/base.yml b/MaxText/configs/base.yml index fdb99197f..5353686b6 100644 --- a/MaxText/configs/base.yml +++ b/MaxText/configs/base.yml @@ -208,6 +208,12 @@ local_checkpoint_directory: "" # It should be a positive number when and only when `enable_emergency_checkpoint` is True. local_checkpoint_period: 0 +# Whether or not to use emergency checkpoint with the replicator service. +use_replicator_service: False + +# The interval to backup local checkpoints to the persistent storage. +replicator_backup_interval_minutes: 0 + # Jax cache directory jax_cache_dir: "~/jax_cache" diff --git a/MaxText/max_utils.py b/MaxText/max_utils.py index 40fb8bbbb..550d22bf3 100644 --- a/MaxText/max_utils.py +++ b/MaxText/max_utils.py @@ -289,6 +289,34 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys): " coordinator_address to initialize JAX distributed runtime..." ) jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id)) + if raw_keys["use_replicator_service"]: + REPLICATOR_FILE = "replicator.yaml" + TEMP_FILE = REPLICATOR_FILE + ".tmp" + replicator_file = epath.Path(raw_keys["local_checkpoint_directory"]) / REPLICATOR_FILE + temp_file = epath.Path(raw_keys["local_checkpoint_directory"]) / TEMP_FILE + num_slices = get_num_slices(raw_keys) + num_nodes = jax.process_count() + nodes_per_slice = num_nodes // num_slices + max_logging.log(f"num_slices: {num_slices}, num_nodes: {num_nodes}, nodes_per_slice: {nodes_per_slice}") + node_rank = jax.process_index() + peer_ranks = [] + for i in range(num_slices): + peer = node_rank % nodes_per_slice + i * nodes_per_slice + if peer != node_rank: + peer_ranks.append(peer) + run_name = raw_keys["run_name"] + if run_name == "": + run_name = os.environ.get("JOBSET_NAME") # using XPK default + + replicator_yaml = f"""job-name: {run_name} + node-rank: {node_rank} + nodes: {num_nodes} + workers-per-node: 1 + peer-ranks: {peer_ranks} + backup-interval-minutes: {raw_keys["replicator_backup_interval_minutes"]}""" + + temp_file.write_text("\n".join([l.strip() for l in replicator_yaml.split("\n")])) + os.rename(temp_file, replicator_file) else: max_logging.log( "Initializing JAX distributed runtime without args when emergency checkpointing is" @@ -319,6 +347,21 @@ def _retrieve_jax_init_info(raw_keys): return "", "" +def get_num_slices(raw_keys): + """Calculate num_slices based on number of devices.""" + if raw_keys["hardware"] == "cpu": + max_logging.log(" Setting num_slices=1 for CPU hardware type") + return 1 + if int(raw_keys["compile_topology_num_slices"]) > 0: + return raw_keys["compile_topology_num_slices"] + else: + devices = jax.devices() + try: + return 1 + max(d.slice_index for d in devices) + except (ValueError, AttributeError): + return 1 + + def is_cpu_backend(raw_keys): """Determine whether Maxtext is intended to run on a CPU backend.""" return raw_keys["hardware"] == "cpu" diff --git a/MaxText/pyconfig.py b/MaxText/pyconfig.py index 5adc88199..fa80c9226 100644 --- a/MaxText/pyconfig.py +++ b/MaxText/pyconfig.py @@ -105,8 +105,16 @@ def validate_keys(keys): assert ( keys["local_checkpoint_period"] > 0 ), "A positive local checkpoint period must be specified when using emergency checkpoint" + if keys["use_replicator_service"]: + assert ( + keys["replicator_backup_interval_minutes"] > 0 + ), "Replicator service is enabled, the backup interval minutes must be positive" else: - max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period") + max_logging.log( + "Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period," + " use_replicator_service and replicator_backup_interval_minutes" + ) + if keys["num_experts"] > 1: validate_megablox_parallelism(keys) @@ -388,7 +396,7 @@ def user_init(raw_keys): raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1 ) - raw_keys["num_slices"] = get_num_slices(raw_keys) + raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys) raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys) if using_pipeline_parallelism(raw_keys): @@ -589,21 +597,6 @@ def get_num_target_devices(raw_keys): return len(jax.devices()) -def get_num_slices(raw_keys): - """Calculate num_slices based on number of devices.""" - if raw_keys["hardware"] == "cpu": - max_logging.log(" Setting num_slices=1 for CPU hardware type") - return 1 - if int(raw_keys["compile_topology_num_slices"]) > 0: - return raw_keys["compile_topology_num_slices"] - else: - devices = jax.devices() - try: - return 1 + max([d.slice_index for d in devices]) - except: - return 1 - - def get_quantization_local_shard_count(raw_keys): if raw_keys["quantization_local_shard_count"] == -1: return raw_keys["num_slices"]