diff --git a/medcat/utils/meta_cat/ml_utils.py b/medcat/utils/meta_cat/ml_utils.py index 79cedb9f3..3559ce1d8 100644 --- a/medcat/utils/meta_cat/ml_utils.py +++ b/medcat/utils/meta_cat/ml_utils.py @@ -200,7 +200,7 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa if config.train['compute_class_weights'] is True: y_ = [x[2] for x in train_data] class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_) - config.train['class_weights'] = class_weights + config.train['class_weights'] = class_weights.tolist() logger.info(f"Class weights computed: {class_weights}") class_weights = torch.FloatTensor(class_weights).to(device) diff --git a/medcat/utils/meta_cat/models.py b/medcat/utils/meta_cat/models.py index 70e235316..2fb20b7b6 100644 --- a/medcat/utils/meta_cat/models.py +++ b/medcat/utils/meta_cat/models.py @@ -114,8 +114,13 @@ def __init__(self, config): self.fc3 = nn.Linear(hidden_size_2, hidden_size_2) # dense layer 3 (Output layer) model_arch_config = config.model.model_architecture_config + + if model_arch_config['fc3'] is True and model_arch_config['fc2'] is False: + logger.warning("FC3 can only be used if FC2 is also enabled. Enabling FC2...") + config.model.model_architecture_config['fc2'] = True + if model_arch_config is not None: - if model_arch_config['fc2'] is True or model_arch_config['fc3'] is True: + if model_arch_config['fc2'] is True: self.fc4 = nn.Linear(hidden_size_2, self.num_labels) else: self.fc4 = nn.Linear(config.model.hidden_size, self.num_labels) @@ -190,11 +195,11 @@ def forward( x = self.relu(x) x = self.dropout(x) - if self.config.model.model_architecture_config['fc3'] is True: - # fc3 - x = self.fc3(x) - x = self.relu(x) - x = self.dropout(x) + if self.config.model.model_architecture_config['fc3'] is True: + # fc3 + x = self.fc3(x) + x = self.relu(x) + x = self.dropout(x) else: # fc2 x = self.fc2(x)