Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Yolox improve with REPConv/ASFF/TOOD #154

Merged
merged 85 commits into from
Aug 24, 2022

Conversation

wuziheng
Copy link
Collaborator

Motivation

  1. use ASFF/TOOD/REPVGG to improve YOLOX, achieve 44.5map yolox-s

Modification

  1. Model Yolox components such as ASFF/ASFFSim/TOODS supports
  2. Model of RepVGGBlocks backbone
  3. Export support of reparameter and blade best practice for Yolox

Backward-Compatibility-Breaking (Optional)

Use cases (Optional)

Pull-Request Todo List

  1. Pre-commit or other linting tools are used to fix the potential lint issues.
  2. The modification is covered by complete unit tests. If not, please add more unit test to ensure the correctness.
  3. The documentation has been modified accordingly, like docstring or example tutorials.

@@ -35,7 +42,9 @@
]

# dataset settings
data_root = 'data/coco/'
# data_root = '/apsarapangu/disk2/xinyi.zxy/data/coco/'
data_root = '/apsara/xinyi.zxy/data/coco/'
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

delete commented lines and update data_root, so are the other files

@@ -659,6 +746,7 @@ def forward(self, image):

with torch.no_grad():
if self.preprocess_fn is not None:
# print('before', image.shape)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove line749

@@ -126,6 +126,9 @@ def visualize(self, results, vis_num=10, score_thr=0.3, **kwargs):
dict of image meta info, containing filename, img_shape,
origin_img_shape, scale_factor and so on.
"""
import copy
results = copy.deepcopy(results)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it is already fixed, please refer to: #67

remove deepcopy

convert_new.py Outdated
@@ -0,0 +1,78 @@
# conver to new
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

remove

@@ -0,0 +1,117 @@
# from easycv.models.detection.detectors.yolox import YOLOX
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to easycv/test

export_log.txt Outdated Show resolved Hide resolved
tools/eval.py Outdated Show resolved Hide resolved
easycv/apis/export.py Outdated Show resolved Hide resolved
easycv/apis/export.py Outdated Show resolved Hide resolved
easycv/apis/export.py Outdated Show resolved Hide resolved
pass

@torch.jit.script
def postprocess_fn(output):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove these export guides?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we have changed a new design, and will discuss it.

- Provide an easy way to use PAI-Blade to accelerate the inference process.
- Provide a convenient way to train/evaluate/export YOLOX-PAI model and conduct end2end object detection.

To learn more details of YOLOX-PAI, you can refer to our technical paper [??link][arxiv].
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??should be replaced

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the mistake, it is a chinese and have fixed it.

Put them in the following format:
```shell
export_blade/
??? epoch_300_pre_notrt.pt.blade
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

??? should be replaced

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry for the mistake, and have fixed it.


else:
if hasattr(cfg, 'export'):
export_type = getattr(cfg.export, 'export_type', 'ori')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update export_type: 'ori' to 'raw'

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

with io.open('/tmp/model.jit', 'wb') as f:
torch.jit.save(model_script, f)
"""
class ModelExportWrapper(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why remove End2endModelExportWrapper and preprocess_fn, postprocess_fn

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we will discuss it.

return model_output


class ProcessExportWrapper(torch.nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

named with ProcessExportWrapper, but it seems has nothing to do with export

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we use the wrapper to trace the model and export a preprocess jit model.

@@ -183,6 +185,9 @@ def main():
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'use device {device}')
checkpoint = load_checkpoint(model, args.checkpoint, map_location=device)

model = reparameterize_models(model)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add explanation for line189

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

@Cathy0908
Copy link
Collaborator

LGTM

@wenmengzhou wenmengzhou merged commit 9aaa600 into alibaba:master Aug 24, 2022
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

8 participants