diff --git a/tools/get_flops.py b/tools/get_flops.py index 6c9cb23400c..502d57651c0 100644 --- a/tools/get_flops.py +++ b/tools/get_flops.py @@ -1,11 +1,11 @@ import argparse +import json from mmcv import Config from mmdet.models import build_detector from mmdet.utils import get_model_complexity_info - def parse_args(): parser = argparse.ArgumentParser(description='Train a detector') parser.add_argument('config', help='train config file path') @@ -15,6 +15,7 @@ def parse_args(): nargs='+', default=[1280, 800], help='input image size') + parser.add_argument('--out') args = parser.parse_args() return args @@ -50,6 +51,14 @@ def main(): 'You may need to check if all ops are supported and verify that the ' 'flops computation is correct.') + if args.out: + out = list() + out.append({'key': 'size', 'displayName': 'Size', 'value': float(params.split(' ')[0]), 'unit': 'Mp'}) + out.append({'key': 'complexity', 'displayName': 'Complexity', 'value': 2 * float(flops.split(' ')[0]), + 'unit': 'GFLOPs'}) + with open(args.out, 'w') as write_file: + json.dump(out, write_file, indent=4) + if __name__ == '__main__': main()