Skip to content

Commit 1ba70fc

Browse files
committed
Fixed bug in example notebook
1 parent 9170418 commit 1ba70fc

File tree

1 file changed

+2
-0
lines changed

1 file changed

+2
-0
lines changed

examples/notebooks/TripletMarginLossMNIST.ipynb

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1142,6 +1142,8 @@
11421142
"def test(train_set, test_set, model, accuracy_calculator):\n",
11431143
" train_embeddings, train_labels = get_all_embeddings(train_set, model)\n",
11441144
" test_embeddings, test_labels = get_all_embeddings(test_set, model)\n",
1145+
" train_labels = train_labels.squeeze(1)\n",
1146+
" test_labels = test_labels.squeeze(1)\n",
11451147
" print(\"Computing accuracy\")\n",
11461148
" accuracies = accuracy_calculator.get_accuracy(test_embeddings, \n",
11471149
" train_embeddings,\n",

0 commit comments

Comments
 (0)