diff --git a/experimental/whisper/speedup.ipynb b/experimental/whisper/speedup.ipynb index 33230f69..70fcc650 100644 --- a/experimental/whisper/speedup.ipynb +++ b/experimental/whisper/speedup.ipynb @@ -666,6 +666,9 @@ } ], "source": [ + "# uncomment 2 following lines and comment the third one to use vanilla torch.compile instead of Kernl\n", + "# model.model.decoder.forward_original = model.model.decoder.forward\n", + "# model.model.decoder.forward = torch.compile(model.model.decoder.forward_original, mode=\"reduce-overhead\")\n", "optimize_model(model.model.decoder)\n", "nb_diff = 0\n", "timings_optimized = list()\n",