Skip to content

Commit

Permalink
analyze dummy models to prevent dense representation during evaluation
Browse files Browse the repository at this point in the history
  • Loading branch information
flaport committed Dec 21, 2023
1 parent 37fbb7e commit c29dab1
Show file tree
Hide file tree
Showing 13 changed files with 235 additions and 99 deletions.
19 changes: 10 additions & 9 deletions examples/01_quick_start.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@
"\n",
"# Using GPU?\n",
"from jax.lib import xla_bridge\n",
"\n",
"print(xla_bridge.get_backend().platform)"
]
},
Expand Down Expand Up @@ -183,7 +184,7 @@
"outputs": [],
"source": [
"coupling = 0.5\n",
"kappa = coupling ** 0.5\n",
"kappa = coupling**0.5\n",
"tau = (1 - coupling) ** 0.5\n",
"coupler_dict = {\n",
" (\"in0\", \"out0\"): tau,\n",
Expand Down Expand Up @@ -296,7 +297,7 @@
"outputs": [],
"source": [
"def coupler(coupling=0.5) -> sax.SDict:\n",
" kappa = coupling ** 0.5\n",
" kappa = coupling**0.5\n",
" tau = (1 - coupling) ** 0.5\n",
" coupler_dict = sax.reciprocal(\n",
" {\n",
Expand Down Expand Up @@ -450,7 +451,7 @@
" models={\n",
" \"coupler\": coupler,\n",
" \"waveguide\": waveguide,\n",
" }\n",
" },\n",
")"
]
},
Expand All @@ -470,7 +471,7 @@
},
"outputs": [],
"source": [
"mzi?"
"?mzi"
]
},
{
Expand Down Expand Up @@ -1107,10 +1108,10 @@
" \"out1\": \"top,out0\",\n",
" },\n",
" },\n",
" models = {\n",
" models={\n",
" \"coupler\": coupler,\n",
" \"waveguide\": waveguide,\n",
" }\n",
" },\n",
")"
]
},
Expand Down Expand Up @@ -1302,9 +1303,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1316,7 +1317,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.12"
},
"papermill": {
"default_parameters": {},
Expand Down
20 changes: 13 additions & 7 deletions examples/02_all_pass_filter.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -162,12 +162,18 @@
"outputs": [],
"source": [
"def all_pass_analytical():\n",
" \"\"\" Analytic Frequency Domain Response of an all pass filter \"\"\"\n",
" \"\"\"Analytic Frequency Domain Response of an all pass filter\"\"\"\n",
" detected = jnp.zeros_like(wl)\n",
" transmission = 1 - coupling\n",
" neff_wl = neff + (wl0 - wl) * (ng - neff) / wl0 # we expect a linear behavior with respect to wavelength\n",
" out = jnp.sqrt(transmission) - 10 ** (-loss * ring_length / 20.0) * jnp.exp(2j * jnp.pi * neff_wl * ring_length / wl)\n",
" out /= 1 - jnp.sqrt(transmission) * 10 ** (-loss * ring_length / 20.0) * jnp.exp(2j * jnp.pi * neff_wl * ring_length / wl)\n",
" neff_wl = (\n",
" neff + (wl0 - wl) * (ng - neff) / wl0\n",
" ) # we expect a linear behavior with respect to wavelength\n",
" out = jnp.sqrt(transmission) - 10 ** (-loss * ring_length / 20.0) * jnp.exp(\n",
" 2j * jnp.pi * neff_wl * ring_length / wl\n",
" )\n",
" out /= 1 - jnp.sqrt(transmission) * 10 ** (-loss * ring_length / 20.0) * jnp.exp(\n",
" 2j * jnp.pi * neff_wl * ring_length / wl\n",
" )\n",
" detected = abs(out) ** 2\n",
" return detected"
]
Expand Down Expand Up @@ -295,9 +301,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -309,7 +315,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.12"
},
"papermill": {
"default_parameters": {},
Expand Down
11 changes: 7 additions & 4 deletions examples/03_circuit_from_yaml.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,10 @@
},
"outputs": [],
"source": [
"mzi, _ = sax.circuit(yaml.safe_load(netlist), models={\"straight\": waveguide_without_dispersion, 'coupler': sax.models.coupler})"
"mzi, _ = sax.circuit(\n",
" yaml.safe_load(netlist),\n",
" models={\"straight\": waveguide_without_dispersion, \"coupler\": sax.models.coupler},\n",
")"
]
},
{
Expand Down Expand Up @@ -310,9 +313,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -324,7 +327,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.12"
},
"papermill": {
"default_parameters": {},
Expand Down
26 changes: 13 additions & 13 deletions examples/04_multimode_simulations.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -293,20 +293,20 @@
"def coupler():\n",
" return {\n",
" (\"in0@TE\", \"out0@TE\"): 0.45**0.5,\n",
" (\"in0@TE\", \"out1@TE\"): 1j*0.45**0.5,\n",
" (\"in1@TE\", \"out0@TE\"): 1j*0.45**0.5,\n",
" (\"in0@TE\", \"out1@TE\"): 1j * 0.45**0.5,\n",
" (\"in1@TE\", \"out0@TE\"): 1j * 0.45**0.5,\n",
" (\"in1@TE\", \"out1@TE\"): 0.45**0.5,\n",
" (\"in0@TM\", \"out0@TM\"): 0.45**0.5,\n",
" (\"in0@TM\", \"out1@TM\"): 1j*0.45**0.5,\n",
" (\"in1@TM\", \"out0@TM\"): 1j*0.45**0.5,\n",
" (\"in0@TM\", \"out1@TM\"): 1j * 0.45**0.5,\n",
" (\"in1@TM\", \"out0@TM\"): 1j * 0.45**0.5,\n",
" (\"in1@TM\", \"out1@TM\"): 0.45**0.5,\n",
" (\"in0@TE\", \"out0@TM\"): 0.01**0.5,\n",
" (\"in0@TE\", \"out1@TM\"): 1j*0.01**0.5,\n",
" (\"in1@TE\", \"out0@TM\"): 1j*0.01**0.5,\n",
" (\"in0@TE\", \"out1@TM\"): 1j * 0.01**0.5,\n",
" (\"in1@TE\", \"out0@TM\"): 1j * 0.01**0.5,\n",
" (\"in1@TE\", \"out1@TM\"): 0.01**0.5,\n",
" (\"in0@TM\", \"out0@TE\"): 0.01**0.5,\n",
" (\"in0@TM\", \"out1@TE\"): 1j*0.01**0.5,\n",
" (\"in1@TM\", \"out0@TE\"): 1j*0.01**0.5,\n",
" (\"in0@TM\", \"out1@TE\"): 1j * 0.01**0.5,\n",
" (\"in1@TM\", \"out0@TE\"): 1j * 0.01**0.5,\n",
" (\"in1@TM\", \"out1@TE\"): 0.01**0.5,\n",
" }"
]
Expand Down Expand Up @@ -383,8 +383,8 @@
" },\n",
" },\n",
" models={\n",
" 'coupler': coupler,\n",
" 'straight': waveguide,\n",
" \"coupler\": coupler,\n",
" \"straight\": waveguide,\n",
" },\n",
")"
]
Expand Down Expand Up @@ -448,9 +448,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -462,7 +462,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.12"
},
"papermill": {
"default_parameters": {},
Expand Down
63 changes: 47 additions & 16 deletions examples/05_thinfilm.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,7 @@
" r_fresnel_ij = (ni - nj) / (ni + nj) # i->j reflection\n",
" t_fresnel_ij = 2 * ni / (ni + nj) # i->j transmission\n",
" r_fresnel_ji = -r_fresnel_ij # j -> i reflection\n",
" t_fresnel_ji = (1 - r_fresnel_ij ** 2) / t_fresnel_ij # j -> i transmission\n",
" t_fresnel_ji = (1 - r_fresnel_ij**2) / t_fresnel_ij # j -> i transmission\n",
" sdict = {\n",
" (\"in\", \"in\"): r_fresnel_ij,\n",
" (\"in\", \"out\"): t_fresnel_ij,\n",
Expand Down Expand Up @@ -231,7 +231,7 @@
" \"out\": \"B_air,out\",\n",
" },\n",
" },\n",
" backend='fg',\n",
" backend=\"fg\",\n",
")\n",
"\n",
"settings = sax.get_settings(dielectric_fabry_perot)\n",
Expand Down Expand Up @@ -427,7 +427,14 @@
"plt.plot(2 * jnp.pi / wls, Tnorm, \"k\", label=\"T tmm\")\n",
"plt.scatter(2 * jnp.pi / wls, jnp.abs(reflected) ** 2, label=\"R SAX\")\n",
"plt.plot(2 * jnp.pi / wls, Rnorm, \"k--\", label=\"R tmm\")\n",
"plt.vlines(jnp.arange(3, 6) * jnp.pi / (2 * 0.5), ymin=0, ymax=1, color=\"k\", linestyle=\"--\", label=\"m$\\pi$/nd\")\n",
"plt.vlines(\n",
" jnp.arange(3, 6) * jnp.pi / (2 * 0.5),\n",
" ymin=0,\n",
" ymax=1,\n",
" color=\"k\",\n",
" linestyle=\"--\",\n",
" label=\"m$\\pi$/nd\",\n",
")\n",
"plt.xlabel(\"k = 2$\\pi$/λ [1/nm]\")\n",
"plt.ylabel(\"Transmitted and reflected intensities\")\n",
"plt.legend(loc=\"upper right\")\n",
Expand Down Expand Up @@ -697,7 +704,9 @@
" Each mirror is assumed to be lossless and reciprocal : tij = tji, rij = rji\n",
" \"\"\"\n",
" phi = n * 2 * jnp.pi * d / wl\n",
" return r21 + t12 * t12 * r23 * jnp.exp(-2j * phi) / (1 - r21 * r23 * jnp.exp(-2j * phi))"
" return r21 + t12 * t12 * r23 * jnp.exp(-2j * phi) / (\n",
" 1 - r21 * r23 * jnp.exp(-2j * phi)\n",
" )"
]
},
{
Expand Down Expand Up @@ -738,7 +747,7 @@
"\n",
"\n",
"def r_complex(t_amp, t_ang):\n",
" r_amp = jnp.sqrt((1.0 - t_amp ** 2))\n",
" r_amp = jnp.sqrt((1.0 - t_amp**2))\n",
" r_ang = t_ang - jnp.pi / 2\n",
" return r_amp * jnp.exp(-1j * r_ang)"
]
Expand Down Expand Up @@ -783,13 +792,26 @@
"\n",
"wls = jnp.linspace(0.38, 0.78, 500)\n",
"\n",
"T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap)) ** 2\n",
"R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap)) ** 2 \n",
"T_analytical_initial = (\n",
" jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))\n",
" ** 2\n",
")\n",
"R_analytical_initial = (\n",
" jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))\n",
" ** 2\n",
")\n",
"\n",
"plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n",
"plt.plot(2 * jnp.pi / wls, T_analytical_initial, label=\"T\")\n",
"plt.plot(2 * jnp.pi / wls, R_analytical_initial, label=\"R\")\n",
"plt.vlines(jnp.arange(6, 11) * jnp.pi / 2.0, ymin=0, ymax=1, color=\"k\", linestyle=\"--\", label=\"m$\\pi$/nd\")\n",
"plt.vlines(\n",
" jnp.arange(6, 11) * jnp.pi / 2.0,\n",
" ymin=0,\n",
" ymax=1,\n",
" color=\"k\",\n",
" linestyle=\"--\",\n",
" label=\"m$\\pi$/nd\",\n",
")\n",
"plt.xlabel(\"k = 2$\\pi$/$\\lambda$ [1/nm]\")\n",
"plt.ylabel(\"Power (units of input)\")\n",
"plt.legend()\n",
Expand Down Expand Up @@ -885,7 +907,7 @@
" \"out\": \"mirror2,out\",\n",
" },\n",
" },\n",
" backend='fg',\n",
" backend=\"fg\",\n",
")\n",
"\n",
"settings = sax.get_settings(fabry_perot_tunable)\n",
Expand Down Expand Up @@ -924,7 +946,7 @@
" \"out\": \"mirror2,out\",\n",
" },\n",
" },\n",
" backend='fg',\n",
" backend=\"fg\",\n",
")\n",
"\n",
"settings = sax.get_settings(fabry_perot_tunable)\n",
Expand Down Expand Up @@ -973,8 +995,14 @@
},
"outputs": [],
"source": [
"T_analytical_initial = jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n",
"R_analytical_initial = jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))**2\n",
"T_analytical_initial = (\n",
" jnp.abs(airy_t13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))\n",
" ** 2\n",
")\n",
"R_analytical_initial = (\n",
" jnp.abs(airy_r13(t_initial, t_initial, r_initial, r_initial, wls, d=d_gap, n=n_gap))\n",
" ** 2\n",
")\n",
"plt.title(f\"t={t_initial:1.3f}, d={d_gap} nm, n={n_gap}\")\n",
"plt.plot(wls, T_analytical_initial, label=\"T theory\")\n",
"plt.scatter(wls, jnp.abs(transmitted_initial) ** 2, label=\"T SAX\")\n",
Expand Down Expand Up @@ -1062,7 +1090,9 @@
"transmitted = jnp.zeros_like(wls)\n",
"reflected = jnp.zeros_like(wls)\n",
"settings = sax.get_settings(fabry_perot_tunable)\n",
"settings = sax.update_settings(settings, wl=wls, t_amp=ts_initial[:N], t_ang=ts_initial[N:])\n",
"settings = sax.update_settings(\n",
" settings, wl=wls, t_amp=ts_initial[:N], t_ang=ts_initial[N:]\n",
")\n",
"settings[\"gap\"][\"ni\"] = 1.0\n",
"settings[\"gap\"][\"di\"] = 2.0\n",
"# Perform computation\n",
Expand Down Expand Up @@ -1214,6 +1244,7 @@
"source": [
"optim_init, optim_update, optim_params = opt.adam(step_size=0.001)\n",
"\n",
"\n",
"def train_step(step, optim_state):\n",
" ts = optim_params(optim_state)\n",
" lossvalue = loss(ts)\n",
Expand Down Expand Up @@ -1393,9 +1424,9 @@
],
"metadata": {
"kernelspec": {
"display_name": "sax",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "sax"
"name": "python3"
},
"language_info": {
"codemirror_mode": {
Expand All @@ -1407,7 +1438,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.18"
"version": "3.10.12"
},
"papermill": {
"default_parameters": {},
Expand Down
Loading

0 comments on commit c29dab1

Please sign in to comment.