Skip to content

Commit 6bfa880

Browse files
committed
Fixed same bug in another example notebook
1 parent 1ba70fc commit 6bfa880

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/notebooks/DistributedTripletMarginLossMNIST.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -171,6 +171,8 @@
171171
"def test(train_set, test_set, model, accuracy_calculator, data_device):\n",
172172
" train_embeddings, train_labels = get_all_embeddings(train_set, model, data_device)\n",
173173
" test_embeddings, test_labels = get_all_embeddings(test_set, model, data_device)\n",
174+
" train_labels = train_labels.squeeze(1)\n",
175+
" test_labels = test_labels.squeeze(1)\n",
174176
" print(\"Computing accuracy\")\n",
175177
" accuracies = accuracy_calculator.get_accuracy(test_embeddings, \n",
176178
" train_embeddings,\n",

0 commit comments

Comments
 (0)