Skip to content

Commit

Permalink
fix similarity config (#180)
Browse files Browse the repository at this point in the history
* fix similarity config

* don't change numbers yet
  • Loading branch information
deepchatterjeeligo authored Nov 21, 2024
1 parent 6ba4b04 commit 80098ea
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions amplfi/train/configs/similarity/cbc.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -28,29 +28,33 @@ trainer:
init_args:
logging_interval: epoch
model:
class_path: amplfi.train.models.SimilarityModel
class_path: amplfi.train.models.similarity.SimilarityModel
init_args:
outdir: ${oc.env:AMPLFI_OUTDIR}
similarity_loss:
class_path: amplfi.train.losses.VICRegLoss
init_args:
lambda_param: 25
lambda_param: 25.0
mu_param: 25.0
nu_param: 1.0
# add path below when running `trainer.test` to load in a checkpoint
checkpoint: null
arch:
class_path: amplfi.train.architectures.embeddings.MultiModalPsd
class_path: amplfi.train.architectures.similarity.SimilarityEmbedding
init_args:
time_context_dim: 8
freq_context_dim: 128
time_layers: [5, 3, 3]
freq_layers: [5, 3, 3]
norm_layer:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
embedding:
class_path: amplfi.train.architectures.embeddings.MultiModalPsd
init_args:
groups: 8
patience: null
time_context_dim: 8
freq_context_dim: 128
time_layers: [5, 3, 3]
freq_layers: [5, 3, 3]
norm_layer:
class_path: ml4gw.nn.norm.GroupNorm1DGetter
init_args:
groups: 8
expander_factor: 3
patience: 5
learning_rate: 0.00071444
weight_decay: 0.0
data:
Expand Down

0 comments on commit 80098ea

Please sign in to comment.