@inproceedings{jiaqigu2021L2ight,
title = {L2ight: Enabling On-Chip Learning for Optical Neural Networks via Efficient in-situ Subspace Optimization},
author = {Jiaqi Gu and Hanqing Zhu and Chenghao Feng and Zixuan Jiang and Ray T. Chen and David Z. Pan},
booktitle = {Conference on Neural Information Processing Systems (NeurIPS)},
year = {2021}
}
Integrated neuromorphic photonics simulation framework based on PyTorch. It supports coherent and incoherent optical neural networks (ONNs) training/inference on GPUs. It can scale up to million-parameter ONNs with efficient implementation.
Researchers on neuromorphic photonics, optical AI system design, photonic integrated circuit optimization, ONN training/inference.
CUDA-backed fast GPU support, optimized highly-parallel tensorized processing, versatile APIs for device/circuit/architecture/algorithm co-optimization
- 09/17/2023: v0.0.6 available. Support add-drop MRR weight bank and initialize ONN layers from standard pytorch Conv2d/Linear!
- 04/19/2022: v0.0.5 available. Automatic differentiable photonic tensor core search! Support customized coherent photonic SuperMesh construction from basic building blocks! (Gu+, ADEPT DAC 2022)
- 04/18/2022: v0.0.4 available. Phase change material (PCM)-based photonic in-memory computing with endurance enhancement! (Zhu+, ELight ASP-DAC 2022)
- 04/18/2022: v0.0.3 available. SqueezeLight architecture based on multi-operand microrings for ultra-compact optical neurocomputing! (Gu+, SqueezeLight DATE 2021)
- 11/28/2021: v0.0.2 available. FFT-ONN-family is supported with trainable butterfly meshes for area-efficient frequency-domain optical neurocomputing! (Gu+, FFT-ONN ASP-DAC 2020) (Gu+, FFT-ONN-v2 IEEE TCAD 2021) (Feng+, PSNN arXiv 2021)
- 06/10/2021: v0.0.1 available. MZI-ONN (Shen+, MZI-ONN) is supported. Feedbacks are highly welcomed!
- Python >= 3.6
- PyTorch >= 1.13.0
- Tensorflow-gpu >= 2.5.0
- pyutils >= 0.0.2
- Others are listed in requirements.txt
- GPU model training requires NVIDIA GPUs and compatible CUDA
git clone https://github.com/JeremieMelo/pytorch-onn.git
cd pytorch-onn
python3 setup.py install --user clean
or
./setup.sh
Construct optical NN models as simple as constructing a normal pytorch model.
import torch.nn as nn
import torch.nn.functional as F
import torchonn as onn
from torchonn.models import ONNBaseModel
class ONNModel(ONNBaseModel):
def __init__(self, device=torch.device("cuda:0)):
super().__init__(device=device)
self.conv = onn.layers.MZIBlockConv2d(
in_channels=1,
out_channels=8,
kernel_size=3,
stride=1,
padding=1,
dilation=1,
bias=True,
miniblock=4,
mode="usv",
decompose_alg="clements",
photodetect=True,
device=device,
)
self.pool = nn.AdaptiveAvgPool2d(5)
self.linear = onn.layers.MZIBlockLinear(
in_features=8*5*5,
out_features=10,
bias=True,
miniblock=4,
mode="usv",
decompose_alg="clements",
photodetect=True,
device=device,
)
self.conv.reset_parameters()
self.linear.reset_parameters()
def forward(self, x):
x = torch.relu(self.conv(x))
x = self.pool(x)
x = x.flatten(1)
x = self.linear(x)
return x
- Support pytorch training MZI-based ONNs. Support MZI-based Linear, Conv2d, BlockLinear, and BlockConv2d. Support
weight
,usv
,phase
modes and their conversion. - Support phase quantization and non-ideality injection, including phase shifter gamma error, phase variations, and crosstalk.
- CUDA-accelerated batched MZI array decomposition and reconstruction for ultra-fast real/complex matrix mapping, which achieves 10-50X speedup over CPU-based unitary group parametrization. Francis (Triangle), Reck (Triangle), Clements (Rectangle) styles MZI meshes are supported. To see the efficiency of our CUDA implementation, try the following unittest command at root directory,
python3 unitest/test_op.py
, and check the runtime comparison. - Support pytorch training general frequency-domain ONNs (Gu+, FFT-ONN ASP-DAC 2020) (Gu+, FFT-ONN-v2 IEEE TCAD 2021) (Feng+, PSNN). Support FFT-ONN BlockLinear, and BlockConv2d. Support
fft
,hadamard
,zero_bias
, andtrainable
modes. - Support multi-operand ring-based ONN architecture (Gu+, SqueezeLight DATE 2021). Support AllpassMORRCirculantLinear, AllpassMORRCirculantConv2d with built-in MORR nonlinearity.
- Support phase change material (PCM)-based ONN architecture (Zhu+, ELight ASP-DAC 2022). Support PCMLinear and PCMConv2d with logrithmic PCM wire quantization and PCM array assignment.
- Support micro-ring resonator (MRR)-based ONN. (Tait+, SciRep 2017)
- Support ONN on-chip learning via zeroth-order optimization. (Gu+, FLOPS DAC 2020) (Gu+, MixedTrain AAAI 2021)
File | Description |
---|---|
torchonn/ | Library source files with model, layer, and device definition |
torchonn/op | Basic operators and CUDA-accelerated operators |
torchonn/layers | Optical device-implemented layers |
torchonn/models | Base ONN model templete |
torchonn/devices | Optical device parameters and configurations |
examples/ | ONN model building and training examples |
examples/configs | YAML-based configuration files |
examples/core | ONN model definition and training utility |
example/train.py | training script |
The examples/
folder contains more examples to train the ONN
models.
An example optical convolutional neural network MZI_CLASS_CNN
is defined in examples/core/models/mzi_cnn.py
.
Training facilities, e.g., optimizer, critetion, lr_scheduler, models are built in examples/core/builder.py
.
The training and validation logic is defined in examples/train.py
.
All training hyperparameters are hierarchically defined in the yaml configuration file examples/configs/mnist/mzi_onn/train.yml
(The final config is the union of all default.yml
from higher-level directories and this specific train.yml
).
By running the following commands,
# train the example MZI-based CNN model with 2 64-channel Conv layers and 1 Linear layer
# training will happend in usv mode to optimize U, Sigma, and V*
# projected gradient descent will be applied to guarantee the orthogonality of U and V*
# the final step will convert unitary matrices into MZI phases and evaluate in the phase mode
cd examples
python3 train.py configs/mnist/mzi_cnn/train.yml # [followed by any command-line arguments that override the values in config file, e.g., --optimizer.lr=0.001]
Detailed documentations coming soon.
Jiaqi Gu (jqgu@utexas.edu)
-
Neural operator-enabled fast photonic device simulation: See NeurOLight, NeurIPS 2022.
-
Automatic photonic tensor core design: See ADEPT, DAC 2022.
-
Endurance-enhanced photonic in-memory computing: See ELight, ASP-DAC 2022.
-
Scalable ONN on-chip learning: See L2ight, NeurIPS 2021.
-
Memory-efficient ONN architecture: See Memory-Efficient-ONN, ICCV 2021.
-
SqueezeLight: Scalable ONNs with Multi-Operand Ring Resonators: See SqueezeLight, DATE 2021.