| import onnx | |
| import os | |
| import itertools | |
| import argparse | |
| import shutil | |
| from onnxconverter_common.float16 import convert_float_to_float16 | |
| from onnxruntime.quantization import quantize_dynamic, QuantType | |
| from onnxruntime.tools.symbolic_shape_infer import SymbolicShapeInference | |
| from multiprocessing import Pool | |
| from tabulate import tabulate | |
| def float32(input, output): | |
| shutil.copy2(input, output) | |
| def float16(input, output): | |
| model = onnx.load(input) | |
| model_f16 = convert_float_to_float16(model) | |
| onnx.save(model_f16, output) | |
| def qint8(input, output): | |
| quantize_dynamic(input, output, weight_type=QuantType.QInt8) | |
| def quint8(input, output): | |
| quantize_dynamic(input, output, weight_type=QuantType.QUInt8) | |
| def infer_shapes(input, output): | |
| out_mp = SymbolicShapeInference.infer_shapes(onnx.load(input)) | |
| onnx.save(out_mp, output) | |
| def print_table(table): | |
| print(tabulate(table, headers="keys", tablefmt="github"), "\n") | |
| def get_file_mb(path): | |
| try: | |
| stat = os.stat(path) | |
| except FileNotFoundError: | |
| return "N/A" | |
| mb = round(stat.st_size / 1_000_000) | |
| return f"{mb}" | |
| def convert(name, mode, f, markdown): | |
| fname = f.__name__ | |
| input = f"converted/clip-{name}-{mode}.onnx" | |
| output = f"models/clip-{name}-{mode}-{fname}.onnx" | |
| exists = os.path.exists(output) | |
| if markdown: | |
| return [output, name, mode, fname, "✅" if exists else "❌"] | |
| if exists: | |
| print(f"{output} exists") | |
| else: | |
| if mode == "textual": | |
| output_temp = f"{output}.temp" | |
| print(f"{output} converting") | |
| f(input, output_temp) | |
| print(f"{output} running shape inference for TensorRT support") | |
| infer_shapes(output_temp, output) | |
| os.remove(output_temp) | |
| else: | |
| print(f"{output} converting") | |
| f(input, output) | |
| print(f"{output} done") | |
| if __name__ == '__main__': | |
| parser = argparse.ArgumentParser(description='Create variants of converted models') | |
| parser.add_argument( | |
| '--markdown', | |
| action='store_true', | |
| help='Print markdown tables describing the variants' | |
| ) | |
| args = parser.parse_args() | |
| names = [ | |
| "resnet-50", | |
| "resnet-101", | |
| "resnet-50x4", | |
| "resnet-50x16", | |
| "resnet-50x64", | |
| "resnet-50", | |
| "resnet-50", | |
| "resnet-50", | |
| "vit-base-patch16", | |
| "vit-base-patch32", | |
| "vit-large-patch14", | |
| "vit-large-patch14-336", | |
| ] | |
| modes = [ | |
| "visual", | |
| "textual" | |
| ] | |
| funcs = [ | |
| float32, | |
| float16, | |
| qint8, | |
| quint8, | |
| ] | |
| markdown = args.markdown | |
| if markdown: | |
| print_table({ "Model ID": names }) | |
| print_table({ "Mode": modes }) | |
| print_table({ "Data Type": [f.__name__ for f in funcs] }) | |
| variants = itertools.product(names, modes, funcs, [markdown]) | |
| with Pool(8 if not markdown else 1) as p: | |
| variants_table = p.starmap(convert, variants) | |
| if markdown: | |
| for row in variants_table: | |
| output = row[0] | |
| file_size = get_file_mb(output) | |
| row.append(file_size) | |
| variants_table.insert(0, ["Path", "Model ID", "Mode", "Data Type", "Available", "Size (MB)"]) | |
| print(tabulate(variants_table, headers="firstrow", tablefmt="github")) | |
| else: | |
| print("done") | |