Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Is there a way for using gpu acceleration in the finetune gpt2 with LoRA use case example? #918

Open
filippo-merlo opened this issue Oct 14, 2024 · 1 comment

Comments

@filippo-merlo
Copy link

When I try to use compile_model with CUDA as the specified device, I encounter the following error. Is there a way to resolve this, or is the lora.py code not yet compatible with running on a GPU?"

the tutorial I am following: https://github.com/zama-ai/concrete-ml/tree/release/1.7.x/use_case_examples/lora_finetuning

Traceback (most recent call last):
File "/lora_finetuning/lorafinetunegpt2.py", line 135, in
hybrid_model.compile_model(inputset, n_bits=16, device="cuda")
File "/src/concrete/ml/torch/hybrid_model.py", line 516, in compile_model
self.private_q_modules[name] = compile_torch_model(
^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 342, in compile_torch_model
return _compile_torch_or_onnx_model(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 224, in _compile_torch_or_onnx_model
quantized_module = build_quantized_module(
^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/compile.py", line 124, in build_quantized_module
numpy_model = NumpyModule(model, dummy_input_for_tracing)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/torch/numpy_module.py", line 51, in init
) = get_equivalent_numpy_forward_from_torch(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/src/concrete/ml/onnx/convert.py", line 153, in get_equivalent_numpy_forward_from_torch
torch.onnx.export(
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 516, in export
_export(
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1612, in _export
graph, params_dict, torch_out = _model_to_graph(
^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1134, in _model_to_graph
graph, params, torch_out, module = _create_jit_graph(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 1010, in _create_jit_graph
graph, torch_out = _trace_and_get_graph_from_model(model, args)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/onnx/utils.py", line 914, in _trace_and_get_graph_from_model
trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 1310, in _get_trace_graph
outs = ONNXTracedModule(
^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 138, in forward
graph, out = torch._C._create_graph_by_tracing(
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/jit/_trace.py", line 129, in wrapper
outs.append(self.inner(*trace_inputs))
^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1522, in _slow_forward
result = self.forward(*input, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/miniconda3/envs/concrete-ml/lib/python3.11/site-packages/torch/nn/modules/linear.py", line 116, in forward
return F.linear(input, self.weight, self.bias)
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu! (when checking argument for argument mat1 in method wrapper_CUDA_addmm)

This is the only part of the code I modified:

Create the HybridFHEModel with the specified remote modules

hybrid_model = HybridFHEModel(lora_training, module_names=remote_names)

Prepare input data for calibration

input_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
tokenizer.vocab_size - 1
)
label_tensor = torch.randint(0, 2, (PER_DEVICE_TRAIN_BATCH_SIZE, BLOCK_SIZE)) * (
tokenizer.vocab_size - 1
)

input_tensor = input_tensor.to("cuda")
label_tensor = label_tensor.to("cuda")

inputset = (input_tensor, label_tensor)

Calibrate and compile the model

hybrid_model.model.toggle_calibrate(enable=True)
hybrid_model.compile_model(inputset, n_bits=16, device="cuda")

@jfrery
Copy link
Collaborator

jfrery commented Oct 15, 2024

Hi @filippo-merlo,

The issue you have comes from a misunderstanding of two separate GPU modes:

  1. Standard torch GPU acceleration (triggered by tensor.to("cuda"))
  2. FHE GPU acceleration (activated by compile(..., device="cuda"))

For now, torch GPU acceleration is not fully supported in lora. The bottleneck being FHE rather than CPU cleartext computation. That being said, there might be cases where it's useful. Also it's pretty easy to support if ever you are interested to open a PR on this.

About the FHE GPU acceleration, the lora fine tuning use case is about learning private weights with the base model parameters on a third-party server. To do this, we only need to do linear layers remotely using FHE and these parts is not yet accelerated on GPU. GPU acceleration is more useful for end-to-end FHE computations with non-linear parts for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants