Skip to content

Commit

Permalink
Update quantization to force gpu usage for blockwise8 (NVIDIA#3256)
Browse files Browse the repository at this point in the history
Fixes # .

### Description

Account for QA finding and [bug from bitsandbytes
](bitsandbytes-foundation/bitsandbytes#1540)
Also add info to supported precisions

### Types of changes
<!--- Put an `x` in all the boxes that apply, and remove the not
applicable items -->
- [x] Non-breaking change (fix or new feature that would not break
existing functionality).
- [ ] Breaking change (fix or new feature that would cause existing
functionality to change).
- [ ] New tests added to cover the changes.
- [ ] Quick tests passed locally by running `./runtest.sh`.
- [ ] In-line docstrings updated.
- [ ] Documentation updated.
  • Loading branch information
ZiyueXu77 authored Feb 27, 2025
1 parent 36a40a2 commit 3e0468a
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 69 deletions.
16 changes: 8 additions & 8 deletions examples/advanced/llm_hf/sft_job.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
from nvflare.app_common.widgets.intime_model_selector import IntimeModelSelector
from nvflare.app_common.workflows.fedavg import FedAvg
from nvflare.app_opt.pt.file_model_persistor import PTFileModelPersistor
from nvflare.app_opt.pt.quantization.dequantizor import ModelDequantizor
from nvflare.app_opt.pt.quantization.quantizor import ModelQuantizor
from nvflare.app_opt.pt.quantization.dequantizer import ModelDequantizer
from nvflare.app_opt.pt.quantization.quantizer import ModelQuantizer
from nvflare.job_config.script_runner import ScriptRunner


Expand Down Expand Up @@ -67,10 +67,10 @@ def main():

if args.quantize_mode:
# If using quantization, add quantize filters.
quantizor = ModelQuantizor(quantization_type=args.quantize_mode)
dequantizor = ModelDequantizor()
job.to(quantizor, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizor, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)
quantizer = ModelQuantizer(quantization_type=args.quantize_mode)
dequantizer = ModelDequantizer()
job.to(quantizer, "server", tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(dequantizer, "server", tasks=["train"], filter_type=FilterType.TASK_RESULT)

# Define the model persistor and send to server
# First send the model to the server
Expand Down Expand Up @@ -106,8 +106,8 @@ def main():
job.to(runner, site_name, tasks=["train"])

if args.quantize_mode:
job.to(quantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizor, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)
job.to(quantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_RESULT)
job.to(dequantizer, site_name, tasks=["train"], filter_type=FilterType.TASK_DATA)

# Export the job
print("job_dir=", job_dir)
Expand Down
11 changes: 8 additions & 3 deletions nvflare/app_opt/pt/quantization/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,20 @@
# See the License for the specific language governing permissions and
# limitations under the License.

# Supported Input Data Type
# Message quantization is mainly for reducing the message that can be
# significantly large, e.g. LLMs. Thus, the supported input data types
# we consider are common ones during LLM training, including fp32, fp16, and bf16.
DATA_TYPE = [
"FLOAT64",
"FLOAT32",
"FLOAT16",
"BFLOAT16",
"UINT8",
"INT8",
]

# Supported Quantization Type to reduce the above input data types
# The quantization types are mainly for reducing the model size,
# Hence, we support 16-, 8-, and 4-bits quantization.
# Note that 8- and 4-bits quantization needs GPU support.
QUANTIZATION_TYPE = [
"FLOAT16",
"BLOCKWISE8",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nvflare.app_opt.pt.quantization.constant import QUANTIZATION_TYPE


class ModelDequantizor(DXOFilter):
class ModelDequantizer(DXOFilter):
def __init__(self):
"""Filter to dequantize Shareable object to recover from quantization
Expand Down Expand Up @@ -84,17 +84,18 @@ def dequantization(
params[param_name] = values
elif quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to dequantize the values
# need GPU for general support
# extract quantization state
if quantization_type == "blockwise8":
if source_data_format == "numpy":
# first convert numpy array to tensor if numpy
quantized = torch.as_tensor(values)
absmax = torch.as_tensor(quant_state[param_name]["absmax"])
code = torch.as_tensor(quant_state[param_name]["code"])
quantized = torch.as_tensor(values).cuda()
absmax = torch.as_tensor(quant_state[param_name]["absmax"]).cuda()
code = torch.as_tensor(quant_state[param_name]["code"]).cuda()
elif source_data_format == "torch":
quantized = values
absmax = quant_state[param_name]["absmax"]
code = quant_state[param_name]["code"]
quantized = values.cuda()
absmax = quant_state[param_name]["absmax"].cuda()
code = quant_state[param_name]["code"].cuda()
# de-quanitze
dequantized = dequantize_blockwise(quantized, absmax=absmax, code=code)
else:
Expand Down Expand Up @@ -125,6 +126,7 @@ def dequantization(
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="fp4")
else:
dequantized = dequantize_4bit(quantized, quantize_state, quant_type="nf4")

if source_data_format == "numpy":
params[param_name] = dequantized.cpu().numpy()
elif source_data_format == "torch":
Expand All @@ -135,16 +137,12 @@ def dequantization(
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].astype(np.float32)
elif source_data_type == "float64":
params[param_name] = params[param_name].astype(np.float64)
elif source_data_type == "float16":
params[param_name] = params[param_name].astype(np.float16)
elif source_data_format == "torch":
# convert back to original data type
if source_data_type == "float32":
params[param_name] = params[param_name].float()
elif source_data_type == "float64":
params[param_name] = params[param_name].double()
elif source_data_type == "float16":
params[param_name] = params[param_name].half()
elif source_data_type == "bfloat16":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nvflare.app_opt.pt.quantization.constant import DATA_TYPE, QUANTIZATION_TYPE


class ModelQuantizor(DXOFilter):
class ModelQuantizer(DXOFilter):
def __init__(
self,
quantization_type="float16",
Expand Down Expand Up @@ -120,41 +120,39 @@ def quantization(self, params: dict, fl_ctx: FLContext):
elif self.quantization_type in ["blockwise8", "float4", "normfloat4"]:
# use bitsandbytes to quantize the values
# input is a tensor, output is a tuple of (quantized tensor, quantized_state)
if self.quantization_type == "blockwise8":
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor
values_tensor = torch.as_tensor(values)
elif source_data_format == "torch":
values_tensor = values

# then quantize the tensor
# CPU has limited support for 8- and 4-bits quantization
# For general purpose, here we use GPU
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()

if self.quantization_type == "blockwise8":
# quantize the tensor
quantized, quantized_state = quantize_blockwise(values_tensor)
# add the quantization state and values, keep source data format
if source_data_format == "numpy":
quant_state[param_name]["absmax"] = quantized_state.absmax.numpy()
quant_state[param_name]["code"] = quantized_state.code.numpy()
values = quantized.numpy()
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu().numpy()
quant_state[param_name]["code"] = quantized_state.code.cpu().numpy()
values = quantized.cpu().numpy()
elif source_data_format == "torch":
quant_state[param_name]["absmax"] = quantized_state.absmax
quant_state[param_name]["code"] = quantized_state.code
values = quantized
quant_state[param_name]["absmax"] = quantized_state.absmax.cpu()
quant_state[param_name]["code"] = quantized_state.code.cpu()
values = quantized.cpu()
n_bytes_meta += quant_state[param_name]["absmax"].nbytes
n_bytes_meta += quant_state[param_name]["code"].nbytes
else:
if source_data_format == "numpy":
# if numpy, first convert numpy array to tensor, need to use GPU
values_tensor = torch.as_tensor(values).cuda()
elif source_data_format == "torch":
# if torch, directly use the tensor, need to use GPU
values_tensor = values.cuda()
# then quantize the tensor
if self.quantization_type == "float4":
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="fp4")
else:
quantized, quantized_state = quantize_4bit(values_tensor, quant_type="nf4")
# add the quantization state and values, keep source data format
quantized_state = quantized_state.as_dict()

# prepared the message
for state_name, state in quantized_state.items():
if isinstance(state, torch.Tensor):
if source_data_format == "numpy":
Expand All @@ -171,6 +169,7 @@ def quantization(self, params: dict, fl_ctx: FLContext):
values = quantized.cpu().numpy()
elif source_data_format == "torch":
values = quantized.cpu()

params[param_name] = values
n_bytes_after += params[param_name].nbytes

Expand Down Expand Up @@ -203,8 +202,8 @@ def process_dxo(self, dxo: DXO, shareable: Shareable, fl_ctx: FLContext) -> Unio
# thus the subsequent communications to the rest of clients will no longer need to apply quantization
# This will not apply to client job, since the client job will be 1-1 and quantization applies to each client
# Potentially:
# If clients talks to each other, it will also be 1-N and same rule applies
# If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# - If clients talks to each other, it will also be 1-N and same rule applies
# - If 1-N server-client filters can be different (Filter_1 applies to server-client_subset_1, etc.), then
# a deep copy of the server data should be made by filter before applying a different filter

# quantized_flag None if does not exist in meta
Expand Down

0 comments on commit 3e0468a

Please sign in to comment.