Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed and memory consumption of binned PrecisionRecallCurve #1493

Merged

Conversation

Callidior
Copy link
Contributor

What does this PR do?

Fixes #1492 by reverting the implementation of the update step of binary and multiclass PrecisionRecallCurve with a fixed number of thresholds to the implementation from torchmetrics 0.9.

I did not touch the multilabel case, because it is difficult to adapt the ignore_idx filtering logic, a first attempt of implementing it on a per-threshold basis was much slower, and multilabel classification is not common in semantic segmentation scenarios, which is where we typically have to deal with huge numbers of samples.

Before submitting
  • Was this discussed/agreed via a Github issue? (no need for typos and docs improvements)
  • Did you read the contributor guideline, Pull Request section?
  • Did you make sure to update the docs?
  • Did you write any new necessary tests?
PR review

Anyone in the community is free to review the PR once the tests have passed.
If we didn't discuss your PR in Github issues there's a high chance it will not be merged.

Performance Comparison

Hardware: NVIDIA RTX 3090

BinaryPrecisionRecallCurve.update()

thresholds=200

#Samples Current Speed New Speed Current Memory New Memory
100 0.2 ms 23.8 ms < 1 MB < 1 MB
1,000 1.1 ms 24.4 ms 20 MB < 1 MB
25,000 24.0 ms 25.8 ms 200 MB < 1 MB
100,000 115.3 ms 25.5 ms 722 MB 2 MB
1,000,000 1,012.1 ms 26.7 ms 7,156 MB 2 MB
3,000,000 7,948.3 ms 56.5 ms 18,886 MB 44 MB

MulticlassPrecisionRecallCurve.update()

num_classes=100, thresholds=200

#Samples Current Speed New Speed Current Memory New Memory
10 0.2 ms 15.5 ms 20 MB < 1 MB
100 0.5 ms 15.2 ms 52 MB < 1 MB
1,000 3.5 ms 15.2 ms 484 MB 4 MB
10,000 30.6 ms 30.8 ms 4,770 MB 22 MB
25,000 76.6 ms 70.6 ms 12,404 MB 60 MB

The implementation proposed here is not faster in all cases but always more memory efficient. In particular, it is up to 100x slower for few numbers of samples such as in a typical image-level classification scenario. However, it is up to 100x faster for scenarios with many samples per batch such as in semantic segmentation use-cases. In these scenarios, the extreme memory consumption of the previous implementation is the largest deal breaker.

Did you have fun?

Make sure you had fun coding 🙃

@codecov
Copy link

codecov bot commented Feb 9, 2023

Codecov Report

Merging #1493 (6ab32fb) into master (2322414) will decrease coverage by 38%.
The diff coverage is 31%.

Additional details and impacted files
@@           Coverage Diff            @@
##           master   #1493     +/-   ##
========================================
- Coverage      89%     51%    -38%     
========================================
  Files         216     216             
  Lines       11243   11275     +32     
========================================
- Hits         9983    5720   -4263     
- Misses       1260    5555   +4295     

@justusschock
Copy link
Member

justusschock commented Feb 9, 2023

Hi @Callidior thanks for this PR. In general the implementation looks good to me. In #1492 you mention that which implementation is faster is influenced by the number of samples, number of thresholds (and potentially number of classes as well). Is there a way we could derive a condition by on the product of those or should we just make this adjustable via a flag (with the new implementation as default)?

I am just asking this since while you are saying that you'd benefit from the new implementation, there might also be a lot of people benefitting from the current implementation (e.g. when memory isn't too much a constraint). So ideally we would find a way to make both implementations available or at least switch appropriately.

@Callidior
Copy link
Contributor Author

Hi @justusschock! Thanks for your feedback. I completely agree that there are also common use cases (such as image-level classification with < 100 samples per batch), for which the current implementation would be the better choice. An automatic switch that selects the more suitable algorithm would be my preferred solution as well.

For the setup I benchmarked above for binary classification, one would probably switch to the old implementation (this PR) from 10k samples and above. I could run some more tests with different numbers of thresholds, but the result might even by hardware-dependent.

What is more predictable, though, is the memory consumption. I would definitely not spend much more than 100 MB for computing a metric if I could do the same with 1-2 MB. We could estimate the memory needed by the current implementation based on number of samples * number of thresholds (* number of classes) * constant and switch implementations based on that.

What do you think?

@Callidior
Copy link
Contributor Author

Callidior commented Feb 9, 2023

A more detailed speed benchmark on the binary case (time in ms):

image

  • The loop implementation scales linearly with the number of thresholds but is almost insensitive to the number of samples.
  • The vectorized implementation scales linearly with the number of samples.
  • Keep in mind that memory consumption also scales almost linearly with number of samples and thresholds for the vectorized implementation. The case in the bottom right corner takes 24 GB of VRAM, while the loop implementation has a constant memory consumption of 2 MB.

The documentation of PrecisionRecallCurve currently states;

The implementation both supports calculating the metric in a non-binned but accurate version and a binned version that is less accurate but more memory efficient. Setting the thresholds argument to None will activate the non-binned version that uses memory of size O(n_samples) whereas setting the thresholds argument to either an integer, list or a 1d tensor will use a binned version that uses memory of size O(n_thresholds) (constant memory).

That is true for the metric's state but not for intermediate computations during the update step. The documentation does not mention speed but puts a promise on memory efficiency in the foreground, which is currently not kept.

With the current implementation, I don't see any benefit of using the binned version of the metric. If you have few samples, the accurate version (thresholds=None) is still very tractable and more accurate than the binned one. If you have many samples, both the binned and the accurate version of unusable due to high memory requirements and computation time.

@justusschock: Given that users can already switch between the accurate and the binned version, I don't think an explicit second switch for two different implementations of the binned version makes sense.

@SkafteNicki
Copy link
Member

@Callidior thank you for investigating this issue.

With the current implementation, I don't see any benefit of using the binned version of the metric. If you have few samples, the accurate version (thresholds=None) is still very tractable and more accurate than the binned one. If you have many samples, both the binned and the accurate version of unusable due to high memory requirements and computation time.

The thing your analysis is missing is that you are assuming that all samples are evaluated at the same time, however this is not how it is done in practise in any deep learning pipeline where samples arrive in mini-batches. Thus, there is a significant difference in memory between doing

preds, target = torch.rand(1e6), torch.randint(2, (1e6,))
output = metric(preds, target)

and

preds, target = torch.rand(1e6), torch.randint(2, (1e6,))
for i in range(1000):
    metric.update(preds[1000*i:1000*(i+1)], target[1000*i:1000*(i+1)])
output = metric.compute()

this is the reason for the binned version existing, because the non-binned version would have to save all samples in memory in both cases whereas the binned version only needs to save reduced state (the [n_threshold, n_class, n_class] tensor) which is significant less memory.

@SkafteNicki
Copy link
Member

I would agree with @justusschock that the correct approach is implement some kind of heuristic that changes between bincount and a for-loop depending on how many samples and threshold the user are using. To settle on this heuristic, would it be possible for you @Callidior do share the benchmark script you have developed. In particular I am also interested in what the difference between doing this on CPU is (it seems you have only tried out GPU).

@Callidior
Copy link
Contributor Author

Callidior commented Feb 10, 2023

@SkafteNicki Fair point. I updated the PR to implement the suggested heuristic switch between the two implementations.

Further benchmark results for the binary and multi-class case on both CPU and GPU can be found below, as well as the benchmark script. The picture on CPU doesn't look too differently from the situation on the GPU.

Based on those benchmarks, I'd propose the following heuristic:

  • For binary classification, use the loop implementation if num_samples > 50,000.
  • For multiclass classification, use the loop implementation if num_samples * num_classes > 1,000,000.

Speed Benchmark

GPU: NVIDIA GeForce RTX 3090
CPU: AMD Ryzen 9 5900X (12-Core)

Binary

GPU

pr_benchmark_binary_gpu

CPU

pr_benchmark_binary_cpu

Multi-Class

GPU

White cells mean that more than 24 GB of GPU memory would have been required.

num_classes==10

pr_benchmark_10cls_gpu

num_classes==100

pr_benchmark_100cls_gpu

num_classes==1000

pr_benchmark_1000cls_gpu

CPU

num_classes==10

pr_benchmark_10cls_cpu

num_classes==100

pr_benchmark_100cls_cpu

num_classes==1000

My 96 GB of main memory are not sufficient to run the vectorized implementation.

Benchmark Script

Click to show code
import argparse
import numpy as np
import torch
import matplotlib.colors
import matplotlib.pyplot as plt
from time import perf_counter_ns
from functools import partial
from tqdm import tqdm


def binary_pr_curve_update_vectorized(preds, targets, thresholds):
    preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0)).long()
    unique_mapping = preds_t + 2 * targets.unsqueeze(-1) + 4 * torch.arange(len(thresholds), device=targets.device)
    bins = torch.bincount(unique_mapping.flatten(), minlength=4 * len(thresholds))
    confmat = bins.reshape(len(thresholds), 2, 2)
    return confmat

def binary_pr_curve_update_loop(preds, targets, thresholds):
    confmat = thresholds.new_empty((len(thresholds), 2, 2), dtype=torch.int64)
    targets_t = targets == 1
    for i in range(len(thresholds)):
        preds_t = preds >= thresholds[i]
        confmat[i, 1, 1] = (targets_t & preds_t).sum()
        confmat[i, 0, 1] = ((~targets_t) & preds_t).sum()
        confmat[i, 1, 0] = (targets_t & (~preds_t)).sum()
    confmat[:, 0, 0] = len(preds_t) - confmat[:, 0, 1] - confmat[:, 1, 0] - confmat[:, 1, 1]
    return confmat


def multiclass_pr_curve_update_vectorized(preds, targets, thresholds, num_classes):
    len_t = len(thresholds)
    preds_t = (preds.unsqueeze(-1) >= thresholds.unsqueeze(0).unsqueeze(0)).long()
    target_t = torch.nn.functional.one_hot(targets, num_classes=num_classes)
    unique_mapping = preds_t + 2 * target_t.unsqueeze(-1)
    unique_mapping += 4 * torch.arange(num_classes, device=preds.device).unsqueeze(0).unsqueeze(-1)
    unique_mapping += 4 * num_classes * torch.arange(len_t, device=preds.device)
    bins = torch.bincount(unique_mapping.flatten(), minlength=4 * num_classes * len_t)
    return bins.reshape(len_t, num_classes, 2, 2)


def multiclass_pr_curve_update_loop(preds, targets, thresholds, num_classes):
    len_t = len(thresholds)
    target_t = torch.nn.functional.one_hot(targets, num_classes=num_classes)
    confmat = thresholds.new_empty((len_t, num_classes, 2, 2), dtype=torch.int64)
    for i in range(len_t):
        preds_t = preds >= thresholds[i]
        confmat[i, :, 1, 1] = (target_t & preds_t).sum(dim=0)
        confmat[i, :, 0, 1] = ((~target_t) & preds_t).sum(dim=0)
        confmat[i, :, 1, 0] = (target_t & (~preds_t)).sum(dim=0)
    confmat[:, :, 0, 0] = len(preds_t) - confmat[:, :, 0, 1] - confmat[:, :, 1, 0] - confmat[:, :, 1, 1]
    return confmat


def benchmark_gpu(fn1, fn2, num_samples_list, num_thresholds_list, num_classes=1, num_loops=100):
    times1 = np.empty((len(num_samples_list), len(num_thresholds_list)))
    times2 = np.empty((len(num_samples_list), len(num_thresholds_list)))

    evt1 = torch.cuda.Event(enable_timing=True)
    evt2 = torch.cuda.Event(enable_timing=True)
    evt3 = torch.cuda.Event(enable_timing=True)

    for i, num_samples in enumerate(tqdm(num_samples_list)):
        for j, num_thresholds in enumerate(tqdm(num_thresholds_list, leave=False)):
            targets = torch.randint(0, max(2, num_classes), (num_samples,), device="cuda")
            preds = torch.rand((num_samples, num_classes), device="cuda").squeeze()
            thresholds = torch.linspace(0, 1, num_thresholds, device="cuda")

            evt1.record()
            try:
                for _ in range(num_loops):
                    confmat1 = fn1(preds, targets, thresholds)
            except RuntimeError:
                confmat1 = None
            evt2.record()
            try:
                for _ in range(num_loops):
                    confmat2 = fn2(preds, targets, thresholds)
            except RuntimeError:
                confmat2 = None
            evt3.record()

            torch.cuda.synchronize()
            assert confmat1 is None or confmat2 is None or torch.allclose(confmat1, confmat2)

            times1[i,j] = evt1.elapsed_time(evt2) / num_loops if confmat1 is not None else np.nan
            times2[i,j] = evt2.elapsed_time(evt3) / num_loops if confmat2 is not None else np.nan

    return times1, times2


def benchmark_cpu(fn1, fn2, num_samples_list, num_thresholds_list, num_classes=1, num_loops=100):
    times1 = np.empty((len(num_samples_list), len(num_thresholds_list)))
    times2 = np.empty((len(num_samples_list), len(num_thresholds_list)))

    for i, num_samples in enumerate(tqdm(num_samples_list)):
        for j, num_thresholds in enumerate(tqdm(num_thresholds_list, leave=False)):
            targets = torch.randint(0, max(2, num_classes), (num_samples,), device="cpu")
            preds = torch.rand((num_samples, num_classes), device="cpu").squeeze()
            thresholds = torch.linspace(0, 1, num_thresholds, device="cpu")

            evt1 = perf_counter_ns()
            for _ in range(num_loops):
                confmat1 = fn1(preds, targets, thresholds)
            evt2 = perf_counter_ns()
            for _ in range(num_loops):
                confmat2 = fn2(preds, targets, thresholds)
            evt3 = perf_counter_ns()

            assert torch.allclose(confmat1, confmat2)

            times1[i,j] = (evt2 - evt1) / (num_loops * 1e6)
            times2[i,j] = (evt3 - evt2) / (num_loops * 1e6)

    return times1, times2


def plot_timings(*times, num_samples_list, num_thresholds_list, titles=None):
    fig, axes = plt.subplots(1, len(times), figsize=(6*len(times), 8), constrained_layout=True, sharey=True)

    vmin = min([np.nanmin(t) for t in times])
    vmax = max([np.nanmax(t) for t in times])
    scaler = matplotlib.colors.LogNorm(vmin=vmin, vmax=vmax)
    for ax, t in zip(axes, times):
        ax.imshow(t.T, norm=scaler)

    if titles is not None:
        for ax, title in zip(axes, titles):
            ax.set_title(title)

    for i in range(len(num_samples_list)):
        for j in range(len(num_thresholds_list)):
            for ax, t in zip(axes, times):
                if not np.isnan(t[i,j]):
                    ax.annotate(
                        f"{t[i,j]:.1f}",
                        (i, j),
                        ha="center",
                        va="center",
                        color="white" if scaler(t[i,j]) < 0.8 else "black"
                    )

    axes[0].set_ylabel("Number of Thresholds")
    for ax in axes.ravel():
        ax.set_xlabel("Number of Samples")
        ax.set_xticks(np.arange(len(num_samples_list)), [f"{x:,d}" for x in num_samples_list])
        ax.set_yticks(np.arange(len(num_thresholds_list)), [f"{x:,d}" for x in num_thresholds_list])

    return fig


if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--num-classes", type=int, default=1, help="Number of classes. Set to 1 for binary classification.")
    parser.add_argument("--num-loops", type=int, default=100, help="Number of repetitions per experiment.")
    parser.add_argument("--cpu", action="store_true", default=False, help="Run on CPU.")
    args = parser.parse_args()
    
    num_thresholds_list = [10, 50, 100, 200, 500, 1000]
    if args.num_classes == 1:
        fn1 = binary_pr_curve_update_vectorized
        fn2 = binary_pr_curve_update_loop
        num_samples_list = [10, 100, 1000, 10000, 100000, 1000000]
    else:
        fn1 = partial(multiclass_pr_curve_update_vectorized, num_classes=args.num_classes)
        fn2 = partial(multiclass_pr_curve_update_loop, num_classes=args.num_classes)
        num_samples_list = [10, 100, 1000, 10000, 25000]

    benchmark_fn = benchmark_cpu if args.cpu else benchmark_gpu
    
    times = benchmark_fn(
        fn1,
        fn2,
        num_samples_list,
        num_thresholds_list,
        num_classes=args.num_classes,
        num_loops=args.num_loops,
    )
    plot_timings(
        *times,
        num_samples_list=num_samples_list,
        num_thresholds_list=num_thresholds_list,
        titles=["Vectorized", "Loop"],
    )
    plt.show()

@justusschock
Copy link
Member

@Callidior this heuristic is fine with me. Thanks for all the detailed benchmarks!

@Borda Borda changed the title Fix speed and memory consumption of binned PrecisionRecallCurve with large number of samples Improve speed and memory consumption of binned PrecisionRecallCurve Feb 20, 2023
@mergify mergify bot added the ready label Feb 20, 2023
@Borda Borda enabled auto-merge (squash) February 21, 2023 11:43
auto-merge was automatically disabled February 22, 2023 15:19

Head branch was pushed to by a user without write access

@mergify mergify bot removed the has conflicts label Feb 22, 2023
@Borda Borda enabled auto-merge (squash) February 22, 2023 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Binned implementation of PrecisionRecallCurve is extremely slow for large batch sizes
4 participants