Skip to content

Commit

Permalink
Add microlens array phase masks
Browse files Browse the repository at this point in the history
  • Loading branch information
diptodip committed Oct 17, 2024
1 parent 7368b2f commit cde0aef
Showing 1 changed file with 95 additions and 3 deletions.
98 changes: 95 additions & 3 deletions src/chromatix/utils/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,25 @@
from typing import Sequence

import jax.numpy as jnp
import numpy as np
from einops import rearrange
from jax import Array
from scipy.special import comb # type: ignore

from chromatix.typing import ScalarLike

from .utils import create_grid, grid_spatial_to_pupil
from ..typing import ScalarLike
from .utils import (
create_grid,
grid_spatial_to_pupil,
l2_norm,
l2_sq_norm,
rotate_grid,
)

__all__ = [
"flat_phase",
"microlens_array_amplitude_and_phase",
"hexagonal_microlens_array_amplitude_and_phase",
"rectangular_microlens_array_amplitude_and_phase",
"potato_chip",
"seidel_aberrations",
"zernike_aberrations",
Expand All @@ -31,6 +40,89 @@ def flat_phase(shape: tuple[int, int], *args, value: ScalarLike = 0.0) -> Array:
return jnp.full(shape, value)


def microlens_array_amplitude_and_phase(
shape: tuple[int, int],
spacing: ScalarLike,
wavelength: ScalarLike,
n: ScalarLike,
fs: Array,
centers: Array,
radii: Array,
) -> tuple[Array, Array]:
phase = jnp.zeros(shape)
amplitude = jnp.zeros(shape)
grid = create_grid(shape, spacing)
for i in range(centers.shape[1]):
center = centers[:, i]
squared_distance = l2_sq_norm(grid - center[:, jnp.newaxis, jnp.newaxis])
L = wavelength * fs[i] / n
mask = jnp.squeeze(squared_distance) < (radii[i] ** 2)
amplitude += mask
phase += mask * jnp.squeeze(squared_distance / L)
phase *= -jnp.pi
amplitude = jnp.clip(amplitude, 0.0, 1.0)
return amplitude, phase


def hexagonal_microlens_array_amplitude_and_phase(
shape: tuple[int, int],
spacing: ScalarLike,
wavelength: ScalarLike,
n: ScalarLike,
f: ScalarLike,
num_lenses_per_side: ScalarLike,
radius: ScalarLike,
separation: ScalarLike,
) -> tuple[Array, Array]:
hex_distance = num_lenses_per_side - 1
unit_hex_coordinates = []
q_basis = np.array([0, 1])
r_basis = np.array([np.sqrt(3) / 2, 1 / 2])
for q in range(-hex_distance, hex_distance + 1):
for r in range(max(-hex_distance, -q - hex_distance), min(hex_distance, -q + hex_distance) + 1):
unit_hex_coordinates.append(q_basis * q + r_basis * r)
unit_hex_coordinates = np.array(unit_hex_coordinates).T
hex_coordinates = unit_hex_coordinates * separation
return microlens_array_amplitude_and_phase(
shape,
spacing,
wavelength,
n,
jnp.ones(hex_coordinates.shape[1]) * f,
hex_coordinates,
jnp.ones(hex_coordinates.shape[1]) * radius,
)


def rectangular_microlens_array_amplitude_and_phase(
shape: tuple[int, int],
spacing: ScalarLike,
wavelength: ScalarLike,
n: ScalarLike,
f: ScalarLike,
num_lenses_height: ScalarLike,
num_lenses_width: ScalarLike,
radius: ScalarLike,
separation: ScalarLike,
) -> tuple[Array, Array]:
unit_coordinates = np.meshgrid(
np.arange(num_lenses_height) - num_lenses_height // 2,
np.arange(num_lenses_width) - num_lenses_width // 2,
indexing="ij",
)
unit_coordinates = np.array(unit_coordinates).reshape(2, num_lenses_height * num_lenses_width)
coordinates = unit_coordinates * separation
return microlens_array_amplitude_and_phase(
shape,
spacing,
wavelength,
n,
jnp.ones(coordinates.shape[1]) * f,
coordinates,
jnp.ones(coordinates.shape[1]) * radius,
)


def potato_chip(
shape: tuple[int, int],
spacing: ScalarLike,
Expand Down

0 comments on commit cde0aef

Please sign in to comment.