Skip to content

Commit 61c178e

Browse files
committed
[Feature] Storage Shared Initialization for Multiprocessing
ghstack-source-id: fbbc221 Pull-Request: #3183
1 parent 4f013a8 commit 61c178e

File tree

24 files changed

+645
-185
lines changed

24 files changed

+645
-185
lines changed

examples/distributed/collectors/multi_nodes/delayed_dist.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def make_env():
150150
if i == 10:
151151
t0 = time.time()
152152
t1 = time.time()
153-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
153+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
154154
collector.shutdown()
155155
exit()
156156

examples/distributed/collectors/multi_nodes/delayed_rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -148,7 +148,7 @@ def make_env():
148148
if i == 10:
149149
t0 = time.time()
150150
t1 = time.time()
151-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
151+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
152152
collector.shutdown()
153153
exit()
154154

examples/distributed/collectors/multi_nodes/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,5 +125,5 @@ def gym_make():
125125
t0 = time.time()
126126
collector.shutdown()
127127
t1 = time.time()
128-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
128+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
129129
exit()

examples/distributed/collectors/multi_nodes/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,5 +113,5 @@ def gym_make():
113113
t0 = time.time()
114114
collector.shutdown()
115115
t1 = time.time()
116-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
116+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
117117
exit()

examples/distributed/collectors/multi_nodes/sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,5 +119,5 @@ def gym_make():
119119
t0 = time.time()
120120
collector.shutdown()
121121
t1 = time.time()
122-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
122+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
123123
exit()

examples/distributed/collectors/single_machine/generic.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,5 +160,5 @@ def gym_make():
160160
t0 = time.time()
161161
collector.shutdown()
162162
t1 = time.time()
163-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
163+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
164164
exit()

examples/distributed/collectors/single_machine/rpc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,5 +129,5 @@ def gym_make():
129129
t0 = time.time()
130130
collector.shutdown()
131131
t1 = time.time()
132-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
132+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
133133
exit()

examples/distributed/collectors/single_machine/sync.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -149,5 +149,5 @@ def gym_make():
149149
t0 = time.time()
150150
collector.shutdown()
151151
t1 = time.time()
152-
torchrl_logger.info(f"time elapsed: {t1-t0}s, rate: {counter/(t1-t0)} fps")
152+
torchrl_logger.info(f"time elapsed: {t1 - t0}s, rate: {counter / (t1 - t0)} fps")
153153
exit()

examples/rlhf/train.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ def main(cfg):
147147
elif it % log_interval == 0:
148148
# loss as float. note: this is a CPU-GPU sync point
149149
loss = batch.loss.item()
150-
msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt*1000:.2f}ms"
150+
msg = f"TRAIN: {it=}: {loss=:.4f}, time {dt * 1000:.2f}ms"
151151
torchrl_logger.info(msg)
152152
loss_logger.info(msg)
153153

examples/rlhf/train_reward.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ def main(cfg):
155155
acc = _accuracy(
156156
batch.chosen_data.end_scores, batch.rejected_data.end_scores
157157
)
158-
msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt*1000:.2f}ms"
158+
msg = f"TRAIN: {it=}: {loss=:.4f}, {acc=:.4f} time={dt * 1000:.2f}ms"
159159
torchrl_logger.info(msg)
160160
loss_logger.info(msg)
161161

0 commit comments

Comments
 (0)