Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Classification & Regression multilabel #815

Merged
merged 14 commits into from
Aug 12, 2024
Merged

Classification & Regression multilabel #815

merged 14 commits into from
Aug 12, 2024

Conversation

psinger
Copy link
Collaborator

@psinger psinger commented Aug 7, 2024

In this PR we add multi-target support for both classification and regression.

The following code-adaptation have been made:

  • For classification and regression the answer column can now be a selection of multiple targets, changing the setting to a tuple
  • Multi label will only work with BCE
  • For plotting and visualizations, the target columns are concatenated as string
  • The same applies for predictions
  • Classification predictions are now consistently post-process in postprocess_output instead of in the individual metrics
  • The validation csv file now contains the probabilities instead of the hard predictions
  • The validation pickle file now contains logits, probabilities and predictions
  • For regression, we set the regression head size by number of answer columns
  • For classification, the num_classes needs to still be set. Potentially would be easier to also start deriving that automatically, would also allow to get rid of all the error checks
  • Added metric tests for regression and adjusted for classification
  • Adjusted integration tests

When reviewing, main potential for bugs is with respect to post processing, loss and metrics.

Closes #805

@psinger psinger marked this pull request as ready for review August 8, 2024 19:30
preds = []
for col in np.arange(len(cfg.dataset.answer_column)):
preds.append(
np.round(output["predictions"][:, col].cpu().numpy(), 3).astype(str)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

specific reason for the rounding here?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah this is shown in the visualizations and with lots of digits not great to read

Copy link
Collaborator

@pascal-pfeiffer pascal-pfeiffer Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then, why not truncate in the visualization? The current implementation also rounds the downloadable predictions.

Actually, this even impacts metric calculation.

Copy link
Collaborator Author

@psinger psinger Aug 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It doesnt because output["predicted_text"] is not used there - and truncating in the vis is difficult because it supports all kinds of texts

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, right, that was changed in the PR. Still is odd for the exported dataframe to have rounded values

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmh, debatable. Would be tricky to add something different there. The pickle should be used anyways for exact values.

Copy link
Collaborator

@pascal-pfeiffer pascal-pfeiffer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All looks very clean and works well for the usual cases, thank you!

I think we might want to improve a bit on the documentation and on error handling. Especially, for CLI users.

The required change to use a list/tuple format for the answer column is a bit unfortunate as the error message is rather cryptic for CLI users when using previously well working yamls.

Traceback (most recent call last):
  File "/home/xxx/h2o-llmstudio/train.py", line 722, in <module>
    run(cfg=cfg)
  File "/home/xxx/h2o-llmstudio/train.py", line 530, in run
    train_dataset = get_train_dataset(train_df=train_df, cfg=cfg)
  File "/home/xxx/h2o-llmstudio/llm_studio/src/utils/data_utils.py", line 396, in get_train_dataset
    train_dataset: Dataset = cfg.dataset.dataset_class(
  File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 19, in __init__
    check_for_non_int_answers(cfg, df)
  File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 104, in check_for_non_int_answers
    x for x in df[column].values if not is_castable_to_int(x)
  File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/frame.py", line 4102, in __getitem__
    indexer = self.columns.get_loc(key)
  File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 3812, in get_loc
    raise KeyError(key) from err
KeyError: 'w'

We should handle that better as it is an expected source of error. For the other improvements on documentation, we can handle it in subsequent iterations.

Comment on lines +15 to +16
answer_column:
- binary_label
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This required change is a bit unfortunate as the error message is rather cryptic for CLI users when using previously well working yamls.

Traceback (most recent call last):
  File "/home/xxx/h2o-llmstudio/train.py", line 722, in <module>
    run(cfg=cfg)
  File "/home/xxx/h2o-llmstudio/train.py", line 530, in run
    train_dataset = get_train_dataset(train_df=train_df, cfg=cfg)
  File "/home/xxx/h2o-llmstudio/llm_studio/src/utils/data_utils.py", line 396, in get_train_dataset
    train_dataset: Dataset = cfg.dataset.dataset_class(
  File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 19, in __init__
    check_for_non_int_answers(cfg, df)
  File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 104, in check_for_non_int_answers
    x for x in df[column].values if not is_castable_to_int(x)
  File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/frame.py", line 4102, in __getitem__
    indexer = self.columns.get_loc(key)
  File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 3812, in get_loc
    raise KeyError(key) from err
KeyError: 'w'

documentation/docs/tooltips/experiments/_answer-column.mdx Outdated Show resolved Hide resolved
tests/integration/test_causal_regression_modeling_cfg.yaml Outdated Show resolved Hide resolved
Copy link
Collaborator

@pascal-pfeiffer pascal-pfeiffer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you for the quick changes, lgtm!

@psinger psinger merged commit aff8044 into main Aug 12, 2024
4 checks passed
@psinger psinger deleted the psi/multilabel branch August 12, 2024 15:37
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[FEATURE] Multilabel classification
2 participants