Skip to content

Commit

Permalink
Merge pull request huggingface#23 from huggingface/modeling_layoutlm_…
Browse files Browse the repository at this point in the history
…v2_lysandre

Fix initialization test
  • Loading branch information
NielsRogge authored Aug 17, 2021
2 parents 41af07f + 49ed4ce commit f71a5e8
Showing 1 changed file with 18 additions and 1 deletion.
19 changes: 18 additions & 1 deletion tests/test_modeling_layoutlmv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from transformers.testing_utils import require_detectron2, require_torch, slow, torch_device

from .test_configuration_common import ConfigTester
from .test_modeling_common import ModelTesterMixin, ids_tensor, random_attention_mask
from .test_modeling_common import ModelTesterMixin, _config_zero_init, ids_tensor, random_attention_mask


if is_torch_available():
Expand Down Expand Up @@ -399,6 +399,23 @@ def test_model_from_pretrained(self):
model = LayoutLMv2Model.from_pretrained(model_name)
self.assertIsNotNone(model)

def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

configs_no_init = _config_zero_init(config)
for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
for name, param in model.named_parameters():
if "backbone" in name or "visual_segment_embedding" in name:
continue

if param.requires_grad:
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)


def prepare_layoutlmv2_batch_inputs():
# Here we prepare a batch of 2 sequences to test a LayoutLMv2 forward pass on:
Expand Down

0 comments on commit f71a5e8

Please sign in to comment.