Skip to content

Commit

Permalink
Pytorch Deep Learning Backend (#179)
Browse files Browse the repository at this point in the history
* added pytorch models

* Add builder_util.py for Pytorch
Clean-up rise mobile v3 for Pytorch

* Add initial TrainerAgentPytorch

* Add more training loop information

* Continue TrainerAgentPytorch

* Change "use_mxnet_style" to "framework"
Update train_cnn.ipynb
Rename trainer_agent.py to trainer_agent_gluon.py

* Delete unused RiseV3 params
Add model definition to train_cnn.ipynb

* Fix PolicyHead definition

* Add pytorch metrcis
Update pytorch trainer

* Fix metric usage
Udpate Risev3 Fwd-Inference
Update train_cnn.ipynb

* Enable wdl and plys_to_end

* Add pytorch metrics

* Add torch.flatten() fro MSELoss

* Use torch.no_grad() for eval and empty_cache()
Correct update of self.batch_proc_tmp

* Add wdl_acc to metric evaluation

* Add torch.flatten() to "value_acc_sign"

* Add missing self.

* Fix tensorboard logging

* Refactoring of trainer_agent_pytorch.py

* Add train_util.py
Implement get_metrics()

* Fix metrics MSE and CrossEntropy

* Increase training batch eval to 25

* Update export_to_onnx()

* Add missing return

* Correct return of train loop
Add onnx export to train_cnn.ipynb
Add missing self. to val_metrics_best

* Integrate Pytorch backend into rl-loop

* Remove unused metrics.py
Enable pytorch training with soft targets

* update rl-loop

* Add pytorch export to convert_to_onnx.py

* Rename train_confi_template.py into train_config_template.py
  • Loading branch information
QueensGambit authored Jul 1, 2022
1 parent 6e43fa9 commit 3d93f03
Show file tree
Hide file tree
Showing 16 changed files with 1,837 additions and 176 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ class TrainConfig:

export_grad_histograms: bool = True

# Decide between 'pytorch', 'mxnet' and 'gluon' style for training
# Reinforcement Learning only works with gluon and pytorch atm
framework: str = 'pytorch'

# Boolean if the policy data is also defined in select_policy_from_plane representation
is_policy_from_plane_data: bool = False

Expand Down Expand Up @@ -103,10 +107,6 @@ class TrainConfig:
# total of training iterations
total_it: int = None

# Decide between mxnet and gluon style for training
# Reinforcement Learning only works with gluon (== False) atm
use_mxnet_style: bool = True

# adds a small mlp to infer the value loss from wdl and plys_to_end_output
use_mlp_wdl_ply: bool = False
# enables training with ply to end head
Expand All @@ -132,4 +132,4 @@ class TrainObjects:
momentum_schedule = None
metrics = None
variant_metrics = None

Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ def value_head(data, channels_value_head=4, value_kernelsize=1, act_type='relu',
:param grad_scale_value: Optional re-weighting of gradient
:param use_se: Indicates if a squeeze excitation layer shall be used
:param use_mix_conv: True, if an additional mix convolutional layer shall be used
:param use_wdl: If a win draw loss head shall be used
:param use_plys_to_end: If a plys to end prediction head shall be used
:param use_mlp_wdl_ply: If a small mlp with value output for the wdl and ply head shall be used
"""
# for value output
value_out = mx.sym.Convolution(data=data, num_filter=channels_value_head,
Expand Down
Loading

0 comments on commit 3d93f03

Please sign in to comment.