1
- from pprint import pprint
2
1
import argparse
3
- import coremltools as ct
4
2
import gc
5
3
import json
6
4
import logging
7
- import numpy as np
8
5
import os
9
6
7
+ import coremltools as ct
8
+ import coremltools .optimize .coreml as cto
9
+ import numpy as np
10
+
10
11
from python_coreml_stable_diffusion .torch2coreml import get_pipeline
11
12
from python_coreml_stable_diffusion .mixed_bit_compression_pre_analysis import (
12
13
NBITS ,
13
14
PALETTIZE_MIN_SIZE as MIN_SIZE
14
15
)
15
16
17
+
16
18
logging .basicConfig ()
17
19
logger = logging .getLogger (__name__ )
18
20
logger .setLevel (logging .INFO )
@@ -23,9 +25,6 @@ def main(args):
23
25
coreml_model = ct .models .MLModel (args .mlpackage_path , compute_units = ct .ComputeUnit .CPU_ONLY )
24
26
logger .info (f"Loaded { args .mlpackage_path } " )
25
27
26
- # Keep track of precision stats
27
- precision_stats = {nbits :{'num_tensors' : 0 , 'numel' : 0 } for nbits in NBITS }
28
-
29
28
# Load palettization recipe
30
29
with open (args .pre_analysis_json_path , 'r' ) as f :
31
30
pre_analysis = json .load (f )
@@ -62,53 +61,29 @@ def get_tensor_hash(tensor):
62
61
del pipe
63
62
gc .collect ()
64
63
65
- current_nbits : int
66
-
67
- def op_selector (const ):
68
- parameter_tensor = const .val .val
69
- if parameter_tensor .size < MIN_SIZE :
70
- return False
71
-
72
- if parameter_tensor .dtype != np .float16 :
73
- # These are the tensors that were compressed to look-up indices in previous passes
74
- return False
75
-
76
- tensor_hash = get_tensor_hash (parameter_tensor )
77
- tensor_spec = f"{ tensor_hash } with shape { parameter_tensor .shape } "
78
-
79
-
80
- hashes = list (hashed_recipe )
81
- pdist = np .abs (np .array (hashes ) - tensor_hash )
64
+ op_name_configs = {}
65
+ weight_metadata = cto .get_weights_metadata (coreml_model , weight_threshold = MIN_SIZE )
66
+ hashes = np .array (list (hashed_recipe ))
67
+ for name , metadata in weight_metadata .items ():
68
+ # Look up target bits for this weight
69
+ tensor_hash = get_tensor_hash (metadata .val )
70
+ pdist = np .abs (hashes - tensor_hash )
71
+ assert (pdist .min () < 0.01 )
82
72
matched = pdist .argmin ()
83
- logger .debug (f"{ tensor_spec } : { tensor_hash } matched with { hashes [matched ]} (hash error={ pdist .min ()} )" )
84
-
85
73
target_nbits = hashed_recipe [hashes [matched ]]
86
-
87
- do_palettize = current_nbits == target_nbits
88
- if do_palettize :
89
- logger .debug (f"{ tensor_spec } : Palettizing to { target_nbits } -bit palette" )
90
- precision_stats [current_nbits ]['num_tensors' ] += 1
91
- precision_stats [current_nbits ]['numel' ] += np .prod (parameter_tensor .shape )
92
- return True
93
- return False
94
-
95
- for nbits in NBITS :
96
- logger .info (f"Processing tensors targeting { nbits } -bit palettes" )
97
- current_nbits = nbits
98
-
99
- config = ct .optimize .coreml .OptimizationConfig (
100
- global_config = ct .optimize .coreml .OpPalettizerConfig (mode = "kmeans" , nbits = nbits , weight_threshold = None ,),
101
- is_deprecated = True ,
102
- op_selector = op_selector ,
74
+
75
+ if target_nbits == 16 :
76
+ continue
77
+
78
+ op_name_configs [name ] = cto .OpPalettizerConfig (
79
+ mode = "kmeans" ,
80
+ nbits = target_nbits ,
81
+ weight_threshold = int (MIN_SIZE )
103
82
)
104
- coreml_model = ct .optimize .coreml .palettize_weights (coreml_model , config = config )
105
- logger .info (f"{ precision_stats [nbits ]['num_tensors' ]} tensors are palettized with { nbits } bits" )
106
83
84
+ config = ct .optimize .coreml .OptimizationConfig (op_name_configs = op_name_configs )
85
+ coreml_model = ct .optimize .coreml .palettize_weights (coreml_model , config )
107
86
108
- tot_numel = sum ([precision_stats [nbits ]['numel' ] for nbits in NBITS ])
109
- final_size = sum ([precision_stats [nbits ]['numel' ] * nbits for nbits in NBITS ])
110
- logger .info (f"Palettization result: { final_size / tot_numel :.2f} -bits resulting in { final_size / (8 * 1e6 )} MB" )
111
- pprint (precision_stats )
112
87
coreml_model .save (args .o )
113
88
114
89
0 commit comments