diff --git a/16_tensorboard.py b/16_tensorboard.py index 2fe9a3aaa6c85e..94462f50d629b2 100644 --- a/16_tensorboard.py +++ b/16_tensorboard.py @@ -80,7 +80,7 @@ def forward(self, x): optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate) ############## TENSORBOARD ######################## -writer.add_graph(model, example_data.reshape(-1, 28*28)) +writer.add_graph(model, example_data.reshape(-1, 28*28).to(device)) #writer.close() #sys.exit() ################################################### @@ -156,4 +156,4 @@ def forward(self, x): preds_i = class_preds[:, i] writer.add_pr_curve(str(i), labels_i, preds_i, global_step=0) writer.close() - ################################################### \ No newline at end of file + ###################################################