2222
2323input_dim_t = Union [int , Tuple [int , int ]]
2424
25- # Need for model weight loading.
26- NUM_ATTENTION_HEADS = 16
25+ # Model parameters which is not in config.json.
26+ # TODO: read from config.json when it is released.
27+ NUM_ATTENTION_HEADS_FOR_VIT = 16
28+ IMAGE_SIZE_FOR_VIT = 224
29+ PATCH_SIZE_FOR_VIT = 16
30+ EMBED_DIM_FOR_VIT = 1280
31+ DEPTH_FOR_VIT = 32
2732
2833
2934class Resolution (NamedTuple ):
@@ -34,7 +39,7 @@ class Resolution(NamedTuple):
3439class RADIOConfig (PretrainedConfig ):
3540 """Pretrained Hugging Face configuration for RADIO models.
3641
37- Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py.
42+ Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py.
3843 """
3944
4045 def __init__ (
@@ -55,8 +60,7 @@ def __init__(
5560 for field in ["dtype" , "amp_dtype" ]:
5661 if self .args is not None and field in self .args :
5762 # Convert to a string in order to make it serializable.
58- # For example for torch.float32 we will store "float32",
59- # for "bfloat16" we will store "bfloat16".
63+ # For example for torch.float32 we will store "float32".
6064 self .args [field ] = str (args [field ]).split ("." )[- 1 ]
6165 self .version = version
6266 self .patch_size = patch_size
@@ -68,13 +72,13 @@ def __init__(
6872 self .vitdet_window_size = vitdet_window_size
6973 self .feature_normalizer_config = feature_normalizer_config
7074 self .inter_feature_normalizer_config = inter_feature_normalizer_config
71- self .num_key_value_heads = NUM_ATTENTION_HEADS
72- self .num_attention_heads = NUM_ATTENTION_HEADS
75+ self .num_key_value_heads = NUM_ATTENTION_HEADS_FOR_VIT
76+ self .num_attention_heads = NUM_ATTENTION_HEADS_FOR_VIT
7377 super ().__init__ (** kwargs )
7478
7579
7680class ClsToken (nn .Module ):
77- """Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/cls_token.py."""
81+ """Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/cls_token.py."""
7882
7983 def __init__ (
8084 self ,
@@ -115,7 +119,7 @@ def forward(self, x: torch.Tensor):
115119
116120
117121class ViTPatchGenerator (nn .Module ):
118- """Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
122+ """Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
119123
120124 def __init__ (
121125 self ,
@@ -132,8 +136,6 @@ def __init__(
132136 register_multiple : Optional [int ] = None ,
133137 num_registers : Optional [int ] = None ,
134138 patch_bias : bool = False ,
135- device = None ,
136- dtype = None ,
137139 ):
138140 super ().__init__ ()
139141
@@ -151,42 +153,31 @@ def __init__(
151153 self .cpe_mode = max_input_dims != input_dims
152154 self .pos_dropout = pos_dropout
153155 self .return_pos_enc = return_pos_enc
154-
155- factory = dict (device = device , dtype = dtype )
156-
157156 self .patch_size = patch_size
158157 self .abs_pos = abs_pos
159158 self .embed_dim = embed_dim
160-
161159 self .num_rows = max_input_dims [0 ] // patch_size
162160 self .num_cols = max_input_dims [1 ] // patch_size
163161 self .input_dims = tuple (d // patch_size for d in input_dims )
164162 self .num_patches = self .num_rows * self .num_cols
165163 self .max_input_dims = max_input_dims
166164
167165 self .im_to_patches = Im2Patches (patch_size )
168- self .embedder = ViTPatchLinear (patch_size ,
169- embed_dim ,
170- bias = patch_bias ,
171- ** factory )
172-
166+ self .embedder = ViTPatchLinear (patch_size , embed_dim , bias = patch_bias )
173167 if abs_pos :
174168 scale = embed_dim ** - 0.5
175169 self .pos_embed = nn .Parameter (
176- torch .randn (1 , self .num_patches , embed_dim , ** factory ) * scale )
177-
170+ torch .randn (1 , self .num_patches , embed_dim ) * scale )
178171 self .cls_token = ClsToken (
179172 embed_dim ,
180173 num_tokens = num_cls_tokens ,
181174 enabled = cls_token ,
182175 register_multiple = register_multiple ,
183176 num_registers = num_registers ,
184177 )
185-
186178 self .patch_normalizer = nn .LayerNorm (
187179 embed_dim ) if normalize_patches else nn .Identity ()
188180
189- @torch .compile
190181 def forward (self , x : torch .Tensor ) -> torch .Tensor :
191182 patches = self .embed_patches (x )
192183 patches , pos_enc = self .apply_pos_enc (patches , input_size = x .shape [2 :])
@@ -265,7 +256,6 @@ def window_select(pos_embed):
265256 size = (max_dim , max_dim ),
266257 align_corners = True ,
267258 mode = 'bilinear' ).to (pos_embed .dtype )
268-
269259 pos_embed = window_select (pos_embed )
270260 else :
271261 pos_embed = window_select (pos_embed )
@@ -277,12 +267,11 @@ def window_select(pos_embed):
277267 mode = 'bilinear' ).to (pos_embed .dtype )
278268
279269 pos_embed = pos_embed .flatten (2 ).permute (0 , 2 , 1 )
280-
281270 return pos_embed
282271
283272
284273class Im2Patches (nn .Module ):
285- """Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
274+ """Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
286275
287276 def __init__ (self , patch_size : int ):
288277 super ().__init__ ()
@@ -308,22 +297,23 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
308297
309298
310299class ViTPatchLinear (nn .Linear ):
311- """Copy from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
312-
313- def __init__ (self ,
314- patch_size : int ,
315- embed_dim : int ,
316- bias : bool = False ,
317- ** kwargs ):
318- super ().__init__ (3 * (patch_size ** 2 ), embed_dim , bias = bias , ** kwargs )
300+ """Modified from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/vit_patch_generator.py."""
301+
302+ def __init__ (
303+ self ,
304+ patch_size : int ,
305+ embed_dim : int ,
306+ bias : bool = False ,
307+ ):
308+ super ().__init__ (3 * (patch_size ** 2 ), embed_dim , bias = bias )
319309 self .patch_size = patch_size
320310
321311
322312class Block (nn .Module ):
323313 """Transformer block with pre-normalization.
324314
325- Copy from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
326- and use trtllm_attn and trtllm_mlp to replace attn and mlp.
315+ Modified from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py
316+ Use trtllm_attn and trtllm_mlp to replace original attention and mlp layers .
327317 """
328318
329319 def __init__ (
@@ -378,16 +368,16 @@ def __init__(
378368 hidden_size = dim ,
379369 num_attention_heads = num_heads ,
380370 num_key_value_heads = num_heads ,
371+ max_position_embeddings = None ,
381372 bias = qkv_bias ,
382- dense_bias = proj_bias ,
383- dtype = self .model_config .torch_dtype ,
384- layer_idx = layer_idx ,
385373 pos_embd_params = None ,
386374 rope_fusion = None ,
375+ layer_idx = layer_idx ,
376+ dtype = self .model_config .torch_dtype ,
377+ dense_bias = proj_bias ,
378+ config = self .model_config ,
387379 q_scaling = 1.0 ,
388380 attention_chunk_size = None ,
389- config = self .model_config ,
390- max_position_embeddings = None ,
391381 )
392382 if init_values :
393383 raise IOError (
@@ -399,8 +389,6 @@ def __init__(
399389 "Limited RADIO model support: Block does not support DropPath for now."
400390 )
401391 self .drop_path1 = nn .Identity ()
402-
403- self .norm2 = norm_layer (dim )
404392 if scale_mlp_norm :
405393 raise IOError (
406394 "Limited RADIO model support: Block does not support scale_mlp_norm for now."
@@ -409,6 +397,7 @@ def __init__(
409397 raise IOError (
410398 "Limited RADIO model support: Block does not support proj_drop for now."
411399 )
400+ self .norm2 = norm_layer (dim )
412401
413402 self .mlp = trtllm_mlp .MLP (
414403 hidden_size = dim ,
@@ -442,8 +431,7 @@ def forward(
442431 position_ids = None ,
443432 hidden_states = x ,
444433 attn_metadata = attn_metadata ,
445- attention_mask = attention_interface .PredefinedAttentionMask .
446- FULL # Always FULL for Vision
434+ attention_mask = attention_interface .PredefinedAttentionMask .FULL ,
447435 )
448436 x = self .ls1 (x )
449437 x = self .drop_path1 (x )
@@ -461,7 +449,7 @@ def forward(
461449class VisionTransformer (nn .Module ):
462450 """ Vision Transformer.
463451
464- Copy from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py.
452+ Modified from https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py.
465453 """
466454
467455 def __init__ (
@@ -535,9 +523,11 @@ def __init__(
535523 **kwargs: Additional keyword arguments, to store unused arguments.
536524 """
537525 super ().__init__ ()
538- assert global_pool in ('' , 'avg' , 'avgmax' , 'max' , 'token' , 'map' )
539- assert class_token or global_pool != 'token'
540- assert pos_embed in ('' , 'none' , 'learn' )
526+ if not (class_token or global_pool != 'token' ):
527+ raise ValueError (
528+ "Class token must be used with global_pool == 'token'" )
529+ if pos_embed not in ('' , 'none' , 'learn' ):
530+ raise ValueError (f"Invalid pos_embed: { pos_embed } " )
541531 use_fc_norm = global_pool in ('avg' , 'avgmax' ,
542532 'max' ) if fc_norm is None else fc_norm
543533
@@ -555,7 +545,7 @@ def __init__(
555545
556546 self .num_classes = num_classes
557547 self .global_pool = global_pool
558- self .num_features = self .head_hidden_size = self .embed_dim = embed_dim # for consistency with other models
548+ self .num_features = self .head_hidden_size = self .embed_dim = embed_dim
559549 self .num_prefix_tokens = 1 if class_token else 0
560550 self .num_prefix_tokens += reg_tokens
561551 self .num_reg_tokens = reg_tokens
@@ -565,7 +555,7 @@ def __init__(
565555 self .patch_drop = nn .Identity ()
566556 self .norm_pre = norm_layer (embed_dim ) if pre_norm else nn .Identity ()
567557
568- # stochastic depth decay rule
558+ # Stochastic depth decay rule.
569559 dpr = [x .item () for x in torch .linspace (0 , drop_path_rate , depth )]
570560 self .blocks = nn .ModuleList ([
571561 Block (
@@ -590,7 +580,7 @@ def __init__(
590580 self .norm = norm_layer (
591581 embed_dim ) if final_norm and not use_fc_norm else nn .Identity ()
592582
593- # Classifier Head but not used for RADIO embedding models.
583+ # Initialize classifier head but not used for RADIO embedding models.
594584 self .attn_pool = None
595585 self .fc_norm = norm_layer (
596586 embed_dim ) if final_norm and use_fc_norm else nn .Identity ()
@@ -664,9 +654,8 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
664654 if self .model_config is not None :
665655 seq_lengths = [seq_len ] * batch_size
666656 attn_metadata = self .prepare_attn_metadata (batch_size , seq_lengths )
667- x = x .reshape (
668- batch_size * seq_len ,
669- hidden_size ) # Need flatten batch/seq_len for trtllm attention.
657+ # Need flatten batch/seq_len for trtllm attention.
658+ x = x .reshape (batch_size * seq_len , hidden_size )
670659 else :
671660 attn_metadata = None
672661 for block in self .blocks :
@@ -678,7 +667,7 @@ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
678667
679668
680669class RADIOVisionModelBase (nn .Module ):
681- """Copy and modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/radio_model.py"""
670+ """Modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/radio_model.py"""
682671
683672 def __init__ (
684673 self ,
@@ -783,17 +772,13 @@ def get_nearest_supported_resolution(self, height: int,
783772 round (height / self .min_resolution_step ) * self .min_resolution_step )
784773 width = int (
785774 round (width / self .min_resolution_step ) * self .min_resolution_step )
786-
787775 height = max (height , self .min_resolution_step )
788776 width = max (width , self .min_resolution_step )
789-
790777 return Resolution (height = height , width = width )
791778
792- def forward (
793- self ,
794- x : torch .Tensor ,
795- feature_fmt : str = 'NLC'
796- ) -> Union [torch .Tensor , Tuple [torch .Tensor , torch .Tensor ]]:
779+ def forward (self ,
780+ x : torch .Tensor ,
781+ feature_fmt : str = 'NLC' ) -> torch .Tensor :
797782 res_step = self .min_resolution_step
798783 if res_step is not None and (x .shape [- 2 ] % res_step != 0
799784 or x .shape [- 1 ] % res_step != 0 ):
@@ -807,7 +792,6 @@ def forward(
807792 ret = self ._extract_final (x , y , feature_fmt = feature_fmt )
808793 return ret
809794
810- @torch .compile
811795 def _extract_final (self ,
812796 x : torch .Tensor ,
813797 y : torch .Tensor ,
@@ -836,12 +820,11 @@ def _extract_final(self,
836820 raise ValueError (
837821 f'Unsupported feature_fmt: { feature_fmt } . Must be one of ["NLC", "NCHW"]'
838822 )
839-
840823 return fmt_feat
841824
842825
843826class RADIOVisionModel (PreTrainedModel ):
844- """Copy and modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py."""
827+ """Modify from https://huggingface.co/nvidia/C-RADIOv2-H/blob/main/hf_model.py."""
845828
846829 def __init__ (self , model_config : model_config_lib .ModelConfig ):
847830 """
@@ -863,11 +846,11 @@ def __init__(self, model_config: model_config_lib.ModelConfig):
863846 elif args .input_size is not None :
864847 in_chans = args .input_size [0 ]
865848 vit_model = VisionTransformer (
866- img_size = 224 ,
867- patch_size = 16 ,
868- embed_dim = 1280 ,
869- depth = 32 ,
870- num_heads = NUM_ATTENTION_HEADS ,
849+ img_size = IMAGE_SIZE_FOR_VIT ,
850+ patch_size = PATCH_SIZE_FOR_VIT ,
851+ embed_dim = EMBED_DIM_FOR_VIT ,
852+ depth = DEPTH_FOR_VIT ,
853+ num_heads = NUM_ATTENTION_HEADS_FOR_VIT ,
871854 in_chans = in_chans ,
872855 drop_rate = args .drop ,
873856 special_args = args ,
@@ -920,11 +903,11 @@ def load_weights(self, weights):
920903 }
921904 missing_keys , unexpected_keys = self .radio_model .load_state_dict (
922905 filter_weights , strict = False )
923-
924906 # Check missing and unexpected keys.
925907 # The input conditioner is not initialized in current implementation.
926908 unexpected_keys .remove ("input_conditioner.norm_mean" )
927909 unexpected_keys .remove ("input_conditioner.norm_std" )
910+ # Partial model.blocks weights will loaded in the following step.
928911 for m in missing_keys :
929912 if not m .startswith ('model.blocks.' ):
930913 raise ValueError (f"Missing key: { m } " )
0 commit comments