Skip to content

Commit

Permalink
updated Optimizer Class in JAX and Torch.
Browse files Browse the repository at this point in the history
  • Loading branch information
yonghakim committed Mar 6, 2023
1 parent d661b02 commit 548306a
Show file tree
Hide file tree
Showing 13 changed files with 110 additions and 440 deletions.
10 changes: 2 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,11 @@
Meent is a RCWA solver and its applications on optimization problem. We are expecting that this tool can accelerate ML research in photonics.

## How to install

You can install from PyPI

```shell
pip install meent
```

or download this repo and run

```shell
pip install .
```
JAX and PyTorch is needed for advanced utilization.

## How to use

Expand All @@ -34,6 +27,7 @@ mode_key = 1
solver = meent.rcwa.call_solver(mode=mode_key, ...)
```


## When to use

| | Numpy | JAX | PyTorch |
Expand Down
12 changes: 6 additions & 6 deletions benchmarks/interface/Reticolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,8 +136,8 @@ def run_acs(self, pattern, n_si='SILICON'):
from meent.on_numpy.emsolver.convolution_matrix import to_conv_mat_discrete, to_conv_mat_continuous
E_conv_all = to_conv_mat_continuous(solver.ucell, solver.fourier_order)
o_E_conv_all = to_conv_mat_continuous(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(solver.ucell, solver.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(mee.ucell, mee.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / mee.ucell, mee.fourier_order)

de_ri, de_ti, _, _, _ = solver.solve(solver.wavelength, E_conv_all, o_E_conv_all)
c = de_ri.shape[0]//2
Expand All @@ -163,8 +163,8 @@ def run_acs(self, pattern, n_si='SILICON'):
from meent.on_jax.emsolver.convolution_matrix import to_conv_mat_discrete, to_conv_mat_continuous
E_conv_all = to_conv_mat_continuous(solver.ucell, solver.fourier_order)
o_E_conv_all = to_conv_mat_continuous(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(solver.ucell, solver.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(mee.ucell, mee.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / mee.ucell, mee.fourier_order)

de_ri, de_ti, _, _, _ = solver.solve(solver.wavelength, E_conv_all, o_E_conv_all)
c = de_ri.shape[0]//2
Expand All @@ -190,8 +190,8 @@ def run_acs(self, pattern, n_si='SILICON'):
from meent.on_torch.emsolver.convolution_matrix import to_conv_mat_discrete, to_conv_mat_continuous
E_conv_all = to_conv_mat_continuous(solver.ucell, solver.fourier_order)
o_E_conv_all = to_conv_mat_continuous(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(solver.ucell, solver.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / solver.ucell, solver.fourier_order)
# E_conv_all = to_conv_mat_discrete(mee.ucell, mee.fourier_order)
# o_E_conv_all = to_conv_mat_discrete(1 / mee.ucell, mee.fourier_order)

de_ri, de_ti, _, _, _ = solver.solve(solver.wavelength, E_conv_all, o_E_conv_all)
c = de_ri.shape[0]//2
Expand Down
44 changes: 0 additions & 44 deletions examples/optimization/aaa.py

This file was deleted.

28 changes: 28 additions & 0 deletions examples/optimization/ex_jax.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import optax

import meent
from meent.on_jax.optimizer.loss import LossDeflector
from meent.on_jax.optimizer.optimizer import OptimizerJax


mode = 1
dtype = 0
device = 0
grating_type = 2

conditions = meent.testcase.load_setting(mode, dtype, device, grating_type)
mee = OptimizerJax(**conditions)

pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print(1, grad)

# case 2: SGD
optimizer = optax.sgd(learning_rate=1e-2)
mee.fit(pois, forward, loss_fn, optimizer)
print(3, mee.thickness*1E5)
31 changes: 31 additions & 0 deletions examples/optimization/ex_torch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import torch

import meent
from meent.on_torch.optimizer.loss import LossDeflector
from meent.on_torch.optimizer.optimizer import OptimizerTorch


mode = 2
dtype = 0
device = 0
grating_type = 2

conditions = meent.testcase.load_setting(mode, dtype, device, grating_type)
mee = OptimizerTorch(**conditions)

pois = ['ucell', 'thickness']

forward = mee.conv_solve
loss_fn = LossDeflector(x_order=0, y_order=0)

# case 1: Gradient
grad = mee.grad(pois, forward, loss_fn)
print(1, grad)

# case 2: SGD
opt_torch = torch.optim.SGD
opt_options = {'lr': 1E-2,
'momentum': 0.9,
}

mee.fit(pois, forward, loss_fn, opt_torch, opt_options)
Loading

0 comments on commit 548306a

Please sign in to comment.