Skip to content

Commit 19e9fd3

Browse files
committed
bugfix: add task_id to JaxSimulationData
1 parent f1e4832 commit 19e9fd3

File tree

3 files changed

+13
-4
lines changed

3 files changed

+13
-4
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1818
### Fixed
1919
- Bug in plotting and computing tilted plane intersections of transformed 0 thickness geometries.
2020
- `Simulation.to_gdspy()` and `Simulation.to_gdstk()` now place polygons in GDS layer `(0, 0)` when no `gds_layer_dtype_map` is provided instead of erroring.
21+
- `task_id` now properly stored in `JaxSimulationData`.
2122

2223
## [2.7.0rc1] - 2024-04-22
2324

tidy3d/plugins/adjoint/web.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,7 @@ def run(
123123
callback_url=callback_url,
124124
verbose=verbose,
125125
)
126+
# TODO: add task_id
126127
return JaxSimulationData.from_sim_data(sim_data, jax_info)
127128

128129

@@ -151,7 +152,9 @@ def run_fwd(
151152
)
152153

153154
res = RunResidual(fwd_task_id=task_id)
154-
jax_sim_data_orig = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
155+
jax_sim_data_orig = JaxSimulationData.from_sim_data(
156+
sim_data_orig, jax_info_orig, task_id=task_id
157+
)
155158

156159
return jax_sim_data_orig, (res,)
157160

@@ -410,6 +413,7 @@ def run_async(
410413
task_name = str(_task_name_orig(i))
411414
sim_data_tidy3d = batch_data_tidy3d[task_name]
412415
jax_info = jax_infos[str(task_name)]
416+
# TODO: add task_id
413417
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
414418
jax_batch_data.append(jax_sim_data)
415419

@@ -450,8 +454,10 @@ def run_async_fwd(
450454
batch_data_orig = [sim_data for _, sim_data in batch_data_orig.items()]
451455

452456
jax_batch_data_orig = []
453-
for sim_data_orig, jax_info_orig in zip(batch_data_orig, jax_infos_orig):
454-
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_orig, jax_info_orig)
457+
for sim_data_orig, jax_info_orig, task_id in zip(batch_data_orig, jax_infos_orig, fwd_task_ids):
458+
jax_sim_data = JaxSimulationData.from_sim_data(
459+
sim_data_orig, jax_info_orig, task_id=task_id
460+
)
455461
jax_batch_data_orig.append(jax_sim_data)
456462

457463
residual = RunResidualBatch(fwd_task_ids=fwd_task_ids)
@@ -626,6 +632,7 @@ def run_local(
626632
)
627633

628634
# convert back to jax type and return
635+
# TODO: add task_id
629636
return JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
630637

631638

@@ -779,6 +786,7 @@ def run_async_local(
779786
task_name = _task_name_orig_local(i, task_name_suffix)
780787
sim_data_tidy3d = batch_data_tidy3d[task_name]
781788
jax_info = jax_infos[str(task_name)]
789+
# TODO: add task_id
782790
jax_sim_data = JaxSimulationData.from_sim_data(sim_data_tidy3d, jax_info=jax_info)
783791
jax_batch_data.append(jax_sim_data)
784792

0 commit comments

Comments
 (0)