Skip to content

Commit

Permalink
Simplify logic in get_default_compression (#6260)
Browse files Browse the repository at this point in the history
  • Loading branch information
jakirkham authored May 5, 2022
1 parent bc3c891 commit 7bd6442
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 11 deletions.
18 changes: 8 additions & 10 deletions distributed/protocol/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,17 +99,15 @@ def zstd_decompress(data):

def get_default_compression():
default = dask.config.get("distributed.comm.compression")
if default != "auto":
if default in compressions:
return default
else:
raise ValueError(
"Default compression '%s' not found.\n"
"Choices include auto, %s"
% (default, ", ".join(sorted(map(str, compressions))))
)
else:
if default == "auto":
return default_compression
if default in compressions:
return default
raise ValueError(
"Default compression '%s' not found.\n"
"Choices include auto, %s"
% (default, ", ".join(sorted(map(str, compressions))))
)


get_default_compression()
Expand Down
26 changes: 25 additions & 1 deletion distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
import pytest

import dask

from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import compressions
from distributed.protocol.compression import (
compressions,
default_compression,
get_default_compression,
)
from distributed.protocol.cuda import cuda_deserialize, cuda_serialize
from distributed.protocol.serialize import (
Serialize,
Expand All @@ -20,6 +26,24 @@ def test_protocol():
assert loads(dumps(msg)) == msg


@pytest.mark.parametrize(
"config,default",
[
("auto", default_compression),
(None, None),
("zlib", "zlib"),
("foo", ValueError),
],
)
def test_compression_config(config, default):
with dask.config.set({"distributed.comm.compression": config}):
if type(default) is type and issubclass(default, Exception):
with pytest.raises(default):
assert get_default_compression()
else:
assert get_default_compression() == default


def test_compression_1():
pytest.importorskip("lz4")
np = pytest.importorskip("numpy")
Expand Down

0 comments on commit 7bd6442

Please sign in to comment.