Skip to content

Commit

Permalink
merged in profiler
Browse files Browse the repository at this point in the history
  • Loading branch information
josephdviviano committed Aug 29, 2024
2 parents fd0ef50 + b5c041c commit 6d7b9d3
Showing 1 changed file with 33 additions and 0 deletions.
33 changes: 33 additions & 0 deletions tutorials/examples/train_hypergrid.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
from gfn.utils.common import set_seed
from gfn.utils.modules import DiscreteUniform, NeuralNet, Tabular
from gfn.utils.training import validate
from torch.profiler import profile, ProfilerActivity

DEFAULT_SEED = 4444

Expand Down Expand Up @@ -308,12 +309,28 @@ def main(args): # noqa: C901
total_opt_time, total_rest_time = 0, 0
time_start = time.time()

if args.profile:
keep_active = args.trajectories_to_profile // args.batch_size
prof = profile(
schedule=torch.profiler.schedule(wait=1, warmup=1, active=keep_active, repeat=1),
activities=[ProfilerActivity.CPU],
record_shapes=True,
with_stack=True
)
prof.start()
for iteration in trange(n_iterations):

iteration_start = time.time()

# Time sample_trajectories method.
sample_start = time.time()

# Use the optional profiler.
if args.profile:
prof.step()
if iteration >= 1 + 1 + keep_active:
break

trajectories = gflownet.sample_trajectories(
env,
n_samples=args.batch_size,
Expand Down Expand Up @@ -437,6 +454,10 @@ def main(args): # noqa: C901
for k, v in to_log.iteritems():
print(" {k}: {.:6f}".format(k, v))

if args.profile:
prof.stop()
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=20))
prof.export_chrome_trace("trace.json")
try:
return validation_info["l1_dist"]
except KeyError:
Expand Down Expand Up @@ -659,6 +680,18 @@ def validate_hypergrid(
action="store_true",
help="Calculates the true partition function.",
)
parser.add_argument(
"--profile",
action="store_true",
help="Profiles the execution using PyTorch Profiler.",
)
parser.add_argument(
"--trajectories_to_profile",
type=int,
default=2048,
help="Number of trajectories to profile using the Pytorch Profiler." +
" Preferably, a multiple of batch size.",
)

args = parser.parse_args()
result = main(args)
Expand Down

0 comments on commit 6d7b9d3

Please sign in to comment.