diff --git a/pytorch_translate/test/test_export_models.py b/pytorch_translate/test/test_export_models.py new file mode 100644 index 00000000..eb05fd7b --- /dev/null +++ b/pytorch_translate/test/test_export_models.py @@ -0,0 +1,13 @@ +#!/usr/bin/env python3 + +import unittest + +import torch +from fairseq.modules import multihead_attention + + +class TestExportModels(unittest.TestCase): + @unittest.skip("TDD: placeholder for development") + def test_export_multihead_attention(self): + module = multihead_attention.MultiheadAttention(embed_dim=8, num_heads=2) + torch.jit.script(module)