Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

tune the pytorch MLPF model to be more similar to TF #165

Merged
merged 1 commit into from
Jan 26, 2023

Conversation

jpata
Copy link
Owner

@jpata jpata commented Jan 26, 2023

  • add some more complexity to the pytorch MLPF model
  • convolution layers separated for ID and regression part

@jpata
Copy link
Owner Author

jpata commented Jan 26, 2023

mlpf_loss.pdf
jet_res

$ python3 ssl_pipeline.py --data_split_mode domain_adaptation -\
  -prefix_VICReg VICReg_test_v4 --prefix_mlpf MLPF_test \
  --train_mlpf --native --n_epochs_VICReg 0 --batch_size_mlpf 200 \
  --n_epochs_mlpf 100 --patience 100 --width_mlpf 256 --embedding_dim 128

Will use GeForce RTX 2060 SUPER
Will use data split mode `domain_adaptation`.
Will use 79619 events to train VICReg
Will use 19905 events to validate VICReg
Will use 80160 events to train MLPF
Will use 19364 events to validate MLPF
...
------> Progressing to MLPF trainings...
Will use 80160 events for train
Will use 19364 events for valid
MLPF(
  (nn0): Sequential(
    (0): Linear(in_features=16, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=128, bias=True)
  )
  (conv_id): ModuleList(
    (0): GravNetConv(128, 128, k=8)
    (1): GravNetConv(128, 128, k=8)
    (2): GravNetConv(128, 128, k=8)
  )
  (conv_reg): ModuleList(
    (0): GravNetConv(128, 128, k=8)
    (1): GravNetConv(128, 128, k=8)
    (2): GravNetConv(128, 128, k=8)
  )
  (nn_id): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=6, bias=True)
  )
  (nn_reg): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=4, bias=True)
  )
  (nn_charge): Sequential(
    (0): Linear(in_features=384, out_features=256, bias=True)
    (1): ELU(alpha=1.0)
    (2): Linear(in_features=256, out_features=256, bias=True)
    (3): ELU(alpha=1.0)
    (4): Linear(in_features=256, out_features=256, bias=True)
    (5): ELU(alpha=1.0)
    (6): Linear(in_features=256, out_features=1, bias=True)
  )
)
MLPF model name: MLPF_test_native
Num of 'native-mlpf' parameters: 1038179
- Training native MLPF over 100 epochs
Will fix VICReg during mlpf training
---->Initiating a training run
loss_id=1.09 loss_momentum=0.30 loss_charge=0.15
---->Initiating a validation run
loss_id=0.93 loss_momentum=0.33 loss_charge=0.15
epoch=1 / 100 train_loss=1.5441 valid_loss=1.4099 stale=0 time=0.33m eta=33.1m
...
---->Initiating a training run
loss_id=0.51 loss_momentum=0.24 loss_charge=0.15
---->Initiating a validation run
loss_id=0.51 loss_momentum=0.28 loss_charge=0.15
epoch=100 / 100 train_loss=0.9044 valid_loss=0.9425 stale=1 time=0.3m eta=0.0m
----------------------------------------------------------
Done with training. Total training time is 30.638min

@jpata jpata merged commit 0d2d997 into main Jan 26, 2023
@jpata jpata deleted the pytorch_mlpf_tuning branch January 26, 2023 11:14
jpata added a commit that referenced this pull request Sep 15, 2023
jpata added a commit that referenced this pull request Sep 15, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant