Skip to content

Commit

Permalink
Support torch.bfloat16 in hivemind.compression (learning-at-home#524)
Browse files Browse the repository at this point in the history
This PR implements bfloat16 support for `CompressionType.NONE` and `CompressionType.BLOCKWISE_8BIT`.

This is important for the Petals client, see bigscience-workshop/petals#79
  • Loading branch information
borzunov authored Nov 28, 2022
1 parent 8d51b97 commit 1e4af43
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 18 deletions.
21 changes: 15 additions & 6 deletions hivemind/compression/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,18 +80,27 @@ class NoCompression(CompressionBase):
compression_type = runtime_pb2.CompressionType.NONE

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
array = tensor.detach().numpy()
tensor = tensor.detach()
dtype_name = str(tensor.dtype).lstrip("torch.")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)

return runtime_pb2.Tensor(
compression=self.compression_type,
buffer=array.tobytes(),
size=array.shape,
dtype=array.dtype.name,
buffer=tensor.numpy().tobytes(),
size=tensor.shape,
dtype=dtype_name,
requires_grad=tensor.requires_grad,
)

def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
return torch.as_tensor(array).reshape(tuple(serialized_tensor.size))
if serialized_tensor.dtype == "bfloat16":
array = np.frombuffer(serialized_tensor.buffer, dtype=np.float32)
tensor = torch.as_tensor(array, dtype=torch.bfloat16)
else:
array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype))
tensor = torch.as_tensor(array)
return tensor.reshape(tuple(serialized_tensor.size))

def estimate_compression_ratio(self, info: CompressionInfo) -> float:
return 1.0
17 changes: 12 additions & 5 deletions hivemind/compression/quantization.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,8 @@ def quantile_qq_approximation(array: np.ndarray, n_quantiles: int, min_chunk_siz
return np.quantile(partition_quantiles, quantiles)


BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
Please install it with `pip install bitsandbytes`
BNB_MISSING_MESSAGE = """BlockwiseQuantization requires bitsandbytes to function properly.
Please install it with `pip install bitsandbytes`
or using the instruction from https://github.com/TimDettmers/bitsandbytes."""


Expand All @@ -139,7 +139,12 @@ def quantize(
return quantized.numpy(), (absmax.numpy(), codebook.numpy())

def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor:
quantized, (absmax, codebook) = self.quantize(tensor.detach(), allow_inplace=allow_inplace)
tensor = tensor.detach()
dtype_name = str(tensor.dtype).lstrip("torch.")
if tensor.dtype == torch.bfloat16:
tensor = tensor.to(torch.float32)

quantized, (absmax, codebook) = self.quantize(tensor, allow_inplace=allow_inplace)

serialized_data = (
np.int64(len(absmax)).tobytes(),
Expand All @@ -153,7 +158,7 @@ def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: b
buffer=b"".join(serialized_data),
size=tensor.shape,
requires_grad=tensor.requires_grad,
dtype=tensor.numpy().dtype.name,
dtype=dtype_name,
compression=self.compression_type,
)

Expand All @@ -172,6 +177,8 @@ def extract(self, serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor:
codebook = torch.as_tensor(codebook)
quantized = torch.as_tensor(quantized).reshape(tuple(serialized_tensor.size))
try:
return dequantize_blockwise(quantized, (absmax, codebook))
result = dequantize_blockwise(quantized, (absmax, codebook)) # Always returns a float32 tensor
except NameError:
raise ImportError(BNB_MISSING_MESSAGE)
result = result.to(dtype=getattr(torch, serialized_tensor.dtype))
return result
24 changes: 17 additions & 7 deletions tests/test_compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,15 +46,18 @@ def test_tensor_compression(size=(128, 128, 64), alpha=5e-08, beta=0.0008):
assert deserialize_torch_tensor(serialize_torch_tensor(zeros, compression_type)).isfinite().all()


def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
serialized_tensor = serialize_torch_tensor(tensor, compression)
chunks = list(split_for_streaming(serialized_tensor, chunk_size))
assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
restored = combine_from_streaming(chunks)
result = deserialize_torch_tensor(restored)
assert torch.allclose(result, tensor, rtol=rtol, atol=atol)
assert result.dtype == tensor.dtype


@pytest.mark.forked
def test_serialize_tensor():
def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
serialized_tensor = serialize_torch_tensor(tensor, compression)
chunks = list(split_for_streaming(serialized_tensor, chunk_size))
assert len(chunks) == (len(serialized_tensor.buffer) - 1) // chunk_size + 1
restored = combine_from_streaming(chunks)
assert torch.allclose(deserialize_torch_tensor(restored), tensor, rtol=rtol, atol=atol)

tensor = torch.randn(512, 12288)
for chunk_size in [1024, 64 * 1024, 64 * 1024 + 1, 10**9]:
_check(tensor, CompressionType.NONE, chunk_size=chunk_size)
Expand All @@ -65,6 +68,13 @@ def _check(tensor, compression, rtol=1e-5, atol=1e-8, chunk_size=30 * 1024):
_check(torch.tensor(1.0), CompressionType.FLOAT16)


@pytest.mark.forked
def test_serialize_bfloat16():
tensor = torch.randn(4096, 16, dtype=torch.bfloat16)
_check(tensor, CompressionType.NONE)
_check(tensor, CompressionType.BLOCKWISE_8BIT, rtol=0.1, atol=0.01, chunk_size=1024)


@pytest.mark.forked
def test_allreduce_compression():
"""this test ensures that compression works correctly when multiple tensors have different compression types"""
Expand Down

0 comments on commit 1e4af43

Please sign in to comment.