Skip to content

Commit

Permalink
No public description
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 700419920
  • Loading branch information
tensorflower-gardener committed Nov 26, 2024
1 parent be8acfe commit a80ad54
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 73 deletions.
48 changes: 30 additions & 18 deletions official/projects/pix2seq/configs/pix2seq.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,11 +99,20 @@ class Backbone(backbones.Backbone):
resnet: backbones.ResNet = dataclasses.field(default_factory=backbones.ResNet)
uvit: uvit_backbones.VisionTransformer = dataclasses.field(
default_factory=uvit_backbones.VisionTransformer)
# Whether to freeze this backbone during training.
freeze: bool = False
# The endpoint name of the features to extract from the backbone.
endpoint_name: str = '5'
norm_activation: common.NormActivation = dataclasses.field(
default_factory=common.NormActivation
)
# Optional checkpoint to load for this backbone.
init_checkpoint: Optional[str] = None


@dataclasses.dataclass
class Pix2Seq(hyperparams.Config):
"""Pix2Seq model definations."""
"""Pix2Seq model definitions."""

max_num_instances: int = 100
hidden_size: int = 256
Expand All @@ -115,16 +124,16 @@ class Pix2Seq(hyperparams.Config):
shared_decoder_embedding: bool = True
decoder_output_bias: bool = True
input_size: List[int] = dataclasses.field(default_factory=list)
backbone: Backbone = dataclasses.field(
default_factory=lambda: Backbone( # pylint: disable=g-long-lambda
type='resnet',
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
)
)
norm_activation: common.NormActivation = dataclasses.field(
default_factory=common.NormActivation
# Backbones for each image modality. If just using RGB, you should only set
# one backbone.
backbones: List[Backbone] = dataclasses.field(
default_factory=lambda: [
Backbone( # pylint: disable=g-long-lambda
type='resnet',
resnet=backbones.ResNet(model_id=50, bn_trainable=False),
)
]
)
backbone_endpoint_name: str = '5'
drop_path: float = 0.1
drop_units: float = 0.1
drop_att: float = 0.0
Expand Down Expand Up @@ -172,13 +181,16 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig:
),
model=Pix2Seq(
input_size=[640, 640, 3],
norm_activation=common.NormActivation(
norm_momentum=0.9,
norm_epsilon=1e-5,
use_sync_bn=True),
backbone=Backbone(
type='resnet', resnet=backbones.ResNet(model_id=50)
),
backbones=[
Backbone(
type='resnet',
resnet=backbones.ResNet(model_id=50),
norm_activation=common.NormActivation(
norm_momentum=0.9, norm_epsilon=1e-5, use_sync_bn=True
),
init_checkpoint='',
)
],
),
losses=Losses(l2_weight_decay=0.0),
train_data=DataConfig(
Expand All @@ -188,7 +200,7 @@ def pix2seq_r50_coco() -> cfg.ExperimentConfig:
shuffle_buffer_size=train_batch_size * 10,
aug_scale_min=0.3,
aug_scale_max=2.0,
aug_color_jitter_strength=0.0
aug_color_jitter_strength=0.0,
),
validation_data=DataConfig(
input_path=os.path.join(COCO_INPUT_PATH_BASE, 'val*'),
Expand Down
18 changes: 12 additions & 6 deletions official/projects/pix2seq/modeling/pix2seq_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ class Pix2Seq(tf_keras.Model):
def __init__(
self,
backbones: Sequence[tf_keras.Model],
backbone_endpoint_name,
backbone_endpoint_names: Sequence[str],
max_seq_len,
vocab_size,
hidden_size,
Expand All @@ -233,7 +233,7 @@ def __init__(
):
super().__init__(**kwargs)
self._backbones = backbones
self._backbone_endpoint_name = backbone_endpoint_name
self._backbone_endpoint_names = backbone_endpoint_names
self._max_seq_len = max_seq_len
self._vocab_size = vocab_size
self._hidden_size = hidden_size
Expand Down Expand Up @@ -285,9 +285,7 @@ def transformer(self) -> tf_keras.Model:
return self._transformer

def get_config(self):
return {
"backbone": self._backbone,
"backbone_endpoint_name": self._backbone_endpoint_name,
config = {
"max_seq_len": self._max_seq_len,
"vocab_size": self._vocab_size,
"hidden_size": self._hidden_size,
Expand All @@ -302,6 +300,12 @@ def get_config(self):
"early_stopping_token": self._early_stopping_token,
"num_heads": self._num_heads,
}
config["backbone"] = self._backbones[0]
config["backbone_endpoint_name"] = self._backbone_endpoint_names[0]
for i in range(1, len(self._backbones)):
config[f"backbone_{i+1}"] = self._backbones[i]
config[f"backbone_endpoint_name_{i+1}"] = self._backbone_endpoint_names[i]
return config

@classmethod
def from_config(cls, config):
Expand Down Expand Up @@ -354,7 +358,9 @@ def call(
if use_input_as_backbone_features:
features = inputs_i
else:
features = self._backbones[i](inputs_i)[self._backbone_endpoint_name]
features = self._backbones[i](inputs_i)[
self._backbone_endpoint_names[i]
]
mask = tf.ones_like(features)
batch_size, h, w, num_channels = get_shape(features)
features = tf.reshape(features, [batch_size, h * w, num_channels])
Expand Down
20 changes: 10 additions & 10 deletions official/projects/pix2seq/modeling/pix2seq_model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,10 +37,10 @@ def test_forward(self, num_backbones: int):
backbones = [
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
]
backbone_endpoint_name = '5'
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
model = pix2seq_model.Pix2Seq(
backbones,
backbone_endpoint_name,
backbone_endpoint_names,
max_seq_len,
vocab_size,
hidden_size,
Expand Down Expand Up @@ -68,10 +68,10 @@ def test_forward_infer_teacher_forcing(self, num_backbones: int):
backbones = [
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
]
backbone_endpoint_name = '5'
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
model = pix2seq_model.Pix2Seq(
backbones,
backbone_endpoint_name,
backbone_endpoint_names,
max_seq_len,
vocab_size,
hidden_size,
Expand Down Expand Up @@ -100,10 +100,10 @@ def test_forward_infer(self, num_backbones: int):
backbones = [
resnet.ResNet(50, bn_trainable=False) for _ in range(num_backbones)
]
backbone_endpoint_name = '5'
backbone_endpoint_names = ['5' for _ in range(num_backbones)]
model = pix2seq_model.Pix2Seq(
backbones,
backbone_endpoint_name,
backbone_endpoint_names,
max_seq_len,
vocab_size,
hidden_size,
Expand All @@ -125,10 +125,10 @@ def test_forward_infer_with_early_stopping(self):
image_size = 640
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
backbone_endpoint_names = ['5']
model = pix2seq_model.Pix2Seq(
[backbone],
backbone_endpoint_name,
backbone_endpoint_names,
max_seq_len,
vocab_size,
hidden_size,
Expand All @@ -151,10 +151,10 @@ def test_forward_infer_with_long_prompt(self):
image_size = 640
batch_size = 2
backbone = resnet.ResNet(50, bn_trainable=False)
backbone_endpoint_name = '5'
backbone_endpoint_names = ['5']
model = pix2seq_model.Pix2Seq(
[backbone],
backbone_endpoint_name,
backbone_endpoint_names,
max_seq_len,
vocab_size,
hidden_size,
Expand Down
111 changes: 72 additions & 39 deletions official/projects/pix2seq/tasks/pix2seq_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
from official.vision.dataloaders import tfds_factory
from official.vision.dataloaders import tf_example_label_map_decoder
from official.vision.evaluation import coco_evaluator
from official.vision.modeling import backbones
from official.vision.modeling import backbones as backbones_lib


@task_factory.register_task_cls(pix2seq_cfg.Pix2SeqTask)
Expand All @@ -44,24 +44,34 @@ class Pix2SeqTask(base_task.Task):
post-processing, and customized metrics with reduction.
"""

def build_model(self):
"""Build Pix2Seq model."""
def _build_backbones_and_endpoint_names(
self,
) -> tuple[list[tf_keras.Model], list[str]]:
"""Build backbones and returns their corresponding endpoint names."""
config: pix2seq_cfg.Pix2Seq = self._task_config.model

input_specs = tf_keras.layers.InputSpec(
shape=[None] + config.input_size
)
backbones = []
endpoint_names = []
for backbone_config in config.backbones:
backbone = backbones_lib.factory.build_backbone(
input_specs=input_specs,
backbone_config=backbone_config,
norm_activation_config=backbone_config.norm_activation,
)
backbone.trainable = not backbone_config.freeze
backbones.append(backbone)
endpoint_names.append(backbone_config.endpoint_name)
return backbones, endpoint_names

backbone = backbones.factory.build_backbone(
input_specs=input_specs,
backbone_config=config.backbone,
norm_activation_config=config.norm_activation,
)

def build_model(self):
"""Build Pix2Seq model."""
config: pix2seq_cfg.Pix2Seq = self._task_config.model
backbones, endpoint_names = self._build_backbones_and_endpoint_names()
model = pix2seq_model.Pix2Seq(
# TODO: b/378885339 - Support multiple backbones from the config.
backbones=[backbone],
backbone_endpoint_name=config.backbone_endpoint_name,
backbones=backbones,
backbone_endpoint_name=endpoint_names,
max_seq_len=config.max_num_instances * 5,
vocab_size=config.vocab_size,
hidden_size=config.hidden_size,
Expand All @@ -78,41 +88,64 @@ def build_model(self):
)
return model

def _get_ckpt(self, ckpt_dir_or_file: str) -> str:
if tf.io.gfile.isdir(ckpt_dir_or_file):
return tf.train.latest_checkpoint(ckpt_dir_or_file)
return ckpt_dir_or_file

def initialize(self, model: tf_keras.Model):
"""Loading pretrained checkpoint."""
if not self._task_config.init_checkpoint:
return
if self._task_config.init_checkpoint_modules == 'backbone':
raise ValueError(
'init_checkpoint_modules=backbone is deprecated. Specify backbone '
'checkpoints in each backbone config.'
)

ckpt_dir_or_file = self._task_config.init_checkpoint
if self._task_config.init_checkpoint_modules not in ['all', 'partial', '']:
raise ValueError(
'Unsupported init_checkpoint_modules: '
f'{self._task_config.init_checkpoint_modules}'
)

# Restoring checkpoint.
if tf.io.gfile.isdir(ckpt_dir_or_file):
ckpt_dir_or_file = tf.train.latest_checkpoint(ckpt_dir_or_file)
if self._task_config.init_checkpoint and any(
[b.init_checkpoint for b in self._task_config.model.backbones]
):
raise ValueError(
'A global init_checkpoint and a backbone init_checkpoint cannot be'
' specified at the same time.'
)

if self._task_config.init_checkpoint_modules == 'all':
if self._task_config.init_checkpoint:
global_ckpt_file = self._get_ckpt(self._task_config.init_checkpoint)
ckpt = tf.train.Checkpoint(**model.checkpoint_items)
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
status = ckpt.restore(global_ckpt_file).expect_partial()
if self._task_config.init_checkpoint_modules != 'partial':
status.assert_existing_objects_matched()
logging.info(
'Finished loading pretrained checkpoint from %s', ckpt_dir_or_file
)
elif self._task_config.init_checkpoint_modules == 'backbone':
if self.task_config.model.backbone.type == 'uvit':
model.backbone.load_checkpoint(ckpt_filepath=ckpt_dir_or_file)
else:
# TODO: b/378885339 - Support multiple backbones from the config.
ckpt = tf.train.Checkpoint(backbone=model.backbones[0])
status = ckpt.restore(ckpt_dir_or_file)
status.expect_partial().assert_existing_objects_matched()
logging.info(
'Finished loading pretrained backbone from %s', ckpt_dir_or_file
'Finished loading pretrained checkpoint from %s', global_ckpt_file
)
else:
raise ValueError(
f'Failed to load {ckpt_dir_or_file}. Unsupported '
'init_checkpoint_modules: '
f'{self._task_config.init_checkpoint_modules}'
)
# This case means that no global checkpoint was provided. Possibly,
# backbone-specific checkpoints were.
for backbone_config, backbone in zip(
self._task_config.model.backbones, model.backbones
):
if not backbone_config.init_checkpoint:
continue

backbone_init_ckpt = self._get_ckpt(backbone_config.init_checkpoint)
if backbone_config.type == 'uvit':
# The UVit object has a special function called load_checkpoint.
# The other backbones do not.
backbone.load_checkpoint(ckpt_filepath=backbone_init_ckpt)
else:
ckpt = tf.train.Checkpoint(backbone=backbone)
status = ckpt.restore(backbone_init_ckpt)
status.expect_partial().assert_existing_objects_matched()

logging.info(
'Finished loading pretrained backbone from %s', backbone_init_ckpt
)

def build_inputs(
self, params, input_context: Optional[tf.distribute.InputContext] = None
Expand Down

0 comments on commit a80ad54

Please sign in to comment.