From 399bb0fba92f0044d9c4c6d351532054ea451676 Mon Sep 17 00:00:00 2001 From: yonghakim Date: Wed, 15 Mar 2023 01:20:07 +0900 Subject: [PATCH 1/4] async 2D order: numpy done but need cleaning; --- examples/ex_ucell.py | 34 +++-- meent/on_numpy/emsolver/_base.py | 85 ++++++++----- meent/on_numpy/emsolver/convolution_matrix.py | 119 ++++++++++++------ meent/on_numpy/emsolver/rcwa.py | 34 +---- meent/on_numpy/emsolver/transfer_method.py | 44 +++++-- setup.py | 2 +- 6 files changed, 196 insertions(+), 122 deletions(-) diff --git a/examples/ex_ucell.py b/examples/ex_ucell.py index 5916a9a..e9f96e1 100644 --- a/examples/ex_ucell.py +++ b/examples/ex_ucell.py @@ -29,10 +29,10 @@ wavelength = 900 thickness = [500] -ucell_materials = [1, 'p_si'] +ucell_materials = [1, 'p_si__real'] period = [1000, 1000] -# period = [1000, 1000] -fourier_order = 2 + +fourier_order = [2] mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', } n_iter = 2 @@ -100,11 +100,13 @@ def run_test(grating_type, mode_key, dtype, device): AA.calculate_field(resolution=resolution, plot=False) print(f'cal_field: {i}', time.time() - t0) + center = np.array(de_ri.shape) // 2 + print(de_ri.sum(), de_ti.sum()) + print(de_ti) try: - center = de_ri.shape[0] // 2 - print(de_ri[center-1:center+2, center-1:center+2]) + print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) except: - print(de_ri[center-1:center+2]) + print(de_ri[center[0]-1:center[0]+2]) return de_ri, de_ti @@ -159,9 +161,27 @@ def load_ucell(grating_type): [0, 0, 0, 1, 1, 0, 0, 0, 0, 0, ], ], ]) + + ucell = np.array([ + + [ + [ + 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + ], + [ + 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + ], + [ + 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + ], + [ + 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + ], + ], + ]) # ucell = ucell * 4 + 1 return ucell if __name__ == '__main__': - run_loop([0, 1, 2], [0, 1, 2], [0], [0]) + run_loop([0], [0], [0], [0]) diff --git a/meent/on_numpy/emsolver/_base.py b/meent/on_numpy/emsolver/_base.py index 821674e..643aa11 100644 --- a/meent/on_numpy/emsolver/_base.py +++ b/meent/on_numpy/emsolver/_base.py @@ -9,7 +9,7 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, device='cpu', type_complex=np.complex128): @@ -35,8 +35,11 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + # self.fourier_order = int(fourier_order) + + self.fourier_order = [int(v) for v in fourier_order] # TODO: other bds + + # self.ff = 2 * self.fourier_order[0] + 1 # TODO self.period = deepcopy(period) @@ -57,12 +60,13 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= def get_kx_vector(self, wavelength): k0 = 2 * np.pi / wavelength - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) + if self.grating_type == 0: - kx_vector = k0 * (self.n_I * np.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) + kx_vector = k0 * (self.n_I * np.sin(self.theta) + fourier_indices_x * (wavelength / self.period[0]) ).astype(self.type_complex) else: - kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) + fourier_indices * ( + kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) + fourier_indices_x * ( wavelength / self.period[0])).astype(self.type_complex) # kx_vector = kx_vector.conjugate() @@ -75,20 +79,22 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = np.zeros(self.ff, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) + + delta_i0 = np.zeros(ff, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ - = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices_x, self.period, self.pol, wl=wavelength) else: raise ValueError @@ -143,7 +149,7 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.T1 = T1 elif self.algo == 'SMM': - de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, self.ff, Wr, self.fourier_order, Kzr, Kzt, + de_ri, de_ti = scattering_1d_3(Wt, Wg, Vt, Vg, Sg, ff, Wr, self.fourier_order, Kzr, Kzt, self.n_I, self.n_II, self.theta, self.pol) else: raise ValueError @@ -155,16 +161,18 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None + ff = self.fourier_order[0] * 2 + 1 + # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = np.zeros(self.ff, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = np.zeros(ff, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -186,7 +194,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex) layer_info = [E_conv_i, q_1, q_2, W_1, W_2, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] @@ -198,7 +206,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 @@ -212,24 +220,41 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): + # TODO: cleaning + self.layer_info_list = [] self.T1 = None - fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) + # fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) + fourier_indices_y = np.arange(-self.fourier_order[1], self.fourier_order[1] + 1) + + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff = 2 * self.fourier_order[0] + 1 + ff_xy = ff_x * ff_y + + delta_i0 = np.zeros((ff ** 2, 1), dtype=self.type_complex) + delta_i0[ff ** 2 // 2, 0] = 1 + + delta_i0 = np.zeros((ff_xy, 1), dtype=self.type_complex) + delta_i0[ff_xy // 2, 0] = 1 + + I = np.eye(ff ** 2, dtype=self.type_complex) + O = np.zeros((ff ** 2, ff ** 2), dtype=self.type_complex) - delta_i0 = np.zeros((self.ff ** 2, 1), dtype=self.type_complex) - delta_i0[self.ff ** 2 // 2, 0] = 1 + I = np.eye(ff_xy, dtype=self.type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=self.type_complex) - I = np.eye(self.ff ** 2, dtype=self.type_complex) - O = np.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) + center = ff ** 2 - center = self.ff ** 2 + center = ff_xy k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': @@ -251,7 +276,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) + W, V, q = transfer_2d_wv(ff, ff_x, ff_y, ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ @@ -262,24 +287,24 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, q = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, q = scattering_2d_wv(ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, ff_x, ff_y, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order, ff) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_numpy/emsolver/convolution_matrix.py b/meent/on_numpy/emsolver/convolution_matrix.py index 6505f9e..a1f7f32 100644 --- a/meent/on_numpy/emsolver/convolution_matrix.py +++ b/meent/on_numpy/emsolver/convolution_matrix.py @@ -42,19 +42,20 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): reference: reticolo """ - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] + if cell.shape[0] == 1: # tODO + fourier_order = fourier_order + [0] + # else: + # fourier_order = [fourier_order, fourier_order] + cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = np.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :], dtype=type_complex) + f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes_x[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x @@ -63,15 +64,15 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() mask = np.ones(f_coeffs_x.shape[1], dtype=bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = np.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :], dtype=type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes_y[None, :], dtype=type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = np.vstack((np.roll(y, -1, axis=0)[:-1], 1)) - y @@ -81,7 +82,7 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): if c: mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T @@ -91,19 +92,19 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com reference: reticolo """ - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] + # if cell.shape[0] == 1: + # fourier_order = [0, fourier_order] + # else: + # fourier_order = [fourier_order, fourier_order] # cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = np.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) - f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes[None, :], dtype=type_complex) + f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes_x[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x @@ -112,15 +113,15 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() mask = np.ones(f_coeffs_x.shape[1], dtype=bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = np.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = np.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = np.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes[None, :], dtype=type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ np.exp(-1j * 2 * np.pi * y @ modes_y[None, :], dtype=type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = np.vstack((np.roll(y, -1, axis=0)[:-1], 1)) - y @@ -130,17 +131,22 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com if c: mask = np.ones(f_coeffs_xy.shape[1], dtype=bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, type_complex=np.complex128): # TODO: do conv and 1/conv in simultaneously? - ff = 2 * fourier_order + 1 + # ff = 2 * fourier_order + 1 + + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 - e_conv_all = np.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) - o_e_conv_all = np.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) + ff = fourier_order[0] + fourier_order[1] + 1 # tODO + + e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + o_e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) # 2D # tODO: 1D for i, ucell_info in enumerate(ucell_info_list): @@ -167,19 +173,21 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=np.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order, device=None, type_complex=np.complex128): + ucell_pmt = ucell ** 2 # TODO: do conv and 1/conv in simultaneously? - if len(pmt.shape) == 2: + if len(ucell_pmt.shape) == 2: print('shape is 2') raise ValueError - ff = 2 * fourier_order + 1 - if pmt.shape[1] == 1: # 1D - res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + if ucell_pmt.shape[1] == 1: # 1D - for i, layer in enumerate(pmt): + ff = 2 * fourier_order[0] + 1 + + res = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) + + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) center = f_coeffs.shape[1] // 2 @@ -189,22 +197,46 @@ def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=np.comp res[i] = e_conv else: # 2D + + # ff = 2 * fourier_order + 1 + + # TODO: cleaning + ff = 2 * fourier_order[0] + 1 + + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + # ff = fourier_order[0] + fourier_order[1] + 1 # tODO + # attention on the order of axis (Z Y X) - res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) + res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv + return res @@ -213,9 +245,11 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.comple if len(pmt.shape) == 2: print('shape is 2') raise ValueError - ff = 2 * fourier_order + 1 if pmt.shape[1] == 1: # 1D + + ff = 2 * fourier_order + 1 + res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) if improve_dft: minimum_pattern_size = 2 * ff * pmt.shape[2] @@ -238,9 +272,16 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.comple res[i] = e_conv else: # 2D + # ff = 2 * fourier_order + 1 + + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + ff = fourier_order[0] + fourier_order[1] + 1 # tODO + # attention on the order of axis (Z Y X) # TODO: separate fourier order - res = np.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) + res = np.zeros((pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) if improve_dft: minimum_pattern_size_1 = 2 * ff * pmt.shape[1] minimum_pattern_size_2 = 2 * ff * pmt.shape[2] diff --git a/meent/on_numpy/emsolver/rcwa.py b/meent/on_numpy/emsolver/rcwa.py index 12651ee..a0f4e38 100644 --- a/meent/on_numpy/emsolver/rcwa.py +++ b/meent/on_numpy/emsolver/rcwa.py @@ -52,7 +52,7 @@ def __init__(self, self.layer_info_list = [] def _solve(self, wavelength, e_conv_all, o_e_conv_all): - self.kx_vector = self.get_kx_vector(wavelength) + self.kx_vector = self.get_kx_vector(wavelength) # TODO: add ky_vector? if self.grating_type == 0: de_ri, de_ti, layer_info_list, T1 = self.solve_1d(wavelength, e_conv_all, o_e_conv_all) @@ -77,38 +77,6 @@ def conv_solve(self, *args, **kwargs): improve_dft=self.improve_dft) o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - E_conv_all1 = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) - o_E_conv_all1 = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) - - # print(1, np.linalg.norm(E_conv_all - E_conv_all1)) - # print(2, np.linalg.norm(o_E_conv_all1 - o_E_conv_all)) - - # import matplotlib.pyplot as plt - # plt.imshow(abs(E_conv_all[0])) - # plt.colorbar() - # plt.show() - # import matplotlib.pyplot as plt - # plt.imshow(abs(E_conv_all1[0])) - # plt.colorbar() - # plt.show() - # - # import matplotlib.pyplot as plt - # plt.imshow(abs(o_E_conv_all[0])) - # plt.colorbar() - # plt.show() - # import matplotlib.pyplot as plt - # plt.imshow(abs(o_E_conv_all[0])) - # plt.colorbar() - # plt.show() - # - # plt.imshow(abs(E_conv_all[0] - E_conv_all1[0])) - # plt.colorbar() - # plt.show() - # - # plt.imshow(abs(o_E_conv_all[0] - o_E_conv_all1[0])) - # plt.colorbar() - # plt.show() - elif self.fft_type == 1: E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) diff --git a/meent/on_numpy/emsolver/transfer_method.py b/meent/on_numpy/emsolver/transfer_method.py index c68cbcc..49074f2 100644 --- a/meent/on_numpy/emsolver/transfer_method.py +++ b/meent/on_numpy/emsolver/transfer_method.py @@ -35,7 +35,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = np.eye(2 * fourier_order + 1, dtype=type_complex) + T = np.eye(2 * fourier_order[0] + 1, dtype=type_complex) return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T @@ -52,8 +52,8 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=np.comple a_i = np.linalg.inv(a) - f = W @ (np.eye(2 * fourier_order + 1, dtype=type_complex) + X @ b @ a_i @ X) - g = V @ (np.eye(2 * fourier_order + 1, dtype=type_complex) - X @ b @ a_i @ X) + f = W @ (np.eye(2 * fourier_order[0] + 1, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (np.eye(2 * fourier_order[0] + 1, dtype=type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b @@ -229,16 +229,20 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, type_complex=np.complex128): + # TODO: cleaning I = np.eye(ff ** 2, dtype=type_complex) O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = np.eye(ff_xy, dtype=type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=type_complex) + # kx_vector = k0 * (n_I * np.sin(theta) * np.cos(phi) + fourier_indices * ( # wavelength / period[0])).astype(type_complex) - ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) + fourier_indices * ( + ky_vector = k0 * (n_I * np.sin(theta) * np.sin(phi) + fourier_indices_y * ( wavelength / period[1])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -250,6 +254,9 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + Kx = np.diag(np.tile(kx_vector, ff_y).flatten()) / k0 + Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff_x).flatten()) / k0 + varphi = np.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() Y_I = np.diag(k_I_z / k0) @@ -262,13 +269,16 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, big_G = np.block([[1j * Y_II, O], [O, I]]) big_T = np.eye(2 * ff ** 2, dtype=type_complex) + big_T = np.eye(2 * ff_xy, dtype=type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): +def transfer_2d_wv(ff, ff_x, ff_y, ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): + # TODO: cleaning I = np.eye(ff ** 2, dtype=type_complex) + I = np.eye(ff_xy, dtype=type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -349,12 +359,15 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, typ return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, ff_x, ff_y, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=np.complex128): - + # TODO: cleaning I = np.eye(ff ** 2, dtype=type_complex) O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = np.eye(ff_xy, dtype=type_complex) + O = np.zeros((ff_xy, ff_xy), dtype=type_complex) + big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] big_F_21 = big_F[center:, :center] @@ -386,14 +399,21 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = np.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + # R_s = final_RT[:ff ** 2, :].flatten() + # R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy: 2 * ff_xy, :].flatten() big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + # T_s = big_T[:ff ** 2, :].flatten() + # T_p = big_T[ff ** 2:, :].flatten() + + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * np.conj(R_s) * np.real(k_I_z / (k0 * n_I * np.cos(theta))) \ + R_p * np.conj(R_p) * np.real((k_I_z / n_I ** 2) / (k0 * n_I * np.cos(theta))) diff --git a/setup.py b/setup.py index fe09168..f2b838b 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ } setup( name='meent', - version='0.8.1', + version='0.8.2', url='https://github.com/kc-ml2/meent', author='KC ML2', author_email='yongha@kc-ml2.com', From 3eb3d13c1c388470e8c3bd82b2c5d2cfc6118678 Mon Sep 17 00:00:00 2001 From: yonghakim Date: Wed, 15 Mar 2023 17:39:05 +0900 Subject: [PATCH 2/4] async 2D order; --- examples/ex_ucell.py | 46 ++-- meent/on_jax/emsolver/_base.py | 74 +++--- meent/on_jax/emsolver/convolution_matrix.py | 219 +++++++++++------- meent/on_jax/emsolver/rcwa.py | 42 ++-- meent/on_jax/emsolver/transfer_method.py | 42 ++-- meent/on_numpy/emsolver/_base.py | 81 +++---- meent/on_numpy/emsolver/convolution_matrix.py | 131 ++++++----- meent/on_numpy/emsolver/rcwa.py | 30 +-- meent/on_numpy/emsolver/transfer_method.py | 35 +-- meent/on_torch/emsolver/_base.py | 81 ++++--- meent/on_torch/emsolver/convolution_matrix.py | 203 +++++++++------- meent/on_torch/emsolver/rcwa.py | 39 ++-- meent/on_torch/emsolver/transfer_method.py | 48 ++-- 13 files changed, 575 insertions(+), 496 deletions(-) diff --git a/examples/ex_ucell.py b/examples/ex_ucell.py index e9f96e1..e18676f 100644 --- a/examples/ex_ucell.py +++ b/examples/ex_ucell.py @@ -17,13 +17,13 @@ import meent # common -pol = 1 # 0: TE, 1: TM +pol = 0 # 0: TE, 1: TM n_I = 1 # n_incidence n_II = 1 # n_transmission -theta = 10 * np.pi / 180 -phi = 0 * np.pi / 180 +theta = 20 * np.pi / 180 +phi = 50 * np.pi / 180 psi = 0 if pol else 90 * np.pi / 180 wavelength = 900 @@ -32,7 +32,7 @@ ucell_materials = [1, 'p_si__real'] period = [1000, 1000] -fourier_order = [2] +fourier_order = [3, 2] mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', } n_iter = 2 @@ -102,7 +102,6 @@ def run_test(grating_type, mode_key, dtype, device): center = np.array(de_ri.shape) // 2 print(de_ri.sum(), de_ti.sum()) - print(de_ti) try: print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) except: @@ -119,6 +118,7 @@ def run_loop(a, b, c, d): print(f'grating:{grating_type}, backend:{bd}, dtype:{dtype}, dev:{device}') run_test(grating_type, bd, dtype, device) + def load_ucell(grating_type): if grating_type in [0, 1]: @@ -162,26 +162,26 @@ def load_ucell(grating_type): ], ]) - ucell = np.array([ - - [ - [ - 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, - ], - [ - 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, - ], - [ - 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, - ], - [ - 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, - ], - ], - ]) + # ucell = np.array([ + # + # [ + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # [ + # 0, 1, 0, 1, 1, 0, 1, 0, 1, 1, + # ], + # ], + # ]) # ucell = ucell * 4 + 1 return ucell if __name__ == '__main__': - run_loop([0], [0], [0], [0]) + run_loop([2], [0,1,2], [0], [0]) diff --git a/meent/on_jax/emsolver/_base.py b/meent/on_jax/emsolver/_base.py index 2ad1188..18e3567 100644 --- a/meent/on_jax/emsolver/_base.py +++ b/meent/on_jax/emsolver/_base.py @@ -12,16 +12,17 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, device='cpu', type_complex=jnp.complex128): self.device = device self.type_complex = type_complex - if self.type_complex == jnp.complex128: + if self.type_complex == jnp.complex128: # TODO: need this? jax.config.update('jax_enable_x64', True) + # TODO: consider to make systematic and apply to other bds self.type_int = jnp.int64 if self.type_complex == jnp.complex128 else jnp.int32 self.type_float = jnp.float64 if self.type_complex == jnp.complex128 else jnp.float32 @@ -44,10 +45,13 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + # TODO: jit-stuff. apply other backends? + if len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] - self.period = period + self.period = deepcopy(period) self.wavelength = wavelength self.thickness = deepcopy(thickness) @@ -67,7 +71,7 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= def get_kx_vector(self, wavelength): k0 = 2 * jnp.pi / wavelength - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices = jnp.arange(-self.fourier_order[0], self.fourier_order[0] + 1) if self.grating_type == 0: kx_vector = k0 * (self.n_I * jnp.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) ).astype(self.type_complex) @@ -84,16 +88,18 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 # TODO: list? + + # fourier_indices = jnp.arange(-self.fourier_order[0], self.fourier_order[0] + 1) - delta_i0 = jnp.zeros(self.ff, dtype=self.type_complex) - delta_i0 = delta_i0.at[self.fourier_order].set(1) + delta_i0 = jnp.zeros(ff, dtype=self.type_complex) + delta_i0 = delta_i0.at[self.fourier_order[0]].set(1) k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ @@ -159,15 +165,16 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.T1 = None # fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = jnp.zeros(self.ff, dtype=self.type_complex) - delta_i0 = delta_i0.at[self.fourier_order].set(1) + delta_i0 = jnp.zeros(ff, dtype=self.type_complex) + delta_i0 = delta_i0.at[self.fourier_order[0]].set(1) k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -181,7 +188,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=self.type_complex, device=self.device) @@ -194,7 +201,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 @@ -211,21 +218,26 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + # fourier_indices = jnp.arange(-self.fourier_order, self.fourier_order + 1) + fourier_indices_y = jnp.arange(-self.fourier_order[1], self.fourier_order[1] + 1) + + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff_xy = ff_x * ff_y - delta_i0 = jnp.zeros((self.ff ** 2, 1), dtype=self.type_complex) - delta_i0 = delta_i0.at[self.ff ** 2 // 2, 0].set(1) + delta_i0 = jnp.zeros((ff_xy, 1), dtype=self.type_complex) + delta_i0 = delta_i0.at[ff_xy // 2, 0].set(1) - I = jnp.eye(self.ff ** 2).astype(self.type_complex) - O = jnp.zeros((self.ff ** 2, self.ff ** 2), dtype=self.type_complex) + I = jnp.eye(ff_xy).astype(self.type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=self.type_complex) - center = self.ff ** 2 + center = ff_xy k0 = 2 * jnp.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ @@ -239,36 +251,36 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv_i = jnp.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex, - device=self.device) + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, + device=self.device, type_complex=self.type_complex) # TODO: device? big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ = transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, - type_complex=self.type_complex) + type_complex=self.type_complex) # tODO: device? layer_info = [E_conv_i, q, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22, big_X, big_A_i, big_B, d] self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, - type_complex=self.type_complex) + type_complex=self.type_complex) # TODO: device? self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order, self.fourier_order) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_jax/emsolver/convolution_matrix.py b/meent/on_jax/emsolver/convolution_matrix.py index 287352e..594f51f 100644 --- a/meent/on_jax/emsolver/convolution_matrix.py +++ b/meent/on_jax/emsolver/convolution_matrix.py @@ -44,21 +44,17 @@ def cell_compression(cell, type_complex=jnp.complex128): # @partial(jax.jit, static_argnums=(1,2 )) -def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): +def fft_piecewise_constant(cell, fourier_order_x, fourier_order_y, type_complex=jnp.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] cell, x, y = cell_compression(cell, type_complex=type_complex) # X axis cell_next_x = jnp.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = jnp.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = jnp.arange(-2 * fourier_order_x, 2 * fourier_order_x + 1, 1) - f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes[None, :]).astype(type_complex) + f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes_x[None, :]).astype(type_complex) c = f_coeffs_x.shape[1] // 2 x_next = jnp.vstack((jnp.roll(x, -1, axis=0)[:-1], 1)) - x @@ -67,18 +63,18 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): assign_value = (cell @ jnp.vstack((x[0], x_next[:-1]))).flatten().astype(type_complex) f_coeffs_x = f_coeffs_x.at[assign_index].set(assign_value) - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_x[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) + assign_index = mask + assign_value = f_coeffs_x[:, mask] / (1j * 2 * jnp.pi * modes_x[mask]) f_coeffs_x = f_coeffs_x.at[:, assign_index].set(assign_value) # Y axis f_coeffs_x_next_y = jnp.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = jnp.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = jnp.arange(-2 * fourier_order_y, 2 * fourier_order_y + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes[None, :]).astype(type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes_y[None, :]).astype(type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = jnp.vstack((jnp.roll(y, -1, axis=0)[:-1], 1)) - y @@ -88,10 +84,10 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) if c: - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_xy[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + assign_index = mask + assign_value = f_coeffs_xy[:, mask] / (1j * 2 * jnp.pi * modes_y[mask]) f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) @@ -99,21 +95,14 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=jnp.complex128): # @partial(jax.jit, static_argnums=(1,2 )) # tODO: jit-able? -def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.complex128): - - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, type_complex=type_complex) - +def fft_piecewise_constant_vector(cell, x, y, fourier_order_x, fourier_order_y, type_complex=jnp.complex128): # X axis cell_next_x = jnp.roll(cell, -1, axis=1) cell_diff_x = cell_next_x - cell - modes = jnp.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1) + modes_x = jnp.arange(-2 * fourier_order_x, 2 * fourier_order_x + 1, 1) - f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes[None, :]).astype(type_complex) + f_coeffs_x = cell_diff_x @ jnp.exp(-1j * 2 * jnp.pi * x @ modes_x[None, :]).astype(type_complex) c = f_coeffs_x.shape[1] // 2 x_next = jnp.vstack((jnp.roll(x, -1, axis=0)[:-1], 1)) - x @@ -122,18 +111,18 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.co assign_value = (cell @ jnp.vstack((x[0], x_next[:-1]))).flatten().astype(type_complex) f_coeffs_x = f_coeffs_x.at[assign_index].set(assign_value) - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_x[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c+1, f_coeffs_x.shape[1])]) + assign_index = mask + assign_value = f_coeffs_x[:, mask] / (1j * 2 * jnp.pi * modes_x[mask]) f_coeffs_x = f_coeffs_x.at[:, assign_index].set(assign_value) # Y axis f_coeffs_x_next_y = jnp.roll(f_coeffs_x, -1, axis=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = jnp.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1) + modes_y = jnp.arange(-2 * fourier_order_y, 2 * fourier_order_y + 1, 1) - f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes[None, :]).astype(type_complex) + f_coeffs_xy = f_coeffs_x_diff_y.T @ jnp.exp(-1j * 2 * jnp.pi * y @ modes_y[None, :]).astype(type_complex) c = f_coeffs_xy.shape[1] // 2 y_next = jnp.vstack((jnp.roll(y, -1, axis=0)[:-1], 1)) - y @@ -143,22 +132,23 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=jnp.co f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) if c: - mask_int = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) + mask = jnp.hstack([jnp.arange(c), jnp.arange(c + 1, f_coeffs_x.shape[1])]) - assign_index = mask_int - assign_value = f_coeffs_xy[:, mask_int] / (1j * 2 * jnp.pi * modes[mask_int]) + assign_index = mask + assign_value = f_coeffs_xy[:, mask] / (1j * 2 * jnp.pi * modes_y[mask]) f_coeffs_xy = f_coeffs_xy.at[:, assign_index].set(assign_value) return f_coeffs_xy.T -def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, type_complex=jnp.complex128): +def to_conv_mat_continuous_vector(ucell_info_list, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128): - ff = 2 * fourier_order + 1 + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 - e_conv_all = jnp.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) - o_e_conv_all = jnp.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).astype(type_complex) + e_conv_all = jnp.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + o_e_conv_all = jnp.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) # 2D for i, ucell_info in enumerate(ucell_info_list): @@ -166,16 +156,30 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t ucell_layer = ucell_layer ** 2 f_coeffs = fft_piecewise_constant_vector(ucell_layer, x_list, y_list, - fourier_order, type_complex=type_complex) + fourier_order_x, fourier_order_y, type_complex=type_complex) o_f_coeffs = fft_piecewise_constant_vector(1/ucell_layer, x_list, y_list, - fourier_order, type_complex=type_complex) + fourier_order_x, fourier_order_y, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) + + + # conv_idx = jnp.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = jnp.repeat(conv_idx, ff, 1) + # conv_i = jnp.repeat(conv_i, ff, axis=0) + # conv_j = jnp.tile(conv_idx, (ff, ff)) + + + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + conv_i = jnp.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -186,19 +190,17 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=jnp.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128): + ucell_pmt = ucell ** 2 + + if ucell_pmt.shape[1] == 1: # 1D - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 + ff = 2 * fourier_order_x + 1 - if pmt.shape[1] == 1: # 1D - res = jnp.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + res = jnp.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) - for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + for i, layer in enumerate(ucell_pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order_x, fourier_order_y, type_complex=type_complex) center = f_coeffs.shape[1] // 2 conv_idx = jnp.arange(-ff + 1, ff, 1) conv_idx = circulant(conv_idx) @@ -206,41 +208,54 @@ def to_conv_mat_continuous(pmt, fourier_order, device=None, type_complex=jnp.com res = res.at[i].set(e_conv) else: # 2D - # attention on the order of axis (Z Y X) - res = jnp.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) - for i, layer in enumerate(pmt): - f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 + + res = jnp.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + + for i, layer in enumerate(ucell_pmt): + f_coeffs = fft_piecewise_constant(layer, fourier_order_x, fourier_order_y, type_complex=type_complex) center = jnp.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) + # conv_idx = jnp.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = jnp.repeat(conv_idx, ff, 1) + # conv_i = jnp.repeat(conv_i, ff, axis=0) + # conv_j = jnp.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res = res.at[i].set(e_conv) + + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + conv_i = jnp.repeat(conv_i, jnp.array([ff_x] * ff_y), axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res = res.at[i].set(e_conv) return res -@partial(jax.jit, static_argnums=(1, 2, 3, 4)) -def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=jnp.complex128, improve_dft=True): - pmt = pmt ** 2 +@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) # TODO +def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 + + if ucell_pmt.shape[1] == 1: # 1D - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - ff = 2 * fourier_order + 1 + ff = 2 * fourier_order_x + 1 - if pmt.shape[1] == 1: # 1D - res = jnp.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + res = jnp.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) @@ -256,33 +271,57 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=jnp.compl res = res.at[i].set(e_conv) else: # 2D - # attention on the order of axis (Z Y X) - res = jnp.zeros((pmt.shape[0], ff ** 2, ff ** 2)).astype(type_complex) + ff_x = 2 * fourier_order_x + 1 + ff_y = 2 * fourier_order_y + 1 + + res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) + + # if improve_dft: + # minimum_pattern_size_1 = 2 * ff * pmt.shape[1] + # minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + # else: + # minimum_pattern_size_1 = 2 * ff + # minimum_pattern_size_2 = 2 * ff + if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # 9 * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = minimum_pattern_size_1 // layer.shape[0] + for i, layer in enumerate(ucell_pmt): + if layer.shape[0] < minimum_pattern_size_y: + n = minimum_pattern_size_y // layer.shape[0] layer = jnp.repeat(layer, n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = minimum_pattern_size_2 // layer.shape[1] + if layer.shape[1] < minimum_pattern_size_x: + n = minimum_pattern_size_x // layer.shape[1] layer = jnp.repeat(layer, n + 1, axis=1) f_coeffs = jnp.fft.fftshift(jnp.fft.fft2(layer / layer.size)) center = jnp.array(f_coeffs.shape) // 2 - conv_idx = jnp.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) + # conv_idx = jnp.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # + # conv_i = jnp.repeat(conv_idx, ff, 1) + # conv_i = jnp.repeat(conv_i, ff, axis=0) + # conv_j = jnp.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res = res.at[i].set(e_conv) + + conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) + # conv_i = jnp.repeat(conv_i, [ff_x] * ff_y, axis=0) + conv_i = jnp.repeat(conv_i, jnp.array([ff_x] * ff_y), axis=0) + + conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) - conv_i = jnp.repeat(conv_idx, ff, 1) - conv_i = jnp.repeat(conv_i, ff, axis=0) - conv_j = jnp.tile(conv_idx, (ff, ff)) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res = res.at[i].set(e_conv) diff --git a/meent/on_jax/emsolver/rcwa.py b/meent/on_jax/emsolver/rcwa.py index 5428d70..48ea0d8 100644 --- a/meent/on_jax/emsolver/rcwa.py +++ b/meent/on_jax/emsolver/rcwa.py @@ -34,7 +34,7 @@ def __init__(self, type_complex=jnp.complex128, fft_type=0, improve_dft=True, - **kwargs, + **kwargs, # TODO: need htis? ): super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, @@ -77,7 +77,7 @@ def _tree_flatten(self): def _tree_unflatten(cls, aux_data, children): return cls(*children, **aux_data) - @jax.jit + # @jax.jit def _solve(self, wavelength, e_conv_all, o_e_conv_all): self.kx_vector = self.get_kx_vector(wavelength) @@ -93,19 +93,28 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector + return de_ri, de_ti # @jax.jit # TODO: can draw field? with jit? def conv_solve(self, **kwargs): - [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? + [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? if self.fft_type == 0: - E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) + E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex, improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex, improve_dft=self.improve_dft) elif self.fft_type == 1: - E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) - o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) + E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) + o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) elif self.fft_type == 2: E_conv_all, o_E_conv_all = to_conv_mat_continuous_vector(self.ucell_info_list, self.fourier_order, type_complex=self.type_complex) @@ -122,8 +131,8 @@ def conv_solve(self, **kwargs): @jax.jit def conv_solve_spectrum(self, ucell): # TODO: other backends - E_conv_all = to_conv_mat_discrete(ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) + E_conv_all = to_conv_mat_discrete(ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) return de_ri, de_ti @@ -149,6 +158,7 @@ def run_ucell_pmap(self, ucell_list): de_ti = np.array(de_ti) return de_ri, de_ti + # TODO: jit? fourier order split in args? def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: @@ -158,15 +168,15 @@ def calculate_field(self, resolution=None, plot=True): type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) else: resolution = [10, 10, 10] if not resolution else resolution - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) return field_cell diff --git a/meent/on_jax/emsolver/transfer_method.py b/meent/on_jax/emsolver/transfer_method.py index a3493ab..19fad12 100644 --- a/meent/on_jax/emsolver/transfer_method.py +++ b/meent/on_jax/emsolver/transfer_method.py @@ -4,8 +4,6 @@ import jax import jax.numpy as jnp -# import meent.on_jax.jitted as ee -# from . import jitted as ee from .primitives import eig @@ -41,7 +39,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = jnp.eye(2 * fourier_order + 1).astype(type_complex) + T = jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) return kx_vector, Kx, k_I_z, k_II_z, Kx, f, YZ_I, g, inc_term, T @@ -57,8 +55,8 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=jnp.compl a_i = jnp.linalg.inv(a) - f = W @ (jnp.eye(2 * fourier_order + 1).astype(type_complex) + X @ b @ a_i @ X) - g = V @ (jnp.eye(2 * fourier_order + 1).astype(type_complex) - X @ b @ a_i @ X) + f = W @ (jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) + X @ b @ a_i @ X) + g = V @ (jnp.eye(2 * fourier_order[0] + 1).astype(type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b @@ -230,15 +228,15 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, type_complex=jnp.complex128): - I = jnp.eye(ff ** 2).astype(type_complex) - O = jnp.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = jnp.eye(ff_xy).astype(type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) # kx_vector = k0 * (n_I * jnp.sin(theta) * jnp.cos(phi) + fourier_indices * ( # wavelength / period[0])).astype(type_complex) - ky_vector = k0 * (n_I * jnp.sin(theta) * jnp.sin(phi) + fourier_indices * ( + ky_vector = k0 * (n_I * jnp.sin(theta) * jnp.sin(phi) + fourier_indices_y * ( wavelength / period[1])).astype(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -247,8 +245,8 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, k_I_z = k_I_z.flatten().conjugate() k_II_z = k_II_z.flatten().conjugate() - Kx = jnp.diag(jnp.tile(kx_vector, ff).flatten()) / k0 - Ky = jnp.diag(jnp.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 + Kx = jnp.diag(jnp.tile(kx_vector, ff_y).flatten()) / k0 + Ky = jnp.diag(jnp.tile(ky_vector.reshape((-1, 1)), ff_x).flatten()) / k0 varphi = jnp.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() @@ -261,13 +259,13 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, big_F = jnp.block([[I, O], [O, 1j * Z_II]]) big_G = jnp.block([[1j * Y_II, O], [O, I]]) - big_T = jnp.eye(ff ** 2 * 2).astype(type_complex) + big_T = jnp.eye(2 * ff_xy).astype(type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'): - I = jnp.eye(ff ** 2).astype(type_complex) +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=jnp.complex128, perturbation=1E-10, device='cpu'): + I = jnp.eye(ff_xy).astype(type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -347,10 +345,10 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, typ return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=jnp.complex128): - I = jnp.eye(ff ** 2).astype(type_complex) - O = jnp.zeros((ff ** 2, ff ** 2), dtype=type_complex) + I = jnp.eye(ff_xy).astype(type_complex) + O = jnp.zeros((ff_xy, ff_xy), dtype=type_complex) big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] @@ -383,14 +381,14 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = jnp.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy:2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * jnp.conj(R_s) * jnp.real(k_I_z / (k0 * n_I * jnp.cos(theta))) \ + R_p * jnp.conj(R_p) * jnp.real((k_I_z / n_I ** 2) / (k0 * n_I * jnp.cos(theta))) diff --git a/meent/on_numpy/emsolver/_base.py b/meent/on_numpy/emsolver/_base.py index 643aa11..b4175f4 100644 --- a/meent/on_numpy/emsolver/_base.py +++ b/meent/on_numpy/emsolver/_base.py @@ -35,12 +35,8 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - # self.fourier_order = int(fourier_order) - self.fourier_order = [int(v) for v in fourier_order] # TODO: other bds - # self.ff = 2 * self.fourier_order[0] + 1 # TODO - self.period = deepcopy(period) self.wavelength = wavelength @@ -54,15 +50,16 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= if self.theta == 0: self.theta = self.perturbation + self.theta = np.where(self.theta == 0, self.perturbation, self.theta) # TODO: check whether correct - self.kx_vector = None + self.kx_vector = None # tODO: need this? why only kx, not ky? def get_kx_vector(self, wavelength): k0 = 2 * np.pi / wavelength fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) - if self.grating_type == 0: + if self.grating_type == 0: # TODO: need this? kx_vector = k0 * (self.n_I * np.sin(self.theta) + fourier_indices_x * (wavelength / self.period[0]) ).astype(self.type_complex) else: @@ -99,14 +96,16 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) # From the last layer - for layer_index in range(count)[::-1]: + # for layer_index in range(count)[::-1]: + # + # E_conv = E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + # d = self.thickness[layer_index] - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): if self.pol == 0: E_conv_i = None @@ -163,8 +162,6 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): ff = self.fourier_order[0] * 2 + 1 - # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - delta_i0 = np.zeros(ff, dtype=self.type_complex) delta_i0[self.fourier_order[0]] = 1 @@ -180,15 +177,16 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - - # From the last layer - for layer_index in range(count)[::-1]: - - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] + # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # + # # From the last layer + # for layer_index in range(count)[::-1]: + # + # E_conv = E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + # d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = np.linalg.inv(E_conv) o_E_conv_i = np.linalg.inv(o_E_conv) @@ -220,41 +218,28 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): - # TODO: cleaning - self.layer_info_list = [] self.T1 = None - # fourier_indices = np.arange(-self.fourier_order, self.fourier_order + 1) - # fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) fourier_indices_y = np.arange(-self.fourier_order[1], self.fourier_order[1] + 1) ff_x = self.fourier_order[0] * 2 + 1 ff_y = self.fourier_order[1] * 2 + 1 - ff = 2 * self.fourier_order[0] + 1 ff_xy = ff_x * ff_y - delta_i0 = np.zeros((ff ** 2, 1), dtype=self.type_complex) - delta_i0[ff ** 2 // 2, 0] = 1 - delta_i0 = np.zeros((ff_xy, 1), dtype=self.type_complex) delta_i0[ff_xy // 2, 0] = 1 - I = np.eye(ff ** 2, dtype=self.type_complex) - O = np.zeros((ff ** 2, ff ** 2), dtype=self.type_complex) - I = np.eye(ff_xy, dtype=self.type_complex) O = np.zeros((ff_xy, ff_xy), dtype=self.type_complex) - center = ff ** 2 - center = ff_xy k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, type_complex=self.type_complex) elif self.algo == 'SMM': @@ -263,20 +248,20 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - - # From the last layer - for layer_index in range(count)[::-1]: - - E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] - d = self.thickness[layer_index] - + # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) + # + # # From the last layer + # for layer_index in range(count)[::-1]: + # + # E_conv = E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + # d = self.thickness[layer_index] + for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = np.linalg.inv(E_conv) o_E_conv_i = np.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(ff, ff_x, ff_y, ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 \ @@ -287,20 +272,20 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, q = scattering_2d_wv(ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, q = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, ff_x, ff_y, ff_xy, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, ff) + self.pol, self.theta, self.phi, self.fourier_order) else: raise ValueError de_ri = de_ri.reshape((ff_y, ff_x)).real diff --git a/meent/on_numpy/emsolver/convolution_matrix.py b/meent/on_numpy/emsolver/convolution_matrix.py index a1f7f32..e425e96 100644 --- a/meent/on_numpy/emsolver/convolution_matrix.py +++ b/meent/on_numpy/emsolver/convolution_matrix.py @@ -42,10 +42,8 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): reference: reticolo """ - if cell.shape[0] == 1: # tODO + if len(fourier_order) == 1: fourier_order = fourier_order + [0] - # else: - # fourier_order = [fourier_order, fourier_order] cell, x, y = cell_compression(cell, type_complex=type_complex) @@ -88,15 +86,8 @@ def fft_piecewise_constant(cell, fourier_order, type_complex=np.complex128): def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.complex128): - """ - reference: reticolo - """ - - # if cell.shape[0] == 1: - # fourier_order = [0, fourier_order] - # else: - # fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, type_complex=type_complex) + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] # X axis cell_next_x = np.roll(cell, -1, axis=1) @@ -107,7 +98,6 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com f_coeffs_x = cell_diff_x @ np.exp(-1j * 2 * np.pi * x @ modes_x[None, :], dtype=type_complex) c = f_coeffs_x.shape[1] // 2 - # x_next = np.vstack(np.roll(x, -1, axis=0)[:-1]) - x x_next = np.vstack((np.roll(x, -1, axis=0)[:-1], 1)) - x f_coeffs_x[:, c] = (cell @ np.vstack((x[0], x_next[:-1]))).flatten() @@ -138,13 +128,10 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, type_complex=np.com def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, type_complex=np.complex128): # TODO: do conv and 1/conv in simultaneously? - # ff = 2 * fourier_order + 1 ff_x = 2 * fourier_order[0] + 1 ff_y = 2 * fourier_order[1] + 1 - ff = fourier_order[0] + fourier_order[1] + 1 # tODO - e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) o_e_conv_all = np.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).astype(type_complex) @@ -152,17 +139,34 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=None, t for i, ucell_info in enumerate(ucell_info_list): ucell_layer, x_list, y_list = ucell_info ucell_layer = ucell_layer ** 2 + f_coeffs = fft_piecewise_constant_vector(ucell_layer, x_list, y_list, fourier_order, type_complex=type_complex) o_f_coeffs = fft_piecewise_constant_vector(1/ucell_layer, x_list, y_list, fourier_order, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] + # + # e_conv_all[i] = e_conv + # o_e_conv_all[i] = o_e_conv + + # TODO: check correct + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -177,19 +181,14 @@ def to_conv_mat_continuous(ucell, fourier_order, device=None, type_complex=np.co ucell_pmt = ucell ** 2 # TODO: do conv and 1/conv in simultaneously? - if len(ucell_pmt.shape) == 2: - print('shape is 2') - raise ValueError if ucell_pmt.shape[1] == 1: # 1D - ff = 2 * fourier_order[0] + 1 res = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, type_complex=type_complex) - center = f_coeffs.shape[1] // 2 conv_idx = np.arange(-ff + 1, ff, 1, dtype=int) conv_idx = circulant(conv_idx) @@ -198,17 +197,11 @@ def to_conv_mat_continuous(ucell, fourier_order, device=None, type_complex=np.co else: # 2D - # ff = 2 * fourier_order + 1 - # TODO: cleaning - ff = 2 * fourier_order[0] + 1 ff_x = 2 * fourier_order[0] + 1 ff_y = 2 * fourier_order[1] + 1 - # ff = fourier_order[0] + fourier_order[1] + 1 # tODO - - # attention on the order of axis (Z Y X) res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) for i, layer in enumerate(ucell_pmt): @@ -236,27 +229,21 @@ def to_conv_mat_continuous(ucell, fourier_order, device=None, type_complex=np.co e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv - return res -def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.complex128, improve_dft=True): - pmt = pmt ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - - if pmt.shape[1] == 1: # 1D - - ff = 2 * fourier_order + 1 +def to_conv_mat_discrete(ucell, fourier_order, device=None, type_complex=np.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 - res = np.zeros((pmt.shape[0], ff, ff)).astype(type_complex) + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 + res = np.zeros((ucell_pmt.shape[0], ff, ff)).astype(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) @@ -265,48 +252,60 @@ def to_conv_mat_discrete(pmt, fourier_order, device=None, type_complex=np.comple # https://kr.mathworks.com/matlabcentral/answers/15770-scaling-the-fft-and-the-ifft?s_tid=srchtitle center = f_coeffs.shape[1] // 2 - conv_idx = np.arange(-ff + 1, ff, 1, dtype=int) conv_idx = circulant(conv_idx) e_conv = f_coeffs[0, center + conv_idx] res[i] = e_conv else: # 2D - # ff = 2 * fourier_order + 1 - ff_x = 2 * fourier_order[0] + 1 ff_y = 2 * fourier_order[1] + 1 - ff = fourier_order[0] + fourier_order[1] + 1 # tODO + res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) - # attention on the order of axis (Z Y X) - # TODO: separate fourier order - res = np.zeros((pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # e.g., 9 bytes * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = minimum_pattern_size_1 // layer.shape[0] + for i, layer in enumerate(ucell_pmt): + + if layer.shape[0] < minimum_pattern_size_y: + n = minimum_pattern_size_y // layer.shape[0] layer = np.repeat(layer, n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = minimum_pattern_size_2 // layer.shape[1] + + if layer.shape[1] < minimum_pattern_size_x: + n = minimum_pattern_size_x // layer.shape[1] layer = np.repeat(layer, n + 1, axis=1) f_coeffs = np.fft.fftshift(np.fft.fft2(layer / layer.size)) center = np.array(f_coeffs.shape) // 2 - conv_idx = np.arange(-ff + 1, ff, 1) - conv_idx = circulant(conv_idx) - conv_i = np.repeat(conv_idx, ff, axis=1) - conv_i = np.repeat(conv_i, [ff] * ff, axis=0) - conv_j = np.tile(conv_idx, (ff, ff)) + # conv_idx = np.arange(-ff + 1, ff, 1) + # conv_idx = circulant(conv_idx) + # + # conv_i = np.repeat(conv_idx, ff, axis=1) + # conv_i = np.repeat(conv_i, [ff] * ff, axis=0) + # conv_j = np.tile(conv_idx, (ff, ff)) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + # + + # TODO: check correct + conv_idx_y = np.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y) + conv_i = np.repeat(conv_idx_y, ff_x, axis=1) + conv_i = np.repeat(conv_i, [ff_x] * ff_y, axis=0) + + conv_idx_x = np.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv diff --git a/meent/on_numpy/emsolver/rcwa.py b/meent/on_numpy/emsolver/rcwa.py index a0f4e38..c30fd46 100644 --- a/meent/on_numpy/emsolver/rcwa.py +++ b/meent/on_numpy/emsolver/rcwa.py @@ -31,12 +31,13 @@ def __init__(self, type_complex=np.complex128, fft_type=0, improve_dft=True, + **kwargs, ): super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, fourier_order=fourier_order, period=period, wavelength=wavelength, thickness=thickness, algo=algo, perturbation=perturbation, - device=device, type_complex=type_complex,) + device=device, type_complex=type_complex, ) self.ucell = deepcopy(ucell) self.ucell_materials = ucell_materials @@ -66,7 +67,12 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector + return de_ri, de_ti def conv_solve(self, *args, **kwargs): @@ -82,7 +88,7 @@ def conv_solve(self, *args, **kwargs): o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) elif self.fft_type == 2: E_conv_all, o_E_conv_all = to_conv_mat_continuous_vector(self.ucell_info_list, self.fourier_order, - type_complex=self.type_complex) + type_complex=self.type_complex) else: raise ValueError @@ -98,23 +104,21 @@ def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, - self.layer_info_list, self.period, self.pol, resolution=resolution, + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, + self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, - self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) else: resolution = [10, 10, 10] if not resolution else resolution t0 = time.time() - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - type_complex=self.type_complex) - print(time.time() - t0) + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) diff --git a/meent/on_numpy/emsolver/transfer_method.py b/meent/on_numpy/emsolver/transfer_method.py index 49074f2..a83bbc7 100644 --- a/meent/on_numpy/emsolver/transfer_method.py +++ b/meent/on_numpy/emsolver/transfer_method.py @@ -60,14 +60,13 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, type_complex=np.comple def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): - T1 = np.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) R = f1 @ T1 - delta_i0 T = T @ T1 de_ri = np.real(R * np.conj(R) * k_I_z / (k0 * n_I * np.cos(theta))) if polarization == 0: - # de_ti = T * np.conj(T) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) + # de_ti = T * np.conj(T) * np.real(k_II_z / (k0 * n_I * np.cos(theta))) #TODO: use this de_ti = np.real(T * np.conj(T) * k_II_z / (k0 * n_I * np.cos(theta))) elif polarization == 1: # de_ti = T * np.conj(T) * np.real(k_II_z / n_II ** 2) / (k0 * np.cos(theta) / n_I) @@ -114,7 +113,6 @@ def transfer_1d_conical_1(ff, k0, n_I, n_II, kx_vector, theta, phi, type_complex def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, type_complex=np.complex128): - I = np.eye(ff, dtype=type_complex) O = np.zeros((ff, ff), dtype=type_complex) @@ -178,7 +176,6 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=np.complex128): - I = np.eye(ff, dtype=type_complex) O = np.zeros((ff, ff), dtype=type_complex) @@ -229,13 +226,8 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 -def transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, type_complex=np.complex128): - - # TODO: cleaning - I = np.eye(ff ** 2, dtype=type_complex) - O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) - I = np.eye(ff_xy, dtype=type_complex) O = np.zeros((ff_xy, ff_xy), dtype=type_complex) @@ -251,9 +243,6 @@ def transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fouri k_I_z = k_I_z.flatten().conjugate() k_II_z = k_II_z.flatten().conjugate() - Kx = np.diag(np.tile(kx_vector, ff).flatten()) / k0 - Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff).flatten()) / k0 - Kx = np.diag(np.tile(kx_vector, ff_y).flatten()) / k0 Ky = np.diag(np.tile(ky_vector.reshape((-1, 1)), ff_x).flatten()) / k0 @@ -268,16 +257,12 @@ def transfer_2d_1(ff, ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fouri big_F = np.block([[I, O], [O, 1j * Z_II]]) big_G = np.block([[1j * Y_II, O], [O, I]]) - big_T = np.eye(2 * ff ** 2, dtype=type_complex) big_T = np.eye(2 * ff_xy, dtype=type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, ff_x, ff_y, ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): - - # TODO: cleaning - I = np.eye(ff ** 2, dtype=type_complex) +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=np.complex128): I = np.eye(ff_xy, dtype=type_complex) B = Kx @ E_conv_i @ Kx - I @@ -359,12 +344,8 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, typ return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, ff_x, ff_y, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, type_complex=np.complex128): - # TODO: cleaning - I = np.eye(ff ** 2, dtype=type_complex) - O = np.zeros((ff ** 2, ff ** 2), dtype=type_complex) - I = np.eye(ff_xy, dtype=type_complex) O = np.zeros((ff_xy, ff_xy), dtype=type_complex) @@ -399,19 +380,12 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, ff_x, f final_RT = np.linalg.inv(final_A) @ final_B - # R_s = final_RT[:ff ** 2, :].flatten() - # R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() - R_s = final_RT[:ff_xy, :].flatten() R_p = final_RT[ff_xy: 2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - # T_s = big_T[:ff ** 2, :].flatten() - # T_p = big_T[ff ** 2:, :].flatten() - T_s = big_T[:ff_xy, :].flatten() T_p = big_T[ff_xy:, :].flatten() @@ -422,4 +396,3 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, ff_x, f + T_p * np.conj(T_p) * np.real((k_II_z / n_II ** 2) / (k0 * n_I * np.cos(theta))) return de_ri.real, de_ti.real, big_T1 - diff --git a/meent/on_torch/emsolver/_base.py b/meent/on_torch/emsolver/_base.py index 4608760..002a6cc 100644 --- a/meent/on_torch/emsolver/_base.py +++ b/meent/on_torch/emsolver/_base.py @@ -12,23 +12,20 @@ class _BaseRCWA: - def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=10, + def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol=0, fourier_order=(2, 2), period=(100, 100), wavelength=900, thickness=None, algo='TMM', perturbation=1E-10, - device='cpu', type_complex=torch.complex128, **kwargs): + device='cpu', type_complex=torch.complex128): self.device = device self.type_complex = type_complex # common self.grating_type = grating_type # 1D=0, 1D_conical=1, 2D=2 + self.n_I = n_I self.n_II = n_II - # self.theta = torch.tensor(theta * np.pi / 180) - # self.phi = torch.tensor(phi * np.pi / 180) - # self.psi = torch.tensor(psi * np.pi / 180) - # degree to radian due to JAX JIT self.theta = torch.tensor(theta) self.phi = torch.tensor(phi) @@ -43,8 +40,7 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = int(fourier_order) - self.ff = 2 * self.fourier_order + 1 + self.fourier_order = [int(v) for v in fourier_order] self.period = deepcopy(period) @@ -57,15 +53,16 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - if self.theta == 0: - self.theta = torch.tensor(self.perturbation) + # if self.theta == 0: + # self.theta = torch.tensor(self.perturbation) + self.theta = torch.where(self.theta == 0, self.perturbation, self.theta) # TODO: check correct? self.kx_vector = None def get_kx_vector(self, wavelength): k0 = 2 * np.pi / wavelength - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + fourier_indices = torch.arange(-self.fourier_order[0], self.fourier_order[0] + 1, device=self.device) if self.grating_type == 0: kx_vector = k0 * (self.n_I * torch.sin(self.theta) + fourier_indices * (wavelength / self.period[0]) ).type(self.type_complex) @@ -82,16 +79,18 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) - delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + ff = self.fourier_order[0] * 2 + 1 + + delta_i0 = torch.zeros(ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T \ - = transfer_1d_1(self.ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, + = transfer_1d_1(ff, self.pol, k0, self.n_I, self.n_II, self.kx_vector, self.theta, delta_i0, self.fourier_order, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': @@ -110,6 +109,9 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] + # Can't use this. seems like bug in Torch. + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + if self.pol == 0: E_conv_i = None A = Kx ** 2 - E_conv @@ -168,16 +170,17 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + ff = self.fourier_order[0] * 2 + 1 - delta_i0 = torch.zeros(self.ff, device=self.device, dtype=self.type_complex) - delta_i0[self.fourier_order] = 1 + delta_i0 = torch.zeros(ff, device=self.device, dtype=self.type_complex) + delta_i0[self.fourier_order[0]] = 1 k0 = 2 * np.pi / wavelength if self.algo == 'TMM': Kx, ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_1d_conical_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, + = transfer_1d_conical_1(ff, k0, self.n_I, self.n_II, self.kx_vector, self.theta, self.phi, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': print('SMM for 1D conical is not implemented') @@ -194,12 +197,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] + # Can't use this. Seems like bug in Torch. + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): E_conv_i = torch.linalg.inv(E_conv) o_E_conv_i = torch.linalg.inv(o_E_conv) if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\ - = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, self.ff, d, + = transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varphi, big_F, big_G, big_T, device=self.device, type_complex=self.type_complex) @@ -212,7 +217,7 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, device=self.device, type_complex=self.type_complex) self.T1 = big_T1 @@ -229,21 +234,26 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list = [] self.T1 = None - fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + # fourier_indices = torch.arange(-self.fourier_order, self.fourier_order + 1, device=self.device) + fourier_indices_y = torch.arange(-self.fourier_order[1], self.fourier_order[1] + 1, device=self.device) - delta_i0 = torch.zeros((self.ff ** 2, 1), device=self.device, dtype=self.type_complex) - delta_i0[self.ff ** 2 // 2, 0] = 1 + ff_x = self.fourier_order[0] * 2 + 1 + ff_y = self.fourier_order[1] * 2 + 1 + ff_xy = ff_x * ff_y - I = torch.eye(self.ff ** 2, device=self.device, dtype=self.type_complex) - O = torch.zeros((self.ff ** 2, self.ff ** 2), device=self.device, dtype=self.type_complex) + delta_i0 = torch.zeros((ff_xy, 1), device=self.device, dtype=self.type_complex) + delta_i0[ff_xy // 2, 0] = 1 - center = self.ff ** 2 + I = torch.eye(ff_xy, device=self.device, dtype=self.type_complex) + O = torch.zeros((ff_xy, ff_xy), device=self.device, dtype=self.type_complex) + + center = ff_xy k0 = 2 * np.pi / wavelength if self.algo == 'TMM': kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T \ - = transfer_2d_1(self.ff, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices, + = transfer_2d_1(ff_x, ff_y, ff_xy, k0, self.n_I, self.n_II, self.kx_vector, self.period, fourier_indices_y, self.theta, self.phi, wavelength, device=self.device, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Ky, kz_inc, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ @@ -251,6 +261,8 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError + # for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): + count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) # From the last layer @@ -259,12 +271,11 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): E_conv = E_conv_all[layer_index] o_E_conv = o_E_conv_all[layer_index] d = self.thickness[layer_index] - E_conv_i = torch.linalg.inv(E_conv) o_E_conv_i = torch.linalg.inv(o_E_conv) if self.algo == 'TMM': - W, V, q = transfer_2d_wv(self.ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, + W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device=self.device, type_complex=self.type_complex) big_X, big_F, big_G, big_T, big_A_i, big_B, \ @@ -276,24 +287,24 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(self.ff, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError if self.algo == 'TMM': - de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, self.ff, + de_ri, de_ti, big_T1 = transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, self.psi, self.theta, ff_xy, delta_i0, k_I_z, k0, self.n_I, self.n_II, k_II_z, device=self.device, type_complex=self.type_complex) self.T1 = big_T1 elif self.algo == 'SMM': de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, - self.pol, self.theta, self.phi, self.fourier_order, self.ff) + self.pol, self.theta, self.phi, self.fourier_order) else: raise ValueError - de_ri = de_ri.reshape((self.ff, self.ff)).real - de_ti = de_ti.reshape((self.ff, self.ff)).real + de_ri = de_ri.reshape((ff_y, ff_x)).real + de_ti = de_ti.reshape((ff_y, ff_x)).real return de_ri, de_ti, self.layer_info_list, self.T1 diff --git a/meent/on_torch/emsolver/convolution_matrix.py b/meent/on_torch/emsolver/convolution_matrix.py index 91cbeed..13ec428 100644 --- a/meent/on_torch/emsolver/convolution_matrix.py +++ b/meent/on_torch/emsolver/convolution_matrix.py @@ -39,37 +39,37 @@ def cell_compression(cell, device=torch.device('cpu'), type_complex=torch.comple def fft_piecewise_constant(cell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] - else: - fourier_order = [fourier_order, fourier_order] + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] + cell, x, y = cell_compression(cell, device=device, type_complex=type_complex) # X axis cell_next_x = torch.roll(cell, -1, dims=1) cell_diff_x = cell_next_x - cell + cell_diff_x = cell_diff_x.type(type_complex) - modes = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) + cell = cell.type(type_complex) - cell_diff_x = cell_diff_x.type(type_complex) - f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes[None, :]).type(type_complex) + modes_x = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + + f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes_x[None, :]).type(type_complex) c = f_coeffs_x.shape[1] // 2 - cell = cell.type(type_complex) x_next = torch.vstack((torch.roll(x, -1, dims=0)[:-1], torch.tensor([1], device=device))) - x f_coeffs_x[:, c] = (cell @ torch.vstack((x[0], x_next[:-1]))).flatten() mask = torch.ones(f_coeffs_x.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = torch.roll(f_coeffs_x, -1, dims=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + modes_y = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) - f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes[None, :]) + f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes_y[None, :]) c = f_coeffs_xy.shape[1] // 2 y_next = torch.vstack((torch.roll(y, -1, dims=0)[:-1], torch.tensor([1], device=device))) - y @@ -79,43 +79,41 @@ def fft_piecewise_constant(cell, fourier_order, device=torch.device('cpu'), type if c: mask = torch.ones(f_coeffs_xy.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def fft_piecewise_constant_vector(cell, x, y, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - if cell.shape[0] == 1: - fourier_order = [0, fourier_order] # tODO - else: - fourier_order = [fourier_order, fourier_order] - # cell, x, y = cell_compression(cell, device=device, type_complex=type_complex) + if len(fourier_order) == 1: + fourier_order = fourier_order + [0] # X axis cell_next_x = torch.roll(cell, -1, dims=1) cell_diff_x = cell_next_x - cell + cell_diff_x = cell_diff_x.type(type_complex) - modes = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) + cell = cell.type(type_complex) - cell_diff_x = cell_diff_x.type(type_complex) - f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes[None, :]).type(type_complex) + modes_x = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + + f_coeffs_x = cell_diff_x @ torch.exp(-1j * 2 * np.pi * x @ modes_x[None, :]).type(type_complex) c = f_coeffs_x.shape[1] // 2 - cell = cell.type(type_complex) x_next = torch.vstack((torch.roll(x, -1, dims=0)[:-1], torch.tensor([1], device=device))) - x f_coeffs_x[:, c] = (cell @ torch.vstack((x[0], x_next[:-1]))).flatten() mask = torch.ones(f_coeffs_x.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_x[:, mask] /= (1j * 2 * np.pi * modes_x[mask]) # Y axis f_coeffs_x_next_y = torch.roll(f_coeffs_x, -1, dims=0) f_coeffs_x_diff_y = f_coeffs_x_next_y - f_coeffs_x - modes = torch.arange(-2 * fourier_order[0], 2 * fourier_order[0] + 1, 1, device=device).type(type_complex) + modes_y = torch.arange(-2 * fourier_order[1], 2 * fourier_order[1] + 1, 1, device=device).type(type_complex) - f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes[None, :]) + f_coeffs_xy = f_coeffs_x_diff_y.T @ torch.exp(-1j * 2 * np.pi * y @ modes_y[None, :]) c = f_coeffs_xy.shape[1] // 2 y_next = torch.vstack((torch.roll(y, -1, dims=0)[:-1], torch.tensor([1], device=device))) - y @@ -125,17 +123,18 @@ def fft_piecewise_constant_vector(cell, x, y, fourier_order, device=torch.device if c: mask = torch.ones(f_coeffs_xy.shape[1], device=device).type(torch.bool) mask[c] = False - f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes[mask]) + f_coeffs_xy[:, mask] /= (1j * 2 * np.pi * modes_y[mask]) return f_coeffs_xy.T def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - ff = 2 * fourier_order + 1 + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 - e_conv_all = torch.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).type(type_complex) - o_e_conv_all = torch.zeros((len(ucell_info_list), ff ** 2, ff ** 2)).type(type_complex) + e_conv_all = torch.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).type(type_complex) + o_e_conv_all = torch.zeros((len(ucell_info_list), ff_x * ff_y, ff_x * ff_y)).type(type_complex) # 2D # tODO: 1D for i, ucell_info in enumerate(ucell_info_list): @@ -152,11 +151,26 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.d center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] + # + # e_conv_all[i] = e_conv + # o_e_conv_all[i] = o_e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -167,21 +181,15 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order, device=torch.d return e_conv_all, o_e_conv_all -def to_conv_mat_continuous(pmt, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): - pmt = pmt ** 2 +def to_conv_mat_continuous(ucell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128): + ucell_pmt = ucell ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 - # pmt = torch.tensor(pmt) if type(pmt) != torch.Tensor else pmt + res = torch.zeros((ucell_pmt.shape[0], ff, ff), device=device).type(type_complex) - ff = 2 * fourier_order + 1 - - if pmt.shape[1] == 1: # 1D - res = torch.zeros((pmt.shape[0], ff, ff), device=device).type(type_complex) - - for i, layer in enumerate(pmt): + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) center = f_coeffs.shape[1] // 2 conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) @@ -190,39 +198,49 @@ def to_conv_mat_continuous(pmt, fourier_order, device=torch.device('cpu'), type_ res[i] = e_conv else: # 2D - # attention on the order of axis (Z Y X) - res = torch.zeros((pmt.shape[0], ff ** 2, ff ** 2), device=device).type(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 - for i, layer in enumerate(pmt): + res = torch.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y), device=device).type(type_complex) + + for i, layer in enumerate(ucell_pmt): f_coeffs = fft_piecewise_constant(layer, fourier_order, device=device, type_complex=type_complex) center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) + e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv return res -def to_conv_mat_discrete(pmt, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128, improve_dft=True): - pmt = pmt ** 2 +def to_conv_mat_discrete(ucell, fourier_order, device=torch.device('cpu'), type_complex=torch.complex128, improve_dft=True): + ucell_pmt = ucell ** 2 - if len(pmt.shape) == 2: - print('shape is 2') - raise ValueError - - ff = 2 * fourier_order + 1 - - if pmt.shape[1] == 1: # 1D - res = torch.zeros((pmt.shape[0], ff, ff), device=device).type(type_complex) + if ucell_pmt.shape[1] == 1: # 1D + ff = 2 * fourier_order[0] + 1 + res = torch.zeros((ucell_pmt.shape[0], ff, ff), device=device).type(type_complex) if improve_dft: - minimum_pattern_size = 2 * ff * pmt.shape[2] + minimum_pattern_size = 2 * ff * ucell_pmt.shape[2] else: minimum_pattern_size = 2 * ff - for i, layer in enumerate(pmt): + + for i, layer in enumerate(ucell_pmt): n = minimum_pattern_size // layer.shape[1] layer = layer.repeat_interleave(n + 1, axis=1) @@ -235,33 +253,58 @@ def to_conv_mat_discrete(pmt, fourier_order, device=torch.device('cpu'), type_co res[i] = e_conv else: # 2D - res = torch.zeros((pmt.shape[0], ff ** 2, ff ** 2), device=device).type(type_complex) + ff_x = 2 * fourier_order[0] + 1 + ff_y = 2 * fourier_order[1] + 1 + + res = torch.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y), device=device).type(type_complex) + + # if improve_dft: + # minimum_pattern_size_1 = 2 * ff * ucell_pmt.shape[1] + # minimum_pattern_size_2 = 2 * ff * ucell_pmt.shape[2] + # else: + # minimum_pattern_size_1 = 2 * ff + # minimum_pattern_size_2 = 2 * ff + if improve_dft: - minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - minimum_pattern_size_2 = 2 * ff * pmt.shape[2] + minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] + minimum_pattern_size_x = 2 * ff_x * ucell_pmt.shape[2] else: - minimum_pattern_size_1 = 2 * ff - minimum_pattern_size_2 = 2 * ff + minimum_pattern_size_y = 2 * ff_y + minimum_pattern_size_x = 2 * ff_x + # 9 * (40*500) * (40*500) / 1E6 = 3600 MB = 3.6 GB - for i, layer in enumerate(pmt): - if layer.shape[0] < minimum_pattern_size_1: - n = torch.div(minimum_pattern_size_1, layer.shape[0], rounding_mode='trunc') + for i, layer in enumerate(ucell_pmt): + if layer.shape[0] < minimum_pattern_size_y: + n = torch.div(minimum_pattern_size_y, layer.shape[0], rounding_mode='trunc') layer = layer.repeat_interleave(n + 1, axis=0) - if layer.shape[1] < minimum_pattern_size_2: - n = torch.div(minimum_pattern_size_2, layer.shape[1], rounding_mode='trunc') + if layer.shape[1] < minimum_pattern_size_x: + n = torch.div(minimum_pattern_size_x, layer.shape[1], rounding_mode='trunc') layer = layer.repeat_interleave(n + 1, axis=1) f_coeffs = torch.fft.fftshift(torch.fft.fft2(layer / (layer.size(0)*layer.size(1)))) center = torch.div(torch.tensor(f_coeffs.shape, device=device), 2, rounding_mode='trunc') - conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) - conv_idx = circulant(conv_idx, device) + # conv_idx = torch.arange(-ff + 1, ff, 1, device=device).type(torch.long) + # conv_idx = circulant(conv_idx, device) + # + # conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) + # conv_i = conv_i.repeat_interleave(ff, dim=0) + # conv_j = conv_idx.repeat(ff, ff).type(torch.long) + # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] + # res[i] = e_conv + + conv_idx_y = torch.arange(-ff_y + 1, ff_y, 1) + conv_idx_y = circulant(conv_idx_y, device=device) + conv_i = conv_idx_y.repeat_interleave(ff_x, dim=1).type(torch.long) + conv_i = conv_i.repeat_interleave(ff_x, dim=0) + + conv_idx_x = torch.arange(-ff_x + 1, ff_x, 1) + conv_idx_x = circulant(conv_idx_x) + conv_j = conv_idx_x.repeat(ff_y, ff_y).type(torch.long) - conv_i = conv_idx.repeat_interleave(ff, dim=1).type(torch.long) - conv_i = conv_i.repeat_interleave(ff, dim=0) - conv_j = conv_idx.repeat(ff, ff).type(torch.long) e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] res[i] = e_conv + return res diff --git a/meent/on_torch/emsolver/rcwa.py b/meent/on_torch/emsolver/rcwa.py index 4b924d7..1bbb65b 100644 --- a/meent/on_torch/emsolver/rcwa.py +++ b/meent/on_torch/emsolver/rcwa.py @@ -1,3 +1,5 @@ +from copy import deepcopy + import torch from ._base import _BaseRCWA @@ -34,9 +36,9 @@ def __init__(self, super().__init__(grating_type=grating_type, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, pol=pol, fourier_order=fourier_order, period=period, wavelength=wavelength, thickness=thickness, algo=algo, perturbation=perturbation, - device=device, type_complex=type_complex, **kwargs) + device=device, type_complex=type_complex) - self.ucell = ucell + self.ucell = deepcopy(ucell) # TODO: deepcopy? self.ucell_materials = ucell_materials self.ucell_info_list = ucell_info_list @@ -64,16 +66,22 @@ def _solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri.real, de_ti.real, layer_info_list, T1, self.kx_vector def solve(self, wavelength, e_conv_all, o_e_conv_all): - de_ri, de_ti, layer_info_list, T1, self.kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) - return de_ri, de_ti + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(wavelength, e_conv_all, o_e_conv_all) + + self.layer_info_list = layer_info_list + self.T1 = T1 + self.kx_vector = kx_vector - def conv_solve(self, *args, **kwargs): + return de_ri, de_ti + def conv_solve(self, *args, **kwargs): # TODO: delete args? [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? if self.fft_type == 0: - E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) - o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, improve_dft=self.improve_dft) + E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order, type_complex=self.type_complex, + improve_dft=self.improve_dft) + o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order, type_complex=self.type_complex, + improve_dft=self.improve_dft) elif self.fft_type == 1: E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order, type_complex=self.type_complex) o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order, type_complex=self.type_complex) @@ -95,20 +103,21 @@ def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, - self.layer_info_list, self.period, self.pol, resolution=resolution, + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, + self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, device=self.device, type_complex=self.type_complex) elif self.grating_type == 1: resolution = [100, 1, 100] if not resolution else resolution - field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, - self.T1, self.layer_info_list, self.period, resolution=resolution, - device=self.device, type_complex=self.type_complex) + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, device=self.device, + type_complex=self.type_complex) else: resolution = [100, 100, 100] if not resolution else resolution - field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, self.fourier_order, self.T1, - self.layer_info_list, self.period, resolution=resolution, - device=self.device, type_complex=self.type_complex) + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, device=self.device, type_complex=self.type_complex) if plot: field_plot(field_cell, self.pol) diff --git a/meent/on_torch/emsolver/transfer_method.py b/meent/on_torch/emsolver/transfer_method.py index cc48ee9..0393491 100644 --- a/meent/on_torch/emsolver/transfer_method.py +++ b/meent/on_torch/emsolver/transfer_method.py @@ -37,7 +37,7 @@ def transfer_1d_1(ff, polarization, k0, n_I, n_II, kx_vector, theta, delta_i0, f else: raise ValueError - T = torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + T = torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) return kx_vector, Kx, k_I_z, k_II_z, f, YZ_I, g, inc_term, T @@ -54,15 +54,14 @@ def transfer_1d_2(k0, q, d, W, V, f, g, fourier_order, T, device='cpu', type_com a_i = torch.linalg.inv(a) - f = W @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) + X @ b @ a_i @ X) - g = V @ (torch.eye(2 * fourier_order + 1, device=device, dtype=type_complex) - X @ b @ a_i @ X) + f = W @ (torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) + X @ b @ a_i @ X) + g = V @ (torch.eye(2 * fourier_order[0] + 1, device=device, dtype=type_complex) - X @ b @ a_i @ X) T = T @ a_i @ X return X, f, g, T, a_i, b def transfer_1d_3(g1, YZ_I, f1, delta_i0, inc_term, T, k_I_z, k0, n_I, n_II, theta, polarization, k_II_z): - T1 = torch.linalg.inv(g1 + 1j * YZ_I @ f1) @ (1j * YZ_I @ delta_i0 + inc_term) R = f1 @ T1 - delta_i0 T = T @ T1 @@ -205,7 +204,6 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, bi def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, device='cpu', type_complex=torch.complex128): - I = torch.eye(ff, device=device, dtype=type_complex) O = torch.zeros((ff, ff), device=device, dtype=type_complex) @@ -255,16 +253,15 @@ def transfer_1d_conical_3(big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i return de_ri.real, de_ti.real, big_T1 - -def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, phi, wavelength, +def transfer_2d_1(ff_x, ff_y, ff_xy, k0, n_I, n_II, kx_vector, period, fourier_indices_y, theta, phi, wavelength, device='cpu', type_complex=torch.complex128): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) - O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) + O = torch.zeros((ff_xy, ff_xy), device=device, dtype=type_complex) # kx_vector = k0 * (n_I * torch.sin(theta) * torch.cos(phi) + fourier_indices * ( # wavelength / period[0])).type(type_complex) - ky_vector = k0 * (n_I * torch.sin(theta) * torch.sin(phi) + fourier_indices * ( + + ky_vector = k0 * (n_I * torch.sin(theta) * torch.sin(phi) + fourier_indices_y * ( wavelength / period[1])).type(type_complex) k_I_z = (k0 ** 2 * n_I ** 2 - kx_vector ** 2 - ky_vector.reshape((-1, 1)) ** 2) ** 0.5 @@ -273,8 +270,8 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, k_I_z = torch.conj(k_I_z.flatten()) k_II_z = torch.conj(k_II_z.flatten()) - Kx = torch.diag(kx_vector.tile(ff).flatten() / k0) - Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff).flatten() / k0) + Kx = torch.diag(kx_vector.tile(ff_y).flatten() / k0) + Ky = torch.diag(ky_vector.reshape((-1, 1)).tile(ff_x).flatten() / k0) varphi = torch.arctan(ky_vector.reshape((-1, 1)) / kx_vector).flatten() @@ -298,15 +295,14 @@ def transfer_2d_1(ff, k0, n_I, n_II, kx_vector, period, fourier_indices, theta, ] ) - big_T = torch.eye(ff ** 2 * 2, device=device, dtype=type_complex) + big_T = torch.eye(2 * ff_xy, device=device, dtype=type_complex) return kx_vector, ky_vector, Kx, Ky, k_I_z, k_II_z, varphi, Y_I, Y_II, Z_I, Z_II, big_F, big_G, big_T -def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_complex=torch.complex128, +def transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_complex=torch.complex128, perturbation=1E-10): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) B = Kx @ E_conv_i @ Kx - I D = Ky @ E_conv_i @ Ky - I @@ -316,6 +312,7 @@ def transfer_2d_wv(ff, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, device='cpu', type_ torch.cat([Ky ** 2 + B @ o_E_conv_i, Kx @ (E_conv_i @ Ky @ E_conv - Ky)], dim=1), torch.cat([Ky @ (E_conv_i @ Kx @ o_E_conv_i - Kx), Kx ** 2 + D @ E_conv], dim=1) ]) + Eig.broadening_parameter = perturbation eigenvalues, W = Eig.apply(S2_from_S) @@ -396,11 +393,10 @@ def transfer_2d_2(k0, d, W, V, center, q, varphi, I, O, big_F, big_G, big_T, dev return big_X, big_F, big_G, big_T, big_A_i, big_B, W_11, W_12, W_21, W_22, V_11, V_12, V_21, V_22 -def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, +def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff_xy, delta_i0, k_I_z, k0, n_I, n_II, k_II_z, device='cpu', type_complex=torch.complex128): - - I = torch.eye(ff ** 2, device=device, dtype=type_complex) - O = torch.zeros((ff ** 2, ff ** 2), device=device, dtype=type_complex) + I = torch.eye(ff_xy, device=device, dtype=type_complex) + O = torch.zeros((ff_xy, ff_xy), device=device, dtype=type_complex) big_F_11 = big_F[:center, :center] big_F_12 = big_F[:center, center:] @@ -433,14 +429,14 @@ def transfer_2d_3(center, big_F, big_G, big_T, Z_I, Y_I, psi, theta, ff, delta_i final_RT = torch.linalg.inv(final_A) @ final_B - R_s = final_RT[:ff ** 2, :].flatten() - R_p = final_RT[ff ** 2:2 * ff ** 2, :].flatten() + R_s = final_RT[:ff_xy, :].flatten() + R_p = final_RT[ff_xy:2 * ff_xy, :].flatten() - big_T1 = final_RT[2 * ff ** 2:, :] + big_T1 = final_RT[2 * ff_xy:, :] big_T = big_T @ big_T1 - T_s = big_T[:ff ** 2, :].flatten() - T_p = big_T[ff ** 2:, :].flatten() + T_s = big_T[:ff_xy, :].flatten() + T_p = big_T[ff_xy:, :].flatten() de_ri = R_s * torch.conj(R_s) * torch.real(k_I_z / (k0 * n_I * torch.cos(theta))) \ + R_p * torch.conj(R_p) * torch.real((k_I_z / n_I ** 2) / (k0 * n_I * torch.cos(theta))) From 2a9b0fd6d07e524356f61dde0b5c02b6f890ccb7 Mon Sep 17 00:00:00 2001 From: yonghakim Date: Thu, 16 Mar 2023 00:47:12 +0900 Subject: [PATCH 3/4] async 2D order; 2D jax backprop does not match; --- QA/backprop.py | 16 +-- benchmarks/interface/Reticolo.py | 22 +++-- examples/ex_ucell.py | 14 ++- meent/on_jax/emsolver/_base.py | 16 ++- meent/on_jax/emsolver/convolution_matrix.py | 40 +------- meent/on_jax/emsolver/field_distribution.py | 99 +++++++++++++++++-- meent/on_jax/emsolver/rcwa.py | 57 +++++++++-- meent/on_numpy/emsolver/_base.py | 25 ++--- meent/on_torch/emsolver/_base.py | 9 +- meent/on_torch/emsolver/field_distribution.py | 2 +- meent/testcase.py | 2 +- 11 files changed, 196 insertions(+), 106 deletions(-) diff --git a/QA/backprop.py b/QA/backprop.py index 44e7011..202ebfe 100644 --- a/QA/backprop.py +++ b/QA/backprop.py @@ -32,7 +32,7 @@ def load_setting(mode_key, dtype, device): wavelength = 900 ucell_materials = [1, 3.48] - fourier_order = 2 + fourier_order = [2, 2] period = [1000, 1000] thickness = [1120., 400, 300] @@ -133,12 +133,12 @@ def optimize_jax_ucell_metasurface(mode_key, dtype, device): @jax.grad def grad_loss(ucell): - E_conv_all = to_conv_mat_discrete(ucell, fourier_order, type_complex=type_complex) - o_E_conv_all = to_conv_mat_discrete(1 / ucell, fourier_order, type_complex=type_complex) + E_conv_all = to_conv_mat_discrete(ucell, *fourier_order, type_complex=type_complex) + o_E_conv_all = to_conv_mat_discrete(1 / ucell, *fourier_order, type_complex=type_complex) de_ri, de_ti = solver.solve(wavelength, E_conv_all, o_E_conv_all) c = de_ti.shape[0] // 2 loss = de_ti[c, c] - print(loss.primal) + # print(loss.primal) return loss def grad_numerical(ucell, delta): @@ -149,14 +149,14 @@ def grad_numerical(ucell, delta): for c in range(ucell.shape[2]): ucell_delta_m = ucell.at[layer, r, c].set(ucell[layer, r, c] - delta) - E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, fourier_order, type_complex=type_complex) - o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, fourier_order, type_complex=type_complex) + E_conv_all_m = to_conv_mat_discrete(ucell_delta_m, *fourier_order, type_complex=type_complex) + o_E_conv_all_m = to_conv_mat_discrete(1 / ucell_delta_m, *fourier_order, type_complex=type_complex) de_ri_delta_m, de_ti_delta_m = solver.solve(wavelength, E_conv_all_m, o_E_conv_all_m) ucell_delta_p = ucell.at[layer, r, c].set(ucell[layer, r, c] + delta) - E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, fourier_order, type_complex=type_complex) - o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, fourier_order, type_complex=type_complex) + E_conv_all_p = to_conv_mat_discrete(ucell_delta_p, *fourier_order, type_complex=type_complex) + o_E_conv_all_p = to_conv_mat_discrete(1 / ucell_delta_p, *fourier_order, type_complex=type_complex) de_ri_delta_p, de_ti_delta_p = solver.solve(wavelength, E_conv_all_p, o_E_conv_all_p) center = de_ti_delta_m.shape[0] // 2 diff --git a/benchmarks/interface/Reticolo.py b/benchmarks/interface/Reticolo.py index 9d556d8..a00d54c 100644 --- a/benchmarks/interface/Reticolo.py +++ b/benchmarks/interface/Reticolo.py @@ -56,7 +56,6 @@ def run(self, grating_type, period, fourier_order, ucell, thickness, theta, phi, textures = [n_I, *ucell_new, n_II] else: - fourier_order = [fourier_order, fourier_order] Nx = ucell.shape[2] Ny = ucell.shape[1] @@ -114,14 +113,14 @@ def run_acs(self, pattern, n_si='SILICON'): mode = 0 dtype = 0 device = 0 - grating_type = 0 + grating_type = 2 pre = load_setting(mode, dtype, device, grating_type) reti = Reticolo() a,b,c,d = reti.run(**pre) - print(np.array(a).flatten()[::-1]) - print(np.array(b).flatten()[::-1]) + print(np.array(a).flatten()) + print(np.array(b).flatten()) # print(np.array(a)) # print(np.array(b)) # print(c) @@ -131,18 +130,21 @@ def run_acs(self, pattern, n_si='SILICON'): mode = 0 pre = load_setting(mode, dtype, device, grating_type) mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) - mee.fft_type = 1 + mee.fft_type = 0 de_ri, de_ti = mee.conv_solve() - c = de_ri.shape[0]//2 + center = np.array(de_ri.shape) // 2 try: - print(de_ri[c-1:c+2, c-1:c+2]) - print(de_ti[c-1:c+2, c-1:c+2]) + print(de_ri[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) + print(de_ti[center[0]-1:center[0]+2, center[1]-1:center[1]+2]) except: # print(de_ri[c-1:c+2]) # print(de_ti[c-1:c+2]) - print(de_ri) - print(de_ti) + print(de_ri[center[0]-1:center[0]+2]) + print(de_ti[center[0]-1:center[0]+2]) + + # print(de_ri) + # print(de_ti) print(a.sum(),de_ri.sum()) print(b.sum(),de_ti.sum()) diff --git a/examples/ex_ucell.py b/examples/ex_ucell.py index e18676f..58dbbec 100644 --- a/examples/ex_ucell.py +++ b/examples/ex_ucell.py @@ -32,9 +32,9 @@ ucell_materials = [1, 'p_si__real'] period = [1000, 1000] -fourier_order = [3, 2] +fourier_order = 3 mode_options = {0: 'numpy', 1: 'JAX', 2: 'Torch', } -n_iter = 2 +n_iter = 3 def run_test(grating_type, mode_key, dtype, device): @@ -88,14 +88,18 @@ def run_test(grating_type, mode_key, dtype, device): AA = meent.call_mee(mode=mode_key, grating_type=grating_type, pol=pol, n_I=n_I, n_II=n_II, theta=theta, phi=phi, psi=psi, fourier_order=fourier_order, wavelength=wavelength, period=period, ucell=ucell, ucell_materials=ucell_materials, - thickness=thickness, device=device, type_complex=type_complex, fft_type=1, improve_dft=True) + thickness=thickness, device=device, type_complex=type_complex, fft_type=0, improve_dft=True) + # for i in range(n_iter): + # t0 = time.time() + # AA.conv_solve_calculate_field() + # print(f'run_cell: {i}: ', time.time() - t0) for i in range(n_iter): t0 = time.time() de_ri, de_ti = AA.conv_solve() print(f'run_cell: {i}: ', time.time() - t0) resolution = (20, 20, 20) - for i in range(0): + for i in range(2): t0 = time.time() AA.calculate_field(resolution=resolution, plot=False) print(f'cal_field: {i}', time.time() - t0) @@ -184,4 +188,4 @@ def load_ucell(grating_type): if __name__ == '__main__': - run_loop([2], [0,1,2], [0], [0]) + run_loop([0], [2], [0], [0]) diff --git a/meent/on_jax/emsolver/_base.py b/meent/on_jax/emsolver/_base.py index 18e3567..bdb8f11 100644 --- a/meent/on_jax/emsolver/_base.py +++ b/meent/on_jax/emsolver/_base.py @@ -45,8 +45,9 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - # TODO: jit-stuff. apply other backends? - if len(fourier_order) == 1: + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: self.fourier_order = list(fourier_order) + [0] else: self.fourier_order = [int(v) for v in fourier_order] @@ -62,8 +63,6 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - # if self.theta == 0: - # self.theta = self.perturbation self.theta = jnp.where(self.theta == 0, self.perturbation, self.theta) self.kx_vector = None @@ -84,13 +83,10 @@ def get_kx_vector(self, wavelength): return kx_vector def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): - self.layer_info_list = [] self.T1 = None - ff = self.fourier_order[0] * 2 + 1 # TODO: list? - - # fourier_indices = jnp.arange(-self.fourier_order[0], self.fourier_order[0] + 1) + ff = self.fourier_order[0] * 2 + 1 delta_i0 = jnp.zeros(ff, dtype=self.type_complex) delta_i0 = delta_i0.at[self.fourier_order[0]].set(1) @@ -103,7 +99,7 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ - = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices, self.period, + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, self.period, self.pol, wl=wavelength) else: raise ValueError @@ -134,7 +130,7 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): else: raise ValueError - if self.algo == 'TMM': + if self.algo == 'TMM': # TODO: fourier order? X, f, g, T, a_i, b = transfer_1d_2(k0, q, d, W, V, f, g, self.fourier_order, T, type_complex=self.type_complex) diff --git a/meent/on_jax/emsolver/convolution_matrix.py b/meent/on_jax/emsolver/convolution_matrix.py index 594f51f..5e16a73 100644 --- a/meent/on_jax/emsolver/convolution_matrix.py +++ b/meent/on_jax/emsolver/convolution_matrix.py @@ -161,15 +161,6 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order_x, fourier_orde fourier_order_x, fourier_order_y, type_complex=type_complex) center = np.array(f_coeffs.shape) // 2 - - - # conv_idx = jnp.arange(-ff + 1, ff, 1) - # conv_idx = circulant(conv_idx) - # conv_i = jnp.repeat(conv_idx, ff, 1) - # conv_i = jnp.repeat(conv_i, ff, axis=0) - # conv_j = jnp.tile(conv_idx, (ff, ff)) - - conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) conv_idx_y = circulant(conv_idx_y) conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) @@ -179,8 +170,6 @@ def to_conv_mat_continuous_vector(ucell_info_list, fourier_order_x, fourier_orde conv_idx_x = circulant(conv_idx_x) conv_j = np.tile(conv_idx_x, (ff_y, ff_y)) - - e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] o_e_conv = o_f_coeffs[center[0] + conv_i, center[1] + conv_j] @@ -218,14 +207,6 @@ def to_conv_mat_continuous(ucell, fourier_order_x, fourier_order_y, device=None, f_coeffs = fft_piecewise_constant(layer, fourier_order_x, fourier_order_y, type_complex=type_complex) center = jnp.array(f_coeffs.shape) // 2 - # conv_idx = jnp.arange(-ff + 1, ff, 1) - # conv_idx = circulant(conv_idx) - # conv_i = jnp.repeat(conv_idx, ff, 1) - # conv_i = jnp.repeat(conv_i, ff, axis=0) - # conv_j = jnp.tile(conv_idx, (ff, ff)) - # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] - # res = res.at[i].set(e_conv) - conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) conv_idx_y = circulant(conv_idx_y) conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) @@ -241,7 +222,7 @@ def to_conv_mat_continuous(ucell, fourier_order_x, fourier_order_y, device=None, return res -@partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) # TODO +# @partial(jax.jit, static_argnums=(1, 2, 3, 4, 5)) def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, type_complex=jnp.complex128, improve_dft=True): ucell_pmt = ucell ** 2 @@ -274,14 +255,7 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t ff_x = 2 * fourier_order_x + 1 ff_y = 2 * fourier_order_y + 1 - res = np.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) - - # if improve_dft: - # minimum_pattern_size_1 = 2 * ff * pmt.shape[1] - # minimum_pattern_size_2 = 2 * ff * pmt.shape[2] - # else: - # minimum_pattern_size_1 = 2 * ff - # minimum_pattern_size_2 = 2 * ff + res = jnp.zeros((ucell_pmt.shape[0], ff_x * ff_y, ff_x * ff_y)).astype(type_complex) if improve_dft: minimum_pattern_size_y = 2 * ff_y * ucell_pmt.shape[1] @@ -303,19 +277,9 @@ def to_conv_mat_discrete(ucell, fourier_order_x, fourier_order_y, device=None, t f_coeffs = jnp.fft.fftshift(jnp.fft.fft2(layer / layer.size)) center = jnp.array(f_coeffs.shape) // 2 - # conv_idx = jnp.arange(-ff + 1, ff, 1) - # conv_idx = circulant(conv_idx) - # - # conv_i = jnp.repeat(conv_idx, ff, 1) - # conv_i = jnp.repeat(conv_i, ff, axis=0) - # conv_j = jnp.tile(conv_idx, (ff, ff)) - # e_conv = f_coeffs[center[0] + conv_i, center[1] + conv_j] - # res = res.at[i].set(e_conv) - conv_idx_y = jnp.arange(-ff_y + 1, ff_y, 1) conv_idx_y = circulant(conv_idx_y) conv_i = jnp.repeat(conv_idx_y, ff_x, axis=1) - # conv_i = jnp.repeat(conv_i, [ff_x] * ff_y, axis=0) conv_i = jnp.repeat(conv_i, jnp.array([ff_x] * ff_y), axis=0) conv_idx_x = jnp.arange(-ff_x + 1, ff_x, 1) diff --git a/meent/on_jax/emsolver/field_distribution.py b/meent/on_jax/emsolver/field_distribution.py index 895b5fb..315ae36 100644 --- a/meent/on_jax/emsolver/field_distribution.py +++ b/meent/on_jax/emsolver/field_distribution.py @@ -16,11 +16,11 @@ def field_distribution(grating_type, *args, **kwargs): return res -def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), +def field_dist_1d_original(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_info_list, period, pol, resolution=(100, 1, 100), type_complex=jnp.complex128): k0 = 2 * jnp.pi / wavelength - fourier_indices = jnp.arange(-fourier_order, fourier_order + 1) + fourier_indices = jnp.arange(-fourier_order[0], fourier_order[0] + 1) # kx_vector = k0 * (n_I * jnp.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) Kx = jnp.diag(kx_vector / k0) @@ -28,7 +28,8 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in resolution_z, resolution_y, resolution_x = resolution # Here use numpy array due to slow assignment speed in JAX - field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + # field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + field_cell = jnp.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) T_layer = T1 @@ -55,7 +56,93 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in for j in range(resolution_y): for i in range(resolution_x): res = x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector) - field_cell[resolution_z * idx_layer + k, j, i] = res + # field_cell[resolution_z * idx_layer + k, j, i] = res + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set(res) + + T_layer = a_i @ X @ T_layer + + return field_cell + + +@partial(jax.jit, static_argnums=(4, 5, 9, 10, 11, )) +def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order_x, fourier_order_y, T1, layer_info_list, period, pol, resolution=(100, 1, 100), + type_complex=jnp.complex128): + + k0 = 2 * jnp.pi / wavelength + fourier_indices = jnp.arange(-fourier_order_x, fourier_order_x + 1) + # kx_vector = k0 * (n_I * jnp.sin(theta) - fourier_indices * (wavelength / period[0])).astype(type_complex) + + Kx = jnp.diag(kx_vector / k0) + + resolution_z, resolution_y, resolution_x = resolution + + # Here use numpy array due to slow assignment speed in JAX + # field_cell = np.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + field_cell = jnp.zeros((resolution_z * len(layer_info_list), resolution_y, resolution_x, 3), dtype=type_complex) + + T_layer = T1 + + # From the first layer + for idx_layer, (E_conv_i, q, W, X, a_i, b, d) in enumerate(layer_info_list[::-1]): + + c1 = T_layer[:, None] + c2 = b @ a_i @ X @ T_layer[:, None] + + Q = jnp.diag(q) + + if pol == 0: + V = W @ Q + EKx = None + + else: + V = E_conv_i @ W @ Q + EKx = E_conv_i @ Kx + + for k in range(resolution_z): + z = k / resolution_z * d + + if pol == 0: + Sy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Ux = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + C = (-1j) * Kx @ Sy + + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Ey = Sy.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hx = -1j * Ux.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Hz = C.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Ey[0, 0], Hx[0, 0], Hz[0, 0]]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Ey, Hx, Hz]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 0].set(Ey[0, 0]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 1].set(Hx[0, 0]) + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i, 2].set(Hz[0, 0]) + # res = [Ey[0, 0], Hx[0, 0], Hz[0, 0]] + + else: + Uy = W @ (expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + Sx = V @ (-expm(-k0 * Q * z) @ c1 + expm(k0 * Q * (z - d)) @ c2) + + C = (-1j) * EKx @ Uy # there is a better option for convergence + for j in range(resolution_y): + for i in range(resolution_x): + x = i * period[0] / resolution_x + + Hy = Uy.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ex = 1j * Sx.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + Ez = C.T @ jnp.exp(-1j * kx_vector.reshape((-1, 1)) * x) + + field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set([Hy[0, 0], Ex[0, 0], Ez[0, 0]]) + # res = [Hy[0, 0], Ex[0, 0], Ez[0, 0]] + + # A, B, C = z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx) + # for j in range(resolution_y): + # for i in range(resolution_x): + # res = x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector) + # # field_cell[resolution_z * idx_layer + k, j, i] = res + # field_cell = field_cell.at[resolution_z * idx_layer + k, j, i].set(res) T_layer = a_i @ X @ T_layer @@ -143,7 +230,7 @@ def field_dist_2d(wavelength, kx_vector, n_I, theta, phi, fourier_order, T1, lay return field_cell -@partial(jax.jit, static_argnums=(0,)) +# @partial(jax.jit, static_argnums=(0,)) def z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx): if pol == 0: # TE @@ -162,7 +249,7 @@ def z_loop_1d(pol, k0, Kx, W, V, Q, c1, c2, d, z, EKx): return Uy, Sx, C -@partial(jax.jit, static_argnums=(0,)) +# @partial(jax.jit, static_argnums=(0,)) def x_loop_1d(pol, resolution_x, period, i, A, B, C, kx_vector): if pol == 0: # TE diff --git a/meent/on_jax/emsolver/rcwa.py b/meent/on_jax/emsolver/rcwa.py index 48ea0d8..817233a 100644 --- a/meent/on_jax/emsolver/rcwa.py +++ b/meent/on_jax/emsolver/rcwa.py @@ -102,19 +102,13 @@ def solve(self, wavelength, e_conv_all, o_e_conv_all): return de_ri, de_ti # @jax.jit # TODO: can draw field? with jit? - def conv_solve(self, **kwargs): - [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? + def _conv_solve(self): if self.fft_type == 0: E_conv_all = to_conv_mat_discrete(self.ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) o_E_conv_all = to_conv_mat_discrete(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], type_complex=self.type_complex, improve_dft=self.improve_dft) - elif self.fft_type == 1: - E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order[0], self.fourier_order[1], - type_complex=self.type_complex) - o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], - type_complex=self.type_complex) elif self.fft_type == 2: E_conv_all, o_E_conv_all = to_conv_mat_continuous_vector(self.ucell_info_list, self.fourier_order, type_complex=self.type_complex) @@ -122,6 +116,29 @@ def conv_solve(self, **kwargs): raise ValueError de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) + # self.layer_info_list = layer_info_list + # self.T1 = T1 + # self.kx_vector = kx_vector + return de_ri, de_ti, layer_info_list, T1, kx_vector + + def _conv_solve_no_jit(self): + + E_conv_all = to_conv_mat_continuous(self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) + o_E_conv_all = to_conv_mat_continuous(1 / self.ucell, self.fourier_order[0], self.fourier_order[1], + type_complex=self.type_complex) + + de_ri, de_ti, layer_info_list, T1, kx_vector = self._solve(self.wavelength, E_conv_all, o_E_conv_all) + + return de_ri, de_ti, layer_info_list, T1, kx_vector + + def conv_solve(self, **kwargs): + [setattr(self, k, v) for k, v in kwargs.items()] # TODO: need this? for optimization? + + if self.fft_type == 1: + de_ri, de_ti, layer_info_list, T1, kx_vector = self._conv_solve_no_jit() + else: + de_ri, de_ti, layer_info_list, T1, kx_vector = self._conv_solve() self.layer_info_list = layer_info_list self.T1 = T1 @@ -159,9 +176,35 @@ def run_ucell_pmap(self, ucell_list): return de_ri, de_ti # TODO: jit? fourier order split in args? + # @jax.jit def calculate_field(self, resolution=None, plot=True): if self.grating_type == 0: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, *self.fourier_order, self.T1, + self.layer_info_list, self.period, self.pol, resolution=resolution, + type_complex=self.type_complex) + elif self.grating_type == 1: + resolution = [100, 1, 100] if not resolution else resolution + field_cell = field_dist_1d_conical(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) + + else: + resolution = [10, 10, 10] if not resolution else resolution + field_cell = field_dist_2d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.phi, + self.fourier_order, self.T1, self.layer_info_list, self.period, + resolution=resolution, type_complex=self.type_complex) + if plot: + field_plot(field_cell, self.pol) + return field_cell + + @jax.jit + def conv_solve_calculate_field(self, resolution=None, plot=False): + self._conv_solve() + if self.grating_type == 0: + resolution = (20, 20, 20) + resolution = [100, 1, 100] if not resolution else resolution field_cell = field_dist_1d(self.wavelength, self.kx_vector, self.n_I, self.theta, self.fourier_order, self.T1, self.layer_info_list, self.period, self.pol, resolution=resolution, diff --git a/meent/on_numpy/emsolver/_base.py b/meent/on_numpy/emsolver/_base.py index b4175f4..be1c3cd 100644 --- a/meent/on_numpy/emsolver/_base.py +++ b/meent/on_numpy/emsolver/_base.py @@ -35,7 +35,12 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = [int(v) for v in fourier_order] # TODO: other bds + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] self.period = deepcopy(period) @@ -48,9 +53,7 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - if self.theta == 0: - self.theta = self.perturbation - self.theta = np.where(self.theta == 0, self.perturbation, self.theta) # TODO: check whether correct + self.theta = np.where(self.theta == 0, self.perturbation, self.theta) self.kx_vector = None # tODO: need this? why only kx, not ky? @@ -66,20 +69,16 @@ def get_kx_vector(self, wavelength): kx_vector = k0 * (self.n_I * np.sin(self.theta) * np.cos(self.phi) + fourier_indices_x * ( wavelength / self.period[0])).astype(self.type_complex) - # kx_vector = kx_vector.conjugate() # kx_vector = np.where(kx_vector == 0, self.perturbation, kx_vector) return kx_vector def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): - self.layer_info_list = [] self.T1 = None ff = self.fourier_order[0] * 2 + 1 - fourier_indices_x = np.arange(-self.fourier_order[0], self.fourier_order[0] + 1) - delta_i0 = np.zeros(ff, dtype=self.type_complex) delta_i0[self.fourier_order[0]] = 1 @@ -91,20 +90,12 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): self.theta, delta_i0, self.fourier_order, type_complex=self.type_complex) elif self.algo == 'SMM': Kx, Wg, Vg, Kzg, Wr, Vr, Kzr, Wt, Vt, Kzt, Ar, Br, Sg \ - = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, fourier_indices_x, self.period, + = scattering_1d_1(k0, self.n_I, self.n_II, self.theta, self.phi, self.period, self.pol, wl=wavelength) else: raise ValueError - # count = min(len(E_conv_all), len(o_E_conv_all), len(self.thickness)) - # From the last layer - # for layer_index in range(count)[::-1]: - # - # E_conv = E_conv_all[layer_index] - # o_E_conv = o_E_conv_all[layer_index] - # d = self.thickness[layer_index] - for E_conv, o_E_conv, d in zip(E_conv_all[::-1], o_E_conv_all[::-1], self.thickness[::-1]): if self.pol == 0: diff --git a/meent/on_torch/emsolver/_base.py b/meent/on_torch/emsolver/_base.py index 002a6cc..a4d09a0 100644 --- a/meent/on_torch/emsolver/_base.py +++ b/meent/on_torch/emsolver/_base.py @@ -40,7 +40,12 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= print('not implemented yet') raise ValueError - self.fourier_order = [int(v) for v in fourier_order] + if type(fourier_order) == int: + self.fourier_order = [fourier_order, 0] + elif len(fourier_order) == 1: + self.fourier_order = list(fourier_order) + [0] + else: + self.fourier_order = [int(v) for v in fourier_order] self.period = deepcopy(period) @@ -53,8 +58,6 @@ def __init__(self, grating_type, n_I=1., n_II=1., theta=0., phi=0., psi=0., pol= self.layer_info_list = [] self.T1 = None - # if self.theta == 0: - # self.theta = torch.tensor(self.perturbation) self.theta = torch.where(self.theta == 0, self.perturbation, self.theta) # TODO: check correct? self.kx_vector = None diff --git a/meent/on_torch/emsolver/field_distribution.py b/meent/on_torch/emsolver/field_distribution.py index 5974b6c..87f334f 100644 --- a/meent/on_torch/emsolver/field_distribution.py +++ b/meent/on_torch/emsolver/field_distribution.py @@ -16,7 +16,7 @@ def field_dist_1d(wavelength, kx_vector, n_I, theta, fourier_order, T1, layer_in device='cpu', type_complex=torch.complex128): k0 = 2 * np.pi / wavelength - fourier_indices = torch.arange(-fourier_order, fourier_order + 1, device=device) + fourier_indices = torch.arange(-fourier_order[0], fourier_order[0] + 1, device=device) # kx_vector = k0 * (n_I * np.sin(theta) - fourier_indices * (wavelength / period[0])).type(type_complex) Kx = torch.diag(kx_vector / k0) diff --git a/meent/testcase.py b/meent/testcase.py index 141980d..3c26ac6 100644 --- a/meent/testcase.py +++ b/meent/testcase.py @@ -20,7 +20,7 @@ def load_setting(mode, dtype, device, grating_type): wavelength = 900 ucell_materials = [1, 3.48] - fourier_order = 5 + fourier_order = [5, 5] # thickness = [1120., 400, 300] thickness = [1000., ] From 9b018dae1b5fca4c8499264f556ed0eccedaf844 Mon Sep 17 00:00:00 2001 From: yonghakim Date: Thu, 16 Mar 2023 21:37:35 +0900 Subject: [PATCH 4/4] async 2D fourier order; --- QA/async_2d_fourier_order.py | 78 ++++++++++++++++++++++++++++++++++++ meent/testcase.py | 2 +- 2 files changed, 79 insertions(+), 1 deletion(-) create mode 100644 QA/async_2d_fourier_order.py diff --git a/QA/async_2d_fourier_order.py b/QA/async_2d_fourier_order.py new file mode 100644 index 0000000..025f91f --- /dev/null +++ b/QA/async_2d_fourier_order.py @@ -0,0 +1,78 @@ +import numpy as np + +import meent + +from meent.testcase import load_setting +from benchmarks.interface.Reticolo import Reticolo + + +mode = 0 +dtype = 0 +device = 0 +grating_type = 2 +pre = load_setting(mode, dtype, device, grating_type) + +reti = Reticolo() +de_ri_top_inc, de_ti_top_inc, de_ri_bot_inc, de_ti_bot_inc = reti.run(**pre) +de_ri_reti = de_ri_top_inc.flatten() +de_ti_reti = de_ti_top_inc.flatten() + +print('reti de_ri:', de_ri_reti) +print('reti de_ti:', de_ti_reti) + +# Numpy +mode = 0 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +cut_index = [3, 1, 4, 7, 5] +# cut_index = [3, 1, 4, 5] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +print('Norm(Reti, NPY): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) + +# JAX +mode = 1 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 + +de_ri, de_ti = np.array(de_ri), np.array(de_ti) +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +# print('meen jx de_ri:', de_ri_cut) +# print('meen jx de_ti:', de_ti_cut) + +print('Norm(Reti, JAX): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) + +# Torch +mode = 2 +pre = load_setting(mode, dtype, device, grating_type) +mee = meent.call_mee(mode=mode, perturbation=1E-30, **pre) +mee.fft_type = 0 + +de_ri, de_ti = mee.conv_solve() +center = np.array(de_ri.shape) // 2 + +de_ri, de_ti = np.array(de_ri), np.array(de_ti) +de_ri_cut = de_ri[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ti_cut = de_ti[center[0] - 1:center[0] + 2, center[1] - 1:center[1] + 2] +de_ri_cut = de_ri_cut.flatten()[cut_index] +de_ti_cut = de_ti_cut.flatten()[cut_index] + +# print('meen to de_ri:', de_ri_cut) +# print('meen to de_ti:', de_ti_cut) + +print('Norm(Reti, TOR): ', np.linalg.norm(de_ri_reti - de_ri_cut), np.linalg.norm(de_ti_reti - de_ti_cut)) diff --git a/meent/testcase.py b/meent/testcase.py index 3c26ac6..7ebf3c2 100644 --- a/meent/testcase.py +++ b/meent/testcase.py @@ -20,7 +20,7 @@ def load_setting(mode, dtype, device, grating_type): wavelength = 900 ucell_materials = [1, 3.48] - fourier_order = [5, 5] + fourier_order = [9, 2] # thickness = [1120., 400, 300] thickness = [1000., ]