Skip to content

Commit

Permalink
Set seed
Browse files Browse the repository at this point in the history
  • Loading branch information
VibhuJawa committed Aug 25, 2023
1 parent 36c82fb commit f99a518
Showing 1 changed file with 9 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import os
import time
import json
import random
import numpy as np
from argparse import ArgumentParser


Expand Down Expand Up @@ -218,6 +220,11 @@ def dataloading_benchmark(g, train_idx, fanouts, batch_sizes, use_uva):
print("==============================================")
return time_ls

def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)

if __name__ == "__main__":
parser = ArgumentParser()
Expand All @@ -230,13 +237,14 @@ def dataloading_benchmark(g, train_idx, fanouts, batch_sizes, use_uva):
)
parser.add_argument("--batch_sizes", type=str, default="512,1024")
parser.add_argument("--do_not_use_uva", action="store_true")
parser.add_argument("--seed", type=int, default=42)
args = parser.parse_args()

if args.do_not_use_uva:
use_uva = False
else:
use_uva = True

set_seed(args.seed)
replication_factors = [int(x) for x in args.replication_factors.split(",")]
fanouts = [[int(y) for y in x.split("_")] for x in args.fanouts.split(",")]
batch_sizes = [int(x) for x in args.batch_sizes.split(",")]
Expand Down

0 comments on commit f99a518

Please sign in to comment.