Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feat] Add multiple dispatch #4

Merged
merged 2 commits into from
Sep 13, 2022
Merged

[Feat] Add multiple dispatch #4

merged 2 commits into from
Sep 13, 2022

Conversation

ice-tong
Copy link
Collaborator

@ice-tong ice-tong commented Aug 18, 2022

Motivation

Some mmeval metrics may have different calculation methods depending on the deep learning framework or numeric computing libraries used, such as PyTorch and NumPy.

In order to deal with the dispatch issue of different calculation methods, we adopt a dynamic multiple dispatch mechanism based on type hints.

A simple example of multiple dispatch based on type hints is as follows:

@dispatch
def compute(x: int, y: int):
    print('this is int')

@dispatch
def compute(x: str, y: str):
    print('this is str')

Modification

Currently, we employ plum (a multiple dispatch library) to implement multiple dispatch mechanism in mmeval.

In this module, we optimized the execution speed of plum through the following two tricks:

  • Caching plum Type instances
  • Caching plum Type hash value

Benefiting from the tricks above, plum dispatch got twice faster as before. More detail can be found at: beartype/plum#53

Besides, we implement MMEvalDispatcher to extend plum dispatch for better support of typing.ForwardRef.

BenchMark

1. MeanIoU with dispatch [1w samples, 512*512 size, num_classes 100]

MeanIoU with multiple dispatch: 17.16 s
MeanIoU without multiple dispatch: 16.37 s
Test code
import torch
from mmeval.segmentation import MeanIoU

from functools import wraps, partial
from contextlib import contextmanager
from time import time

def timeit(f):
    @wraps(f)
    def wrap(*args, **kw):
        time_start = time()
        result = f(*args, **kw)
        time_end = time()
        return time_end - time_start, result
    return wrap

@timeit
def test_miou(miou: MeanIoU, num_batch, batch_size, num_classes, size=512**2):
    miou.reset()
    predictions = torch.randint(0, num_classes, size=(batch_size, size))
    labels = torch.randint(0, num_classes, size=(batch_size, size))
    for _ in range(num_batch):
        miou.add(predictions, labels)
    results = miou.compute()
    return results

@contextmanager
def single_torch_impl(miou):
    multiple_method1 = miou._compute_metric
    multiple_method2 = miou.compute_confusion_matrix

    for sig, single_method in multiple_method1.methods.items():
        if 'torch.Tensor' in str(sig.types):
            miou._compute_metric = partial(single_method[0], miou)
    
    for sig, single_method in multiple_method2.methods.items():
        if 'torch.Tensor' in str(sig.types):
            miou.compute_confusion_matrix = partial(single_method[0], miou)

    yield
    miou._compute_metric = multiple_method1
    miou.compute_confusion_matrix = multiple_method2
    return

def test():
    num_classes = 100
    miou = MeanIoU(num_classes=num_classes)
    t, _ = test_miou(miou, 100, 100, num_classes=num_classes)
    print(MenIoU with multiple dispatch: {t:.4} s')

    with single_torch_impl(miou):
        t, _ = test_miou(miou, 100, 100, num_classes=num_classes)
        print('MenIoU without multiple dispatch: {t:.4} s')

if __name__ == "__main__":
    test()

2. Accuracy with dispatch [100w samples, num_classes 1000]

The plum type_of runs into performance issues with big nested parameters. We should avoid this situation or hack type_of to only inspect a subset of the nested parameters. See: beartype/plum#53

Accuracy with multiple dispatch: 5.678 s
Accuracy without multiple dispatch: 2.757 s

ps. Using the tricks above, we speed up the accuracy with multiple dispatch from 8.0 s to 5.6 s.

Test code
import torch
from mmeval.classification import Accuracy

from functools import wraps, partial
from contextlib import contextmanager
from time import time

def timeit(f):
    @wraps(f)
    def wrap(*args, **kw):
        time_start = time()
        result = f(*args, **kw)
        time_end = time()
        return time_end - time_start, result
    return wrap

@timeit
def test_accuracy(accuracy: Accuracy, num_batch, batch_size, num_classes=1000):
    accuracy.reset()
    for _ in range(num_batch):
        predictions = torch.rand(size=(batch_size, ))
        labels = torch.randint(0, num_classes, size=(batch_size, ))
        accuracy.add(predictions, labels)
    results = accuracy.compute()
    return results

@contextmanager
def single_torch_impl(accuracy):
    multiple_method = accuracy._compute_metric

    for sig, single_method in multiple_method.methods.items():
        if 'torch.Tensor' in str(sig.types):
            accuracy._compute_metric = partial(single_method[0], accuracy)

    yield
    accuracy._compute_metric = multiple_method
    return

def test():
    accuracy = Accuracy(topk=(1, 3, 5), thrs=(0.1, 0.3, 0.5))
    t, _ = test_accuracy(accuracy, 10000, 100)
    print(f'Accuracy with multiple dispatch: {t:.4} s')

    with single_torch_impl(accuracy):
        t, _ = test_accuracy(accuracy, 10000, 100)
        print(f'Accuracy without multiple dispatch: {t:.4} s')

if __name__ == "__main__":
    test()

@ice-tong ice-tong changed the title dev(mmeval/core/dispatch): add multiple dispatch [Feat] Add multiple dispatch Aug 18, 2022
@ice-tong ice-tong force-pushed the yancong/dev-dispatch branch 2 times, most recently from e5fe20a to c9f9617 Compare August 19, 2022 05:00
_singleton_patch()
except Exception as e:
logger.warning(
f'Patch `plum.type.TypeMeta` with singleton failed, raise error: {e}.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does it probably fail?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the case of failures due to plum version changes, we can still work.

@ice-tong ice-tong force-pushed the yancong/dev-dispatch branch 2 times, most recently from 062e847 to e47d1de Compare August 24, 2022 08:39
@ice-tong ice-tong changed the base branch from yancong/dev to main September 6, 2022 06:43
@ice-tong ice-tong changed the base branch from main to yancong/dev September 6, 2022 06:43
@ice-tong ice-tong changed the base branch from yancong/dev to main September 6, 2022 06:48
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
Comment on lines 111 to 112
logger.warning(
f'Patch plum Type with hash value cache failed, raise error: {e}.')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should describe that the evaluation speed will slow down.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated

mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
mmeval/core/dispatcher.py Outdated Show resolved Hide resolved
@zhouzaida zhouzaida merged commit f892eb2 into main Sep 13, 2022
@zhouzaida zhouzaida deleted the yancong/dev-dispatch branch October 29, 2022 13:01
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants