Skip to content

Commit

Permalink
Add checkpoint topology discovery for the replicator service
Browse files Browse the repository at this point in the history
  • Loading branch information
xuefgu committed Nov 25, 2024
1 parent 0c2f0d1 commit c4c0bbb
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 17 deletions.
6 changes: 6 additions & 0 deletions MaxText/configs/base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,12 @@ local_checkpoint_directory: ""
# It should be a positive number when and only when `enable_emergency_checkpoint` is True.
local_checkpoint_period: 0

# Whether or not to use emergency checkpoint with the replicator service.
use_replicator_service: False

# The interval to backup local checkpoints to the persistent storage.
replicator_backup_interval_minutes: 0

# Jax cache directory
jax_cache_dir: "~/jax_cache"

Expand Down
43 changes: 43 additions & 0 deletions MaxText/max_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,34 @@ def initialize_jax_for_tpu_with_emergency_checkpointing(raw_keys):
" coordinator_address to initialize JAX distributed runtime..."
)
jax.distributed.initialize(coordinator_address=coordinator_address, process_id=int(process_id))
if raw_keys["use_replicator_service"]:
REPLICATOR_FILE = "replicator.yaml"
TEMP_FILE = REPLICATOR_FILE + ".tmp"
replicator_file = epath.Path(raw_keys["local_checkpoint_directory"]) / REPLICATOR_FILE
temp_file = epath.Path(raw_keys["local_checkpoint_directory"]) / TEMP_FILE
num_slices = get_num_slices(raw_keys)
num_nodes = jax.process_count()
nodes_per_slice = num_nodes // num_slices
max_logging.log(f"num_slices: {num_slices}, num_nodes: {num_nodes}, nodes_per_slice: {nodes_per_slice}")
node_rank = jax.process_index()
peer_ranks = []
for i in range(num_slices):
peer = node_rank % nodes_per_slice + i * nodes_per_slice
if peer != node_rank:
peer_ranks.append(peer)
run_name = raw_keys["run_name"]
if run_name == "":
run_name = os.environ.get("JOBSET_NAME") # using XPK default

replicator_yaml = f"""job-name: {run_name}
node-rank: {node_rank}
nodes: {num_nodes}
workers-per-node: 1
peer-ranks: {peer_ranks}
backup-interval-minutes: {raw_keys["replicator_backup_interval_minutes"]}"""

temp_file.write_text("\n".join([l.strip() for l in replicator_yaml.split("\n")]))
os.rename(temp_file, replicator_file)
else:
max_logging.log(
"Initializing JAX distributed runtime without args when emergency checkpointing is"
Expand Down Expand Up @@ -319,6 +347,21 @@ def _retrieve_jax_init_info(raw_keys):
return "", ""


def get_num_slices(raw_keys):
"""Calculate num_slices based on number of devices."""
if raw_keys["hardware"] == "cpu":
max_logging.log(" Setting num_slices=1 for CPU hardware type")
return 1
if int(raw_keys["compile_topology_num_slices"]) > 0:
return raw_keys["compile_topology_num_slices"]
else:
devices = jax.devices()
try:
return 1 + max(d.slice_index for d in devices)
except (ValueError, AttributeError):
return 1


def is_cpu_backend(raw_keys):
"""Determine whether Maxtext is intended to run on a CPU backend."""
return raw_keys["hardware"] == "cpu"
Expand Down
27 changes: 10 additions & 17 deletions MaxText/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,16 @@ def validate_keys(keys):
assert (
keys["local_checkpoint_period"] > 0
), "A positive local checkpoint period must be specified when using emergency checkpoint"
if keys["use_replicator_service"]:
assert (
keys["replicator_backup_interval_minutes"] > 0
), "Replicator service is enabled, the backup interval minutes must be positive"
else:
max_logging.log("Not using emergency checkpoint, ignoring local_checkpoint_directory and local_checkpoint_period")
max_logging.log(
"Not using emergency checkpoint, ignoring local_checkpoint_directory, local_checkpoint_period,"
" use_replicator_service and replicator_backup_interval_minutes"
)

if keys["num_experts"] > 1:
validate_megablox_parallelism(keys)

Expand Down Expand Up @@ -388,7 +396,7 @@ def user_init(raw_keys):
raw_keys["eval_per_device_batch_size"], raw_keys["expansion_factor_real_data"], get_num_target_devices(raw_keys), 1
)

raw_keys["num_slices"] = get_num_slices(raw_keys)
raw_keys["num_slices"] = max_utils.get_num_slices(raw_keys)
raw_keys["quantization_local_shard_count"] = get_quantization_local_shard_count(raw_keys)

if using_pipeline_parallelism(raw_keys):
Expand Down Expand Up @@ -589,21 +597,6 @@ def get_num_target_devices(raw_keys):
return len(jax.devices())


def get_num_slices(raw_keys):
"""Calculate num_slices based on number of devices."""
if raw_keys["hardware"] == "cpu":
max_logging.log(" Setting num_slices=1 for CPU hardware type")
return 1
if int(raw_keys["compile_topology_num_slices"]) > 0:
return raw_keys["compile_topology_num_slices"]
else:
devices = jax.devices()
try:
return 1 + max([d.slice_index for d in devices])
except:
return 1


def get_quantization_local_shard_count(raw_keys):
if raw_keys["quantization_local_shard_count"] == -1:
return raw_keys["num_slices"]
Expand Down

0 comments on commit c4c0bbb

Please sign in to comment.