diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index a0705224..239a1271 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -1,6 +1,6 @@ # This is a basic workflow to help you get started with Actions -name: test +name: cpu-tests # Controls when the action will run. on: @@ -8,34 +8,62 @@ on: pull_request: branches: - main - types: - - opened - - synchronize - - ready_for_review - # Allows you to run this workflow manually from the Actions tab workflow_dispatch: # A workflow run is made up of one or more jobs that can run sequentially or in parallel jobs: - test-its: - + test-tsadar-cpu: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 - - name: Set up Python 3.11 - uses: actions/setup-python@v2 - with: - python-version: 3.11 + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + python-version: 3.11 + mamba-version: "*" + channels: conda-forge,defaults + channel-priority: true + activate-environment: tsadar-cpu + environment-file: env.yml + - shell: bash -el {0} + run: | + pytest tests/ + + test-tsadar-gpu: + runs-on: "gpu_runner" + steps: + - uses: actions/checkout@v4 + - uses: conda-incubator/setup-miniconda@v3 + with: + python-version: 3.11 + mamba-version: "*" + channels: conda-forge,defaults + channel-priority: true + activate-environment: tsadar-gpu + environment-file: env_gpu.yml + - shell: bash -el {0} + run: | + pytest tests/ + + + # - shell: bash -el {0} + # run: mamba env create -f env.yml + # steps: + # - uses: actions/checkout@v3 + # - name: Set up Python 3.11 + # uses: actions/setup-python@v2 + # with: + # python-version: 3.11 - - name: Install dependencies - run: | - python -m pip install --upgrade pip - python -m pip install --upgrade black - python -m pip install --upgrade pytest wheel - python -m pip install --upgrade -r requirements.txt + # - name: Install dependencies + # run: | + # python -m pip install --upgrade pip + # python -m pip install --upgrade black + # python -m pip install --upgrade pytest wheel + # # python -m pip install --upgrade -r requirements.txt - - name: Test with pytest - run: | - pytest tests/ + # - name: Test with pytest + # run: | + # python -m pip install pytest + # CPU_ONLY=True pytest tests/ diff --git a/configs/1d/defaults.yaml b/configs/1d/defaults.yaml index 1349e337..d515ede7 100644 --- a/configs/1d/defaults.yaml +++ b/configs/1d/defaults.yaml @@ -1,7 +1,5 @@ parameters: - species1: - type: - electron: + electron: Te: val: .6 active: False @@ -12,31 +10,8 @@ parameters: active: False lb: 0.001 ub: 1.0 - m: - val: 3.0 - active: False - lb: 2.0 - ub: 5.0 - matte: False - fe: - val: [ ] - active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 - fe_decrease_strict: False - symmetric: False - dim: 1 - v_res: 0.1 - temp_asym: 1.0 - m_theta: 0.0 - m_asym: 1. - species2: - type: - ion: + ion-1: Ti: val: 0.12 active: False @@ -56,8 +31,6 @@ parameters: active: False general: - type: - general: amp1: val: 1. active: False @@ -106,7 +79,8 @@ parameters: other: expandedions: False extraoptions: - load_ion_spec: False + spectype: 1d + load_ion_spec: false load_ele_spec: True fit_IAW: False fit_EPWb: True @@ -115,6 +89,9 @@ other: PhysParams: background: [0, 0] norm: 0 + widIRF: + spect_stddev_ele: 1.3 + spect_stddev_ion: 0.015 iawoff: 0 iawfilter: [1, 4, 24, 528] CCDsize: [1024, 1024] @@ -158,7 +135,7 @@ data: pixel start: 400 end: 600 - skip: 10 + skip: 50 background: type: pixel @@ -186,7 +163,7 @@ optimizer: x_norm: False grad_method: AD batch_size: 2 - num_epochs: 1000 + num_epochs: 120 learning_rate: 1.0e-2 parameter_norm: True refine_factor: 0 @@ -204,4 +181,4 @@ dist_fit: mlflow: experiment: inverse-thomson-scattering - run: base \ No newline at end of file + run: new \ No newline at end of file diff --git a/configs/1d/inputs.yaml b/configs/1d/inputs.yaml index da98408e..ada15a93 100644 --- a/configs/1d/inputs.yaml +++ b/configs/1d/inputs.yaml @@ -1,8 +1,5 @@ parameters: - species1: - type: - electron: - active: False + electron: Te: val: .5 active: True @@ -13,30 +10,23 @@ parameters: active: True lb: .001 ub: 1. - m: - val: 3.0 - active: True - lb: 2. - ub: 5. fe: - val: [ ] - active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 + active: True + type: dlm dim: 1 - v_res: 0.1 - species2: - type: - ion: - active: False + nv: 64 + params: + m: + val: 2.5 + lb: 2.0 + ub: 5.0 + ion-1: Ti: val: .2 active: False lb: 0.01 ub: 3. + same: False Z: val: 8. active: False @@ -46,12 +36,9 @@ parameters: val: 40. active: False fract: - val: 1.0 + val: 1. active: False general: - type: - general: - active: False amp1: val: 1. active: True @@ -101,5 +88,5 @@ parameters: other: refit: False refit_thresh: 5.0 - calc_sigmas: True + calc_sigmas: false diff --git a/configs/arts-1d/defaults.yaml b/configs/arts-1d/defaults.yaml new file mode 100644 index 00000000..0717f502 --- /dev/null +++ b/configs/arts-1d/defaults.yaml @@ -0,0 +1,184 @@ +parameters: + electron: + Te: + val: .5 + active: False + lb: 0.01 + ub: 3. + gradient: 0.0 + fe: + active: False + type: dlm + ne: + val: .2 + active: False + lb: .001 + ub: 10. + + ion-1: + Ti: + val: .2 + active: False + lb: 0.01 + ub: 3. + Z: + val: 1. + active: False + lb: 1. + ub: 25. + A: + val: 40 + active: False + fract: + val: 1.0 + active: False + general: + amp1: + val: 1. + active: False + lb: 0. + ub: 10. + amp2: + val: 1. + active: False + lb: 0. + ub: 10. + amp3: + val: 1. + active: False + lb: 0. + ub: 10. + lam: + val: 526.5 + active: False + lb: 523. + ub: 528. + ud: + val: 0 + active: False + Va: + val: 0 + active: False + + blur: + val: [] + active: False + lb: 0. + ub: 10. + specCurvature: + val: [] + active: False + lb: .1 + ub: 10. + fitprops: + val: [] + active: False + +other: + crop_window: 1 + BinWidth: 10 + NumBinInRng: 0 + TotalNumBin: 1023 + expandedions: False + extraoptions: + load_ion_spec: False + load_ele_spec: True + fit_IAW: False + fit_EPWb: True + fit_EPWr: True + PhysParams: + background: [0, 0] + norm: 0 + iawoff: 0 + iawfilter: [1, 4, 24, 528] + CCDsize: [1024, 1024] + flatbg: 0 + gain: 1 + points_per_pixel: 2 + ang_res_unit: 10 + lam_res_unit: 5 + refit: True + refit_thresh: 0.25 + calc_sigmas: False + +plotting: + n_sigmas: 3 + rolling_std_width: 5 + data_cbar_u: data #1.1 + data_cbar_l: data #0 + ion_window_start: 525 + ion_window_end: 528 + ele_window_start: 425 + ele_window_end: 625 + + +data: + launch_data_visualizer: False + shotnum: 101675 + shotDay: False + fit_rng: + blue_min: 450 + blue_max: 510 + red_min: 545 + red_max: 650 + iaw_min: 350 + iaw_max: 352 + iaw_cf_min: 526.4 + iaw_cf_max: 526.6 + forward_epw_start: 400 + forward_epw_end: 700 + forward_iaw_start: 524 + forward_iaw_end: 528 + bgscaleE: 1.0 + bgscaleI: 0.1 + bgshotmult: 1 + ele_lam_shift: 0.0 + ion_loss_scale: 1.0 + probe_beam: P9 + dpixel: 2 + lineouts: + type: + pixel + start: 500 + end: 502 + skip: 1 + background: + type: + pixel + slice: 900 + + + +optimizer: + # use adam for nn / stochastic gradient descent + # use L-BFGS-B for full batch / parameter learning + # although adam will work here too + method: adam + hessian: False + loss_method: l2 + y_norm: True + x_norm: False + grad_method: AD + grad_scalar: 0.5 + batch_size: 5 + num_epochs: 1 + learning_rate: 1.0e-4 + parameter_norm: True + refine_factor: 1 + num_mins: 2 + moment_loss: false + save_state: False + +nn: + use: false + conv_filters: 32|16|16 #|32|32 + linear_widths: 32|16 + +dist_fit: + window: + len: 0.2 #should be even + type: hamming # one of [hamming, hann, bartlett] + +mlflow: + experiment: inverse-thomson-scattering + run: base \ No newline at end of file diff --git a/configs/arts-1d/inputs.yaml b/configs/arts-1d/inputs.yaml new file mode 100644 index 00000000..61698617 --- /dev/null +++ b/configs/arts-1d/inputs.yaml @@ -0,0 +1,123 @@ +parameters: + general: + amp1: + val: 1.0 + active: True + lb: 0.01 + ub: 3.75 + amp2: + val: 1.0 + active: True + lb: 0.01 + ub: 3.75 + amp3: + val: 1. + active: False + lb: 0.01 + ub: 3.75 + lam: + val: 524.5 + active: True + lb: 523.0 + ub: 528.0 + Te_gradient: + val: 0.0 + active: False + lb: 0. + ub: 10. + num_grad_points: 1 + ud: + val: 0.0 + angle: 0.0 + active: False + lb: -2.0 + ub: 2.0 + Va: + val: 0.0 + angle: 0.0 + active: False + lb: -1.1 + ub: 2.5 + ne_gradient: + val: 0. + active: False + lb: 0. + ub: 15. + num_grad_points: 1 + ion-1: + Ti: + val: 0.2 + active: False + lb: 0.01 + ub: 3.0 + same: False + Z: + val: 8. + active: False + lb: 1.0 + ub: 25.0 + A: + val: 14. + active: False + fract: + val: 1.0 + active: False + electron: + Te: + val: 1.0 + active: True + lb: 0.01 + ub: 1.5 + ne: + val: 0.4 + active: True + lb: .001 + ub: 1.0 + fe: + + active: False + type: dlm + params: + m: + val: 2.5 + lb: 2.0 + ub: 5.0 + dim: 1 + nv: 2048 + +data: + shotnum: 94475 + lineouts: + type: + range + start: 90 + end: 950 + skip: 20 + background: + type: + Fit + slice: 94477 + +other: + extraoptions: + load_ion_spec: False + load_ele_spec: True + fit_IAW: False + fit_EPWb: True + fit_EPWr: True + spectype: angular + PhysParams: + widIRF: + spect_stddev_ion: 0.015 + spect_stddev_ele: 0.1 + spect_FWHM_ele: 0.9 + ang_FWHM_ele: 1.0 + refit: False + refit_thresh: 5.0 + calc_sigmas: False + +mlflow: + experiment: inverse-thomson-scattering + run: maxwellian_ang_test_2D_v=0.01_scanned_gpu + +machine: gpu \ No newline at end of file diff --git a/configs/arts/defaults.yaml b/configs/arts-2d/defaults.yaml similarity index 86% rename from configs/arts/defaults.yaml rename to configs/arts-2d/defaults.yaml index 84ca2cac..654c24c1 100644 --- a/configs/arts/defaults.yaml +++ b/configs/arts-2d/defaults.yaml @@ -135,15 +135,13 @@ parameters: num_grad_points: 1 ub: 15.0 val: 0.0 - type: - general: null ud: active: false angle: 0.0 lb: -100.0 ub: 100.0 val: 0.0 - species1: + electron: Te: active: false lb: 0.01 @@ -151,33 +149,20 @@ parameters: val: 0.6 fe: active: false - dim: 1 - fe_decrease_strict: false - lb: -100.0 - length: 3999 - m_asym: 1.0 - m_theta: 0.0 - symmetric: false - temp_asym: 1.0 - type: - DLM: null - ub: -0.5 - v_res: 0.1 - val: [] - m: - active: true - lb: 2.0 - matte: false - ub: 5.0 - val: 3.0 + dim: 2 + type: sphericalharmonic + nvx: 128 + params: + init_m: 2.0 + Nl: 1 + nvr: 64 ne: active: false lb: 0.001 ub: 1.0 val: 0.2 - type: - electron: null - species2: + + ion-1: A: active: false val: 40.0 @@ -194,8 +179,7 @@ parameters: fract: active: false val: 0.1 - type: - ion: null + plotting: data_cbar_l: 0 data_cbar_u: data diff --git a/configs/arts/inputs.yaml b/configs/arts-2d/inputs.yaml similarity index 78% rename from configs/arts/inputs.yaml rename to configs/arts-2d/inputs.yaml index 8f9db3aa..f758e873 100644 --- a/configs/arts/inputs.yaml +++ b/configs/arts-2d/inputs.yaml @@ -72,16 +72,13 @@ parameters: num_grad_points: 1 ub: 15.0 val: 0.0 - type: - active: false - general: null ud: active: false angle: 0.0 lb: -10.0 ub: 30.0 val: 5.0 - species1: + electron: Te: active: false gradient_scalar: 10.0 @@ -89,37 +86,23 @@ parameters: ub: 2.0 val: 1.1 fe: - active: true - dim: 2 - fe_decrease_strict: false - lb: -100.0 - length: 64 - m_asym: 1.0 - m_theta: 0.0 - symmetric: false - temp_asym: 1.0 - type: - DLM: null - ub: -0.5 - v_res: 0.1 - val: [] - m: active: false - intens: 4.55 - lb: 2.0 - matte: false - ub: 5.0 - val: 2.0 + dim: 2 + type: sphericalharmonic + nvx: 128 + params: + init_m: 2.0 + Nl: 1 + nvr: 64 + ne: active: false gradient_scalar: 10.0 lb: 0.03 ub: 2.0 val: 0.44 - type: - active: false - electron: null - species2: + + ion-1: A: active: false val: 1.0 @@ -137,6 +120,4 @@ parameters: fract: active: false val: 1.0 - type: - active: false - ion: null \ No newline at end of file + \ No newline at end of file diff --git a/docs/source/TSfitter.rst b/docs/source/ThomsonScattering.rst similarity index 74% rename from docs/source/TSfitter.rst rename to docs/source/ThomsonScattering.rst index 9aa4844f..0d9456ec 100644 --- a/docs/source/TSfitter.rst +++ b/docs/source/ThomsonScattering.rst @@ -3,6 +3,6 @@ TS Fitter API This is the module that handles the loss function, gradient calculation, hessian calculation, and parameter generation -.. autoclass:: tsadar.model.TSFitter.TSFitter +.. autoclass:: tsadar.model.ThomsonScattering.ThomsonScattering :members: :private-members: diff --git a/docs/source/api_main.rst b/docs/source/api_main.rst index 9247fb05..6371ae2b 100644 --- a/docs/source/api_main.rst +++ b/docs/source/api_main.rst @@ -15,7 +15,7 @@ Note: This section is under heavy development and the API is very much subject t :maxdepth: 2 :caption: Middle Level API: - TSfitter + ThomsonScattering spectrum .. toctree:: diff --git a/env.yml b/env.yml index 744556e4..6133f742 100644 --- a/env.yml +++ b/env.yml @@ -2,28 +2,9 @@ name: tsadar-cpu channels: - conda-forge dependencies: - - python=3.10 - - pyhdf - - xlrd + - python=3.11 - pip - pyhdf - xlrd - pip: - - jaxlib - - "jax[cpu]" - - jaxopt - - numpy - - scipy - - matplotlib - - pyyaml - - mlflow - - boto3 - - flatten-dict - - typing-extensions - - optax - - tqdm - - xarray - - mlflow_export_import - - pandas - - interpax - \ No newline at end of file + - -r requirements.txt \ No newline at end of file diff --git a/env_gpu.yml b/env_gpu.yml index ceeb3ae9..f8914ad0 100644 --- a/env_gpu.yml +++ b/env_gpu.yml @@ -1,27 +1,27 @@ +name: tsadar-gpu channels: - conda-forge dependencies: - - python=3.12 - - pyhdf + - python=3.11 - pip - - jax[cuda12] - - jaxopt - - numpy<2 - - scipy - - matplotlib - - pyyaml - - mlflow - - boto3 - - flatten-dict - - typing-extensions - - optax - - tqdm - - xarray - - pandas + - pyhdf - xlrd - pip: - - mlflow_export_import - - interpax - - psutil - - pynvml + - "jax[cuda12]" + - jaxopt + - numpy + - scipy + - matplotlib + - pyyaml + - mlflow + - boto3 + - flatten-dict + - typing-extensions + - optax + - tqdm + - xarray + - mlflow_export_import + - pandas + - interpax + - tabulate \ No newline at end of file diff --git a/requirements.txt b/requirements.txt index 85668d83..378b16cc 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,9 +1,7 @@ -jaxlib jax -numpy<2 +numpy scipy matplotlib -pyhdf xlrd pyyaml mlflow diff --git a/run_tsadar.py b/run_tsadar.py index 40514f1d..0e390ec5 100644 --- a/run_tsadar.py +++ b/run_tsadar.py @@ -5,10 +5,11 @@ from jax import config config.update("jax_enable_x64", True) +# config.update("jax_debug_nans", True) # config.update("jax_disable_jit", True) from tsadar.runner import run, run_job, load_and_make_folders -from tsadar.misc.utils import export_run +from tsadar.utils.misc import export_run if __name__ == "__main__": @@ -23,7 +24,7 @@ run_job(args.run_id, args.mode, nested=None) run_id = args.run_id else: - # run_id, config = load_and_make_folders(args.cfg) + # run_id, config = load_and_make_folders(args.cfg) run_id = run(args.cfg, mode=args.mode) if "MLFLOW_EXPORT" in os.environ: diff --git a/setup.py b/setup.py index f392c2a1..d6e2f7d2 100644 --- a/setup.py +++ b/setup.py @@ -26,10 +26,10 @@ # cmdclass=versioneer.get_cmdclass(), packages=find_packages(), python_requires=">=3.11", - package_data={"tsadar": ["aux/**/*"]}, + package_data={"tsadar": ["external/**/*"]}, include_package_data=True, install_requires=[ - "numpy<2", + "numpy", "scipy", "matplotlib", # "pyhdf", # install using conda, has hdf5 dependencies that need configuring otherwise diff --git a/tests/configs/1d-defaults.yaml b/tests/configs/1d-defaults.yaml new file mode 100644 index 00000000..d515ede7 --- /dev/null +++ b/tests/configs/1d-defaults.yaml @@ -0,0 +1,184 @@ +parameters: + electron: + Te: + val: .6 + active: False + lb: 0.01 + ub: 1.5 + ne: + val: 0.2 + active: False + lb: 0.001 + ub: 1.0 + + ion-1: + Ti: + val: 0.12 + active: False + same: False + lb: 0.001 + ub: 1.0 + Z: + val: 14.0 + active: False + lb: 0.5 + ub: 7.0 + A: + val: 40.0 + active: False + fract: + val: 0.1 + active: False + + general: + amp1: + val: 1. + active: False + lb: 0.01 + ub: 3.75 + amp2: + val: 1. + active: False + lb: 0.01 + ub: 3.75 + amp3: + val: 1.0 + active: False + lb: 0.01 + ub: 3.75 + lam: + val: 526.5 + active: False + lb: 523.0 + ub: 528.0 + Te_gradient: + val: 0.0 + active: False + lb: 0. + ub: 10. + num_grad_points: 1 + ne_gradient: + val: 0. + active: False + lb: 0. + ub: 15. + num_grad_points: 1 + ud: + val: 0.0 + angle: 0.0 + active: False + lb: -100.0 + ub: 100.0 + Va: + val: 0.0 + angle: 0.0 + active: False + lb: -20.5 + ub: 20.5 + +other: + expandedions: False + extraoptions: + spectype: 1d + load_ion_spec: false + load_ele_spec: True + fit_IAW: False + fit_EPWb: True + fit_EPWr: True + absolute_timing: false + PhysParams: + background: [0, 0] + norm: 0 + widIRF: + spect_stddev_ele: 1.3 + spect_stddev_ion: 0.015 + iawoff: 0 + iawfilter: [1, 4, 24, 528] + CCDsize: [1024, 1024] + flatbg: 0 + gain: 1 + points_per_pixel: 5 + ang_res_unit: 10 + lam_res_unit: 5 + refit: False + refit_thresh: 0.25 + calc_sigmas: False + +data: + shotnum: 101675 + shotDay: False + launch_data_visualizer: True + fit_rng: + blue_min: 450 + blue_max: 510 + red_min: 540 + red_max: 625 + iaw_min: 525.5 + iaw_max: 527.5 + iaw_cf_min: 526.49 + iaw_cf_max: 526.51 + forward_epw_start: 400 + forward_epw_end: 700 + forward_iaw_start: 525.75 + forward_iaw_end: 527.25 + bgscaleE: 1.0 + bgscaleI: 0.1 + bgshotmult: 1 + ion_loss_scale: 1.0 + ele_t0: 0 + ion_t0_shift: 0 + ele_lam_shift: 0.0 + probe_beam: P9 + dpixel: 2 + lineouts: + type: + pixel + start: 400 + end: 600 + skip: 50 + background: + type: + pixel + slice: 900 + +plotting: + n_sigmas: 3 + rolling_std_width: 5 + data_cbar_u: 10 + data_cbar_l: 0 + ion_window_start: 525 + ion_window_end: 528 + ele_window_start: 425 + ele_window_end: 625 + +optimizer: + # use adam for nn / stochastic gradient descent + # use L-BFGS-B for full batch / parameter learning + # although adam will work here too + method: l-bfgs-b + moment_loss: false + loss_method: l2 + hessian: False + y_norm: True + x_norm: False + grad_method: AD + batch_size: 2 + num_epochs: 120 + learning_rate: 1.0e-2 + parameter_norm: True + refine_factor: 0 + num_mins: 1 + +nn: + use: false + conv_filters: 32|32|16 + linear_widths: 16|8 + +dist_fit: + window: + len: 0.2 #should be even + type: hamming # one of [hamming, hann, bartlett] + +mlflow: + experiment: inverse-thomson-scattering + run: new \ No newline at end of file diff --git a/tests/configs/1d-inputs.yaml b/tests/configs/1d-inputs.yaml new file mode 100644 index 00000000..ada15a93 --- /dev/null +++ b/tests/configs/1d-inputs.yaml @@ -0,0 +1,92 @@ +parameters: + electron: + Te: + val: .5 + active: True + lb: 0.001 + ub: 1.5 + ne: + val: .2 + active: True + lb: .001 + ub: 1. + fe: + active: True + type: dlm + dim: 1 + nv: 64 + params: + m: + val: 2.5 + lb: 2.0 + ub: 5.0 + ion-1: + Ti: + val: .2 + active: False + lb: 0.01 + ub: 3. + same: False + Z: + val: 8. + active: False + lb: 1. + ub: 25. + A: + val: 40. + active: False + fract: + val: 1. + active: False + general: + amp1: + val: 1. + active: True + lb: 0.01 + ub: 3.75 + amp2: + val: 1. + active: True + lb: 0.01 + ub: 3.75 + amp3: + val: 1. + active: False + lb: 0. + ub: 10. + lam: + val: 524.0 + active: True + lb: 523. + ub: 528. + Te_gradient: + val: 0.0 + active: False + lb: 0. + ub: 10. + num_grad_points: 1 + ne_gradient: + val: 0. + active: False + lb: 0. + ub: 15. + num_grad_points: 1 + ud: + val: 0.0 + angle: 0.0 + active: False + lb: -10.0 + ub: 10.0 + Va: + val: 0.0 + angle: 0.0 + active: False + lb: -20.5 + ub: 20.5 + + +other: + refit: False + refit_thresh: 5.0 + calc_sigmas: false + diff --git a/tests/configs/arts1d_test_defaults.yaml b/tests/configs/arts1d_test_defaults.yaml new file mode 100644 index 00000000..957e51df --- /dev/null +++ b/tests/configs/arts1d_test_defaults.yaml @@ -0,0 +1,184 @@ +parameters: + electron: + Te: + val: .5 + active: False + lb: 0.01 + ub: 3. + gradient: 0.0 + fe: + active: False + type: dlm + ne: + val: .2 + active: False + lb: .001 + ub: 10. + + ion-1: + Ti: + val: .2 + active: False + lb: 0.01 + ub: 3. + Z: + val: 1. + active: False + lb: 1. + ub: 25. + A: + val: 40 + active: False + fract: + val: 1 + active: False + general: + amp1: + val: 1. + active: False + lb: 0. + ub: 10. + amp2: + val: 1. + active: False + lb: 0. + ub: 10. + amp3: + val: 1. + active: False + lb: 0. + ub: 10. + lam: + val: 526.5 + active: False + lb: 523. + ub: 528. + ud: + val: 0 + active: False + Va: + val: 0 + active: False + + blur: + val: [] + active: False + lb: 0. + ub: 10. + specCurvature: + val: [] + active: False + lb: .1 + ub: 10. + fitprops: + val: [] + active: False + +other: + crop_window: 1 + BinWidth: 10 + NumBinInRng: 0 + TotalNumBin: 1023 + expandedions: False + extraoptions: + load_ion_spec: False + load_ele_spec: True + fit_IAW: False + fit_EPWb: True + fit_EPWr: True + PhysParams: + background: [0, 0] + norm: 0 + iawoff: 0 + iawfilter: [1, 4, 24, 528] + CCDsize: [1024, 1024] + flatbg: 0 + gain: 1 + points_per_pixel: 2 + ang_res_unit: 10 + lam_res_unit: 5 + refit: True + refit_thresh: 0.25 + calc_sigmas: False + +plotting: + n_sigmas: 3 + rolling_std_width: 5 + data_cbar_u: data #1.1 + data_cbar_l: data #0 + ion_window_start: 525 + ion_window_end: 528 + ele_window_start: 425 + ele_window_end: 625 + + +data: + launch_data_visualizer: False + shotnum: 101675 + shotDay: False + fit_rng: + blue_min: 450 + blue_max: 510 + red_min: 545 + red_max: 650 + iaw_min: 350 + iaw_max: 352 + iaw_cf_min: 526.4 + iaw_cf_max: 526.6 + forward_epw_start: 400 + forward_epw_end: 700 + forward_iaw_start: 524 + forward_iaw_end: 528 + bgscaleE: 1.0 + bgscaleI: 0.1 + bgshotmult: 1 + ele_lam_shift: 0.0 + ion_loss_scale: 1.0 + probe_beam: P9 + dpixel: 2 + lineouts: + type: + pixel + start: 500 + end: 502 + skip: 1 + background: + type: + pixel + slice: 900 + + + +optimizer: + # use adam for nn / stochastic gradient descent + # use L-BFGS-B for full batch / parameter learning + # although adam will work here too + method: adam + hessian: False + loss_method: l2 + y_norm: True + x_norm: False + grad_method: AD + grad_scalar: 0.5 + batch_size: 5 + num_epochs: 1 + learning_rate: 1.0e-4 + parameter_norm: True + refine_factor: 1 + num_mins: 2 + moment_loss: false + save_state: False + +nn: + use: false + conv_filters: 32|16|16 #|32|32 + linear_widths: 32|16 + +dist_fit: + window: + len: 0.2 #should be even + type: hamming # one of [hamming, hann, bartlett] + +mlflow: + experiment: tsadar-test + run: arts1d \ No newline at end of file diff --git a/tests/configs/arts1d_test_inputs.yaml b/tests/configs/arts1d_test_inputs.yaml new file mode 100644 index 00000000..a32a4bf8 --- /dev/null +++ b/tests/configs/arts1d_test_inputs.yaml @@ -0,0 +1,120 @@ +parameters: + general: + amp1: + val: 1.0 + active: True + lb: 0.01 + ub: 3.75 + amp2: + val: 1.0 + active: True + lb: 0.01 + ub: 3.75 + amp3: + val: 1. + active: False + lb: 0.01 + ub: 3.75 + lam: + val: 524.5 + active: True + lb: 523.0 + ub: 528.0 + Te_gradient: + val: 0.0 + active: False + lb: 0. + ub: 10. + num_grad_points: 1 + ud: + val: 0.0 + angle: 0.0 + active: False + lb: -2.0 + ub: 2.0 + Va: + val: 0.0 + angle: 0.0 + active: False + lb: -1.1 + ub: 2.5 + ne_gradient: + val: 0. + active: False + lb: 0. + ub: 15. + num_grad_points: 1 + ion-1: + Ti: + val: 0.2 + active: False + lb: 0.01 + ub: 3.0 + same: False + Z: + val: 8. + active: False + lb: 1.0 + ub: 25.0 + A: + val: 14. + active: False + electron: + Te: + val: 1.0 + active: True + lb: 0.01 + ub: 1.5 + ne: + val: 0.4 + active: True + lb: .001 + ub: 1.0 + fe: + + active: True + type: dlm + params: + m: + val: 2.5 + lb: 2.0 + ub: 5.0 + dim: 1 + nv: 256 + +data: + shotnum: 94475 + lineouts: + type: + range + start: 90 + end: 950 + skip: 20 + background: + type: + Fit + slice: 94477 + +other: + extraoptions: + load_ion_spec: False + load_ele_spec: True + fit_IAW: False + fit_EPWb: True + fit_EPWr: True + spectype: angular + PhysParams: + widIRF: + spect_stddev_ion: 0.015 + spect_stddev_ele: 0.1 + spect_FWHM_ele: 0.9 + ang_FWHM_ele: 1.0 + refit: False + refit_thresh: 5.0 + calc_sigmas: False + +mlflow: + experiment: tsadar-tests + run: arts1d-fwd + +machine: gpu \ No newline at end of file diff --git a/tests/configs/arts2d_test_defaults.yaml b/tests/configs/arts2d_test_defaults.yaml new file mode 100644 index 00000000..354a8455 --- /dev/null +++ b/tests/configs/arts2d_test_defaults.yaml @@ -0,0 +1,191 @@ +data: + background: + slice: 900 + type: pixel + bgscaleE: 0.0 + bgscaleI: 0.1 + bgshotmult: 1 + dpixel: 2 + ele_lam_shift: 0.0 + ele_t0: 0 + fit_rng: + blue_max: 510 + blue_min: 450 + forward_epw_end: 700 + forward_epw_start: 400 + forward_iaw_end: 528 + forward_iaw_start: 524 + iaw_cf_max: 526.6 + iaw_cf_min: 526.4 + iaw_max: 352 + iaw_min: 350 + red_max: 650 + red_min: 545 + ion_loss_scale: 1.0 + ion_t0_shift: 0 + launch_data_visualizer: false + lineouts: + end: 502 + skip: 1 + start: 500 + type: pixel + probe_beam: P9 + shotDay: false + shotnum: 101675 +dist_fit: + smooth: false + window: + len: 0.2 + type: bartlett +mlflow: + experiment: tsadar-test + run: arts2d +nn: + conv_filters: 32|16|16 + linear_widths: 32|16 + use: false +optimizer: + batch_size: 6 + grad_method: AD + grad_scalar: 1.0 + hessian: false + learning_rate: 0.0001 + method: rmsprop + moment_loss: false + loss_method: l2 + num_epochs: 1000 + num_mins: 1 + parameter_norm: true + refine_factor: 1 + save_state: true + save_state_freq: 20 + x_norm: false + y_norm: true +other: + BinWidth: 10 + CCDsize: + - 1024 + - 1024 + NumBinInRng: 0 + PhysParams: + background: + - 0 + - 0 + norm: 0 + TotalNumBin: 1023 + ang_res_unit: 10 + calc_sigmas: false + crop_window: 1 + expandedions: false + extraoptions: + fit_EPWb: true + fit_EPWr: true + fit_IAW: false + load_ele_spec: true + load_ion_spec: false + flatbg: 0 + gain: 1 + iawfilter: + - 1 + - 4 + - 24 + - 528 + iawoff: 0 + lam_res_unit: 5 + points_per_pixel: 1 + refit: true + refit_thresh: 0.25 +parameters: + general: + Te_gradient: + active: false + lb: 0.0 + num_grad_points: 1 + ub: 10.0 + val: 0.0 + Va: + active: false + angle: 0.0 + lb: -20.5 + ub: 20.5 + val: 0.0 + amp1: + active: true + lb: 0.01 + ub: 3.75 + val: 1.0 + amp2: + active: false + lb: 0.01 + ub: 3.75 + val: 1.0 + amp3: + active: false + lb: 0.01 + ub: 3.75 + val: 1.0 + lam: + active: false + lb: 523.0 + ub: 528.0 + val: 526.5 + ne_gradient: + active: false + lb: 0.0 + num_grad_points: 1 + ub: 15.0 + val: 0.0 + ud: + active: false + angle: 0.0 + lb: -100.0 + ub: 100.0 + val: 0.0 + electron: + Te: + active: false + lb: 0.01 + ub: 1.5 + val: 0.6 + fe: + active: false + dim: 2 + type: sphericalharmonic + nvx: 128 + params: + init_m: 2.7 + Nl: 1 + nvr: 64 + ne: + active: false + lb: 0.001 + ub: 1.0 + val: 0.2 + + ion-1: + A: + active: false + val: 40.0 + Ti: + active: false + lb: 0.001 + ub: 1.0 + val: 0.12 + Z: + active: false + lb: 0.5 + ub: 7.0 + val: 14.0 + fract: + active: false + val: 0.1 + +plotting: + data_cbar_l: 0 + data_cbar_u: data + ele_window_end: 625 + ele_window_start: 425 + ion_window_end: 528 + ion_window_start: 525 + n_sigmas: 3 + rolling_std_width: 5 \ No newline at end of file diff --git a/tests/configs/arts2d_test_inputs.yaml b/tests/configs/arts2d_test_inputs.yaml new file mode 100644 index 00000000..c29b4d95 --- /dev/null +++ b/tests/configs/arts2d_test_inputs.yaml @@ -0,0 +1,123 @@ +data: + background: + slice: 10002 + type: Fit + lineouts: + end: 950 + skip: 20 + start: 90 + type: range + shot_rot: 90.0 + shotnum: 10007 +machine: gpu +mlflow: + experiment: tsadar-test + run: arts2d +other: + PhysParams: + widIRF: + ang_FWHM_ele: 1.0 + spect_FWHM_ele: 0.9 + spect_stddev_ele: 0.1 + spect_stddev_ion: 0.015 + calc_sigmas: false + extraoptions: + fit_EPWb: true + fit_EPWr: true + fit_IAW: false + load_ele_spec: true + load_ion_spec: false + spectype: angular + refit: false + refit_thresh: 5.0 +parameters: + general: + Te_gradient: + active: false + lb: 0.0 + num_grad_points: 1 + ub: 10.0 + val: 0.0 + Va: + active: false + angle: 0.0 + lb: -40.5 + ub: 40.5 + val: 0.0 + amp1: + active: true + lb: 0.01 + ub: 3.75 + val: 1.0 + amp2: + active: true + lb: 0.01 + ub: 3.75 + val: 1.0 + amp3: + active: false + lb: 0.01 + ub: 3.75 + val: 1.0 + lam: + active: true + lb: 523.0 + ub: 528.0 + val: 526.5 + ne_gradient: + active: false + lb: 0.0 + num_grad_points: 1 + ub: 15.0 + val: 0.0 + ud: + active: false + angle: 0.0 + lb: -10.0 + ub: 30.0 + val: 5.0 + electron: + Te: + active: false + gradient_scalar: 10.0 + lb: 0.01 + ub: 2.0 + val: 1.1 + fe: + active: false + dim: 2 + type: sphericalharmonic + nvx: 256 + params: + init_m: 2.7 + init_f10: 1.0e-2 + init_f11: 1.0e-2 + Nl: 1 + nvr: 32 + + ne: + active: true + gradient_scalar: 10.0 + lb: 0.03 + ub: 2.0 + val: 0.44 + + ion-1: + A: + active: false + val: 1.0 + Ti: + active: false + lb: 0.01 + same: false + ub: 1.0 + val: 0.03 + Z: + active: false + lb: 0.5 + ub: 1.0 + val: 1.0 + fract: + active: false + val: 1.0 + \ No newline at end of file diff --git a/tests/configs/epw_defaults.yaml b/tests/configs/epw_defaults.yaml index e64e8a58..df87646b 100644 --- a/tests/configs/epw_defaults.yaml +++ b/tests/configs/epw_defaults.yaml @@ -1,7 +1,5 @@ parameters: - species1: - type: - electron: + electron: Te: val: .6 active: False @@ -12,30 +10,17 @@ parameters: active: False lb: 0.001 ub: 1.0 - m: - val: 3.0 - active: False - lb: 2.0 - ub: 5.0 fe: - val: [ ] - active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 - fe_decrease_strict: False - symmetric: False + type: "DLM" dim: 1 - v_res: 0.1 - temp_asym: 1.0 - m_theta: 0.0 - m_asym: 1. + nv: 64 + params: + m: + val: 2.0 + ub: 5.0 + lb: 2.0 - species2: - type: - ion: + ion-1: Ti: val: 0.12 active: False @@ -54,8 +39,6 @@ parameters: active: False general: - type: - general: amp1: val: 1. active: False @@ -199,5 +182,5 @@ dist_fit: type: hamming # one of [hamming, hann, bartlett] mlflow: - experiment: inverse-thomson-scattering - run: base + experiment: tsadar-test + run: 1d diff --git a/tests/configs/epw_inputs.yaml b/tests/configs/epw_inputs.yaml index 9c29dc61..497f30f1 100644 --- a/tests/configs/epw_inputs.yaml +++ b/tests/configs/epw_inputs.yaml @@ -1,8 +1,5 @@ parameters: - species1: - type: - electron: - active: False + electron: Te: val: .6 active: False @@ -13,31 +10,17 @@ parameters: active: False lb: 0.001 ub: 1.0 - m: - val: 2.0 - active: False - lb: 2.0 - ub: 4.9 fe: - val: [ ] - active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 - fe_decrease_strict: False - symmetric: False + type: "DLM" dim: 1 - v_res: 0.1 - temp_asym: 1.0 - m_theta: 0.0 - m_asym: 1. + nv: 64 + params: + m: + val: 2.0 + ub: 5.0 + lb: 2.0 - species2: - type: - ion: - active: False + ion-1: Ti: val: 0.2 active: True @@ -57,9 +40,6 @@ parameters: active: False general: - type: - general: - active: False amp1: val: 1. active: False diff --git a/tests/configs/time_test_defaults.yaml b/tests/configs/time_test_defaults.yaml index 37671407..990f3e7d 100644 --- a/tests/configs/time_test_defaults.yaml +++ b/tests/configs/time_test_defaults.yaml @@ -1,7 +1,5 @@ parameters: - species1: - type: - electron: + electron: Te: val: .6 active: False @@ -12,31 +10,14 @@ parameters: active: False lb: 0.001 ub: 1.0 - m: - val: 3.0 - active: False - lb: 2.0 - ub: 5.0 - matte: False fe: - val: [ ] active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 - fe_decrease_strict: False - symmetric: False + type: "DLM" dim: 1 - v_res: 0.1 - temp_asym: 1.0 - m_theta: 0.0 - m_asym: 1. + nv: 64 + - species2: - type: - ion: + ion-1: Ti: val: 0.12 active: False @@ -56,8 +37,6 @@ parameters: active: False general: - type: - general: amp1: val: 1. active: False @@ -112,9 +91,13 @@ other: fit_EPWb: True fit_EPWr: True absolute_timing: false + spectype: 1d PhysParams: background: [0, 0] norm: 0 + widIRF: + spect_stddev_ele: 1.3 + spect_stddev_ion: 0.015 iawoff: 0 iawfilter: [1, 4, 24, 528] CCDsize: [1024, 1024] @@ -203,5 +186,5 @@ dist_fit: type: hamming # one of [hamming, hann, bartlett] mlflow: - experiment: inverse-thomson-scattering - run: base \ No newline at end of file + experiment: tsadar-tests + run: time_test \ No newline at end of file diff --git a/tests/configs/time_test_inputs.yaml b/tests/configs/time_test_inputs.yaml index da98408e..8c7e8fe0 100644 --- a/tests/configs/time_test_inputs.yaml +++ b/tests/configs/time_test_inputs.yaml @@ -1,8 +1,5 @@ parameters: - species1: - type: - electron: - active: False + electron: Te: val: .5 active: True @@ -13,25 +10,17 @@ parameters: active: True lb: .001 ub: 1. - m: - val: 3.0 - active: True - lb: 2. - ub: 5. fe: - val: [ ] - active: False - length: 3999 - type: - DLM: - lb: -100. - ub: -0.5 + active: True + type: dlm + params: + m: + val: 3.0 + lb: 2.0 + ub: 5.0 dim: 1 - v_res: 0.1 - species2: - type: - ion: - active: False + nv: 64 + ion-1: Ti: val: .2 active: False @@ -49,9 +38,6 @@ parameters: val: 1.0 active: False general: - type: - general: - active: False amp1: val: 1. active: True @@ -101,5 +87,5 @@ parameters: other: refit: False refit_thresh: 5.0 - calc_sigmas: True + calc_sigmas: False diff --git a/tsadar/data_handleing/__init__.py b/tests/test_form_factor/__init__.py similarity index 100% rename from tsadar/data_handleing/__init__.py rename to tests/test_form_factor/__init__.py diff --git a/tests/test_epw.py b/tests/test_form_factor/test_epw.py similarity index 67% rename from tests/test_epw.py rename to tests/test_form_factor/test_epw.py index b750c88f..40316f4b 100644 --- a/tests/test_epw.py +++ b/tests/test_form_factor/test_epw.py @@ -10,8 +10,10 @@ config.update("jax_enable_x64", True) from scipy.signal import find_peaks -from tsadar.model.physics.form_factor import FormFactor -from tsadar.distribution_functions.gen_num_dist_func import DistFunc +from tsadar.core.physics.form_factor import FormFactor +from tsadar.core.modules import ThomsonParams + +# from tsadar.distribution_functions.gen_num_dist_func import DistFunc def test_epw(): @@ -31,36 +33,29 @@ def test_epw(): # Test #1: Bohm-Gross test, calculate a spectrum and compare the resonance to the Bohm gross dispersion relation npts = 2048 - num_dist_func = DistFunc(config["parameters"]["species1"]) - vcur, fecur = num_dist_func(config["parameters"]["species1"]["m"]["val"]) + # num_dist_func = DistFunc(config["parameters"]["electron"]) + # vcur, fecur = num_dist_func(config["parameters"]["electron"]["m"]["val"]) + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False) electron_form_factor = FormFactor( [400, 700], npts=npts, - fe_dim=num_dist_func.dim, - vax=vcur, + lam_shift=config["data"]["ele_lam_shift"], + scattering_angles={"sa": np.array([60])}, + num_grad_points=config["parameters"]["general"]["ne_gradient"]["num_grad_points"], + ud_ang=None, + va_ang=None, ) sa = np.array([60]) - params = { - "general": { - "Va": config["parameters"]["general"]["Va"]["val"], - "ud": config["parameters"]["general"]["ud"]["val"], - } - } - - ThryE, lamAxisE = jit(electron_form_factor)( - params, - jnp.array(config["parameters"]["species1"]["ne"]["val"] * 1e20).reshape(1, 1), - jnp.array(config["parameters"]["species1"]["Te"]["val"]).reshape(1, 1), - config["parameters"]["species2"]["A"]["val"], - config["parameters"]["species2"]["Z"]["val"], - config["parameters"]["species2"]["Ti"]["val"], - config["parameters"]["species2"]["fract"]["val"], - sa, - (fecur, vcur), - config["parameters"]["general"]["lam"]["val"], - ) + # params = { + # "general": { + # "Va": config["parameters"]["general"]["Va"]["val"], + # "ud": config["parameters"]["general"]["ud"]["val"], + # } + # } + physical_params = ts_params() + ThryE, lamAxisE = jit(electron_form_factor)(physical_params) ThryE = np.squeeze(ThryE) test = deepcopy(np.asarray(ThryE)) peaks, peak_props = find_peaks(test, height=(0.01, 0.5), prominence=0.05) diff --git a/tests/test_iaw.py b/tests/test_form_factor/test_iaw.py similarity index 63% rename from tests/test_iaw.py rename to tests/test_form_factor/test_iaw.py index cb543241..5f0ed38c 100644 --- a/tests/test_iaw.py +++ b/tests/test_form_factor/test_iaw.py @@ -9,8 +9,8 @@ config.update("jax_enable_x64", True) from numpy.testing import assert_allclose from scipy.signal import find_peaks -from tsadar.model.physics.form_factor import FormFactor -from tsadar.distribution_functions.gen_num_dist_func import DistFunc +from tsadar.core.physics.form_factor import FormFactor +from tsadar.core.modules import ThomsonParams def test_iaw(): @@ -36,26 +36,29 @@ def test_iaw(): re = 2.8179e-13 # classical electron radius cm Esq = Me * C**2 * re # sq of the electron charge keV cm - num_dist_func = DistFunc(config["parameters"]["species1"]) - vcur, fecur = num_dist_func(config["parameters"]["species1"]["m"]["val"]) + # num_dist_func = DistFunc(config["parameters"]["electron"]) + # vcur, fecur = num_dist_func(config["parameters"]["electron"]["m"]["val"]) ion_form_factor = FormFactor( [525, 528], npts=8192, - fe_dim=num_dist_func.dim, - vax=vcur, + lam_shift=0.0, + scattering_angles={"sa": np.array([60])}, + num_grad_points=config["parameters"]["general"]["ne_gradient"]["num_grad_points"], + ud_ang=None, + va_ang=None, ) # xie = np.linspace(-7, 7, 1024) # ion_form_factor = FormFactor([525, 528], npts=8192) - sa = np.array([60]) - params = { - "general": { - "Va": config["parameters"]["general"]["Va"]["val"], - "ud": config["parameters"]["general"]["ud"]["val"], - } - } + # sa = np.array([60]) + # params = { + # "general": { + # "Va": config["parameters"]["general"]["Va"]["val"], + # "ud": config["parameters"]["general"]["ud"]["val"], + # } + # } # num_dist_func = get_num_dist_func({"DLM": []}, xie) # fecur = num_dist_func(2.0) # lam = 526.5 @@ -66,20 +69,21 @@ def test_iaw(): # cur_Te = 0.5 # ThryI, lamAxisI = jit(ion_form_factor)(inps, cur_ne, cur_Te, sa, (fecur, xie), lam) - ThryI, lamAxisI = jit(ion_form_factor)( - params, - jnp.array(config["parameters"]["species1"]["ne"]["val"] * 1e20).reshape(1, 1), - jnp.array(config["parameters"]["species1"]["Te"]["val"]).reshape(1, 1), - config["parameters"]["species2"]["A"]["val"], - config["parameters"]["species2"]["Z"]["val"], - config["parameters"]["species2"]["Ti"]["val"], - config["parameters"]["species2"]["fract"]["val"], - sa, - (fecur, vcur), - config["parameters"]["general"]["lam"]["val"], - ) + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False) + physical_params = ts_params() + ThryI, lamAxisI = jit(ion_form_factor)(physical_params) + # params, + # jnp.array(config["parameters"]["electron"]["ne"]["val"] * 1e20).reshape(1, 1), + # jnp.array(config["parameters"]["electron"]["Te"]["val"]).reshape(1, 1), + # config["parameters"]["ion-1"]["A"]["val"], + # config["parameters"]["ion-1"]["Z"]["val"], + # config["parameters"]["ion-1"]["Ti"]["val"], + # config["parameters"]["ion-1"]["fract"]["val"], + # sa, + # (fecur, vcur), + # config["parameters"]["general"]["lam"]["val"], + # ) - ThryI = jnp.real(ThryI) ThryI = jnp.mean(ThryI, axis=0) ThryI = np.squeeze(ThryI) diff --git a/tests/test_forward/ThryE-1d.npy b/tests/test_forward/ThryE-1d.npy new file mode 100644 index 00000000..61605934 Binary files /dev/null and b/tests/test_forward/ThryE-1d.npy differ diff --git a/tests/test_forward/ThryE-arts1d.npy b/tests/test_forward/ThryE-arts1d.npy new file mode 100644 index 00000000..eecd82d0 Binary files /dev/null and b/tests/test_forward/ThryE-arts1d.npy differ diff --git a/tests/test_forward/ThryE-arts2d.npy b/tests/test_forward/ThryE-arts2d.npy new file mode 100644 index 00000000..b36b78db Binary files /dev/null and b/tests/test_forward/ThryE-arts2d.npy differ diff --git a/tsadar/data_handleing/calibrations/__init__.py b/tests/test_forward/__init__.py similarity index 100% rename from tsadar/data_handleing/calibrations/__init__.py rename to tests/test_forward/__init__.py diff --git a/tests/test_forward/test_1d.py b/tests/test_forward/test_1d.py new file mode 100644 index 00000000..72faef65 --- /dev/null +++ b/tests/test_forward/test_1d.py @@ -0,0 +1,88 @@ +from jax import config + +config.update("jax_enable_x64", True) + + +import numpy as np +import matplotlib.pyplot as plt +import yaml, os, mlflow, tempfile +from flatten_dict import flatten, unflatten + +from tsadar.utils import misc +from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic +from tsadar.core.modules import ThomsonParams +from tsadar.utils.data_handling.calibration import get_scattering_angles + + +def test_1d_forward_pass(): + """ + Runs a forward pass with the Thomson scattering diagnostic and ThomsonParams classes. Saves the results to mlflow. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + + mlflow.set_experiment("tsadar-tests") + with mlflow.start_run(run_name="test_1d_fwd"): + with open("tests/configs/1d-defaults.yaml", "r") as fi: + defaults = yaml.safe_load(fi) + + with open("tests/configs/1d-inputs.yaml", "r") as fi: + inputs = yaml.safe_load(fi) + + defaults = flatten(defaults) + defaults.update(flatten(inputs)) + config = unflatten(defaults) + + # get scattering angles and weights + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + sas = get_scattering_angles(config) + + dummy_batch = { + "i_data": np.array([1]), + "e_data": np.array([1]), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) + ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) + + # np.save("tests/test_forward/ThryE-1d.npy", ThryE) + + ground_truth = np.load("tests/test_forward/ThryE-1d.npy") + + misc.log_mlflow(config) + with tempfile.TemporaryDirectory() as td: + fig, ax = plt.subplots(1, 1, figsize=(9, 4), tight_layout=True) + ax.plot(np.squeeze(lamAxisE), np.squeeze(ThryE), label="Model") + ax.plot(np.squeeze(lamAxisE), np.squeeze(ground_truth), label="Ground Truth") + ax.grid() + ax.legend() + ax.set_xlabel("Wavelength (nm)") + ax.set_ylabel("Intensity (arb. units)") + ax.set_title("Electron Spectrum") + fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + mlflow.log_artifacts(td) + + np.testing.assert_allclose(ThryE, ground_truth, rtol=1e-4) + + +if __name__ == "__main__": + test_1d_forward_pass() diff --git a/tests/test_forward/test_angular_1d.py b/tests/test_forward/test_angular_1d.py new file mode 100644 index 00000000..a9ab9370 --- /dev/null +++ b/tests/test_forward/test_angular_1d.py @@ -0,0 +1,94 @@ +from jax import config + +config.update("jax_enable_x64", True) + +import yaml, mlflow, os, tempfile +import numpy as np +import matplotlib.pyplot as plt +from flatten_dict import flatten, unflatten + +from tsadar.utils import misc + +from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic +from tsadar.core.modules import ThomsonParams +from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations + + +def test_arts1d_forward_pass(): + """ + Runs a forward pass with the Thomson scattering diagnostic and ThomsonParams classes. Saves the results to mlflow. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + + mlflow.set_experiment("tsadar-tests") + with mlflow.start_run(run_name="test_arts1d_fwd") as run: + with open("tests/configs/arts1d_test_defaults.yaml", "r") as fi: + defaults = yaml.safe_load(fi) + + with open("tests/configs/arts1d_test_inputs.yaml", "r") as fi: + inputs = yaml.safe_load(fi) + + defaults = flatten(defaults) + defaults.update(flatten(inputs)) + config = unflatten(defaults) + + # get scattering angles and weights + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + sas = get_scattering_angles(config) + + [axisxE, _, _, _, _, _] = get_calibrations( + 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] + ) # shot number hardcoded to get calibration + config["other"]["extraoptions"]["spectype"] = "angular_full" + + sas["angAxis"] = axisxE + + dummy_batch = { + "i_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "e_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False, activate=True) + ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) + # np.save("tests/test_forward/ThryE-arts1d.npy", ThryE) + + ground_truth = np.load("tests/test_forward/ThryE-arts1d.npy") + + misc.log_mlflow(config) + fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) + c = ax[0].contourf(ThryE.T) + ax[0].set_title("Forward Model") + fig.colorbar(c) + + c = ax[1].contourf(ground_truth.T) + ax[1].set_title("Ground Truth") + fig.colorbar(c) + with tempfile.TemporaryDirectory() as td: + fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + mlflow.log_artifacts(td) + + np.testing.assert_allclose(ThryE, ground_truth, rtol=1e-4) + + +if __name__ == "__main__": + test_arts1d_forward_pass() diff --git a/tests/test_forward/test_angular_2d.py b/tests/test_forward/test_angular_2d.py new file mode 100644 index 00000000..47bda087 --- /dev/null +++ b/tests/test_forward/test_angular_2d.py @@ -0,0 +1,107 @@ +import pytest +from jax import config, devices + +config.update("jax_enable_x64", True) + +import yaml, mlflow, os, tempfile, time +import numpy as np +from equinox import filter_jit +import matplotlib.pyplot as plt +from flatten_dict import flatten, unflatten + +from tsadar.utils import misc +from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic +from tsadar.core.modules import ThomsonParams +from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations + + +def test_arts2d_forward_pass(): + """ + Runs a forward pass with the Thomson scattering diagnostic and ThomsonParams classes. Saves the results to mlflow. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + + if not any(["gpu" == device.platform for device in devices()]): + pytest.skip("Takes too long without a GPU") + + mlflow.set_experiment("tsadar-tests") + with mlflow.start_run(run_name="test_arts2d_fwd") as run: + with tempfile.TemporaryDirectory() as td: + + t0 = time.time() + with open("tests/configs/arts2d_test_defaults.yaml", "r") as fi: + defaults = yaml.safe_load(fi) + + with open("tests/configs/arts2d_test_inputs.yaml", "r") as fi: + inputs = yaml.safe_load(fi) + + defaults = flatten(defaults) + defaults.update(flatten(inputs)) + config = unflatten(defaults) + + with open(os.path.join(td, "config.yaml"), "w") as fi: + yaml.dump(config, fi) + + # get scattering angles and weights + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + sas = get_scattering_angles(config) + + [axisxE, _, _, _, _, _] = get_calibrations( + 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] + ) # shot number hardcoded to get calibration + config["other"]["extraoptions"]["spectype"] = "angular_full" + + sas["angAxis"] = axisxE + + dummy_batch = { + "i_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "e_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=False) + ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params, dummy_batch) + # np.save("tests/test_forward/ThryE-arts2d.npy", ThryE) + + ground_truth = np.load("tests/test_forward/ThryE-arts2d.npy") + + misc.log_mlflow(config) + fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True) + c = ax[0].contourf(ThryE.T) + ax[0].set_title("Forward Model") + fig.colorbar(c) + + c = ax[1].contourf(ground_truth.T) + ax[1].set_title("Ground Truth") + fig.colorbar(c) + + fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + mlflow.log_artifacts(td) + + mlflow.log_metric("runtime-sec", time.time() - t0) + + # np.testing.assert_allclose(ThryE, ground_truth, rtol=1e-4) + misc.export_run(run.info.run_id) + + +if __name__ == "__main__": + test_arts2d_forward_pass() diff --git a/tsadar/distribution_functions/__init__.py b/tests/test_inverse/__init__.py similarity index 100% rename from tsadar/distribution_functions/__init__.py rename to tests/test_inverse/__init__.py diff --git a/tests/test_time_data.py b/tests/test_inverse/test_1d.py similarity index 87% rename from tests/test_time_data.py rename to tests/test_inverse/test_1d.py index d149ba05..e982df64 100644 --- a/tests/test_time_data.py +++ b/tests/test_inverse/test_1d.py @@ -10,8 +10,8 @@ # config.update("jax_disable_jit", True) # config.update("jax_check_tracer_leaks", True) -from tsadar import fitter -from tsadar.misc import utils +from tsadar.inverse import fitter +from tsadar.utils import misc @pytest.mark.parametrize("nn", [False]) @@ -37,7 +37,7 @@ def test_data(nn): mlflow.set_experiment(config["mlflow"]["experiment"]) with mlflow.start_run() as run: - utils.log_params(config) + misc.log_mlflow(config) config["num_cores"] = int(mp.cpu_count()) t0 = time.time() @@ -52,9 +52,9 @@ def test_data(nn): assert_allclose(fit_results["amp1_general"][0], 0.734, rtol=1e-1) # 0.9257 assert_allclose(fit_results["amp2_general"][0], 0.519, rtol=1e-1) # 0.6727 assert_allclose(fit_results["lam_general"][0], 524.016, rtol=5e-3) # 524.2455 - assert_allclose(fit_results["Te_species1"][0], 0.5994, rtol=1e-1) # 0.67585 - assert_allclose(fit_results["ne_species1"][0], 0.2256, rtol=5e-2) # 0.21792 - assert_allclose(fit_results["m_species1"][0], 2.987, rtol=15e-2) # 3.3673 + assert_allclose(fit_results["Te_electron"][0], 0.5994, rtol=1e-1) # 0.67585 + assert_allclose(fit_results["ne_electron"][0], 0.2256, rtol=5e-2) # 0.21792 + assert_allclose(fit_results["m_electron"][0], 2.987, rtol=15e-2) # 3.3673 if __name__ == "__main__": diff --git a/tests/test_inverse/test_1d_random.py b/tests/test_inverse/test_1d_random.py new file mode 100644 index 00000000..2675f9a1 --- /dev/null +++ b/tests/test_inverse/test_1d_random.py @@ -0,0 +1,177 @@ +from jax import config + +config.update("jax_enable_x64", True) + +from jax import numpy as jnp +from jax.flatten_util import ravel_pytree +from scipy.optimize import minimize +import equinox as eqx +import numpy as np +import matplotlib.pyplot as plt +import yaml, os, mlflow, tempfile, optax, tqdm +from flatten_dict import flatten, unflatten + +from tsadar.utils import misc +from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic +from tsadar.core.modules import ThomsonParams, get_filter_spec +from tsadar.utils.data_handling.calibration import get_scattering_angles + + +def _perturb_params_(rng, params): + """ + Perturbs the parameters for the forward pass. + + Args: + params: Dictionary - Parameters to be perturbed + + Returns: + Dictionary - Perturbed parameters + + """ + + params["electron"]["fe"]["params"]["m"]["val"] = float(rng.uniform(2.0, 3.5)) + params["electron"]["Te"]["val"] = float(rng.uniform(0.5, 1.5)) + params["electron"]["ne"]["val"] = float(rng.uniform(0.1, 0.7)) + + params["general"]["amp1"]["val"] = float(rng.uniform(0.5, 2.5)) + params["general"]["amp2"]["val"] = float(rng.uniform(0.5, 2.5)) + params["general"]["lam"]["val"] = float(rng.uniform(523, 527)) + + # for key in params["general"].keys(): + # params[key]["val"] *= rng.uniform(0.75, 1.25) + + # for key in params["ion-1"].keys(): + # params[key]["val"] *= rng.uniform(0.75, 1.25) + + return params + + +def _floatify(_params_, prefix="gt"): + flattened_params = flatten(_params_) + new_params = {} + for key in flattened_params.keys(): + new_params[(prefix,) + key] = float(flattened_params[key][0]) + return unflatten(new_params) + + +def test_1d_inverse(): + """ + Runs a forward pass with the Thomson scattering diagnostic and ThomsonParams classes. Saves the results to mlflow. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + + mlflow.set_experiment("tsadar-tests") + with mlflow.start_run(run_name="test_1d_inverse"): + with open("tests/configs/1d-defaults.yaml", "r") as fi: + defaults = yaml.safe_load(fi) + + with open("tests/configs/1d-inputs.yaml", "r") as fi: + inputs = yaml.safe_load(fi) + + defaults = flatten(defaults) + defaults.update(flatten(inputs)) + config = unflatten(defaults) + + # get scattering angles and weights + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + sas = get_scattering_angles(config) + + dummy_batch = { + "i_data": np.array([1]), + "e_data": np.array([1]), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + rng = np.random.default_rng() + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + config["parameters"] = _perturb_params_(rng, config["parameters"]) + misc.log_mlflow(config) + ts_params_gt = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) + ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params_gt, dummy_batch) + ground_truth = {"ThryE": ThryE, "lamAxisE": lamAxisE, "ThryI": ThryI, "lamAxisI": lamAxisI} + + loss = 1 + while np.nan_to_num(loss, nan=1) > 1e-3: + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + config["parameters"] = _perturb_params_(rng, config["parameters"]) + ts_params_fit = ThomsonParams(config["parameters"], num_params=1, batch=True, activate=True) + diff_params, static_params = eqx.partition( + ts_params_fit, filter_spec=get_filter_spec(cfg_params=config["parameters"], ts_params=ts_params_fit) + ) + + def loss_fn(_diff_params): + _all_params = eqx.combine(_diff_params, static_params) + ThryE, ThryI, _, _ = ts_diag(_all_params, dummy_batch) + return jnp.mean(jnp.square(ThryE - ground_truth["ThryE"])) + + use_optax = False + if use_optax: + opt = optax.adam(0.004) + + opt_state = opt.init(diff_params) + for i in (pbar := tqdm.tqdm(range(1000))): + loss, grad_loss = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn))(diff_params) + updates, opt_state = opt.update(grad_loss, opt_state) + diff_params = eqx.apply_updates(diff_params, updates) + pbar.set_description(f"Loss: {loss:.4f}") + + else: + flattened_diff_params, unravel = ravel_pytree(diff_params) + + def scipy_vg_fn(diff_params_flat): + diff_params_pytree = unravel(diff_params_flat) + loss, grads = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn))(diff_params_pytree) + flattened_grads, _ = ravel_pytree(grads) + + return float(loss), np.array(flattened_grads) + + res = minimize(scipy_vg_fn, flattened_diff_params, method="L-BFGS-B", jac=True, options={"disp": True}) + + diff_params = unravel(res["x"]) + loss = res["fun"] + + gt_params = _floatify(ts_params_gt.get_unnormed_params(), prefix="gt") + learned_params = _floatify(eqx.combine(diff_params, static_params).get_unnormed_params(), prefix="learned") + misc.log_mlflow({"loss": loss} | learned_params | gt_params, which="metrics") + ThryE, _, _, _ = ts_diag(eqx.combine(diff_params, static_params), dummy_batch) + + with tempfile.TemporaryDirectory() as td: + fig, ax = plt.subplots(1, 1, figsize=(9, 4), tight_layout=True) + ax.plot(np.squeeze(lamAxisE), np.squeeze(ThryE), label="Model") + ax.plot(np.squeeze(lamAxisE), np.squeeze(ground_truth["ThryE"]), label="Ground Truth") + ax.grid() + ax.legend() + ax.set_xlabel("Wavelength (nm)") + ax.set_ylabel("Intensity (arb. units)") + ax.set_title("Electron Spectrum") + fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + mlflow.log_artifacts(td) + + # np.testing.assert_allclose(ThryE, ground_truth["ThryE"], atol=0, rtol=0.2) + + gt_flat = flatten(gt_params) + learned_flat = flatten(learned_params) + + for key in gt_flat.keys(): + np.testing.assert_allclose(gt_flat[key], learned_flat[("learned",) + key[1:]], atol=0, rtol=0.1) + + +if __name__ == "__main__": + test_1d_inverse() diff --git a/tests/test_inverse/test_arts1d_random.py b/tests/test_inverse/test_arts1d_random.py new file mode 100644 index 00000000..e64f1b4d --- /dev/null +++ b/tests/test_inverse/test_arts1d_random.py @@ -0,0 +1,217 @@ +import pytest +from jax import config, block_until_ready, devices + +config.update("jax_enable_x64", True) + +from jax import numpy as jnp +from jax.flatten_util import ravel_pytree +from scipy.optimize import minimize +import equinox as eqx +import numpy as np +import matplotlib.pyplot as plt +import yaml, os, mlflow, tempfile, optax, tqdm, time +from flatten_dict import flatten, unflatten + +from tsadar.utils import misc +from tsadar.core.thomson_diagnostic import ThomsonScatteringDiagnostic +from tsadar.core.modules import ThomsonParams, get_filter_spec +from tsadar.utils.data_handling.calibration import get_scattering_angles, get_calibrations + + +def _perturb_params_(rng, params, arbitrary_distribution: bool = False): + """ + Perturbs the parameters for the forward pass. + + Args: + params: Dictionary - Parameters to be perturbed + + Returns: + Dictionary - Perturbed parameters + + """ + + params["electron"]["Te"]["val"] = float(rng.uniform(0.5, 1.5)) + params["electron"]["ne"]["val"] = float(rng.uniform(0.1, 0.7)) + + params["general"]["amp1"]["val"] = float(rng.uniform(0.5, 2.5)) + params["general"]["amp2"]["val"] = float(rng.uniform(0.5, 2.5)) + params["general"]["lam"]["val"] = float(rng.uniform(523, 527)) + + if arbitrary_distribution: + params["electron"]["fe"]["params"]["init_m"] = float(rng.uniform(2.0, 3.5)) + params["electron"]["fe"]["type"] = "arbitrary" + else: + params["electron"]["fe"]["params"]["m"]["val"] = float(rng.uniform(2.0, 3.5)) + params["electron"]["fe"]["type"] = "dlm" + + # for key in params["general"].keys(): + # params[key]["val"] *= rng.uniform(0.75, 1.25) + + # for key in params["ion-1"].keys(): + # params[key]["val"] *= rng.uniform(0.75, 1.25) + + return params + + +@pytest.mark.parametrize( + "arbitrary_distribution", + [ + False, + ], +) +def test_arts1d_inverse(arbitrary_distribution: bool): + """ + Runs a forward pass with the Thomson scattering diagnostic and ThomsonParams classes. Saves the results to mlflow. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + if not any(["gpu" == device.platform for device in devices()]): + pytest.skip("Takes too long without a GPU") + + _t0 = time.time() + mlflow.set_experiment("tsadar-tests") + with mlflow.start_run(run_name="test_arts1d_inverse") as run: + with open("tests/configs/arts1d_test_defaults.yaml", "r") as fi: + defaults = yaml.safe_load(fi) + + with open("tests/configs/arts1d_test_inputs.yaml", "r") as fi: + inputs = yaml.safe_load(fi) + + defaults = flatten(defaults) + defaults.update(flatten(inputs)) + config = unflatten(defaults) + + with tempfile.TemporaryDirectory() as td: + with open(os.path.join(td, "config.yaml"), "w") as fi: + yaml.dump(config, fi) + # get scattering angles and weights + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + sas = get_scattering_angles(config) + + sas["angAxis"], _, _, _, _, _ = get_calibrations( + 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] + ) # shot number hardcoded to get calibration + config["other"]["extraoptions"]["spectype"] = "angular_full" + + dummy_batch = { + "i_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "e_data": np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + rng = np.random.default_rng() + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + config["parameters"] = _perturb_params_(rng, config["parameters"], arbitrary_distribution=False) + misc.log_mlflow(config) + ts_params_gt = ThomsonParams(config["parameters"], num_params=1, batch=False, activate=True) + + ThryE, ThryI, lamAxisE, lamAxisI = ts_diag(ts_params_gt, dummy_batch) + ground_truth = {"ThryE": ThryE, "lamAxisE": lamAxisE, "ThryI": ThryI, "lamAxisI": lamAxisI} + + def loss_fn(_diff_params, _static_params): + _all_params = eqx.combine(_diff_params, _static_params) + ThryE, ThryI, _, _ = ts_diag(_all_params, dummy_batch) + return jnp.mean(jnp.square(ThryE - ground_truth["ThryE"])) + + t0 = time.time() + jit_vg = eqx.filter_jit(eqx.filter_value_and_grad(loss_fn)) + diff_params, static_params = perturb_and_split_params(arbitrary_distribution, config, rng) + temp_out = block_until_ready(jit_vg(diff_params, static_params)) + mlflow.log_metric(f"first run time", time.time() - t0) + + loss = 1 + while np.nan_to_num(loss, nan=1) > 5e-2: + # ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + diff_params, static_params = perturb_and_split_params(arbitrary_distribution, config, rng) + + use_optax = False + if use_optax: + opt = optax.adam(0.001) + opt_state = opt.init(diff_params) + for i in (pbar := tqdm.tqdm(range(25))): + t0 = time.time() + loss, grad_loss = jit_vg(diff_params, static_params) + mlflow.log_metric(f"iteration time", time.time() - t0, step=i) + + updates, opt_state = opt.update(grad_loss, opt_state) + diff_params = eqx.apply_updates(diff_params, updates) + pbar.set_description(f"Loss: {loss:.4f}") + + else: + flattened_diff_params, unravel = ravel_pytree(diff_params) + + def scipy_vg_fn(diff_params_flat): + diff_params_pytree = unravel(diff_params_flat) + loss, grads = jit_vg(diff_params_pytree, static_params) + flattened_grads, _ = ravel_pytree(grads) + + return float(loss), np.array(flattened_grads) + + res = minimize( + scipy_vg_fn, flattened_diff_params, method="L-BFGS-B", jac=True, options={"disp": True} + ) + + diff_params = unravel(res["x"]) + loss = res["fun"] + + params_to_log = { + "gt": ts_params_gt.get_unnormed_params(), + "learned": eqx.combine(diff_params, static_params).get_unnormed_params(), + } + + misc.log_mlflow({"loss": loss} | params_to_log, which="metrics") + ThryE, _, _, _ = ts_diag(eqx.combine(diff_params, static_params), dummy_batch) + + fig, ax = plt.subplots(1, 3, figsize=(11, 4), tight_layout=True) + c = ax[0].contourf(np.squeeze(ThryE).T) + fig.colorbar(c) + c = ax[1].contourf(np.squeeze(ground_truth["ThryE"]).T) + fig.colorbar(c) + c = ax[2].contourf((np.squeeze(ground_truth["ThryE"]) - np.squeeze(ThryE)).T) + fig.colorbar(c) + + ax[0].set_title("Model") + ax[1].set_title("Ground Truth") + ax[2].set_title("Model - Ground Truth") + fig.savefig(os.path.join(td, "ThryE.png"), bbox_inches="tight") + + mlflow.log_metric("runtime-sec", time.time() - _t0) + mlflow.log_artifacts(td) + + misc.export_run(run.info.run_id) + # np.testing.assert_allclose(ThryE, ground_truth["ThryE"], atol=0.01, rtol=0) + + +def perturb_and_split_params(arbitrary_distribution, config, rng): + config["parameters"] = _perturb_params_(rng, config["parameters"], arbitrary_distribution=arbitrary_distribution) + ts_params_fit = ThomsonParams( + config["parameters"], + num_params=1, + batch=False, + activate=True, + ) + diff_params, static_params = eqx.partition( + ts_params_fit, filter_spec=get_filter_spec(cfg_params=config["parameters"], ts_params=ts_params_fit) + ) + + return diff_params, static_params + + +if __name__ == "__main__": + test_arts1d_inverse(arbitrary_distribution=False) diff --git a/tsadar/__init__.py b/tsadar/__init__.py index cfc87ad3..4293bbe7 100644 --- a/tsadar/__init__.py +++ b/tsadar/__init__.py @@ -1 +1,3 @@ from .runner import run_for_app +from .core import ThomsonScatteringDiagnostic, ThomsonParams +from .utils.data_handling.calibration import get_scattering_angles diff --git a/tsadar/core/__init__.py b/tsadar/core/__init__.py new file mode 100644 index 00000000..f145920a --- /dev/null +++ b/tsadar/core/__init__.py @@ -0,0 +1,2 @@ +from .thomson_diagnostic import ThomsonScatteringDiagnostic +from .modules import ThomsonParams diff --git a/tsadar/core/modules.py b/tsadar/core/modules.py new file mode 100644 index 00000000..4ba15809 --- /dev/null +++ b/tsadar/core/modules.py @@ -0,0 +1,591 @@ +from typing import List, Dict, Union, Callable +from collections import defaultdict + +from jax import Array, numpy as jnp, tree_util as jtu, vmap +from jax.nn import sigmoid, relu +from jax.random import PRNGKey +from jax.scipy.special import gamma, sph_harm +import equinox as eqx + + +class DistributionFunction1D(eqx.Module): + vx: Array + + def __init__(self, dist_cfg: Dict): + super().__init__() + vmax = 6.0 + dv = 2 * vmax / dist_cfg["nv"] + self.vx = jnp.linspace(-vmax + dv / 2, vmax - dv / 2, dist_cfg["nv"]) + + def __call__(self): + raise NotImplementedError + + +class Arbitrary1DNN(DistributionFunction1D): + f_nn: eqx.Module + + def __init__(self, dist_cfg): + super().__init__(dist_cfg) + # self.learn_log = dist_cfg["params"]["learn_log"] + self.f_nn = eqx.nn.MLP(1, 1, 32, 3, final_activation=relu, key=PRNGKey(0)) + + def get_unnormed_params(self): + return {"f": self()} + + def __call__(self): + # if self.learn_log: + # # bound values between 1e-15 and 10 + # f_nn = -16 * sigmoid(self.f_nn) + 1 + # f_nn = jnp.power(10.0, self.f_nn) + # else: + # f_nn = sigmoid(f_nn) * 10 + fval = eqx.filter_vmap(self.f_nn)(self.vx[:, None]) + fval = jnp.squeeze(fval) + # if self.learn_log: + # fval = jnp.power(10.0, -fval) + + return fval / jnp.sum(fval) / (self.vx[1] - self.vx[0]) + + +class Arbitrary1D(DistributionFunction1D): + fval: Array + learn_log: bool + + def __init__(self, dist_cfg): + super().__init__(dist_cfg) + self.learn_log = dist_cfg["params"]["learn_log"] + self.fval = self.init_dlm(dist_cfg["params"]["init_m"]) + + def init_dlm(self, m): + vth_x = jnp.sqrt(2.0) + alpha = jnp.sqrt(3.0 * gamma(3.0 / m) / 2.0 / gamma(5.0 / m)) + cst = m / (4.0 * jnp.pi * alpha**3.0 * gamma(3.0 / m)) + fdlm = cst / vth_x**3.0 * jnp.exp(-(jnp.abs(self.vx / alpha / vth_x) ** m)) + fdlm = fdlm / jnp.sum(fdlm) / (self.vx[1] - self.vx[0]) + + if self.learn_log: + # # logit function + # fdlm = 1 / 16 * jnp.log(fdlm / (1 - fdlm)) - 1 + fdlm = -jnp.log10(fdlm) + # else: + # fdlm = 0.1 * jnp.log(fdlm / (1 - fdlm)) + + return jnp.sqrt(fdlm) + + def get_unnormed_params(self): + return {"f": self()} + + def __call__(self): + # if self.learn_log: + # # bound values between 1e-15 and 10 + # fval = -16 * sigmoid(self.fval) + 1 + # fval = jnp.power(10.0, self.fval) + # else: + # fval = sigmoid(fval) * 10 + fval = self.fval**2.0 + if self.learn_log: + fval = jnp.power(10.0, -fval) + + return fval / jnp.sum(fval) / (self.vx[1] - self.vx[0]) + + +class DLM1D(DistributionFunction1D): + normed_m: Array + m_scale: float + m_shift: float + act_fun: Callable + + def __init__(self, dist_cfg, activate=False): + super().__init__(dist_cfg) + self.m_scale = 3.0 # dist_cfg["params"]["m"]["ub"] - dist_cfg["params"]["m"]["lb"] + self.m_shift = 2.0 # dist_cfg["params"]["m"]["lb"] + + if activate: + inv_act_fun = lambda x: x # jnp.log(1e-6 + x / (1 - x)) + self.act_fun = sigmoid + else: + inv_act_fun = lambda x: x + self.act_fun = lambda x: x + + self.normed_m = inv_act_fun((dist_cfg["params"]["m"]["val"] - self.m_shift) / self.m_scale) + + def get_unnormed_params(self): + return {"m": self.act_fun(self.normed_m) * self.m_scale + self.m_shift} + + def __call__(self): + unnormed_m = self.act_fun(self.normed_m) * self.m_scale + self.m_shift + vth_x = jnp.sqrt(2.0) + alpha = jnp.sqrt(3.0 * gamma(3.0 / unnormed_m) / 2.0 / gamma(5.0 / unnormed_m)) + cst = unnormed_m / (4.0 * jnp.pi * alpha**3.0 * gamma(3.0 / unnormed_m)) + fdlm = cst / vth_x**3.0 * jnp.exp(-(jnp.abs(self.vx / alpha / vth_x) ** unnormed_m)) + + return fdlm / jnp.sum(fdlm) / (self.vx[1] - self.vx[0]) + + +class DistributionFunction2D(eqx.Module): + vx: Array + + def __init__(self, dist_cfg): + super().__init__() + vmax = 6.0 + dvx = 2 * vmax / dist_cfg["nvx"] + self.vx = jnp.linspace(-vmax + dvx / 2, vmax - dvx / 2, dist_cfg["nvx"]) + + def __call__(self, *args, **kwds): + return super().__call__(*args, **kwds) + + +class SphericalHarmonics(DistributionFunction2D): + vr: Array + th: Array + sph_harm: Callable + vr_vxvy: Array + Nl: int + flm: Dict[str, Dict[str, Array]] + m_scale: float + m_shift: float + act_fun: Callable + normed_m: Array + + def __init__(self, dist_cfg): + super().__init__(dist_cfg) + + vmax = 6.0 * 1.05 * jnp.sqrt(2.0) + dvr = vmax / dist_cfg["params"]["nvr"] + self.vr = jnp.linspace(dvr / 2, vmax - dvr / 2, dist_cfg["params"]["nvr"]) + + vx, vy = jnp.meshgrid(self.vx, self.vx) + self.th = jnp.arctan2(vy, vx) + self.vr_vxvy = jnp.sqrt(vx**2 + vy**2) + self.Nl = dist_cfg["params"]["Nl"] + + self.sph_harm = vmap(sph_harm, in_axes=(None, None, None, 0, None)) + self.flm = defaultdict(dict) + for i in range(self.Nl + 1): + self.flm[i] = {j: jnp.zeros(dist_cfg["params"]["nvr"]) for j in range(i + 1)} + + init_m = dist_cfg["params"]["init_m"] + self.flm[0][0] = self.get_f00(init_m) + if dist_cfg["params"]["init_f10"]: + self.flm[1][0] = ( + dist_cfg["params"]["init_f10"] * jnp.gradient(jnp.gradient(self.flm[0][0])) * self.vr**2.0 * dvr + ) + if dist_cfg["params"]["init_f11"]: + self.flm[1][1] = ( + dist_cfg["params"]["init_f11"] * jnp.gradient(jnp.gradient(self.flm[0][0])) * self.vr**2.0 * dvr + ) + + self.m_scale = 3.0 # dist_cfg["params"]["m"]["ub"] - dist_cfg["params"]["m"]["lb"] + self.m_shift = 2.0 # dist_cfg["params"]["m"]["lb"] + inv_act_fun = lambda x: x # jnp.log(1e-6 + x / (1 - x)) + self.act_fun = sigmoid + self.normed_m = inv_act_fun((init_m - self.m_shift) / self.m_scale) + + def get_unnormed_params(self): + return {"flm": self.flm} + + def get_f00(self, m): + vth_x = 1.0 + alpha = jnp.sqrt(3.0 * gamma(3.0 / m) / 2.0 / gamma(5 / m)) + cst = m / (4 * jnp.pi * alpha**3.0 * gamma(3 / m)) + f00 = cst / vth_x**3.0 * jnp.exp(-((self.vr / alpha / vth_x) ** m)) + f00 /= jnp.sum(f00 * 4 * jnp.pi * self.vr**2.0) * (self.vr[1] - self.vr[0]) + + return f00 + + def __call__(self): + # fvxvy = jnp.zeros(jnp.shape(self.vr_vxvy)) + unnormed_m = self.act_fun(self.normed_m) * self.m_scale + self.m_shift + f00 = self.get_f00(unnormed_m) + fvxvy = jnp.interp(self.vr_vxvy, self.vr, f00, right=1e-16) + + for i in range(1, self.Nl + 1): + for j in range(i + 1): + _flmvxvy = jnp.interp(self.vr_vxvy, self.vr, self.flm[i][j], right=1e-16) + _sph_harm = self.sph_harm( + jnp.array([j]), jnp.array([i]), 0.0, self.th.reshape(-1, order="C"), 2 + ).reshape(self.vr_vxvy.shape, order="C") + fvxvy += _flmvxvy * jnp.real(_sph_harm) + + return fvxvy + + +class ElectronParams(eqx.Module): + normed_Te: Array + normed_ne: Array + Te_scale: float + Te_shift: float + ne_scale: float + ne_shift: float + distribution_functions: Union[ + List[DistributionFunction1D], List[DistributionFunction2D], DistributionFunction1D, DistributionFunction2D + ] + batch: bool + act_fun: Callable + + def __init__(self, cfg, batch_size, batch=True, activate=False): + super().__init__() + + self.Te_scale = cfg["Te"]["ub"] - cfg["Te"]["lb"] + self.Te_shift = cfg["Te"]["lb"] + self.ne_scale = cfg["ne"]["ub"] - cfg["ne"]["lb"] + self.ne_shift = cfg["ne"]["lb"] + self.batch = batch + + if activate: + self.act_fun = sigmoid + inv_act_fun = lambda x: x # jnp.log(1e-6 + x / (1 - x)) + else: + self.act_fun = lambda x: x + inv_act_fun = lambda x: x + + if batch: + self.normed_Te = inv_act_fun(jnp.full(batch_size, (cfg["Te"]["val"] - self.Te_shift) / self.Te_scale)) + self.normed_ne = inv_act_fun(jnp.full(batch_size, (cfg["ne"]["val"] - self.ne_shift) / self.ne_scale)) + else: + self.normed_Te = inv_act_fun((cfg["Te"]["val"] - self.Te_shift) / self.Te_scale) + self.normed_ne = inv_act_fun((cfg["ne"]["val"] - self.ne_shift) / self.ne_scale) + + self.distribution_functions = self.init_dists(cfg["fe"], batch_size, batch, activate) + + def init_dists(self, dist_cfg, batch_size, batch, activate): + if dist_cfg["dim"] == 1: + if dist_cfg["type"].casefold() == "dlm": + if batch: + distribution_functions = [DLM1D(dist_cfg, activate) for _ in range(batch_size)] + else: + distribution_functions = DLM1D(dist_cfg, activate) + + elif dist_cfg["type"].casefold() == "mx": + if batch: + distribution_functions = [ + lambda vx: jnp.exp(-(vx**2 / 2)) / jnp.sum(jnp.exp(-(vx**2 / 2))) / (vx[1] - vx[0]) + ] + else: + distribution_functions = ( + lambda vx: jnp.exp(-(vx**2 / 2)) / jnp.sum(jnp.exp(-(vx**2 / 2))) / (vx[1] - vx[0]) + ) + elif dist_cfg["type"].casefold() == "arbitrary": + if batch: + distribution_functions = [Arbitrary1D(dist_cfg) for _ in range(batch_size)] + else: + distribution_functions = Arbitrary1D(dist_cfg) + + elif dist_cfg["type"].casefold() == "arbitrary-nn": + if batch: + distribution_functions = [Arbitrary1DNN(dist_cfg) for _ in range(batch_size)] + else: + distribution_functions = Arbitrary1DNN(dist_cfg) + + else: + raise NotImplementedError(f"Unknown 1D distribution type: {dist_cfg['type']}") + elif dist_cfg["dim"] == 2: + if "sph" in dist_cfg["type"].casefold(): + if batch: + raise NotImplementedError( + "Batch mode not implemented for 2D distributions as a precautionary measure against memory issues" + ) + distribution_functions = [SphericalHarmonics(dist_cfg) for _ in range(batch_size)] + else: + distribution_functions = SphericalHarmonics(dist_cfg) + else: + raise NotImplementedError(f"Unknown 2D distribution type: {dist_cfg['type']}") + else: + raise NotImplementedError(f"Not implemented distribution dimension: {dist_cfg['dim']}") + + return distribution_functions + + def get_unnormed_params(self): + unnormed_fe_params = defaultdict(list) + if isinstance(self.distribution_functions, list): + for fe in self.distribution_functions: + for k, v in fe.get_unnormed_params().items(): + unnormed_fe_params[k].append(v) + unnormed_fe_params = {k: jnp.array(v) for k, v in unnormed_fe_params.items()} + else: + unnormed_fe_params = self.distribution_functions.get_unnormed_params() + + return { + "Te": self.act_fun(self.normed_Te) * self.Te_scale + self.Te_shift, + "ne": self.act_fun(self.normed_ne) * self.ne_scale + self.ne_shift, + } | unnormed_fe_params + + def __call__(self): + physical_params = { + "Te": self.act_fun(self.normed_Te) * self.Te_scale + self.Te_shift, + "ne": self.act_fun(self.normed_ne) * self.ne_scale + self.ne_shift, + } + if self.batch: + dist_params = { + "fe": jnp.concatenate([df()[None, :] for df in self.distribution_functions]), + "v": jnp.concatenate([df.vx[None, :] for df in self.distribution_functions]), + } + else: + dist_params = { + "fe": self.distribution_functions(), + "v": self.distribution_functions.vx, + } + + return physical_params | dist_params + + +class IonParams(eqx.Module): + normed_Ti: Array + normed_Z: Array + # normed_A: Array + fract: Array + Ti_scale: float + Ti_shift: float + Z_scale: float + Z_shift: float + # A_scale: float + # A_shift: float + A: int + act_fun: Callable + + def __init__(self, cfg, batch_size, batch=True, activate=False): + super().__init__() + self.Ti_scale = cfg["Ti"]["ub"] - cfg["Ti"]["lb"] + self.Ti_shift = cfg["Ti"]["lb"] + self.Z_scale = cfg["Z"]["ub"] - cfg["Z"]["lb"] + self.Z_shift = cfg["Z"]["lb"] + + # self.A_scale = cfg["A"]["ub"] - cfg["A"]["lb"] + # self.A_shift = cfg["A"]["lb"] + + if activate: + inv_act_fun = lambda x: x # jnp.log(1e-6 + x / (1 - x)) + self.act_fun = sigmoid + else: + inv_act_fun = lambda x: x + self.act_fun = lambda x: x + + if batch: + self.normed_Ti = inv_act_fun(jnp.full(batch_size, (cfg["Ti"]["val"] - self.Ti_shift) / self.Ti_scale)) + self.normed_Z = inv_act_fun(jnp.full(batch_size, (cfg["Z"]["val"] - self.Z_shift) / self.Z_scale)) + self.A = jnp.full(batch_size, cfg["A"]["val"]) + self.fract = inv_act_fun(jnp.full(batch_size, cfg["fract"]["val"])) + else: + self.normed_Ti = inv_act_fun((cfg["Ti"]["val"] - self.Ti_shift) / self.Ti_scale) + self.normed_Z = inv_act_fun((cfg["Z"]["val"] - self.Z_shift) / self.Z_scale) + self.A = cfg["A"]["val"] + self.fract = float(inv_act_fun(cfg["fract"]["val"])) + + def get_unnormed_params(self): + return self() + + def __call__(self): + + return { + "A": self.A, + "fract": self.act_fun(self.fract), + "Ti": self.act_fun(self.normed_Ti) * self.Ti_scale + self.Ti_shift, + "Z": self.act_fun(self.normed_Z) * self.Z_scale + self.Z_shift, + } + + +class GeneralParams(eqx.Module): + normed_lam: Array + normed_amp1: Array + normed_amp2: Array + normed_amp3: Array + normed_ne_gradient: Array + normed_Te_gradient: Array + normed_ud: Array + normed_vA: Array + lam_scale: float + lam_shift: float + amp1_scale: float + amp1_shift: float + amp2_scale: float + amp2_shift: float + amp3_scale: float + amp3_shift: float + ne_gradient_scale: float + ne_gradient_shift: float + Te_gradient_scale: float + Te_gradient_shift: float + ud_scale: float + ud_shift: float + vA_scale: float + vA_shift: float + act_fun: Callable + + def __init__(self, cfg, batch_size: int, batch=True, activate=False): + super().__init__() + self.lam_scale = cfg["lam"]["ub"] - cfg["lam"]["lb"] + self.lam_shift = cfg["lam"]["lb"] + self.amp1_scale = cfg["amp1"]["ub"] - cfg["amp1"]["lb"] + self.amp1_shift = cfg["amp1"]["lb"] + self.amp2_scale = cfg["amp2"]["ub"] - cfg["amp2"]["lb"] + self.amp2_shift = cfg["amp2"]["lb"] + self.amp3_scale = cfg["amp3"]["ub"] - cfg["amp3"]["lb"] + self.amp3_shift = cfg["amp3"]["lb"] + self.ne_gradient_scale = cfg["ne_gradient"]["ub"] - cfg["ne_gradient"]["lb"] + self.ne_gradient_shift = cfg["ne_gradient"]["lb"] + self.Te_gradient_scale = cfg["Te_gradient"]["ub"] - cfg["Te_gradient"]["lb"] + self.Te_gradient_shift = cfg["Te_gradient"]["lb"] + self.ud_scale = cfg["ud"]["ub"] - cfg["ud"]["lb"] + self.ud_shift = cfg["ud"]["lb"] + self.vA_scale = cfg["Va"]["ub"] - cfg["Va"]["lb"] + self.vA_shift = cfg["Va"]["lb"] + + if activate: + inv_act_fun = lambda x: x # jnp.log(1e-6 + x / (1 - x)) + self.act_fun = sigmoid + else: + inv_act_fun = lambda x: x + self.act_fun = lambda x: x + + if batch: + self.normed_amp1 = inv_act_fun( + jnp.full(batch_size, (cfg["amp1"]["val"] - self.amp1_shift) / self.amp1_scale) + ) + self.normed_amp2 = inv_act_fun( + jnp.full(batch_size, (cfg["amp2"]["val"] - self.amp2_shift) / self.amp2_scale) + ) + self.normed_amp3 = inv_act_fun( + jnp.full(batch_size, (cfg["amp3"]["val"] - self.amp3_shift) / self.amp3_scale) + ) + self.normed_ne_gradient = inv_act_fun( + jnp.full(batch_size, (cfg["ne_gradient"]["val"] - self.ne_gradient_shift) / self.ne_gradient_scale) + ) + self.normed_Te_gradient = inv_act_fun( + jnp.full(batch_size, (cfg["Te_gradient"]["val"] - self.Te_gradient_shift) / self.Te_gradient_scale) + ) + self.normed_ud = inv_act_fun(jnp.full(batch_size, (cfg["ud"]["val"] - self.ud_shift) / self.ud_scale)) + self.normed_vA = inv_act_fun(jnp.full(batch_size, (cfg["Va"]["val"] - self.vA_shift) / self.vA_scale)) + self.normed_lam = inv_act_fun(jnp.full(batch_size, (cfg["lam"]["val"] - self.lam_shift) / self.lam_scale)) + else: + self.normed_amp1 = inv_act_fun((cfg["amp1"]["val"] - self.amp1_shift) / self.amp1_scale) + self.normed_amp2 = inv_act_fun((cfg["amp2"]["val"] - self.amp2_shift) / self.amp2_scale) + self.normed_amp3 = inv_act_fun((cfg["amp3"]["val"] - self.amp3_shift) / self.amp3_scale) + self.normed_ne_gradient = inv_act_fun( + (cfg["ne_gradient"]["val"] - self.ne_gradient_shift) / self.ne_gradient_scale + ) + self.normed_Te_gradient = inv_act_fun( + (cfg["Te_gradient"]["val"] - self.Te_gradient_shift) / self.Te_gradient_scale + ) + self.normed_ud = inv_act_fun((cfg["ud"]["val"] - self.ud_shift) / self.ud_scale) + self.normed_vA = inv_act_fun((cfg["Va"]["val"] - self.vA_shift) / self.vA_scale) + self.normed_lam = inv_act_fun((cfg["lam"]["val"] - self.lam_shift) / self.lam_scale) + + def get_unnormed_params(self): + return self() + + def __call__(self): + unnormed_lam = self.act_fun(self.normed_lam) * self.lam_scale + self.lam_shift + unnormed_amp1 = self.act_fun(self.normed_amp1) * self.amp1_scale + self.amp1_shift + unnormed_amp2 = self.act_fun(self.normed_amp2) * self.amp2_scale + self.amp2_shift + unnormed_amp3 = self.act_fun(self.normed_amp3) * self.amp3_scale + self.amp3_shift + unnormed_ne_gradient = self.act_fun(self.normed_ne_gradient) * self.ne_gradient_scale + self.ne_gradient_shift + unnormed_Te_gradient = self.act_fun(self.normed_Te_gradient) * self.Te_gradient_scale + self.Te_gradient_shift + unnormed_ud = self.act_fun(self.normed_ud) * self.ud_scale + self.ud_shift + unnormed_vA = self.act_fun(self.normed_vA) * self.vA_scale + self.vA_shift + + return { + "lam": unnormed_lam, + "amp1": unnormed_amp1, + "amp2": unnormed_amp2, + "amp3": unnormed_amp3, + "ne_gradient": unnormed_ne_gradient, + "Te_gradient": unnormed_Te_gradient, + "ud": unnormed_ud, + "Va": unnormed_vA, + } + + +class ThomsonParams(eqx.Module): + electron: ElectronParams + ions: List[IonParams] + general: GeneralParams + + def __init__(self, param_cfg, num_params: int, batch=True, activate=False): + super().__init__() + self.electron = ElectronParams(param_cfg["electron"], num_params, batch, activate) + self.ions = [] + for species in param_cfg.keys(): + if "ion" in species: + self.ions.append(IonParams(param_cfg[species], num_params, batch, activate)) + + assert len(self.ions) > 0, "No ion species found in input deck" + self.general = GeneralParams(param_cfg["general"], num_params, batch, activate) + + def get_unnormed_params(self): + return { + "electron": self.electron.get_unnormed_params(), + "general": self.general.get_unnormed_params(), + } | {f"ion-{i+1}": ion.get_unnormed_params() for i, ion in enumerate(self.ions)} + + def __call__(self): + return {"electron": self.electron(), "general": self.general()} | { + f"ion-{i+1}": ion() for i, ion in enumerate(self.ions) + } + + +def get_filter_spec(cfg_params: Dict, ts_params: ThomsonParams) -> Dict: + # Step 2 + filter_spec = jtu.tree_map(lambda _: False, ts_params) + for species, params in cfg_params.items(): + for key, val in params.items(): + if val["active"]: + if key == "fe": + filter_spec = get_distribution_filter_spec(filter_spec, dist_type=val["type"]) + else: + nkey = f"normed_{key}" + filter_spec = eqx.tree_at( + lambda tree: getattr(getattr(tree, species), nkey), + filter_spec, + replace=True, + ) + + return filter_spec + + +def get_distribution_filter_spec(filter_spec: Dict, dist_type: str) -> Dict: + if dist_type.casefold() == "dlm": + if isinstance(filter_spec.electron.distribution_functions, list): + num_dists = len(filter_spec.electron.distribution_functions) + for i in range(num_dists): + filter_spec = eqx.tree_at( + lambda tree: tree.electron.distribution_functions[i].normed_m, filter_spec, replace=True + ) + else: + filter_spec = eqx.tree_at( + lambda tree: tree.electron.distribution_functions.normed_m, filter_spec, replace=True + ) + + elif dist_type.casefold() == "mx": + raise Warning("No trainable parameters for Maxwellian distribution") + + elif dist_type.casefold() == "arbitrary": + if isinstance(filter_spec.electron.distribution_functions, list): + num_dists = len(filter_spec.electron.distribution_functions) + for i in range(num_dists): + filter_spec = eqx.tree_at( + lambda tree: tree.electron.distribution_functions[i].fval, filter_spec, replace=True + ) + else: + filter_spec = eqx.tree_at(lambda tree: tree.electron.distribution_functions.fval, filter_spec, replace=True) + elif dist_type.casefold() == "arbitrary-nn": + df = filter_spec.electron.distribution_functions + if isinstance(df, list): + for i in range(len(df)): + filter_spec = update_distribution_layers(filter_spec, df=df[i]) + else: + filter_spec = update_distribution_layers(filter_spec, df=df) + + else: + raise NotImplementedError(f"Untrainable distribution type: {dist_type}") + + return filter_spec + + +def update_distribution_layers(filter_spec, df): + print(df.f_nn.layers) + for j in range(len(df.f_nn.layers)): + if df.f_nn.layers[j].weight: + filter_spec = eqx.tree_at(lambda tree: df.f_nn.layers[j].linear.weight, filter_spec, replace=True) + filter_spec = eqx.tree_at(lambda tree: df.f_nn.layers[j].linear.bias, filter_spec, replace=True) + + return filter_spec diff --git a/tsadar/misc/__init__.py b/tsadar/core/physics/__init__.py similarity index 100% rename from tsadar/misc/__init__.py rename to tsadar/core/physics/__init__.py diff --git a/tsadar/model/physics/form_factor.py b/tsadar/core/physics/form_factor.py similarity index 68% rename from tsadar/model/physics/form_factor.py rename to tsadar/core/physics/form_factor.py index 75fc361d..e2e3a3da 100644 --- a/tsadar/model/physics/form_factor.py +++ b/tsadar/core/physics/form_factor.py @@ -1,4 +1,4 @@ -from jax import numpy as jnp, vmap, device_put +from jax import numpy as jnp, vmap, device_put, device_count, devices from jax.experimental import mesh_utils from jax.sharding import Mesh, PartitionSpec as P, NamedSharding @@ -11,11 +11,10 @@ from jax.lax import scan, map as jmap from jax import checkpoint -from tsadar.model.physics import ratintn -from tsadar.data_handleing import lam_parse -from tsadar.misc.vector_tools import vsub, vdot, vdiv +from . import ratintn +from ...utils.vector_tools import vsub, vdot, vdiv -BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "aux") +BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "external") def zprimeMaxw(xi): @@ -49,13 +48,13 @@ def zprimeMaxw(xi): class FormFactor: - def __init__(self, lamrang, npts, fe_dim, vax=None): + def __init__(self, lambda_range, npts, lam_shift, scattering_angles, num_grad_points, ud_ang, va_ang): """ Creates a FormFactor object holding all the static values to use for repeated calculations of the Thomson scattering structure factor or spectral density function. Args: - lamrang: list of the starting and ending wavelengths over which to calculate the spectrum. + lambda_range: list of the starting and ending wavelengths over which to calculate the spectrum. npts: number of wavelength points to use in the calculation fe_dim: dimension of the electron velocity distribution function (EDF), should be 1 or 2 vax: (optional) velocity axis coordinates that the 2D EDF is defined on @@ -68,26 +67,44 @@ def __init__(self, lamrang, npts, fe_dim, vax=None): self.C = 2.99792458e10 self.Me = 510.9896 / self.C**2 # electron mass keV/C^2 self.Mp = self.Me * 1836.1 # proton mass keV/C^2 - self.lamrang = lamrang + # self.lambda_range = lambda_range self.npts = npts self.h = 0.01 minmax = 8.2 - h1 = 1024 # 1024 # 1024 + h1 = 1024 # 1024 + c = 2.99792458e10 + lamAxis = jnp.linspace(lambda_range[0], lambda_range[1], npts) + self.omgL_num = 2 * jnp.pi * 1e7 * c + omgs = 2e7 * jnp.pi * c / lamAxis # Scattered frequency axis(1 / sec) + self.omgs = omgs[None, ..., None] + self.xi1 = jnp.linspace(-minmax - jnp.sqrt(2.0) / h1, minmax + jnp.sqrt(2.0) / h1, h1) self.xi2 = jnp.array(jnp.arange(-minmax, minmax, self.h)) self.Zpi = jnp.array(zprimeMaxw(self.xi2)) - - if (vax is not None) and (fe_dim == 2): - self.coords = jnp.concatenate([np.copy(vax[0][..., None]), np.copy(vax[1][..., None])], axis=-1) - self.v = vax[0][0] + self.lam_shift = lam_shift + self.scattering_angles = scattering_angles + self.num_grad_points = num_grad_points self.vmap_calc_chi_vals = vmap(checkpoint(self.calc_chi_vals), in_axes=(None, None, 0, 0, 0), out_axes=0) + self.ud_angle, self.va_angle = ud_ang, va_ang # Create a Sharding object to distribute a value across devices: - mesh = Mesh(devices=mesh_utils.create_device_mesh(1), axis_names=("x")) - self.sharding = NamedSharding(mesh, P("x")) - - def __call__(self, params, cur_ne, cur_Te, A, Z, Ti, fract, sa, f_and_v, lam): + is_gpu_present = any(["gpu" == device.platform for device in devices()]) + self.calc_all_chi_vals = self._calc_all_chi_vals_ + + if is_gpu_present: + num_gpus = device_count(backend="gpu") + if num_gpus > 1: + print( + f"If this is a 2D Angular calculation, it will be parallelized across {num_gpus} GPUs. Otherwise, only a single GPU is used" + ) + mesh = Mesh(devices=mesh_utils.create_device_mesh((device_count(backend="gpu"),)), axis_names=("x")) + self.sharding = NamedSharding(mesh, P("x")) + self.calc_all_chi_vals = self.parallel_calc_all_chi_vals + else: + self.calc_all_chi_vals = self._calc_all_chi_vals_ + + def __call__(self, params): """ Calculates the standard collisionless Thomson spectral density function S(k,omg) and is capable of handling multiple plasma conditions and scattering angles. Distribution functions can be arbitrary as calculations of the @@ -113,32 +130,43 @@ def __call__(self, params, cur_ne, cur_Te, A, Z, Ti, fract, sa, f_and_v, lam): wavelength points, number of angles] """ - Te, ne, Va, ud, fe = ( - cur_Te.squeeze(-1), - cur_ne.squeeze(-1), - params["general"]["Va"], - params["general"]["ud"], - f_and_v, # this is now a DistFunc object + ne = ( + 1.0e20 + * params["electron"]["ne"] + * jnp.linspace( + (1 - params["general"]["ne_gradient"] / 200), + (1 + params["general"]["ne_gradient"] / 200), + self.num_grad_points, + ) + ) + Te = params["electron"]["Te"] * jnp.linspace( + (1 - params["general"]["Te_gradient"] / 200), + (1 + params["general"]["Te_gradient"] / 200), + self.num_grad_points, ) + lam = params["general"]["lam"] + self.lam_shift + A = [params[species]["A"] for species in params.keys() if "ion" in species] + Z = [params[species]["Z"] for species in params.keys() if "ion" in species] + Ti = [params[species]["Ti"] for species in params.keys() if "ion" in species] + fract = [params[species]["fract"] for species in params.keys() if "ion" in species] + Va = params["general"]["Va"] * 1e6 # flow velocity in 1e6 cm/s + ud = params["general"]["ud"] * 1e6 # drift velocity in 1e6 cm/s + fe = params["electron"]["fe"] + vx = params["electron"]["v"] Mi = jnp.array(A) * self.Mp # ion mass re = 2.8179e-13 # classical electron radius cm Esq = self.Me * self.C**2 * re # sq of the electron charge keV cm constants = jnp.sqrt(4 * jnp.pi * Esq / self.Me) - sarad = sa * jnp.pi / 180 # scattering angle in radians + sarad = self.scattering_angles["sa"] * jnp.pi / 180 # scattering angle in radians sarad = jnp.reshape(sarad, [1, 1, -1]) - - Va = Va * 1e6 # flow velocity in 1e6 cm/s - ud = ud * 1e6 # drift velocity in 1e6 cm/s - - omgL, omgs, lamAxis, _ = lam_parse.lamParse(self.lamrang, lam, npts=self.npts) # , True) + omgL = self.omgL_num / lam # laser frequency Rad / s # calculate k and omega vectors omgpe = constants * jnp.sqrt(ne[..., jnp.newaxis, jnp.newaxis]) # plasma frequency Rad/cm - omgs = omgs[jnp.newaxis, ..., jnp.newaxis] - omg = omgs - omgL + omg = self.omgs - omgL - ks = jnp.sqrt(omgs**2 - omgpe**2) / self.C + ks = jnp.sqrt(self.omgs**2 - omgpe**2) / self.C kL = jnp.sqrt(omgL**2 - omgpe**2) / self.C k = jnp.sqrt(ks**2 + kL**2 - 2 * ks * kL * jnp.cos(sarad)) @@ -146,7 +174,6 @@ def __call__(self, params, cur_ne, cur_Te, A, Z, Ti, fract, sa, f_and_v, lam): omgdop = omg - kdotv # plasma parameters - # electrons vTe = jnp.sqrt(Te[..., jnp.newaxis, jnp.newaxis] / self.Me) # electron thermal velocity klde = (vTe / omgpe) * k @@ -161,36 +188,36 @@ def __call__(self, params, cur_ne, cur_Te, A, Z, Ti, fract, sa, f_and_v, lam): vTi = jnp.sqrt(jnp.array(Ti) / Mi) # ion thermal velocity kldi = (vTi / omgpi) * (k[..., jnp.newaxis]) + # ion susceptibilities # finding derivative of plasma dispersion function along xii array # proper handeling of multiple ion temperatures is not implemented xii = 1.0 / jnp.transpose((jnp.sqrt(2.0) * vTi), [1, 0, 2, 3]) * ((omgdop / k)[..., jnp.newaxis]) - num_species = len(fract) - num_ion_pts = jnp.shape(xii) - chiI = jnp.zeros(num_ion_pts) + + # num_ion_pts = jnp.shape(xii) + # chiI = jnp.zeros(num_ion_pts) ZpiR = jnp.interp(xii, self.xi2, self.Zpi[0, :], left=xii**-2, right=xii**-2) ZpiI = jnp.interp(xii, self.xi2, self.Zpi[1, :], left=0, right=0) - chiI = jnp.sum(-0.5 / (kldi**2) * (ZpiR + jnp.sqrt(-1 + 0j) * ZpiI), 3) + chiI = jnp.sum(-0.5 / (kldi**2) * (ZpiR + 1j * ZpiI), 3) # electron susceptibility # calculating normilized phase velcoity(xi's) for electrons xie = omgdop / (k * vTe) - ud / vTe - DF, x = fe - fe_vphi = jnp.exp(jnp.interp(xie, x, jnp.log(jnp.squeeze(DF)))) + fe_vphi = jnp.exp(jnp.interp(xie, vx, jnp.log(fe))) df = jnp.diff(fe_vphi, 1, 1) / jnp.diff(xie, 1, 1) - df = jnp.append(df, jnp.zeros((len(ne), 1, len(sa))), 1) + df = jnp.append(df, jnp.zeros((len(ne), 1, len(self.scattering_angles["sa"]))), 1) - chiEI = jnp.pi / (klde**2) * jnp.sqrt(-1 + 0j) * df + chiEI = jnp.pi / (klde**2) * 1j * df - ratmod = jnp.exp(jnp.interp(self.xi1, x, jnp.log(jnp.squeeze(DF)))) + ratmod = jnp.exp(jnp.interp(self.xi1, vx, jnp.log(fe))) ratdf = jnp.gradient(ratmod, self.xi1[1] - self.xi1[0]) - def this_ratintn(this_dx): - return jnp.real(ratintn.ratintn(ratdf, this_dx, self.xi1)) + chiERratprim = vmap(ratintn.ratintn, in_axes=(None, 0, None))( + ratdf, self.xi1[None, :] - self.xi2[:, None], self.xi1 + ) - chiERratprim = vmap(this_ratintn)(self.xi1[None, :] - self.xi2[:, None]) # if len(fe) == 2: chiERrat = jnp.reshape(jnp.interp(xie.flatten(), self.xi2, chiERratprim[:, 0]), xie.shape) # else: @@ -217,14 +244,14 @@ def this_ratintn(this_dx): PsOmg = (SKW_ion_omg + SKW_ele_omg) * (1 + 2 * omgdop / omgL) * re**2.0 * ne[:, None, None] # PsOmgE = (SKW_ele_omg) * (1 + 2 * omgdop / omgL) * re**2.0 * jnp.transpose(ne) # commented because unused - lams = 2 * jnp.pi * self.C / omgs + lams = 2 * jnp.pi * self.C / self.omgs PsLam = PsOmg * 2 * jnp.pi * self.C / lams**2 # PsLamE = PsOmgE * 2 * jnp.pi * C / lams**2 # commented because unused formfactor = PsLam return formfactor, lams - def rotate(self, df, angle, reshape: bool = False) -> jnp.ndarray: + def rotate(self, vx, df, angle, reshape: bool = False) -> jnp.ndarray: """ Rotate a 2D array by a given angle in radians @@ -240,16 +267,13 @@ def rotate(self, df, angle, reshape: bool = False) -> jnp.ndarray: cos_angle = jnp.cos(rad_angle) sin_angle = jnp.sin(rad_angle) rotation_matrix = jnp.array([[cos_angle, -sin_angle], [sin_angle, cos_angle]]) + _vx, _vy = jnp.meshgrid(vx, vx) + coords = jnp.stack((_vx.flatten(), _vy.flatten())) + rotated_coords = jnp.einsum("ij, ik->kj", rotation_matrix, coords) + xq = rotated_coords[:, 0] + yq = rotated_coords[:, 1] - rotated_mesh = vmap(vmap(jnp.dot, in_axes=(None, 0)), in_axes=(None, 1), out_axes=1)( - rotation_matrix, self.coords - ) - xq = rotated_mesh[..., 0].flatten() - yq = rotated_mesh[..., 1].flatten() - - return interp2d(xq, yq, self.v, self.v, df, extrap=True, method="cubic").reshape( - (self.v.size, self.v.size), order="F" - ) + return interp2d(xq, yq, vx, vx, df, extrap=True, method="cubic").reshape((vx.size, vx.size), order="F") def scan_calc_chi_vals(self, carry, xs): """ @@ -274,7 +298,7 @@ def scan_calc_chi_vals(self, carry, xs): fe_vphi, chiEI, chiERrat = self.calc_chi_vals(x, DF, xs) return (x, DF), (fe_vphi, chiEI, chiERrat) - def calc_chi_vals(self, x, DF, inputs): + def calc_chi_vals(self, vx, DF, inputs): """ Calculate the values of the susceptibility at a given point in the distribution function @@ -282,10 +306,10 @@ def calc_chi_vals(self, x, DF, inputs): carry: container for x: 1D array DF: 2D array - xs: container for - element: angle in radians - xie_mag_at: float - klde_mag_at: float + inputs: container for + element: angle in radians + xie_mag_at: float + klde_mag_at: float Returns: fe_vphi: float, value of the projected distribution function at the point xie @@ -294,29 +318,28 @@ def calc_chi_vals(self, x, DF, inputs): """ element, xie_mag_at, klde_mag_at = inputs - fe_2D_k = checkpoint(self.rotate)(DF, element * 180 / jnp.pi, reshape=False) - fe_1D_k = jnp.sum(fe_2D_k, axis=0) * (x[1] - x[0]) + dvx = vx[1] - vx[0] + fe_2D_k = checkpoint(self.rotate)(vx, DF, element * 180 / jnp.pi, reshape=False) + fe_1D_k = jnp.sum(fe_2D_k, axis=0) * dvx + df = jnp.gradient(fe_1D_k, dvx) # find the location of xie in axis array - loc = jnp.argmin(jnp.abs(x - xie_mag_at)) # add the value of fe to the fe container - fe_vphi = fe_1D_k[loc] - - # derivative of f along k - df = jnp.real(jnp.gradient(fe_1D_k, x[1] - x[0])) + fe_vphi = jnp.interp(xie_mag_at, vx, fe_1D_k) + dfe = jnp.interp(xie_mag_at, vx, df) # Chi is really chi evaluated at the points xie # so the imaginary part is - chiEI = jnp.pi / (klde_mag_at**2) * df[loc] + chiEI = jnp.pi / (klde_mag_at**2) * dfe # the real part is solved with rational integration # giving the value at a single point where the pole is located at xie_mag[ind] chiERrat = ( - -1.0 / (klde_mag_at**2) * jnp.real(ratintn.ratintn(df, x - xie_mag_at, x)) + -1.0 / (klde_mag_at**2) * ratintn.ratintn(df, vx - xie_mag_at, vx) ) # this may need to be downsampled for run time return fe_vphi, chiEI, chiERrat - def calc_all_chi_vals(self, x, DF, beta, xie_mag, klde_mag): + def _calc_all_chi_vals_(self, vx, DF, beta, xie_mag, klde_mag): """ Calculate the susceptibility values for all the desired points xie @@ -339,15 +362,15 @@ def calc_all_chi_vals(self, x, DF, beta, xie_mag, klde_mag): if calc_chi_vals == "scan": _, (fe_vphi, chiEI, chiERrat) = scan( - self.scan_calc_chi_vals, (x, jnp.squeeze(DF)), flattened_inputs, unroll=1 + self.scan_calc_chi_vals, (vx, jnp.squeeze(DF)), flattened_inputs, unroll=1 ) elif calc_chi_vals == "vmap": - fe_vphi, chiEI, chiERrat = self.vmap_calc_chi_vals(x, jnp.squeeze(DF), flattened_inputs) + fe_vphi, chiEI, chiERrat = self.vmap_calc_chi_vals(vx, jnp.squeeze(DF), flattened_inputs) elif calc_chi_vals == "batch_vmap": - batch_vmap_calc_chi_vals = partial(self.calc_chi_vals, x, jnp.squeeze(DF)) - fe_vphi, chiEI, chiERrat = jmap(batch_vmap_calc_chi_vals, xs=flattened_inputs, batch_size=8) + batch_vmap_calc_chi_vals = partial(self.calc_chi_vals, vx, jnp.squeeze(DF)) + fe_vphi, chiEI, chiERrat = jmap(batch_vmap_calc_chi_vals, xs=flattened_inputs, batch_size=128) else: raise NotImplementedError @@ -357,17 +380,25 @@ def calc_all_chi_vals(self, x, DF, beta, xie_mag, klde_mag): return fe_vphi, chiEI, chiERrat - def parallel_calc_all_chi_vals(self, x, DF, flattened_inputs): + def parallel_calc_all_chi_vals(self, x, DF, beta, xie_mag, klde_mag): - beta, xie_mag, klde_mag = flattened_inputs + f_beta = beta.reshape(-1) + f_xie_mag = xie_mag.reshape(-1) + f_klde_mag = klde_mag.reshape(-1) - flat_beta = device_put(beta, self.sharding) - flat_xie_mag = device_put(xie_mag, self.sharding) - flat_klde_mag = device_put(klde_mag, self.sharding) + flat_beta = device_put(f_beta, self.sharding) + flat_xie_mag = device_put(f_xie_mag, self.sharding) + flat_klde_mag = device_put(f_klde_mag, self.sharding) - return self.calc_all_chi_vals(x, DF, (flat_beta, flat_xie_mag, flat_klde_mag)) + fe_vphi, chiEI, chiERrat = self._calc_all_chi_vals_(x, DF, flat_beta, flat_xie_mag, flat_klde_mag) - def calc_in_2D(self, params, ud_ang, va_ang, cur_ne, cur_Te, A, Z, Ti, fract, sa, f_and_v, lam): + fe_vphi = fe_vphi.reshape(beta.shape) + chiEI = chiEI.reshape(beta.shape) + chiERrat = chiERrat.reshape(beta.shape) + + return fe_vphi, chiEI, chiERrat + + def calc_in_2D(self, params): """ Calculates the collisionless Thomson spectral density function S(k,omg) for a 2D numerical EDF, capable of handling multiple plasma conditions and scattering angles. Distribution functions can be arbitrary as @@ -396,37 +427,52 @@ def calc_in_2D(self, params, ud_ang, va_ang, cur_ne, cur_Te, A, Z, Ti, fract, sa wavelength points, number of angles] """ - Te, ne, Va, ud, fe = ( - cur_Te.squeeze(-1), - cur_ne.squeeze(-1), - params["general"]["Va"], - params["general"]["ud"], - f_and_v, # this is now a DistFunc object + ne = ( + 1.0e20 + * params["electron"]["ne"] + * jnp.linspace( + (1 - params["general"]["ne_gradient"] / 200), + (1 + params["general"]["ne_gradient"] / 200), + self.num_grad_points, + ) ) + Te = params["electron"]["Te"] * jnp.linspace( + (1 - params["general"]["Te_gradient"] / 200), + (1 + params["general"]["Te_gradient"] / 200), + self.num_grad_points, + ) + lam = params["general"]["lam"] + self.lam_shift + A = jnp.array([params[species]["A"] for species in params.keys() if "ion" in species]) + Z = jnp.array([params[species]["Z"] for species in params.keys() if "ion" in species]) + Ti = jnp.array([params[species]["Ti"] for species in params.keys() if "ion" in species]) + fract = jnp.array([params[species]["fract"] for species in params.keys() if "ion" in species]) + Va = params["general"]["Va"] * 1e6 # flow velocity in 1e6 cm/s + ud = params["general"]["ud"] * 1e6 # drift velocity in 1e6 cm/s + fe = params["electron"]["fe"] + vx = params["electron"]["v"] Mi = jnp.array(A) * self.Mp # ion mass re = 2.8179e-13 # classical electron radius cm Esq = self.Me * self.C**2 * re # sq of the electron charge keV cm constants = jnp.sqrt(4 * jnp.pi * Esq / self.Me) - sarad = sa * jnp.pi / 180 # scattering angle in radians + sarad = self.scattering_angles["sa"] * jnp.pi / 180 # scattering angle in radians sarad = jnp.reshape(sarad, [1, 1, -1]) - Va = Va * 1e6 # flow velocity in 1e6 cm/s + # Va = Va * 1e6 # flow velocity in 1e6 cm/s # convert Va from mag, angle to x,y - Va = (Va * jnp.cos(va_ang * jnp.pi / 180), Va * jnp.sin(va_ang * jnp.pi / 180)) - ud = ud * 1e6 # drift velocity in 1e6 cm/s + Va = (Va * jnp.cos(self.va_angle * jnp.pi / 180), Va * jnp.sin(self.va_angle * jnp.pi / 180)) + # ud = ud * 1e6 # drift velocity in 1e6 cm/s # convert ua from mag, angle to x,y - ud = (ud * jnp.cos(ud_ang * jnp.pi / 180), ud * jnp.sin(ud_ang * jnp.pi / 180)) - - omgL, omgs, lamAxis, _ = lam_parse.lamParse(self.lamrang, lam, npts=self.npts) # , True) + ud = (ud * jnp.cos(self.ud_angle * jnp.pi / 180), ud * jnp.sin(self.ud_angle * jnp.pi / 180)) + omgL = self.omgL_num / lam # laser frequency Rad / s # calculate k and omega vectors omgpe = constants * jnp.sqrt(ne[..., jnp.newaxis, jnp.newaxis]) # plasma frequency Rad/cm - omgs = omgs[jnp.newaxis, ..., jnp.newaxis] - omg = omgs - omgL + # omgs = omgs[jnp.newaxis, ..., jnp.newaxis] + omg = self.omgs - omgL kL = (jnp.sqrt(omgL**2 - omgpe**2) / self.C, jnp.zeros_like(omgpe)) # defined to be along the x axis - ks_mag = jnp.sqrt(omgs**2 - omgpe**2) / self.C + ks_mag = jnp.sqrt(self.omgs**2 - omgpe**2) / self.C ks = (jnp.cos(sarad) * ks_mag, jnp.sin(sarad) * ks_mag) k = vsub(ks, kL) # 2D k_mag = jnp.sqrt(vdot(k, k)) # 1D @@ -441,9 +487,9 @@ def calc_in_2D(self, params, ud_ang, va_ang, cur_ne, cur_Te, A, Z, Ti, fract, sa klde_mag = (vTe / omgpe) * (k_mag[..., jnp.newaxis]) # 1D # ions - Z = jnp.reshape(Z, [1, 1, 1, -1]) + Z = jnp.reshape(jnp.array(Z), [1, 1, 1, -1]) Mi = jnp.reshape(Mi, [1, 1, 1, -1]) - fract = jnp.reshape(fract, [1, 1, 1, -1]) + fract = jnp.reshape(jnp.array(fract), [1, 1, 1, -1]) Zbar = jnp.sum(Z * fract) ni = fract * ne[..., jnp.newaxis, jnp.newaxis, jnp.newaxis] / Zbar omgpi = constants * Z * jnp.sqrt(ni * self.Me / Mi) @@ -467,15 +513,15 @@ def calc_in_2D(self, params, ud_ang, va_ang, cur_ne, cur_Te, A, Z, Ti, fract, sa # xie = vsub(vdiv(omgdop, vdot(k, vTe)), vdiv(ud, vTe)) xie = vdiv(vsub(vdot(omgdop / k_mag**2, k), ud), vTe) xie_mag = jnp.sqrt(vdot(xie, xie)) - DF, (x, y) = fe - + # DF, (x, y) = fe + # # for each vector in xie # find the rotation angle beta, the heaviside changes the angles to [0, 2pi) beta = jnp.arctan(xie[1] / xie[0]) + jnp.pi * (-jnp.heaviside(xie[0], 1) + 1) - fe_vphi, chiEI, chiERrat = self.calc_all_chi_vals(x[0, :], DF, beta, xie_mag, klde_mag) + fe_vphi, chiEI, chiERrat = self.calc_all_chi_vals(vx, fe, beta, xie_mag, klde_mag) - chiE = chiERrat + jnp.sqrt(-1 + 0j) * chiEI + chiE = chiERrat + 1j * chiEI epsilon = 1.0 + chiE + chiI # This line needs to be changed if ion distribution is changed!!! @@ -494,16 +540,10 @@ def calc_in_2D(self, params, ud_ang, va_ang, cur_ne, cur_Te, A, Z, Ti, fract, sa PsOmg = (SKW_ion_omg + SKW_ele_omg) * (1 + 2 * omgdop / omgL) * re**2.0 * ne[:, None, None] # PsOmgE = (SKW_ele_omg) * (1 + 2 * omgdop / omgL) * re**2.0 * jnp.transpose(ne) # commented because unused - lams = 2 * jnp.pi * self.C / omgs + lams = 2 * jnp.pi * self.C / self.omgs PsLam = PsOmg * 2 * jnp.pi * self.C / lams**2 # PsLamE = PsOmgE * 2 * jnp.pi * C / lams**2 # commented because unused formfactor = PsLam # formfactorE = PsLamE # commented because unused - # - # from matplotlib import pyplot as plt - # - # fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True, sharex=False) - # ax[0].plot(fe_vphi[1, :, 0]) - # plt.show() return formfactor, lams diff --git a/tsadar/core/physics/generate_spectra.py b/tsadar/core/physics/generate_spectra.py new file mode 100644 index 00000000..2b4b0926 --- /dev/null +++ b/tsadar/core/physics/generate_spectra.py @@ -0,0 +1,152 @@ +from typing import Dict + +from .form_factor import FormFactor + +from jax import numpy as jnp + + +class FitModel: + """ + The FitModel Class wraps the FormFactor class adding finite aperture effects and finite volume effects. This class + also handles the options for calculating the form factor. + + Args: + config: Dict- configuration dictionary built from input deck + sa: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative + weights of each of the scattering angles in the final spectrum + """ + + def __init__(self, config: Dict, scattering_angles: Dict): + """ + FitModel class constructor, sets the static properties associated with spectrum generation that will not be + modified from one iteration of the fitter to the next. + + Args: + config: Dict- configuration dictionary built from input deck + sa: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative + weights of each of the scattering angles in the final spectrum + """ + self.config = config + self.scattering_angles = scattering_angles + + assert ( + config["parameters"]["general"]["Te_gradient"]["num_grad_points"] + == config["parameters"]["general"]["ne_gradient"]["num_grad_points"] + ), "Number of gradient points for Te and ne must be the same" + num_grad_points = config["parameters"]["general"]["Te_gradient"]["num_grad_points"] + + ud_angle = ( + None + if config["parameters"]["electron"]["fe"]["dim"] < 2 + else config["parameters"]["general"]["ud"]["angle"] + ) + va_angle = ( + None + if config["parameters"]["electron"]["fe"]["dim"] < 2 + else config["parameters"]["general"]["Va"]["angle"] + ) + self.electron_form_factor = FormFactor( + config["other"]["lamrangE"], + npts=config["other"]["npts"], + lam_shift=config["data"]["ele_lam_shift"], + scattering_angles=self.scattering_angles, + num_grad_points=num_grad_points, + va_ang=va_angle, + ud_ang=ud_angle, + ) + self.ion_form_factor = FormFactor( + config["other"]["lamrangI"], + npts=config["other"]["npts"], + lam_shift=0, + scattering_angles=scattering_angles, + num_grad_points=num_grad_points, + va_ang=va_angle, + ud_ang=ud_angle, + ) + + def __call__(self, all_params: Dict): + """ + Produces Thomson spectra corrected for finite aperture and optionally including gradients in the plasma + conditions based off the current parameter dictionary. Calling this method will automatically choose the + appropriate version of the formfactor class based off the dimension and distribute the conditions for + multiple ion species to their respective inputs. + + + Args: + all_params: Parameter dictionary containing the current values for all active and static parameters. Only a + few permanently static properties from the configuration dictionary will be used, everything else must + be included in this input. + + Returns: + modlE: calculated electron plasma wave spectrum as an array with length of npts. If an angular spectrum is + calculated then it will be 2D. If the EPW is not loaded this is returned as the int 0. + modlI: calculated ion acoustic wave spectrum as an array with length of npts. If the IAW is not loaded this + is returned as the int 0. + lamAxisE: electron plasma wave wavelength axis as an array with length of npts. If the EPW is not loaded + this is returned as an empty list. + lamAxisI: ion acoustic wave wavelength axis as an array with length of npts. If the IAW is not loaded + this is returned as an empty list. + all_params: The input all_params is returned + + """ + + lamAxisI, modlI = self.ion_spectrum(all_params) + lamAxisE, modlE = self.electron_spectrum(all_params) + + return modlE, modlI, lamAxisE, lamAxisI + + def ion_spectrum(self, all_params): + if self.config["other"]["extraoptions"]["load_ion_spec"]: + if self.num_dist_func.dim == 1: + ThryI, lamAxisI = self.ion_form_factor(all_params) + + else: + ThryI, lamAxisI = self.ion_form_factor.calc_in_2D(all_params) + + # remove extra dimensions and rescale to nm + lamAxisI = jnp.squeeze(lamAxisI) * 1e7 # TODO hardcoded + ThryI = jnp.mean(ThryI, axis=0) + modlI = jnp.sum(ThryI * self.scattering_angles["weights"][0], axis=1) + else: + modlI = 0 + lamAxisI = jnp.zeros(1) + return lamAxisI, modlI + + def electron_spectrum(self, all_params): + if self.config["other"]["extraoptions"]["load_ele_spec"]: + if self.config["parameters"]["electron"]["fe"]["dim"] == 1: + ThryE, lamAxisE = self.electron_form_factor(all_params) + elif self.config["parameters"]["electron"]["fe"]["dim"] == 2: + ThryE, lamAxisE = self.electron_form_factor.calc_in_2D(all_params) + + # remove extra dimensions and rescale to nm + lamAxisE = jnp.squeeze(lamAxisE) * 1e7 # TODO hardcoded + + ThryE = jnp.mean(ThryE, axis=0) + if self.config["other"]["extraoptions"]["spectype"] == "angular_full": + modlE = jnp.matmul(self.scattering_angles["weights"], ThryE.transpose()) + else: + modlE = jnp.sum(ThryE * self.scattering_angles["weights"][0], axis=1) + + lam = all_params["general"]["lam"] + if self.config["other"]["iawoff"] and ( + self.config["other"]["lamrangE"][0] < lam < self.config["other"]["lamrangE"][1] + ): + # set the ion feature to 0 #should be switched to a range about lam + lamlocb = jnp.argmin(jnp.abs(lamAxisE - lam - 3.0)) + lamlocr = jnp.argmin(jnp.abs(lamAxisE - lam + 3.0)) + modlE = jnp.concatenate( + [modlE[:lamlocb], jnp.zeros(lamlocr - lamlocb), modlE[lamlocr:]] + ) # TODO hardcoded + + if self.config["other"]["iawfilter"][0]: + filterb = self.config["other"]["iawfilter"][3] - self.config["other"]["iawfilter"][2] / 2 + filterr = self.config["other"]["iawfilter"][3] + self.config["other"]["iawfilter"][2] / 2 + + if self.config["other"]["lamrangE"][0] < filterr and self.config["other"]["lamrangE"][1] > filterb: + indices = (filterb < lamAxisE) & (filterr > lamAxisE) + modlE = jnp.where(indices, modlE * 10 ** (-self.config["other"]["iawfilter"][1]), modlE) + else: + modlE = 0 + lamAxisE = [] + return lamAxisE, modlE diff --git a/tsadar/process/irf.py b/tsadar/core/physics/irf.py similarity index 89% rename from tsadar/process/irf.py rename to tsadar/core/physics/irf.py index 4149c1dd..fe1200d1 100644 --- a/tsadar/process/irf.py +++ b/tsadar/core/physics/irf.py @@ -2,7 +2,7 @@ from jax import numpy as jnp -def add_ATS_IRF(config, sas, lamAxisE, modlE, amps, TSins, lam) -> Tuple[jnp.ndarray, jnp.ndarray]: +def add_ATS_IRF(config, sas, lamAxisE, modlE, amps, TSins) -> Tuple[jnp.ndarray, jnp.ndarray]: """ Applies a 2D gaussian smoothing to angular Thomson data to account for the instrument response of the diagnostic. todo: improve doc and typehints @@ -41,9 +41,9 @@ def add_ATS_IRF(config, sas, lamAxisE, modlE, amps, TSins, lam) -> Tuple[jnp.nda if config["other"]["PhysParams"]["norm"] > 0: ThryE = jnp.where( - lamAxisE < lam, - TSins["general"]["amp1"] * (ThryE / jnp.amax(ThryE[lamAxisE < lam])), - TSins["general"]["amp2"] * (ThryE / jnp.amax(ThryE[lamAxisE > lam])), + lamAxisE < TSins["general"]["lam"], + TSins["general"]["amp1"] * (ThryE / jnp.amax(ThryE[lamAxisE < TSins["general"]["lam"]])), + TSins["general"]["amp2"] * (ThryE / jnp.amax(ThryE[lamAxisE > TSins["general"]["lam"]])), ) return lamAxisE, ThryE @@ -85,7 +85,7 @@ def add_ion_IRF(config, lamAxisI, modlI, amps, TSins) -> Tuple[jnp.ndarray, jnp. return lamAxisI, ThryI -def add_electron_IRF(config, lamAxisE, modlE, amps, TSins, lam) -> Tuple[jnp.ndarray, jnp.ndarray]: +def add_electron_IRF(config, lamAxisE, modlE, amps, TSins) -> Tuple[jnp.ndarray, jnp.ndarray]: """ electron IRF (Instrument Response Function?) @@ -113,15 +113,17 @@ def add_electron_IRF(config, lamAxisE, modlE, amps, TSins, lam) -> Tuple[jnp.nda if config["other"]["PhysParams"]["norm"] > 0: ThryE = jnp.where( - lamAxisE < lam, - TSins["general"]["amp1"] * (ThryE / jnp.amax(ThryE[lamAxisE < lam])), - TSins["general"]["amp2"] * (ThryE / jnp.amax(ThryE[lamAxisE > lam])), + lamAxisE < TSins["general"]["lam"], + TSins["general"]["amp1"] * (ThryE / jnp.amax(ThryE[lamAxisE < TSins["general"]["lam"]])), + TSins["general"]["amp2"] * (ThryE / jnp.amax(ThryE[lamAxisE > TSins["general"]["lam"]])), ) ThryE = jnp.average(ThryE.reshape(1024, -1), axis=1) if config["other"]["PhysParams"]["norm"] == 0: lamAxisE = jnp.average(lamAxisE.reshape(1024, -1), axis=1) ThryE = amps * ThryE / jnp.amax(ThryE) - ThryE = jnp.where(lamAxisE < lam, TSins["general"]["amp1"] * ThryE, TSins["general"]["amp2"] * ThryE) + ThryE = jnp.where( + lamAxisE < TSins["general"]["lam"], TSins["general"]["amp1"] * ThryE, TSins["general"]["amp2"] * ThryE + ) return lamAxisE, ThryE diff --git a/tsadar/model/physics/ratintn.py b/tsadar/core/physics/ratintn.py similarity index 98% rename from tsadar/model/physics/ratintn.py rename to tsadar/core/physics/ratintn.py index ce467728..818f33b9 100644 --- a/tsadar/model/physics/ratintn.py +++ b/tsadar/core/physics/ratintn.py @@ -51,4 +51,4 @@ def ratcen(f: jnp.ndarray, g: jnp.ndarray) -> jnp.ndarray: rfn = fdif / gdif + tmp * jnp.log((gav + (0.5 + 0j) * gdif) / (gav - 0.5 * gdif)) / gdif**2 out = jnp.where((jnp.abs(gdif) < 1.0e-4 * jnp.abs(gav))[None, :], rf, rfn) - return out + return jnp.real(out) diff --git a/tsadar/model/spectrum.py b/tsadar/core/thomson_diagnostic.py similarity index 63% rename from tsadar/model/spectrum.py rename to tsadar/core/thomson_diagnostic.py index d64ed614..83b9fcb0 100644 --- a/tsadar/model/spectrum.py +++ b/tsadar/core/thomson_diagnostic.py @@ -1,11 +1,12 @@ -from jax import numpy as jnp -from jax import vmap +from jax import numpy as jnp, vmap -from tsadar.model.physics.generate_spectra import FitModel -from tsadar.process import irf +from .modules import ThomsonParams +from .physics import irf +from .physics.generate_spectra import FitModel -class SpectrumCalculator: + +class ThomsonScatteringDiagnostic: """ The SpectrumCalculator class wraps the FitModel class adding instrumental effects to the calculated spectrum so it can be compared to data. @@ -15,27 +16,29 @@ class SpectrumCalculator: Args: cfg: Dict- configuration dictionary built from input deck - sas: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative + scattering_angles: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative weights of each of the scattering angles in the final spectrum - dummy_batch: Dict- data dictionary containing the electron and ion data as well as the noise and amplitudes """ - def __init__(self, cfg, sas, dummy_batch): + def __init__(self, cfg, scattering_angles): super().__init__() self.cfg = cfg - self.sas = sas - - self.forward_pass = FitModel(cfg, sas) - self.lam = cfg["parameters"]["general"]["lam"]["val"] + self.scattering_angles = scattering_angles + self.model = FitModel(cfg, scattering_angles) - if cfg["other"]["extraoptions"]["spectype"] == "angular_full": - self.vmap_forward_pass = self.forward_pass - self.vmap_postprocess_thry = self.postprocess_thry + if ( + "temporal" in cfg["other"]["extraoptions"]["spectype"] + or "imaging" in cfg["other"]["extraoptions"]["spectype"] + or "1d" in cfg["other"]["extraoptions"]["spectype"] + ): + self.model = vmap(self.model) + self.postprocess_theory = vmap(self.postprocess_theory) + elif "angular" in cfg["other"]["extraoptions"]["spectype"]: + pass else: - self.vmap_forward_pass = vmap(self.forward_pass) - self.vmap_postprocess_thry = vmap(self.postprocess_thry) + raise NotImplementedError(f"Unknown spectype: {cfg['other']['extraoptions']['spectype']}") - def postprocess_thry(self, modlE, modlI, lamAxisE, lamAxisI, amps, TSins): + def postprocess_theory(self, modlE, modlI, lamAxisE, lamAxisI, amps, TSins): """ Adds instrumental broadening to the synthetic Thomson spectrum. @@ -53,18 +56,17 @@ def postprocess_thry(self, modlE, modlI, lamAxisE, lamAxisI, amps, TSins): if self.cfg["other"]["extraoptions"]["load_ion_spec"]: lamAxisI, ThryI = irf.add_ion_IRF(self.cfg, lamAxisI, modlI, amps["i_amps"], TSins) else: - # lamAxisI = jnp.nan - ThryI = modlI # jnp.nan - - if self.cfg["other"]["extraoptions"]["load_ele_spec"] & ( - self.cfg["other"]["extraoptions"]["spectype"] == "angular_full" - ): - lamAxisE, ThryE = irf.add_ATS_IRF(self.cfg, self.sas, lamAxisE, modlE, amps["e_amps"], TSins, self.lam) - elif self.cfg["other"]["extraoptions"]["load_ele_spec"]: - lamAxisE, ThryE = irf.add_electron_IRF(self.cfg, lamAxisE, modlE, amps["e_amps"], TSins, self.lam) + ThryI = modlI + + if self.cfg["other"]["extraoptions"]["load_ele_spec"]: + if self.cfg["other"]["extraoptions"]["spectype"] == "angular_full": + lamAxisE, ThryE = irf.add_ATS_IRF( + self.cfg, self.scattering_angles, lamAxisE, modlE, amps["e_amps"], TSins + ) + else: + lamAxisE, ThryE = irf.add_electron_IRF(self.cfg, lamAxisE, modlE, amps["e_amps"], TSins) else: - # lamAxisE = jnp.nan - ThryE = modlE # jnp.nan + ThryE = modlE return ThryE, ThryI, lamAxisE, lamAxisI @@ -94,10 +96,12 @@ def reduce_ATS_to_resunit(self, ThryE, lamAxisE, TSins, batch): ) ThryE = ThryE[self.cfg["data"]["lineouts"]["start"] : self.cfg["data"]["lineouts"]["end"], :] ThryE = batch["e_amps"] * ThryE / jnp.amax(ThryE, axis=1, keepdims=True) - ThryE = jnp.where(lamAxisE < self.lam, TSins["general"]["amp1"] * ThryE, TSins["general"]["amp2"] * ThryE) + ThryE = jnp.where( + lamAxisE < TSins["general"]["lam"], TSins["general"]["amp1"] * ThryE, TSins["general"]["amp2"] * ThryE + ) return ThryE, lamAxisE - def __call__(self, params, batch): + def __call__(self, ts_params: ThomsonParams, batch): """ TODO @@ -108,12 +112,14 @@ def __call__(self, params, batch): Returns: """ - modlE, modlI, lamAxisE, lamAxisI, live_TSinputs = self.vmap_forward_pass(params) # , sas["weights"]) - ThryE, ThryI, lamAxisE, lamAxisI = self.vmap_postprocess_thry( - modlE, modlI, lamAxisE, lamAxisI, {"e_amps": batch["e_amps"], "i_amps": batch["i_amps"]}, live_TSinputs + + physical_params = ts_params() + modlE, modlI, lamAxisE, lamAxisI = self.model(physical_params) + ThryE, ThryI, lamAxisE, lamAxisI = self.postprocess_theory( + modlE, modlI, lamAxisE, lamAxisI, {"e_amps": batch["e_amps"], "i_amps": batch["i_amps"]}, physical_params ) if self.cfg["other"]["extraoptions"]["spectype"] == "angular_full": - ThryE, lamAxisE = self.reduce_ATS_to_resunit(ThryE, lamAxisE, live_TSinputs, batch) + ThryE, lamAxisE = self.reduce_ATS_to_resunit(ThryE, lamAxisE, physical_params, batch) ThryE = ThryE + batch["noise_e"] ThryI = ThryI + batch["noise_i"] diff --git a/tsadar/data_handleing/calibrations/sa_table.py b/tsadar/data_handleing/calibrations/sa_table.py deleted file mode 100644 index cadbe217..00000000 --- a/tsadar/data_handleing/calibrations/sa_table.py +++ /dev/null @@ -1,189 +0,0 @@ -import numpy as np - - -def sa_lookup(beam): - """ - Creates the scattering angle dictionary with the scattering angles and their weights based of the chosen probe - beam. All values are precalculated. Available options are P9, B12, B15, B23, B26, B35, B42, B46, B58. - - Args: - beam: string with the name of the beam to be used as a probe - - Returns: - sa: dictionary with scattering angles in the 'sa' field and their relative weights in the 'weights' field - """ - if beam == "P9": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(53.637560, 66.1191, 10), - weights=np.array( - [ - 0.00702671050853565, - 0.0391423809738300, - 0.0917976667717670, - 0.150308544660150, - 0.189541011666141, - 0.195351560740507, - 0.164271879645061, - 0.106526733030044, - 0.0474753389486960, - 0.00855817305526778, - ] - ), - ) - elif beam == "B12": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(71.0195, 83.3160, 10), - weights=np.array( - [ - 0.007702, - 0.0404, - 0.09193, - 0.1479, - 0.1860, - 0.1918, - 0.1652, - 0.1083, - 0.05063, - 0.01004, - ] - ), - ) - elif beam == "B15": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(12.0404, 24.0132, 10), - weights=np.array( - [ - 0.0093239, - 0.04189, - 0.0912121, - 0.145579, - 0.182019, - 0.188055, - 0.163506, - 0.1104, - 0.0546822, - 0.0133327, - ] - ), - ) - elif beam == "B23": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(72.281, 84.3307, 10), - weights=np.array( - [ - 0.00945903, - 0.0430611, - 0.0925634, - 0.146705, - 0.182694, - 0.1881, - 0.162876, - 0.109319, - 0.0530607, - 0.0121616, - ] - ), - ) - elif beam == "B26": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(55.5636, 68.1058, 10), - weights=np.array( - [ - 0.00648619, - 0.0386019, - 0.0913923, - 0.150489, - 0.190622, - 0.195171, - 0.166389, - 0.105671, - 0.0470249, - 0.00815279, - ] - ), - ) - elif beam == "B35": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(32.3804, 44.6341, 10), - weights=np.array( - [ - 0.00851313, - 0.0417549, - 0.0926084, - 0.149182, - 0.187019, - 0.191523, - 0.16265, - 0.106842, - 0.049187, - 0.0107202, - ] - ), - ) - elif beam == "B42": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(155.667, 167.744, 10), - weights=np.array( - [ - 0.00490969, - 0.0257646, - 0.0601324, - 0.106076, - 0.155308, - 0.187604, - 0.19328, - 0.15702, - 0.0886447, - 0.0212603, - ] - ), - ) - elif beam == "B46": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(56.5615, 69.1863, 10), - weights=np.array( - [ - 0.00608081, - 0.0374307, - 0.0906716, - 0.140714, - 0.191253, - 0.197333, - 0.166164, - 0.106121, - 0.0464844, - 0.0077474, - ] - ), - ) - elif beam == "B58": - # Scattering angle in degrees for OMEGA TIM6 TS - sa = dict( - sa=np.linspace(119.093, 131.666, 10), - weights=np.array( - [ - 0.00549525, - 0.0337372, - 0.0819783, - 0.140084, - 0.186388, - 0.19855, - 0.174136, - 0.117517, - 0.0527003, - 0.00941399, - ] - ), - ) - else: - raise NotImplmentedError("Other probe geometrries are not yet supported") - - return sa diff --git a/tsadar/distribution_functions/dist_functional_forms.py b/tsadar/distribution_functions/dist_functional_forms.py deleted file mode 100644 index c3b301b1..00000000 --- a/tsadar/distribution_functions/dist_functional_forms.py +++ /dev/null @@ -1,215 +0,0 @@ -from jax.scipy.special import gamma -from jax import numpy as jnp -from tsadar.misc.vector_tools import rotate -from interpax import interp2d - - -# we will probably want to add input checks to ensure the proper fields are defined -def DLM_1D(m, h): - """ - Produces a 1-D Dum-Langdon-Matte distribution parametrized by a super-gaussian order m. - - Args: - m: (int) Super-Gaussian order - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - vx: normalized velocity grid - fe_num: numerical distribution function - """ - - def SG(vx, m): - x0 = jnp.sqrt(3 * gamma(3 / m) / gamma(5 / m)) - return jnp.exp(-((jnp.abs(vx) / x0) ** m)) - - vx = jnp.arange(-8, 8, h) - fe_num = jnp.array([trapz(SG(jnp.sqrt(vx**2 + vz**2), m), h) for vz in vx]) - - # x0 = jnp.sqrt(3 * gamma(3 / m) / gamma(5 / m)) - # fe_num = jnp.exp(-((jnp.abs(vx) / x0) ** m)) - fe_num = fe_num / trapz(fe_num, h) - return vx, fe_num - - -def SG_1D(m, h): - """ - Produces a 1-D Super-Gaussian distribution parametrized by a super-gaussian order m. - - Args: - m: (int) Super-Gaussian order - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - vx: normalized velocity grid - fe_num: numerical distribution function - """ - vx = jnp.arange(-8, 8, h) - x0 = jnp.sqrt(3 * gamma(3 / m) / gamma(5 / m)) - fe_num = jnp.exp(-((jnp.abs(vx) / x0) ** m)) - fe_num = fe_num / trapz(fe_num, h) - return vx, fe_num - - -# Warning: These super-gaussian orders do not follow Matte -def DLM_2D(m, h): - """ - Produces a 2-D symmetric Dum-Langdon-Matte distribution parametrized by a super-gaussian order m. - - Args: - m: (int) Super-Gaussian order - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - (vx, vy): tuple of the normalized velocity grids in x and y - fe_num: numerical distribution function - """ - vx = jnp.arange(-8, 8, h) - vy = jnp.arange(-8, 8, h) - vx, vy = jnp.meshgrid(vx, vy) - x0 = jnp.sqrt(3 * gamma(3 / m) / gamma(5 / m)) - fe_num = jnp.exp(-((jnp.sqrt(vx**2 + vy**2) / x0) ** m)) - fe_num = fe_num / trapz(trapz(fe_num, h), h) - return (vx, vy), fe_num - - -# Warning: These super-gaussian orders do not follow Matte -def BiDLM(mx, my, tasym, theta, h): - """ - Produces a 2-D Dum-Langdon-Matte distribution that can have different widths and super-gaussian orders in the 2 - dimensions. - - Args: - mx: (int) Super-Gaussian order for the x direction - my: (int) Super-Gaussian order for the y direction - tasym: (int) Temperature asymetry, where the y direction will have an effective temperature of Te*tasym. x - direction will have an effective temperature of Te. - theta: (int) counter-clockwise rotation of the distribution in radians - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - (vx, vy): tuple of the normalized velocity grids in x and y - fe_num: numerical distribution function - """ - vx = jnp.arange(-8, 8, h) - vy = jnp.arange(-8, 8, h) - vx, vy = jnp.meshgrid(vx, vy) - x0x = jnp.sqrt(3 * gamma(3 / mx) / gamma(5 / mx)) - x0y = jnp.sqrt(3 * gamma(3 / my) / gamma(5 / my)) - fe_num = jnp.exp(-((jnp.abs(vx) / x0x) ** mx) - (jnp.abs(vy) / (x0y * jnp.sqrt(tasym))) ** my) - fe_num = rotate(fe_num, theta) - #fe_num = fe_num / calc_moment(fe_num,(vx,vy),0) - - renorm = jnp.sqrt(calc_moment(fe_num,(vx,vy),2)/ (2*calc_moment(fe_num,(vx,vy),0)))#the 2 is to make the moment equal the number of dimensions, not sure on this - h2 = h/renorm - #vx2 = jnp.arange(-8/renorm, 8/renorm, h2) - vx2 = vx[0]/renorm - vy2 = vx[0]/renorm - #vy2 = jnp.arange(-8/renorm, 8/renorm, h2) - print(jnp.shape(fe_num)) - print(jnp.shape(vx2)) - print(h2) - print(jnp.shape(jnp.log(fe_num))) - fe_num = jnp.exp(interp2d(vx.flatten(), vy.flatten(), vx2, vy2, jnp.log(fe_num), extrap=[-100, -100], method="linear").reshape(jnp.shape(vx),order="F")) - fe_num = fe_num / calc_moment(fe_num,(vx,vy),0) - - return (vx, vy), fe_num - - -# not positive on the normalizations for f1 vs f0 so dt may not be =lambda_ei/LT -def Spitzer_3V(dt, vq, h): - """ - Produces a 2-D Spitzer-Harm distribution with the f1 direction given in 3-space. - - Args: - dt: (int) Knudsen number determining the magnitude of the perturbation - vq: array or list with 3 elements giving the direction of the f1 perturbation in x,y,z - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - (vx, vy): tuple of the normalized velocity grids in x and y - fe_num: numerical distribution function - """ - # likely to OOM (probably a shortcut by calculating the anlge out of the plane and multiplying f1 by cos of that angle) - x = jnp.arange(-8, 8, h) - y = jnp.arange(-8, 8, h) - z = jnp.arange(-8, 8, h) - vx, vy, vz = jnp.meshgrid(x, y, z) - # vq = vq/jnp.sqrt(vq[0]**2 + vq[1]**2 + vq[2]**2) - f0 = 1 / (2 * jnp.pi) ** (3 / 2) * jnp.exp(-(vx**2 + vy**2 + vz**2) / 2) - f1 = ( - dt - * jnp.sqrt(2 / (9 * jnp.pi)) - * (vx * vq[0] + vy * vq[1] + vz * vq[2]) ** 4 - * (4 - (vx * vq[0] + vy * vq[1] + vz * vq[2]) / 2) - * f0 - ) - fe_num = f0 + f1 - fe_num = trapz(fe_num, h) # integrate over z - fe_num = fe_num / trapz(trapz(fe_num, h), h) # renormalize - - # redefine to coordinates - vx, vy = jnp.meshgrid(x, y) - - return (vx, vy), fe_num - - -def Spitzer_2V(dt, vq, h): - """ - Produces a 2-D Spitzer-Harm distribution with the f1 direction given in the plane. - - Args: - dt: (int) Knudsen number determining the magnitude of the perturbation - vq: array or list with 2 elements giving the direction of the f1 perturbation in x,y - h: (int) resolution of normalized velocity grid, i.e. spacing of the grid - - Returns: - (vx, vy): tuple of the normalized velocity grids in x and y - fe_num: numerical distribution function - """ - x = jnp.arange(-8, 8, h) - y = jnp.arange(-8, 8, h) - vx, vy = jnp.meshgrid(x, y) - # vq = vq/jnp.sqrt(vq[0]**2 + vq[1]**2) - f0 = 1 / (2 * jnp.pi) ** (3 / 2) * jnp.exp(-(vx**2 + vy**2) / 2) - f1 = dt * jnp.sqrt(2 / (9 * jnp.pi)) * (vx * vq[0] + vy * vq[1]) ** 4 * (4 - (vx * vq[0] + vy * vq[1]) / 2) * f0 - fe_num = f0 + f1 - fe_num = fe_num / trapz(trapz(fe_num, h), h) # renormalize - - return (vx, vy), fe_num - -def calc_moment(f,v,m): - """ - Calculates the moment of the distribtuion function specified by m - - Args: - f: function to calculate the moment of - m: moment 0, 1, or 2 - v: velocity grid - - Returns: - moment_val: value of the mth moment - """ - #print(jnp.shape(f)) - #print(jnp.shape(v)) - if len(jnp.shape(f))==1: - moment_val = trapz(v**m *f, v[1]-v[0]) - elif len(jnp.shape(f))==2: - moment_val = trapz(trapz((v[0]**2 + v[1]**2)**(m/2) *f, v[0][0][1]-v[0][0][0]), v[1][1][0]-v[1][0][0]) - - return moment_val - -def trapz(y, dx): - """ - JAX compatible trapizoidal intergration. - - Args: - y: numerical array to be integrated - dx: spacing of the associated x-axis - - Returns: - z: integral of ydx - """ - return 0.5 * (dx * (y[..., 1:] + y[..., :-1])).sum(-1) - - -# def MoraYahi_3V(dt, vq, m, h) diff --git a/tsadar/distribution_functions/gen_num_dist_func.py b/tsadar/distribution_functions/gen_num_dist_func.py deleted file mode 100644 index 4bf87d38..00000000 --- a/tsadar/distribution_functions/gen_num_dist_func.py +++ /dev/null @@ -1,134 +0,0 @@ -from jax import numpy as jnp -import scipy.io as sio -import jax, os - - -BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "aux") - -from tsadar.distribution_functions import dist_functional_forms - - -# needs the ability to enforce symetry -class DistFunc: - """ - Distribution function class used to generate numerical distribution functions based off some known functional forms. - Eventually this class will be expanded to handle loading of numerical distribution function from text files. - - """ - - def __init__(self, cfg): - """ - Distribution function class constructor, reads the inout deck and used the relevant fields to set static - parameters for the distribution function creation. These include properties like the dimension and velocity grid - spacing that are static. - - - Args: - cfg: Dictionary for the electron species, a subfield of the input deck dictionary - - Returns: - DistFunc: An instance of the DistFunc class - - """ - self.velocity_res = cfg["fe"]["v_res"] - self.fe_name = list(cfg["fe"]["type"].keys())[-1] - - if "dim" in cfg["fe"].keys(): - self.dim = cfg["fe"]["dim"] - else: - self.dim = 1 - - if "dt" in cfg["fe"].keys(): - self.dt = cfg["fe"]["dt"] - - # normalized here so it only is done once - if "f1_direction" in cfg["fe"].keys(): - self.f1_direction = jnp.array(cfg["fe"]["f1_direction"]) / jnp.sqrt( - jnp.sum(jnp.array([ele**2 for ele in cfg["fe"]["f1_direction"]])) - ) - # temperature asymetry for biDLM with Tex = Te and Tey = Te*temp_asym - if "temp_asym" in cfg["fe"].keys(): - self.temp_asym = cfg["fe"]["temp_asym"] - else: - self.temp_asym = 1.0 - - # m asymetry for biDLM with mx = m and my = m*m_asym (with a min of 2) - if "m_asym" in cfg["fe"].keys(): - self.m_asym = cfg["fe"]["m_asym"] - else: - self.m_asym = 1.0 - - # rotion angle for the biDLM defined counter clockwise from the x-axis in degrees - if "m_theta" in cfg["fe"].keys(): - self.m_theta = cfg["fe"]["m_theta"] / 180.0 * jnp.pi - else: - self.m_theta = 0.0 - - def __call__(self, mval): - """ - Distribution function class call, produces a numerical distribution function based of the object and the current - m-value. - - - Args: - mval: super-gaussian order to be used in calculation must be a float or shape (1,) - - Returns: - v: Velocity grid, for 1D distribution this is a single array, for 2D this is a tuple of arrays - fe: Numerical distribution function - - """ - if self.fe_name == "DLM": - if self.dim == 1: - # v, fe = dist_functional_forms.DLM_1D(mval, self.velocity_res) - tabl = os.path.join(BASE_FILES_PATH, "numDistFuncs/DLM_x_-3_-10_10_m_-1_2_5.mat") - tablevar = sio.loadmat(tabl, variable_names="IT") - IT = tablevar["IT"] - vx = jnp.arange(-8, 8, self.velocity_res) - xs = jnp.arange(-10, 10, 0.001) - ms = jnp.arange(2, 5, 0.1) - x_float_inds = jnp.interp(vx, xs, jnp.linspace(0, xs.shape[0] - 1, xs.shape[0])) - m_float_inds = jnp.interp(mval, ms, jnp.linspace(0, ms.shape[0] - 1, ms.shape[0])) - - # np.linspace(0, params["x"].size - 1, params["x"].size)) - # m_float_inds = jnp.array(jnp.interp(m, params["m"], np.linspace(0, params["m"].size - 1, params["m"].size))) - m_float_inds = m_float_inds.reshape((1,)) - ind_x_mesh, ind_m_mesh = jnp.meshgrid(x_float_inds, m_float_inds) - indices = jnp.concatenate([ind_x_mesh.flatten()[:, None], ind_m_mesh.flatten()[:, None]], axis=1) - - fe = jax.scipy.ndimage.map_coordinates(IT, indices.T, order=1, mode="constant", cval=0.0) - v = vx - elif self.dim == 2: - # v, fe = dist_functional_forms.DLM_2D(mdict["val"], self.velocity_res) - # v, fe = dist_functional_forms.BiDLM( - # mval, - # jnp.max(jnp.array([mval * self.m_asym, 2.0])), - # jnp.max(jnp.array([jnp.array(mval * self.m_asym).squeeze(), 2.0])), - # self.temp_asym, - # self.m_theta, - # self.velocity_res, - # ) - # this will cause issues if my is less then 2 - v, fe = dist_functional_forms.BiDLM( - mval, mval * self.m_asym, self.temp_asym, self.m_theta, self.velocity_res - ) - - elif self.fe_name == "Spitzer": - if self.dim == 2: - if len(self.f1_direction) == 2: - v, fe = dist_functional_forms.Spitzer_2V(self.dt, self.f1_direction, self.velocity_res) - elif len(self.f1_direction) == 3: - v, fe = dist_functional_forms.Spitzer_3V(self.dt, self.f1_direction, self.velocity_res) - else: - raise ValueError("Spitzer distribution can only be computed in 2D") - - elif self.fe_name == "MYDLM": - if self.dim == 2: - if len(self.f1_direction) == 2: - v, fe = dist_functional_forms.MoraYahi_2V(self.dt, self.f1_direction, self.velocity_res) - elif len(self.f1_direction) == 3: - v, fe = dist_functional_forms.MoriYahi_3V(self.dt, self.f1_direction, self.velocity_res) - else: - raise ValueError("Mora and Yahi distribution can only be computed in 2D") - - return v, fe diff --git a/tsadar/aux/data/ATS-s94475.hdf b/tsadar/external/data/ATS-s94475.hdf similarity index 100% rename from tsadar/aux/data/ATS-s94475.hdf rename to tsadar/external/data/ATS-s94475.hdf diff --git a/tsadar/aux/data/ATS-s94477.hdf b/tsadar/external/data/ATS-s94477.hdf similarity index 100% rename from tsadar/aux/data/ATS-s94477.hdf rename to tsadar/external/data/ATS-s94477.hdf diff --git a/tsadar/aux/data/EPW-s101675.hdf b/tsadar/external/data/EPW-s101675.hdf similarity index 100% rename from tsadar/aux/data/EPW-s101675.hdf rename to tsadar/external/data/EPW-s101675.hdf diff --git a/tsadar/aux/data/EPW_CCD-s102583.hdf b/tsadar/external/data/EPW_CCD-s102583.hdf similarity index 100% rename from tsadar/aux/data/EPW_CCD-s102583.hdf rename to tsadar/external/data/EPW_CCD-s102583.hdf diff --git a/tsadar/aux/data/EPW_CCD-s102584.hdf b/tsadar/external/data/EPW_CCD-s102584.hdf similarity index 100% rename from tsadar/aux/data/EPW_CCD-s102584.hdf rename to tsadar/external/data/EPW_CCD-s102584.hdf diff --git a/tsadar/aux/data/IAW-s101675.hdf b/tsadar/external/data/IAW-s101675.hdf similarity index 100% rename from tsadar/aux/data/IAW-s101675.hdf rename to tsadar/external/data/IAW-s101675.hdf diff --git a/tsadar/aux/data/IAW-s108135.hdf b/tsadar/external/data/IAW-s108135.hdf similarity index 100% rename from tsadar/aux/data/IAW-s108135.hdf rename to tsadar/external/data/IAW-s108135.hdf diff --git a/tsadar/aux/files/Copy of MeasuredSensitivity_9.21.15.xls b/tsadar/external/files/Copy of MeasuredSensitivity_9.21.15.xls similarity index 100% rename from tsadar/aux/files/Copy of MeasuredSensitivity_9.21.15.xls rename to tsadar/external/files/Copy of MeasuredSensitivity_9.21.15.xls diff --git a/tsadar/aux/files/MeasuredSensitivity_11_30_21.mat b/tsadar/external/files/MeasuredSensitivity_11_30_21.mat similarity index 100% rename from tsadar/aux/files/MeasuredSensitivity_11_30_21.mat rename to tsadar/external/files/MeasuredSensitivity_11_30_21.mat diff --git a/tsadar/aux/files/angleWghtsFredfine.mat b/tsadar/external/files/angleWghtsFredfine.mat similarity index 100% rename from tsadar/aux/files/angleWghtsFredfine.mat rename to tsadar/external/files/angleWghtsFredfine.mat diff --git a/tsadar/aux/files/angsFRED.mat b/tsadar/external/files/angsFRED.mat similarity index 100% rename from tsadar/aux/files/angsFRED.mat rename to tsadar/external/files/angsFRED.mat diff --git a/tsadar/aux/files/epwtestDW5img1x.npy b/tsadar/external/files/epwtestDW5img1x.npy similarity index 100% rename from tsadar/aux/files/epwtestDW5img1x.npy rename to tsadar/external/files/epwtestDW5img1x.npy diff --git a/tsadar/aux/files/epwtestDW5img1y.npy b/tsadar/external/files/epwtestDW5img1y.npy similarity index 100% rename from tsadar/aux/files/epwtestDW5img1y.npy rename to tsadar/external/files/epwtestDW5img1y.npy diff --git a/tsadar/aux/files/idWT.txt b/tsadar/external/files/idWT.txt similarity index 100% rename from tsadar/aux/files/idWT.txt rename to tsadar/external/files/idWT.txt diff --git a/tsadar/aux/files/rdWT.txt b/tsadar/external/files/rdWT.txt similarity index 100% rename from tsadar/aux/files/rdWT.txt rename to tsadar/external/files/rdWT.txt diff --git a/tsadar/aux/files/spectral_sensitivity.mat b/tsadar/external/files/spectral_sensitivity.mat similarity index 100% rename from tsadar/aux/files/spectral_sensitivity.mat rename to tsadar/external/files/spectral_sensitivity.mat diff --git a/tsadar/aux/numDistFuncs/DLM_x_-3_-10_10_m_-1_2_5.mat b/tsadar/external/numDistFuncs/DLM_x_-3_-10_10_m_-1_2_5.mat similarity index 100% rename from tsadar/aux/numDistFuncs/DLM_x_-3_-10_10_m_-1_2_5.mat rename to tsadar/external/numDistFuncs/DLM_x_-3_-10_10_m_-1_2_5.mat diff --git a/tsadar/model/__init__.py b/tsadar/forward/__init__.py similarity index 100% rename from tsadar/model/__init__.py rename to tsadar/forward/__init__.py diff --git a/tsadar/forward/calc_series.py b/tsadar/forward/calc_series.py new file mode 100644 index 00000000..4679bbc1 --- /dev/null +++ b/tsadar/forward/calc_series.py @@ -0,0 +1,186 @@ +from time import time +import os +import tempfile +import numpy as np +import matplotlib.pyplot as plt +import mlflow +import xarray as xr +import pandas + +from ..utils.plotting import plotters +from ..core.thomson_diagnostic import ThomsonScatteringDiagnostic +from ..core.modules import ThomsonParams +from ..utils.data_handling.calibration import get_scattering_angles, get_calibrations + + +def forward_pass(config): + """ + Calculates a spectrum or series of spectra from the input deck, i.e. performs a forward pass or series of forward + passes. + + + Args: + config: Dictionary - Configuration dictionary created from the input deck. For series of spectra contains the special + field 'series'. This field can have up to 8 subfields [param1, vals1, param2, vals2, param3, vals3, param4, vals4]. + the param subfields are a string identifying which fields of "parameters" are to be looped over. The vals subfields + give the values of that subfield for each spectrum in the series. + + Returns: + Ion data, electron data, and plots are saved to mlflow + + """ + is_angular = True if "angular" in config["other"]["extraoptions"]["spectype"] else False + # get scattering angles and weights + config["optimizer"]["batch_size"] = 1 + + config["other"]["lamrangE"] = [ + config["data"]["fit_rng"]["forward_epw_start"], + config["data"]["fit_rng"]["forward_epw_end"], + ] + config["other"]["lamrangI"] = [ + config["data"]["fit_rng"]["forward_iaw_start"], + config["data"]["fit_rng"]["forward_iaw_end"], + ] + config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) + + sas = get_scattering_angles(config) + + dummy_batch = { + "i_data": np.array([1]), + "e_data": np.array([1]), + "noise_e": np.array([0]), + "noise_i": np.array([0]), + "e_amps": np.array([1]), + "i_amps": np.array([1]), + } + + if is_angular: + [axisxE, _, _, _, _, _] = get_calibrations( + 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] + ) # shot number hardcoded to get calibration + config["other"]["extraoptions"]["spectype"] = "angular_full" + + sas["angAxis"] = axisxE + dummy_batch["i_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) + dummy_batch["e_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) + + if "series" in config.keys(): + serieslen = len(config["series"]["vals1"]) + else: + serieslen = 1 + ThryE = [None] * serieslen + ThryI = [None] * serieslen + lamAxisE = [None] * serieslen + lamAxisI = [None] * serieslen + + t_start = time() + for i in range(serieslen): + # if "series" in config.keys(): + # config["parameters"]["species"][config["series"]["param1"]]["val"] = config["series"]["vals1"][i] + # if "param2" in config["series"].keys(): + # config["parameters"]["species"][config["series"]["param2"]]["val"] = config["series"]["vals2"][i] + # if "param3" in config["series"].keys(): + # config["parameters"]["species"][config["series"]["param3"]]["val"] = config["series"]["vals3"][i] + # if "param4" in config["series"].keys(): + # config["parameters"]["species"][config["series"]["param4"]]["val"] = config["series"]["vals4"][i] + + ts_params = ThomsonParams(config["parameters"], num_params=1, batch=not is_angular) + ts_diag = ThomsonScatteringDiagnostic(config, scattering_angles=sas) + + # params = ts_diag.get_plasma_parameters(ts_diag.pytree_weights["active"]) + ThryE[i], ThryI[i], lamAxisE[i], lamAxisI[i] = ts_diag(ts_params, dummy_batch) + + spectime = time() - t_start + ThryE = np.array(ThryE) + ThryI = np.array(ThryI) + lamAxisE = np.array(lamAxisE) + lamAxisI = np.array(lamAxisI) + + # physical_params = ts_params() + # fe_val = physical_params["electron"]["fe"][0] + # velocity = physical_params["electron"]["v"][0] + + with tempfile.TemporaryDirectory() as td: + os.makedirs(os.path.join(td, "plots"), exist_ok=True) + os.makedirs(os.path.join(td, "binary"), exist_ok=True) + os.makedirs(os.path.join(td, "csv"), exist_ok=True) + if is_angular: + physical_params = ts_params() + fe_val = physical_params["electron"]["fe"] + velocity = physical_params["electron"]["v"] + + savedata = plotters.plot_data_angular( + config, + {"ele": np.squeeze(ThryE)}, + {"e_data": np.zeros((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1]))}, + {"epw_x": sas["angAxis"], "epw_y": lamAxisE}, + td, + ) + # plotters.plot_dist(config, "electron", {"fe": np.squeeze(fe_val), "v": velocity}, np.zeros_like(fe_val), td) + # if len(np.shape(np.squeeze(fe_val))) == 1: + # final_dist = pandas.DataFrame({"fe": [l for l in fe_val], "vx": [vx for vx in velocity]}) + # elif len(np.shape(np.squeeze(fe_val))) == 2: + # final_dist = pandas.DataFrame( + # data=np.squeeze(fe_val), + # columns=velocity[0][0], + # index=velocity[0][:, 0], + # ) + # final_dist.to_csv(os.path.join(td, "csv", "learned_dist.csv")) + else: + if config["parameters"]["electron"]["fe"]["dim"] == 2: + plotters.plot_dist(config, "electron", {"fe": fe_val, "v": velocity}, np.zeros_like(fe_val), td) + + fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True, sharex=False) + if config["other"]["extraoptions"]["load_ele_spec"]: + ax[0].plot( + lamAxisE.squeeze().transpose(), ThryE.squeeze().transpose() + ) # transpose might break single specs? + ax[0].set_title("Simulated Data", fontsize=14) + ax[0].set_ylabel("Amp (arb. units)") + ax[0].set_xlabel("Wavelength (nm)") + ax[0].grid() + + if "series" in config.keys(): + ax[0].legend([str(ele) for ele in config["series"]["vals1"]]) + if config["series"]["param1"] == "fract" or config["series"]["param1"] == "Z": + coords_ele = ( + ("series", np.array(config["series"]["vals1"])[:, 0]), + ("Wavelength", lamAxisE[0, :]), + ) + else: + coords_ele = (("series", config["series"]["vals1"]), ("Wavelength", lamAxisE[0, :])) + ele_dat = {"Sim": ThryE} + ele_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ele) for k, v in ele_dat.items()}) + else: + coords_ele = (("series", [0]), ("Wavelength", lamAxisE[0, :].squeeze())) + ele_dat = {"Sim": ThryE.squeeze(0)} + ele_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ele) for k, v in ele_dat.items()}) + ele_data.to_netcdf(os.path.join(td, "binary", "electron_data.nc")) + + if config["other"]["extraoptions"]["load_ion_spec"]: + ax[1].plot(lamAxisI.squeeze().transpose(), ThryI.squeeze().transpose()) + ax[1].set_title("Simulated Data", fontsize=14) + ax[1].set_ylabel("Amp (arb. units)") + ax[1].set_xlabel("Wavelength (nm)") + ax[1].grid() + + if "series" in config.keys(): + ax[1].legend([str(ele) for ele in config["series"]["vals1"]]) + if config["series"]["param1"] == "fract" or config["series"]["param1"] == "Z": + coords_ion = ( + ("series", np.array(config["series"]["vals1"])[:, 0]), + ("Wavelength", lamAxisI[0, :]), + ) + else: + coords_ion = (("series", config["series"]["vals1"]), ("Wavelength", lamAxisI[0, :])) + ion_dat = {"Sim": ThryI} + ion_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ion) for k, v in ion_dat.items()}) + else: + coords_ion = (("series", [0]), ("Wavelength", lamAxisI[0, :].squeeze())) + ion_dat = {"Sim": ThryI.squeeze(0)} + ion_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ion) for k, v in ion_dat.items()}) + ion_data.to_netcdf(os.path.join(td, "binary", "ion_data.nc")) + fig.savefig(os.path.join(td, "plots", "simulated_data"), bbox_inches="tight") + mlflow.log_artifacts(td) + metrics_dict = {"spectrum_calc_time": spectime} + mlflow.log_metrics(metrics=metrics_dict) diff --git a/tsadar/model/physics/__init__.py b/tsadar/inverse/__init__.py similarity index 100% rename from tsadar/model/physics/__init__.py rename to tsadar/inverse/__init__.py diff --git a/tsadar/fitter.py b/tsadar/inverse/fitter.py similarity index 62% rename from tsadar/fitter.py rename to tsadar/inverse/fitter.py index cddd7833..92459083 100644 --- a/tsadar/fitter.py +++ b/tsadar/inverse/fitter.py @@ -2,19 +2,18 @@ import time import numpy as np import pandas as pd -import copy import pickle import scipy.optimize as spopt import mlflow, optax -from optax import tree_utils as otu +import equinox as eqx +from optax import tree_utils as otu from tqdm import trange from jax.flatten_util import ravel_pytree -import jaxopt -from tsadar.distribution_functions.gen_num_dist_func import DistFunc -from tsadar.model.TSFitter import TSFitter -from tsadar.process import prepare, postprocess +from .loss_function import LossFunction +from ..core.modules import get_filter_spec, ThomsonParams +from ..utils.process import prepare, postprocess def init_param_norm_and_shift(config: Dict) -> Dict: @@ -81,30 +80,45 @@ def _validate_inputs_(config: Dict) -> Dict: """ # get derived quantities - for species in config["parameters"].keys(): - if "electron" in config["parameters"][species]["type"].keys(): - dist_obj = DistFunc(config["parameters"][species]) - config["parameters"][species]["fe"]["velocity"], config["parameters"][species]["fe"]["val"] = dist_obj( - config["parameters"][species]["m"]["val"] - ) - config["parameters"][species]["fe"]["val"] = np.log(config["parameters"][species]["fe"]["val"])[None, :] - # config["velocity"] = np.linspace(-7, 7, config["parameters"]["fe"]["length"]) - Warning("fe length is currently overwritten by v_res") - config["parameters"][species]["fe"]["length"] = len(config["parameters"][species]["fe"]["val"]) - if config["parameters"][species]["fe"]["symmetric"]: - Warning("Symmetric EDF has been disabled") - # config["velocity"] = np.linspace(0, 7, config["parameters"]["fe"]["length"]) - if config["parameters"][species]["fe"]["dim"] == 2 and config["parameters"][species]["fe"]["active"]: - Warning("2D EDFs can only be fit for angular data") - - config["parameters"][species]["fe"]["lb"] = np.multiply( - config["parameters"][species]["fe"]["lb"], np.ones(config["parameters"][species]["fe"]["length"]) - ) - config["parameters"][species]["fe"]["ub"] = np.multiply( - config["parameters"][species]["fe"]["ub"], np.ones(config["parameters"][species]["fe"]["length"]) - ) - if "dist_obj" in locals(): - ValueError("Only 1 electron species is currently supported") + # electron_params = config["parameters"]["electron"] + + # if electron_params["fe"]["type"].casefold() == "arbitrary": + # if isinstance(electron_params["fe"]["val"]) in [list, np.array]: + # pass + # elif isinstance(electron_params["fe"]["val"], str): + # if electron_params["fe"]["val"].casefold() == "dlm": + # electron_params["fe"]["val"] = DLM1D(electron_params)(electron_params["m"]["val"]) + # elif "file" in electron_params["fe"]["val"]: # file-/pscratch/a/.../file.txt + # filename = electron_params["fe"]["val"].split("-")[1] + # else: + # raise NotImplementedError(f"Functional form {electron_params['fe']['val']} not implemented") + + # elif electron_params["fe"]["type"].casefold() == "dlm": + # assert electron_params["m"]["val"] >= 2, "DLM requires m >= 2" + # assert electron_params["m"]["val"] <= 5, "DLM requires m <= 5" + + # elif electron_params["fe"]["type"].casefold() == "sphericalharmonic": + # pass + + # elif electron_params["fe"]["type"].casefold() == "spitzer": + # pass # dont need anything here + # elif electron_params["fe"]["type"].casefold() == "mydlm": + # pass # don't need anything here + # else: + # raise NotImplementedError(f"Functional form {electron_params['fe']['type']} not implemented") + + # dist_obj = DistFunc(electron_params) + # electron_params["fe"]["velocity"], electron_params["fe"]["val"] = dist_obj(electron_params["m"]["val"]) + # electron_params["fe"]["val"] = np.log(electron_params["fe"]["val"])[None, :] + # Warning("fe length is currently overwritten by v_res") + # electron_params["fe"]["length"] = len(electron_params["fe"]["val"]) + # if electron_params["fe"]["symmetric"]: + # Warning("Symmetric EDF has been disabled") + # if electron_params["fe"]["dim"] == 2 and electron_params["fe"]["active"]: + # Warning("2D EDFs can only be fit for angular data") + + # electron_params["fe"]["lb"] = np.multiply(electron_params["fe"]["lb"], np.ones(electron_params["fe"]["length"])) + # electron_params["fe"]["ub"] = np.multiply(electron_params["fe"]["ub"], np.ones(electron_params["fe"]["length"])) # get slices config["data"]["lineouts"]["val"] = [ @@ -124,11 +138,10 @@ def _validate_inputs_(config: Dict) -> Dict: config["data"]["lineouts"]["val"] = config["data"]["lineouts"]["val"][: -(num_slices % batch_size)] print(f"final {num_slices % batch_size} lineouts have been removed") - config["units"] = init_param_norm_and_shift(config) - return config -def angular_optax(config, all_data, sa, batch_indices, num_batches): + +def angular_optax(config, all_data, sa): """ This performs an fitting routines from the optax packages, different minimizers have different requirements for updating steps @@ -136,13 +149,11 @@ def angular_optax(config, all_data, sa, batch_indices, num_batches): config: Configuration dictionary build from the input decks all_data: dictionary of the datasets, amplitudes, and backgrounds as constructed by the prepare.py code sa: dictionary of the scattering angles and thier relative weights - batch_indices: NA - num_batches: NA Returns: best_weights: best parameter weights as returned by the minimizer best_loss: best value of the fit metric found by ther minimizer - ts_fitter: instance of the TSFitter object used for minimization + ts_instance: instance of the ThomsonScattering object used for minimization """ @@ -154,35 +165,35 @@ def angular_optax(config, all_data, sa, batch_indices, num_batches): "e_amps": all_data["e_amps"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], "i_data": all_data["i_data"], "i_amps": all_data["i_amps"], - "noise_e": all_data["noiseE"][ - config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : - ], - "noise_i": all_data["noiseI"][ - config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : - ], + "noise_e": all_data["noiseE"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + "noise_i": all_data["noiseI"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], } - if isinstance(config["data"]["shotnum"],list): + if isinstance(config["data"]["shotnum"], list): batch2 = { - "e_data": all_data["e_data_rot"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], - "e_amps": all_data["e_amps_rot"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], - "noise_e": all_data["noiseE_rot"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], - "i_data": all_data["i_data"], - "i_amps": all_data["i_amps"], - "noise_i": all_data["noiseI"][ + "e_data": all_data["e_data_rot"][ config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : ], - } - test_batch = {'b1':batch1,'b2':batch2} + "e_amps": all_data["e_amps_rot"][ + config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : + ], + "noise_e": all_data["noiseE_rot"][ + config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : + ], + "i_data": all_data["i_data"], + "i_amps": all_data["i_amps"], + "noise_i": all_data["noiseI"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + } + actual_data = {"b1": batch1, "b2": batch2} else: - test_batch = batch1 + actual_data = batch1 - ts_fitter = TSFitter(config, sa, batch1) + loss_fn = LossFunction(config, sa, batch1) minimizer = getattr(optax, config["optimizer"]["method"]) - #schedule = optax.schedules.cosine_decay_schedule(config["optimizer"]["learning_rate"], 100, alpha = 0.00001) - #solver = minimizer(schedule) + # schedule = optax.schedules.cosine_decay_schedule(config["optimizer"]["learning_rate"], 100, alpha = 0.00001) + # solver = minimizer(schedule) solver = minimizer(config["optimizer"]["learning_rate"]) - weights = ts_fitter.pytree_weights["active"] + weights = loss_fn.pytree_weights["active"] opt_state = solver.init(weights) # start train loop @@ -193,23 +204,20 @@ def angular_optax(config, all_data, sa, batch_indices, num_batches): num_g_wait = 0 num_b_wait = 0 for i_epoch in (pbar := trange(config["optimizer"]["num_epochs"])): - if config["nn"]["use"]: - np.random.shuffle(batch_indices) - - (val, aux), grad = ts_fitter.vg_loss(weights, test_batch) + (val, aux), grad = loss_fn.vg_loss(weights, actual_data) updates, opt_state = solver.update(grad, opt_state, weights) - + epoch_loss = val if epoch_loss < best_loss: print(f"delta loss {best_loss - epoch_loss}") if best_loss - epoch_loss < 0.000001: best_loss = epoch_loss - num_g_wait+=1 + num_g_wait += 1 if num_g_wait > 5: print("Minimizer exited due to change in loss < 1e-6") break elif epoch_loss > best_loss: - num_b_wait+=1 + num_b_wait += 1 if num_b_wait > 5: print("Minimizer exited due to increase in loss") break @@ -218,41 +226,44 @@ def angular_optax(config, all_data, sa, batch_indices, num_batches): num_b_wait = 0 num_g_wait = 0 pbar.set_description(f"Loss {epoch_loss:.2e}, Learning rate {otu.tree_get(opt_state, 'scale')}") - + weights = optax.apply_updates(weights, updates) - + if config["optimizer"]["save_state"]: if i_epoch % config["optimizer"]["save_state_freq"] == 0: state_weights[i_epoch] = weights mlflow.log_metrics({"epoch loss": float(epoch_loss)}, step=i_epoch) - with open('state_weights.txt', 'wb') as file: + with open("state_weights.txt", "wb") as file: file.write(pickle.dumps(state_weights)) - mlflow.log_artifact('state_weights.txt') - return weights, epoch_loss, ts_fitter + mlflow.log_artifact("state_weights.txt") + return weights, epoch_loss, loss_fn + def _1d_adam_loop_( - config: Dict, ts_fitter: TSFitter, previous_weights: np.ndarray, batch: Dict, tbatch + config: Dict, loss_fn: LossFunction, previous_weights: np.ndarray, batch: Dict, tbatch ) -> Tuple[float, Dict]: - jaxopt_kwargs = dict( - fun=ts_fitter.vg_loss, maxiter=config["optimizer"]["num_epochs"], value_and_grad=True, has_aux=True - ) + # jaxopt_kwargs = dict( + # fun=loss_fn.vg_loss, maxiter=config["optimizer"]["num_epochs"], value_and_grad=True, has_aux=True + # ) opt = optax.adam(config["optimizer"]["learning_rate"]) - solver = jaxopt.OptaxSolver(opt=opt, **jaxopt_kwargs) + ts_params = ThomsonParams(config["parameters"], config["optimizer"]["batch_size"]) + diff_params, static_params = eqx.partition(ts_params, get_filter_spec(config["parameters"], ts_params)) + opt_state = opt.init(diff_params) - if previous_weights is None: - init_weights = ts_fitter.pytree_weights["active"] - else: - init_weights = previous_weights + # if previous_weights is None: + # init_weights = loss_fn.pytree_weights["active"] + # else: + # init_weights = previous_weights # if "sequential" in config["optimizer"]: # if config["optimizer"]["sequential"]: # if previous_weights is not None: # init_weights = previous_weights - opt_state = solver.init_state(init_weights, batch=batch) + # opt_state = solver.init_state(init_weights, batch=batch) best_loss = 1e16 epoch_loss = 1e19 @@ -260,21 +271,30 @@ def _1d_adam_loop_( tbatch.set_description(f"Epoch {i_epoch + 1}, Prev Epoch Loss {epoch_loss:.2e}") # if config["nn"]["use"]: # np.random.shuffle(batch_indices) + (epoch_loss, aux), grad = loss_fn.vg_loss(diff_params, static_params, batch) + updates, opt_state = opt.update(grad, opt_state) + diff_params = eqx.apply_updates(diff_params, updates) - init_weights, opt_state = solver.update(params=init_weights, state=opt_state, batch=batch) - epoch_loss = opt_state.value + # init_weights, opt_state = solver.update(params=init_weights, state=opt_state, batch=batch) + # epoch_loss = opt_state.value if epoch_loss < best_loss: best_loss = epoch_loss - best_weights = init_weights + best_weights = eqx.combine(diff_params, static_params) return best_loss, best_weights -def _1d_scipy_loop_(config: Dict, ts_fitter: TSFitter, previous_weights: np.ndarray, batch: Dict) -> Tuple[float, Dict]: - if previous_weights is None: # if prev, then use that, if not then use flattened weights - init_weights = np.copy(ts_fitter.flattened_weights) - else: - init_weights = np.array(previous_weights) +def _1d_scipy_loop_( + config: Dict, loss_fn: LossFunction, previous_weights: np.ndarray, batch: Dict +) -> Tuple[float, Dict]: + # if previous_weights is None: # if prev, then use that, if not then use flattened weights + # init_weights = np.copy(loss_fn.ts_diag.flattened_weights) + # else: + # init_weights = np.array(previous_weights) + + ts_params = ThomsonParams(config["parameters"], config["optimizer"]["batch_size"]) + diff_params, static_params = eqx.partition(ts_params, get_filter_spec(config["parameters"], ts_params)) + init_weights, loss_fn.unravel_weights = ravel_pytree(diff_params) # if "sequential" in config["optimizer"]: # if config["optimizer"]["sequential"]: @@ -282,24 +302,24 @@ def _1d_scipy_loop_(config: Dict, ts_fitter: TSFitter, previous_weights: np.ndar # init_weights = previous_weights res = spopt.minimize( - ts_fitter.vg_loss if config["optimizer"]["grad_method"] == "AD" else ts_fitter.loss, + loss_fn.vg_loss if config["optimizer"]["grad_method"] == "AD" else loss_fn.loss, init_weights, - args=batch, + args=(static_params, batch), method=config["optimizer"]["method"], jac=True if config["optimizer"]["grad_method"] == "AD" else False, - bounds=ts_fitter.bounds, + bounds=((0, 1) for _ in range(len(init_weights))), options={"disp": True, "maxiter": config["optimizer"]["num_epochs"]}, ) best_loss = res["fun"] - best_weights = ts_fitter.unravel_pytree(res["x"]) + best_weights = eqx.combine(loss_fn.unravel_weights(res["x"]), static_params) return best_loss, best_weights def one_d_loop( config: Dict, all_data: Dict, sa: Tuple, batch_indices: np.ndarray, num_batches: int -) -> Tuple[List, float, TSFitter]: +) -> Tuple[List, float, LossFunction]: """ This is the higher level wrapper that prepares the data and the fitting code for the 1D fits @@ -322,7 +342,7 @@ def one_d_loop( "noise_e": all_data["noiseE"][: config["optimizer"]["batch_size"]], "noise_i": all_data["noiseI"][: config["optimizer"]["batch_size"]], } | sample - ts_fitter = TSFitter(config, sa, sample) + loss_fn = LossFunction(config, sa, sample) print("minimizing") mlflow.set_tag("status", "minimizing") @@ -343,11 +363,11 @@ def one_d_loop( } if config["optimizer"]["method"] == "adam": # Stochastic Gradient Descent - best_loss, best_weights = _1d_adam_loop_(config, ts_fitter, previous_weights, batch, tbatch) + best_loss, best_weights = _1d_adam_loop_(config, loss_fn, previous_weights, batch, tbatch) else: # not sure why this is needed but something needs to be reset, either the weights or the bounds - ts_fitter = TSFitter(config, sa, batch) - best_loss, best_weights = _1d_scipy_loop_(config, ts_fitter, previous_weights, batch) + loss_fn = LossFunction(config, sa, batch) + best_loss, best_weights = _1d_scipy_loop_(config, loss_fn, previous_weights, batch) all_weights.append(best_weights) mlflow.log_metrics({"batch loss": float(best_loss)}, step=i_batch) @@ -361,7 +381,7 @@ def one_d_loop( else: previous_weights, _ = ravel_pytree(best_weights) - return all_weights, overall_loss, ts_fitter + return all_weights, overall_loss, loss_fn def fit(config) -> Tuple[pd.DataFrame, float]: @@ -393,21 +413,9 @@ def fit(config) -> Tuple[pd.DataFrame, float]: config = _validate_inputs_(config) # prepare data - if isinstance(config["data"]["shotnum"],list): - startCCDsize = config["other"]["CCDsize"] - all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0]) - config["other"]["CCDsize"] = startCCDsize - all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1]) - all_data.update({'e_data_rot': all_data2['e_data'], 'e_amps_rot': all_data2['e_amps'], - 'rot_angle': config["data"]['shot_rot'], 'noiseE_rot': all_data2['noiseE']}) - - if config["other"]["extraoptions"]["spectype"] != 'angular_full': - raise NotImplementedError('Muliplexed data fitting is only availible for angular data') - else: - all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"]) - - batch_indices = np.arange(max(len(all_data["e_data"]), len(all_data["i_data"]))) - num_batches = len(batch_indices) // config["optimizer"]["batch_size"] or 1 + all_data, sa, all_axes = load_data_for_fitting(config) + sample_indices = np.arange(max(len(all_data["e_data"]), len(all_data["i_data"]))) + num_batches = len(sample_indices) // config["optimizer"]["batch_size"] or 1 mlflow.log_metrics({"setup_time": round(time.time() - t1, 2)}) # perform fit @@ -416,15 +424,37 @@ def fit(config) -> Tuple[pd.DataFrame, float]: print("minimizing") if "angular" in config["other"]["extraoptions"]["spectype"]: - fitted_weights, overall_loss, ts_fitter = angular_optax(config, all_data, sa, batch_indices, num_batches) + fitted_weights, overall_loss, loss_fn = angular_optax(config, all_data, sa) else: - fitted_weights, overall_loss, ts_fitter = one_d_loop(config, all_data, sa, batch_indices, num_batches) + fitted_weights, overall_loss, loss_fn = one_d_loop(config, all_data, sa, sample_indices, num_batches) mlflow.log_metrics({"overall loss": float(overall_loss)}) mlflow.log_metrics({"fit_time": round(time.time() - t1, 2)}) mlflow.set_tag("status", "postprocessing") print("postprocessing") - final_params = postprocess.postprocess(config, batch_indices, all_data, all_axes, ts_fitter, sa, fitted_weights) + final_params = postprocess.postprocess(config, sample_indices, all_data, all_axes, loss_fn, sa, fitted_weights) return final_params, float(overall_loss) + + +def load_data_for_fitting(config): + if isinstance(config["data"]["shotnum"], list): + startCCDsize = config["other"]["CCDsize"] + all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"][0]) + config["other"]["CCDsize"] = startCCDsize + all_data2, _, _ = prepare.prepare_data(config, config["data"]["shotnum"][1]) + all_data.update( + { + "e_data_rot": all_data2["e_data"], + "e_amps_rot": all_data2["e_amps"], + "rot_angle": config["data"]["shot_rot"], + "noiseE_rot": all_data2["noiseE"], + } + ) + + if config["other"]["extraoptions"]["spectype"] != "angular_full": + raise NotImplementedError("Muliplexed data fitting is only availible for angular data") + else: + all_data, sa, all_axes = prepare.prepare_data(config, config["data"]["shotnum"]) + return all_data, sa, all_axes diff --git a/tsadar/inverse/loss_function.py b/tsadar/inverse/loss_function.py new file mode 100644 index 00000000..50b962a8 --- /dev/null +++ b/tsadar/inverse/loss_function.py @@ -0,0 +1,474 @@ +import copy +from typing import Dict + +import jax +from jax import numpy as jnp +from equinox import filter_value_and_grad, filter_hessian, filter_jit +from jax.flatten_util import ravel_pytree +import numpy as np +import equinox as eqx + +from ..core.thomson_diagnostic import ThomsonScatteringDiagnostic +from ..utils.vector_tools import rotate + + +class LossFunction: + """ + This class is responsible for handling the forward pass and using that to create a loss function + + """ + + def __init__(self, cfg: Dict, scattering_angles, dummy_batch): + """ + + Args: + cfg: Configuration dictionary constructed from the inputs + scattering_angles: Dictionary containing the scattering angles and thier relative weights + dummy_batch: Dictionary of dummy data + """ + self.cfg = cfg + + if cfg["optimizer"]["y_norm"]: + self.i_norm = np.amax(dummy_batch["i_data"]) + self.e_norm = np.amax(dummy_batch["e_data"]) + else: + self.i_norm = self.e_norm = 1.0 + + if cfg["optimizer"]["x_norm"] and cfg["nn"]["use"]: + self.i_input_norm = np.amax(dummy_batch["i_data"]) + self.e_input_norm = np.amax(dummy_batch["e_data"]) + else: + self.i_input_norm = self.e_input_norm = 1.0 + + # boolean used to determine if the analyis is performed twice with rotation of the EDF + self.multiplex_ang = isinstance(cfg["data"]["shotnum"], list) + + ############ + + self.ts_diag = ThomsonScatteringDiagnostic(cfg, scattering_angles=scattering_angles) + + self._loss_ = filter_jit(self.__loss__) + self._vg_func_ = filter_jit(filter_value_and_grad(self.__loss__, has_aux=True)) + ## this will be replaced with jacobian params jacobian inverse + self._h_func_ = filter_jit(filter_hessian(self._loss_for_hess_fn_)) + self.array_loss = filter_jit(self.calc_loss) + + def _get_normed_batch_(self, batch: Dict): + """ + Normalizes the batch + + Args: + batch: + + Returns: + + """ + normed_batch = copy.deepcopy(batch) + normed_batch["i_data"] = normed_batch["i_data"] / self.i_input_norm + normed_batch["e_data"] = normed_batch["e_data"] / self.e_input_norm + return normed_batch + + def vg_loss(self, diff_weights, static_weights: Dict, batch: Dict): + """ + This is the primary workhorse high level function. This function returns the value of the loss function which + is used to assess goodness-of-fit and the gradient of that value with respect to the weights, which is used to + update the weights + + This function is used by both optimization methods. It performs the necessary pre-/post- processing that is + needed to work with the optimization software. + + Args: + weights: + batch: + + Returns: + + """ + if self.cfg["optimizer"]["method"] == "l-bfgs-b": + # pytree_weights = self.ts_diag.unravel_pytree(weights) + + diff_weights = self.unravel_weights(diff_weights) + (value, aux), grad = self._vg_func_(diff_weights, static_weights, batch) + + # if "fe" in grad: + # grad["fe"] = self.cfg["optimizer"]["grad_scalar"] * grad["fe"] + + # for species in self.cfg["parameters"].keys(): + # for k, param_dict in self.cfg["parameters"][species].items(): + # if param_dict["active"]: + # scalar = param_dict["gradient_scalar"] if "gradient_scalar" in param_dict else 1.0 + # grad[species][k] *= scalar + + temp_grad, _ = ravel_pytree(grad) + flattened_grads = np.array(temp_grad) + return value, flattened_grads + else: + return self._vg_func_(diff_weights, static_weights, batch) + + def h_loss_wrt_params(self, weights, batch): + return self._h_func_(weights, batch) + + def _loss_for_hess_fn_(self, weights, batch): + # params = params | self.static_params + # params = self.ts_diag.get_plasma_parameters(weights) + ThryE, ThryI, lamAxisE, lamAxisI = self.ts_diag(params, batch) + i_error, e_error, _, _ = self.calc_ei_error( + batch, + ThryI, + lamAxisI, + ThryE, + lamAxisE, + uncert=[jnp.abs(batch["i_data"]) + 1e-10, jnp.abs(batch["e_data"]) + 1e-10], + reduce_func=jnp.sum, + ) + + return i_error + e_error + + def calc_ei_error(self, batch, ThryI, lamAxisI, ThryE, lamAxisE, uncert, reduce_func=jnp.mean): + """ + This function calculates the error in the fit of the IAW and EPW + + Args: + batch: dictionary containing the data + ThryI: ion theoretical spectrum + lamAxisI: ion wavelength axis + ThryE: electron theoretical spectrum + lamAxisE: electron wavelength axis + uncert: uncertainty values + reduce_func: method to combine all lineouts into a single metric + + Returns: + + """ + i_error = 0.0 + e_error = 0.0 + used_points = 0 + i_data = batch["i_data"] + e_data = batch["e_data"] + sqdev = {"ele": jnp.zeros(e_data.shape), "ion": jnp.zeros(i_data.shape)} + + if self.cfg["other"]["extraoptions"]["fit_IAW"]: + _error_ = self.loss_functionals(i_data, ThryI, uncert[0], method=self.cfg["optimizer"]["loss_method"]) + _error_ = jnp.where( + ( + (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_min"]) + & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_cf_min"]) + ) + | ( + (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_cf_max"]) + & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_max"]) + ), + _error_, + 0.0, + ) + + used_points += jnp.sum( + ( + (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_min"]) + & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_cf_min"]) + ) + | ( + (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_cf_max"]) + & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_max"]) + ) + ) + # this was temp code to help with 2 species fits + # _error_ = jnp.where( + # (lamAxisI > 526.25) & (lamAxisI < 526.75), + # 10.0 * _error_, + # _error_, + # ) + sqdev["ion"] = _error_ + i_error += reduce_func(_error_) + + if self.cfg["other"]["extraoptions"]["fit_EPWb"]: + _error_ = self.loss_functionals(e_data, ThryE, uncert[1], method=self.cfg["optimizer"]["loss_method"]) + _error_ = jnp.where( + (lamAxisE > self.cfg["data"]["fit_rng"]["blue_min"]) + & (lamAxisE < self.cfg["data"]["fit_rng"]["blue_max"]), + _error_, + 0.0, + ) + used_points += jnp.sum( + (lamAxisE > self.cfg["data"]["fit_rng"]["blue_min"]) + & (lamAxisE < self.cfg["data"]["fit_rng"]["blue_max"]) + ) + e_error += reduce_func(_error_) + sqdev["ele"] += _error_ + + if self.cfg["other"]["extraoptions"]["fit_EPWr"]: + _error_ = self.loss_functionals(e_data, ThryE, uncert[1], method=self.cfg["optimizer"]["loss_method"]) + _error_ = jnp.where( + (lamAxisE > self.cfg["data"]["fit_rng"]["red_min"]) + & (lamAxisE < self.cfg["data"]["fit_rng"]["red_max"]), + _error_, + 0.0, + ) + used_points += jnp.sum( + (lamAxisE > self.cfg["data"]["fit_rng"]["red_min"]) + & (lamAxisE < self.cfg["data"]["fit_rng"]["red_max"]) + ) + + e_error += reduce_func(_error_) + sqdev["ele"] += _error_ + + return i_error, e_error, sqdev, used_points + + def calc_loss(self, ts_params, batch: Dict): + """ + This function calculates the value of the loss function + + Args: + params: + batch: + + Returns: + + """ + + if self.multiplex_ang: + ThryE, ThryI, lamAxisE, lamAxisI = self.ts_diag(params, batch["b1"]) + # jax.debug.print("fe size {e_error}", e_error=jnp.shape(params["electron"]['fe'])) + params["electron"]["fe"] = rotate( + jnp.squeeze(params["electron"]["fe"]), self.cfg["data"]["shot_rot"] * jnp.pi / 180.0 + ) + + ThryE_rot, _, _, _ = self.ts_diag(params, batch["b2"]) + i_error1, e_error1, sqdev, used_points = self.calc_ei_error( + batch["b1"], + ThryI, + lamAxisI, + ThryE, + lamAxisE, + denom=[jnp.square(self.i_norm), jnp.square(self.e_norm)], + reduce_func=jnp.mean, + ) + i_error2, e_error2, sqdev, used_points = self.calc_ei_error( + batch["b2"], + ThryI, + lamAxisI, + ThryE_rot, + lamAxisE, + denom=[jnp.square(self.i_norm), jnp.square(self.e_norm)], + reduce_func=jnp.mean, + ) + i_error = i_error1 + i_error2 + e_error = e_error1 + e_error2 + + normed_batch = self._get_normed_batch_(batch["b1"]) + else: + ThryE, ThryI, lamAxisE, lamAxisI = self.ts_diag(ts_params, batch) + + i_error, e_error, sqdev, used_points = self.calc_ei_error( + batch, + ThryI, + lamAxisI, + ThryE, + lamAxisE, + uncert=[jnp.square(self.i_norm), jnp.square(self.e_norm)], + reduce_func=jnp.mean, + ) + + normed_batch = self._get_normed_batch_(batch) + + normed_e_data = normed_batch["e_data"] + ion_error = self.cfg["data"]["ion_loss_scale"] * i_error + + penalty_error = 0.0 # self.penalties(weights) + total_loss = ion_error + e_error + penalty_error + # jax.debug.print("e_error {total_loss}", total_loss=e_error) + + return total_loss, sqdev, used_points, ThryE, ThryI, ts_params() + # return total_loss, [ThryE, params] + + def loss(self, weights, batch: Dict): + """ + High level function that returns the value of the loss function + + Args: + weights: + batch: Dict + + Returns: + + """ + if self.cfg["optimizer"]["method"] == "l-bfgs-b": + pytree_weights = self.unravel_pytree(weights) + value, _ = self._loss_(pytree_weights, batch) + return value + else: + return self._loss_(weights, batch) + + def __loss__(self, diff_weights, static_weights, batch: Dict): + """ + Output wrapper + """ + + weights = eqx.combine(static_weights, diff_weights) + total_loss, sqdev, used_points, ThryE, normed_e_data, params = self.calc_loss(weights, batch) + return total_loss, [ThryE, params] + + def loss_functionals(self, d, t, uncert, method="l2"): + """ + This function calculates the error loss metric between d and t for different metrics sepcified by method, + with the default being the l2 norm + + Args: + d: data array + t: theory array + uncert: uncertainty values + method: name of the loss metric method, l1, l2, poisson, log-cosh. Currently only l1 and l2 include the uncertainties + + Returns: + loss: value of the loss metric per slice + + """ + if method == "l1": + _error_ = jnp.abs(d - t) / uncert + elif method == "l2": + _error_ = jnp.square(d - t) / uncert + elif method == "log-cosh": + _error_ = jnp.log(jnp.cosh(d - t)) + elif method == "poisson": + _error_ = t - d * jnp.log(t) + return _error_ + + def penalties(self, weights): + """ + This function calculates additional penatlities to be added to the loss function + + Args: + params: parameter weights as supplied to the loss function + batch: + + Returns: + + """ + param_penalty = 0.0 + # this will need to be modified for the params instead of weights + for species in weights.keys(): + for k in weights[species].keys(): + if k != "fe": + # jax.debug.print("fe size {e_error}", e_error=weights[species][k]) + param_penalty += jnp.maximum(0.0, jnp.log(jnp.abs(weights[species][k] - 0.5) + 0.5)) + if self.cfg["optimizer"]["moment_loss"]: + density_loss, temperature_loss, momentum_loss = self._moment_loss_(weights) + param_penalty = param_penalty + density_loss + temperature_loss + momentum_loss + else: + density_loss = 0.0 + temperature_loss = 0.0 + momentum_loss = 0.0 + if self.cfg["parameters"]["electron"]["fe"]["fe_decrease_strict"]: + gradfe = jnp.sign(self.cfg["velocity"][1:]) * jnp.diff(params["fe"].squeeze()) + vals = jnp.where(gradfe > 0.0, gradfe, 0.0).sum() + fe_penalty = jnp.tan(jnp.amin(jnp.array([vals, jnp.pi / 2]))) + else: + fe_penalty = 0.0 + # jax.debug.print("e_err {e_error}", e_error=e_error) + # jax.debug.print("{density_loss}", density_loss=density_loss) + # jax.debug.print("{temperature_loss}", temperature_loss=temperature_loss) + # jax.debug.print("{momentum_loss}", momentum_loss=momentum_loss) + # jax.debug.print("tot loss {total_loss}", total_loss=total_loss) + # jax.debug.print("param_penalty {total_loss}", total_loss=jnp.sum(param_penalty)) + + return jnp.sum(param_penalty) + fe_penalty + density_loss + temperature_loss + momentum_loss + + def _moment_loss_(self, params): + """ + This function calculates the loss associated with regularizing the moments of the distribution function i.e. + the density should be 1, the temperature should be 1, and momentum should be 0. + + Args: + params: + + Returns: + + """ + if self.cfg["parameters"]["electron"]["fe"]["dim"] == 1: + dv = ( + self.cfg["parameters"]["electron"]["fe"]["velocity"][1] + - self.cfg["parameters"]["electron"]["fe"]["velocity"][0] + ) + if self.cfg["parameters"]["electron"]["fe"]["symmetric"]: + density_loss = jnp.mean(jnp.square(1.0 - 2.0 * jnp.sum(jnp.exp(params["electron"]["fe"]) * dv, axis=1))) + temperature_loss = jnp.mean( + jnp.square( + 1.0 + - 2.0 + * jnp.sum( + jnp.exp(params["electron"]["fe"]) + * self.cfg["parameters"]["electron"]["fe"]["velocity"] ** 2.0 + * dv, + axis=1, + ) + ) + ) + else: + density_loss = jnp.mean(jnp.square(1.0 - jnp.sum(jnp.exp(params["electron"]["fe"]) * dv, axis=1))) + temperature_loss = jnp.mean( + jnp.square( + 1.0 + - jnp.sum( + jnp.exp(params["electron"]["fe"]) + * self.cfg["parameters"]["electron"]["fe"]["velocity"] ** 2.0 + * dv, + axis=1, + ) + ) + ) + momentum_loss = jnp.mean( + jnp.square( + jnp.sum( + jnp.exp(params["electron"]["fe"]) * self.cfg["parameters"]["electron"]["fe"]["velocity"] * dv, + axis=1, + ) + ) + ) + else: + fedens = ( + jnp.sum(jnp.exp(params["electron"]["fe"])) * self.cfg["parameters"]["electron"]["fe"]["v_res"] ** 2.0 + ) + jax.debug.print("zero moment = {fedens}", fedens=fedens) + density_loss = jnp.mean(jnp.square(1.0 - fedens)) + + # density_loss = jnp.mean( + # jnp.square( + # 1.0 + # - trapz( + # trapz( + # jnp.exp(params["electron"]["fe"]), self.cfg["parameters"]["electron"]["fe"]["v_res"] + # ), + # self.cfg["parameters"]["electron"]["fe"]["v_res"], + # ) + # ) + # ) + second_moment = ( + jnp.sum( + jnp.exp(params["electron"]["fe"]) + * ( + self.cfg["parameters"]["electron"]["fe"]["velocity"][0] ** 2 + + self.cfg["parameters"]["electron"]["fe"]["velocity"][1] ** 2 + ) + ) + * self.cfg["parameters"]["electron"]["fe"]["v_res"] ** 2.0 + ) + jax.debug.print("second moment = {fedens}", fedens=second_moment) + temperature_loss = jnp.mean(jnp.square(1.0 - second_moment / 2)) + # needs to be fixed + first_moment = second_moment = trapz( + trapz( + jnp.exp(params["electron"]["fe"]) + * ( + self.cfg["parameters"]["electron"]["fe"]["velocity"][0] ** 2 + + self.cfg["parameters"]["electron"]["fe"]["velocity"][1] ** 2 + ) + ** (1 / 2), + self.cfg["parameters"]["electron"]["fe"]["v_res"], + ), + self.cfg["parameters"]["electron"]["fe"]["v_res"], + ) + jax.debug.print("first moment = {fedens}", fedens=first_moment) + # momentum_loss = jnp.mean(jnp.square(jnp.sum(jnp.exp(params["fe"]) * self.cfg["velocity"] * dv, axis=1))) + momentum_loss = 0.0 + # print(temperature_loss) + return density_loss, temperature_loss, momentum_loss diff --git a/tsadar/model/TSFitter.py b/tsadar/model/TSFitter.py deleted file mode 100644 index 5df77a44..00000000 --- a/tsadar/model/TSFitter.py +++ /dev/null @@ -1,743 +0,0 @@ -import copy -from typing import Dict - -import jax -from jax import numpy as jnp - - -from jax import jit, value_and_grad -from jax.flatten_util import ravel_pytree -from interpax import interp2d -import numpy as np - -from tsadar.model.spectrum import SpectrumCalculator -from tsadar.distribution_functions.dist_functional_forms import calc_moment, trapz -from tsadar.misc.vector_tools import rotate - - -class TSFitter: - """ - This class is responsible for handling the forward pass and using that to create a loss function - - Args: - cfg: Configuration dictionary - sas: TODO - dummy_batch: Dictionary of dummy data - - """ - - def __init__(self, cfg: Dict, sas, dummy_batch): - """ - - Args: - cfg: Configuration dictionary constructed from the inputs - sas: Dictionary containing the scattering angles and thier relative weights - dummy_batch: Dictionary of dummy data - """ - self.cfg = cfg - - if cfg["optimizer"]["y_norm"]: - self.i_norm = np.amax(dummy_batch["i_data"]) - self.e_norm = np.amax(dummy_batch["e_data"]) - else: - self.i_norm = self.e_norm = 1.0 - - if cfg["optimizer"]["x_norm"] and cfg["nn"]["use"]: - self.i_input_norm = np.amax(dummy_batch["i_data"]) - self.e_input_norm = np.amax(dummy_batch["e_data"]) - else: - self.i_input_norm = self.e_input_norm = 1.0 - - # this will need to be fixed for multi electron - for species in self.cfg["parameters"].keys(): - if "electron" in self.cfg["parameters"][species]["type"].keys(): - self.e_species = species - - #boolean used to determine if the analyis is performed twice with rotation of the EDF - self.multiplex_ang = isinstance(cfg["data"]["shotnum"],list) - - self.spec_calc = SpectrumCalculator(cfg, sas, dummy_batch) - - self._loss_ = jit(self.__loss__) - self._vg_func_ = jit(value_and_grad(self.__loss__, argnums=0, has_aux=True)) - ##this will be replaced with jacobian params jacobian inverse - self._h_func_ = jit(jax.hessian(self._loss_for_hess_fn_, argnums=0)) - self.array_loss = jit(self.calc_loss) - - ############ - - - lb, ub, init_weights = init_weights_and_bounds(cfg, num_slices=cfg["optimizer"]["batch_size"]) - self.flattened_weights, self.unravel_pytree = ravel_pytree(init_weights["active"]) - self.static_params = init_weights["inactive"] - self.pytree_weights = init_weights - self.lb = lb - self.ub = ub - self.construct_bounds() - - # this needs to be rethought and does not work in all cases - if cfg["parameters"][self.e_species]["fe"]["active"]: - if "dist_fit" in cfg: - if cfg["parameters"][self.e_species]["fe"]["dim"] == 1: - self.smooth_window_len = round( - cfg["parameters"][self.e_species]["fe"]["velocity"].size * cfg["dist_fit"]["window"]["len"] - ) - self.smooth_window_len = self.smooth_window_len if self.smooth_window_len > 1 else 2 - - if cfg["dist_fit"]["window"]["type"] == "hamming": - self.w = jnp.hamming(self.smooth_window_len) - elif cfg["dist_fit"]["window"]["type"] == "hann": - self.w = jnp.hanning(self.smooth_window_len) - elif cfg["dist_fit"]["window"]["type"] == "bartlett": - self.w = jnp.bartlett(self.smooth_window_len) - else: - raise NotImplementedError - else: - Warning("Smoothing not enabled for 2D distributions") - else: - Warning( - "\n !!! Distribution function not fitted !!! Make sure this is what you thought you were running \n" - ) - - def construct_bounds(self): - """ - This method construct a bounds zip from the upper and lower bounds. This allows the iterable to be reconstructed - after being used in a fit. - - Args: - - Returns: - - """ - flattened_lb, _ = ravel_pytree(self.lb) - flattened_ub, _ = ravel_pytree(self.ub) - self.bounds = zip(flattened_lb, flattened_ub) - - def smooth(self, distribution: jnp.ndarray) -> jnp.ndarray: - """ - This method is used to smooth the distribution function. It sits right in between the optimization algorithm - that provides the weights/values of the distribution function and the fitting code that uses it. - - Because the optimizer is not constrained to provide a smooth distribution function, this operation smoothens - the output. This is a differentiable operation and we train/fit our weights through this. - - Args: - distribution: - - Returns: - - """ - s = jnp.r_[ - distribution[self.smooth_window_len - 1 : 0 : -1], - distribution, - distribution[-2 : -self.smooth_window_len - 1 : -1], - ] - return jnp.convolve(self.w / self.w.sum(), s, mode="same")[ - self.smooth_window_len - 1 : -(self.smooth_window_len - 1) - ] - - def smooth2D(self, distribution: jnp.ndarray) -> jnp.ndarray: - """ - This method is used to smooth the distribution function. It sits right in between the optimization algorithm - that provides the weights/values of the distribution function and the fitting code that uses it. - - Because the optimizer is not constrained to provide a smooth distribution function, this operation smoothens - the output. This is a differentiable operation and we train/fit our weights through this. - - Args: - distribution: - - Returns: - - """ - - smoothing_kernel = jnp.outer(jnp.bartlett(5),jnp.bartlett(5)) - smoothing_kernel = smoothing_kernel/jnp.sum(smoothing_kernel) - #print(distribution) - #print(jnp.shape(distribution)) - - return jax.scipy.signal.convolve2d(distribution,smoothing_kernel,'same') - - def weights_to_params(self, input_weights: Dict, return_static_params: bool = True) -> Dict: - """ - This function creates the physical parameters used in the TS algorithm from the weights. The input input_weights - is mapped to these_params causing the input_weights to also be modified. - - This could be a 1:1 mapping, or it could be a linear transformation e.g. "normalized" parameters, or it could - be something else altogether e.g. a neural network - - Args: - input_weights: dictionary of weights used or supplied by the minimizer, these are bounded from 0 to 1 - return_static_params: boolean determining if the static parameters (these not modified during fitting) will - be inculded in the retuned dictionary. This is nessesary for the physics model which requires values for all - parameters. - - Returns: - these_params: dictionary of the paramters in physics units - - """ - Te_mult=1.0 - ne_mult=1.0 - these_params = copy.deepcopy(input_weights) - for species in self.cfg["parameters"].keys(): - for param_name, param_config in self.cfg["parameters"][species].items(): - if param_name == "type": - continue - if param_config["active"]: - if param_name != "fe": - these_params[species][param_name] = ( - these_params[species][param_name] * self.cfg["units"]["norms"][species][param_name] - + self.cfg["units"]["shifts"][species][param_name] - ) - else: - fe_shape = jnp.shape(these_params[species][param_name]) - #convert EDF from 01 bounded log units to unbounded log units - #jax.debug.print("these params {a}", a=these_params[species][param_name]) - - fe_cur = jnp.exp( - these_params[species][param_name] * self.cfg["units"]["norms"][species][param_name].reshape(fe_shape) - + self.cfg["units"]["shifts"][species][param_name].reshape(fe_shape) - ) - #commented out the renormalization to see effect on 2D edfs 9/26/24 - #jax.debug.print("fe_cur {a}", a=fe_cur) - #this only works for 2D edfs and will have to be genralized to 1D - #recaclulate the moments of the EDF - renorm = jnp.sqrt( - calc_moment(jnp.squeeze(fe_cur), - self.cfg["parameters"][self.e_species]["fe"]["velocity"],2) - / (2*calc_moment(jnp.squeeze(fe_cur), - self.cfg["parameters"][self.e_species]["fe"]["velocity"],0))) - Te_mult = renorm**2 - #h2 = self.cfg["parameters"][self.e_species]["fe"]["v_res"]/renorm - vx2 = self.cfg["parameters"][self.e_species]["fe"]["velocity"][0][0]/renorm - vy2 = self.cfg["parameters"][self.e_species]["fe"]["velocity"][0][0]/renorm - # fe_cur = interp2d( - # self.cfg["parameters"][self.e_species]["fe"]["velocity"][0].flatten(), - # self.cfg["parameters"][self.e_species]["fe"]["velocity"][1].flatten(), - # vx2, vy2, - # jnp.squeeze(fe_cur), - # extrap=[0, 0], method="linear").reshape( - # jnp.shape(self.cfg["parameters"][self.e_species]["fe"]["velocity"][0]),order="F") - fe_cur = jnp.exp(interp2d( - self.cfg["parameters"][self.e_species]["fe"]["velocity"][0].flatten(), - self.cfg["parameters"][self.e_species]["fe"]["velocity"][1].flatten(), - vx2, vy2, - jnp.log(jnp.squeeze(fe_cur)), - extrap=[-100, -100], method="linear").reshape( - jnp.shape(self.cfg["parameters"][self.e_species]["fe"]["velocity"][0]),order="F")) - ne_mult = calc_moment(jnp.squeeze(fe_cur), - self.cfg["parameters"][self.e_species]["fe"]["velocity"],0) - fe_cur = fe_cur/ ne_mult - these_params[species][param_name]=jnp.log(fe_cur) - - - if self.cfg["parameters"][species]["fe"]["dim"] == 1: - these_params[species]["fe"] = jnp.log( - self.smooth(jnp.exp(these_params[species]["fe"][0]))[None, :] - ) - elif self.cfg["dist_fit"]['smooth']: - these_params[species]["fe"] = self.smooth2D(these_params[species]['fe']) - # jnp.log( - # self.smooth2D(jnp.exp(these_params[species]["fe"][0])) - # ) - # these_params["fe"] = jnp.log(self.smooth(jnp.exp(these_params["fe"]))) - - else: - if return_static_params: - these_params[species][param_name] = self.static_params[species][param_name] - - #need to confirm this works due to jax imutability - #jax.debug.print("Temult {total_loss}", total_loss=Te_mult) - #jax.debug.print("nemult {total_loss}", total_loss=ne_mult) - #jax.debug.print("Tebefore {total_loss}", total_loss=these_params[self.e_species]['Te']) - these_params[self.e_species]['Te']*=Te_mult - these_params[self.e_species]['ne']*=ne_mult - #jax.debug.print("Teafter {total_loss}", total_loss=these_params[self.e_species]['Te']) - #jax.debug.print("fe after has NANs {total_loss}", total_loss=jnp.isnan(fe_cur)) - - return these_params - - - def _get_normed_batch_(self, batch: Dict): - """ - Normalizes the batch - - Args: - batch: - - Returns: - - """ - normed_batch = copy.deepcopy(batch) - normed_batch["i_data"] = normed_batch["i_data"] / self.i_input_norm - normed_batch["e_data"] = normed_batch["e_data"] / self.e_input_norm - return normed_batch - - - def vg_loss(self, weights: Dict, batch: Dict): - """ - This is the primary workhorse high level function. This function returns the value of the loss function which - is used to assess goodness-of-fit and the gradient of that value with respect to the weights, which is used to - update the weights - - This function is used by both optimization methods. It performs the necessary pre-/post- processing that is - needed to work with the optimization software. - - Args: - weights: - batch: - - Returns: - - """ - if self.cfg["optimizer"]["method"] == "l-bfgs-b": - pytree_weights = self.unravel_pytree(weights) - (value, aux), grad = self._vg_func_(pytree_weights, batch) - - if "fe" in grad: - grad["fe"] = self.cfg["optimizer"]["grad_scalar"] * grad["fe"] - - for species in self.cfg["parameters"].keys(): - for k, param_dict in self.cfg["parameters"][species].items(): - if param_dict["active"]: - scalar = param_dict["gradient_scalar"] if "gradient_scalar" in param_dict else 1.0 - grad[species][k] *= scalar - - temp_grad, _ = ravel_pytree(grad) - flattened_grads = np.array(temp_grad) - return value, flattened_grads - else: - return self._vg_func_(weights, batch) - - def h_loss_wrt_params(self, weights, batch): - return self._h_func_(weights, batch) - - def _loss_for_hess_fn_(self, weights, batch): - # params = params | self.static_params - params = self.weights_to_params(weights) - ThryE, ThryI, lamAxisE, lamAxisI = self.spec_calc(params, batch) - i_error, e_error, _, _ = self.calc_ei_error( - batch, - ThryI, - lamAxisI, - ThryE, - lamAxisE, - uncert=[jnp.abs(batch["i_data"]) + 1e-10, jnp.abs(batch["e_data"]) + 1e-10], - reduce_func=jnp.sum, - ) - - return i_error + e_error - - def calc_ei_error(self, batch, ThryI, lamAxisI, ThryE, lamAxisE, uncert, reduce_func=jnp.mean): - """ - This function calculates the error in the fit of the IAW and EPW - - Args: - batch: dictionary containing the data - ThryI: ion theoretical spectrum - lamAxisI: ion wavelength axis - ThryE: electron theoretical spectrum - lamAxisE: electron wavelength axis - uncert: uncertainty values - reduce_func: method to combine all lineouts into a single metric - - Returns: - - """ - i_error = 0.0 - e_error = 0.0 - used_points = 0 - i_data = batch["i_data"] - e_data = batch["e_data"] - sqdev = {"ele": jnp.zeros(e_data.shape), "ion": jnp.zeros(i_data.shape)} - - if self.cfg["other"]["extraoptions"]["fit_IAW"]: - _error_ = self.loss_functionals(i_data, ThryI, uncert[0], method = self.cfg['optimizer']['loss_method']) - _error_ = jnp.where( - ( - (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_min"]) - & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_cf_min"]) - ) - | ( - (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_cf_max"]) - & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_max"]) - ), - _error_, - 0.0, - ) - - used_points += jnp.sum( - ( - (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_min"]) - & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_cf_min"]) - ) - | ( - (lamAxisI > self.cfg["data"]["fit_rng"]["iaw_cf_max"]) - & (lamAxisI < self.cfg["data"]["fit_rng"]["iaw_max"]) - ) - ) - #this was temp code to help with 2 species fits - # _error_ = jnp.where( - # (lamAxisI > 526.25) & (lamAxisI < 526.75), - # 10.0 * _error_, - # _error_, - # ) - sqdev["ion"] = _error_ - i_error += reduce_func(_error_) - - if self.cfg["other"]["extraoptions"]["fit_EPWb"]: - _error_ = self.loss_functionals(e_data, ThryE, uncert[1], method = self.cfg['optimizer']['loss_method']) - _error_ = jnp.where( - (lamAxisE > self.cfg["data"]["fit_rng"]["blue_min"]) - & (lamAxisE < self.cfg["data"]["fit_rng"]["blue_max"]), - _error_, - 0.0, - ) - used_points += jnp.sum( - (lamAxisE > self.cfg["data"]["fit_rng"]["blue_min"]) - & (lamAxisE < self.cfg["data"]["fit_rng"]["blue_max"]) - ) - e_error += reduce_func(_error_) - sqdev["ele"] += _error_ - - - if self.cfg["other"]["extraoptions"]["fit_EPWr"]: - _error_ = self.loss_functionals(e_data, ThryE, uncert[1], method = self.cfg['optimizer']['loss_method']) - _error_ = jnp.where( - (lamAxisE > self.cfg["data"]["fit_rng"]["red_min"]) - & (lamAxisE < self.cfg["data"]["fit_rng"]["red_max"]), - _error_, - 0.0, - ) - used_points += jnp.sum( - (lamAxisE > self.cfg["data"]["fit_rng"]["red_min"]) - & (lamAxisE < self.cfg["data"]["fit_rng"]["red_max"]) - ) - - e_error += reduce_func(_error_) - sqdev["ele"] += _error_ - - return i_error, e_error, sqdev, used_points - - def calc_loss(self, weights, batch: Dict): - """ - This function calculates the value of the loss function - - Args: - params: - batch: - - Returns: - - """ - params = self.weights_to_params(weights) - - if self.multiplex_ang: - ThryE, ThryI, lamAxisE, lamAxisI = self.spec_calc(params, batch['b1']) - #jax.debug.print("fe size {e_error}", e_error=jnp.shape(params[self.e_species]['fe'])) - params[self.e_species]['fe']=rotate(jnp.squeeze(params[self.e_species]['fe']),self.cfg['data']['shot_rot']*jnp.pi/180.0) - - ThryE_rot, _, _, _ = self.spec_calc(params, batch['b2']) - i_error1, e_error1, sqdev, used_points = self.calc_ei_error( - batch['b1'], - ThryI, - lamAxisI, - ThryE, - lamAxisE, - denom=[jnp.square(self.i_norm), jnp.square(self.e_norm)], - reduce_func=jnp.mean, - ) - i_error2, e_error2, sqdev, used_points = self.calc_ei_error( - batch['b2'], - ThryI, - lamAxisI, - ThryE_rot, - lamAxisE, - denom=[jnp.square(self.i_norm), jnp.square(self.e_norm)], - reduce_func=jnp.mean, - ) - i_error = i_error1 +i_error2 - e_error = e_error1 +e_error2 - - normed_batch = self._get_normed_batch_(batch['b1']) - else: - ThryE, ThryI, lamAxisE, lamAxisI = self.spec_calc(params, batch) - - i_error, e_error, sqdev, used_points = self.calc_ei_error( - batch, - ThryI, - lamAxisI, - ThryE, - lamAxisE, - uncert=[jnp.square(self.i_norm), jnp.square(self.e_norm)], - reduce_func=jnp.mean, - ) - - - normed_batch = self._get_normed_batch_(batch) - - normed_e_data = normed_batch["e_data"] - ion_error = self.cfg["data"]["ion_loss_scale"] * i_error - - penalty_error = self.penalties(weights) - total_loss = ion_error + e_error + penalty_error - #jax.debug.print("e_error {total_loss}", total_loss=e_error) - - return total_loss, sqdev, used_points, ThryE, ThryI, params - #return total_loss, [ThryE, params] - def loss(self, weights, batch: Dict): - """ - High level function that returns the value of the loss function - - Args: - weights: - batch: Dict - - Returns: - - """ - if self.cfg["optimizer"]["method"] == "l-bfgs-b": - pytree_weights = self.unravel_pytree(weights) - value, _ = self._loss_(pytree_weights, batch) - return value - else: - return self._loss_(weights, batch) - - def __loss__(self, weights, batch: Dict): - """ - Output wrapper - """ - - total_loss, sqdev, used_points, ThryE, normed_e_data, params = self.calc_loss(weights, batch) - return total_loss, [ThryE, params] - - def loss_functionals(self,d,t,uncert,method='l2'): - """ - This function calculates the error loss metric between d and t for different metrics sepcified by method, - with the default being the l2 norm - - Args: - d: data array - t: theory array - uncert: uncertainty values - method: name of the loss metric method, l1, l2, poisson, log-cosh. Currently only l1 and l2 include the uncertainties - - Returns: - loss: value of the loss metric per slice - - """ - if method == 'l1': - _error_= jnp.abs(d - t) / uncert - elif method == 'l2': - _error_ = jnp.square(d - t) / uncert - elif method == 'log-cosh': - _error_ = jnp.log(jnp.cosh(d - t)) - elif method == 'poisson': - _error_ = t-d*jnp.log(t) - return _error_ - - - def penalties(self, weights): - """ - This function calculates additional penatlities to be added to the loss function - - Args: - params: parameter weights as supplied to the loss function - batch: - - Returns: - - """ - param_penalty = 0.0 - #this will need to be modified for the params instead of weights - for species in weights.keys(): - for k in weights[species].keys(): - if k!='fe': - #jax.debug.print("fe size {e_error}", e_error=weights[species][k]) - param_penalty += jnp.maximum(0.0, jnp.log(jnp.abs(weights[species][k] - 0.5) + 0.5)) - if self.cfg['optimizer']['moment_loss']: - density_loss, temperature_loss, momentum_loss = self._moment_loss_(weights) - param_penalty= param_penalty+density_loss+temperature_loss+momentum_loss - else: - density_loss = 0.0 - temperature_loss=0.0 - momentum_loss=0.0 - if self.cfg["parameters"][self.e_species]["fe"]["fe_decrease_strict"]: - gradfe = jnp.sign(self.cfg["velocity"][1:]) * jnp.diff(params["fe"].squeeze()) - vals = jnp.where(gradfe > 0.0, gradfe, 0.0).sum() - fe_penalty = jnp.tan(jnp.amin(jnp.array([vals, jnp.pi / 2]))) - else: - fe_penalty = 0.0 - #jax.debug.print("e_err {e_error}", e_error=e_error) - # jax.debug.print("{density_loss}", density_loss=density_loss) - # jax.debug.print("{temperature_loss}", temperature_loss=temperature_loss) - # jax.debug.print("{momentum_loss}", momentum_loss=momentum_loss) - #jax.debug.print("tot loss {total_loss}", total_loss=total_loss) - #jax.debug.print("param_penalty {total_loss}", total_loss=jnp.sum(param_penalty)) - - return jnp.sum(param_penalty)+fe_penalty+density_loss+temperature_loss+momentum_loss - - def _moment_loss_(self, params): - """ - This function calculates the loss associated with regularizing the moments of the distribution function i.e. - the density should be 1, the temperature should be 1, and momentum should be 0. - - Args: - params: - - Returns: - - """ - if self.cfg["parameters"][self.e_species]["fe"]["dim"] == 1: - dv = ( - self.cfg["parameters"][self.e_species]["fe"]["velocity"][1] - - self.cfg["parameters"][self.e_species]["fe"]["velocity"][0] - ) - if self.cfg["parameters"][self.e_species]["fe"]["symmetric"]: - density_loss = jnp.mean( - jnp.square(1.0 - 2.0 * jnp.sum(jnp.exp(params[self.e_species]["fe"]) * dv, axis=1)) - ) - temperature_loss = jnp.mean( - jnp.square( - 1.0 - - 2.0 - * jnp.sum( - jnp.exp(params[self.e_species]["fe"]) - * self.cfg["parameters"][self.e_species]["fe"]["velocity"] ** 2.0 - * dv, - axis=1, - ) - ) - ) - else: - density_loss = jnp.mean(jnp.square(1.0 - jnp.sum(jnp.exp(params[self.e_species]["fe"]) * dv, axis=1))) - temperature_loss = jnp.mean( - jnp.square( - 1.0 - - jnp.sum( - jnp.exp(params[self.e_species]["fe"]) - * self.cfg["parameters"][self.e_species]["fe"]["velocity"] ** 2.0 - * dv, - axis=1, - ) - ) - ) - momentum_loss = jnp.mean( - jnp.square( - jnp.sum( - jnp.exp(params[self.e_species]["fe"]) - * self.cfg["parameters"][self.e_species]["fe"]["velocity"] - * dv, - axis=1, - ) - ) - ) - else: - fedens = trapz( - trapz( - jnp.exp(params[self.e_species]["fe"]), self.cfg["parameters"][self.e_species]["fe"]["v_res"] - ), - self.cfg["parameters"][self.e_species]["fe"]["v_res"], - ) - jax.debug.print("zero moment = {fedens}", fedens=fedens) - density_loss = jnp.mean(jnp.square(1.0-fedens)) - - # density_loss = jnp.mean( - # jnp.square( - # 1.0 - # - trapz( - # trapz( - # jnp.exp(params[self.e_species]["fe"]), self.cfg["parameters"][self.e_species]["fe"]["v_res"] - # ), - # self.cfg["parameters"][self.e_species]["fe"]["v_res"], - # ) - # ) - # ) - second_moment = trapz( - trapz( - jnp.exp(params[self.e_species]["fe"]) - * (self.cfg["parameters"][self.e_species]["fe"]["velocity"][0]**2 - + self.cfg["parameters"][self.e_species]["fe"]["velocity"][1]**2), - self.cfg["parameters"][self.e_species]["fe"]["v_res"], - ), - self.cfg["parameters"][self.e_species]["fe"]["v_res"], - ) - jax.debug.print("second moment = {fedens}", fedens=second_moment) - temperature_loss = jnp.mean(jnp.square(1.0- second_moment/2)) - # needs to be fixed - first_moment = second_moment = trapz( - trapz( - jnp.exp(params[self.e_species]["fe"]) - * (self.cfg["parameters"][self.e_species]["fe"]["velocity"][0]**2 - + self.cfg["parameters"][self.e_species]["fe"]["velocity"][1]**2)**(1/2), - self.cfg["parameters"][self.e_species]["fe"]["v_res"], - ), - self.cfg["parameters"][self.e_species]["fe"]["v_res"], - ) - jax.debug.print("first moment = {fedens}", fedens=first_moment) - # momentum_loss = jnp.mean(jnp.square(jnp.sum(jnp.exp(params["fe"]) * self.cfg["velocity"] * dv, axis=1))) - momentum_loss = 0.0 - # print(temperature_loss) - return density_loss, temperature_loss, momentum_loss - -def init_weights_and_bounds(config, num_slices): - """ - this dict form will be unpacked for scipy consumption, we assemble them all in the same way so that we can then - use ravel pytree from JAX utilities to unpack it - Args: - config: - init_weights: - num_slices: - - Returns: - - """ - lb = {"active": {}, "inactive": {}} - ub = {"active": {}, "inactive": {}} - iw = {"active": {}, "inactive": {}} - - for species in config["parameters"].keys(): - lb["active"][species] = {} - ub["active"][species] = {} - iw["active"][species] = {} - lb["inactive"][species] = {} - ub["inactive"][species] = {} - iw["inactive"][species] = {} - - for species in config["parameters"].keys(): - for k, v in config["parameters"][species].items(): - if k == "type": - continue - if v["active"]: - active_or_inactive = "active" - else: - active_or_inactive = "inactive" - - if k != "fe": - iw[active_or_inactive][species][k] = np.array( - [config["parameters"][species][k]["val"] for _ in range(num_slices)] - )[:, None] - else: - iw[active_or_inactive][species][k] = np.concatenate( - [config["parameters"][species][k]["val"] for _ in range(num_slices)] - ) - - if v["active"]: - lb[active_or_inactive][species][k] = np.array( - [0 * config["units"]["lb"][species][k] for _ in range(num_slices)] - ) - ub[active_or_inactive][species][k] = np.array( - [1.0 + 0 * config["units"]["ub"][species][k] for _ in range(num_slices)] - ) - - if k != "fe": - iw[active_or_inactive][species][k] = ( - iw[active_or_inactive][species][k] - config["units"]["shifts"][species][k] - ) / config["units"]["norms"][species][k] - else: - iw[active_or_inactive][species][k] = ( - iw[active_or_inactive][species][k] - - config["units"]["shifts"][species][k].reshape(jnp.shape(iw[active_or_inactive][species][k])) - ) / config["units"]["norms"][species][k].reshape(jnp.shape(iw[active_or_inactive][species][k])) - - return lb, ub, iw diff --git a/tsadar/model/physics/generate_spectra.py b/tsadar/model/physics/generate_spectra.py deleted file mode 100644 index 0f345e58..00000000 --- a/tsadar/model/physics/generate_spectra.py +++ /dev/null @@ -1,267 +0,0 @@ -from typing import Dict - -from tsadar.model.physics.form_factor import FormFactor - -# from tsadar.distribution_functions.gen_num_dist_func import DistFunc -from tsadar.distribution_functions.gen_num_dist_func import DistFunc - -from jax import numpy as jnp - - -class FitModel: - """ - The FitModel Class wraps the FormFactor class adding finite aperture effects and finite volume effects. This class - also handles the options for calculating the form factor. - - Args: - config: Dict- configuration dictionary built from input deck - sa: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative - weights of each of the scattering angles in the final spectrum - """ - - def __init__(self, config: Dict, sa): - """ - FitModel class constructor, sets the static properties associated with spectrum generation that will not be - modified from one iteration of the fitter to the next. - - Args: - config: Dict- configuration dictionary built from input deck - sa: Dict- has fields containing the scattering angles the spectrum will be calculated at and the relative - weights of each of the scattering angles in the final spectrum - """ - self.config = config - self.sa = sa - # this will need to be fixed for multi electron - self.num_ions = 0 - self.num_electrons = 0 - for species in config["parameters"].keys(): - if "electron" in config["parameters"][species]["type"].keys(): - self.num_dist_func = DistFunc(config["parameters"][species]) - self.e_species = species - self.num_electrons += 1 - elif "ion" in config["parameters"][species]["type"].keys(): - self.num_ions += 1 - - #print(f"{config['other']['npts']=}") - self.electron_form_factor = FormFactor( - config["other"]["lamrangE"], - npts=config["other"]["npts"], - fe_dim=self.num_dist_func.dim, - vax=config["parameters"][self.e_species]["fe"]["velocity"], - ) - self.ion_form_factor = FormFactor( - config["other"]["lamrangI"], - npts=config["other"]["npts"], - fe_dim=self.num_dist_func.dim, - vax=config["parameters"][self.e_species]["fe"]["velocity"], - ) - - def __call__(self, all_params: Dict): - """ - Produces Thomson spectra corrected for finite aperture and optionally including gradients in the plasma - conditions based off the current parameter dictionary. Calling this method will automatically choose the - appropriate version of the formfactor class based off the dimension and distribute the conditions for - multiple ion species to their respective inputs. - - - Args: - all_params: Parameter dictionary containing the current values for all active and static parameters. Only a - few permanently static properties from the configuration dictionary will be used, everything else must - be included in this input. - - Returns: - modlE: calculated electron plasma wave spectrum as an array with length of npts. If an angular spectrum is - calculated then it will be 2D. If the EPW is not loaded this is returned as the int 0. - modlI: calculated ion acoustic wave spectrum as an array with length of npts. If the IAW is not loaded this - is returned as the int 0. - lamAxisE: electron plasma wave wavelength axis as an array with length of npts. If the EPW is not loaded - this is returned as an empty list. - lamAxisI: ion acoustic wave wavelength axis as an array with length of npts. If the IAW is not loaded - this is returned as an empty list. - all_params: The input all_params is returned - - """ - if self.config["parameters"][self.e_species]["m"]["active"]: - ( - self.config["parameters"][self.e_species]["fe"]["velocity"], - all_params[self.e_species]["fe"], - ) = self.num_dist_func(all_params[self.e_species]["m"]) - # self.config["velocity"], all_params["fe"] = self.num_dist_func(self.config["parameters"]["m"]["val"]) - all_params[self.e_species]["fe"] = jnp.log(all_params[self.e_species]["fe"]) - # all_params["fe"] = jnp.log(self.num_dist_func(self.config["parameters"]["m"])) - if ( - self.config["parameters"][self.e_species]["m"]["active"] - and self.config["parameters"][self.e_species]["fe"]["active"] - ): - raise ValueError("m and fe cannot be actively fit at the same time") - - # Add gradients to electron temperature and density just being applied to EPW - cur_Te = jnp.zeros((self.config["parameters"]["general"]["Te_gradient"]["num_grad_points"], self.num_electrons)) - cur_ne = jnp.zeros((self.config["parameters"]["general"]["ne_gradient"]["num_grad_points"], self.num_electrons)) - A = jnp.zeros(self.num_ions) - Z = jnp.zeros(self.num_ions) - Ti = jnp.zeros(self.num_ions) - fract = jnp.zeros(self.num_ions) - - ion_c = 0 - ele_c = 0 - for species in self.config["parameters"].keys(): - if "electron" in self.config["parameters"][species]["type"].keys(): - cur_Te = cur_Te.at[:, ele_c].set( - jnp.linspace( - (1 - all_params["general"]["Te_gradient"] / 200) * all_params[species]["Te"], - (1 + all_params["general"]["Te_gradient"] / 200) * all_params[species]["Te"], - self.config["parameters"]["general"]["Te_gradient"]["num_grad_points"], - ).squeeze() - ) - - cur_ne = cur_ne.at[:, ele_c].set( - ( - jnp.linspace( - (1 - all_params["general"]["ne_gradient"] / 200) * all_params[species]["ne"], - (1 + all_params["general"]["ne_gradient"] / 200) * all_params[species]["ne"], - self.config["parameters"]["general"]["ne_gradient"]["num_grad_points"], - ) - * 1e20 - ).squeeze() - ) - ele_c += 1 - - elif "ion" in self.config["parameters"][species]["type"].keys(): - A = A.at[ion_c].set(all_params[species]["A"].squeeze()) - Z = Z.at[ion_c].set(all_params[species]["Z"].squeeze()) - if self.config["parameters"][species]["Ti"]["same"]: - Ti = Ti.at[ion_c].set(Ti[ion_c - 1]) - else: - Ti = Ti.at[ion_c].set(all_params[species]["Ti"].squeeze()) - fract = fract.at[ion_c].set(all_params[species]["fract"].squeeze()) - ion_c += 1 - - lam = all_params["general"]["lam"] - - if self.config["parameters"][self.e_species]["m"]["active"]: - ( - self.config["parameters"][self.e_species]["fe"]["velocity"], - all_params[self.e_species]["fe"], - ) = self.num_dist_func(all_params[self.e_species]["m"]) - all_params[self.e_species]["fe"] = jnp.log(all_params[self.e_species]["fe"]) - if ( - self.config["parameters"][self.e_species]["m"]["active"] - and self.config["parameters"][self.e_species]["fe"]["active"] - ): - raise ValueError("m and fe cannot be actively fit at the same time") - elif self.config["parameters"][self.e_species]["m"]["matte"]: - # Intensity should be given in effective 3omega intensity e.i. I*lamda^2/lamda_3w^2 and in units of 10^14 W/cm^2 - alpha = ( - 0.042 - * self.config["parameters"][self.e_species]["m"]["intens"] - / 9.0 - * jnp.sum(Z**2 * fract) - / (jnp.sum(Z * fract) * cur_Te) - ) - mcur = 2.0 + 3.0 / (1.0 + 1.66 / (alpha**0.724)) - ( - self.config["parameters"][self.e_species]["fe"]["velocity"], - all_params[self.e_species]["fe"], - ) = self.num_dist_func(mcur.squeeze()) - all_params[self.e_species]["fe"] = jnp.log(all_params[self.e_species]["fe"]) - - fecur = jnp.exp(all_params[self.e_species]["fe"]) - vcur = self.config["parameters"][self.e_species]["fe"]["velocity"] - if self.config["parameters"][self.e_species]["fe"]["symmetric"]: - fecur = jnp.concatenate((jnp.flip(fecur[1:]), fecur)) - vcur = jnp.concatenate((-jnp.flip(vcur[1:]), vcur)) - - if self.config["other"]["extraoptions"]["load_ion_spec"]: - if self.num_dist_func.dim == 1: - ThryI, lamAxisI = self.ion_form_factor( - all_params, cur_ne, cur_Te, A, Z, Ti, fract, self.sa["sa"], (fecur, vcur), lam - ) - else: - ThryI, lamAxisI = self.ion_form_factor.calc_in_2D( - all_params, - self.config["parameters"]["general"]["ud"]["angle"], - self.config["parameters"]["general"]["ud"]["angle"], - cur_ne, - cur_Te, - A, - Z, - Ti, - fract, - self.sa["sa"], - (fecur, vcur), - lam, - ) - - # remove extra dimensions and rescale to nm - lamAxisI = jnp.squeeze(lamAxisI) * 1e7 # TODO hardcoded - - ThryI = jnp.real(ThryI) - ThryI = jnp.mean(ThryI, axis=0) - modlI = jnp.sum(ThryI * self.sa["weights"][0], axis=1) - else: - modlI = 0 - lamAxisI = [] - - if self.config["other"]["extraoptions"]["load_ele_spec"]: - if self.num_dist_func.dim == 1: - ThryE, lamAxisE = self.electron_form_factor( - all_params, - cur_ne, - cur_Te, - A, - Z, - Ti, - fract, - self.sa["sa"], - (fecur, vcur), - lam + self.config["data"]["ele_lam_shift"], - ) - else: - ThryE, lamAxisE = self.electron_form_factor.calc_in_2D( - all_params, - self.config["parameters"]["general"]["ud"]["angle"], - self.config["parameters"]["general"]["ud"]["angle"], - cur_ne, - cur_Te, - A, - Z, - Ti, - fract, - self.sa["sa"], - (fecur, vcur), - lam + self.config["data"]["ele_lam_shift"], - ) - - # remove extra dimensions and rescale to nm - lamAxisE = jnp.squeeze(lamAxisE) * 1e7 # TODO hardcoded - - ThryE = jnp.real(ThryE) - ThryE = jnp.mean(ThryE, axis=0) - if self.config["other"]["extraoptions"]["spectype"] == "angular_full": - modlE = jnp.matmul(self.sa["weights"], ThryE.transpose()) - else: - modlE = jnp.sum(ThryE * self.sa["weights"][0], axis=1) - - if self.config["other"]["iawoff"] and ( - self.config["other"]["lamrangE"][0] < lam < self.config["other"]["lamrangE"][1] - ): - # set the ion feature to 0 #should be switched to a range about lam - lamlocb = jnp.argmin(jnp.abs(lamAxisE - lam - 3.0)) - lamlocr = jnp.argmin(jnp.abs(lamAxisE - lam + 3.0)) - modlE = jnp.concatenate( - [modlE[:lamlocb], jnp.zeros(lamlocr - lamlocb), modlE[lamlocr:]] - ) # TODO hardcoded - - if self.config["other"]["iawfilter"][0]: - filterb = self.config["other"]["iawfilter"][3] - self.config["other"]["iawfilter"][2] / 2 - filterr = self.config["other"]["iawfilter"][3] + self.config["other"]["iawfilter"][2] / 2 - - if self.config["other"]["lamrangE"][0] < filterr and self.config["other"]["lamrangE"][1] > filterb: - indices = (filterb < lamAxisE) & (filterr > lamAxisE) - modlE = jnp.where(indices, modlE * 10 ** (-self.config["other"]["iawfilter"][1]), modlE) - else: - modlE = 0 - lamAxisE = [] - - return modlE, modlI, lamAxisE, lamAxisI, all_params diff --git a/tsadar/plotting/ele_fit_and_data (1).nc b/tsadar/plotting/ele_fit_and_data (1).nc deleted file mode 100644 index 1a33d6d4..00000000 Binary files a/tsadar/plotting/ele_fit_and_data (1).nc and /dev/null differ diff --git a/tsadar/process/postprocess.py b/tsadar/process/postprocess.py deleted file mode 100644 index 1f2839f3..00000000 --- a/tsadar/process/postprocess.py +++ /dev/null @@ -1,353 +0,0 @@ -from typing import Dict - -import time, tempfile, mlflow, os, copy - -import numpy as np -import scipy.optimize as spopt - -from tsadar.plotting import plotters -from tsadar.model.TSFitter import TSFitter - - -def recalculate_with_chosen_weights( - config: Dict, batch_indices, all_data: Dict, ts_fitter: TSFitter, calc_sigma: bool, fitted_weights: Dict -): - """ - Gets parameters and the result of the full forward pass i.e. fits - - - Args: - config: Dict- configuration dictionary built from input deck - batch_indices: - all_data: Dict- contains the electron data, ion data, and their respective amplitudes - ts_fitter: Instance of the TSFitter class - fitted_weights: Dict- best values of the parameters returned by the minimizer - - Returns: - - """ - - all_params = {} - num_params = 0 - for species in config["parameters"].keys(): - all_params[species] = {} - for key in config["parameters"][species].keys(): - if config["parameters"][species][key]["active"]: - all_params[species][key] = np.zeros( - (batch_indices.flatten()[-1] + 1, np.size(config["parameters"][species][key]["val"])), - dtype=np.float64, - ) - num_params += np.shape(all_params[species][key])[1] - batch_indices.sort() - losses = np.zeros(batch_indices.flatten()[-1] + 1, dtype=np.float64) - batch_indices = np.reshape(batch_indices, (-1, config["optimizer"]["batch_size"])) - - fits = {} - sqdevs = {} - fits["ion"] = np.zeros(all_data["i_data"].shape) - sqdevs["ion"] = np.zeros(all_data["i_data"].shape) - fits["ele"] = np.zeros(all_data["e_data"].shape) - sqdevs["ele"] = np.zeros(all_data["e_data"].shape) - - # num_params = 0 - # for key, vec in all_params.items(): - # num_params += np.shape(vec)[1] - - if config["other"]["extraoptions"]["load_ion_spec"]: - sigmas = np.zeros((all_data["i_data"].shape[0], num_params)) - - if config["other"]["extraoptions"]["load_ele_spec"]: - sigmas = np.zeros((all_data["e_data"].shape[0], num_params)) - - if config["other"]["extraoptions"]["spectype"] == "angular_full": - batch = { - "e_data": all_data["e_data"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], - "e_amps": all_data["e_amps"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], - "i_data": all_data["i_data"], - "i_amps": all_data["i_amps"], - "noise_e": all_data["noiseE"][ - config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : - ], - "noise_i": all_data["noiseI"][ - config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], : - ], - } - losses, sqds, used_points, [ThryE, _, params] = ts_fitter.array_loss(fitted_weights, batch) - fits["ele"] = ThryE - sqdevs["ele"] = sqds["ele"] - - for species in all_params.keys(): - for k in all_params[species].keys(): - if k != 'fe': - # all_params[k] = np.concatenate([all_params[k], params[k].reshape(-1)]) - all_params[species][k] = params[species][k].reshape(-1) - else: - all_params[species][k] = params[species][k] - - if calc_sigma: - # this line may need to be omited since the weights may be transformed by line 77 - active_params = ts_fitter.weights_to_params(fitted_weights, return_static_params=False) - hess = ts_fitter.h_loss_wrt_params(active_params, batch) - sigmas = get_sigmas(hess, config["optimizer"]["batch_size"]) - print(f"Number of 0s in sigma: {len(np.where(sigmas==0)[0])}") - - else: - for i_batch, inds in enumerate(batch_indices): - batch = { - "e_data": all_data["e_data"][inds], - "e_amps": all_data["e_amps"][inds], - "i_data": all_data["i_data"][inds], - "i_amps": all_data["i_amps"][inds], - "noise_e": all_data["noiseE"][inds], - "noise_i": all_data["noiseI"][inds], - } - - #loss, sqds, used_points, ThryE, ThryI, params = ts_fitter.array_loss(fitted_weights[i_batch], batch) - loss, sqds, used_points, ThryE, ThryI, params = ts_fitter.array_loss(fitted_weights[i_batch], batch) - #calc_ei_error(self, batch, ThryI, lamAxisI, ThryE, lamAxisE, uncert, reduce_func=jnp.mean) - # these_params = ts_fitter.weights_to_params(fitted_weights[i_batch], return_static_params=False) - - if calc_sigma: - hess = ts_fitter.h_loss_wrt_params(fitted_weights[i_batch], batch) - try: - hess = ts_fitter.h_loss_wrt_params(fitted_weights[i_batch], batch) - except: - print("Error calculating Hessian, no hessian based uncertainties have been calculated") - calc_sigma = False - - losses[inds] = loss - sqdevs["ele"][inds] = sqds["ele"] - sqdevs["ion"][inds] = sqds["ion"] - if calc_sigma: - sigmas[inds] = get_sigmas(hess, config["optimizer"]["batch_size"]) - # print(f"Number of 0s in sigma: {len(np.where(sigmas==0)[0])}") number of negatives? - - fits["ele"][inds] = ThryE - fits["ion"][inds] = ThryI - - for species in all_params.keys(): - for k in all_params[species].keys(): - if np.size(np.shape(params[species][k])) == 3: - all_params[species][k][inds] = np.squeeze(params[species][k][inds]) - else: - all_params[species][k][inds] = params[species][k] - - return losses, sqdevs, used_points, fits, sigmas, all_params - - -def get_sigmas(hess: Dict, batch_size: int) -> Dict: - """ - Calculates the variance using the hessian with respect to the parameters and then using the hessian values - as the inverse of the covariance matrix and then inverting that. Negatives in the inverse hessian normally indicate - non-optimal points, to represent this in the final result the uncertainty of those values are reported as negative. - - - Args: - hess: Hessian dictionary, the field for each fitted parameter has subfields corresponding to each of the other - fitted parameters. Within each nested subfield is a batch_size x batch_size array with the hessian values - for that parameter combination and that batch. The cross terms of this array are zero since separate - lineouts within a batch do not affect each other, they are therefore discarded - batch_size: int- number of lineouts in a batch - - Returns: - sigmas: batch_size x number_of_parameters array with the uncertainty values for each parameter - """ - sizes = { - key + species: hess[species][key][species][key].shape[1] - for species in hess.keys() - for key in hess[species].keys() - } - # sizes = {key: hess[key][key].shape[1] for key in keys} - actual_num_params = sum([v for k, v in sizes.items()]) - sigmas = np.zeros((batch_size, actual_num_params)) - - for i in range(batch_size): - temp = np.zeros((actual_num_params, actual_num_params)) - k1 = 0 - for species1 in hess.keys(): - for key1 in hess[species1].keys(): - k2 = 0 - for species2 in hess.keys(): - for key2 in hess[species2].keys(): - temp[k1, k2] = np.squeeze(hess[species1][key1][species2][key2])[i, i] - k2 += 1 - k1 += 1 - - # xc = 0 - # for k1, param in enumerate(keys): - # yc = 0 - # for k2, param2 in enumerate(keys): - # if i > 0: - # temp[k1, k2] = np.squeeze(hess[param][param2])[i, i] - # else: - # temp[xc : xc + sizes[param], yc : yc + sizes[param2]] = hess[param][param2][0, :, 0, :] - # - # yc += sizes[param2] - # xc += sizes[param] - - # print(temp) - inv = np.linalg.inv(temp) - # print(inv) - - sigmas[i, :] = np.sign(np.diag(inv)) * np.sqrt(np.abs(np.diag(inv))) - # for k1, param in enumerate(keys): - # sigmas[i, xc : xc + sizes[param]] = np.diag( - # np.sign(inv[xc : xc + sizes[param], xc : xc + sizes[param]]) - # * np.sqrt(np.abs(inv[xc : xc + sizes[param], xc : xc + sizes[param]])) - # ) - # print(sigmas[i, k1]) - # change sigmas into a dictionary? - - return sigmas - - -def postprocess(config, batch_indices, all_data: Dict, all_axes: Dict, ts_fitter, sa, fitted_weights): - t1 = time.time() - - for species in config["parameters"].keys(): - if "electron" in config["parameters"][species]["type"].keys(): - elec_species = species - - if config["other"]["extraoptions"]["spectype"] != "angular_full" and config["other"]["refit"]: - losses_init, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( - config, batch_indices, all_data, ts_fitter, False, fitted_weights - ) - - # refit bad fits - red_losses_init = losses_init / (1.1 * (used_points - len(all_params))) - true_batch_size = config["optimizer"]["batch_size"] - # config["optimizer"]["batch_size"] = 1 - mlflow.log_metrics({"number of fits": len(batch_indices.flatten())}) - mlflow.log_metrics({"number of refits": int(np.sum(red_losses_init > config["other"]["refit_thresh"]))}) - - for i in batch_indices.flatten()[red_losses_init > config["other"]["refit_thresh"]]: - if i == 0: - continue - - batch = { - "e_data": np.reshape(all_data["e_data"][i], (1, -1)), - "e_amps": np.reshape(all_data["e_amps"][i], (1, -1)), - "i_data": np.reshape(all_data["i_data"][i], (1, -1)), - "i_amps": np.reshape(all_data["i_amps"][i], (1, -1)), - "noise_e": np.reshape(all_data["noiseE"][i], (1, -1)), - "noise_i": np.reshape(all_data["noiseI"][i], (1, -1)), - } - - # previous_weights = {} - temp_cfg = copy.copy(config) - temp_cfg["optimizer"]["batch_size"] = 1 - for species in fitted_weights[(i - 1) // true_batch_size].keys(): - for key in fitted_weights[(i - 1) // true_batch_size][species].keys(): - if config["parameters"][species][key]["active"]: - temp_cfg["parameters"][species][key]["val"] = float( - fitted_weights[(i - 1) // true_batch_size][species][key][(i - 1) % true_batch_size] - ) - - ts_fitter_refit = TSFitter(temp_cfg, sa, batch) - - # ts_fitter_refit.flattened_weights, ts_fitter_refit.unravel_pytree = ravel_pytree(previous_weights) - - res = spopt.minimize( - ts_fitter_refit.vg_loss if config["optimizer"]["grad_method"] == "AD" else ts_fitter_refit.loss, - np.copy(ts_fitter_refit.flattened_weights), - args=batch, - method=config["optimizer"]["method"], - jac=True if config["optimizer"]["grad_method"] == "AD" else False, - bounds=ts_fitter_refit.bounds, - options={"disp": True, "maxiter": config["optimizer"]["num_epochs"]}, - ) - cur_result = ts_fitter_refit.unravel_pytree(res["x"]) - - for species in cur_result.keys(): - for key in cur_result[species].keys(): - fitted_weights[i // true_batch_size][species][key] = ( - fitted_weights[i // true_batch_size][species][key] - .at[i % true_batch_size] - .set(cur_result[species][key][0]) - ) - # fitted_weights[i // true_batch_size][species][key][i % true_batch_size] = cur_result[species][key] - - # for key in fitted_weights[i // true_batch_size].keys(): - # cur_value = cur_result[key][0, 0] - # new_vals = fitted_weights[i // true_batch_size][key] - # new_vals = new_vals.at[tuple([i % true_batch_size, 0])].set(cur_value) - # fitted_weights[i // true_batch_size][key] = new_vals - - config["optimizer"]["batch_size"] = true_batch_size - - mlflow.log_metrics({"refitting time": round(time.time() - t1, 2)}) - - with tempfile.TemporaryDirectory() as td: - os.makedirs(os.path.join(td, "plots"), exist_ok=True) - os.makedirs(os.path.join(td, "binary"), exist_ok=True) - os.makedirs(os.path.join(td, "csv"), exist_ok=True) - if config["other"]["extraoptions"]["spectype"] == "angular_full": - best_weights_val = {} - best_weights_std = {} - if config["optimizer"]["num_mins"]>1: - for k, v in fitted_weights.items(): - best_weights_val[k] = np.average(v, axis=0) # [0, :] - best_weights_std[k] = np.std(v, axis=0) # [0, :] - else: - best_weights_val = fitted_weights - - losses, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( - config, batch_indices, all_data, ts_fitter, config["other"]["calc_sigmas"], best_weights_val - ) - - mlflow.log_metrics({"postprocessing time": round(time.time() - t1, 2)}) - mlflow.set_tag("status", "plotting") - t1 = time.time() - - final_params = plotters.get_final_params(config, all_params, all_axes, td) - if config["other"]["calc_sigmas"]: - sigma_fe = plotters.save_sigmas_fe(final_params, best_weights_std, sigmas, td) - else: - sigma_fe = np.zeros_like(final_params["fe"]) - savedata = plotters.plot_data_angular(config, fits, all_data, all_axes, td) - plotters.plot_ang_lineouts(used_points, sqdevs, losses, all_params, all_axes, savedata, td) - plotters.plot_dist(config, elec_species, final_params, sigma_fe, td) - - else: - losses, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( - config, batch_indices, all_data, ts_fitter, config["other"]["calc_sigmas"], fitted_weights - ) - if "losses_init" not in locals(): - losses_init = losses - mlflow.log_metrics({"postprocessing time": round(time.time() - t1, 2)}) - mlflow.set_tag("status", "plotting") - t1 = time.time() - - final_params = plotters.get_final_params(config, all_params, all_axes, td) - # for species in config["parameters"].keys(): - # if "m" in config["parameters"][species].keys(): - # if not config["parameters"][species]["m"]["active"] and config["parameters"][species]["matte"]: - # alpha = ( - # 0.042 - # * config["parameters"][species]["m"]["intens"] - # / 9.0 - # * np.sum(Z**2 * fract) - # / (np.sum(Z * fract) * all_params[species]["Te"]) - # ) - # mcur = 2.0 + 3.0 / (1.0 + 1.66 / (alpha**0.724)) - # ( - # self.config["parameters"][self.e_species]["fe"]["velocity"], - # all_params[self.e_species]["fe"], - # ) = self.num_dist_func(mcur.squeeze()) - - red_losses = plotters.plot_loss_hist(config, losses_init, losses, all_params, used_points, td) - savedata = plotters.plot_ts_data(config, fits, all_data, all_axes, td) - plotters.model_v_actual(config, all_data, all_axes, fits, losses, red_losses, sqdevs, td) - sigma_ds = plotters.save_sigmas_params(config, all_params, sigmas, all_axes, td) - plotters.plot_final_params(config, all_params, sigma_ds, td) - # plotters.plot_dist(config, final_params, sigma_fe, td) - - # final_params = plotters.plot_regular( - # config, losses, all_params, used_points, all_axes, fits, all_data, sqdevs, sigmas, td - # ) - mlflow.log_artifacts(td) - mlflow.log_metrics({"plotting time": round(time.time() - t1, 2)}) - - mlflow.set_tag("status", "done plotting") - - return final_params diff --git a/tsadar/runner.py b/tsadar/runner.py index 64c60a40..02f25a5f 100644 --- a/tsadar/runner.py +++ b/tsadar/runner.py @@ -1,21 +1,13 @@ import time, os from typing import Dict, Tuple -import numpy as np -import matplotlib.pyplot as plt -import mlflow, tempfile, yaml, pandas +import mlflow, tempfile, yaml import multiprocessing as mp -import xarray as xr -from tqdm import tqdm from flatten_dict import flatten, unflatten -from tsadar import fitter -from tsadar.distribution_functions.gen_num_dist_func import DistFunc -from tsadar.model.TSFitter import TSFitter -from tsadar.fitter import init_param_norm_and_shift -from tsadar.misc import utils -from tsadar.plotting import plotters -from tsadar.data_handleing.calibrations.calibration import get_calibrations, get_scattering_angles +from .inverse import fitter +from .forward import calc_series +from .utils import misc if "BASE_TEMPDIR" in os.environ: BASE_TEMPDIR = os.environ["BASE_TEMPDIR"] @@ -86,17 +78,17 @@ def run_for_app(run_id: str) -> str: # download config with tempfile.TemporaryDirectory(dir=BASE_TEMPDIR) as temp_path: - dest_file_path = utils.download_file(f"config.yaml", mlflow_run.info.artifact_uri, temp_path) + dest_file_path = misc.download_file(f"config.yaml", mlflow_run.info.artifact_uri, temp_path) with open(dest_file_path, "r") as fi: config = yaml.safe_load(fi) if config["data"]["filenames"]["epw"] is not None: - config["data"]["filenames"]["epw-local"] = utils.download_file( + config["data"]["filenames"]["epw-local"] = misc.download_file( config["data"]["filenames"]["epw"], mlflow_run.info.artifact_uri, temp_path ) if config["data"]["filenames"]["iaw"] is not None: - config["data"]["filenames"]["iaw-local"] = utils.download_file( + config["data"]["filenames"]["iaw-local"] = misc.download_file( config["data"]["filenames"]["iaw"], mlflow_run.info.artifact_uri, temp_path ) @@ -118,12 +110,12 @@ def _run_(config: Dict, mode: str = "fit"): Returns: """ - utils.log_params(config) + misc.log_mlflow(config) t0 = time.time() - if mode in ("fit", "FIT", "Fit"): + if mode.casefold() == "fit": fit_results, loss = fitter.fit(config=config) elif mode == "forward" or mode == "series": - calc_series(config=config) + calc_series.forward_pass(config=config) else: raise NotImplementedError(f"Mode {mode} not implemented") @@ -148,7 +140,7 @@ def run_job(run_id: str, mode: str, nested: bool): with tempfile.TemporaryDirectory(dir=BASE_TEMPDIR) as temp_path: all_configs = {} for k in ["defaults", "inputs"]: - dest_file_path = utils.download_file(f"{k}.yaml", run.info.artifact_uri, temp_path) + dest_file_path = misc.download_file(f"{k}.yaml", run.info.artifact_uri, temp_path) with open(f"{os.path.join(temp_path, k)}.yaml", "r") as fi: all_configs[k] = yaml.safe_load(fi) defaults = flatten(all_configs["defaults"]) @@ -156,220 +148,3 @@ def run_job(run_id: str, mode: str, nested: bool): config = unflatten(defaults) _run_(config, mode) - - -def calc_series(config): - """ - Calculates a spectrum or series of spectra from the input deck, i.e. performs a forward pass or series of forward - passes. - - - Args: - config: Dictionary - Configuration dictionary created from the input deck. For series of spectra contains the special - field 'series'. This field can have up to 8 subfields [param1, vals1, param2, vals2, param3, vals3, param4, vals4]. - the param subfields are a string identifying which fields of "parameters" are to be looped over. The vals subfields - give the values of that subfield for each spectrum in the series. - - Returns: - Ion data, electron data, and plots are saved to mlflow - - """ - # get scattering angles and weights - config["optimizer"]["batch_size"] = 1 - config["other"]["lamrangE"] = [ - config["data"]["fit_rng"]["forward_epw_start"], - config["data"]["fit_rng"]["forward_epw_end"], - ] - config["other"]["lamrangI"] = [ - config["data"]["fit_rng"]["forward_iaw_start"], - config["data"]["fit_rng"]["forward_iaw_end"], - ] - config["other"]["npts"] = int(config["other"]["CCDsize"][1] * config["other"]["points_per_pixel"]) - - for species in config["parameters"].keys(): - if "electron" in config["parameters"][species]["type"].keys(): - elec_species = species - dist_obj = DistFunc(config["parameters"][species]) - config["parameters"][species]["fe"]["velocity"], config["parameters"][species]["fe"]["val"] = dist_obj( - config["parameters"][species]["m"]["val"] - ) - config["parameters"][species]["fe"]["val"] = np.log(config["parameters"][species]["fe"]["val"])[None, :] - - config["units"] = init_param_norm_and_shift(config) - - sas = get_scattering_angles(config) - dummy_batch = { - "i_data": np.array([1]), - "e_data": np.array([1]), - "noise_e": np.array([0]), - "noise_i": np.array([0]), - "e_amps": np.array([1]), - "i_amps": np.array([1]), - } - - if config["other"]["extraoptions"]["spectype"] == "angular": - [axisxE, _, _, _, _, _] = get_calibrations( - 104000, config["other"]["extraoptions"]["spectype"], 0.0, config["other"]["CCDsize"] - ) # shot number hardcoded to get calibration - config["other"]["extraoptions"]["spectype"] = "angular_full" - - sas["angAxis"] = axisxE - dummy_batch["i_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) - dummy_batch["e_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) - - if "series" in config.keys(): - serieslen = len(config["series"]["vals1"]) - else: - serieslen = 1 - ThryE = [None] * serieslen - ThryI = [None] * serieslen - lamAxisE = [None] * serieslen - lamAxisI = [None] * serieslen - - t_start = time.time() - for i in tqdm(range(serieslen), total=serieslen): - if "series" in config.keys(): - - config["parameters"]["species"][config["series"]["param1"]]["val"] = config["series"]["vals1"][i] - if "param2" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param2"]]["val"] = config["series"]["vals2"][i] - if "param3" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param3"]]["val"] = config["series"]["vals3"][i] - if "param4" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param4"]]["val"] = config["series"]["vals4"][i] - - if config["other"]["extraoptions"]["spectype"] == "angular": - [axisxE, _, _, _, _, _] = get_calibrations( - 104000, config["other"]["extraoptions"]["spectype"], config["other"]["CCDsize"] - ) # shot number hardcoded to get calibration - config["other"]["extraoptions"]["spectype"] = "angular_full" - - sas["angAxis"] = axisxE - dummy_batch["i_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) - dummy_batch["e_data"] = np.ones((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1])) - - if "series" in config.keys(): - serieslen = len(config["series"]["vals1"]) - else: - serieslen = 1 - ThryE = [None] * serieslen - ThryI = [None] * serieslen - lamAxisE = [None] * serieslen - lamAxisI = [None] * serieslen - - t_start = time.time() - for i in range(serieslen): - if "series" in config.keys(): - config["parameters"]["species"][config["series"]["param1"]]["val"] = config["series"]["vals1"][i] - if "param2" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param2"]]["val"] = config["series"]["vals2"][i] - if "param3" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param3"]]["val"] = config["series"]["vals3"][i] - if "param4" in config["series"].keys(): - config["parameters"]["species"][config["series"]["param4"]]["val"] = config["series"]["vals4"][i] - - ts_fitter = TSFitter(config, sas, dummy_batch) - params = ts_fitter.weights_to_params(ts_fitter.pytree_weights["active"]) - ThryE[i], ThryI[i], lamAxisE[i], lamAxisI[i] = ts_fitter.spec_calc(params, dummy_batch) - - spectime = time.time() - t_start - ThryE = np.array(ThryE) - ThryI = np.array(ThryI) - lamAxisE = np.array(lamAxisE) - lamAxisI = np.array(lamAxisI) - - with tempfile.TemporaryDirectory() as td: - os.makedirs(os.path.join(td, "plots"), exist_ok=True) - os.makedirs(os.path.join(td, "binary"), exist_ok=True) - os.makedirs(os.path.join(td, "csv"), exist_ok=True) - if config["other"]["extraoptions"]["spectype"] == "angular_full": - savedata = plotters.plot_data_angular( - config, - {"ele": np.squeeze(ThryE)}, - {"e_data": np.zeros((config["other"]["CCDsize"][0], config["other"]["CCDsize"][1]))}, - {"epw_x": sas["angAxis"], "epw_y": lamAxisE}, - td, - ) - plotters.plot_dist( - config, - elec_species, - { - "fe": np.squeeze(config["parameters"][elec_species]["fe"]["val"]), - "v": config["parameters"][elec_species]["fe"]["velocity"], - }, - np.zeros_like(config["parameters"][elec_species]["fe"]["val"]), - td, - ) - print(np.shape(config["parameters"][elec_species]["fe"]["val"])) - if len(np.shape(np.squeeze(config["parameters"][elec_species]["fe"]["val"])))==1: - final_dist = pandas.DataFrame({'fe':[l for l in config["parameters"][elec_species]["fe"]["val"]], 'vx':[vx for vx in config["parameters"][elec_species]["fe"]["velocity"]]}) - elif len(np.shape(np.squeeze(config["parameters"][elec_species]["fe"]["val"])))==2: - final_dist = pandas.DataFrame(data=np.squeeze(config["parameters"][elec_species]["fe"]["val"]), columns=config["parameters"][elec_species]["fe"]["velocity"][0][0], index=config["parameters"][elec_species]["fe"]["velocity"][0][:,0]) - final_dist.to_csv(os.path.join(td, "csv", "learned_dist.csv")) - else: - if config["parameters"][elec_species]["fe"]["dim"] == 2: - plotters.plot_dist( - config, - elec_species, - { - "fe": config["parameters"][elec_species]["fe"]["val"], - "v": config["parameters"][elec_species]["fe"]["velocity"], - }, - np.zeros_like(config["parameters"][elec_species]["fe"]["val"]), - td, - ) - - fig, ax = plt.subplots(1, 2, figsize=(12, 6), tight_layout=True, sharex=False) - if config["other"]["extraoptions"]["load_ele_spec"]: - ax[0].plot( - lamAxisE.squeeze().transpose(), ThryE.squeeze().transpose() - ) # transpose might break single specs? - ax[0].set_title("Simulated Data, fontsize=14") - ax[0].set_ylabel("Amp (arb. units)") - ax[0].set_xlabel("Wavelength (nm)") - ax[0].grid() - - if "series" in config.keys(): - ax[0].legend([str(ele) for ele in config["series"]["vals1"]]) - if config["series"]["param1"] == "fract" or config["series"]["param1"] == "Z": - coords_ele = ( - ("series", np.array(config["series"]["vals1"])[:, 0]), - ("Wavelength", lamAxisE[0, :]), - ) - else: - coords_ele = (("series", config["series"]["vals1"]), ("Wavelength", lamAxisE[0, :])) - ele_dat = {"Sim": ThryE} - ele_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ele) for k, v in ele_dat.items()}) - else: - coords_ele = (("series", [0]), ("Wavelength", lamAxisE[0, :].squeeze())) - ele_dat = {"Sim": ThryE.squeeze(0)} - ele_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ele) for k, v in ele_dat.items()}) - ele_data.to_netcdf(os.path.join(td, "binary", "ele_fit_and_data.nc")) - - if config["other"]["extraoptions"]["load_ion_spec"]: - ax[1].plot(lamAxisI.squeeze().transpose(), ThryI.squeeze().transpose()) - ax[1].set_title("Simulated Data, fontsize=14") - ax[1].set_ylabel("Amp (arb. units)") - ax[1].set_xlabel("Wavelength (nm)") - ax[1].grid() - - if "series" in config.keys(): - ax[1].legend([str(ele) for ele in config["series"]["vals1"]]) - if config["series"]["param1"] == "fract" or config["series"]["param1"] == "Z": - coords_ion = ( - ("series", np.array(config["series"]["vals1"])[:, 0]), - ("Wavelength", lamAxisI[0, :]), - ) - else: - coords_ion = (("series", config["series"]["vals1"]), ("Wavelength", lamAxisI[0, :])) - ion_dat = {"Sim": ThryI} - ion_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ion) for k, v in ion_dat.items()}) - else: - coords_ion = (("series", [0]), ("Wavelength", lamAxisI[0, :].squeeze())) - ion_dat = {"Sim": ThryI.squeeze(0)} - ion_data = xr.Dataset({k: xr.DataArray(v, coords=coords_ion) for k, v in ion_dat.items()}) - ion_data.to_netcdf(os.path.join(td, "binary", "ion_fit_and_data.nc")) - fig.savefig(os.path.join(td, "plots", "simulated_data"), bbox_inches="tight") - mlflow.log_artifacts(td) - metrics_dict = {"spectrum_calc_time": spectime} - mlflow.log_metrics(metrics=metrics_dict) diff --git a/tsadar/plotting/__init__.py b/tsadar/utils/__init__.py similarity index 100% rename from tsadar/plotting/__init__.py rename to tsadar/utils/__init__.py diff --git a/tsadar/process/__init__.py b/tsadar/utils/data_handling/__init__.py similarity index 100% rename from tsadar/process/__init__.py rename to tsadar/utils/data_handling/__init__.py diff --git a/tsadar/data_handleing/calibrations/calibration.py b/tsadar/utils/data_handling/calibration.py similarity index 63% rename from tsadar/data_handleing/calibrations/calibration.py rename to tsadar/utils/data_handling/calibration.py index 168c1e85..b46f502a 100644 --- a/tsadar/data_handleing/calibrations/calibration.py +++ b/tsadar/utils/data_handling/calibration.py @@ -1,9 +1,197 @@ from typing import Dict import numpy as np import scipy.io as sio -from os.path import join +import os -from tsadar.data_handleing.calibrations.sa_table import sa_lookup +BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "external") + + +def sa_lookup(beam): + """ + Creates the scattering angle dictionary with the scattering angles and their weights based of the chosen probe + beam. All values are precalculated. Available options are P9, B12, B15, B23, B26, B35, B42, B46, B58. + + Args: + beam: string with the name of the beam to be used as a probe + + Returns: + sa: dictionary with scattering angles in the 'sa' field and their relative weights in the 'weights' field + """ + if beam == "P9": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(53.637560, 66.1191, 10), + weights=np.array( + [ + 0.00702671050853565, + 0.0391423809738300, + 0.0917976667717670, + 0.150308544660150, + 0.189541011666141, + 0.195351560740507, + 0.164271879645061, + 0.106526733030044, + 0.0474753389486960, + 0.00855817305526778, + ] + ), + ) + elif beam == "B12": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(71.0195, 83.3160, 10), + weights=np.array( + [ + 0.007702, + 0.0404, + 0.09193, + 0.1479, + 0.1860, + 0.1918, + 0.1652, + 0.1083, + 0.05063, + 0.01004, + ] + ), + ) + elif beam == "B15": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(12.0404, 24.0132, 10), + weights=np.array( + [ + 0.0093239, + 0.04189, + 0.0912121, + 0.145579, + 0.182019, + 0.188055, + 0.163506, + 0.1104, + 0.0546822, + 0.0133327, + ] + ), + ) + elif beam == "B23": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(72.281, 84.3307, 10), + weights=np.array( + [ + 0.00945903, + 0.0430611, + 0.0925634, + 0.146705, + 0.182694, + 0.1881, + 0.162876, + 0.109319, + 0.0530607, + 0.0121616, + ] + ), + ) + elif beam == "B26": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(55.5636, 68.1058, 10), + weights=np.array( + [ + 0.00648619, + 0.0386019, + 0.0913923, + 0.150489, + 0.190622, + 0.195171, + 0.166389, + 0.105671, + 0.0470249, + 0.00815279, + ] + ), + ) + elif beam == "B35": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(32.3804, 44.6341, 10), + weights=np.array( + [ + 0.00851313, + 0.0417549, + 0.0926084, + 0.149182, + 0.187019, + 0.191523, + 0.16265, + 0.106842, + 0.049187, + 0.0107202, + ] + ), + ) + elif beam == "B42": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(155.667, 167.744, 10), + weights=np.array( + [ + 0.00490969, + 0.0257646, + 0.0601324, + 0.106076, + 0.155308, + 0.187604, + 0.19328, + 0.15702, + 0.0886447, + 0.0212603, + ] + ), + ) + elif beam == "B46": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(56.5615, 69.1863, 10), + weights=np.array( + [ + 0.00608081, + 0.0374307, + 0.0906716, + 0.140714, + 0.191253, + 0.197333, + 0.166164, + 0.106121, + 0.0464844, + 0.0077474, + ] + ), + ) + elif beam == "B58": + # Scattering angle in degrees for OMEGA TIM6 TS + sa = dict( + sa=np.linspace(119.093, 131.666, 10), + weights=np.array( + [ + 0.00549525, + 0.0337372, + 0.0819783, + 0.140084, + 0.186388, + 0.19855, + 0.174136, + 0.117517, + 0.0527003, + 0.00941399, + ] + ), + ) + else: + raise NotImplmentedError("Other probe geometrries are not yet supported") + + return sa def get_calibrations(shotNum, tstype, t0, CCDsize): @@ -170,7 +358,7 @@ def get_calibrations(shotNum, tstype, t0, CCDsize): magE = 5.13 / 0.36175 * 1.118 # um / px times strech factor accounting for tilt in view EPWtcc = 1024 - 503 # 562; - IAWtcc = 1024 - 450 #578 # 469; + IAWtcc = 1024 - 450 # 578 # 469; else: # needs to be updated with the calibrations from 7-26-22 @@ -204,7 +392,7 @@ def get_calibrations(shotNum, tstype, t0, CCDsize): axisxI = axisxI - IAWtcc * magI # axisxI = axisxI + 200 else: - imp = sio.loadmat(join("files", "angsFRED.mat"), variable_names="angsFRED") + imp = sio.loadmat(os.path.join(BASE_FILES_PATH, "files", "angsFRED.mat"), variable_names="angsFRED") axisxE = imp["angsFRED"][0, :] # axisxE = np.vstack(np.loadtxt("files/angsFRED.txt")) axisxI = np.arange(1, CCDsize[1] + 1) @@ -234,7 +422,9 @@ def get_scattering_angles(config: Dict) -> Dict: sa = sa_lookup(config["data"]["probe_beam"]) else: # Scattering angle in degrees for Artemis - imp = sio.loadmat(join("files", "angleWghtsFredfine.mat"), variable_names="weightMatrix") + imp = sio.loadmat( + os.path.join(BASE_FILES_PATH, "files", "angleWghtsFredfine.mat"), variable_names="weightMatrix" + ) weights = imp["weightMatrix"] sa = dict(sa=np.arange(19, 139.5, 0.5), weights=weights) return sa diff --git a/tsadar/data_handleing/data_visualizer.py b/tsadar/utils/data_handling/data_visualizer.py similarity index 100% rename from tsadar/data_handleing/data_visualizer.py rename to tsadar/utils/data_handling/data_visualizer.py diff --git a/tsadar/data_handleing/lam_parse.py b/tsadar/utils/data_handling/lam_parse.py similarity index 100% rename from tsadar/data_handleing/lam_parse.py rename to tsadar/utils/data_handling/lam_parse.py diff --git a/tsadar/data_handleing/load_ts_data.py b/tsadar/utils/data_handling/load_ts_data.py similarity index 70% rename from tsadar/data_handleing/load_ts_data.py rename to tsadar/utils/data_handling/load_ts_data.py index ae9b25b1..7dd6f7f3 100644 --- a/tsadar/data_handleing/load_ts_data.py +++ b/tsadar/utils/data_handling/load_ts_data.py @@ -1,12 +1,12 @@ from os.path import join from os import listdir import os -from pyhdf.SD import SD, SDC import numpy as np from scipy.signal import find_peaks -from tsadar.process.warpcorr import perform_warp_correction +from tsadar.utils.process.warpcorr import perform_warp_correction + +BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "external") -BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "aux") def loadData(sNum, sDay, loadspecs, custom_path=False): """ @@ -42,36 +42,38 @@ def loadData(sNum, sDay, loadspecs, custom_path=False): else: folder = join(BASE_FILES_PATH, "data") - file_list = listdir(folder) + print(f"{file_list=}") + print(f"{sNum=}") files = [name for name in file_list if str(sNum) in name] + print(f"Files found: {files}") t0 = [0, 0] - #print(sNum) - #print(files) for fl in files: - if "epw" in fl or "EPW" in fl: + if "epw" in fl.casefold(): hdfnameE = join(folder, fl) - if "ccd" in fl or "CCD" in fl: + if "ccd" in fl.casefold(): xlab = "Radius (\mum)" specType = "imaging" else: xlab = "Time (ps)" specType = "temporal" - if "iaw" in fl or "IAW" in fl: + if "iaw" in fl.casefold(): hdfnameI = join(folder, fl) - if "ccd" in fl or "CCD" in fl: + if "ccd" in fl.casefold(): xlab = "Radius (\mum)" specType = "imaging" else: xlab = "Time (ps)" specType = "temporal" - if "ats" in fl or "ATS" in fl: + if "ats" in fl.casefold(): hdfnameE = join(folder, fl) specType = "angular" xlab = "Scattering angle (degrees)" if loadspecs["load_ion_spec"]: + from pyhdf.SD import SD, SDC + try: iDatfile = SD(hdfnameI, SDC.READ) sds_obj = iDatfile.select("Streak_array") # select sds @@ -96,31 +98,33 @@ def loadData(sNum, sDay, loadspecs, custom_path=False): iDat = [] if loadspecs["load_ele_spec"]: - try: - eDatfile = SD(hdfnameE, SDC.READ) - sds_obj = eDatfile.select("Streak_array") # select sds - eDat = sds_obj.get() # get sds data - eDat = eDat.astype("float64") - eDat = eDat[0, :, :] - eDat[1, :, :] - - if specType == "angular": - eDat = np.fliplr(eDat) - print("found angular data") - elif specType == "temporal": - eDat = perform_warp_correction(eDat) - elif specType == "imaging": - eDat = np.rot90(np.squeeze(eDat), 3) - - if specType == "temporal" and loadspecs["absolute_timing"]: - # this sets t0 by locating the fiducial and placing t0 164px earlier - fidu = np.sum(eDat[0:100, :], 0) - res = find_peaks(fidu, prominence=1000, width=10) - peak_center = res[1]["left_ips"][0] + (res[1]["right_ips"][0] - res[1]["left_ips"][0]) / 2.0 - t0[1] = round(peak_center - 95) - except BaseException: - print("Unable to find EPW") - eDat = [] - loadspecs["load_ele_spec"] = False + from pyhdf.SD import SD, SDC + + # try: + eDatfile = SD(hdfnameE, SDC.READ) + sds_obj = eDatfile.select("Streak_array") # select sds + eDat = sds_obj.get() # get sds data + eDat = eDat.astype("float64") + eDat = eDat[0, :, :] - eDat[1, :, :] + + if specType == "angular": + eDat = np.fliplr(eDat) + print("found angular data") + elif specType == "temporal": + eDat = perform_warp_correction(eDat) + elif specType == "imaging": + eDat = np.rot90(np.squeeze(eDat), 3) + + if specType == "temporal" and loadspecs["absolute_timing"]: + # this sets t0 by locating the fiducial and placing t0 164px earlier + fidu = np.sum(eDat[0:100, :], 0) + res = find_peaks(fidu, prominence=1000, width=10) + peak_center = res[1]["left_ips"][0] + (res[1]["right_ips"][0] - res[1]["left_ips"][0]) / 2.0 + t0[1] = round(peak_center - 95) + # except BaseException: + # print("Unable to find EPW") + # eDat = [] + # loadspecs["load_ele_spec"] = False else: eDat = [] diff --git a/tsadar/misc/utils.py b/tsadar/utils/misc.py similarity index 90% rename from tsadar/misc/utils.py rename to tsadar/utils/misc.py index 09a563b2..8922a023 100644 --- a/tsadar/misc/utils.py +++ b/tsadar/utils/misc.py @@ -1,8 +1,8 @@ -import os, mlflow, flatten_dict, boto3, yaml, botocore, shutil, time, tempfile +import os, mlflow, flatten_dict, boto3, botocore, shutil, time, tempfile from urllib.parse import urlparse -def log_params(cfg): +def log_mlflow(cfg, which="params"): """ Logs the parameters form the input deck in the parameters section of MLFlow. @@ -16,16 +16,22 @@ def log_params(cfg): flattened_dict = flatten_dict.flatten(cfg, reducer="dot") # dict(flatdict.FlatDict(cfg, delimiter=".")) num_entries = len(flattened_dict.keys()) + if which == "params": + log_func = mlflow.log_params + elif which == "metrics": + log_func = mlflow.log_metrics + else: + raise ValueError("which must be either 'params' or 'metrics'") + if num_entries > 100: num_batches = num_entries % 100 fl_list = list(flattened_dict.items()) for i in range(num_batches): end_ind = min((i + 1) * 100, num_entries) trunc_dict = {k: v for k, v in fl_list[i * 100 : end_ind]} - mlflow.log_params(trunc_dict) + log_func(trunc_dict) else: - mlflow.log_params(flattened_dict) - + log_func(flattened_dict) def update(base_dict, new_dict): diff --git a/tsadar/utils/plotting/__init__.py b/tsadar/utils/plotting/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tsadar/plotting/edf_movie.py b/tsadar/utils/plotting/edf_movie.py similarity index 100% rename from tsadar/plotting/edf_movie.py rename to tsadar/utils/plotting/edf_movie.py diff --git a/tsadar/plotting/example_plot.py b/tsadar/utils/plotting/example_plot.py similarity index 100% rename from tsadar/plotting/example_plot.py rename to tsadar/utils/plotting/example_plot.py diff --git a/tsadar/plotting/lineout_plot.py b/tsadar/utils/plotting/lineout_plot.py similarity index 92% rename from tsadar/plotting/lineout_plot.py rename to tsadar/utils/plotting/lineout_plot.py index 6497eed3..40daeb18 100644 --- a/tsadar/plotting/lineout_plot.py +++ b/tsadar/utils/plotting/lineout_plot.py @@ -43,12 +43,10 @@ def lineout_plot(data, fits, sqdev, yaxis, ylim, s_ind, e_ind, titlestr, filenam ax[0][col].set_ylabel("Amp (arb. units)") ax[0][col].legend(fontsize=14) ax[0][col].grid() - ax[0][col].set_ylim(ylim) + # ax[0][col].set_ylim(ylim) ax[1][col].plot( - yaxis[col][s_ind[col] : e_ind[col]], - np.squeeze(sqdev[col][s_ind[col] : e_ind[col]]), - label="Residual", + yaxis[col][s_ind[col] : e_ind[col]], np.squeeze(sqdev[col][s_ind[col] : e_ind[col]]), label="Residual" ) ax[1][col].set_xlabel("Wavelength (nm)") ax[1][col].set_ylabel(r"$\chi_i^2$") diff --git a/tsadar/plotting/plotters.py b/tsadar/utils/plotting/plotters.py similarity index 86% rename from tsadar/plotting/plotters.py rename to tsadar/utils/plotting/plotters.py index f4b1ce34..f8cc3861 100644 --- a/tsadar/plotting/plotters.py +++ b/tsadar/utils/plotting/plotters.py @@ -4,7 +4,7 @@ import matplotlib.pyplot as plt import xarray as xr -from tsadar.plotting.lineout_plot import lineout_plot +from tsadar.utils.plotting.lineout_plot import lineout_plot def get_final_params(config, best_weights, all_axes, td): @@ -30,9 +30,10 @@ def get_final_params(config, best_weights, all_axes, td): for species in best_weights.keys(): for k, v in best_weights[species].items(): if k == "fe": - fitted_dist =True - dist[k] = v.squeeze() - dist["v"] = config["parameters"][species]["fe"]["velocity"] + # fitted_dist = True + # dist[k] = v.squeeze() + # dist["v"] = config["parameters"][species]["fe"]["velocity"] + pass else: all_params[k + "_" + species] = pandas.Series(v.reshape(-1)) # if np.shape(v)[1] > 1: @@ -45,18 +46,17 @@ def get_final_params(config, best_weights, all_axes, td): if config["other"]["extraoptions"]["load_ion_spec"]: final_params.insert(0, all_axes["x_label"], np.array(all_axes["iaw_x"][config["data"]["lineouts"]["pixelI"]])) final_params.insert(0, "lineout pixel", config["data"]["lineouts"]["pixelI"]) - elif config["other"]["extraoptions"]["spectype"] != "angular_full": + elif config["other"]["extraoptions"]["spectype"] != "angular_full": final_params.insert(0, all_axes["x_label"], np.array(all_axes["epw_x"][config["data"]["lineouts"]["pixelE"]])) final_params.insert(0, "lineout pixel", config["data"]["lineouts"]["pixelE"]) final_params.to_csv(os.path.join(td, "csv", "learned_parameters.csv")) - if fitted_dist: - if len(np.shape(dist['fe']))==1: - final_dist = pandas.DataFrame({'fe':[l for l in dist['fe']], 'vx':[vx for vx in dist['v']]}) - elif len(np.shape(dist['fe']))==2: - final_dist = pandas.DataFrame(data=dist['fe'], columns=dist['v'][0][0], index=dist['v'][0][:,0]) - #final_dist = pandas.DataFrame({'fe':[l for l in dist['fe']], 'vx':[vx for vx in dist['v'][0]], 'vy':[vy for vy in dist['v'][1]]}) + if len(np.shape(dist["fe"])) == 1: + final_dist = pandas.DataFrame({"fe": [l for l in dist["fe"]], "vx": [vx for vx in dist["v"]]}) + elif len(np.shape(dist["fe"])) == 2: + final_dist = pandas.DataFrame(data=dist["fe"], columns=dist["v"][0][0], index=dist["v"][0][:, 0]) + # final_dist = pandas.DataFrame({'fe':[l for l in dist['fe']], 'vx':[vx for vx in dist['v'][0]], 'vy':[vy for vy in dist['v'][1]]}) final_dist.to_csv(os.path.join(td, "csv", "learned_dist.csv")) return all_params | dist @@ -80,35 +80,34 @@ def plot_final_params(config, all_params, sigmas_ds, td): """ for species in all_params.keys(): for param in all_params[species].keys(): - for i in range(all_params[species][param].shape[1]): - vals = pandas.Series(all_params[species][param][:, i].squeeze(), dtype=float) - fig, ax = plt.subplots(1, 1, figsize=(4, 4)) - lineouts = np.array(config["data"]["lineouts"]["val"]) - std = vals.rolling(config["plotting"]["rolling_std_width"], min_periods=1, center=True).std() - - ax.plot(lineouts, vals) - ax.fill_between( - lineouts, - (vals.values - config["plotting"]["n_sigmas"] * sigmas_ds[param + "_" + species].values), - (vals.values + config["plotting"]["n_sigmas"] * sigmas_ds[param + "_" + species].values), - color="b", - alpha=0.1, - ) - ax.fill_between( - lineouts, - (vals.values - config["plotting"]["n_sigmas"] * std.values), - (vals.values + config["plotting"]["n_sigmas"] * std.values), - color="r", - alpha=0.1, - ) - ax.set_xlabel("lineout", fontsize=14) - ax.grid() - ax.set_ylim(0.8 * np.min(vals), 1.2 * np.max(vals)) - ax.set_ylabel(param, fontsize=14) - fig.savefig( - os.path.join(td, "plots", "learned_" + param + "_" + species + "_" + str(i) + ".png"), - bbox_inches="tight", - ) + vals = pandas.Series(all_params[species][param], dtype=float) + fig, ax = plt.subplots(1, 1, figsize=(4, 4)) + lineouts = np.array(config["data"]["lineouts"]["val"]) + std = vals.rolling(config["plotting"]["rolling_std_width"], min_periods=1, center=True).std() + + ax.plot(lineouts, vals) + ax.fill_between( + lineouts, + (vals.values - config["plotting"]["n_sigmas"] * sigmas_ds[param + "_" + species].values), + (vals.values + config["plotting"]["n_sigmas"] * sigmas_ds[param + "_" + species].values), + color="b", + alpha=0.1, + ) + ax.fill_between( + lineouts, + (vals.values - config["plotting"]["n_sigmas"] * std.values), + (vals.values + config["plotting"]["n_sigmas"] * std.values), + color="r", + alpha=0.1, + ) + ax.set_xlabel("lineout", fontsize=14) + ax.grid() + ax.set_ylim(0.8 * np.min(vals), 1.2 * np.max(vals)) + ax.set_ylabel(param, fontsize=14) + fig.savefig( + os.path.join(td, "plots", "learned_" + param + "_" + species + ".png"), + bbox_inches="tight", + ) return @@ -217,6 +216,24 @@ def plot_dist(config, ele_species, final_params, sigma_fe, td): ax[2].set_ylabel("f_e") ax[2].grid() else: + fig, ax = plt.subplots(1, 2, figsize=(12, 4), tight_layout=True) + c = ax[0].contourf(final_params["v"][0], final_params["v"][1], final_params["fe"].T) + ax[0].set_xlabel("$v_x/v_{th}$", fontsize=14) + ax[0].set_ylabel("$v_y/v_{th}$", fontsize=14) + ax[0].set_title("$f_e$", fontsize=14) + fig.colorbar(c) + + c = ax[1].contourf(final_params["v"][0], final_params["v"][1], np.log10(final_params["fe"].T)) + ax[1].set_xlabel("$v_x/v_{th}$", fontsize=14) + ax[1].set_ylabel("$v_y/v_{th}$", fontsize=14) + ax[1].set_title("log$_{10}(f_e)$", fontsize=14) + fig.colorbar(c) + + print(np.isnan(final_params["fe"]).any()) + + fig.savefig(os.path.join(td, "plots", "fe_contourf.png"), bbox_inches="tight") + plt.close() + fig = plt.figure(figsize=(15, 5)) ax = fig.add_subplot(1, 3, 1, projection="3d") curfe = np.where(final_params["fe"] < -50.0, -50.0, final_params["fe"]) @@ -438,10 +455,10 @@ def plot_ts_data(config, fits, all_data, all_axes, td): Returns: """ if config["other"]["extraoptions"]["load_ion_spec"]: - coords = (all_axes["x_label"], np.array(all_axes["iaw_x"][config["data"]["lineouts"]["pixelI"]])), ( - "Wavelength", - all_axes["iaw_y"], - ) + coords_x = all_axes["x_label"], np.array(all_axes["iaw_x"][config["data"]["lineouts"]["pixelI"]]) + coords_y = "Wavelength", all_axes["iaw_y"] + coords = coords_x, coords_y + ion_dat = {"fit": fits["ion"], "data": all_data["i_data"]} # fit vs data storage and plot ion_savedata = xr.Dataset({k: xr.DataArray(v, coords=coords) for k, v in ion_dat.items()}) @@ -516,28 +533,22 @@ def plot_2D_data_vs_fit( Returns: """ + + if "angular" in config["other"]["extraoptions"]["spectype"]: + vmin, vmax = 0.0, 1.5 + else: + vmin = np.amin(data) if config["plotting"]["data_cbar_l"] == "data" else config["plotting"]["data_cbar_l"] + vmax = np.amax(data) if config["plotting"]["data_cbar_u"] == "data" else config["plotting"]["data_cbar_u"] + # Create fit and data image fig, ax = plt.subplots(1, 2, figsize=(12, 5), tight_layout=True) - pc = ax[0].pcolormesh( - x, - y, - fit, - shading="nearest", - cmap="gist_ncar", - vmin=np.amin(data) if config["plotting"]["data_cbar_l"] == "data" else config["plotting"]["data_cbar_l"], - vmax=np.amax(data) if config["plotting"]["data_cbar_u"] == "data" else config["plotting"]["data_cbar_u"], - ) + pc = ax[0].pcolormesh(x, y, fit, shading="nearest", cmap="gist_ncar") # , vmin=vmin, vmax=vmax) ax[0].set_xlabel(xlabel) ax[0].set_ylabel(ylabel) - ax[1].pcolormesh( - x, - y, - data, - shading="nearest", - cmap="gist_ncar", - vmin=np.amin(data) if config["plotting"]["data_cbar_l"] == "data" else config["plotting"]["data_cbar_l"], - vmax=np.amax(data) if config["plotting"]["data_cbar_u"] == "data" else config["plotting"]["data_cbar_u"], - ) + ax[1].pcolormesh(x, y, data, shading="nearest", cmap="gist_ncar") + # vmin=np.amin(data) if config["plotting"]["data_cbar_l"] == "data" else config["plotting"]["data_cbar_l"], + # vmax=np.amax(data) if config["plotting"]["data_cbar_u"] == "data" else config["plotting"]["data_cbar_u"], + # ) ax[1].set_xlabel(xlabel) ax[1].set_ylabel(ylabel) fig.colorbar(pc) @@ -640,7 +651,8 @@ def model_v_actual(config, all_data, all_axes, fits, losses, red_losses, sqdevs, for i in range(num_plots): # plot model vs actual titlestr = ( - r"|Error|$^2$" + f" = {sorted_losses[i]:.2e}, line out # {all_axes['iaw_x'][config['data']['lineouts']['pixelI'][loss_inds[i]]]}" + r"|Error|$^2$" + + f" = {sorted_losses[i]:.2e}, line out # {all_axes['iaw_x'][config['data']['lineouts']['pixelI'][loss_inds[i]]]}" ) filename = f"loss={sorted_losses[i]:.2e}-reduced_loss={sorted_red_losses[i]:.2e}-lineout={config['data']['lineouts']['val'][loss_inds[i]]}.png" diff --git a/tsadar/utils/process/__init__.py b/tsadar/utils/process/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tsadar/process/correct_throughput.py b/tsadar/utils/process/correct_throughput.py similarity index 99% rename from tsadar/process/correct_throughput.py rename to tsadar/utils/process/correct_throughput.py index 333405ea..0634c6e3 100644 --- a/tsadar/process/correct_throughput.py +++ b/tsadar/utils/process/correct_throughput.py @@ -6,7 +6,7 @@ from os.path import join import os -BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "aux") +BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "external") def correctThroughput(data, tstype, axisy, shotNum): diff --git a/tsadar/process/evaluate_background.py b/tsadar/utils/process/evaluate_background.py similarity index 98% rename from tsadar/process/evaluate_background.py rename to tsadar/utils/process/evaluate_background.py index a0a7df77..a98124f8 100644 --- a/tsadar/process/evaluate_background.py +++ b/tsadar/utils/process/evaluate_background.py @@ -6,8 +6,8 @@ from scipy.signal import convolve2d as conv2 -from tsadar.data_handleing.load_ts_data import loadData -from tsadar.process.correct_throughput import correctThroughput +from ..data_handling.load_ts_data import loadData +from .correct_throughput import correctThroughput def get_shot_bg(config, shotNum, axisyE, elecData): diff --git a/tsadar/process/feature_detector.py b/tsadar/utils/process/feature_detector.py similarity index 100% rename from tsadar/process/feature_detector.py rename to tsadar/utils/process/feature_detector.py diff --git a/tsadar/process/lineouts.py b/tsadar/utils/process/lineouts.py similarity index 98% rename from tsadar/process/lineouts.py rename to tsadar/utils/process/lineouts.py index 10277ec8..98cd5ebf 100644 --- a/tsadar/process/lineouts.py +++ b/tsadar/utils/process/lineouts.py @@ -3,7 +3,7 @@ from collections import defaultdict import numpy as np -from tsadar.process.evaluate_background import get_lineout_bg +from tsadar.utils.process.evaluate_background import get_lineout_bg def get_lineouts( diff --git a/tsadar/utils/process/postprocess.py b/tsadar/utils/process/postprocess.py new file mode 100644 index 00000000..69fde49e --- /dev/null +++ b/tsadar/utils/process/postprocess.py @@ -0,0 +1,336 @@ +from typing import Dict +from collections import defaultdict + +import time, tempfile, mlflow, os, copy + +import numpy as np +import scipy.optimize as spopt + +from tsadar.utils.plotting import plotters +from tsadar.inverse.loss_function import LossFunction + + +def recalculate_with_chosen_weights( + config: Dict, sample_indices, all_data: Dict, loss_fn: LossFunction, calc_sigma: bool, fitted_weights: Dict +): + """ + Gets parameters and the result of the full forward pass i.e. fits + + + Args: + config: Dict- configuration dictionary built from input deck + sample_indices: + all_data: Dict- contains the electron data, ion data, and their respective amplitudes + loss_fn: Instance of the LossFunction class + fitted_weights: Dict- best values of the parameters returned by the minimizer + + Returns: + + """ + + losses = np.zeros_like(sample_indices, dtype=np.float64) + sample_indices.sort() + batch_indices = np.reshape(sample_indices, (-1, config["optimizer"]["batch_size"])) + + # turn list of dictionaries into dictionary of lists + all_params = {k: defaultdict(list) for k in config["parameters"].keys()} + + for _fw in fitted_weights: + unnormed_params = _fw.get_unnormed_params() + for k in all_params.keys(): + for k2 in unnormed_params[k].keys(): + all_params[k][k2].append(unnormed_params[k][k2]) + + # concatenate all the lists in the dictionary + num_params = 0 + for k in all_params.keys(): + for k2 in all_params[k].keys(): + all_params[k][k2] = np.concatenate(all_params[k][k2]) + num_params += len(all_params[k][k2]) + + fits = {} + sqdevs = {} + fits["ion"] = np.zeros(all_data["i_data"].shape) + sqdevs["ion"] = np.zeros(all_data["i_data"].shape) + fits["ele"] = np.zeros(all_data["e_data"].shape) + sqdevs["ele"] = np.zeros(all_data["e_data"].shape) + + if config["other"]["extraoptions"]["load_ion_spec"]: + sigmas = np.zeros((all_data["i_data"].shape[0], num_params)) + + if config["other"]["extraoptions"]["load_ele_spec"]: + sigmas = np.zeros((all_data["e_data"].shape[0], num_params)) + + if config["other"]["extraoptions"]["spectype"] == "angular_full": + batch = { + "e_data": all_data["e_data"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + "e_amps": all_data["e_amps"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + "i_data": all_data["i_data"], + "i_amps": all_data["i_amps"], + "noise_e": all_data["noiseE"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + "noise_i": all_data["noiseI"][config["data"]["lineouts"]["start"] : config["data"]["lineouts"]["end"], :], + } + losses, sqds, used_points, [ThryE, _, params] = loss_fn.array_loss(fitted_weights, batch) + fits["ele"] = ThryE + sqdevs["ele"] = sqds["ele"] + + for species in all_params.keys(): + for k in all_params[species].keys(): + if k != "fe": + # all_params[k] = np.concatenate([all_params[k], params[k].reshape(-1)]) + all_params[species][k] = params[species][k].reshape(-1) + else: + all_params[species][k] = params[species][k] + + if calc_sigma: + # this line may need to be omited since the weights may be transformed by line 77 + active_params = loss_fn.spec_calc.get_plasma_parameters(fitted_weights, return_static_params=False) + hess = loss_fn.h_loss_wrt_params(active_params, batch) + sigmas = get_sigmas(hess, config["optimizer"]["batch_size"]) + print(f"Number of 0s in sigma: {len(np.where(sigmas==0)[0])}") + + else: + for i_batch, inds in enumerate(batch_indices): + batch = { + "e_data": all_data["e_data"][inds], + "e_amps": all_data["e_amps"][inds], + "i_data": all_data["i_data"][inds], + "i_amps": all_data["i_amps"][inds], + "noise_e": all_data["noiseE"][inds], + "noise_i": all_data["noiseI"][inds], + } + + loss, sqds, used_points, ThryE, ThryI, params = loss_fn.array_loss(fitted_weights[i_batch], batch) + + if calc_sigma: + hess = loss_fn.h_loss_wrt_params(fitted_weights[i_batch], batch) + try: + hess = loss_fn.h_loss_wrt_params(fitted_weights[i_batch], batch) + except: + print("Error calculating Hessian, no hessian based uncertainties have been calculated") + calc_sigma = False + + losses[inds] = loss + sqdevs["ele"][inds] = sqds["ele"] + sqdevs["ion"][inds] = sqds["ion"] + if calc_sigma: + sigmas[inds] = get_sigmas(hess, config["optimizer"]["batch_size"]) + # print(f"Number of 0s in sigma: {len(np.where(sigmas==0)[0])}") number of negatives? + + fits["ele"][inds] = ThryE + fits["ion"][inds] = ThryI + + return losses, sqdevs, used_points, fits, sigmas, all_params + + +def get_sigmas(hess: Dict, batch_size: int) -> Dict: + """ + Calculates the variance using the hessian with respect to the parameters and then using the hessian values + as the inverse of the covariance matrix and then inverting that. Negatives in the inverse hessian normally indicate + non-optimal points, to represent this in the final result the uncertainty of those values are reported as negative. + + + Args: + hess: Hessian dictionary, the field for each fitted parameter has subfields corresponding to each of the other + fitted parameters. Within each nested subfield is a batch_size x batch_size array with the hessian values + for that parameter combination and that batch. The cross terms of this array are zero since separate + lineouts within a batch do not affect each other, they are therefore discarded + batch_size: int- number of lineouts in a batch + + Returns: + sigmas: batch_size x number_of_parameters array with the uncertainty values for each parameter + """ + sizes = { + key + species: hess[species][key][species][key].shape[1] + for species in hess.keys() + for key in hess[species].keys() + } + # sizes = {key: hess[key][key].shape[1] for key in keys} + actual_num_params = sum([v for k, v in sizes.items()]) + sigmas = np.zeros((batch_size, actual_num_params)) + + for i in range(batch_size): + temp = np.zeros((actual_num_params, actual_num_params)) + k1 = 0 + for species1 in hess.keys(): + for key1 in hess[species1].keys(): + k2 = 0 + for species2 in hess.keys(): + for key2 in hess[species2].keys(): + temp[k1, k2] = np.squeeze(hess[species1][key1][species2][key2])[i, i] + k2 += 1 + k1 += 1 + + # xc = 0 + # for k1, param in enumerate(keys): + # yc = 0 + # for k2, param2 in enumerate(keys): + # if i > 0: + # temp[k1, k2] = np.squeeze(hess[param][param2])[i, i] + # else: + # temp[xc : xc + sizes[param], yc : yc + sizes[param2]] = hess[param][param2][0, :, 0, :] + # + # yc += sizes[param2] + # xc += sizes[param] + + # print(temp) + inv = np.linalg.inv(temp) + # print(inv) + + sigmas[i, :] = np.sign(np.diag(inv)) * np.sqrt(np.abs(np.diag(inv))) + # for k1, param in enumerate(keys): + # sigmas[i, xc : xc + sizes[param]] = np.diag( + # np.sign(inv[xc : xc + sizes[param], xc : xc + sizes[param]]) + # * np.sqrt(np.abs(inv[xc : xc + sizes[param], xc : xc + sizes[param]])) + # ) + # print(sigmas[i, k1]) + # change sigmas into a dictionary? + + return sigmas + + +def postprocess(config, sample_indices, all_data: Dict, all_axes: Dict, loss_fn, sa, fitted_weights): + t1 = time.time() + + for species in config["parameters"].keys(): + if "electron" == species: + elec_species = species + + if config["other"]["extraoptions"]["spectype"] != "angular_full" and config["other"]["refit"]: + refit_bad_fits(config, sample_indices, all_data, loss_fn, sa, fitted_weights) + + mlflow.log_metrics({"refitting time": round(time.time() - t1, 2)}) + + with tempfile.TemporaryDirectory() as td: + _ = [os.makedirs(os.path.join(td, dirname), exist_ok=True) for dirname in ["plots", "binary", "csv"]] + if config["other"]["extraoptions"]["spectype"] == "angular_full": + t1 = process_angular_data( + config, sample_indices, all_data, all_axes, loss_fn, fitted_weights, t1, elec_species, td + ) + + else: + t1, final_params = process_data(config, sample_indices, all_data, all_axes, loss_fn, fitted_weights, t1, td) + + mlflow.log_artifacts(td) + mlflow.log_metrics({"plotting time": round(time.time() - t1, 2)}) + + mlflow.set_tag("status", "done plotting") + + return final_params + + +def refit_bad_fits(config, batch_indices, all_data, loss_fn, sa, fitted_weights): + losses_init, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( + config, batch_indices, all_data, loss_fn, False, fitted_weights + ) + + # refit bad fits + red_losses_init = losses_init / (1.1 * (used_points - len(all_params))) + true_batch_size = config["optimizer"]["batch_size"] + # config["optimizer"]["batch_size"] = 1 + mlflow.log_metrics({"number of fits": len(batch_indices.flatten())}) + mlflow.log_metrics({"number of refits": int(np.sum(red_losses_init > config["other"]["refit_thresh"]))}) + + for i in batch_indices.flatten()[red_losses_init > config["other"]["refit_thresh"]]: + if i == 0: + continue + + batch = { + "e_data": np.reshape(all_data["e_data"][i], (1, -1)), + "e_amps": np.reshape(all_data["e_amps"][i], (1, -1)), + "i_data": np.reshape(all_data["i_data"][i], (1, -1)), + "i_amps": np.reshape(all_data["i_amps"][i], (1, -1)), + "noise_e": np.reshape(all_data["noiseE"][i], (1, -1)), + "noise_i": np.reshape(all_data["noiseI"][i], (1, -1)), + } + + # previous_weights = {} + temp_cfg = copy.copy(config) + temp_cfg["optimizer"]["batch_size"] = 1 + for species in fitted_weights[(i - 1) // true_batch_size].keys(): + for key in fitted_weights[(i - 1) // true_batch_size][species].keys(): + if config["parameters"][species][key]["active"]: + temp_cfg["parameters"][species][key]["val"] = float( + fitted_weights[(i - 1) // true_batch_size][species][key][(i - 1) % true_batch_size] + ) + + loss_fn_refit = LossFunction(temp_cfg, sa, batch) + + # loss_fn_refit.flattened_weights, loss_fn_refit.unravel_pytree = ravel_pytree(previous_weights) + + res = spopt.minimize( + loss_fn_refit.vg_loss if config["optimizer"]["grad_method"] == "AD" else loss_fn_refit.loss, + np.copy(loss_fn_refit.flattened_weights), + args=batch, + method=config["optimizer"]["method"], + jac=True if config["optimizer"]["grad_method"] == "AD" else False, + bounds=loss_fn_refit.bounds, + options={"disp": True, "maxiter": config["optimizer"]["num_epochs"]}, + ) + cur_result = loss_fn_refit.unravel_pytree(res["x"]) + + for species in cur_result.keys(): + for key in cur_result[species].keys(): + fitted_weights[i // true_batch_size][species][key] = ( + fitted_weights[i // true_batch_size][species][key] + .at[i % true_batch_size] + .set(cur_result[species][key][0]) + ) + # fitted_weights[i // true_batch_size][species][key][i % true_batch_size] = cur_result[species][key] + + # for key in fitted_weights[i // true_batch_size].keys(): + # cur_value = cur_result[key][0, 0] + # new_vals = fitted_weights[i // true_batch_size][key] + # new_vals = new_vals.at[tuple([i % true_batch_size, 0])].set(cur_value) + # fitted_weights[i // true_batch_size][key] = new_vals + + config["optimizer"]["batch_size"] = true_batch_size + + +def process_data(config, sample_indices, all_data, all_axes, loss_fn, fitted_weights, t1, td): + losses, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( + config, sample_indices, all_data, loss_fn, config["other"]["calc_sigmas"], fitted_weights + ) + if "losses_init" not in locals(): + losses_init = losses + mlflow.log_metrics({"postprocessing time": round(time.time() - t1, 2)}) + mlflow.set_tag("status", "plotting") + t1 = time.time() + + final_params = plotters.get_final_params(config, all_params, all_axes, td) + + red_losses = plotters.plot_loss_hist(config, losses_init, losses, all_params, used_points, td) + savedata = plotters.plot_ts_data(config, fits, all_data, all_axes, td) + plotters.model_v_actual(config, all_data, all_axes, fits, losses, red_losses, sqdevs, td) + sigma_ds = plotters.save_sigmas_params(config, all_params, sigmas, all_axes, td) + plotters.plot_final_params(config, all_params, sigma_ds, td) + return t1, final_params + + +def process_angular_data(config, batch_indices, all_data, all_axes, loss_fn, fitted_weights, t1, elec_species, td): + best_weights_val = {} + best_weights_std = {} + if config["optimizer"]["num_mins"] > 1: + for k, v in fitted_weights.items(): + best_weights_val[k] = np.average(v, axis=0) # [0, :] + best_weights_std[k] = np.std(v, axis=0) # [0, :] + else: + best_weights_val = fitted_weights + + losses, sqdevs, used_points, fits, sigmas, all_params = recalculate_with_chosen_weights( + config, batch_indices, all_data, loss_fn, config["other"]["calc_sigmas"], best_weights_val + ) + + mlflow.log_metrics({"postprocessing time": round(time.time() - t1, 2)}) + mlflow.set_tag("status", "plotting") + t1 = time.time() + + final_params = plotters.get_final_params(config, all_params, all_axes, td) + if config["other"]["calc_sigmas"]: + sigma_fe = plotters.save_sigmas_fe(final_params, best_weights_std, sigmas, td) + else: + sigma_fe = np.zeros_like(final_params["fe"]) + savedata = plotters.plot_data_angular(config, fits, all_data, all_axes, td) + plotters.plot_ang_lineouts(used_points, sqdevs, losses, all_params, all_axes, savedata, td) + # plotters.plot_dist(config, elec_species, final_params, sigma_fe, td) + return t1 diff --git a/tsadar/process/prepare.py b/tsadar/utils/process/prepare.py similarity index 82% rename from tsadar/process/prepare.py rename to tsadar/utils/process/prepare.py index fd6b30dc..e1f7a110 100644 --- a/tsadar/process/prepare.py +++ b/tsadar/utils/process/prepare.py @@ -3,12 +3,12 @@ import numpy as np import os -from tsadar.process.evaluate_background import get_shot_bg -from tsadar.data_handleing.load_ts_data import loadData -from tsadar.process.correct_throughput import correctThroughput -from tsadar.data_handleing.calibrations.calibration import get_calibrations, get_scattering_angles -from tsadar.process.lineouts import get_lineouts -from tsadar.data_handleing.data_visualizer import launch_data_visualizer +from .evaluate_background import get_shot_bg +from ..data_handling.load_ts_data import loadData +from .correct_throughput import correctThroughput +from ..data_handling.calibration import get_calibrations, get_scattering_angles +from .lineouts import get_lineouts +from ..data_handling.data_visualizer import launch_data_visualizer def prepare_data(config: Dict, shotNum: int) -> Dict: @@ -43,7 +43,7 @@ def prepare_data(config: Dict, shotNum: int) -> Dict: # Calibrate axes [axisxE, axisxI, axisyE, axisyI, magE, stddev] = get_calibrations( - shotNum, config["other"]["extraoptions"]["spectype"], t0, config["other"]["CCDsize"] + shotNum, config["other"]["extraoptions"]["spectype"], t0, config["other"]["CCDsize"] ) all_axes = {"epw_x": axisxE, "epw_y": axisyE, "iaw_x": axisxI, "iaw_y": axisyI, "x_label": xlab} @@ -55,15 +55,13 @@ def prepare_data(config: Dict, shotNum: int) -> Dict: config["other"]["extraoptions"]["fit_EPWb"] = 0 config["other"]["extraoptions"]["fit_EPWr"] = 0 print("EPW data not loaded, omitting EPW fit") - #if config["other"]["extraoptions"]["first_guess"]: - #run code - #outs=first_guess(inputs) - #config["data"]["lineouts"]["start"]=start + # if config["other"]["extraoptions"]["first_guess"]: + # run code + # outs=first_guess(inputs) + # config["data"]["lineouts"]["start"]=start # Correct for spectral throughput if config["other"]["extraoptions"]["load_ele_spec"]: - elecData = correctThroughput( - elecData, config["other"]["extraoptions"]["spectype"], axisyE, shotNum - ) + elecData = correctThroughput(elecData, config["other"]["extraoptions"]["spectype"], axisyE, shotNum) # temp fix for zeros elecData = elecData + 0.1 @@ -111,10 +109,10 @@ def prepare_data(config: Dict, shotNum: int) -> Dict: all_data["i_data"] = all_data["i_amps"] = np.zeros(len(data_res_unit)) # changed this 8-29-23 not sure how it worked with =0? all_data["noiseI"] = np.zeros(np.shape(bg_res_unit)) - all_data['noiseE']=config["data"]["bgscaleE"]*bg_res_unit + 0.1 + all_data["noiseE"] = config["data"]["bgscaleE"] * bg_res_unit + 0.1 config["other"]["CCDsize"] = np.shape(data_res_unit) - #config["data"]["lineouts"]["start"] = int(config["data"]["lineouts"]["start"] / ang_res_unit) - #config["data"]["lineouts"]["end"] = int(config["data"]["lineouts"]["end"] / ang_res_unit) + # config["data"]["lineouts"]["start"] = int(config["data"]["lineouts"]["start"] / ang_res_unit) + # config["data"]["lineouts"]["end"] = int(config["data"]["lineouts"]["end"] / ang_res_unit) else: all_data = get_lineouts( diff --git a/tsadar/process/warpcorr.py b/tsadar/utils/process/warpcorr.py similarity index 99% rename from tsadar/process/warpcorr.py rename to tsadar/utils/process/warpcorr.py index 813ab987..225fe783 100644 --- a/tsadar/process/warpcorr.py +++ b/tsadar/utils/process/warpcorr.py @@ -3,7 +3,7 @@ import math, os from os.path import join, exists -BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "aux") +BASE_FILES_PATH = os.path.join(os.path.dirname(__file__), "..", "..", "external") def perform_warp_correction(warpedData, instrument="EPW", sweepSpeed=5, flatField=True): diff --git a/tsadar/misc/vector_tools.py b/tsadar/utils/vector_tools.py similarity index 97% rename from tsadar/misc/vector_tools.py rename to tsadar/utils/vector_tools.py index d6c1d915..4db70195 100644 --- a/tsadar/misc/vector_tools.py +++ b/tsadar/utils/vector_tools.py @@ -46,8 +46,8 @@ def vdot(a, b): Returns: c: ND-array or tuple of ND-arrays based off the operation being a dot product or scalar product """ - if type(a) is tuple: - if type(b) is tuple: + if isinstance(a, tuple): + if isinstance(b, tuple): return a[0] * b[0] + a[1] * b[1] else: return (a[0] * b, a[1] * b) @@ -67,8 +67,8 @@ def vdiv(a, b): c: tuple of ND-arrays """ # custom function for vector divided by a scalar - if type(a) is tuple: - if type(b) is tuple: + if isinstance(a, tuple): + if isinstance(b, tuple): raise ValueError("vector must be divided by a scalar") else: return (a[0] / b, a[1] / b)