Skip to content

Commit

Permalink
Merge pull request #474 from CogStack/metacat_bug_resolve
Browse files Browse the repository at this point in the history
Fixing bug for metacat
  • Loading branch information
shubham-s-agarwal authored Aug 9, 2024
2 parents 33e32fd + 267cd4f commit b7658ee
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 7 deletions.
2 changes: 1 addition & 1 deletion medcat/utils/meta_cat/ml_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def train_model(model: nn.Module, data: List, config: ConfigMetaCAT, save_dir_pa
if config.train['compute_class_weights'] is True:
y_ = [x[2] for x in train_data]
class_weights = compute_class_weight(class_weight="balanced", classes=np.unique(y_), y=y_)
config.train['class_weights'] = class_weights
config.train['class_weights'] = class_weights.tolist()
logger.info(f"Class weights computed: {class_weights}")

class_weights = torch.FloatTensor(class_weights).to(device)
Expand Down
17 changes: 11 additions & 6 deletions medcat/utils/meta_cat/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,8 +114,13 @@ def __init__(self, config):
self.fc3 = nn.Linear(hidden_size_2, hidden_size_2)
# dense layer 3 (Output layer)
model_arch_config = config.model.model_architecture_config

if model_arch_config['fc3'] is True and model_arch_config['fc2'] is False:
logger.warning("FC3 can only be used if FC2 is also enabled. Enabling FC2...")
config.model.model_architecture_config['fc2'] = True

if model_arch_config is not None:
if model_arch_config['fc2'] is True or model_arch_config['fc3'] is True:
if model_arch_config['fc2'] is True:
self.fc4 = nn.Linear(hidden_size_2, self.num_labels)
else:
self.fc4 = nn.Linear(config.model.hidden_size, self.num_labels)
Expand Down Expand Up @@ -190,11 +195,11 @@ def forward(
x = self.relu(x)
x = self.dropout(x)

if self.config.model.model_architecture_config['fc3'] is True:
# fc3
x = self.fc3(x)
x = self.relu(x)
x = self.dropout(x)
if self.config.model.model_architecture_config['fc3'] is True:
# fc3
x = self.fc3(x)
x = self.relu(x)
x = self.dropout(x)
else:
# fc2
x = self.fc2(x)
Expand Down

0 comments on commit b7658ee

Please sign in to comment.