-
Notifications
You must be signed in to change notification settings - Fork 54
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Question about Complex datatype support #265
Comments
Hi @wcqc, Yes, we do not officially support complex data types yet. Hidet does not reply on any pytorch components to work (e.g., you can install hidet and use its other frontends like onnx). Can I know what networks are you using? It would be great if you could share a self-contained script to run a complex-value network and I am happy to add some basic support for that in hidet. |
Hi @yaoyaoding , Thank you for the swift reply! On the other hand more specifically, I'm trying to accelerate code based on this library: https://github.com/mit-han-lab/torchquantum, which mainly involves matrix multiplications where the matrices contain complex numbers). An example program can be found here:
Now my question is: Does it look like hidet would work with a lib such as TorchQuantum above and provide acceleration (after fixing the above error)? Or as it currently stands hidet would not work with TorchQuantum due to the lack of complex dtype support? Thanks again. |
Hi @wcqc, Sorry, I still can not reproduce the error you encountered as I do not have the backgroud of quantum machine learning/simulation and do not know how to put the torch.compile in the example. Again, a self-contained example that can be directly run would be helpful. In general, if the torch dynamo could extract the computation graph for us, we are happy to add the support. I expect it would not take long to make it functionally work, as long as torch dynamo can extract the graph for us and no strange operator occurred. But to achieve the good performance, we still need to add some complex-value number related schedule template (as a complex64 or complex128 number have 8 and 16 bytes which is larger than the heavily optimized data type like float32 and float16 (4 bytes and 2 bytes). |
This is a script to reproduce the above error, it can be run with: python script.py, it still requires the torchquantum library. Is this sufficient?
|
Thanks for the script! I will have a look when I have time and try to add the missing operator when needed. |
Hi @wcqc, I have added the missing operators in #271. But there is a bug that I can not fix from hidet side. You can run the script and get the subsequent error message. It is likely a bug of torch dynamo (I have checked the outputs of hidet backend for the received sub-graph, everything looks good). I did observe that there are a lot of fusion opportunity for this network (e.g., see the You can also try using hidet's onnx frontend if you can export the inference task as an onnx model. import hidet
import torch
import torch.nn.functional as F
import torch.optim as optim
import argparse
import random
import numpy as np
import torchquantum as tq
from torchquantum.plugins import (
tq2qiskit_measurement,
qiskit_assemble_circs,
op_history2qiskit,
op_history2qiskit_expand_params,
)
from torchquantum.datasets import MNIST
from torch.optim.lr_scheduler import CosineAnnealingLR
hidet.option.cache_dir('./outs/cache')
hidet.torch.dynamo_config.dump_graph_ir('./outs/graphs')
hidet.torch.dynamo_config.print_input_graph()
class QFCModel(tq.QuantumModule):
class QLayer(tq.QuantumModule):
def __init__(self):
super().__init__()
self.n_wires = 4
self.random_layer = tq.RandomLayer(
n_ops=50, wires=list(range(self.n_wires))
)
# gates with trainable parameters
self.rx0 = tq.RX(has_params=True, trainable=True)
self.ry0 = tq.RY(has_params=True, trainable=True)
self.rz0 = tq.RZ(has_params=True, trainable=True)
self.crx0 = tq.CRX(has_params=True, trainable=True)
def forward(self, qdev: tq.QuantumDevice):
self.random_layer(qdev)
# some trainable gates (instantiated ahead of time)
self.rx0(qdev, wires=0)
self.ry0(qdev, wires=1)
self.rz0(qdev, wires=3)
self.crx0(qdev, wires=[0, 2])
# add some more non-parameterized gates (add on-the-fly)
qdev.h(wires=3) # type: ignore
qdev.sx(wires=2) # type: ignore
qdev.cnot(wires=[3, 0]) # type: ignore
qdev.rx(
wires=1,
params=torch.tensor([0.1]),
static=self.static_mode,
parent_graph=self.graph,
) # type: ignore
def __init__(self):
super().__init__()
self.n_wires = 4
self.encoder = tq.GeneralEncoder(tq.encoder_op_list_name_dict["4x4_u3rx"])
self.q_layer = self.QLayer()
self.measure = tq.MeasureAll(tq.PauliZ)
def forward(self, x, use_qiskit=False):
qdev = tq.QuantumDevice(
n_wires=self.n_wires, bsz=x.shape[0], device=x.device, record_op=True
)
bsz = x.shape[0]
x = F.avg_pool2d(x, 6).view(bsz, 16)
devi = x.device
if use_qiskit:
# use qiskit to process the circuit
# create the qiskit circuit for encoder
self.encoder(qdev, x)
op_history_parameterized = qdev.op_history
qdev.reset_op_history()
encoder_circs = op_history2qiskit_expand_params(self.n_wires, op_history_parameterized, bsz=bsz)
# create the qiskit circuit for trainable quantum layers
self.q_layer(qdev)
op_history_fixed = qdev.op_history
qdev.reset_op_history()
q_layer_circ = op_history2qiskit(self.n_wires, op_history_fixed)
# create the qiskit circuit for measurement
measurement_circ = tq2qiskit_measurement(qdev, self.measure)
# assemble the encoder, trainable quantum layers, and measurement circuits
assembled_circs = qiskit_assemble_circs(
encoder_circs, q_layer_circ, measurement_circ
)
# call the qiskit processor to process the circuit
x0 = self.qiskit_processor.process_ready_circs(qdev, assembled_circs).to( # type: ignore
devi
)
x = x0
else:
# use torchquantum to process the circuit
self.encoder(qdev, x)
qdev.reset_op_history()
self.q_layer(qdev)
x = self.measure(qdev)
x = x.reshape(bsz, 2, 2).sum(-1).squeeze()
x = F.log_softmax(x, dim=1)
return x
def train(dataflow, model, device, optimizer):
for feed_dict in dataflow["train"]:
inputs = feed_dict["image"].to(device)
targets = feed_dict["digit"].to(device)
outputs = model(inputs)
loss = F.nll_loss(outputs, targets)
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"loss: {loss.item()}", end="\r")
def valid_test(dataflow, split, model, device, qiskit=False):
target_all = []
output_all = []
with torch.no_grad():
for feed_dict in dataflow[split]:
inputs = feed_dict["image"].to(device)
targets = feed_dict["digit"].to(device)
outputs = model(inputs, use_qiskit=qiskit)
target_all.append(targets)
output_all.append(outputs)
target_all = torch.cat(target_all, dim=0)
output_all = torch.cat(output_all, dim=0)
_, indices = output_all.topk(1, dim=1)
masks = indices.eq(target_all.view(-1, 1).expand_as(indices))
size = target_all.shape[0]
corrects = masks.sum().item()
accuracy = corrects / size
loss = F.nll_loss(output_all, target_all).item()
print(f"{split} set accuracy: {accuracy}")
print(f"{split} set loss: {loss}")
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--static", action="store_true", help="compute with " "static mode"
)
parser.add_argument("--pdb", action="store_true", help="debug with pdb")
parser.add_argument(
"--wires-per-block", type=int, default=2, help="wires per block int static mode"
)
parser.add_argument(
"--epochs", type=int, default=2, help="number of training epochs"
)
args = parser.parse_args()
if args.pdb:
import pdb
pdb.set_trace()
seed = 0
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
dataset = MNIST(
root="./mnist_data",
train_valid_split_ratio=[0.9, 0.1],
digits_of_interest=[3, 6],
n_test_samples=75,
)
dataflow = dict()
for split in dataset:
sampler = torch.utils.data.RandomSampler(dataset[split])
dataflow[split] = torch.utils.data.DataLoader(
dataset[split],
batch_size=256,
sampler=sampler,
num_workers=8,
pin_memory=True,
)
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
_model = QFCModel().to(device)
model = torch.compile(_model, backend='hidet')
n_epochs = args.epochs
optimizer = optim.Adam(model.parameters(), lr=5e-3, weight_decay=1e-4)
scheduler = CosineAnnealingLR(optimizer, T_max=n_epochs)
if args.static:
# optionally to switch to the static mode, which can bring speedup
# on training
model.q_layer.static_on(wires_per_block=args.wires_per_block)
for epoch in range(1, n_epochs + 1):
# train
print(f"Epoch {epoch}:")
train(dataflow, model, device, optimizer)
print(optimizer.param_groups[0]["lr"])
# valid
valid_test(dataflow, "valid", model, device)
scheduler.step()
# test
valid_test(dataflow, "test", model, device, qiskit=False)
# run on Qiskit simulator and real Quantum Computers
try:
from qiskit import IBMQ
from torchquantum.plugins import QiskitProcessor
# firstly perform simulate
print(f"\nTest with Qiskit Simulator")
processor_simulation = QiskitProcessor(use_real_qc=False)
model.set_qiskit_processor(processor_simulation)
valid_test(dataflow, "test", model, device, qiskit=True)
# then try to run on REAL QC
backend_name = "ibmq_lima"
print(f"\nTest on Real Quantum Computer {backend_name}")
# Please specify your own hub group and project if you have the
# IBMQ premium plan to access more machines.
processor_real_qc = QiskitProcessor(
use_real_qc=True,
backend_name=backend_name,
hub="ibm-q",
group="open",
project="main",
)
model.set_qiskit_processor(processor_real_qc)
valid_test(dataflow, "test", model, device, qiskit=True)
except ImportError:
print(
"Please install qiskit, create an IBM Q Experience Account and "
"save the account token according to the instruction at "
"'https://github.com/Qiskit/qiskit-ibmq-provider', "
"then try again."
)
if __name__ == "__main__":
main()
|
Hi @yaoyaoding, thanks for adding the operators. I will further debug the new errors. When you say there are lots of fusion opportunities do you mean even w/o complete complex dtype support once the new errors/bugs are fixed, we should see a noticeable speedup with hidet with the above code? (Of course this is empirical, just asking do you mean this would be the case in principle?) Thanks again. |
Hi @wcqc, That depends. I am not familar the commonly used operators in the quantum networks and what operators are the bottleneck. Usually, fusion can greatly reduce the memory access and speedup your network. But if the bottleneck is on the operator like large matrix multiplication, then it will diminish the speedup of fusion other small operators (say, 20% of time for those small operators, you can only get at most 1/0.8 speedup even you can optimize the 20% to zero). Thanks for trying hidet, and let me know if you find out how to fix/avoid above errors (e.g., write the pytorch program in another way). |
The inductor backend of Pytorch2.0 does not officially support Complex data types yet (pytorch/pytorch#93424), just wondering if hidet has the same limitation currently, or not?
If it does does it rely on other parts of Pytorch2.0 (e.g., the inductor, dynamo etc.) to fully support complex, or can complex support be added separately?
The text was updated successfully, but these errors were encountered: