Skip to content

Commit

Permalink
Fix keys of save_own_variables and load_own_variables (#19581)
Browse files Browse the repository at this point in the history
  • Loading branch information
james77777778 authored Apr 22, 2024
1 parent 3afc089 commit afc92f5
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 44 deletions.
12 changes: 8 additions & 4 deletions keras/src/layers/convolutional/base_conv.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,19 +307,23 @@ def save_own_variables(self, store):
# Do nothing if the layer isn't yet built
if not self.built:
return
store["0"] = self.kernel
target_variables = [self.kernel]
if self.use_bias:
store["1"] = self.bias
target_variables.append(self.bias)
for i, variable in enumerate(target_variables):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.lora_enabled:
self._check_load_own_variables(store)
# Do nothing if the layer isn't yet built
if not self.built:
return
self._kernel.assign(store["0"])
target_variables = [self._kernel]
if self.use_bias:
self.bias.assign(store["1"])
target_variables.append(self.bias)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
Expand Down
40 changes: 22 additions & 18 deletions keras/src/layers/core/dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -202,24 +202,26 @@ def save_own_variables(self, store):
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
store["0"] = kernel_value
target_variables = [kernel_value]
if self.use_bias:
store["1"] = self.bias
target_variables.append(self.bias)
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
store["2"] = kernel_scale
target_variables.append(kernel_scale)
elif mode == "float8":
store["2"] = self.inputs_scale
store["3"] = self.inputs_amax_history
store["4"] = self.kernel_scale
store["5"] = self.kernel_amax_history
store["6"] = self.outputs_grad_scale
store["7"] = self.outputs_grad_amax_history
target_variables.append(self.inputs_scale)
target_variables.append(self.inputs_amax_history)
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_amax_history)
target_variables.append(self.outputs_grad_scale)
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.lora_enabled:
Expand All @@ -229,24 +231,26 @@ def load_own_variables(self, store):
return
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
self._kernel.assign(store["0"])
target_variables = [self._kernel]
if self.use_bias:
self.bias.assign(store["1"])
target_variables.append(self.bias)
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
self.kernel_scale.assign(store["2"])
target_variables.append(self.kernel_scale)
elif mode == "float8":
self.inputs_scale.assign(store["2"])
self.inputs_amax_history.assign(store["3"])
self.kernel_scale.assign(store["4"])
self.kernel_amax_history.assign(store["5"])
self.outputs_grad_scale.assign(store["6"])
self.outputs_grad_amax_history.assign(store["7"])
target_variables.append(self.inputs_scale)
target_variables.append(self.inputs_amax_history)
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_amax_history)
target_variables.append(self.outputs_grad_scale)
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
Expand Down
40 changes: 22 additions & 18 deletions keras/src/layers/core/einsum_dense.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,24 +257,26 @@ def save_own_variables(self, store):
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
kernel_value, kernel_scale = self._get_kernel_with_merged_lora()
store["0"] = kernel_value
target_variables = [kernel_value]
if self.bias is not None:
store["1"] = self.bias
target_variables.append(self.bias)
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
store["2"] = kernel_scale
target_variables.append(kernel_scale)
elif mode == "float8":
store["2"] = self.inputs_scale
store["3"] = self.inputs_amax_history
store["4"] = self.kernel_scale
store["5"] = self.kernel_amax_history
store["6"] = self.outputs_grad_scale
store["7"] = self.outputs_grad_amax_history
target_variables.append(self.inputs_scale)
target_variables.append(self.inputs_amax_history)
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_amax_history)
target_variables.append(self.outputs_grad_scale)
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.lora_enabled:
Expand All @@ -284,24 +286,26 @@ def load_own_variables(self, store):
return
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
self._kernel.assign(store["0"])
target_variables = [self._kernel]
if self.bias is not None:
self.bias.assign(store["1"])
target_variables.append(self.bias)
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
self.kernel_scale.assign(store["2"])
target_variables.append(self.kernel_scale)
elif mode == "float8":
self.inputs_scale.assign(store["2"])
self.inputs_amax_history.assign(store["3"])
self.kernel_scale.assign(store["4"])
self.kernel_amax_history.assign(store["5"])
self.outputs_grad_scale.assign(store["6"])
self.outputs_grad_amax_history.assign(store["7"])
target_variables.append(self.inputs_scale)
target_variables.append(self.inputs_amax_history)
target_variables.append(self.kernel_scale)
target_variables.append(self.kernel_amax_history)
target_variables.append(self.outputs_grad_scale)
target_variables.append(self.outputs_grad_amax_history)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
if self.lora_enabled:
self.lora_kernel_a.assign(ops.zeros(self.lora_kernel_a.shape))
self.lora_kernel_b.assign(ops.zeros(self.lora_kernel_b.shape))
Expand Down
12 changes: 8 additions & 4 deletions keras/src/layers/core/embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,15 +199,17 @@ def save_own_variables(self, store):
embeddings_value, embeddings_scale = (
self._get_embeddings_with_merged_lora()
)
store["0"] = embeddings_value
target_variables = [embeddings_value]
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
store["1"] = embeddings_scale
target_variables.append(embeddings_scale)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
store[str(i)] = variable

def load_own_variables(self, store):
if not self.lora_enabled:
Expand All @@ -217,15 +219,17 @@ def load_own_variables(self, store):
return
# The keys of the `store` will be saved as determined because the
# default ordering will change after quantization
self._embeddings.assign(store["0"])
target_variables = [self._embeddings]
if isinstance(self.dtype_policy, dtype_policies.QuantizedDTypePolicy):
mode = self.dtype_policy.quantization_mode
if mode == "int8":
self.embeddings_scale.assign(store["1"])
target_variables.append(self.embeddings_scale)
else:
raise NotImplementedError(
self.QUANTIZATION_MODE_ERROR_TEMPLATE.format(mode)
)
for i, variable in enumerate(target_variables):
variable.assign(store[str(i)])
if self.lora_enabled:
self.lora_embeddings_a.assign(
ops.zeros(self.lora_embeddings_a.shape)
Expand Down

0 comments on commit afc92f5

Please sign in to comment.