diff --git a/QA/async_2d_fourier_order.py b/QA/async_2d_fourier_order.py new file mode 100644 index 0000000..025f91f --- /dev/null +++ b/QA/async_2d_fourier_order.py @@ -0,0 +1,78 @@ +import numpy as np + +import meent + +from meent.testcase import load_setting +from benchmarks.interface.Reticolo import Reticolo + + +mode = 0 +dtype = 0 +device = 0 +grating_type = 2 +pre = load_setting(mode, dtype, device, grating_type) + +reti = Reticolo() +de_ri_top_inc, de_ti_top_inc, de_ri_bot_inc, de_ti_bot_inc = reti.run(**pre) +de_ri_reti = de_ri_top_inc.flatten() +de_ti_reti = de_ti_top_inc.flatten() + +print('reti de_ri:', de_ri_reti) +print('reti de_ti:', de_ti_reti) + +# Numpy +mode = 0 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +cut_index = [3, 1, 4, 7, 5] +# cut_index = [3, 1, 4, 5] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +print('Norm(Reti, NPY): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) + +# JAX +mode = 1 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 + +de_ri, de_ti = np.array(de_ri), np.array(de_ti) +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +# print('meen jx de_ri:', de_ri_cut) +# print('meen jx de_ti:', de_ti_cut) + +print('Norm(Reti, JAX): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) + +# Torch +mode = 2 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 + +de_ri, de_ti = np.array(de_ri), np.array(de_ti) +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +# print('meen to de_ri:', de_ri_cut) +# print('meen to de_ti:', de_ti_cut) + +print('Norm(Reti, TOR): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) diff --git a/QA/backprop.py b/QA/backprop.py index 44e7011..202ebfe 100644 --- a/QA/backprop.py +++ b/QA/backprop.py @@ -32,7 +32,7 @@ def load_setting(mode_key, dtype, device): wavelength = 900 ucell_materials = [1, 3.48] - fourier_order = 2 + fourier_order = [2, 2] period = [1000, 1000] thickness = [1120., 400, 300] @@ -133,12 +133,12 @@ def optimize_jax_ucell_metasurface(mode_key, dtype, device): @jax.grad def grad_loss(ucell): - E_conv_all = to_conv_mat_discrete(ucell, fourier_order, type_complex=type_complex) - o_E_conv_all = to_conv_mat_discrete(1 / ucell, fourier_order, type_complex=type_complex) + E_conv_all = to_conv_mat_discrete(ucell, *fourier_order, type_complex=type_complex) + o_E_conv_all = to_conv_mat_discrete(1 / ucell, *fourier_order, type_complex=type_complex) de_ri, de_ti = solver.solve(wavelength, E_conv_all, o_E_conv_all) c = de_ti.shape[0] // 2 loss = de_ti[c, c] - print(loss.primal) + # print(loss.primal) return loss def grad_numerical(ucell, delta): @@ -149,14 +149,14 @@ def grad_numerical(ucell, delta): for c in range(ucell.shape[2]): ucell_delta_m = ucell.at[layer, r, c].set(ucell[layer, r, c] - delta) - E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, fourier_order, type_complex=type_complex) - o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, fourier_order, type_complex=type_complex) + E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, *fourier_order, type_complex=type_complex) + o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, *fourier_order, type_complex=type_complex) de_ri_delta_m, de_ti_delta_m = solver.solve(wavelength, E_conv_all_m, o_E_conv_all_m) ucell_delta_p = ucell.at[layer, r, c].set(ucell[layer, r, c] + delta) - E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, fourier_order, type_complex=type_complex) - o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, fourier_order, type_complex=type_complex) + E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, *fourier_order, type_complex=type_complex) + o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, *fourier_order, type_complex=type_complex) de_ri_delta_p, de_ti_delta_p = solver.solve(wavelength, E_conv_all_p, o_E_conv_all_p) center = de_ti_delta_m.shape[0] // 2 diff --git a/benchmarks/interface/Reticolo.py b/benchmarks/interface/Reticolo.py index 9d556d8..a00d54c 100644 --- a/benchmarks/interface/Reticolo.py +++ b/benchmarks/interface/Reticolo.py @@ -56,7 +56,6 @@ def run(self, grating_type, period, fourier_order, ucell, thickness, theta, phi, textures = [n_I, *ucell_new, n_II] else: - fourier_order = [fourier_order, fourier_order] Nx = ucell.shape[2] Ny = ucell.shape[1] @@ -114,14 +113,14 @@ def run_acs(self, pattern, n_si='SILICON'): mode = 0 dtype = 0 device = 0 - grating_type = 0 + grating_type = 2 pre = load_setting(mode, dtype, device, grating_type) reti = Reticolo() a,b,c,d = reti.run(**pre) - print(np.array(a).flatten()[::-1]) - print(np.array(b).flatten()[::-1]) + print(np.array(a).flatten()) + print(np.array(b).flatten()) # print(np.array(a)) # print(np.array(b)) # print(c) @@ -131,18 +130,21 @@ def run_acs(self, pattern, n_si='SILICON'): mode = 0 pre = load_setting(mode, dtype, device, grating_type) mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) - mee.fft_type = 1 + mee.fft_type = 0 de_ri, de_ti = mee.conv_solve() - c = de_ri.shape[0]//2 + center = np.array(de_ri.shape) // 2 try: - print(de_ri[c-1:c+2, c-1:c+2]) - print(de_ti[c-1:c+2, c-1:c+2]) + print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) + print(de_ti[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) except: # print(de_ri[c-1:c+2]) # print(de_ti[c-1:c+2]) - print(de_ri) - print(de_ti) + print(de_ri[center[0]-1:center[0]+2]) + print(de_ti[center[0]-1:center[0]+2]) + + # print(de_ri) + # print(de_ti) print(a.sum(),de_ri.sum()) print(b.sum(),de_ti.sum()) diff --git a/examples/ex_ucell.py b/examples/ex_ucell.py index 5916a9a..58dbbec 100644 --- a/examples/ex_ucell.py +++ b/examples/ex_ucell.py @@ -17,24 +17,24 @@ import meent # common -pol = 1 # 0: TE, 1: TM +pol = 0 # 0: TE, 1: TM n_I = 1 # n_incidence n_II = 1 # n_transmission -theta = 10 * np.pi / 180 -phi = 0 * np.pi / 180 +theta = 20 * np.pi / 180 +phi = 50 * np.pi / 180 psi = 0 if pol else 90 * np.pi / 180 wavelength = 900 thickness = [500] -ucell_materials = [1, 'p_si'] +ucell_materials = [1, 'p_si__real'] period = [1000, 1000] -# period = [1000, 1000] -fourier_order = 2 + +fourier_order = 3 mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', } -n_iter = 2 +n_iter = 3 def run_test(grating_type, mode_key, dtype, device): @@ -88,23 +88,28 @@ def run_test(grating_type, mode_key, dtype, device): AA = meent.call_mee(mode=mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, ucell_materials=ucell_materials, - thickness=thickness, device=device, type_complex=type_complex, fft_type=1, improve_dft=True) + thickness=thickness, device=device, type_complex=type_complex, fft_type=0, improve_dft=True) + # for i in range(n_iter): + # t0 = time.time() + # AA.conv_solve_calculate_field() + # print(f'run_cell: {i}: ', time.time() - t0) for i in range(n_iter): t0 = time.time() de_ri, de_ti = AA.conv_solve() print(f'run_cell: {i}: ', time.time() - t0) resolution = (20, 20, 20) - for i in range(0): + for i in range(2): t0 = time.time() AA.calculate_field(resolution=resolution, plot=False) print(f'cal_field: {i}', time.time() - t0) + center = np.array(de_ri.shape) // 2 + print(de_ri.sum(), de_ti.sum()) try: - center = de_ri.shape[0] // 2 - print(de_ri[center-1:center+2, center-1:center+2]) + print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) except: - print(de_ri[center-1:center+2]) + print(de_ri[center[0]-1:center[0]+2]) return de_ri, de_ti @@ -117,6 +122,7 @@ def run_loop(a, b, c, d): print(f'grating:{grating_type}, backend:{bd}, dtype:{dtype}, dev:{device}') run_test(grating_type, bd, dtype, device) + def load_ucell(grating_type): if grating_type in [0, 1]: @@ -159,9 +165,27 @@ def load_ucell(grating_type): [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ], ], ]) + + # ucell = np.array([ + # + # [ + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # ], + # ]) # ucell = ucell * 4 + 1 return ucell if __name__ == '__main__': - run_loop([0, 1, 2], [0, 1, 2], [0], [0]) + run_loop([0], [2], [0], [0]) diff --git a/meent/on_jax/emsolver/_base.py b/meent/on_jax/emsolver/_base.py index 2ad1188..bdb8f11 100644 --- a/meent/on_jax/emsolver/_base.py +++ b/meent/on_jax/emsolver/_base.py @@ -12,16 +12,17 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, device='cpu', type_complex=jnp.complex128): self.device = device self.type_complex = type_complex - if self.type_complex == jnp.complex128: + if self.type_complex == jnp.complex128: # TODO: need this? jax.config.update('jax_enable_x64', True) + # TODO: consider to make systematic and apply to other bds self.type_int = jnp.int64 if self.type_complex == jnp.complex128 else jnp.int32 self.type_float = jnp.float64 if self.type_complex == jnp.complex128 else jnp.float32 @@ -44,10 +45,14 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] - self.period = period + self.period = deepcopy(period) self.wavelength = wavelength self.thickness = deepcopy(thickness) @@ -58,8 +63,6 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - # if self.theta == 0: - # self.theta = self.perturbation self.theta = jnp.where(self.theta == 0, self.perturbation, self.theta) self.kx_vector = None @@ -67,7 +70,7 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= def get_kx_vector(self, wavelength): k0 = 2 * jnp.pi / wavelength - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices = jnp.arange(-self.fourier_order[0], self.fourier_order[0] + 1) if self.grating_type == 0: kx_vector = k0 * (self.n_I * jnp.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) ).astype(self.type_complex) @@ -80,24 +83,23 @@ def get_kx_vector(self, wavelength): return kx_vector def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): - self.layer_info_list = [] self.T1 = None - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = jnp.zeros(self.ff, dtype=self.type_complex) - delta_i0 = delta_i0.at[self.fourier_order].set(1) + delta_i0 = jnp.zeros(ff, dtype=self.type_complex) + delta_i0 = delta_i0.at[self.fourier_order[0]].set(1) k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ - = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, self.period, self.pol, wl=wavelength) else: raise ValueError @@ -128,7 +130,7 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - if self.algo == 'TMM': + if self.algo == 'TMM': # TODO: fourier order? X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T, type_complex=self.type_complex) @@ -159,15 +161,16 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.T1 = None # fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = jnp.zeros(self.ff, dtype=self.type_complex) - delta_i0 = delta_i0.at[self.fourier_order].set(1) + delta_i0 = jnp.zeros(ff, dtype=self.type_complex) + delta_i0 = delta_i0.at[self.fourier_order[0]].set(1) k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -181,7 +184,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex, device=self.device) @@ -194,7 +197,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 @@ -211,21 +214,26 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + # fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices_y = jnp.arange(-self.fourier_order[1], self.fourier_order[1] + 1) + + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff_xy = ff_x * ff_y - delta_i0 = jnp.zeros((self.ff ** 2, 1), dtype=self.type_complex) - delta_i0 = delta_i0.at[self.ff ** 2 // 2, 0].set(1) + delta_i0 = jnp.zeros((ff_xy, 1), dtype=self.type_complex) + delta_i0 = delta_i0.at[ff_xy // 2, 0].set(1) - I = jnp.eye(self.ff ** 2).astype(self.type_complex) - O = jnp.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) + I = jnp.eye(ff_xy).astype(self.type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=self.type_complex) - center = self.ff ** 2 + center = ff_xy k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ @@ -239,36 +247,36 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv_i = jnp.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex, - device=self.device) + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, + device=self.device, type_complex=self.type_complex) # TODO: device? big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, - type_complex=self.type_complex) + type_complex=self.type_complex) # tODO: device? layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, - type_complex=self.type_complex) + type_complex=self.type_complex) # TODO: device? self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order, self.fourier_order) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_jax/emsolver/convolution_matrix.py b/meent/on_jax/emsolver/convolution_matrix.py index 287352e..5e16a73 100644 --- a/meent/on_jax/emsolver/convolution_matrix.py +++ b/meent/on_jax/emsolver/convolution_matrix.py @@ -44,21 +44,17 @@ def cell_compression(cell, type_complex=jnp.complex128): # @partial(jax.jit, static_argnums=(1,2 )) -def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): +def fft_piecewise_constant(cell, fourier_order_x, fourier_order_y, type_complex=jnp.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = jnp.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = jnp.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = jnp.arange(-2 * fourier_order_x, 2 * fourier_order_x + 1, 1) - f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes[None, :]).astype(type_complex) + f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes_x[None, :]).astype(type_complex) c = f_coeffs_x.shape[1] // 2 x_next = jnp.vstack((jnp.roll(x, -1, axis=0)[:-1], 1)) - x @@ -67,18 +63,18 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): assign_value = (cell @ jnp.vstack((x[0], x_next[:-1]))).flatten().astype(type_complex) f_coeffs_x = f_coeffs_x.at[assign_index].set(assign_value) - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_x[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) + assign_index = mask + assign_value = f_coeffs_x[:, mask] / (1j * 2 * jnp.pi * modes_x[mask]) f_coeffs_x = f_coeffs_x.at[:, assign_index].set(assign_value) # Y axis f_coeffs_x_next_y = jnp.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = jnp.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = jnp.arange(-2 * fourier_order_y, 2 * fourier_order_y + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes[None, :]).astype(type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes_y[None, :]).astype(type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = jnp.vstack((jnp.roll(y, -1, axis=0)[:-1], 1)) - y @@ -88,10 +84,10 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) if c: - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_xy[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + assign_index = mask + assign_value = f_coeffs_xy[:, mask] / (1j * 2 * jnp.pi * modes_y[mask]) f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) @@ -99,21 +95,14 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): # @partial(jax.jit, static_argnums=(1,2 )) # tODO: jit-able? -def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.complex128): - - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, type_complex=type_complex) - +def fft_piecewise_constant_vector(cell, x, y, fourier_order_x, fourier_order_y, type_complex=jnp.complex128): # X axis cell_next_x = jnp.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = jnp.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = jnp.arange(-2 * fourier_order_x, 2 * fourier_order_x + 1, 1) - f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes[None, :]).astype(type_complex) + f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes_x[None, :]).astype(type_complex) c = f_coeffs_x.shape[1] // 2 x_next = jnp.vstack((jnp.roll(x, -1, axis=0)[:-1], 1)) - x @@ -122,18 +111,18 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.co assign_value = (cell @ jnp.vstack((x[0], x_next[:-1]))).flatten().astype(type_complex) f_coeffs_x = f_coeffs_x.at[assign_index].set(assign_value) - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_x[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) + assign_index = mask + assign_value = f_coeffs_x[:, mask] / (1j * 2 * jnp.pi * modes_x[mask]) f_coeffs_x = f_coeffs_x.at[:, assign_index].set(assign_value) # Y axis f_coeffs_x_next_y = jnp.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = jnp.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = jnp.arange(-2 * fourier_order_y, 2 * fourier_order_y + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes[None, :]).astype(type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes_y[None, :]).astype(type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = jnp.vstack((jnp.roll(y, -1, axis=0)[:-1], 1)) - y @@ -143,22 +132,23 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.co f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) if c: - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_xy[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + assign_index = mask + assign_value = f_coeffs_xy[:, mask] / (1j * 2 * jnp.pi * modes_y[mask]) f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) return f_coeffs_xy.T -def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, type_complex=jnp.complex128): +def to_conv_mat_continuous_vector(ucell_info_list, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128): - ff = 2 * fourier_order + 1 + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 - e_conv_all = jnp.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) - o_e_conv_all = jnp.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) + e_conv_all = jnp.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + o_e_conv_all = jnp.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) # 2D for i, ucell_info in enumerate(ucell_info_list): @@ -166,16 +156,19 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t ucell_layer = ucell_layer ** 2 f_coeffs = fft_piecewise_constant_vector(ucell_layer, x_list, y_list, - fourier_order, type_complex=type_complex) + fourier_order_x, fourier_order_y, type_complex=type_complex) o_f_coeffs = fft_piecewise_constant_vector(1/ucell_layer, x_list, y_list, - fourier_order, type_complex=type_complex) + fourier_order_x, fourier_order_y, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + conv_i = jnp.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -186,19 +179,17 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=jnp.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128): + ucell_pmt = ucell ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 + if ucell_pmt.shape[1] == 1: # 1D - if pmt.shape[1] == 1: # 1D - res = jnp.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + ff = 2 * fourier_order_x + 1 - for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + res = jnp.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) + + for i, layer in enumerate(ucell_pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order_x, fourier_order_y, type_complex=type_complex) center = f_coeffs.shape[1] // 2 conv_idx = jnp.arange(-ff + 1, ff, 1) conv_idx = circulant(conv_idx) @@ -206,41 +197,46 @@ def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=jnp.com res = res.at[i].set(e_conv) else: # 2D - # attention on the order of axis (Z Y X) - res = jnp.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) - for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 + + res = jnp.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + + for i, layer in enumerate(ucell_pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order_x, fourier_order_y, type_complex=type_complex) center = jnp.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + conv_i = jnp.repeat(conv_i, jnp.array([ff_x] * ff_y), axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res = res.at[i].set(e_conv) return res -@partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=jnp.complex128, improve_dft=True): - pmt = pmt ** 2 +# @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) +def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 + + if ucell_pmt.shape[1] == 1: # 1D - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 + ff = 2 * fourier_order_x + 1 - if pmt.shape[1] == 1: # 1D - res = jnp.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + res = jnp.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) @@ -256,33 +252,40 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=jnp.compl res = res.at[i].set(e_conv) else: # 2D - # attention on the order of axis (Z Y X) - res = jnp.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 + + res = jnp.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # 9 * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = minimum_pattern_size_1 // layer.shape[0] + for i, layer in enumerate(ucell_pmt): + if layer.shape[0] < minimum_pattern_size_y: + n = minimum_pattern_size_y // layer.shape[0] layer = jnp.repeat(layer, n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = minimum_pattern_size_2 // layer.shape[1] + if layer.shape[1] < minimum_pattern_size_x: + n = minimum_pattern_size_x // layer.shape[1] layer = jnp.repeat(layer, n + 1, axis=1) f_coeffs = jnp.fft.fftshift(jnp.fft.fft2(layer / layer.size)) center = jnp.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + conv_i = jnp.repeat(conv_i, jnp.array([ff_x] * ff_y), axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res = res.at[i].set(e_conv) diff --git a/meent/on_jax/emsolver/field_distribution.py b/meent/on_jax/emsolver/field_distribution.py index 895b5fb..315ae36 100644 --- a/meent/on_jax/emsolver/field_distribution.py +++ b/meent/on_jax/emsolver/field_distribution.py @@ -16,11 +16,11 @@ def field_distribution(grating_type, *args, **kwargs): return res -def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), +def field_dist_1d_original(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), type_complex=jnp.complex128): k0 = 2 * jnp.pi / wavelength - fourier_indices = jnp.arange(-fourier_order, fourier_order + 1) + fourier_indices = jnp.arange(-fourier_order[0], fourier_order[0] + 1) # kx_vector = k0 * (n_I * jnp.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) Kx = jnp.diag(kx_vector / k0) @@ -28,7 +28,8 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in resolution_z, resolution_y, resolution_x = resolution # Here use numpy array due to slow assignment speed in JAX - field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + # field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + field_cell = jnp.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) T_layer = T1 @@ -55,7 +56,93 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in for j in range(resolution_y): for i in range(resolution_x): res = x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector) - field_cell[resolution_z * idx_layer + k, j, i] = res + # field_cell[resolution_z * idx_layer + k, j, i] = res + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set(res) + + T_layer = a_i @ X @ T_layer + + return field_cell + + +@partial(jax.jit, static_argnums=(4, 5, 9, 10, 11, )) +def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order_x, fourier_order_y, T1, layer_info_list, period, pol, resolution=(100, 1, 100), + type_complex=jnp.complex128): + + k0 = 2 * jnp.pi / wavelength + fourier_indices = jnp.arange(-fourier_order_x, fourier_order_x + 1) + # kx_vector = k0 * (n_I * jnp.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) + + Kx = jnp.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + # Here use numpy array due to slow assignment speed in JAX + # field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + field_cell = jnp.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + + T_layer = T1 + + # From the first layer + for idx_layer, (E_conv_i, q, W, X, a_i, b, d) in enumerate(layer_info_list[::-1]): + + c1 = T_layer[:, None] + c2 = b @ a_i @ X @ T_layer[:, None] + + Q = jnp.diag(q) + + if pol == 0: + V = W @ Q + EKx = None + + else: + V = E_conv_i @ W @ Q + EKx = E_conv_i @ Kx + + for k in range(resolution_z): + z = k / resolution_z * d + + if pol == 0: + Sy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Ux = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + C = (-1j) * Kx @ Sy + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Ey = Sy.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hx = -1j * Ux.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hz = C.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Ey[0, 0], Hx[0, 0], Hz[0, 0]]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Ey, Hx, Hz]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 0].set(Ey[0, 0]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 1].set(Hx[0, 0]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 2].set(Hz[0, 0]) + # res = [Ey[0, 0], Hx[0, 0], Hz[0, 0]] + + else: + Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + C = (-1j) * EKx @ Uy # there is a better option for convergence + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Hy = Uy.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ex = 1j * Sx.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ez = C.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Hy[0, 0], Ex[0, 0], Ez[0, 0]]) + # res = [Hy[0, 0], Ex[0, 0], Ez[0, 0]] + + # A, B, C = z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx) + # for j in range(resolution_y): + # for i in range(resolution_x): + # res = x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector) + # # field_cell[resolution_z * idx_layer + k, j, i] = res + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set(res) T_layer = a_i @ X @ T_layer @@ -143,7 +230,7 @@ def field_dist_2d(wavelength, kx_vector, n_I, theta, phi, fourier_order, T1, lay return field_cell -@partial(jax.jit, static_argnums=(0,)) +# @partial(jax.jit, static_argnums=(0,)) def z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx): if pol == 0: # TE @@ -162,7 +249,7 @@ def z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx): return Uy, Sx, C -@partial(jax.jit, static_argnums=(0,)) +# @partial(jax.jit, static_argnums=(0,)) def x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector): if pol == 0: # TE diff --git a/meent/on_jax/emsolver/rcwa.py b/meent/on_jax/emsolver/rcwa.py index 5428d70..817233a 100644 --- a/meent/on_jax/emsolver/rcwa.py +++ b/meent/on_jax/emsolver/rcwa.py @@ -34,7 +34,7 @@ def __init__(self, type_complex=jnp.complex128, fft_type=0, improve_dft=True, - **kwargs, + **kwargs, # TODO: need htis? ): super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, @@ -77,7 +77,7 @@ def _tree_flatten(self): def _tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data) - @jax.jit + # @jax.jit def _solve(self, wavelength, e_conv_all, o_e_conv_all): self.kx_vector = self.get_kx_vector(wavelength) @@ -93,19 +93,22 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector + return de_ri, de_ti # @jax.jit # TODO: can draw field? with jit? - def conv_solve(self, **kwargs): - [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? + def _conv_solve(self): if self.fft_type == 0: - E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - elif self.fft_type == 1: - E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) - o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) + E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex, improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex, improve_dft=self.improve_dft) elif self.fft_type == 2: E_conv_all, o_E_conv_all = to_conv_mat_continuous_vector(self.ucell_info_list, self.fourier_order, type_complex=self.type_complex) @@ -113,6 +116,29 @@ def conv_solve(self, **kwargs): raise ValueError de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) + # self.layer_info_list = layer_info_list + # self.T1 = T1 + # self.kx_vector = kx_vector + return de_ri, de_ti, layer_info_list, T1, kx_vector + + def _conv_solve_no_jit(self): + + E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) + o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) + + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) + + return de_ri, de_ti, layer_info_list, T1, kx_vector + + def conv_solve(self, **kwargs): + [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? + + if self.fft_type == 1: + de_ri, de_ti, layer_info_list, T1, kx_vector = self._conv_solve_no_jit() + else: + de_ri, de_ti, layer_info_list, T1, kx_vector = self._conv_solve() self.layer_info_list = layer_info_list self.T1 = T1 @@ -122,8 +148,8 @@ def conv_solve(self, **kwargs): @jax.jit def conv_solve_spectrum(self, ucell): # TODO: other backends - E_conv_all = to_conv_mat_discrete(ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) + E_conv_all = to_conv_mat_discrete(ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) return de_ri, de_ti @@ -149,24 +175,51 @@ def run_ucell_pmap(self, ucell_list): de_ti = np.array(de_ti) return de_ri, de_ti + # TODO: jit? fourier order split in args? + # @jax.jit def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, *self.fourier_order, self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) else: resolution = [10, 10, 10] if not resolution else resolution - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) + if plot: + field_plot(field_cell, self.pol) + return field_cell + + @jax.jit + def conv_solve_calculate_field(self, resolution=None, plot=False): + self._conv_solve() + if self.grating_type == 0: + resolution = (20, 20, 20) + + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, + self.layer_info_list, self.period, self.pol, resolution=resolution, type_complex=self.type_complex) + elif self.grating_type == 1: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) + + else: + resolution = [10, 10, 10] if not resolution else resolution + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) return field_cell diff --git a/meent/on_jax/emsolver/transfer_method.py b/meent/on_jax/emsolver/transfer_method.py index a3493ab..19fad12 100644 --- a/meent/on_jax/emsolver/transfer_method.py +++ b/meent/on_jax/emsolver/transfer_method.py @@ -4,8 +4,6 @@ import jax import jax.numpy as jnp -# import meent.on_jax.jitted as ee -# from . import jitted as ee from .primitives import eig @@ -41,7 +39,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = jnp.eye(2 * fourier_order + 1).astype(type_complex) + T = jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) return kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T @@ -57,8 +55,8 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=jnp.compl a_i = jnp.linalg.inv(a) - f = W @ (jnp.eye(2 * fourier_order + 1).astype(type_complex) + X @ b @ a_i @ X) - g = V @ (jnp.eye(2 * fourier_order + 1).astype(type_complex) - X @ b @ a_i @ X) + f = W @ (jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) + X @ b @ a_i @ X) + g = V @ (jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b @@ -230,15 +228,15 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, type_complex=jnp.complex128): - I = jnp.eye(ff ** 2).astype(type_complex) - O = jnp.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = jnp.eye(ff_xy).astype(type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) # kx_vector = k0 * (n_I * jnp.sin(theta) * jnp.cos(phi) + fourier_indices * ( # wavelength / period[0])).astype(type_complex) - ky_vector = k0 * (n_I * jnp.sin(theta) * jnp.sin(phi) + fourier_indices * ( + ky_vector = k0 * (n_I * jnp.sin(theta) * jnp.sin(phi) + fourier_indices_y * ( wavelength / period[1])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -247,8 +245,8 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, k_I_z = k_I_z.flatten().conjugate() k_II_z = k_II_z.flatten().conjugate() - Kx = jnp.diag(jnp.tile(kx_vector, ff).flatten()) / k0 - Ky = jnp.diag(jnp.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + Kx = jnp.diag(jnp.tile(kx_vector, ff_y).flatten()) / k0 + Ky = jnp.diag(jnp.tile(ky_vector.reshape((-1, 1)), ff_x).flatten()) / k0 varphi = jnp.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() @@ -261,13 +259,13 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, big_F = jnp.block([[I, O], [O, 1j * Z_II]]) big_G = jnp.block([[1j * Y_II, O], [O, I]]) - big_T = jnp.eye(ff ** 2 * 2).astype(type_complex) + big_T = jnp.eye(2 * ff_xy).astype(type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'): - I = jnp.eye(ff ** 2).astype(type_complex) +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'): + I = jnp.eye(ff_xy).astype(type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -347,10 +345,10 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, typ return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=jnp.complex128): - I = jnp.eye(ff ** 2).astype(type_complex) - O = jnp.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = jnp.eye(ff_xy).astype(type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] @@ -383,14 +381,14 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = jnp.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy:2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * jnp.conj(R_s) * jnp.real(k_I_z / (k0 * n_I * jnp.cos(theta))) \ + R_p * jnp.conj(R_p) * jnp.real((k_I_z / n_I ** 2) / (k0 * n_I * jnp.cos(theta))) diff --git a/meent/on_numpy/emsolver/_base.py b/meent/on_numpy/emsolver/_base.py index 821674e..be1c3cd 100644 --- a/meent/on_numpy/emsolver/_base.py +++ b/meent/on_numpy/emsolver/_base.py @@ -9,7 +9,7 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, device='cpu', type_complex=np.complex128): @@ -35,8 +35,12 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] self.period = deepcopy(period) @@ -49,58 +53,50 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - if self.theta == 0: - self.theta = self.perturbation + self.theta = np.where(self.theta == 0, self.perturbation, self.theta) - self.kx_vector = None + self.kx_vector = None # tODO: need this? why only kx, not ky? def get_kx_vector(self, wavelength): k0 = 2 * np.pi / wavelength - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - if self.grating_type == 0: - kx_vector = k0 * (self.n_I * np.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) + fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) + + if self.grating_type == 0: # TODO: need this? + kx_vector = k0 * (self.n_I * np.sin(self.theta) + fourier_indices_x * (wavelength / self.period[0]) ).astype(self.type_complex) else: - kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) + fourier_indices * ( + kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) + fourier_indices_x * ( wavelength / self.period[0])).astype(self.type_complex) - # kx_vector = kx_vector.conjugate() # kx_vector = np.where(kx_vector == 0, self.perturbation, kx_vector) return kx_vector def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): - self.layer_info_list = [] self.T1 = None - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = np.zeros(self.ff, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = np.zeros(ff, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ - = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, self.period, self.pol, wl=wavelength) else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - # From the last layer - for layer_index in range(count)[::-1]: - - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): if self.pol == 0: E_conv_i = None @@ -143,7 +139,7 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.T1 = T1 elif self.algo == 'SMM': - de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fourier_order, Kzr, Kzt, + de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, ff, Wr, self.fourier_order, Kzr, Kzt, self.n_I, self.n_II, self.theta, self.pol) else: raise ValueError @@ -155,16 +151,16 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = np.zeros(self.ff, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = np.zeros(ff, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -172,21 +168,22 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - - # From the last layer - for layer_index in range(count)[::-1]: - - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] + # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # + # # From the last layer + # for layer_index in range(count)[::-1]: + # + # E_conv = E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + # d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = np.linalg.inv(E_conv) o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex) layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] @@ -198,7 +195,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 @@ -215,21 +212,25 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices_y = np.arange(-self.fourier_order[1], self.fourier_order[1] + 1) + + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff_xy = ff_x * ff_y - delta_i0 = np.zeros((self.ff ** 2, 1), dtype=self.type_complex) - delta_i0[self.ff ** 2 // 2, 0] = 1 + delta_i0 = np.zeros((ff_xy, 1), dtype=self.type_complex) + delta_i0[ff_xy // 2, 0] = 1 - I = np.eye(self.ff ** 2, dtype=self.type_complex) - O = np.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) + I = np.eye(ff_xy, dtype=self.type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=self.type_complex) - center = self.ff ** 2 + center = ff_xy k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': @@ -238,20 +239,20 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - - # From the last layer - for layer_index in range(count)[::-1]: - - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] - + # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # + # # From the last layer + # for layer_index in range(count)[::-1]: + # + # E_conv = E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + # d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = np.linalg.inv(E_conv) o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ @@ -262,24 +263,24 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, q = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, q = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_numpy/emsolver/convolution_matrix.py b/meent/on_numpy/emsolver/convolution_matrix.py index 6505f9e..e425e96 100644 --- a/meent/on_numpy/emsolver/convolution_matrix.py +++ b/meent/on_numpy/emsolver/convolution_matrix.py @@ -42,19 +42,18 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): reference: reticolo """ - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] + cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = np.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :], dtype=type_complex) + f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes_x[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x @@ -63,15 +62,15 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() mask = np.ones(f_coeffs_x.shape[1], dtype=bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = np.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :], dtype=type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes_y[None, :], dtype=type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = np.vstack((np.roll(y, -1, axis=0)[:-1], 1)) - y @@ -81,46 +80,38 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): if c: mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.complex128): - """ - reference: reticolo - """ - - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, type_complex=type_complex) + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] # X axis cell_next_x = np.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :], dtype=type_complex) + f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes_x[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 - # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x x_next = np.vstack((np.roll(x, -1, axis=0)[:-1], 1)) - x f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() mask = np.ones(f_coeffs_x.shape[1], dtype=bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = np.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :], dtype=type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes_y[None, :], dtype=type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = np.vstack((np.roll(y, -1, axis=0)[:-1], 1)) - y @@ -130,33 +121,52 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com if c: mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, type_complex=np.complex128): # TODO: do conv and 1/conv in simultaneously? - ff = 2 * fourier_order + 1 - e_conv_all = np.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) - o_e_conv_all = np.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + o_e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) # 2D # tODO: 1D for i, ucell_info in enumerate(ucell_info_list): ucell_layer, x_list, y_list = ucell_info ucell_layer = ucell_layer ** 2 + f_coeffs = fft_piecewise_constant_vector(ucell_layer, x_list, y_list, fourier_order, type_complex=type_complex) o_f_coeffs = fft_piecewise_constant_vector(1/ucell_layer, x_list, y_list, fourier_order, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] + # + # e_conv_all[i] = e_conv + # o_e_conv_all[i] = o_e_conv + + # TODO: check correct + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -167,21 +177,18 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=np.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order, device=None, type_complex=np.complex128): + ucell_pmt = ucell ** 2 # TODO: do conv and 1/conv in simultaneously? - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 - if pmt.shape[1] == 1: # 1D - res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 - for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + res = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) + for i, layer in enumerate(ucell_pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) center = f_coeffs.shape[1] // 2 conv_idx = np.arange(-ff + 1, ff, 1, dtype=int) conv_idx = circulant(conv_idx) @@ -189,40 +196,54 @@ def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=np.comp res[i] = e_conv else: # 2D - # attention on the order of axis (Z Y X) - res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) - for i, layer in enumerate(pmt): + # TODO: cleaning + + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv return res -def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.complex128, improve_dft=True): - pmt = pmt ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 +def to_conv_mat_discrete(ucell, fourier_order, device=None, type_complex=np.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 - if pmt.shape[1] == 1: # 1D - res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 + res = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) @@ -231,41 +252,60 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.comple # https://kr.mathworks.com/matlabcentral/answers/15770-scaling-the-fft-and-the-ifft?s_tid=srchtitle center = f_coeffs.shape[1] // 2 - conv_idx = np.arange(-ff + 1, ff, 1, dtype=int) conv_idx = circulant(conv_idx) e_conv = f_coeffs[0, center + conv_idx] res[i] = e_conv else: # 2D - # attention on the order of axis (Z Y X) - # TODO: separate fourier order - res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # e.g., 9 bytes * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = minimum_pattern_size_1 // layer.shape[0] + for i, layer in enumerate(ucell_pmt): + + if layer.shape[0] < minimum_pattern_size_y: + n = minimum_pattern_size_y // layer.shape[0] layer = np.repeat(layer, n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = minimum_pattern_size_2 // layer.shape[1] + + if layer.shape[1] < minimum_pattern_size_x: + n = minimum_pattern_size_x // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size)) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + # + + # TODO: check correct + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv diff --git a/meent/on_numpy/emsolver/rcwa.py b/meent/on_numpy/emsolver/rcwa.py index 12651ee..c30fd46 100644 --- a/meent/on_numpy/emsolver/rcwa.py +++ b/meent/on_numpy/emsolver/rcwa.py @@ -31,12 +31,13 @@ def __init__(self, type_complex=np.complex128, fft_type=0, improve_dft=True, + **kwargs, ): super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, fourier_order=fourier_order, period=period, wavelength=wavelength, thickness=thickness, algo=algo, perturbation=perturbation, - device=device, type_complex=type_complex,) + device=device, type_complex=type_complex, ) self.ucell = deepcopy(ucell) self.ucell_materials = ucell_materials @@ -52,7 +53,7 @@ def __init__(self, self.layer_info_list = [] def _solve(self, wavelength, e_conv_all, o_e_conv_all): - self.kx_vector = self.get_kx_vector(wavelength) + self.kx_vector = self.get_kx_vector(wavelength) # TODO: add ky_vector? if self.grating_type == 0: de_ri, de_ti, layer_info_list, T1 = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) @@ -66,7 +67,12 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector + return de_ri, de_ti def conv_solve(self, *args, **kwargs): @@ -77,44 +83,12 @@ def conv_solve(self, *args, **kwargs): improve_dft=self.improve_dft) o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - E_conv_all1 = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) - o_E_conv_all1 = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) - - # print(1, np.linalg.norm(E_conv_all - E_conv_all1)) - # print(2, np.linalg.norm(o_E_conv_all1 - o_E_conv_all)) - - # import matplotlib.pyplot as plt - # plt.imshow(abs(E_conv_all[0])) - # plt.colorbar() - # plt.show() - # import matplotlib.pyplot as plt - # plt.imshow(abs(E_conv_all1[0])) - # plt.colorbar() - # plt.show() - # - # import matplotlib.pyplot as plt - # plt.imshow(abs(o_E_conv_all[0])) - # plt.colorbar() - # plt.show() - # import matplotlib.pyplot as plt - # plt.imshow(abs(o_E_conv_all[0])) - # plt.colorbar() - # plt.show() - # - # plt.imshow(abs(E_conv_all[0] - E_conv_all1[0])) - # plt.colorbar() - # plt.show() - # - # plt.imshow(abs(o_E_conv_all[0] - o_E_conv_all1[0])) - # plt.colorbar() - # plt.show() - elif self.fft_type == 1: E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) elif self.fft_type == 2: E_conv_all, o_E_conv_all = to_conv_mat_continuous_vector(self.ucell_info_list, self.fourier_order, - type_complex=self.type_complex) + type_complex=self.type_complex) else: raise ValueError @@ -130,23 +104,21 @@ def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, - self.layer_info_list, self.period, self.pol, resolution=resolution, + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, + self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, - self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) else: resolution = [10, 10, 10] if not resolution else resolution t0 = time.time() - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) - print(time.time() - t0) + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) diff --git a/meent/on_numpy/emsolver/transfer_method.py b/meent/on_numpy/emsolver/transfer_method.py index c68cbcc..a83bbc7 100644 --- a/meent/on_numpy/emsolver/transfer_method.py +++ b/meent/on_numpy/emsolver/transfer_method.py @@ -35,7 +35,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = np.eye(2 * fourier_order + 1, dtype=type_complex) + T = np.eye(2 * fourier_order[0] + 1, dtype=type_complex) return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T @@ -52,22 +52,21 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=np.comple a_i = np.linalg.inv(a) - f = W @ (np.eye(2 * fourier_order + 1, dtype=type_complex) + X @ b @ a_i @ X) - g = V @ (np.eye(2 * fourier_order + 1, dtype=type_complex) - X @ b @ a_i @ X) + f = W @ (np.eye(2 * fourier_order[0] + 1, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (np.eye(2 * fourier_order[0] + 1, dtype=type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): - T1 = np.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) R = f1 @ T1 - delta_i0 T = T @ T1 de_ri = np.real(R * np.conj(R) * k_I_z / (k0 * n_I * np.cos(theta))) if polarization == 0: - # de_ti = T * np.conj(T) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) + # de_ti = T * np.conj(T) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) #TODO: use this de_ti = np.real(T * np.conj(T) * k_II_z / (k0 * n_I * np.cos(theta))) elif polarization == 1: # de_ti = T * np.conj(T) * np.real(k_II_z / n_II ** 2) / (k0 * np.cos(theta) / n_I) @@ -114,7 +113,6 @@ def transfer_1d_conical_1(ff, k0, n_I, n_II, kx_vector, theta, phi, type_complex def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=np.complex128): - I = np.eye(ff, dtype=type_complex) O = np.zeros((ff, ff), dtype=type_complex) @@ -178,7 +176,6 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=np.complex128): - I = np.eye(ff, dtype=type_complex) O = np.zeros((ff, ff), dtype=type_complex) @@ -229,16 +226,15 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, type_complex=np.complex128): - - I = np.eye(ff ** 2, dtype=type_complex) - O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = np.eye(ff_xy, dtype=type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=type_complex) # kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) + fourier_indices * ( # wavelength / period[0])).astype(type_complex) - ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) + fourier_indices * ( + ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) + fourier_indices_y * ( wavelength / period[1])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -247,8 +243,8 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, k_I_z = k_I_z.flatten().conjugate() k_II_z = k_II_z.flatten().conjugate() - Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 - Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + Kx = np.diag(np.tile(kx_vector, ff_y).flatten()) / k0 + Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff_x).flatten()) / k0 varphi = np.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() @@ -261,14 +257,13 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, big_F = np.block([[I, O], [O, 1j * Z_II]]) big_G = np.block([[1j * Y_II, O], [O, I]]) - big_T = np.eye(2 * ff ** 2, dtype=type_complex) + big_T = np.eye(2 * ff_xy, dtype=type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): - - I = np.eye(ff ** 2, dtype=type_complex) +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): + I = np.eye(ff_xy, dtype=type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -349,11 +344,10 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, typ return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=np.complex128): - - I = np.eye(ff ** 2, dtype=type_complex) - O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = np.eye(ff_xy, dtype=type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=type_complex) big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] @@ -386,14 +380,14 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = np.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy: 2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * np.conj(R_s) * np.real(k_I_z / (k0 * n_I * np.cos(theta))) \ + R_p * np.conj(R_p) * np.real((k_I_z / n_I ** 2) / (k0 * n_I * np.cos(theta))) @@ -402,4 +396,3 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i + T_p * np.conj(T_p) * np.real((k_II_z / n_II ** 2) / (k0 * n_I * np.cos(theta))) return de_ri.real, de_ti.real, big_T1 - diff --git a/meent/on_torch/emsolver/_base.py b/meent/on_torch/emsolver/_base.py index 4608760..a4d09a0 100644 --- a/meent/on_torch/emsolver/_base.py +++ b/meent/on_torch/emsolver/_base.py @@ -12,23 +12,20 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, - device='cpu', type_complex=torch.complex128, **kwargs): + device='cpu', type_complex=torch.complex128): self.device = device self.type_complex = type_complex # common self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 + self.n_I = n_I self.n_II = n_II - # self.theta = torch.tensor(theta * np.pi / 180) - # self.phi = torch.tensor(phi * np.pi / 180) - # self.psi = torch.tensor(psi * np.pi / 180) - # degree to radian due to JAX JIT self.theta = torch.tensor(theta) self.phi = torch.tensor(phi) @@ -43,8 +40,12 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] self.period = deepcopy(period) @@ -57,15 +58,14 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - if self.theta == 0: - self.theta = torch.tensor(self.perturbation) + self.theta = torch.where(self.theta == 0, self.perturbation, self.theta) # TODO: check correct? self.kx_vector = None def get_kx_vector(self, wavelength): k0 = 2 * np.pi / wavelength - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + fourier_indices = torch.arange(-self.fourier_order[0], self.fourier_order[0] + 1, device=self.device) if self.grating_type == 0: kx_vector = k0 * (self.n_I * torch.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) ).type(self.type_complex) @@ -82,16 +82,18 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = torch.zeros(ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': @@ -110,6 +112,9 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] + # Can't use this. seems like bug in Torch. + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + if self.pol == 0: E_conv_i = None A = Kx ** 2 - E_conv @@ -168,16 +173,17 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = torch.zeros(ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -194,12 +200,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] + # Can't use this. Seems like bug in Torch. + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = torch.linalg.inv(E_conv) o_E_conv_i = torch.linalg.inv(o_E_conv) if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, device=self.device, type_complex=self.type_complex) @@ -212,7 +220,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, device=self.device, type_complex=self.type_complex) self.T1 = big_T1 @@ -229,21 +237,26 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + fourier_indices_y = torch.arange(-self.fourier_order[1], self.fourier_order[1] + 1, device=self.device) + + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff_xy = ff_x * ff_y - delta_i0 = torch.zeros((self.ff ** 2, 1), device=self.device, dtype=self.type_complex) - delta_i0[self.ff ** 2 // 2, 0] = 1 + delta_i0 = torch.zeros((ff_xy, 1), device=self.device, dtype=self.type_complex) + delta_i0[ff_xy // 2, 0] = 1 - I = torch.eye(self.ff ** 2, device=self.device, dtype=self.type_complex) - O = torch.zeros((self.ff ** 2, self.ff ** 2), device=self.device, dtype=self.type_complex) + I = torch.eye(ff_xy, device=self.device, dtype=self.type_complex) + O = torch.zeros((ff_xy, ff_xy), device=self.device, dtype=self.type_complex) - center = self.ff ** 2 + center = ff_xy k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ @@ -251,6 +264,8 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) # From the last layer @@ -259,12 +274,11 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): E_conv = E_conv_all[layer_index] o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] - E_conv_i = torch.linalg.inv(E_conv) o_E_conv_i = torch.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device=self.device, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ @@ -276,24 +290,24 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, device=self.device, type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_torch/emsolver/convolution_matrix.py b/meent/on_torch/emsolver/convolution_matrix.py index 91cbeed..13ec428 100644 --- a/meent/on_torch/emsolver/convolution_matrix.py +++ b/meent/on_torch/emsolver/convolution_matrix.py @@ -39,37 +39,37 @@ def cell_compression(cell, device=torch.device('cpu'), type_complex=torch.comple def fft_piecewise_constant(cell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] + cell, x, y = cell_compression(cell, device=device, type_complex=type_complex) # X axis cell_next_x = torch.roll(cell, -1, dims=1) cell_diff_x = cell_next_x - cell + cell_diff_x = cell_diff_x.type(type_complex) - modes = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) + cell = cell.type(type_complex) - cell_diff_x = cell_diff_x.type(type_complex) - f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes[None, :]).type(type_complex) + modes_x = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + + f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes_x[None, :]).type(type_complex) c = f_coeffs_x.shape[1] // 2 - cell = cell.type(type_complex) x_next = torch.vstack((torch.roll(x, -1, dims=0)[:-1], torch.tensor([1], device=device))) - x f_coeffs_x[:, c] = (cell @ torch.vstack((x[0], x_next[:-1]))).flatten() mask = torch.ones(f_coeffs_x.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = torch.roll(f_coeffs_x, -1, dims=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + modes_y = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) - f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes[None, :]) + f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes_y[None, :]) c = f_coeffs_xy.shape[1] // 2 y_next = torch.vstack((torch.roll(y, -1, dims=0)[:-1], torch.tensor([1], device=device))) - y @@ -79,43 +79,41 @@ def fft_piecewise_constant(cell, fourier_order, device=torch.device('cpu'), type if c: mask = torch.ones(f_coeffs_xy.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def fft_piecewise_constant_vector(cell, x, y, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] # tODO - else: - fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, device=device, type_complex=type_complex) + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] # X axis cell_next_x = torch.roll(cell, -1, dims=1) cell_diff_x = cell_next_x - cell + cell_diff_x = cell_diff_x.type(type_complex) - modes = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) + cell = cell.type(type_complex) - cell_diff_x = cell_diff_x.type(type_complex) - f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes[None, :]).type(type_complex) + modes_x = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + + f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes_x[None, :]).type(type_complex) c = f_coeffs_x.shape[1] // 2 - cell = cell.type(type_complex) x_next = torch.vstack((torch.roll(x, -1, dims=0)[:-1], torch.tensor([1], device=device))) - x f_coeffs_x[:, c] = (cell @ torch.vstack((x[0], x_next[:-1]))).flatten() mask = torch.ones(f_coeffs_x.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = torch.roll(f_coeffs_x, -1, dims=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + modes_y = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) - f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes[None, :]) + f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes_y[None, :]) c = f_coeffs_xy.shape[1] // 2 y_next = torch.vstack((torch.roll(y, -1, dims=0)[:-1], torch.tensor([1], device=device))) - y @@ -125,17 +123,18 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, device=torch.device if c: mask = torch.ones(f_coeffs_xy.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - ff = 2 * fourier_order + 1 + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 - e_conv_all = torch.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).type(type_complex) - o_e_conv_all = torch.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).type(type_complex) + e_conv_all = torch.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).type(type_complex) + o_e_conv_all = torch.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).type(type_complex) # 2D # tODO: 1D for i, ucell_info in enumerate(ucell_info_list): @@ -152,11 +151,26 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.d center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] + # + # e_conv_all[i] = e_conv + # o_e_conv_all[i] = o_e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -167,21 +181,15 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.d return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): + ucell_pmt = ucell ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 - # pmt = torch.tensor(pmt) if type(pmt) != torch.Tensor else pmt + res = torch.zeros((ucell_pmt.shape[0], ff, ff), device=device).type(type_complex) - ff = 2 * fourier_order + 1 - - if pmt.shape[1] == 1: # 1D - res = torch.zeros((pmt.shape[0], ff, ff), device=device).type(type_complex) - - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) center = f_coeffs.shape[1] // 2 conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) @@ -190,39 +198,49 @@ def to_conv_mat_continuous(pmt, fourier_order, device=torch.device('cpu'), type_ res[i] = e_conv else: # 2D - # attention on the order of axis (Z Y X) - res = torch.zeros((pmt.shape[0], ff ** 2, ff ** 2), device=device).type(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 - for i, layer in enumerate(pmt): + res = torch.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y), device=device).type(type_complex) + + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv return res -def to_conv_mat_discrete(pmt, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128, improve_dft=True): - pmt = pmt ** 2 +def to_conv_mat_discrete(ucell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - - ff = 2 * fourier_order + 1 - - if pmt.shape[1] == 1: # 1D - res = torch.zeros((pmt.shape[0], ff, ff), device=device).type(type_complex) + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 + res = torch.zeros((ucell_pmt.shape[0], ff, ff), device=device).type(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = layer.repeat_interleave(n + 1, axis=1) @@ -235,33 +253,58 @@ def to_conv_mat_discrete(pmt, fourier_order, device=torch.device('cpu'), type_co res[i] = e_conv else: # 2D - res = torch.zeros((pmt.shape[0], ff ** 2, ff ** 2), device=device).type(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + res = torch.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y), device=device).type(type_complex) + + # if improve_dft: + # minimum_pattern_size_1 = 2 * ff * ucell_pmt.shape[1] + # minimum_pattern_size_2 = 2 * ff * ucell_pmt.shape[2] + # else: + # minimum_pattern_size_1 = 2 * ff + # minimum_pattern_size_2 = 2 * ff + if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # 9 * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = torch.div(minimum_pattern_size_1, layer.shape[0], rounding_mode='trunc') + for i, layer in enumerate(ucell_pmt): + if layer.shape[0] < minimum_pattern_size_y: + n = torch.div(minimum_pattern_size_y, layer.shape[0], rounding_mode='trunc') layer = layer.repeat_interleave(n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = torch.div(minimum_pattern_size_2, layer.shape[1], rounding_mode='trunc') + if layer.shape[1] < minimum_pattern_size_x: + n = torch.div(minimum_pattern_size_x, layer.shape[1], rounding_mode='trunc') layer = layer.repeat_interleave(n + 1, axis=1) f_coeffs = torch.fft.fftshift(torch.fft.fft2(layer / (layer.size(0)*layer.size(1)))) center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv + return res diff --git a/meent/on_torch/emsolver/field_distribution.py b/meent/on_torch/emsolver/field_distribution.py index 5974b6c..87f334f 100644 --- a/meent/on_torch/emsolver/field_distribution.py +++ b/meent/on_torch/emsolver/field_distribution.py @@ -16,7 +16,7 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in device='cpu', type_complex=torch.complex128): k0 = 2 * np.pi / wavelength - fourier_indices = torch.arange(-fourier_order, fourier_order + 1, device=device) + fourier_indices = torch.arange(-fourier_order[0], fourier_order[0] + 1, device=device) # kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).type(type_complex) Kx = torch.diag(kx_vector / k0) diff --git a/meent/on_torch/emsolver/rcwa.py b/meent/on_torch/emsolver/rcwa.py index 4b924d7..1bbb65b 100644 --- a/meent/on_torch/emsolver/rcwa.py +++ b/meent/on_torch/emsolver/rcwa.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import torch from ._base import _BaseRCWA @@ -34,9 +36,9 @@ def __init__(self, super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, fourier_order=fourier_order, period=period, wavelength=wavelength, thickness=thickness, algo=algo, perturbation=perturbation, - device=device, type_complex=type_complex, **kwargs) + device=device, type_complex=type_complex) - self.ucell = ucell + self.ucell = deepcopy(ucell) # TODO: deepcopy? self.ucell_materials = ucell_materials self.ucell_info_list = ucell_info_list @@ -64,16 +66,22 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) - return de_ri, de_ti + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector - def conv_solve(self, *args, **kwargs): + return de_ri, de_ti + def conv_solve(self, *args, **kwargs): # TODO: delete args? [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? if self.fft_type == 0: - E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) + E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, + improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, + improve_dft=self.improve_dft) elif self.fft_type == 1: E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) @@ -95,20 +103,21 @@ def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, - self.layer_info_list, self.period, self.pol, resolution=resolution, + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, + self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, device=self.device, type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, - self.T1, self.layer_info_list, self.period, resolution=resolution, - device=self.device, type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, device=self.device, + type_complex=self.type_complex) else: resolution = [100, 100, 100] if not resolution else resolution - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - device=self.device, type_complex=self.type_complex) + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, device=self.device, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) diff --git a/meent/on_torch/emsolver/transfer_method.py b/meent/on_torch/emsolver/transfer_method.py index cc48ee9..0393491 100644 --- a/meent/on_torch/emsolver/transfer_method.py +++ b/meent/on_torch/emsolver/transfer_method.py @@ -37,7 +37,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + T = torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T @@ -54,15 +54,14 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, device='cpu', type_com a_i = torch.linalg.inv(a) - f = W @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + X @ b @ a_i @ X) - g = V @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) - X @ b @ a_i @ X) + f = W @ (torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): - T1 = torch.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) R = f1 @ T1 - delta_i0 T = T @ T1 @@ -205,7 +204,6 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, bi def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, device='cpu', type_complex=torch.complex128): - I = torch.eye(ff, device=device, dtype=type_complex) O = torch.zeros((ff, ff), device=device, dtype=type_complex) @@ -255,16 +253,15 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 - -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, device='cpu', type_complex=torch.complex128): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) - O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) + O = torch.zeros((ff_xy, ff_xy), device=device, dtype=type_complex) # kx_vector = k0 * (n_I * torch.sin(theta) * torch.cos(phi) + fourier_indices * ( # wavelength / period[0])).type(type_complex) - ky_vector = k0 * (n_I * torch.sin(theta) * torch.sin(phi) + fourier_indices * ( + + ky_vector = k0 * (n_I * torch.sin(theta) * torch.sin(phi) + fourier_indices_y * ( wavelength / period[1])).type(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -273,8 +270,8 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, k_I_z = torch.conj(k_I_z.flatten()) k_II_z = torch.conj(k_II_z.flatten()) - Kx = torch.diag(kx_vector.tile(ff).flatten() / k0) - Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff).flatten() / k0) + Kx = torch.diag(kx_vector.tile(ff_y).flatten() / k0) + Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff_x).flatten() / k0) varphi = torch.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() @@ -298,15 +295,14 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, ] ) - big_T = torch.eye(ff ** 2 * 2, device=device, dtype=type_complex) + big_T = torch.eye(2 * ff_xy, device=device, dtype=type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_complex=torch.complex128, +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_complex=torch.complex128, perturbation=1E-10): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -316,6 +312,7 @@ def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_ torch.cat([Ky ** 2 + B @ o_E_conv_i, Kx @ (E_conv_i @ Ky @ E_conv - Ky)], dim=1), torch.cat([Ky @ (E_conv_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv], dim=1) ]) + Eig.broadening_parameter = perturbation eigenvalues, W = Eig.apply(S2_from_S) @@ -396,11 +393,10 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, dev return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, device='cpu', type_complex=torch.complex128): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) - O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) + O = torch.zeros((ff_xy, ff_xy), device=device, dtype=type_complex) big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] @@ -433,14 +429,14 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = torch.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy:2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * torch.conj(R_s) * torch.real(k_I_z / (k0 * n_I * torch.cos(theta))) \ + R_p * torch.conj(R_p) * torch.real((k_I_z / n_I ** 2) / (k0 * n_I * torch.cos(theta))) diff --git a/meent/testcase.py b/meent/testcase.py index 141980d..7ebf3c2 100644 --- a/meent/testcase.py +++ b/meent/testcase.py @@ -20,7 +20,7 @@ def load_setting(mode, dtype, device, grating_type): wavelength = 900 ucell_materials = [1, 3.48] - fourier_order = 5 + fourier_order = [9, 2] # thickness = [1120., 400, 300] thickness = [1000., ] diff --git a/setup.py b/setup.py index fe09168..f2b838b 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ } setup( name='meent', - version='0.8.1', + version='0.8.2', url='https://github.com/kc-ml2/meent', author='KC ML2', author_email='yongha@kc-ml2.com',