Skip to content

Commit

Permalink
Merge pull request #31 from kc-ml2/DEV
Browse files Browse the repository at this point in the history
Dev
  • Loading branch information
yonghakim authored Mar 16, 2023
2 parents 8ff4b21 + 9b018da commit fcb4f88
Show file tree
Hide file tree
Showing 20 changed files with 898 additions and 577 deletions.
78 changes: 78 additions & 0 deletions QA/async_2d_fourier_order.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import numpy as np

import meent

from meent.testcase import load_setting
from benchmarks.interface.Reticolo import Reticolo


mode = 0
dtype = 0
device = 0
grating_type = 2
pre = load_setting(mode, dtype, device, grating_type)

reti = Reticolo()
de_ri_top_inc, de_ti_top_inc, de_ri_bot_inc, de_ti_bot_inc = reti.run(**pre)
de_ri_reti = de_ri_top_inc.flatten()
de_ti_reti = de_ti_top_inc.flatten()

print('reti de_ri:', de_ri_reti)
print('reti de_ti:', de_ti_reti)

# Numpy
mode = 0
pre = load_setting(mode, dtype, device, grating_type)
mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre)
mee.fft_type = 0

de_ri, de_ti = mee.conv_solve()
center = np.array(de_ri.shape) // 2
de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
cut_index = [3, 1, 4, 7, 5]
# cut_index = [3, 1, 4, 5]
de_ri_cut = de_ri_cut.flatten()[cut_index]
de_ti_cut = de_ti_cut.flatten()[cut_index]

print('Norm(Reti, NPY): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut))

# JAX
mode = 1
pre = load_setting(mode, dtype, device, grating_type)
mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre)
mee.fft_type = 0

de_ri, de_ti = mee.conv_solve()
center = np.array(de_ri.shape) // 2

de_ri, de_ti = np.array(de_ri), np.array(de_ti)
de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
de_ri_cut = de_ri_cut.flatten()[cut_index]
de_ti_cut = de_ti_cut.flatten()[cut_index]

# print('meen jx de_ri:', de_ri_cut)
# print('meen jx de_ti:', de_ti_cut)

print('Norm(Reti, JAX): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut))

# Torch
mode = 2
pre = load_setting(mode, dtype, device, grating_type)
mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre)
mee.fft_type = 0

de_ri, de_ti = mee.conv_solve()
center = np.array(de_ri.shape) // 2

de_ri, de_ti = np.array(de_ri), np.array(de_ti)
de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2]
de_ri_cut = de_ri_cut.flatten()[cut_index]
de_ti_cut = de_ti_cut.flatten()[cut_index]

# print('meen to de_ri:', de_ri_cut)
# print('meen to de_ti:', de_ti_cut)

print('Norm(Reti, TOR): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut))
16 changes: 8 additions & 8 deletions QA/backprop.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def load_setting(mode_key, dtype, device):
wavelength = 900

ucell_materials = [1, 3.48]
fourier_order = 2
fourier_order = [2, 2]

period = [1000, 1000]
thickness = [1120., 400, 300]
Expand Down Expand Up @@ -133,12 +133,12 @@ def optimize_jax_ucell_metasurface(mode_key, dtype, device):
@jax.grad
def grad_loss(ucell):

E_conv_all = to_conv_mat_discrete(ucell, fourier_order, type_complex=type_complex)
o_E_conv_all = to_conv_mat_discrete(1 / ucell, fourier_order, type_complex=type_complex)
E_conv_all = to_conv_mat_discrete(ucell, *fourier_order, type_complex=type_complex)
o_E_conv_all = to_conv_mat_discrete(1 / ucell, *fourier_order, type_complex=type_complex)
de_ri, de_ti = solver.solve(wavelength, E_conv_all, o_E_conv_all)
c = de_ti.shape[0] // 2
loss = de_ti[c, c]
print(loss.primal)
# print(loss.primal)
return loss

def grad_numerical(ucell, delta):
Expand All @@ -149,14 +149,14 @@ def grad_numerical(ucell, delta):
for c in range(ucell.shape[2]):
ucell_delta_m = ucell.at[layer, r, c].set(ucell[layer, r, c] - delta)

E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, fourier_order, type_complex=type_complex)
o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, fourier_order, type_complex=type_complex)
E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, *fourier_order, type_complex=type_complex)
o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, *fourier_order, type_complex=type_complex)
de_ri_delta_m, de_ti_delta_m = solver.solve(wavelength, E_conv_all_m, o_E_conv_all_m)

ucell_delta_p = ucell.at[layer, r, c].set(ucell[layer, r, c] + delta)

E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, fourier_order, type_complex=type_complex)
o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, fourier_order, type_complex=type_complex)
E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, *fourier_order, type_complex=type_complex)
o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, *fourier_order, type_complex=type_complex)
de_ri_delta_p, de_ti_delta_p = solver.solve(wavelength, E_conv_all_p, o_E_conv_all_p)

center = de_ti_delta_m.shape[0] // 2
Expand Down
22 changes: 12 additions & 10 deletions benchmarks/interface/Reticolo.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ def run(self, grating_type, period, fourier_order, ucell, thickness, theta, phi,
textures = [n_I, *ucell_new, n_II]

else:
fourier_order = [fourier_order, fourier_order]

Nx = ucell.shape[2]
Ny = ucell.shape[1]
Expand Down Expand Up @@ -114,14 +113,14 @@ def run_acs(self, pattern, n_si='SILICON'):
mode = 0
dtype = 0
device = 0
grating_type = 0
grating_type = 2
pre = load_setting(mode, dtype, device, grating_type)

reti = Reticolo()
a,b,c,d = reti.run(**pre)

print(np.array(a).flatten()[::-1])
print(np.array(b).flatten()[::-1])
print(np.array(a).flatten())
print(np.array(b).flatten())
# print(np.array(a))
# print(np.array(b))
# print(c)
Expand All @@ -131,18 +130,21 @@ def run_acs(self, pattern, n_si='SILICON'):
mode = 0
pre = load_setting(mode, dtype, device, grating_type)
mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre)
mee.fft_type = 1
mee.fft_type = 0

de_ri, de_ti = mee.conv_solve()
c = de_ri.shape[0]//2
center = np.array(de_ri.shape) // 2
try:
print(de_ri[c-1:c+2, c-1:c+2])
print(de_ti[c-1:c+2, c-1:c+2])
print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2])
print(de_ti[center[0]-1:center[0]+2, center[1]-1:center[1]+2])
except:
# print(de_ri[c-1:c+2])
# print(de_ti[c-1:c+2])
print(de_ri)
print(de_ti)
print(de_ri[center[0]-1:center[0]+2])
print(de_ti[center[0]-1:center[0]+2])

# print(de_ri)
# print(de_ti)

print(a.sum(),de_ri.sum())
print(b.sum(),de_ti.sum())
Expand Down
50 changes: 37 additions & 13 deletions examples/ex_ucell.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,24 +17,24 @@
import meent

# common
pol = 1 # 0: TE, 1: TM
pol = 0 # 0: TE, 1: TM

n_I = 1 # n_incidence
n_II = 1 # n_transmission

theta = 10 * np.pi / 180
phi = 0 * np.pi / 180
theta = 20 * np.pi / 180
phi = 50 * np.pi / 180
psi = 0 if pol else 90 * np.pi / 180

wavelength = 900

thickness = [500]
ucell_materials = [1, 'p_si']
ucell_materials = [1, 'p_si__real']
period = [1000, 1000]
# period = [1000, 1000]
fourier_order = 2

fourier_order = 3
mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', }
n_iter = 2
n_iter = 3


def run_test(grating_type, mode_key, dtype, device):
Expand Down Expand Up @@ -88,23 +88,28 @@ def run_test(grating_type, mode_key, dtype, device):
AA = meent.call_mee(mode=mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi,
psi=psi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell,
ucell_materials=ucell_materials,
thickness=thickness, device=device, type_complex=type_complex, fft_type=1, improve_dft=True)
thickness=thickness, device=device, type_complex=type_complex, fft_type=0, improve_dft=True)

# for i in range(n_iter):
# t0 = time.time()
# AA.conv_solve_calculate_field()
# print(f'run_cell: {i}: ', time.time() - t0)
for i in range(n_iter):
t0 = time.time()
de_ri, de_ti = AA.conv_solve()
print(f'run_cell: {i}: ', time.time() - t0)
resolution = (20, 20, 20)
for i in range(0):
for i in range(2):
t0 = time.time()
AA.calculate_field(resolution=resolution, plot=False)
print(f'cal_field: {i}', time.time() - t0)

center = np.array(de_ri.shape) // 2
print(de_ri.sum(), de_ti.sum())
try:
center = de_ri.shape[0] // 2
print(de_ri[center-1:center+2, center-1:center+2])
print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2])
except:
print(de_ri[center-1:center+2])
print(de_ri[center[0]-1:center[0]+2])

return de_ri, de_ti

Expand All @@ -117,6 +122,7 @@ def run_loop(a, b, c, d):
print(f'grating:{grating_type}, backend:{bd}, dtype:{dtype}, dev:{device}')
run_test(grating_type, bd, dtype, device)


def load_ucell(grating_type):
if grating_type in [0, 1]:

Expand Down Expand Up @@ -159,9 +165,27 @@ def load_ucell(grating_type):
[0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ],
],
])

# ucell = np.array([
#
# [
# [
# 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
# ],
# [
# 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
# ],
# [
# 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
# ],
# [
# 0, 1, 0, 1, 1, 0, 1, 0, 1, 1,
# ],
# ],
# ])
# ucell = ucell * 4 + 1
return ucell


if __name__ == '__main__':
run_loop([0, 1, 2], [0, 1, 2], [0], [0])
run_loop([0], [2], [0], [0])
Loading

0 comments on commit fcb4f88

Please sign in to comment.