diff --git a/meent/on_jax/emsolver/_base.py b/meent/on_jax/emsolver/_base.py index b5aba80..28f52a6 100644 --- a/meent/on_jax/emsolver/_base.py +++ b/meent/on_jax/emsolver/_base.py @@ -278,13 +278,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): elif self.pol == 1: E_conv_i = jnp.linalg.inv(E_conv) B = Kx @ E_conv_i @ Kx - jnp.eye(E_conv.shape[0]).astype(self.type_complex) - o_E_conv_i = jnp.linalg.inv(o_E_conv) - eigenvalues, W = eig(o_E_conv_i @ B, type_complex=self.type_complex, perturbation=self.perturbation, + # o_E_conv_i = jnp.linalg.inv(o_E_conv) + + eigenvalues, W = eig(E_conv @ B, type_complex=self.type_complex, perturbation=self.perturbation, device=self.device) eigenvalues += 0j # to get positive square root q = eigenvalues ** 0.5 Q = jnp.diag(q) - V = o_E_conv @ W @ Q + # V = o_E_conv @ W @ Q + V = E_conv_i @ W @ Q else: raise ValueError @@ -345,11 +347,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] E_conv_i = jnp.linalg.inv(E_conv) - o_E_conv_i = jnp.linalg.inv(o_E_conv) + # o_E_conv_i = jnp.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ @@ -418,11 +423,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): # From the last layer for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] E_conv_i = jnp.linalg.inv(E_conv) - o_E_conv_i = jnp.linalg.inv(o_E_conv) + # o_E_conv_i = jnp.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, diff --git a/meent/on_jax/emsolver/transfer_method.py b/meent/on_jax/emsolver/transfer_method.py index 8c10ae7..f4cb996 100644 --- a/meent/on_jax/emsolver/transfer_method.py +++ b/meent/on_jax/emsolver/transfer_method.py @@ -135,7 +135,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph B_i = jnp.linalg.inv(B) to_decompose_W_1 = (ky/k0) ** 2 * I + A - to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + # to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv eigenvalues_1, W_1 = eig(to_decompose_W_1, type_complex=type_complex, perturbation=perturbation, device=device) eigenvalues_2, W_2 = eig(to_decompose_W_2, type_complex=type_complex, perturbation=perturbation, device=device) diff --git a/meent/on_numpy/emsolver/_base.py b/meent/on_numpy/emsolver/_base.py index 0fecd07..112416a 100644 --- a/meent/on_numpy/emsolver/_base.py +++ b/meent/on_numpy/emsolver/_base.py @@ -226,7 +226,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): # From the last layer for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] if self.pol == 0: @@ -241,13 +242,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): elif self.pol == 1: E_conv_i = np.linalg.inv(E_conv) B = Kx @ E_conv_i @ Kx - np.eye(E_conv.shape[0], dtype=self.type_complex) - o_E_conv_i = np.linalg.inv(o_E_conv) + # o_E_conv_i = np.linalg.inv(o_E_conv) - eigenvalues, W = np.linalg.eig(o_E_conv_i @ B) + eigenvalues, W = np.linalg.eig(E_conv @ B) eigenvalues += 0j # to get positive square root q = eigenvalues ** 0.5 Q = np.diag(q) - V = o_E_conv @ W @ Q + # V = o_E_conv @ W @ Q + V = E_conv_i @ W @ Q + else: raise ValueError @@ -305,11 +308,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] E_conv_i = np.linalg.inv(E_conv) - o_E_conv_i = np.linalg.inv(o_E_conv) + # o_E_conv_i = np.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \ @@ -375,11 +381,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): # From the last layer for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] E_conv_i = np.linalg.inv(E_conv) - o_E_conv_i = np.linalg.inv(o_E_conv) + # o_E_conv_i = np.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex) @@ -393,7 +402,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, q = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, q = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q) else: raise ValueError @@ -405,7 +414,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.T1 = big_T1 elif self.algo == 'SMM': - de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, + de_ri, de_ti = scattering_2d_3(ff_xy, Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I, self.pol, self.theta, self.phi, self.fourier_order) else: raise ValueError diff --git a/meent/on_numpy/emsolver/transfer_method.py b/meent/on_numpy/emsolver/transfer_method.py index 4115b90..2554692 100644 --- a/meent/on_numpy/emsolver/transfer_method.py +++ b/meent/on_numpy/emsolver/transfer_method.py @@ -121,7 +121,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph B_i = np.linalg.inv(B) to_decompose_W_1 = (ky/k0) ** 2 * I + A - to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + # to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv eigenvalues_1, W_1 = np.linalg.eig(to_decompose_W_1) eigenvalues_2, W_2 = np.linalg.eig(to_decompose_W_2) diff --git a/meent/on_torch/emsolver/_base.py b/meent/on_torch/emsolver/_base.py index 7c73acf..d7b8545 100644 --- a/meent/on_torch/emsolver/_base.py +++ b/meent/on_torch/emsolver/_base.py @@ -274,7 +274,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + d = self.thickness[layer_index] if self.pol == 0: @@ -289,13 +290,14 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all): elif self.pol == 1: E_conv_i = torch.linalg.inv(E_conv) B = Kx @ E_conv_i @ Kx - torch.eye(E_conv.shape[0], device=self.device, dtype=self.type_complex) - o_E_conv_i = torch.linalg.inv(o_E_conv) + # o_E_conv_i = torch.linalg.inv(o_E_conv) Eig.perturbation = self.perturbation - eigenvalues, W = Eig.apply(o_E_conv_i @ B) + eigenvalues, W = Eig.apply(E_conv @ B) q = eigenvalues ** 0.5 Q = torch.diag(q) - V = o_E_conv @ W @ Q + # V = o_E_conv @ W @ Q + V = E_conv_i @ W @ Q else: raise ValueError @@ -355,11 +357,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all): for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] E_conv_i = torch.linalg.inv(E_conv) - o_E_conv_i = torch.linalg.inv(o_E_conv) + # o_E_conv_i = torch.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\ @@ -429,10 +434,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): for layer_index in range(count)[::-1]: E_conv = E_conv_all[layer_index] - o_E_conv = o_E_conv_all[layer_index] + # o_E_conv = o_E_conv_all[layer_index] + o_E_conv = None + d = self.thickness[layer_index] + E_conv_i = torch.linalg.inv(E_conv) - o_E_conv_i = torch.linalg.inv(o_E_conv) + # o_E_conv_i = torch.linalg.inv(o_E_conv) + o_E_conv_i = None if self.algo == 'TMM': W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, @@ -447,7 +456,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all): self.layer_info_list.append(layer_info) elif self.algo == 'SMM': - W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) + W, V, LAMBDA = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i) A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA) else: raise ValueError diff --git a/meent/on_torch/emsolver/transfer_method.py b/meent/on_torch/emsolver/transfer_method.py index 7e55361..8997e9e 100644 --- a/meent/on_torch/emsolver/transfer_method.py +++ b/meent/on_torch/emsolver/transfer_method.py @@ -134,7 +134,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, bi B_i = torch.linalg.inv(B) to_decompose_W_1 = (ky/k0) ** 2 * I + A - to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + # to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i + to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv Eig.perturbation = perturbation eigenvalues_1, W_1 = Eig.apply(to_decompose_W_1)