Skip to content

Update tests to use fp8 #508

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,121 @@
import onnxruntime
from onnxruntime.quantization import CalibrationDataReader, create_calibrator, write_calibration_table

def custom_write_calibration_table(calibration_cache, filename):
"""
Helper function to write calibration table to files.
"""

import json
import logging
import flatbuffers
import numpy as np

import onnxruntime.quantization.CalTableFlatBuffers.KeyValue as KeyValue
import onnxruntime.quantization.CalTableFlatBuffers.TrtTable as TrtTable
from onnxruntime.quantization.calibrate import CalibrationMethod, TensorData, TensorsData

logging.info(f"calibration cache: {calibration_cache}")

class MyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, (TensorData, TensorsData)):
return obj.to_dict()
if isinstance(obj, TensorDataWrapper):
return obj.data_dict
if isinstance(obj, np.ndarray):
return {"data": obj.tolist(), "dtype": str(obj.dtype), "CLS": "numpy.array"}
if isinstance(obj, CalibrationMethod):
return {"CLS": obj.__class__.__name__, "value": str(obj)}
return json.JSONEncoder.default(self, obj)

json_data = json.dumps(calibration_cache, cls=MyEncoder)

with open(filename, "w") as file:
file.write(json_data) # use `json.loads` to do the reverse

# Serialize data using FlatBuffers
zero = np.array(0)
builder = flatbuffers.Builder(1024)
key_value_list = []

for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()

highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = str(max(floats))

flat_key = builder.CreateString(key)
flat_value = builder.CreateString(value)

KeyValue.KeyValueStart(builder)
KeyValue.KeyValueAddKey(builder, flat_key)
KeyValue.KeyValueAddValue(builder, flat_value)
key_value = KeyValue.KeyValueEnd(builder)

key_value_list.append(key_value)


TrtTable.TrtTableStartDictVector(builder, len(key_value_list))
for key_value in key_value_list:
builder.PrependUOffsetTRelative(key_value)
main_dict = builder.EndVector()

TrtTable.TrtTableStart(builder)
TrtTable.TrtTableAddDict(builder, main_dict)
cal_table = TrtTable.TrtTableEnd(builder)

builder.Finish(cal_table)
buf = builder.Output()

with open(filename, "wb") as file:
file.write(buf)

# Deserialize data (for validation)
if os.environ.get("QUANTIZATION_DEBUG", 0) in (1, "1"):
cal_table = TrtTable.TrtTable.GetRootAsTrtTable(buf, 0)
dict_len = cal_table.DictLength()
for i in range(dict_len):
key_value = cal_table.Dict(i)
logging.info(key_value.Key())
logging.info(key_value.Value())

# write plain text
with open(filename + ".cache", "w") as file:
for key in sorted(calibration_cache.keys()):
values = calibration_cache[key]
d_values = values.to_dict()
highest = d_values.get("highest", zero)
lowest = d_values.get("lowest", zero)

highest_val = highest.item() if hasattr(highest, "item") else float(highest)
lowest_val = lowest.item() if hasattr(lowest, "item") else float(lowest)

floats = [float(highest_val), float(lowest_val)]

value = key + " " + str(max(floats))
file.write(value)
file.write("\n")


def parse_input_args():
parser = argparse.ArgumentParser()

parser.add_argument(
"--model",
required=False,
default='./resnet50-v2-7.onnx',
help='Target DIR for model. Default is ./resnet50-v2-7.onnx',
)

parser.add_argument(
"--fp16",
action="store_true",
Expand All @@ -29,6 +141,14 @@ def parse_input_args():
help='Perform no quantization',
)

parser.add_argument(
"--fp8",
action="store_true",
required=False,
default=False,
help='Perform fp8 quantizaton instead of int8',
)

parser.add_argument(
"--image_dir",
required=False,
Expand All @@ -48,6 +168,29 @@ def parse_input_args():
help='Size of images for calibration',
type=int)

parser.add_argument(
"--exhaustive_tune",
action="store_true",
required=False,
default=False,
help='Enable MIGraphX Exhaustive tune before compile. Default False',
)

parser.add_argument(
"--cache",
action="store_true",
required=False,
default=True,
help='cache the compiled model between runs. Saves quantization and compile time. Default true',
)

parser.add_argument(
"--cache_name",
required=False,
default="./cached_model.mxr",
help='Name and path of the compiled model cache. Default: ./cached_model.mxr',
)

return parser.parse_args()

class ImageNetDataReader(CalibrationDataReader):
Expand Down Expand Up @@ -255,6 +398,7 @@ class ImageClassificationEvaluator:
def __init__(self,
model_path,
synset_id,
flags,
data_reader: CalibrationDataReader,
providers=["MIGraphXExecutionProvider"]):
'''
Expand All @@ -276,10 +420,21 @@ def get_result(self):

def predict(self):
sess_options = onnxruntime.SessionOptions()
sess_options.log_severity_level = 0
sess_options.log_verbosity_level = 0
sess_options.log_severity_level = 2
sess_options.log_verbosity_level = 2
sess_options.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_DISABLE_ALL
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options, providers=self.providers)
session = onnxruntime.InferenceSession(self.model_path, sess_options=sess_options,
providers=[("MIGraphXExecutionProvider",
{"migraphx_fp8_enable": flags.fp8 and not flags.fp32,
"migraphx_int8_enable": not (flags.fp8 or flags.fp32),
"migraphx_fp16_enable": flags.fp16 and not flags.fp32,
"migraphx_int8_calibration_table_name": flags.calibration_table,
"migraphx_use_native_calibration_table": flags.native_calibration_table,
"migraphx_save_compiled_model": flags.cache,
"migraphx_save_model_path": flags.cache_name,
"migraphx_load_compiled_model": flags.cache,
"migraphx_load_model_path": flags.cache_name,
"migraphx_exhaustive_tune": flags.exhaustive_tune})])

inference_outputs_list = []
while True:
Expand Down Expand Up @@ -362,21 +517,31 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
flags = parse_input_args()

# Dataset settings
model_path = "./resnet50-v2-7.onnx"
model_path = flags.model
ilsvrc2012_dataset_path = flags.image_dir
augmented_model_path = "./augmented_model.onnx"
batch_size = flags.batch
calibration_dataset_size = 0 if flags.fp32 else flags.cal_size # Size of dataset for calibration

precision=""

if not (flags.fp8 or flags.fp32):
precision = precision + "_int8"

if flags.fp8 and not flags.fp32:
precision = precision + "_fp8"

if flags.fp16 and not flags.fp32:
precision = "_fp16" + precision

calibration_table_generation_enable = False
if not flags.fp32:
# INT8 calibration setting
calibration_table_generation_enable = True # Enable/Disable INT8 calibration

# MIGraphX EP INT8 settings
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "1" # Enable INT8 precision
os.environ["ORT_MIGRAPHX_INT8_CALIBRATION_TABLE_NAME"] = "calibration.flatbuffers" # Calibration table name
os.environ["ORT_MIGRAPHX_INT8_NATIVE_CALIBRATION_TABLE"] = "0" # Calibration table name
flags.calibration_table = "calibration_cal"+ str(flags.cal_size) + precision + ".flatbuffers"
flags.native_calibration_table = "False"
if os.path.isfile("./" + flags.calibration_table):
calibration_table_generation = False
print("Found previous calibration: " + flags.calibration_table + "Skipping generating table")

execution_provider = ["MIGraphXExecutionProvider"]

Expand All @@ -396,25 +561,46 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
start_index=0,
end_index=calibration_dataset_size,
stride=calibration_dataset_size,
batch_size=batch_size,
batch_size=1,
model_path=augmented_model_path,
input_name=input_name)
calibrator.collect_data(data_reader)
cal_tensors = calibrator.compute_data()

serial_cal_tensors = {}
for keys, values in cal_tensors.data.items():
serial_cal_tensors[keys] = [float(x[0]) for x in values.range_value]
class TensorDataWrapper:
def __init__(self, data_dict):
self.data_dict = data_dict

def to_dict(self):
return self.data_dict

def __repr__(self):
return repr(self.data_dict)

def __serializable__(self):
return self.data_dict

calibration_data = {}
for k, v in cal_tensors.data.items():
if hasattr(v, 'to_dict'):
tensor_dict = v.to_dict()
processed_dict = {}
for dk, dv in tensor_dict.items():
if isinstance(dv, np.ndarray):
processed_dict[dk] = dv.item() if dv.size == 1 else dv.tolist()
elif isinstance(dv, np.number):
processed_dict[dk] = dv.item()
else:
processed_dict[dk] = dv
calibration_data[k] = TensorDataWrapper(processed_dict)
else:
calibration_data[k] = v

print("Writing calibration table")
write_calibration_table(serial_cal_tensors)
print("Writing calibration table to:" + flags.calibration_table)
custom_write_calibration_table(calibration_data, flags.calibration_table)
os.rename("./calibration.flatbuffers", flags.calibration_table)
print("Write complete")

if flags.fp16:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "1"
else:
os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"

# Run prediction in MIGraphX EP138G
data_reader = ImageNetDataReader(ilsvrc2012_dataset_path,
start_index=calibration_dataset_size,
Expand All @@ -427,14 +613,9 @@ def get_dataset_size(dataset_path, calibration_dataset_size):
synset_id = data_reader.get_synset_id(ilsvrc2012_dataset_path, calibration_dataset_size,
prediction_dataset_size) # Generate synset id
print("Prepping Evalulator")
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, data_reader, providers=execution_provider)
evaluator = ImageClassificationEvaluator(new_model_path, synset_id, flags, data_reader, providers=execution_provider)
print("Performing Predictions")
evaluator.predict()
print("Read out answer")
result = evaluator.get_result()
evaluator.evaluate(result)

#Set OS flags to off to ensure we don't interfere with other test runs

os.environ["ORT_MIGRAPHX_FP16_ENABLE"] = "0"
os.environ["ORT_MIGRAPHX_INT8_ENABLE"] = "0"
Loading