Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 43 additions & 18 deletions tidy3d/plugins/smatrix/smatrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,18 +136,44 @@ def plot_sim(self, x: float = None, y: float = None, z: float = None, ax: Ax = N
sim_plot.sources.append(mode_source_0)
return sim_plot.plot(x=x, y=y, z=z, ax=ax)

def _shift_value(self, port: Port) -> float:
"""How far (signed) to shift the monitor from the source."""
normal_index = port.size.index(0.0)
dl = self.simulation.grid_size[normal_index]
if not isinstance(dl, float):
raise NotImplementedError("doesn't support nonuniform. How many grid cells to shift?")
return dl if port.direction == "+" else -1 * dl
def _shift_value_signed(self, port: Port) -> float:
"""How far (signed) to shift the source from the monitor."""

# get the grid boundaries and sizes along port normal from the simulation
normal_axis = port.size.index(0.0)
grid = self.simulation.grid
grid_boundaries = grid.boundaries.to_list[normal_axis]
grid_centers = grid.centers.to_list[normal_axis]

# get the index of the grid cell where the port lies
port_position = port.center[normal_axis]
port_index = np.argwhere(port_position > grid_boundaries)[-1]

# shift the port to the left
if port.direction == "+":
shifted_index = port_index - 2
if shifted_index < 0:
raise SetupError(
f"Port {port.name} normal is too close to boundary "
f"on -{'xyz'[normal_axis]} side."
)

# shift the port to the right
else:
shifted_index = port_index + 2
if shifted_index >= len(grid_centers):
raise SetupError(
f"Port {port.name} normal is too close to boundary "
f"on +{'xyz'[normal_axis]} side."
)

new_pos = grid_centers[shifted_index]
return new_pos - port_position

def _shift_port(self, port: Port) -> Port:
"""Generate a new port shifted by one grid cell in normal direction."""
"""Generate a new port shifted by the shift amount in normal direction."""

shift_value = self._shift_value(port)
shift_value = self._shift_value_signed(port)
center_shifted = list(port.center)
center_shifted[port.size.index(0.0)] += shift_value
port_shifted = port.copy(deep=True)
Expand All @@ -160,26 +186,25 @@ def _task_name(self, port_source: Port, mode_index: int) -> str:

def _make_sims(self) -> Dict[str, Simulation]:
"""Generate all the :class:`Simulation` objects for the S matrix calculation."""

mode_monitors = [self._to_monitor(port) for port in self.ports]
sim_dict = {}
for port_source in self.ports:
port_source = self._shift_port(port_source)
for mode_source in self._to_sources(port_source):
sim_copy = self.simulation.copy(deep=True)
sim_copy.sources = [mode_source]
for port_monitor in self.ports:
if port_source == port_monitor:
port_monitor = self._shift_port(port_source)
mode_monitor = self._to_monitor(port_monitor)
sim_copy.monitors.append(mode_monitor)
task_name = self._task_name(port_source, mode_source.mode_index)
sim_dict[task_name] = sim_copy
sim_copy.monitors += mode_monitors
task_name = self._task_name(port_source, mode_source.mode_index)
sim_dict[task_name] = sim_copy
return sim_dict

def _run_sims(
self, sim_dict: Dict[str, Simulation], folder_name: str, path_dir: str
) -> "BatchData":
"""Run :class:`Simulations` for each port and return the batch after saving."""
batch = Batch(simulations=sim_dict, folder_name=folder_name)

batch = Batch(simulations=sim_dict, folder_name=folder_name)
batch.upload()
batch.start()
batch.monitor()
Expand All @@ -203,7 +228,7 @@ def _normalization_factor(self, port_source: Port, sim_data: SimulationData) ->

k0 = 2 * np.pi * C_0 / self.freq
k_eff = k0 * normalize_n_eff
shift_value = self._shift_value(port_source)
shift_value = self._shift_value_signed(port_source)
return normalize_amp * np.exp(1j * k_eff * shift_value)

def _construct_smatrix(self, batch_data: "BatchData") -> SMatrixType:
Expand Down