-
Notifications
You must be signed in to change notification settings - Fork 432
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Classification & Regression multilabel #815
Conversation
preds = [] | ||
for col in np.arange(len(cfg.dataset.answer_column)): | ||
preds.append( | ||
np.round(output["predictions"][:, col].cpu().numpy(), 3).astype(str) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
specific reason for the rounding here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yeah this is shown in the visualizations and with lots of digits not great to read
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Then, why not truncate in the visualization? The current implementation also rounds the downloadable predictions.
Actually, this even impacts metric calculation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It doesnt because output["predicted_text"]
is not used there - and truncating in the vis is difficult because it supports all kinds of texts
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, right, that was changed in the PR. Still is odd for the exported dataframe to have rounded values
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hmh, debatable. Would be tricky to add something different there. The pickle should be used anyways for exact values.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
All looks very clean and works well for the usual cases, thank you!
I think we might want to improve a bit on the documentation and on error handling. Especially, for CLI users.
The required change to use a list/tuple format for the answer column is a bit unfortunate as the error message is rather cryptic for CLI users when using previously well working yamls.
Traceback (most recent call last):
File "/home/xxx/h2o-llmstudio/train.py", line 722, in <module>
run(cfg=cfg)
File "/home/xxx/h2o-llmstudio/train.py", line 530, in run
train_dataset = get_train_dataset(train_df=train_df, cfg=cfg)
File "/home/xxx/h2o-llmstudio/llm_studio/src/utils/data_utils.py", line 396, in get_train_dataset
train_dataset: Dataset = cfg.dataset.dataset_class(
File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 19, in __init__
check_for_non_int_answers(cfg, df)
File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 104, in check_for_non_int_answers
x for x in df[column].values if not is_castable_to_int(x)
File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/frame.py", line 4102, in __getitem__
indexer = self.columns.get_loc(key)
File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 3812, in get_loc
raise KeyError(key) from err
KeyError: 'w'
We should handle that better as it is an expected source of error. For the other improvements on documentation, we can handle it in subsequent iterations.
answer_column: | ||
- binary_label |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This required change is a bit unfortunate as the error message is rather cryptic for CLI users when using previously well working yamls.
Traceback (most recent call last):
File "/home/xxx/h2o-llmstudio/train.py", line 722, in <module>
run(cfg=cfg)
File "/home/xxx/h2o-llmstudio/train.py", line 530, in run
train_dataset = get_train_dataset(train_df=train_df, cfg=cfg)
File "/home/xxx/h2o-llmstudio/llm_studio/src/utils/data_utils.py", line 396, in get_train_dataset
train_dataset: Dataset = cfg.dataset.dataset_class(
File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 19, in __init__
check_for_non_int_answers(cfg, df)
File "/home/xxx/h2o-llmstudio/llm_studio/src/datasets/text_causal_classification_ds.py", line 104, in check_for_non_int_answers
x for x in df[column].values if not is_castable_to_int(x)
File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/frame.py", line 4102, in __getitem__
indexer = self.columns.get_loc(key)
File "/home/xxx/.local/share/virtualenvs/h2o-llmstudio-tT3gHl3a/lib/python3.10/site-packages/pandas/core/indexes/base.py", line 3812, in get_loc
raise KeyError(key) from err
KeyError: 'w'
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for the quick changes, lgtm!
In this PR we add multi-target support for both classification and regression.
The following code-adaptation have been made:
postprocess_output
instead of in the individual metricslogits
,probabilities
andpredictions
num_classes
needs to still be set. Potentially would be easier to also start deriving that automatically, would also allow to get rid of all the error checksWhen reviewing, main potential for bugs is with respect to post processing, loss and metrics.
Closes #805