Skip to content

Commit

Permalink
Merge branch 'DEV/Li_off' into DEV/main
Browse files Browse the repository at this point in the history
  • Loading branch information
yonghakim committed Jul 9, 2024
2 parents cb9568b + d92f7b4 commit cc28149
Show file tree
Hide file tree
Showing 6 changed files with 58 additions and 29 deletions.
22 changes: 15 additions & 7 deletions meent/on_jax/emsolver/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -278,13 +278,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
elif self.pol == 1:
E_conv_i = jnp.linalg.inv(E_conv)
B = Kx @ E_conv_i @ Kx - jnp.eye(E_conv.shape[0]).astype(self.type_complex)
o_E_conv_i = jnp.linalg.inv(o_E_conv)
eigenvalues, W = eig(o_E_conv_i @ B, type_complex=self.type_complex, perturbation=self.perturbation,
# o_E_conv_i = jnp.linalg.inv(o_E_conv)

eigenvalues, W = eig(E_conv @ B, type_complex=self.type_complex, perturbation=self.perturbation,
device=self.device)
eigenvalues += 0j # to get positive square root
q = eigenvalues ** 0.5
Q = jnp.diag(q)
V = o_E_conv @ W @ Q
# V = o_E_conv @ W @ Q
V = E_conv_i @ W @ Q

else:
raise ValueError
Expand Down Expand Up @@ -345,11 +347,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = jnp.linalg.inv(E_conv)
o_E_conv_i = jnp.linalg.inv(o_E_conv)
# o_E_conv_i = jnp.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
Expand Down Expand Up @@ -418,11 +423,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
# From the last layer
for layer_index in range(count)[::-1]:
E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = jnp.linalg.inv(E_conv)
o_E_conv_i = jnp.linalg.inv(o_E_conv)
# o_E_conv_i = jnp.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv,
Expand Down
3 changes: 2 additions & 1 deletion meent/on_jax/emsolver/transfer_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph
B_i = jnp.linalg.inv(B)

to_decompose_W_1 = (ky/k0) ** 2 * I + A
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv

eigenvalues_1, W_1 = eig(to_decompose_W_1, type_complex=type_complex, perturbation=perturbation, device=device)
eigenvalues_2, W_2 = eig(to_decompose_W_2, type_complex=type_complex, perturbation=perturbation, device=device)
Expand Down
29 changes: 19 additions & 10 deletions meent/on_numpy/emsolver/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,7 +226,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
# From the last layer
for layer_index in range(count)[::-1]:
E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]

d = self.thickness[layer_index]

if self.pol == 0:
Expand All @@ -241,13 +242,15 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
elif self.pol == 1:
E_conv_i = np.linalg.inv(E_conv)
B = Kx @ E_conv_i @ Kx - np.eye(E_conv.shape[0], dtype=self.type_complex)
o_E_conv_i = np.linalg.inv(o_E_conv)
# o_E_conv_i = np.linalg.inv(o_E_conv)

eigenvalues, W = np.linalg.eig(o_E_conv_i @ B)
eigenvalues, W = np.linalg.eig(E_conv @ B)
eigenvalues += 0j # to get positive square root
q = eigenvalues ** 0.5
Q = np.diag(q)
V = o_E_conv @ W @ Q
# V = o_E_conv @ W @ Q
V = E_conv_i @ W @ Q

else:
raise ValueError

Expand Down Expand Up @@ -305,11 +308,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = np.linalg.inv(E_conv)
o_E_conv_i = np.linalg.inv(o_E_conv)
# o_E_conv_i = np.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2 \
Expand Down Expand Up @@ -375,11 +381,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
# From the last layer
for layer_index in range(count)[::-1]:
E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = np.linalg.inv(E_conv)
o_E_conv_i = np.linalg.inv(o_E_conv)
# o_E_conv_i = np.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv, type_complex=self.type_complex)
Expand All @@ -393,7 +402,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
self.layer_info_list.append(layer_info)

elif self.algo == 'SMM':
W, V, q = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
W, V, q = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, q)
else:
raise ValueError
Expand All @@ -405,7 +414,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
self.T1 = big_T1

elif self.algo == 'SMM':
de_ri, de_ti = scattering_2d_3(Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I,
de_ri, de_ti = scattering_2d_3(ff_xy, Wt, Wg, Vt, Vg, Sg, Wr, Kx, Ky, Kzr, Kzt, kz_inc, self.n_I,
self.pol, self.theta, self.phi, self.fourier_order)
else:
raise ValueError
Expand Down
3 changes: 2 additions & 1 deletion meent/on_numpy/emsolver/transfer_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_conv_i, o_E_conv_i, ff, d, varph
B_i = np.linalg.inv(B)

to_decompose_W_1 = (ky/k0) ** 2 * I + A
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv

eigenvalues_1, W_1 = np.linalg.eig(to_decompose_W_1)
eigenvalues_2, W_2 = np.linalg.eig(to_decompose_W_2)
Expand Down
27 changes: 18 additions & 9 deletions meent/on_torch/emsolver/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,8 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]

d = self.thickness[layer_index]

if self.pol == 0:
Expand All @@ -289,13 +290,14 @@ def solve_1d(self, wavelength, E_conv_all, o_E_conv_all):
elif self.pol == 1:
E_conv_i = torch.linalg.inv(E_conv)
B = Kx @ E_conv_i @ Kx - torch.eye(E_conv.shape[0], device=self.device, dtype=self.type_complex)
o_E_conv_i = torch.linalg.inv(o_E_conv)
# o_E_conv_i = torch.linalg.inv(o_E_conv)

Eig.perturbation = self.perturbation
eigenvalues, W = Eig.apply(o_E_conv_i @ B)
eigenvalues, W = Eig.apply(E_conv @ B)
q = eigenvalues ** 0.5
Q = torch.diag(q)
V = o_E_conv @ W @ Q
# V = o_E_conv @ W @ Q
V = E_conv_i @ W @ Q

else:
raise ValueError
Expand Down Expand Up @@ -355,11 +357,14 @@ def solve_1d_conical(self, wavelength, E_conv_all, o_E_conv_all):
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = torch.linalg.inv(E_conv)
o_E_conv_i = torch.linalg.inv(o_E_conv)
# o_E_conv_i = torch.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
big_X, big_F, big_G, big_T, big_A_i, big_B, W_1, W_2, V_11, V_12, V_21, V_22, q_1, q_2\
Expand Down Expand Up @@ -429,10 +434,14 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
for layer_index in range(count)[::-1]:

E_conv = E_conv_all[layer_index]
o_E_conv = o_E_conv_all[layer_index]
# o_E_conv = o_E_conv_all[layer_index]
o_E_conv = None

d = self.thickness[layer_index]

E_conv_i = torch.linalg.inv(E_conv)
o_E_conv_i = torch.linalg.inv(o_E_conv)
# o_E_conv_i = torch.linalg.inv(o_E_conv)
o_E_conv_i = None

if self.algo == 'TMM':
W, V, q = transfer_2d_wv(ff_xy, Kx, E_conv_i, Ky, o_E_conv_i, E_conv,
Expand All @@ -447,7 +456,7 @@ def solve_2d(self, wavelength, E_conv_all, o_E_conv_all):
self.layer_info_list.append(layer_info)

elif self.algo == 'SMM':
W, V, LAMBDA = scattering_2d_wv(Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
W, V, LAMBDA = scattering_2d_wv(ff_xy, Kx, Ky, E_conv, o_E_conv, o_E_conv_i, E_conv_i)
A, B, Sl_dict, Sg_matrix, Sg = scattering_2d_2(W, Wg, V, Vg, d, k0, Sg, LAMBDA)
else:
raise ValueError
Expand Down
3 changes: 2 additions & 1 deletion meent/on_torch/emsolver/transfer_method.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ def transfer_1d_conical_2(k0, Kx, ky, E_conv, E_i, o_E_conv_i, ff, d, varphi, bi
B_i = torch.linalg.inv(B)

to_decompose_W_1 = (ky/k0) ** 2 * I + A
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
# to_decompose_W_2 = (ky/k0) ** 2 * I + B @ o_E_conv_i
to_decompose_W_2 = (ky/k0) ** 2 * I + B @ E_conv

Eig.perturbation = perturbation
eigenvalues_1, W_1 = Eig.apply(to_decompose_W_1)
Expand Down

0 comments on commit cc28149

Please sign in to comment.