Skip to content

Commit

Permalink
fix: support bf16 lora weights (#82)
Browse files Browse the repository at this point in the history
  • Loading branch information
Green-Sky authored Nov 20, 2023
1 parent ae1d5dc commit c874063
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions models/convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ def quantize_q5_1(x):
def quantize_q8_0(x):
assert x.shape[-1] % QK8_0 == 0 and x.shape[-1] > QK8_0
x = x.reshape(-1, QK8_0)
amax = np.max(np.abs(x), axis=-1, keepdims=True)
amax = np.max(np.abs(x), axis=-1, keepdims=True)
d = amax / ((1 << 7) - 1)
qs = (x / d).round().clip(min=-128, max=127).astype(np.int8)
d = d.astype(np.float16).view(np.int8)
Expand Down Expand Up @@ -178,7 +178,7 @@ def preprocess(state_dict):
print("no alphas_cumprod in file, generate new one")
alphas_cumprod = get_alpha_comprod()
state_dict["alphas_cumprod"] = alphas_cumprod

new_state_dict = {}
for name, w in state_dict.items():
# ignore unused tensors
Expand All @@ -192,7 +192,7 @@ def preprocess(state_dict):
if skip:
continue

# # convert BF16 to FP16
# convert BF16 to FP16
if w.dtype == torch.bfloat16:
w = w.to(torch.float16)

Expand Down Expand Up @@ -251,7 +251,7 @@ def preprocess(state_dict):
new_state_dict[new_name] = w
print(f"preprocess {name} => {new_name}")
continue

# convert unet transformer linear to conv2d 1x1
if name.startswith("model.diffusion_model.") and (name.endswith("proj_in.weight") or name.endswith("proj_out.weight")):
if len(w.shape) == 2:
Expand Down Expand Up @@ -342,6 +342,11 @@ def preprocess_lora(state_dict):
for name, w in state_dict.items():
if not isinstance(w, torch.Tensor):
continue

# convert BF16 to FP16
if w.dtype == torch.bfloat16:
w = w.to(torch.float16)

name_without_network_parts, network_part = name.split(".", 1)
new_name_without_network_parts = convert_diffusers_name_to_compvis(name_without_network_parts)
if new_name_without_network_parts == None:
Expand Down Expand Up @@ -421,6 +426,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
continue
if name in unused_tensors:
continue

data = state_dict[name].numpy()

n_dims = len(data.shape)
Expand Down Expand Up @@ -452,7 +458,7 @@ def convert(model_path, out_type = None, out_file=None, lora=False):
else:
data = data.astype(np.float32)
ttype = "f32"

print("Processing tensor: {} with shape {}, {} -> {}".format(name, data.shape, old_type, ttype))

# header
Expand Down

0 comments on commit c874063

Please sign in to comment.