Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using weights creates a confusing confusion matrix and inaccurate accuracy score #114

Open
CodingDoug opened this issue Jul 2, 2024 · 1 comment

Comments

@CodingDoug
Copy link

I'm using RandomForestLearner to train a 10-class categorization model using roughly 15000 examples and 12 features. My example set is imbalanced in terms of category distribution, so I need to use class-based weighting to boost the under-represented classes.

I'm post-processing my dataset with weights computed from the entire set:

for row in rows:
    row["weight"] = count / (len(category_counts) * category_counts[row["category"]])

The resulting model is effective, but the confusion matrix is confusing. Here part of the output from model.describe():

Confusion Table:
truth\prediction
          n     p     e     h     c     b     t     a     d     s
    n1504.9912.17848.717155.255933.717610.384580.384581.794711.281931.79471
    p10.05111525.09     0     0     00.670074     0     0     04.69052
    e23.4119     01517.09     0     0     0     0     0     0     0
    h31.82856.3657     01502.31     0     0     0     0     0     0
    c332.748     0     0     01170.78     0     0     024.64812.324
    b28.527828.5278     0     0     01483.44     0     0     0     0
    t55.6807     0     0     0     0     01484.82     0     0     0
    a418.407     0     0     0     0     0     01122.09     0     0
    d303.432     0     0     023.3409     0     046.68181143.723.3409
    s252.08284.0273     0     028.0091     0     0     028.00911148.37
Total: 15405

Basically unreadable. Here it is again from model.self_evaluation():

accuracy: 0.883005
confusion matrix:
    label (row) \ prediction (col)
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |          |        n |        p |        e |        h |        c |        b |        t |        a |        d |        s |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        n |  1504.99 |  12.1784 |  8.71715 |  5.25593 |  3.71761 |  0.38458 |  0.38458 |  1.79471 |  1.28193 |  1.79471 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        p |  10.0511 |  1525.09 |        0 |        0 |        0 | 0.670074 |        0 |        0 |        0 |  4.69052 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        e |  23.4119 |        0 |  1517.09 |        0 |        0 |        0 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        h |  31.8285 |   6.3657 |        0 |  1502.31 |        0 |        0 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        c |  332.748 |        0 |        0 |        0 |  1170.78 |        0 |        0 |        0 |   24.648 |   12.324 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        b |  28.5278 |  28.5278 |        0 |        0 |        0 |  1483.44 |        0 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        t |  55.6807 |        0 |        0 |        0 |        0 |        0 |  1484.82 |        0 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        a |  418.407 |        0 |        0 |        0 |        0 |        0 |        0 |  1122.09 |        0 |        0 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        d |  303.432 |        0 |        0 |        0 |  23.3409 |        0 |        0 |  46.6818 |   1143.7 |  23.3409 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
    |        s |  252.082 |  84.0273 |        0 |        0 |  28.0091 |        0 |        0 |        0 |  28.0091 |  1148.37 |
    +----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+----------+
loss: 1.2135
num examples: 15405
num examples (weighted): 15405

Without weights, the confusion matrix prints integers, as I would expect. With weights, it's these floating point numbers that don't make much sense. Also I believe the accuracy number is incorrect. If I run predictions against the model using the same training dataset, I compute only 186 of 15405 incorrect predictions (1.2%).

@achoum
Copy link
Collaborator

achoum commented Jul 5, 2024

Hi Doug,
Thanks for the details.

The formatting issue in the confusion matrix printed in the training logs was fixed. The fix will be included in a next release (the next one of the one after).

Note that model.evaluate computes a non-weighted evaluation. A next release introduces a "weighted" argument to model.evaluate to enable weighted evaluations.

After some exploration in this case, it seems the model prediction and evaluation (e.g. programmatic access) is correct (i.e., it is only a display issue).

If I run predictions against the model using the same training dataset, I compute only 186 of 15405 incorrect predictions (1.2%).

This is possible.

If a training dataset is small, the model self evaluation will be noisy. Having example weights (both for training or evaluation) further increase this noise.

If you use gradient boosted trees (GBT), the self evaluation is computed using the validation dataset (which it extracted from the training dataset if not provided). So, if the training dataset is small, the validation dataset is also small an a discrepancy between self evaluation and evaluation on a test set is expected.

If you use Random Forests , the self evaluation is computed using an out-of-bag evaluation. The out-of-bag evaluation is a conservative estimate of the model quality. If the dataset is small, this estimate can be poor. In addition, if the model contains a small amount of trees, the out-of-bag evaluation can be biased (in the conservative direction). Note that in this case, the winner_take_all Random Forest learner hyper-parameter can help--but using more trees is generally preferable.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants