|  | 
|  | 1 | +import math | 
|  | 2 | + | 
|  | 3 | +import pytest | 
|  | 4 | + | 
|  | 5 | +from ignite.contrib.handlers.custom_events import CustomPeriodicEvent | 
|  | 6 | +from ignite.engine import Engine | 
|  | 7 | + | 
|  | 8 | + | 
|  | 9 | +def test_bad_input(): | 
|  | 10 | + | 
|  | 11 | +    with pytest.warns(DeprecationWarning, match=r"CustomPeriodicEvent is deprecated"): | 
|  | 12 | +        with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): | 
|  | 13 | +            CustomPeriodicEvent(n_iterations="a") | 
|  | 14 | +        with pytest.raises(ValueError, match="Argument n_iterations should be positive"): | 
|  | 15 | +            CustomPeriodicEvent(n_iterations=0) | 
|  | 16 | +        with pytest.raises(TypeError, match="Argument n_iterations should be an integer"): | 
|  | 17 | +            CustomPeriodicEvent(n_iterations=10.0) | 
|  | 18 | +        with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): | 
|  | 19 | +            CustomPeriodicEvent(n_epochs="a") | 
|  | 20 | +        with pytest.raises(ValueError, match="Argument n_epochs should be positive"): | 
|  | 21 | +            CustomPeriodicEvent(n_epochs=0) | 
|  | 22 | +        with pytest.raises(TypeError, match="Argument n_epochs should be an integer"): | 
|  | 23 | +            CustomPeriodicEvent(n_epochs=10.0) | 
|  | 24 | +        with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"): | 
|  | 25 | +            CustomPeriodicEvent() | 
|  | 26 | +        with pytest.raises(ValueError, match="Either n_iterations or n_epochs should be defined"): | 
|  | 27 | +            CustomPeriodicEvent(n_iterations=1, n_epochs=2) | 
|  | 28 | + | 
|  | 29 | + | 
|  | 30 | +def test_new_events(): | 
|  | 31 | +    def update(*args, **kwargs): | 
|  | 32 | +        pass | 
|  | 33 | + | 
|  | 34 | +    with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | 
|  | 35 | +        engine = Engine(update) | 
|  | 36 | +        cpe = CustomPeriodicEvent(n_iterations=5) | 
|  | 37 | +        cpe.attach(engine) | 
|  | 38 | + | 
|  | 39 | +        assert hasattr(cpe, "Events") | 
|  | 40 | +        assert hasattr(cpe.Events, "ITERATIONS_5_STARTED") | 
|  | 41 | +        assert hasattr(cpe.Events, "ITERATIONS_5_COMPLETED") | 
|  | 42 | + | 
|  | 43 | +        assert engine._allowed_events[-2] == getattr(cpe.Events, "ITERATIONS_5_STARTED") | 
|  | 44 | +        assert engine._allowed_events[-1] == getattr(cpe.Events, "ITERATIONS_5_COMPLETED") | 
|  | 45 | + | 
|  | 46 | +    with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | 
|  | 47 | +        cpe = CustomPeriodicEvent(n_epochs=5) | 
|  | 48 | +        cpe.attach(engine) | 
|  | 49 | + | 
|  | 50 | +        assert hasattr(cpe, "Events") | 
|  | 51 | +        assert hasattr(cpe.Events, "EPOCHS_5_STARTED") | 
|  | 52 | +        assert hasattr(cpe.Events, "EPOCHS_5_COMPLETED") | 
|  | 53 | + | 
|  | 54 | +        assert engine._allowed_events[-2] == getattr(cpe.Events, "EPOCHS_5_STARTED") | 
|  | 55 | +        assert engine._allowed_events[-1] == getattr(cpe.Events, "EPOCHS_5_COMPLETED") | 
|  | 56 | + | 
|  | 57 | + | 
|  | 58 | +def test_integration_iterations(): | 
|  | 59 | +    def _test(n_iterations, max_epochs, n_iters_per_epoch): | 
|  | 60 | +        def update(*args, **kwargs): | 
|  | 61 | +            pass | 
|  | 62 | + | 
|  | 63 | +        engine = Engine(update) | 
|  | 64 | +        with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | 
|  | 65 | +            cpe = CustomPeriodicEvent(n_iterations=n_iterations) | 
|  | 66 | +            cpe.attach(engine) | 
|  | 67 | +        data = list(range(n_iters_per_epoch)) | 
|  | 68 | + | 
|  | 69 | +        custom_period = [0] | 
|  | 70 | +        n_calls_iter_started = [0] | 
|  | 71 | +        n_calls_iter_completed = [0] | 
|  | 72 | + | 
|  | 73 | +        event_started = getattr(cpe.Events, "ITERATIONS_{}_STARTED".format(n_iterations)) | 
|  | 74 | + | 
|  | 75 | +        @engine.on(event_started) | 
|  | 76 | +        def on_my_event_started(engine): | 
|  | 77 | +            assert (engine.state.iteration - 1) % n_iterations == 0 | 
|  | 78 | +            custom_period[0] += 1 | 
|  | 79 | +            custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) | 
|  | 80 | +            assert custom_iter == custom_period[0] | 
|  | 81 | +            n_calls_iter_started[0] += 1 | 
|  | 82 | + | 
|  | 83 | +        event_completed = getattr(cpe.Events, "ITERATIONS_{}_COMPLETED".format(n_iterations)) | 
|  | 84 | + | 
|  | 85 | +        @engine.on(event_completed) | 
|  | 86 | +        def on_my_event_ended(engine): | 
|  | 87 | +            assert engine.state.iteration % n_iterations == 0 | 
|  | 88 | +            custom_iter = getattr(engine.state, "iterations_{}".format(n_iterations)) | 
|  | 89 | +            assert custom_iter == custom_period[0] | 
|  | 90 | +            n_calls_iter_completed[0] += 1 | 
|  | 91 | + | 
|  | 92 | +        engine.run(data, max_epochs=max_epochs) | 
|  | 93 | + | 
|  | 94 | +        n = len(data) * max_epochs / n_iterations | 
|  | 95 | +        nf = math.floor(n) | 
|  | 96 | +        assert custom_period[0] == n_calls_iter_started[0] | 
|  | 97 | +        assert n_calls_iter_started[0] == nf + 1 if nf < n else nf | 
|  | 98 | +        assert n_calls_iter_completed[0] == nf | 
|  | 99 | + | 
|  | 100 | +    _test(3, 5, 16) | 
|  | 101 | +    _test(4, 5, 16) | 
|  | 102 | +    _test(5, 5, 16) | 
|  | 103 | +    _test(300, 50, 1000) | 
|  | 104 | + | 
|  | 105 | + | 
|  | 106 | +def test_integration_epochs(): | 
|  | 107 | +    def update(*args, **kwargs): | 
|  | 108 | +        pass | 
|  | 109 | + | 
|  | 110 | +    engine = Engine(update) | 
|  | 111 | + | 
|  | 112 | +    n_epochs = 3 | 
|  | 113 | +    with pytest.warns(DeprecationWarning, match="CustomPeriodicEvent is deprecated"): | 
|  | 114 | +        cpe = CustomPeriodicEvent(n_epochs=n_epochs) | 
|  | 115 | +        cpe.attach(engine) | 
|  | 116 | +    data = list(range(16)) | 
|  | 117 | + | 
|  | 118 | +    custom_period = [1] | 
|  | 119 | + | 
|  | 120 | +    @engine.on(cpe.Events.EPOCHS_3_STARTED) | 
|  | 121 | +    def on_my_epoch_started(engine): | 
|  | 122 | +        assert (engine.state.epoch - 1) % n_epochs == 0 | 
|  | 123 | +        assert engine.state.epochs_3 == custom_period[0] | 
|  | 124 | + | 
|  | 125 | +    @engine.on(cpe.Events.EPOCHS_3_COMPLETED) | 
|  | 126 | +    def on_my_epoch_ended(engine): | 
|  | 127 | +        assert engine.state.epoch % n_epochs == 0 | 
|  | 128 | +        assert engine.state.epochs_3 == custom_period[0] | 
|  | 129 | +        custom_period[0] += 1 | 
|  | 130 | + | 
|  | 131 | +    engine.run(data, max_epochs=10) | 
|  | 132 | + | 
|  | 133 | +    assert custom_period[0] == 4 | 
0 commit comments