Skip to content

Commit

Permalink
Merge branch 'main' 4
Browse files Browse the repository at this point in the history
  • Loading branch information
aymeric-roucher committed Jul 4, 2024
2 parents 5f30cc5 + 048f599 commit af843a8
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 20 deletions.
5 changes: 5 additions & 0 deletions src/transformers/models/rt_detr/configuration_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,9 @@ class RTDetrConfig(PretrainedConfig):
Args:
initializer_range (`float`, *optional*, defaults to 0.01):
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
initializer_bias_prior_prob (`float`, *optional*):
The prior probability used by the bias initializer to initialize biases for `enc_score_head` and `class_embed`.
If `None`, `prior_prob` computed as `prior_prob = 1 / (num_labels + 1)` while initializing model weights.
layer_norm_eps (`float`, *optional*, defaults to 1e-05):
The epsilon used by the layer normalization layers.
batch_norm_eps (`float`, *optional*, defaults to 1e-05):
Expand Down Expand Up @@ -179,6 +182,7 @@ class RTDetrConfig(PretrainedConfig):
def __init__(
self,
initializer_range=0.01,
initializer_bias_prior_prob=None,
layer_norm_eps=1e-5,
batch_norm_eps=1e-5,
# backbone
Expand Down Expand Up @@ -239,6 +243,7 @@ def __init__(
**kwargs,
):
self.initializer_range = initializer_range
self.initializer_bias_prior_prob = initializer_bias_prior_prob
self.layer_norm_eps = layer_norm_eps
self.batch_norm_eps = batch_norm_eps
# backbone
Expand Down
50 changes: 36 additions & 14 deletions src/transformers/models/rt_detr/modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -1148,14 +1148,27 @@ class RTDetrPreTrainedModel(PreTrainedModel):
def _init_weights(self, module):
"""Initalize the weights"""

"""initialize conv/fc bias value according to a given probability value."""
if isinstance(module, nn.Linear) and hasattr(module, "class_embed"):
prior_prob = self.config.initializer_range
"""initialize linear layer bias value according to a given probability value."""
if isinstance(module, (RTDetrForObjectDetection, RTDetrDecoder)):
if module.class_embed is not None:
for layer in module.class_embed:
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
bias = float(-math.log((1 - prior_prob) / prior_prob))
nn.init.xavier_uniform_(layer.weight)
nn.init.constant_(layer.bias, bias)

if module.bbox_embed is not None:
for layer in module.bbox_embed:
nn.init.constant_(layer.layers[-1].weight, 0)
nn.init.constant_(layer.layers[-1].bias, 0)

if isinstance(module, RTDetrModel):
prior_prob = self.config.initializer_bias_prior_prob or 1 / (self.config.num_labels + 1)
bias = float(-math.log((1 - prior_prob) / prior_prob))
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, bias)
elif isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
nn.init.xavier_uniform_(module.enc_score_head.weight)
nn.init.constant_(module.enc_score_head.bias, bias)

if isinstance(module, (nn.Linear, nn.Conv2d, nn.BatchNorm2d)):
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
if module.bias is not None:
module.bias.data.zero_()
Expand Down Expand Up @@ -1656,7 +1669,11 @@ def unfreeze_backbone(self):
param.requires_grad_(True)

@lru_cache(maxsize=32)
def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.float32, device="cpu"):
def generate_anchors(self, spatial_shapes=None, grid_size=0.05):
# We always generate anchors in float32 to preserve equivalence between
# dynamic and static anchor inference
dtype = torch.float32

if spatial_shapes is None:
spatial_shapes = [
[int(self.config.anchor_image_size[0] / s), int(self.config.anchor_image_size[1] / s)]
Expand All @@ -1674,7 +1691,7 @@ def generate_anchors(self, spatial_shapes=None, grid_size=0.05, dtype=torch.floa
anchors.append(torch.concat([grid_xy, wh], -1).reshape(-1, height * width, 4))
# define the valid range for anchor coordinates
eps = 1e-2
anchors = torch.concat(anchors, 1).to(device)
anchors = torch.concat(anchors, 1)
valid_mask = ((anchors > eps) * (anchors < 1 - eps)).all(-1, keepdim=True)
anchors = torch.log(anchors / (1 - anchors))
anchors = torch.where(valid_mask, anchors, torch.inf)
Expand Down Expand Up @@ -1769,15 +1786,15 @@ def forward(

# Prepare encoder inputs (by flattening)
source_flatten = []
spatial_shapes = []
spatial_shapes_list = []
for level, source in enumerate(sources):
batch_size, num_channels, height, width = source.shape
spatial_shape = (height, width)
spatial_shapes.append(spatial_shape)
spatial_shapes_list.append(spatial_shape)
source = source.flatten(2).transpose(1, 2)
source_flatten.append(source)
source_flatten = torch.cat(source_flatten, 1)
spatial_shapes = torch.as_tensor(spatial_shapes, dtype=torch.long, device=source_flatten.device)
spatial_shapes = torch.as_tensor(spatial_shapes_list, dtype=torch.long, device=source_flatten.device)
level_start_index = torch.cat((spatial_shapes.new_zeros((1,)), spatial_shapes.prod(1).cumsum(0)[:-1]))

# prepare denoising training
Expand Down Expand Up @@ -1805,9 +1822,14 @@ def forward(

# prepare input for decoder
if self.training or self.config.anchor_image_size is None:
anchors, valid_mask = self.generate_anchors(spatial_shapes, device=device, dtype=dtype)
# Pass spatial_shapes as tuple to make it hashable and make sure
# lru_cache is working for generate_anchors()
spatial_shapes_tuple = tuple(spatial_shapes_list)
anchors, valid_mask = self.generate_anchors(spatial_shapes_tuple)
else:
anchors, valid_mask = self.anchors.to(device, dtype), self.valid_mask.to(device, dtype)
anchors, valid_mask = self.anchors, self.valid_mask

anchors, valid_mask = anchors.to(device, dtype), valid_mask.to(device, dtype)

# use the valid_mask to selectively retain values in the feature map where the mask is `True`
memory = valid_mask.to(source_flatten.dtype) * source_flatten
Expand Down
76 changes: 70 additions & 6 deletions tests/models/rt_detr/test_modeling_rt_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import inspect
import math
import tempfile
import unittest

from parameterized import parameterized
Expand Down Expand Up @@ -583,6 +584,11 @@ def test_initialization(self):
config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()

configs_no_init = _config_zero_init(config)
configs_no_init.initializer_bias_prior_prob = 0.2
bias_value = -1.3863 # log_e ((1 - 0.2) / 0.2)

failed_cases = []

for model_class in self.all_model_classes:
model = model_class(config=configs_no_init)
# Skip the check for the backbone
Expand All @@ -593,20 +599,36 @@ def test_initialization(self):

for name, param in model.named_parameters():
if param.requires_grad:
if (
if ("class_embed" in name and "bias" in name) or "enc_score_head.bias" in name:
bias_tensor = torch.full_like(param.data, bias_value)
if not torch.allclose(param.data, bias_tensor, atol=1e-4):
failed_cases.append(
f"Parameter {name} of model {model_class} seems not properly initialized. "
f"Biases should be initialized to {bias_value}, got {param.data}"
)
elif (
"level_embed" in name
or "sampling_offsets.bias" in name
or "value_proj" in name
or "output_proj" in name
or "reference_points" in name
or "enc_score_head.weight" in name
or ("class_embed" in name and "weight" in name)
or name in backbone_params
):
continue
self.assertIn(
((param.data.mean() * 1e9).round() / 1e9).item(),
[0.0, 1.0],
msg=f"Parameter {name} of model {model_class} seems not properly initialized",
)
else:
mean = param.data.mean()
round_mean = (mean * 1e9).round() / 1e9
round_mean = round_mean.item()
if round_mean not in [0.0, 1.0]:
failed_cases.append(
f"Parameter {name} of model {model_class} seems not properly initialized. "
f"Mean is {round_mean}, but should be in [0, 1]"
)

message = "\n" + "\n".join(failed_cases)
self.assertTrue(not failed_cases, message)

@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu
Expand All @@ -630,6 +652,48 @@ def test_inference_with_different_dtypes(self, torch_dtype_str):
with torch.no_grad():
_ = model(**self._prepare_for_class(inputs_dict, model_class))

@parameterized.expand(["float32", "float16", "bfloat16"])
@require_torch_gpu
@slow
def test_inference_equivalence_for_static_and_dynamic_anchors(self, torch_dtype_str):
torch_dtype = {
"float32": torch.float32,
"float16": torch.float16,
"bfloat16": torch.bfloat16,
}[torch_dtype_str]

config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
h, w = inputs_dict["pixel_values"].shape[-2:]

# convert inputs to the desired dtype
for key, tensor in inputs_dict.items():
if tensor.dtype == torch.float32:
inputs_dict[key] = tensor.to(torch_dtype)

for model_class in self.all_model_classes:
with tempfile.TemporaryDirectory() as tmpdirname:
model_class(config).save_pretrained(tmpdirname)
model_static = model_class.from_pretrained(
tmpdirname, anchor_image_size=[h, w], device_map=torch_device, torch_dtype=torch_dtype
).eval()
model_dynamic = model_class.from_pretrained(
tmpdirname, anchor_image_size=None, device_map=torch_device, torch_dtype=torch_dtype
).eval()

self.assertIsNotNone(model_static.config.anchor_image_size)
self.assertIsNone(model_dynamic.config.anchor_image_size)

with torch.no_grad():
outputs_static = model_static(**self._prepare_for_class(inputs_dict, model_class))
outputs_dynamic = model_dynamic(**self._prepare_for_class(inputs_dict, model_class))

self.assertTrue(
torch.allclose(
outputs_static.last_hidden_state, outputs_dynamic.last_hidden_state, rtol=1e-4, atol=1e-4
),
f"Max diff: {(outputs_static.last_hidden_state - outputs_dynamic.last_hidden_state).abs().max()}",
)


TOLERANCE = 1e-4

Expand Down

0 comments on commit af843a8

Please sign in to comment.