Skip to content

Commit

Permalink
[air/benchmarks] Fix typo in tensorflow_benchmark.py script preventin…
Browse files Browse the repository at this point in the history
…g proper error surfacing (ray-project#32269)

There is a small typo in the tensorflow_benchmark.py script that does not properly catch when a vanilla TF run failed three times. Because of this, we would previously record a training time of 0.0 for vanilla TF, which skews the calculated average and suggests that vanilla TF outperformed Ray Train. Instead, we should have raised an error message to surface the problem.

Signed-off-by: Kai Fricke <kai@anyscale.com>
Signed-off-by: Edward Oakes <ed.nmi.oakes@gmail.com>
  • Loading branch information
krfricke authored and edoakes committed Mar 22, 2023
1 parent f391c23 commit 02e1a45
Showing 1 changed file with 2 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,7 @@ def run(
config=config,
)
except Exception as e:
if i > +2:
if i >= 2:
raise RuntimeError("Vanilla TF run failed 3 times") from e
print("Vanilla TF run failed:", e)
continue
Expand Down Expand Up @@ -338,6 +338,7 @@ def run(
times_ray.append(time_ray)
times_local_ray.append(time_local_ray)
losses_ray.append(loss_ray)

times_vanilla.append(time_vanilla)
times_local_vanilla.append(time_local_vanilla)
losses_vanilla.append(loss_vanilla)
Expand Down

0 comments on commit 02e1a45

Please sign in to comment.