Skip to content

Commit 9a3d463

Browse files
committed
Applying mixed bit compression using new optimize API
1 parent 7449ce4 commit 9a3d463

File tree

3 files changed

+29
-54
lines changed

3 files changed

+29
-54
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
*~
2+
13
# Swift Package
24
.DS_Store
35
/.build

python_coreml_stable_diffusion/mixed_bit_compression_apply.py

+23-48
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,20 @@
1-
from pprint import pprint
21
import argparse
3-
import coremltools as ct
42
import gc
53
import json
64
import logging
7-
import numpy as np
85
import os
96

7+
import coremltools as ct
8+
import coremltools.optimize.coreml as cto
9+
import numpy as np
10+
1011
from python_coreml_stable_diffusion.torch2coreml import get_pipeline
1112
from python_coreml_stable_diffusion.mixed_bit_compression_pre_analysis import (
1213
NBITS,
1314
PALETTIZE_MIN_SIZE as MIN_SIZE
1415
)
1516

17+
1618
logging.basicConfig()
1719
logger = logging.getLogger(__name__)
1820
logger.setLevel(logging.INFO)
@@ -23,9 +25,6 @@ def main(args):
2325
coreml_model = ct.models.MLModel(args.mlpackage_path, compute_units=ct.ComputeUnit.CPU_ONLY)
2426
logger.info(f"Loaded {args.mlpackage_path}")
2527

26-
# Keep track of precision stats
27-
precision_stats = {nbits:{'num_tensors': 0, 'numel': 0} for nbits in NBITS}
28-
2928
# Load palettization recipe
3029
with open(args.pre_analysis_json_path, 'r') as f:
3130
pre_analysis = json.load(f)
@@ -62,53 +61,29 @@ def get_tensor_hash(tensor):
6261
del pipe
6362
gc.collect()
6463

65-
current_nbits: int
66-
67-
def op_selector(const):
68-
parameter_tensor = const.val.val
69-
if parameter_tensor.size < MIN_SIZE:
70-
return False
71-
72-
if parameter_tensor.dtype != np.float16:
73-
# These are the tensors that were compressed to look-up indices in previous passes
74-
return False
75-
76-
tensor_hash = get_tensor_hash(parameter_tensor)
77-
tensor_spec = f"{tensor_hash} with shape {parameter_tensor.shape}"
78-
79-
80-
hashes = list(hashed_recipe)
81-
pdist = np.abs(np.array(hashes) - tensor_hash)
64+
op_name_configs = {}
65+
weight_metadata = cto.get_weights_metadata(coreml_model, weight_threshold=MIN_SIZE)
66+
hashes = np.array(list(hashed_recipe))
67+
for name, metadata in weight_metadata.items():
68+
# Look up target bits for this weight
69+
tensor_hash = get_tensor_hash(metadata.val)
70+
pdist = np.abs(hashes - tensor_hash)
71+
assert(pdist.min() < 0.01)
8272
matched = pdist.argmin()
83-
logger.debug(f"{tensor_spec}: {tensor_hash} matched with {hashes[matched]} (hash error={pdist.min()})")
84-
8573
target_nbits = hashed_recipe[hashes[matched]]
86-
87-
do_palettize = current_nbits == target_nbits
88-
if do_palettize:
89-
logger.debug(f"{tensor_spec}: Palettizing to {target_nbits}-bit palette")
90-
precision_stats[current_nbits]['num_tensors'] += 1
91-
precision_stats[current_nbits]['numel'] += np.prod(parameter_tensor.shape)
92-
return True
93-
return False
94-
95-
for nbits in NBITS:
96-
logger.info(f"Processing tensors targeting {nbits}-bit palettes")
97-
current_nbits = nbits
98-
99-
config = ct.optimize.coreml.OptimizationConfig(
100-
global_config=ct.optimize.coreml.OpPalettizerConfig(mode="kmeans", nbits=nbits, weight_threshold=None,),
101-
is_deprecated=True,
102-
op_selector=op_selector,
74+
75+
if target_nbits == 16:
76+
continue
77+
78+
op_name_configs[name] = cto.OpPalettizerConfig(
79+
mode="kmeans",
80+
nbits=target_nbits,
81+
weight_threshold=int(MIN_SIZE)
10382
)
104-
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config=config)
105-
logger.info(f"{precision_stats[nbits]['num_tensors']} tensors are palettized with {nbits} bits")
10683

84+
config = ct.optimize.coreml.OptimizationConfig(op_name_configs=op_name_configs)
85+
coreml_model = ct.optimize.coreml.palettize_weights(coreml_model, config)
10786

108-
tot_numel = sum([precision_stats[nbits]['numel'] for nbits in NBITS])
109-
final_size = sum([precision_stats[nbits]['numel'] * nbits for nbits in NBITS])
110-
logger.info(f"Palettization result: {final_size / tot_numel:.2f}-bits resulting in {final_size / (8*1e6)} MB")
111-
pprint(precision_stats)
11287
coreml_model.save(args.o)
11388

11489

python_coreml_stable_diffusion/mixed_bit_compression_pre_analysis.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import requests
2222
torch.set_grad_enabled(False)
2323

24-
from tqdm import tqdm, trange
24+
from tqdm import tqdm
2525

2626
# Bit-widths the Neural Engine is capable of accelerating
2727
NBITS = [1, 2, 4, 6, 8]
@@ -342,8 +342,8 @@ def simulate_quant_fn(ref_pipe, quantization_to_simulate):
342342

343343
ref_out = run_pipe(ref_pipe)
344344
simulated_psnr = sum([
345-
float(f"{compute_psnr(r,t):.1f}")
346-
for r,t in zip(ref_out, simulated_out)
345+
float(f"{compute_psnr(r, t):.1f}")
346+
for r, t in zip(ref_out, simulated_out)
347347
]) / len(ref_out)
348348

349349
return simulated_out, simulated_psnr
@@ -459,9 +459,7 @@ def main(args):
459459
json_name = f"{args.model_version.replace('/','-')}_palettization_recipe.json"
460460
candidates, sizes = get_palettizable_modules(pipe.unet)
461461

462-
sizes_table = {
463-
candidate:size for candidate, size in zip(candidates, sizes)
464-
}
462+
sizes_table = dict(zip(candidates, sizes))
465463

466464
if os.path.isfile(os.path.join(args.o, json_name)):
467465
with open(os.path.join(args.o, json_name), "r") as f:

0 commit comments

Comments
 (0)