Skip to content

Commit

Permalink
Update the PT2E CV example (#2032)
Browse files Browse the repository at this point in the history
Signed-off-by: yiliu30 <yi4.liu@intel.com>
  • Loading branch information
yiliu30 authored Oct 16, 2024
1 parent 08ec908 commit d6149aa
Show file tree
Hide file tree
Showing 4 changed files with 15 additions and 7 deletions.
3 changes: 2 additions & 1 deletion examples/3.x_api/pytorch/cv/static_quant/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ This implements quantization of popular model architectures, such as ResNet on t
To quant a model and validate accaracy, run `main.py` with the desired model architecture and the path to the ImageNet dataset:

```bash
python main.py -a resnet18 [imagenet-folder with train and val folders] -q -e
export ImageNetDataPath=/path/to/imagenet
python main.py $ImageNetDataPath --pretrained -a resnet18 --tune --calib_iters 5
```


Expand Down
9 changes: 6 additions & 3 deletions examples/3.x_api/pytorch/cv/static_quant/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@
parser.add_argument('--wd', '--weight-decay', default=1e-4, type=float,
metavar='W', help='weight decay (default: 1e-4)',
dest='weight_decay')
parser.add_argument('--w_granularity', default="per_channel", type=str, choices=["per_channel", "per_tensor"], help='weight granularity')
parser.add_argument('-p', '--print-freq', default=10, type=int,
metavar='N', help='print frequency (default: 10)')
parser.add_argument('--resume', default='', type=str, metavar='PATH',
Expand Down Expand Up @@ -179,7 +180,7 @@ def eval_func(model):

if args.tune:
from neural_compressor.torch.export import export
from neural_compressor.torch.quantization import prepare, convert, get_default_static_config
from neural_compressor.torch.quantization import prepare, convert, StaticQuantConfig

# Prepare the float model and example inputs for exporting model
x = torch.randn(args.batch_size, 3, 224, 224).contiguous(memory_format=torch.channels_last)
Expand All @@ -188,15 +189,15 @@ def eval_func(model):
# Specify that the first dimension of each input is that batch size
from torch.export import Dim
print(args.batch_size)
batch = Dim("batch", min=16)
batch = Dim("batch")

# Specify that the first dimension of each input is that batch size
dynamic_shapes = {"x": {0: batch}}

# Export eager model into FX graph model
exported_model = export(model=model, example_inputs=example_inputs, dynamic_shapes=dynamic_shapes)
# Quantize the model
quant_config = get_default_static_config()
quant_config = StaticQuantConfig(w_granularity=args.w_granularity)

prepared_model = prepare(exported_model, quant_config=quant_config)
# Calibrate
Expand Down Expand Up @@ -233,7 +234,9 @@ def eval_func(model):
new_model = opt_model
else:
new_model = model
# For fair comparison, we also compile the float model
new_model.eval()
new_model = torch.compile(new_model)
if args.performance:
benchmark(val_loader, new_model, args)
return
Expand Down
4 changes: 2 additions & 2 deletions examples/3.x_api/pytorch/cv/static_quant/run_benchmark.sh
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ function run_benchmark {
python main.py \
--pretrained \
-a resnet18 \
-b 30 \
-b ${batch_size} \
--tuned_checkpoint ${tuned_checkpoint} \
${dataset_location} \
${extra_cmd} \
Expand All @@ -89,7 +89,7 @@ function run_benchmark {
main.py \
--pretrained \
-a resnet18 \
-b 30 \
-b ${batch_size} \
--tuned_checkpoint ${tuned_checkpoint} \
${dataset_location} \
${extra_cmd} \
Expand Down
6 changes: 5 additions & 1 deletion examples/3.x_api/pytorch/cv/static_quant/run_quant.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ function main {

# init params
function init_params {
batch_size=16
tuned_checkpoint="saved_results"
for var in "$@"
do
Expand All @@ -22,6 +23,9 @@ function init_params {
;;
--input_model=*)
input_model=$(echo $var |cut -f2 -d=)
;;
--batch_size=*)
batch_size=$(echo $var |cut -f2 -d=)
;;
--output_model=*)
tuned_checkpoint=$(echo $var |cut -f2 -d=)
Expand All @@ -44,7 +48,7 @@ function run_tuning {
--pretrained \
-t \
-a resnet18 \
-b 30 \
-b ${batch_size} \
--tuned_checkpoint ${tuned_checkpoint} \
${dataset_location}
}
Expand Down

0 comments on commit d6149aa

Please sign in to comment.