Skip to content

Commit

Permalink
wrapped in main
Browse files Browse the repository at this point in the history
  • Loading branch information
parthraut committed Dec 22, 2024
1 parent c0c53ca commit 0f803cf
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 24 deletions.
45 changes: 23 additions & 22 deletions examples/jax/train_single.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,25 +101,26 @@ def get_train_batches():
# tfds.dataset_as_numpy converts the tf.data.Dataset into an iterable of NumPy arrays
return tfds.as_numpy(ds)

monitor = ZeusMonitor()
plo = GlobalPowerLimitOptimizer(monitor)

for epoch in range(num_epochs):
start_time = time.time()

plo.on_epoch_begin()
for x, y in get_train_batches():
plo.on_step_begin()
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
plo.on_step_end()
plo.on_epoch_end()

epoch_time = time.time() - start_time

train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
if __name__ == "__main__":
monitor = ZeusMonitor(sync_execution_with="jax")
plo = GlobalPowerLimitOptimizer(monitor)

for epoch in range(num_epochs):
start_time = time.time()

plo.on_epoch_begin()
for x, y in get_train_batches():
plo.on_step_begin()
x = jnp.reshape(x, (len(x), num_pixels))
y = one_hot(y, num_labels)
params = update(params, x, y)
plo.on_step_end()
plo.on_epoch_end()

epoch_time = time.time() - start_time

train_acc = accuracy(params, train_images, train_labels)
test_acc = accuracy(params, test_images, test_labels)
print("Epoch {} in {:0.2f} sec".format(epoch, epoch_time))
print("Training set accuracy {}".format(train_acc))
print("Test set accuracy {}".format(test_acc))
4 changes: 2 additions & 2 deletions zeus/utils/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def all_reduce(
array = jax.numpy.array(object)

if operation == "sum":
reduced = jax.lax.psum(array)
reduced = jax.pmap(lambda x: jax.lax.psum(x, axis_name="i"), axis_name="i")(array)
elif operation == "max":
reduced = jax.lax.pmax(array)
reduced = jax.pmap(lambda x: jax.lax.pmax(x, axis_name="i"), axis_name="i")(array)
else:
raise ValueError(f"all_reduce unsupported operation: {operation}")

Expand Down

0 comments on commit 0f803cf

Please sign in to comment.