Skip to content

Commit c9154b5

Browse files
committed
docs: tentative autograd dev docs
1 parent 2477d96 commit c9154b5

File tree

2 files changed

+278
-0
lines changed

2 files changed

+278
-0
lines changed
Lines changed: 171 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,171 @@
1+
# Tidy3D Autograd Maintainer Guide
2+
3+
This document targets contributors who work on the native autograd stack. It complements `tidy3d/web/api/autograd/README.md` by focusing on internal interfaces, data contracts, and where new functionality must plug in.
4+
5+
## Scope & Key Modules
6+
- `tidy3d/web/api/autograd/autograd.py`: entry points for `web.run` and `web.run_async` when traced parameters are present. Hosts the autograd `@primitive` wrappers, the defvjp registrations, and orchestration of local vs. server gradients.
7+
- `tidy3d/web/api/autograd/{forward,backward,engine,io_utils}.py`: forward monitor injection, adjoint setup/post-processing, wrappers around `Job`/`Batch`, and serialization helpers (`FieldMap`, `TracerKeys`, `SIM_*` files).
8+
- `tidy3d/components/autograd/`: tracer-aware NumPy shims (`boxes.py`), differentiable utilities (`functions.py`), tracing types (`types.py`), field mapping (`field_map.py`), and the `DerivativeInfo` contract (`derivative_utils.py`).
9+
- Geometry & medium implementations (`tidy3d/components/{geometry,medium}.py`): provide `_compute_derivatives` overrides that consume `DerivativeInfo` and emit VJPs.
10+
- Tests live under `tests/test_components/autograd/` and are the canonical signal for regressions in tracing, adjoint batching, and derivative formulas.
11+
12+
## Mathematical Foundation
13+
14+
The implementation follows the adjoint state method, enabling efficient gradient computation with just one forward and one adjoint simulation.
15+
16+
### Adjoint Method
17+
1. **Forward Problem**: $\nabla \times \nabla \times E - k^2\varepsilon(p)E = -i\omega\mu_0 J$
18+
2. **Objective Function**: $f = F(E(p))$
19+
3. **Adjoint Problem**: $\nabla \times \nabla \times E^\dagger - k^2\varepsilon(p)E^\dagger = \partial F / \partial E$
20+
4. **Parameter Gradient**: $\partial f / \partial p = \text{Re}[\int E^\dagger \cdot (\partial (k^2\varepsilon) / \partial p) \cdot E dV]$
21+
22+
### Derivative Types
23+
- **Volume derivatives** (for material parameters $\varepsilon, \sigma$):
24+
Computed as $E_{fwd} \cdot E_{adj}$ integrated over the structure volume.
25+
- **Surface derivatives** (for geometry boundaries):
26+
- Normal component: $\Delta(1/\varepsilon) \cdot D_{fwd,n} \cdot D_{adj,n}$
27+
- Tangential component: $\Delta\varepsilon \cdot E_{fwd,t} \cdot E_{adj,t}$
28+
- Integrated over the structure surface.
29+
30+
## Tracing Mechanics
31+
32+
The system relies on `HIPS/autograd` for automatic differentiation.
33+
34+
### TidyArrayBox
35+
We use `TidyArrayBox` (an alias for autograd's `ArrayBox`) to wrap traced values.
36+
- **Flow**: `params` (numpy array) -> `anp` operations -> `TidyArrayBox` -> `Structure` fields -> `Simulation`.
37+
- **Extraction**: `Simulation._strip_traced_fields()` recursively extracts these boxes into an `AutogradFieldMap` before simulation execution.
38+
- **Re-insertion**: After the simulation, results are mapped back to these boxes to continue the computation graph.
39+
40+
## Architecture at a Glance
41+
```mermaid
42+
graph TB
43+
subgraph "Frontend (Python SDK)"
44+
W["web/api/autograd/autograd.py<br/>@primitive wrappers"]
45+
FWD["forward.py<br/>setup_fwd + postprocess_fwd"]
46+
BWD["backward.py<br/>setup_adj + postprocess_adj"]
47+
ENG["engine.py & io_utils.py<br/>Job/Batch orchestration + FieldMap IO"]
48+
COMP["components/autograd/*<br/>boxes, types, DerivativeInfo, utils"]
49+
GEO["components/geometry/*<br/>geometry VJPs"]
50+
MED["components/medium.py<br/>material VJPs"]
51+
end
52+
subgraph "Solver / Storage (cloud or local)"
53+
S1["autograd_fwd tasks<br/>with adjoint monitors"]
54+
S2["autograd_bwd tasks<br/>adjoint batches"]
55+
STORE["Artifacts<br/>- autograd_sim_fields_keys.hdf5<br/>- autograd_fwd_data.hdf5<br/>- autograd_sim_vjp.hdf5"]
56+
end
57+
W --> FWD
58+
W --> BWD
59+
FWD --> ENG
60+
BWD --> ENG
61+
ENG --> S1
62+
ENG --> S2
63+
S1 --> STORE
64+
S2 --> STORE
65+
STORE --> ENG
66+
COMP --> GEO
67+
COMP --> MED
68+
BWD --> GEO
69+
BWD --> MED
70+
```
71+
72+
## Forward (primal) data flow
73+
1. **Tracer detection**`setup_run()` calls `Simulation._strip_traced_fields(..., starting_path=("structures",))` to collect an `AutogradFieldMap`. `is_valid_for_autograd()` enforces traced content, at least one frequency-domain monitor, and `config.adjoint.max_traced_structures` (`tidy3d/web/api/autograd/autograd.py`).
74+
2. **Static snapshot & payload** – The simulation is frozen via `Simulation.to_static()`. When tracers exist, `Simulation._serialized_traced_field_keys(...)` stores `TracerKeys` in `TRACED_FIELD_KEYS_ATTR` so solver sidecars know which structural slots to differentiate.
75+
3. **Monitor injection**`setup_fwd()` (delegating to `tidy3d/web/api/autograd/forward.py`) asks `Simulation._with_adjoint_monitors(sim_fields)` to duplicate the job with structure-aligned `FieldMonitor` and `PermittivityMonitor` objects. Low-level placement happens in `Structure._make_adjoint_monitors()` (`tidy3d/components/structure.py`), which consults `config.adjoint.monitor_interval_poly/custom` and adds H-field sampling for PEC materials.
76+
4. **Primitive invocation**`_run_primitive()` marks the task as `simulation_type="autograd_fwd"` when gradients stay on the server. Local gradients run the combined simulation directly; remote gradients call `_run_tidy3d()`/`_run_async_tidy3d()` with upload-time hooks to send `autograd_sim_fields_keys.hdf5`.
77+
5. **Auxiliary caching**`postprocess_fwd()` splits the solver output into user data vs. gradient monitors. It populates `aux_data[AUX_KEY_SIM_DATA_ORIGINAL]` and `aux_data[AUX_KEY_SIM_DATA_FWD]`, returning only the tracer-shaped dictionary that autograd sees as the primitive output. These cached blobs are mandatory for the later VJP.
78+
79+
### High-level execution flow
80+
```mermaid
81+
flowchart LR
82+
A[Objective via autograd.grad] --> B[Simulation with traced params]
83+
B --> C[web.api.autograd.run primitive]
84+
C --> D{local_gradient?}
85+
D -->|True| E[Run combined forward locally]
86+
D -->|False| F[Upload autograd_fwd + TracerKeys]
87+
E --> G[postprocess_fwd caches sim data]
88+
F --> G
89+
G --> H[setup_adj builds adjoint sims]
90+
H --> I[Batch run autograd_bwd jobs]
91+
I --> J["postprocess_adj (chunked DerivativeInfo)"]
92+
J --> K[Structure/medium _compute_derivatives]
93+
K --> L[Autograd VJP returns dJ/dparams]
94+
```
95+
96+
## Backward (adjoint) data flow & batching
97+
1. **Gradient request** – When autograd calls the registered VJP (`_run_bwd` or `_run_async_bwd`), the wrapper pulls `sim_fields_keys`, the original `SimulationData`, and (for local gradients) the stored forward monitor data.
98+
2. **Adjoint source assembly**`setup_adj()` zero-filters user VJPs, reinserts them into the `SimulationData`, and asks `SimulationData._make_adjoint_sims(...)` to build one adjoint simulation per unique `(monitor, frequency, polarization)` bucket. The limit is enforced via `max_num_adjoint_per_fwd` (call argument or defaulting to `config.adjoint.max_adjoint_per_fwd`).
99+
3. **Batch execution** – Local gradients reuse `_run_async_tidy3d` with `path_dir / config.adjoint.local_adjoint_dir`; remote gradients mutate each adjoint sim to `simulation_type="autograd_bwd"`, link them to the forward task via `parent_tasks`, and rely on `_run_async_tidy3d_bwd()` plus `_get_vjp_traced_fields()` to download `output/autograd_sim_vjp.hdf5`.
100+
4. **Adjoint post-processing**`tidy3d/web/api/autograd/backward.postprocess_adj()` pulls forward (`fld_fwd`, `eps_fwd`) and adjoint (`fld_adj`, `eps_adj`) monitors, builds `E_der_map`, `D_der_map`, and optional `H_der_map` via `get_derivative_maps()`, and converts E-fields to D-fields with `E_to_D()`. Frequency batching honors `config.adjoint.solver_freq_chunk_size` to trade memory for CPU.
101+
5. **Derivative dispatch** – The routine constructs a `DerivativeInfo` (see below) per chunk, forwards it into `Structure._compute_derivatives()`, and accumulates the returned dict `{('structures', i, 'geometry', ...): gradient}` across adjoint simulations before returning to autograd.
102+
103+
## DerivativeInfo contract (tidy3d/components/autograd/derivative_utils.py)
104+
`DerivativeInfo` centralizes every tensor the geometry/medium code needs. Key expectations:
105+
- `paths`: tuple of relative paths inside `Structure.geometry` or `.medium` that must be filled with gradients.
106+
- `E_der_map`, `D_der_map`, `H_der_map` (optional) : dictionaries mapping field-component names (e.g., `Ex`, `eps_xx`) to `ScalarFieldDataArray`s already multiplied element-wise (`E_fwd * E_adj`, etc.).
107+
- `E_fwd`, `E_adj`, `D_fwd`, `D_adj`, `H_*`: raw fields for terms that require asymmetric handling (e.g., PEC tangential enforcement).
108+
- `eps_data`: slice of the permittivity monitor on the same grid; `eps_in`, `eps_out`, `eps_background`, `eps_no_structure`, and `eps_inf_structure` cover cases where the monitor cannot deliver inside/outside material automatically (geometry groups, approximations, PEC detection).
109+
- `bounds`, `bounds_intersect`, `simulation_bounds`: cached bounding boxes for clipping integrals; all derived from simulation + geometry differences.
110+
- `frequencies`: the chunked frequency array currently being reduced; geometry implementations must sum over this subset because `postprocess_adj()` loops over slices.
111+
- `eps_approx`, `is_medium_pec`, `interpolators`: flags and caches that geometry/medium code can honor to shortcut expensive recomputation across related shapes.
112+
Use `DerivativeInfo.updated_copy(deep=False, paths=...)` to retarget subsets while sharing cached interpolators. Geometry code is expected to tolerate NaNs by calling `_nan_to_num_if_needed` before evaluating interpolators.
113+
114+
## Custom VJP providers
115+
`Structure._compute_derivatives()` fans gradients out to the relevant constituent objects. Every class below defines `_compute_derivatives()` and therefore must be updated whenever the contract changes:
116+
117+
**Geometry stack**
118+
- `Geometry` (base dispatch) – `tidy3d/components/geometry/base.py`. Handles shared surface integrals (normal/tangential D/E terms) and loops over child geometries.
119+
- `Box``tidy3d/components/geometry/base.py`. Implements closed-form face quadratures; used for axis-aligned primitives.
120+
- `Cylinder``tidy3d/components/geometry/primitives.py`. Generates adaptive azimuthal sampling controlled by `config.adjoint.points_per_wavelength`, etc.
121+
- `PolySlab``tidy3d/components/geometry/polyslab.py`. Handles polygon meshes, including sidewall extrusion and vertex-by-vertex derivatives.
122+
- `GeometryGroup``tidy3d/components/geometry/base.py`. Splits groups into constituent geometries, toggles `DerivativeInfo.eps_approx`, and shares cached interpolators.
123+
124+
**Medium stack** (`tidy3d/components/medium.py`)
125+
- `AbstractMedium` (base) and `Medium` – implement the generic volume integral with `E_der_map`/`D_der_map`.
126+
- Dispersion families each override `_compute_derivatives()` to wire frequency-dependent parameters into the adjoint accumulation: `CustomMedium`, `PoleResidue`/`CustomPoleResidue`, `Sellmeier`/`CustomSellmeier`, `Lorentz`/`CustomLorentz`, `Drude`/`CustomDrude`, `Debye`/`CustomDebye`. Any new dispersive model must emit gradients for both pole frequencies and residues, respecting `config.adjoint.gradient_dtype_*`.
127+
128+
### Example: PolySlab derivative dispatch
129+
```mermaid
130+
sequenceDiagram
131+
autonumber
132+
participant PP as postprocess_adj()
133+
participant DI as DerivativeInfo chunk
134+
participant ST as Structure._compute_derivatives
135+
participant GEO as PolySlab geometry
136+
participant MED as Medium
137+
PP->>PP: Slice fields & build DI
138+
PP->>ST: ST._compute_derivatives(DI)
139+
ST->>ST: Group paths by 'geometry'/'medium'
140+
ST->>GEO: geometry._compute_derivatives(DI_geom)
141+
GEO-->>ST: Gradients for vertices/sidewalls
142+
ST->>MED: medium._compute_derivatives(DI_med)
143+
MED-->>ST: Gradients for eps / dispersion params
144+
ST-->>PP: Merge to {('structures',i, ...): value}
145+
PP-->>Autograd: Accumulate into VJP dict
146+
```
147+
148+
## Configuration & local-vs-server gradients
149+
- `config.adjoint.local_gradient`: when `True`, all computations happen locally, stored under `path.parent / config.adjoint.local_adjoint_dir`, and every other `config.adjoint` override takes effect. When `False` (default), overrides besides `local_gradient` are ignored (see `apply_adjoint()`), because backend defaults guarantee reproducibility.
150+
- `config.adjoint.max_traced_structures`: enforced in `is_valid_for_autograd()` before any run is uploaded. Increase cautiously because each structure inserts its own monitor pair.
151+
- `config.adjoint.max_adjoint_per_fwd`: default for `max_num_adjoint_per_fwd` in `run()`/`run_async()`. If backend returns more adjoint sims than this limit, `setup_adj()` raises `AdjointError` early.
152+
- `config.adjoint.solver_freq_chunk_size`: drives per-chunk slicing inside `postprocess_adj()` so streaming geometries or high-resolution spectra do not explode memory usage.
153+
- `config.adjoint.monitor_interval_poly` / `monitor_interval_custom`: control the spatial sampling density for the auto-inserted monitors; geometry-specific overrides decide which tuple to use based on whether medium parameters are traced.
154+
- `config.adjoint.gradient_precision`: influences dtype selection within `DerivativeInfo` consumers (e.g., `medium.py` uses `config.adjoint.gradient_dtype_float`).
155+
- Other knobs (quadrature order, wavelength fractions, edge clipping tolerances) are consumed by geometry helpers throughout `tidy3d/components/geometry/base.py` & `primitives.py`. Favor these settings instead of sprinkling new constants.
156+
157+
## Serialization artifacts
158+
- **TracerKeys (`autograd_sim_fields_keys.hdf5`)** – produced from `AutogradFieldMap` via `TracerKeys.from_field_mapping()` and uploaded before remote forward runs. They allow the backend to match traced tensors to structural indices, even when the python-side order changes.
159+
- **Forward data (`AUX_KEY_SIM_DATA_*`)** – always cached locally, even when gradients are done remotely, because `postprocess_run()` rehydrates the user-visible `SimulationData` by copying autograd boxes back onto `sim_data_original`.
160+
- **VJP data (`output/autograd_sim_vjp.hdf5`)** – downloaded automatically for server adjoints via `_get_vjp_traced_fields()` and converted back into `AutogradFieldMap` objects using `FieldMap.from_file().to_autograd_field_map`.
161+
162+
## Testing expectations
163+
- Fast unit coverage lives in `tests/test_components/autograd/`. Core suites:
164+
- `test_autograd.py`: integration harness that patches the pipeline, emulates server responses, and checks tracing edge cases (e.g., `TRACED_FIELD_KEYS_ATTR`).
165+
- `test_autograd_dispersive_vjps.py` and `_custom_*` variants: assert each dispersive medium’s `_compute_derivatives()` matches analytic expectations.
166+
- `test_autograd_polyslab*.py`, `test_autograd_rf_*`, and `test_sidewall_edge_cases.py`: stress geometry-derived gradients, particularly for `PolySlab` and right-facing (RF) boxes.
167+
- `tests/test_components/autograd/numerical/`: longer-running numerical comparisons enabled via `poetry run pytest -m numerical tests/test_components/autograd/numerical -k case_name` once maintainers approve the simulation cost.
168+
- Always run `poetry run pytest tests/test_components/autograd -q` after touching this stack. Enable `RUN_NUMERICAL` or pass `-m numerical` only when you are ready to run solver-backed adjoints.
169+
170+
## Additional references
171+
- When exposing new public APIs, update the user-facing docs under `docs/` and reference this README so contributors understand the tracing implications.

0 commit comments

Comments
 (0)