Skip to content

Commit

Permalink
First basic proper version works.
Browse files Browse the repository at this point in the history
  • Loading branch information
Gert-Jan Both committed Nov 8, 2024
1 parent bb6cbce commit d7bdb6b
Show file tree
Hide file tree
Showing 3 changed files with 232 additions and 228 deletions.
75 changes: 75 additions & 0 deletions modified_born/04_proper_code.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""
Author: GJ Both
Date: 08/11/2024
In this notebook we replace all the functions from the previous notebook by a proper implementation.
"""

# %% Imports
import jax.numpy as jnp
import matplotlib.pyplot as plt
from samples import vacuum_cylinders
from solvers import add_bc, make_source, maxwell_solver

# Settings
spacing = 0.1
wavelength = 1.0
width = (25 / wavelength, None, 25 / wavelength)
alpha_boundary = 0.35
order = 4

# %%
n_sample = vacuum_cylinders()

# %%
plt.title("Sample - cylinders in vacuum")
plt.imshow(jnp.rot90(n_sample[:, 0, :]))
plt.xlabel("z")
plt.ylabel("x")
plt.colorbar(label="Refractive index")

# %% Now adding the absorbing boundary conditions
permittivity, roi = add_bc(
n_sample**2, width, spacing, wavelength, alpha=alpha_boundary, order=order
)
source = make_source(permittivity.shape, spacing, wavelength, 25)

# %%
plt.figure(figsize=(10, 5))
plt.subplot(121)
plt.title("Ex of source")
plt.imshow(jnp.rot90(source[:, 0, :, 2]))
plt.colorbar(fraction=0.046, pad=0.04)

plt.subplot(122)
plt.title("Ey of source")
plt.imshow(jnp.rot90(source[:, 0, :, 1]))
plt.colorbar(fraction=0.046, pad=0.04)


# %%
field, history = maxwell_solver(permittivity, source, spacing, wavelength)

# %%
plt.title("Relative change in field")
plt.semilogy(history)
plt.ylabel("dE")
plt.xlabel("Iteration")

# %%
plt.figure(figsize=(15, 5))
plt.subplot(131)
plt.title("Ex")
plt.imshow(jnp.rot90(jnp.abs(field[roi][:, 0, :, 2])), vmin=0.0, vmax=1.2)
plt.colorbar(fraction=0.046, pad=0.04)

plt.subplot(132)
plt.title("Ey")
plt.imshow(jnp.rot90(jnp.abs(field[roi][:, 0, :, 1])), vmin=0.0, vmax=1.2)
plt.colorbar(fraction=0.046, pad=0.04)

plt.subplot(133)
plt.title("Ez")
plt.imshow(jnp.rot90(jnp.abs(field[roi][:, 0, :, 0])), vmin=0.0, vmax=1.2)
plt.colorbar(fraction=0.046, pad=0.04)

# %%
167 changes: 8 additions & 159 deletions modified_born/samples.py
Original file line number Diff line number Diff line change
@@ -1,165 +1,6 @@
from __future__ import annotations

from functools import reduce as freduce
from typing import Literal

import jax.numpy as jnp
import numpy as np
from chex import dataclass
from einops import reduce
from jax import Array
from scipy.ndimage import distance_transform_edt
from scipy.special import factorial


@dataclass
class Source:
field: Array
wavelength: float

@property
def source(self):
return (2 * jnp.pi / self.wavelength) ** 2 * self.field


@dataclass
class Sample:
# Simple dataclass to deal with refractive index
# padding, boundary conditions etc. Nothing vital
# but makes life a lot easier.
permittivity: Array
dx: float
roi: tuple[slice, ...]

@classmethod
def init(
cls,
refractive_index: Array, # [N_z, N_y, N_x]
dx: float, # in wavelengths
wavelength: float,
boundary_width: tuple[None | int, ...],
boundary_type: Literal["pbl", "arl"] = "arl",
*,
boundary_strength: float = 0.2,
boundary_order: int = 4,
) -> Sample:
permittivity, roi = cls.add_boundary(
refractive_index**2,
wavelength,
dx,
boundary_type,
boundary_width,
boundary_order,
boundary_strength,
)
permittivity = cls.pad_fourier(permittivity, boundary_width)

return cls(
permittivity=permittivity,
dx=dx,
roi=roi,
)

@property
def shape(self) -> tuple[int, ...]:
return self.permittivity.shape

@property
def spatial_shape(self) -> tuple[int, ...]:
return self.permittivity.shape[:3]

@property
def k_grid(self) -> Array:
# NOTE THE 2PI factor!!!
ks = [
2 * jnp.pi * jnp.fft.fftfreq(shape, self.dx) for shape in self.spatial_shape
]
return jnp.stack(jnp.meshgrid(*ks, indexing="ij"), axis=-1)

@staticmethod
def add_boundary(
permittivity: Array,
wavelength: float,
dx: float,
type: Literal["arl", "pbl"],
width: tuple[None | int, ...],
order: int,
strength: float,
) -> tuple[Array, tuple[slice, ...]]:
# Finding new shapes and rois
n_pad = tuple(0 if width_i is None else int(width_i / dx) for width_i in width)
roi = tuple(slice(n, n + size) for n, size in zip(n_pad, permittivity.shape))

# Padding permittivity to new size
padding = [(0, 0) for _ in range(permittivity.ndim)]
for idx, n in enumerate(n_pad):
padding[idx] = (n, n)
permittivity = jnp.pad(permittivity, padding, mode="edge")

# Adding BCs
match type:
case "pbl":
permittivity = Sample.add_pbl(
permittivity, roi, strength, order, wavelength, dx
)
case "arl":
permittivity = Sample.add_arl(permittivity, roi, np.max(n_pad))
return permittivity, roi

@staticmethod
def add_pbl(
permittivity: Array,
roi: tuple[slice, ...],
strength: float,
order: int,
wavelength: float,
dx: float,
) -> Array:
# Gathering constants
km = 2 * jnp.pi * jnp.sqrt(jnp.mean(permittivity)) * dx / wavelength
alpha = strength * km**2 / (2 * km)

# Defining distance from sample
r = jnp.ones_like(permittivity).at[roi].set(0)
r = distance_transform_edt(r)

# Making boundary
ar = alpha * r
P = freduce(
lambda P, n: P + (ar**n / factorial(n, exact=True)),
range(order + 1),
jnp.zeros_like(ar),
)

numerator = alpha**2 * (order - ar + 2 * 1j * km * r) * ar ** (order - 1)
denominator = P * factorial(order, exact=True)
boundary = 1 / km**2 * numerator / denominator

return permittivity + boundary

@staticmethod
def add_arl(permittivity: Array, roi: tuple[slice, ...], n_pad: Array) -> Array:
# Defining distance from sample
r = jnp.ones_like(permittivity).at[roi].set(0)
r = distance_transform_edt(r)
beta = (jnp.abs(n_pad - r) - 0.21) / (n_pad + 0.66)
beta = beta.at[roi].set(1.0)
return permittivity * beta

@staticmethod
def pad_fourier(x: Array, width: tuple[int | None, ...]) -> Array:
# Pads to fourier friendly shapes (powers of 2), depending
# on periodic or absorbing BCs
def n_pad(order, size):
padding = 0 if order is None else int(2 ** (np.ceil(np.log2(size)))) - size
return (padding//2, padding//2)

padding = [(0, 0) for _ in range(x.ndim)]
for idx, (order, size) in enumerate(
zip(width, x.shape[: len(width)], strict=True)
):
padding[idx] = n_pad(order, size)
return jnp.pad(x, padding, mode="constant", constant_values=0)


def sample_grid(size: tuple[int, int, int]) -> Array:
Expand Down Expand Up @@ -253,3 +94,11 @@ def bio_cylinders():
n_cylinder,
antialiasing=10,
)


def angled_interface():
return (
jnp.full((1000, 1000), 1.0)
.at[jnp.triu_indices(n=1000)]
.set(1.55)[::-1, None, :]
)
Loading

0 comments on commit d7bdb6b

Please sign in to comment.