Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to convert YOLO NAS .pth output into Tflite. #1153

Closed
adkbbx opened this issue Jun 9, 2023 · 7 comments
Closed

How to convert YOLO NAS .pth output into Tflite. #1153

adkbbx opened this issue Jun 9, 2023 · 7 comments

Comments

@adkbbx
Copy link

adkbbx commented Jun 9, 2023

💡 Your Question

I was trying to convert the yolo nas output of my custom object detection model from .pth into .tflite by refering to this repo by converting it into ONNX -> TF -> Tflite.

!pip install super-gradients==3.1.0
!pip install imutils
!pip install pytube --upgrade
!pip install --upgrade pillow==9.2.0

import torch
from super_gradients.training import models

dataset_params = {
    'data_dir':'/content/drive/MyDrive/Numberplate/Yolo/Yolov8/data',
    'train_images_dir':'train/images',
    'train_labels_dir':'train/labels',
    'val_images_dir':'validate/images',
    'val_labels_dir':'validate/labels',
    'test_images_dir':'test/images',
    'test_labels_dir':'test/labels',
    'classes': ['blue','red','green','black','yellow','white']
}


model = models.get('yolo_nas_l', 
                   num_classes=len(dataset_params['classes']), 
                   pretrained_weights="coco"
                   ) 


pt_model_path = "/path/to/ckpt_best.pth" 
model.load_state_dict(torch.load(pt_model_path, map_location='cpu'))

but there seems to be a runtime error when I try to load the model :

RuntimeError: Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "net", "acc", "epoch", "optimizer_state_dict", "scaler_state_dict", "ema_net", "processing_params". 

I am using google colab to run this code. Do let me know if you have any solution this issue.

Thank you.

Versions

No response

@NatanBagrov
Copy link
Contributor

The weights in the pth include other "states". The model's weights are located under ema_net or net

@adkbbx
Copy link
Author

adkbbx commented Jun 12, 2023

The weights in the pth include other "states". The model's weights are located under ema_net or net

Thank you for you response.

I tried to use the net key to get the model weights as below.

my_model = (torch.load(pt_model_path, map_location='cpu'))
model.load_state_dict(my_model['net'])

but it agian gives a similar runtime error as before

RuntimeError                              Traceback (most recent call last)
[<ipython-input-34-68a3d4abfa3b>](https://localhost:8080/#) in <cell line: 1>()
----> 1 model.load_state_dict(my_model['ema_net'])

[/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in load_state_dict(self, state_dict, strict)
   1669 
   1670         if len(error_msgs) > 0:
-> 1671             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
   1672                                self.__class__.__name__, "\n\t".join(error_msgs)))
   1673         return _IncompatibleKeys(missing_keys, unexpected_keys)

RuntimeError: Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "module.backbone.stem.conv.branch_3x3.conv.weight", "module.backbone.stem.conv.branch_3x3.bn.weight", "module.backbone.stem.conv.branch_3x3.bn.bias", "module.backbone.stem.conv.branch_3x3.bn.running_mean", "module.backbone.stem.conv.branch_3x3.bn.running_var", "module.backbone.stem.conv.branch_3x3.bn.num_batches_tracked", "module.backbone.stem.conv.branch_1x1.weight", "module.backbone.stem.conv.branch_1x1.bias", "module.backbone.stem.conv.post_bn.weight", "module.backbone.stem.conv.post_bn.bias", "module.backbone.stem.conv.post_bn.running_mean", "module.backbone.stem.conv.post_bn.running_var", "module.backbone.stem.conv.post_bn.num_batches_tracked", "module.backbone.stem.conv.rbr_reparam.weight", "module.backbone.stem.conv.rbr_reparam.bias", "module.backbone.stage1.downsample.branch_3x3.conv.weight", "module.backbone.stage1.downsample.branch_3x3.bn.weight", "module.backbone.stage1.downsample.branch_3x3.bn.bias", "module.backbone.stage1.downsample.branch_3x3.bn.running_mean", "module.backbone.stage1.downsample.branch_3x3.bn.running_var", "module.backbone.stage1.downsample.branch_3x3.bn.num_batches_tracked", "module.backbone.stage1.downsample.branch_1x1.weight", "module.backbone.stage1.downsample.branch_1x1.bias", "module.backbone.stage1.downsample.post_bn.weight", "module.backbone.stage1.downsample.post_bn.bias", "module.backbone.stage1.downsample.post_bn.running_mean", "module.backbone.stage1.downsample.post_bn.running_var", "module.backbo...

Do let me know if there is any other solution I can try to solve this.

Thank you in advance.

@NatanBagrov
Copy link
Contributor

As you can see, the keys pretty much match, and the only difference is module. prefix. To successfully load state dict you may use the following existing method in SG:

adaptive_load_state_dict(model, my_model, strict="no_key_matching")

@adkbbx
Copy link
Author

adkbbx commented Jun 12, 2023

Thank you for your response.

I did try to run the code above as

from super_gradients.training.utils.checkpoint_utils import adaptive_load_state_dict
adaptive_load_state_dict(model, my_model, strict="no_key_matching")

but I got another runtime error as below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in adaptive_load_state_dict(net, state_dict, strict, solver)
     57         strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
---> 58         net.load_state_dict(state_dict, strict=strict_bool)
     59     except (RuntimeError, ValueError, KeyError) as ex:

3 frames
RuntimeError: Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "module.backbone.stem.conv.branch_3x3.conv.weight", "module.backbone.stem.conv.branch_3x3.bn.weight", "module.backbone.stem.conv.branch_3x3.bn.bias", "module.backbone.stem.conv.branch_3x3.bn.running_mean", "module.backbone.stem.conv.branch_3x3.bn.running_var", "module.backbone.stem.conv.branch_3x3.bn.num_batches_tracked", "module.backbone.stem.conv.branch_1x1.weight", "module.backbone.stem.conv.branch_1x1.bias", "module.backbone.stem.conv.post_bn.weight", "module.backbone.stem.conv.post_bn.bias", "module.backbone.stem.conv.post_bn.running_mean", "module.backbone.stem.conv.post_bn.running_var", "module.backbone.stem.conv.post_bn.num_batches_tracked", "module.backbone.stem.conv.rbr_reparam.weight", "module.backbone.stem.conv.rbr_reparam.bias", "module.backbone.stage1.downsample.branch_3x3.conv.weight", "module.backbone.stage1.downsample.branch_3x3.bn.weight", "module.backbone.stage1.downsample.branch_3x3.bn.bias", "module.backbone.stage1.downsample.branch_3x3.bn.running_mean", "module.backbone.stage1.downsample.branch_3x3.bn.running_var", "module.backbone.stage1.downsample.branch_3x3.bn.num_batches_tracked", "module.backbone.stage1.downsample.branch_1x1.weight", "module.backbone.stage1.downsample.branch_1x1.bias", "module.backbone.stage1.downsample.post_bn.weight", "module.backbone.stage1.downsample.post_bn.bias", "module.backbone.stage1.downsample.post_bn.running_mean", "module.backbone.stage1.downsample.post_bn.running_var", "module.backbo...

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in raise_informative_runtime_error(state_dict, checkpoint, exception_msg)
    178         exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
    179     finally:
--> 180         raise RuntimeError(exception_msg)
    181 
    182 

RuntimeError: 
======================================================================================================================================================================================================== 
The checkpoint and model shapes do no fit, e.g.: ckpt layer module.heads.head1.cls_pred.weight with shape torch.Size([3, 128, 1, 1]) does not match heads.head1.cls_pred.weight with shape torch.Size([6, 128, 1, 1]) in the model
========================================================================================================================================================================================================

@NatanBagrov
Copy link
Contributor

This looks like an issue with different number of classes.
Please follow the instruction of this issue #949, and let me know if that helps

@adkbbx
Copy link
Author

adkbbx commented Jun 12, 2023

Thank you for your response.

I tried running the code as in #949

It shows me a runtime error similar to above when I used:
model = models.get('yolo_nas_l', num_classes=6, pretrained_weights="coco")
error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in adaptive_load_state_dict(net, state_dict, strict, solver)
     57         strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
---> 58         net.load_state_dict(state_dict, strict=strict_bool)
     59     except (RuntimeError, ValueError, KeyError) as ex:

3 frames
RuntimeError: Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "module.backbone.stem.conv.branch_3x3.conv.weight", "module.backbone.stem.conv.branch_3x3.bn.weight", "module.backbone.stem.conv.branch_3x3.bn.bias", "module.backbone.stem.conv.branch_3x3.bn.running_mean", "module.backbone.stem.conv.branch_3x3.bn.running_var", "module.backbone.stem.conv.branch_3x3.bn.num_batches_tracked", "module.backbone.stem.conv.branch_1x1.weight", "module.backbone.stem.conv.branch_1x1.bias", "module.backbone.stem.conv.post_bn.weight", "module.backbone.stem.conv.post_bn.bias", "module.backbone.stem.conv.post_bn.running_mean", "module.backbone.stem.conv.post_bn.running_var", "module.backbone.stem.conv.post_bn.num_batches_tracked", "module.backbone.stem.conv.rbr_reparam.weight", "module.backbone.stem.conv.rbr_reparam.bias", "module.backbone.stage1.downsample.branch_3x3.conv.weight", "module.backbone.stage1.downsample.branch_3x3.bn.weight", "module.backbone.stage1.downsample.branch_3x3.bn.bias", "module.backbone.stage1.downsample.branch_3x3.bn.running_mean", "module.backbone.stage1.downsample.branch_3x3.bn.running_var", "module.backbone.stage1.downsample.branch_3x3.bn.num_batches_tracked", "module.backbone.stage1.downsample.branch_1x1.weight", "module.backbone.stage1.downsample.branch_1x1.bias", "module.backbone.stage1.downsample.post_bn.weight", "module.backbone.stage1.downsample.post_bn.bias", "module.backbone.stage1.downsample.post_bn.running_mean", "module.backbone.stage1.downsample.post_bn.running_var", "module.backbo...

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in raise_informative_runtime_error(state_dict, checkpoint, exception_msg)
    178         exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
    179     finally:
--> 180         raise RuntimeError(exception_msg)
    181 
    182 

RuntimeError: 
======================================================================================================================================================================================================== 
The checkpoint and model shapes do no fit, e.g.: ckpt layer module.heads.head1.cls_pred.weight with shape torch.Size([3, 128, 1, 1]) does not match heads.head1.cls_pred.weight with shape torch.Size([6, 128, 1, 1]) in the model
========================================================================================================================================================================================================

but it seems to run and showed me another runtime error when I changed the number of classes to 3 as below:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in adaptive_load_state_dict(net, state_dict, strict, solver)
     57         strict_bool = strict if isinstance(strict, bool) else strict != StrictLoad.OFF
---> 58         net.load_state_dict(state_dict, strict=strict_bool)
     59     except (RuntimeError, ValueError, KeyError) as ex:

3 frames
RuntimeError: Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "module.backbone.stem.conv.branch_3x3.conv.weight", "module.backbone.stem.conv.branch_3x3.bn.weight", "module.backbone.stem.conv.branch_3x3.bn.bias", "module.backbone.stem.conv.branch_3x3.bn.running_mean", "module.backbone.stem.conv.branch_3x3.bn.running_var", "module.backbone.stem.conv.branch_3x3.bn.num_batches_tracked", "module.backbone.stem.conv.branch_1x1.weight", "module.backbone.stem.conv.branch_1x1.bias", "module.backbone.stem.conv.post_bn.weight", "module.backbone.stem.conv.post_bn.bias", "module.backbone.stem.conv.post_bn.running_mean", "module.backbone.stem.conv.post_bn.running_var", "module.backbone.stem.conv.post_bn.num_batches_tracked", "module.backbone.stem.conv.rbr_reparam.weight", "module.backbone.stem.conv.rbr_reparam.bias", "module.backbone.stage1.downsample.branch_3x3.conv.weight", "module.backbone.stage1.downsample.branch_3x3.bn.weight", "module.backbone.stage1.downsample.branch_3x3.bn.bias", "module.backbone.stage1.downsample.branch_3x3.bn.running_mean", "module.backbone.stage1.downsample.branch_3x3.bn.running_var", "module.backbone.stage1.downsample.branch_3x3.bn.num_batches_tracked", "module.backbone.stage1.downsample.branch_1x1.weight", "module.backbone.stage1.downsample.branch_1x1.bias", "module.backbone.stage1.downsample.post_bn.weight", "module.backbone.stage1.downsample.post_bn.bias", "module.backbone.stage1.downsample.post_bn.running_mean", "module.backbone.stage1.downsample.post_bn.running_var", "module.backbo...

During handling of the above exception, another exception occurred:

RuntimeError                              Traceback (most recent call last)
[/usr/local/lib/python3.10/dist-packages/super_gradients/training/utils/checkpoint_utils.py](https://localhost:8080/#) in raise_informative_runtime_error(state_dict, checkpoint, exception_msg)
    178         exception_msg = f"\n{'=' * 200} \nThe checkpoint and model shapes do no fit, e.g.: {ex}\n{'=' * 200}"
    179     finally:
--> 180         raise RuntimeError(exception_msg)
    181 
    182 

RuntimeError: 
========================================================================================================================================================================================================
Error(s) in loading state_dict for YoloNAS_L:
	Missing key(s) in state_dict: "backbone.stem.conv.branch_3x3.conv.weight", "backbone.stem.conv.branch_3x3.bn.weight", "backbone.stem.conv.branch_3x3.bn.bias", "backbone.stem.conv.branch_3x3.bn.running_mean", "backbone.stem.conv.branch_3x3.bn.running_var", "backbone.stem.conv.branch_1x1.weight", "backbone.stem.conv.branch_1x1.bias", "backbone.stem.conv.post_bn.weight", "backbone.stem.conv.post_bn.bias", "backbone.stem.conv.post_bn.running_mean", "backbone.stem.conv.post_bn.running_var", "backbone.stem.conv.rbr_reparam.weight", "backbone.stem.conv.rbr_reparam.bias", "backbone.stage1.downsample.branch_3x3.conv.weight", "backbone.stage1.downsample.branch_3x3.bn.weight", "backbone.stage1.downsample.branch_3x3.bn.bias", "backbone.stage1.downsample.branch_3x3.bn.running_mean", "backbone.stage1.downsample.branch_3x3.bn.running_var", "backbone.stage1.downsample.branch_1x1.weight", "backbone.stage1.downsample.branch_1x1.bias", "backbone.stage1.downsample.post_bn.weight", "backbone.stage1.downsample.post_bn.bias", "backbone.stage1.downsample.post_bn.running_mean", "backbone.stage1.downsample.post_bn.running_var", "backbone.stage1.downsample.rbr_reparam.weight", "backbone.stage1.downsample.rbr_reparam.bias", "backbone.stage1.blocks.conv1.conv.weight", "backbone.stage1.blocks.conv1.bn.weight", "backbone.stage1.blocks.conv1.bn.bias", "backbone.stage1.blocks.conv1.bn.running_mean", "backbone.stage1.blocks.conv1.bn.running_var", "backbone.stage1.blocks.conv2.conv.weight", "backbone.stage...
	Unexpected key(s) in state_dict: "module.backbone.stem.conv.branch_3x3.conv.weight", "module.backbone.stem.conv.branch_3x3.bn.weight", "module.backbone.stem.conv.branch_3x3.bn.bias", "module.backbone.stem.conv.branch_3x3.bn.running_mean", "module.backbone.stem.conv.branch_3x3.bn.running_var", "module.backbone.stem.conv.branch_3x3.bn.num_batches_tracked", "module.backbone.stem.conv.branch_1x1.weight", "module.backbone.stem.conv.branch_1x1.bias", "module.backbone.stem.conv.post_bn.weight", "module.backbone.stem.conv.post_bn.bias", "module.backbone.stem.conv.post_bn.running_mean", "module.backbone.stem.conv.post_bn.running_var", "module.backbone.stem.conv.post_bn.num_batches_tracked", "module.backbone.stem.conv.rbr_reparam.weight", "module.backbone.stem.conv.rbr_reparam.bias", "module.backbone.stage1.downsample.branch_3x3.conv.weight", "module.backbone.stage1.downsample.branch_3x3.bn.weight", "module.backbone.stage1.downsample.branch_3x3.bn.bias", "module.backbone.stage1.downsample.branch_3x3.bn.running_mean", "module.backbone.stage1.downsample.branch_3x3.bn.running_var", "module.backbone.stage1.downsample.branch_3x3.bn.num_batches_tracked", "module.backbone.stage1.downsample.branch_1x1.weight", "module.backbone.stage1.downsample.branch_1x1.bias", "module.backbone.stage1.downsample.post_bn.weight", "module.backbone.stage1.downsample.post_bn.bias", "module.backbone.stage1.downsample.post_bn.running_mean", "module.backbone.stage1.downsample.post_bn.running_var", "module.backbo...
convert ckpt via the utils.adapt_state_dict_to_fit_model_layer_names method
a converted checkpoint file was saved in the path /tmp/tmpp97s1mos.pt
========================================================================================================================================================================================================

Does this mean the model assumes that I only have 3 classes when I had initialised 6 classes during trining the model?

@BloodAxe
Copy link
Contributor

After we merged #1184 this should be resolved

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants