diff --git a/setup.cfg b/setup.cfg index 3ddaccc..4751a84 100644 --- a/setup.cfg +++ b/setup.cfg @@ -32,7 +32,7 @@ setup_requires = # Dependencies of the project (semicolon/line-separated): install_requires = numpy - torch + torch>=2.2.0 einops einconv # The usage of test_requires is discouraged, see `Dependency Management` docs diff --git a/singd/optim/optimizer.py b/singd/optim/optimizer.py index 47f3537..7cad745 100644 --- a/singd/optim/optimizer.py +++ b/singd/optim/optimizer.py @@ -591,13 +591,16 @@ def _accumulate_H_terms( kfac_approx = self._get_param_group_entry(module, "kfac_approx") module_name = self.module_names[module] + # load inputs and gradients to the same precision as the pre-conditioner + (dtype_K, dtype_C), _ = self._get_preconditioner_dtypes_and_device(module) + # 1) PROCESS INPUTS AND GRAD_OUTPUTS - a = inputs[0].data + a = inputs[0].data.to(dtype_K) # Process into matrix according to kfac_approx # For convolutions, unfold the input, for modules with bias terms, append a 1 a = process_input(a, module, kfac_approx) - g = grad_output.data + g = grad_output.data.to(dtype_C) # Process into matrix according to kfac_approx, add scaling from batch average g = process_grad_output(g, module, loss_average, kfac_approx) @@ -694,14 +697,20 @@ def _compute_natural_gradient(self, module: Module) -> Tuple[Tensor, ...]: # 2) COMPUTE THE NATURAL GRADIENT IN CONCATENATED MATRIX FORM module_name = self.module_names[module] + # load the gradient to the pre-conditioner precision while multiplying + dtype_K, dtype_C = self._get_param_group_entry(module, "preconditioner_dtype") + # We need to compute `W @ K @ K^T` where `W` is the weight gradient # `K` supports `K @ ...` and `K^T @ ...`. Hence, we rewrite into # `W @ K @ K^T = ( K @ (K^T @ W^T) )^T`. K = self.Ks[module_name] - nat_grad = (K @ K.rmatmat(grad_mat.T)).T + nat_grad = (K @ K.rmatmat(grad_mat.T.to(dtype_K))).T C = self.Cs[module_name] - nat_grad = C @ (C.rmatmat(nat_grad)) + nat_grad = C @ (C.rmatmat(nat_grad.to(dtype_C))) + + # load the pre-conditioned gradient back to the original precision + nat_grad = nat_grad.to(grad_mat.dtype) # If DDP is used. if dist.is_initialized(): diff --git a/singd/structures/base.py b/singd/structures/base.py index eb6f686..0680b0d 100644 --- a/singd/structures/base.py +++ b/singd/structures/base.py @@ -11,7 +11,7 @@ from torch import Tensor, zeros from torch.linalg import matrix_norm -from singd.structures.utils import diag_add_, supported_eye, supported_matmul +from singd.structures.utils import diag_add_, supported_eye class StructuredMatrix(ABC): @@ -100,10 +100,10 @@ def __matmul__( dense = self.to_dense() if isinstance(other, Tensor): - return supported_matmul(dense, other) + return dense @ other other_dense = other.to_dense() - return self.from_dense(supported_matmul(dense, other_dense)) + return self.from_dense(dense @ other_dense) @classmethod @abstractmethod @@ -206,7 +206,7 @@ def rmatmat(self, mat: Tensor) -> Tensor: A dense PyTorch tensor resulting from the multiplication. """ self._warn_naive_implementation("rmatmat") - return supported_matmul(self.to_dense().T, mat) + return self.to_dense().T @ mat @classmethod def _warn_naive_implementation(cls, fn_name: str): @@ -281,7 +281,7 @@ def from_inner(self, X: Union[Tensor, None] = None) -> StructuredMatrix: """ self._warn_naive_implementation("from_inner") S_dense = self.to_dense().T if X is None else self.rmatmat(X) - return self.from_dense(supported_matmul(S_dense, S_dense.T)) + return self.from_dense(S_dense @ S_dense.T) # NOTE This operation should be removed long-term as implementing IF-KFAC # with `from_inner` is more efficient. For now, it will exist as it makes @@ -298,7 +298,7 @@ def from_inner2(self, XXT: Tensor) -> StructuredMatrix: """ self._warn_naive_implementation("from_inner2") dense = self.to_dense() - return self.from_dense(supported_matmul(dense.T, XXT, dense)) + return self.from_dense(dense.T @ XXT @ dense) def average_trace(self) -> Tensor: """Compute the average trace of the represented matrix. diff --git a/singd/structures/blockdiagonal.py b/singd/structures/blockdiagonal.py index f5a4231..2c81b6c 100644 --- a/singd/structures/blockdiagonal.py +++ b/singd/structures/blockdiagonal.py @@ -6,15 +6,10 @@ import torch from einops import rearrange -from torch import Tensor, arange, cat, zeros +from torch import Tensor, arange, cat, einsum, zeros from singd.structures.base import StructuredMatrix -from singd.structures.utils import ( - lowest_precision, - supported_einsum, - supported_eye, - supported_matmul, -) +from singd.structures.utils import lowest_precision, supported_eye class BlockDiagonalMatrixTemplate(StructuredMatrix): @@ -194,7 +189,7 @@ def __matmul__( other_blocks = rearrange( other_blocks, "(block row) col -> block row col", **dims ) - result_blocks = supported_einsum( + result_blocks = einsum( "nij,njk->nik", self._blocks.to(compute_dtype), other_blocks.to(compute_dtype), @@ -205,15 +200,15 @@ def __matmul__( out_dtype = self._last.dtype compute_dtype = lowest_precision(self._last.dtype, other_last.dtype) - result_last = supported_matmul( - self._last.to(compute_dtype), other_last.to(compute_dtype) + result_last = ( + self._last.to(compute_dtype) @ other_last.to(compute_dtype) ).to(out_dtype) return cat([result_blocks, result_last]) else: - out_blocks = supported_einsum("nij,njk->nik", self._blocks, other._blocks) - out_last = supported_matmul(self._last, other._last) + out_blocks = einsum("nij,njk->nik", self._blocks, other._blocks) + out_last = self._last @ other._last return self.__class__(out_blocks, out_last) def __add__( @@ -286,8 +281,8 @@ def from_inner(self, X: Union[Tensor, None] = None) -> BlockDiagonalMatrixTempla dims = {"block": num_blocks, "row": self.BLOCK_DIM} S_blocks = rearrange(S_blocks, "(block row) col -> block row col", **dims) - out_blocks = supported_einsum("nij,nkj->nik", S_blocks, S_blocks) - out_last = supported_matmul(S_last, S_last.T) + out_blocks = einsum("nij,nkj->nik", S_blocks, S_blocks) + out_last = S_last @ S_last.T return self.__class__(out_blocks, out_last) @@ -299,10 +294,7 @@ def average_trace(self) -> Tensor: """ num_blocks, last_dim = self._blocks.shape[0], self._last.shape[0] dim = num_blocks * self.BLOCK_DIM + last_dim - return ( - supported_einsum("nii->", self._blocks / dim) - + (self._last.diag() / dim).sum() - ) + return einsum("nii->", self._blocks / dim) + (self._last.diag() / dim).sum() def diag_add_(self, value: float) -> BlockDiagonalMatrixTemplate: """In-place add a value to the diagonal of the represented matrix. diff --git a/singd/structures/hierarchical.py b/singd/structures/hierarchical.py index 5cd7b08..95c43a6 100644 --- a/singd/structures/hierarchical.py +++ b/singd/structures/hierarchical.py @@ -5,16 +5,10 @@ from typing import Tuple, Union import torch -from torch import Tensor, arange, cat, ones, zeros +from torch import Tensor, arange, cat, einsum, ones, zeros from singd.structures.base import StructuredMatrix -from singd.structures.utils import ( - diag_add_, - lowest_precision, - supported_einsum, - supported_eye, - supported_matmul, -) +from singd.structures.utils import diag_add_, lowest_precision, supported_eye class HierarchicalMatrixTemplate(StructuredMatrix): @@ -209,34 +203,25 @@ def __matmul__( [self.K1, self.diag_dim, self.K2] ) - top = ( - supported_matmul(self.A, other_top) - + supported_matmul(B_C, other_middle) - + supported_matmul(B_E, other_bottom) - ) - middle = supported_einsum("i,ij->ij", self.C, other_middle) - bottom = supported_matmul(self.D, other_middle) + supported_matmul( - self.E, other_bottom - ) + top = self.A @ other_top + B_C @ other_middle + B_E @ other_bottom + middle = einsum("i,ij->ij", self.C, other_middle) + bottom = self.D @ other_middle + self.E @ other_bottom return cat([top, middle, bottom], dim=0) else: - A_new = supported_matmul(self.A, other.A) + A_new = self.A @ other.A C_new = self.C * other.C - E_new = supported_matmul(self.E, other.E) - D_new = supported_einsum("ij,j->ij", self.D, other.C) + supported_matmul( - self.E, other.D - ) + E_new = self.E @ other.E + D_new = einsum("ij,j->ij", self.D, other.C) + self.E @ other.D B_C_other, B_E_other = other.B.split([other.diag_dim, other.K2], dim=1) B_new = cat( [ - supported_matmul(self.A, B_C_other) - + supported_einsum("ij,j->ij", B_C, other.C) - + supported_matmul(B_E, other.D), - supported_matmul(self.A, B_E_other) - + supported_matmul(B_E, other.E), + self.A @ B_C_other + + einsum("ij,j->ij", B_C, other.C) + + B_E @ other.D, + self.A @ B_E_other + B_E @ other.E, ], dim=1, ) @@ -291,20 +276,18 @@ def rmatmat(self, mat: Tensor) -> Tensor: # parts of B that share columns with C, E B_C, B_E = self.B.split([self.diag_dim, self.K2], dim=1) - top = supported_matmul(self.A.T, mat_top) + top = self.A.T @ mat_top compute_dtype = lowest_precision(self.C.dtype, mat_middle.dtype) out_dtype = self.C.dtype middle = ( - supported_matmul(B_C.T, mat_top) - + supported_einsum( + B_C.T @ mat_top + + einsum( "i,ij->ij", self.C.to(compute_dtype), mat_middle.to(compute_dtype) ).to(out_dtype) - + supported_matmul(self.D.T, mat_bottom) - ) - bottom = supported_matmul(B_E.T, mat_top) + supported_matmul( - self.E.T, mat_bottom + + self.D.T @ mat_bottom ) + bottom = B_E.T @ mat_top + self.E.T @ mat_bottom return cat([top, middle, bottom]) @@ -323,26 +306,22 @@ def from_inner(self, X: Union[Tensor, None] = None) -> HierarchicalMatrixTemplat `self.T @ X @ X^T @ self`. """ if X is None: - A_new = supported_matmul(self.A.T, self.A) - B_new = 2 * supported_matmul(self.A.T, self.B) + A_new = self.A.T @ self.A + B_new = 2 * self.A.T @ self.B # parts of B that share columns with C, E B_C, B_E = self.B.split([self.diag_dim, self.K2], dim=1) C_new = self.C**2 + (B_C**2).sum(0) + (self.D**2).sum(0) - D_new = 2 * ( - supported_matmul(B_E.T, B_C) + supported_matmul(self.E.T, self.D) - ) - E_new = supported_matmul(self.E.T, self.E) + supported_matmul(B_E.T, B_E) + D_new = 2 * (B_E.T @ B_C + self.E.T @ self.D) + E_new = self.E.T @ self.E + B_E.T @ B_E else: S_A, S_C, S_E = self.rmatmat(X).split([self.K1, self.diag_dim, self.K2]) - A_new = supported_matmul(S_A, S_A.T) - B_new = 2 * cat( - [supported_matmul(S_A, S_C.T), supported_matmul(S_A, S_E.T)], dim=1 - ) + A_new = S_A @ S_A.T + B_new = 2 * cat([S_A @ S_C.T, S_A @ S_E.T], dim=1) C_new = (S_C**2).sum(1) - D_new = 2 * supported_matmul(S_E, S_C.T) - E_new = supported_matmul(S_E, S_E.T) + D_new = 2 * S_E @ S_C.T + E_new = S_E @ S_E.T return self.__class__(A_new, B_new, C_new, D_new, E_new) diff --git a/singd/structures/triltoeplitz.py b/singd/structures/triltoeplitz.py index a8bc7a0..3e85b5d 100644 --- a/singd/structures/triltoeplitz.py +++ b/singd/structures/triltoeplitz.py @@ -6,15 +6,10 @@ import torch from torch import Tensor, arange, cat, zeros -from torch.nn.functional import pad +from torch.nn.functional import conv1d, pad from singd.structures.base import StructuredMatrix -from singd.structures.utils import ( - all_traces, - lowest_precision, - supported_conv1d, - toeplitz_matmul, -) +from singd.structures.utils import all_traces, lowest_precision, toeplitz_matmul class TrilToeplitzMatrix(StructuredMatrix): @@ -149,7 +144,7 @@ def __matmul__( # need to create fake channel dimensions conv_input = pad(other._lower_diags, (dim - 1, 0)).unsqueeze(0) conv_weight = col.flip(0).unsqueeze(0).unsqueeze(0) - mat_column = supported_conv1d(conv_input, conv_weight).squeeze(0) + mat_column = conv1d(conv_input, conv_weight).squeeze(0) return TrilToeplitzMatrix(mat_column) def rmatmat(self, mat: Tensor) -> Tensor: diff --git a/singd/structures/triutoeplitz.py b/singd/structures/triutoeplitz.py index 76614f4..832a12e 100644 --- a/singd/structures/triutoeplitz.py +++ b/singd/structures/triutoeplitz.py @@ -6,15 +6,10 @@ import torch from torch import Tensor, arange, cat, triu_indices, zeros -from torch.nn.functional import pad +from torch.nn.functional import conv1d, pad from singd.structures.base import StructuredMatrix -from singd.structures.utils import ( - all_traces, - lowest_precision, - supported_conv1d, - toeplitz_matmul, -) +from singd.structures.utils import all_traces, lowest_precision, toeplitz_matmul class TriuToeplitzMatrix(StructuredMatrix): @@ -147,7 +142,7 @@ def __matmul__( # need to create fake channel dimensions conv_input = pad(other._upper_diags, (dim - 1, 0)).unsqueeze(0) conv_weight = row.flip(0).unsqueeze(0).unsqueeze(0) - mat_row = supported_conv1d(conv_input, conv_weight).squeeze(0) + mat_row = conv1d(conv_input, conv_weight).squeeze(0) return TriuToeplitzMatrix(mat_row) def rmatmat(self, mat: Tensor) -> Tensor: diff --git a/singd/structures/utils.py b/singd/structures/utils.py index af000e2..2ed041b 100644 --- a/singd/structures/utils.py +++ b/singd/structures/utils.py @@ -8,7 +8,6 @@ arange, bfloat16, device, - einsum, eye, float16, float32, @@ -30,40 +29,6 @@ def is_half_precision(dtype: torch.dtype) -> bool: return dtype in [float16, bfloat16] -def supported_matmul(*matrices: Tensor) -> Tensor: - """Multiply matrices with the same or higher numerical precision. - - If the matrix multiplication is not supported on the hardware, - carry out the multiplication in single precision. - - Args: - matrices: The matrices to multiply. - - Returns: - The result of the matrix chain multiplication in the original precision. - - Raises: - RuntimeError: If the matrices are not on the same device. - """ - devices = {m.device for m in matrices} - if len(devices) > 1: - raise RuntimeError("Matrices must be on the same device.") - dev = devices.pop() - - # Use the first matrix's data type as the result's data type. - # The matrices may have different data types if `autocast` was used. - dtype = matrices[0].dtype - - # @ not supported on CPU for float16 (bfloat16 is supported) - convert = dtype == float16 and str(dev) == "cpu" - - result = matrices[0].to(float32) if convert else matrices[0] - for mat in matrices[1:]: - result = result @ mat.to(result.dtype) - - return result.to(dtype) if convert else result - - def supported_eye(n: int, **kwargs: Any) -> Tensor: """Same as PyTorch's `eye`, but uses higher precision if unsupported. @@ -86,40 +51,6 @@ def supported_eye(n: int, **kwargs: Any) -> Tensor: return eye(n, **kwargs, dtype=dtype) -def supported_einsum(equation, *operands: Tensor) -> Tensor: - """Compute an `einsum` with the same or higher numerical precision. - - If the `einsum` is not supported on the hardware, - carry out the multiplication in single precision. - - Args: - equation: The `einsum` equation. - operands: The operands to the `einsum`. - - Returns: - The result of the `einsum` in the original precision. - - Raises: - RuntimeError: If the operands are not on the same device. - """ - devices = {m.device for m in operands} - if len(devices) > 1: - raise RuntimeError("Operands must be on the same device.") - dev = devices.pop() - - # Use the first tensor's data type as the result's data type. - # The tensors may have different data types if `autocast` was used. - dtype = operands[0].dtype - - # @ not supported on CPU for float16 (bfloat16 is supported) - convert = dtype == float16 and str(dev) == "cpu" - - operands = tuple(m.to(float32) if convert else m for m in operands) - result = einsum(equation, *operands) - - return result.to(dtype) if convert else result - - def all_traces(mat: Tensor) -> Tensor: """Compute the traces of a matrix across all diagonals. @@ -149,44 +80,6 @@ def all_traces(mat: Tensor) -> Tensor: return traces -def supported_conv1d( - input: Tensor, weight: Tensor, padding: int = 0, groups: int = 1 -) -> Tensor: - """Same as PyTorch's `conv1d`, but uses higher precision if unsupported. - - For now, we don't support bias and non-default hyper-parameters. - - Args: - input: The input of the convolution. Has shape `[N, C_in, I_1]`. - weight: The kernel of the convolution. Has shape `[C_out, C_in // G, K_1]`. - padding: The amount of padding on both sides of the input. Default: `0`. - groups: The number of groups `G`. Default: `1`. - - Returns: - The output of the convolution in the same precision as `input`. - Has shape `[N, C_out, O_1]`, where `O_1 = I_1 - K_1 + 1`. - - Raises: - RuntimeError: If input and kernel are not on the same device. - """ - devices = {input.device, weight.device} - if len(devices) > 1: - raise RuntimeError("Input and kernel must be on the same device.") - dev = devices.pop() - - # Use the input's data type as the result's data type. - # Input and kernel may have different data types if `autocast` was used. - dtype = input.dtype - - # 'slow_conv2d_cpu' not implemented for 'Half' (bfloat16 is supported) - if dtype == float16 and str(dev) == "cpu": - return conv1d( - input.to(float32), weight.to(float32), padding=padding, groups=groups - ).to(dtype) - else: - return conv1d(input, weight, padding=padding, groups=groups) - - def toeplitz_matmul(coeffs: Tensor, mat: Tensor) -> Tensor: """Compute the product of a Toeplitz matrix and a matrix. @@ -216,9 +109,7 @@ def toeplitz_matmul(coeffs: Tensor, mat: Tensor) -> Tensor: # columns act as channels conv_input = mat.T conv_weight = coeffs.unsqueeze(0).unsqueeze(0).expand(num_cols, -1, -1) - conv_result = supported_conv1d( - conv_input, conv_weight, padding=padding, groups=num_cols - ) + conv_result = conv1d(conv_input, conv_weight, padding=padding, groups=num_cols) return conv_result.T diff --git a/test/structures/test_utils.py b/test/structures/test_utils.py index 476ff7b..c45347c 100644 --- a/test/structures/test_utils.py +++ b/test/structures/test_utils.py @@ -1,5 +1,7 @@ """Test utility functions of ``singd.structures``.""" +from sys import platform + from pytest import raises from torch import ( Tensor, @@ -18,41 +20,38 @@ from singd.structures.utils import all_traces, diag_add_ -def test_cpu_float16_matmul_unsupported(): - """Test whether ``@`` between two ``float16`` tensors on CPU is unsupported.""" +def test_cpu_float16_matmul_supported(): + """Test whether ``@`` between two ``float16`` tensors on CPU is supported.""" cpu = device("cpu") mat1 = zeros((2, 2), dtype=float16, device=cpu) mat2 = zeros((2, 2), dtype=float16, device=cpu) - - with raises(RuntimeError): - _ = mat1 @ mat2 + _ = mat1 @ mat2 -def test_cpu_bfloat16_eye_unsupported(): - """Test whether ``eye`` is unsupported in ``bfloat16`` on CPU.""" +def test_eye_support(): + """Test whether ``eye`` is unsupported in ``bfloat16`` on MAC+CPU.""" cpu = device("cpu") - with raises(RuntimeError): + if platform == "darwin": + with raises(RuntimeError): + eye(2, dtype=bfloat16, device=cpu) + else: eye(2, dtype=bfloat16, device=cpu) -def test_cpu_float16_einsum_unsupported(): - """Test whether ``einsum`` is unsupported in ``float16`` on CPU.""" +def test_cpu_float16_einsum_supported(): + """Test whether ``einsum`` is supported in ``float16`` on CPU.""" cpu = device("cpu") mat1 = zeros((2, 2, 2), dtype=float16, device=cpu) mat2 = zeros((2, 2, 2), dtype=float16, device=cpu) + _ = einsum("nij,njk->nik", mat1, mat2) - with raises(RuntimeError): - einsum("nij,njk->nik", mat1, mat2) - -def test_cpu_float16_conv1d_unsupported(): - """Test whether ``conv1d`` is unsupported in ``float16`` on CPU.""" +def test_cpu_float16_conv1d_supported(): + """Test whether ``conv1d`` is supported in ``float16`` on CPU.""" cpu = device("cpu") inputs = zeros((2, 2, 2), dtype=float16, device=cpu) kernel = zeros((2, 2, 1), dtype=float16, device=cpu) - - with raises(RuntimeError): - conv1d(inputs, kernel) + _ = conv1d(inputs, kernel) def test_all_traces(): diff --git a/test/structures/utils.py b/test/structures/utils.py index 17e0cf6..145a738 100644 --- a/test/structures/utils.py +++ b/test/structures/utils.py @@ -14,7 +14,7 @@ from torch.linalg import matrix_norm, vector_norm from singd.structures.base import StructuredMatrix -from singd.structures.utils import is_half_precision, supported_eye, supported_matmul +from singd.structures.utils import is_half_precision, supported_eye DTYPES = [torch.float32, torch.float16, torch.bfloat16] DTYPE_IDS = [str(dt).split(".")[-1] for dt in DTYPES] @@ -48,13 +48,15 @@ def _test_matmul( } # multiplication with a structured matrix - truth = supported_matmul(project(sym_mat1), project(sym_mat2)) + truth = project(sym_mat1) @ project(sym_mat2) report_nonclose( - truth, (sym_mat1_structured @ sym_mat2_structured).to_dense(), **tolerances + truth, + (sym_mat1_structured @ sym_mat2_structured).to_dense(), + **tolerances, ) # multiplication with a PyTorch tensor - truth = supported_matmul(project(sym_mat1), mat2) + truth = project(sym_mat1) @ mat2 report_nonclose(truth, sym_mat1_structured @ mat2, **tolerances) @@ -133,7 +135,7 @@ def _test_rmatmat( project: A function which converts an arbitrary symmetric dense matrix into a dense matrix of the tested structure. Used to establish the ground truth. """ - truth = supported_matmul(project(sym_mat1).T, mat2) + truth = project(sym_mat1).T @ mat2 sym_mat1_structured = structured_matrix_cls.from_dense(sym_mat1) report_nonclose( @@ -162,9 +164,10 @@ def _test_from_inner( X: An optional matrix which will be passed to the `from_inner` method. """ if X is None: - truth = project(supported_matmul(project(sym_mat).T, project(sym_mat))) + truth = project(project(sym_mat).T @ project(sym_mat)) else: - truth = project(supported_matmul(project(sym_mat).T, X, X.T, project(sym_mat))) + MTX = project(sym_mat).T @ X + truth = project(MTX @ MTX.T) sym_mat_structured = structured_matrix_cls.from_dense(sym_mat) report_nonclose( @@ -192,7 +195,7 @@ def _test_from_inner2( dense matrix of the tested structure. Used to establish the ground truth. XXT: An symmetric PSD matrix that will be passed to `from_inner2`. """ - truth = project(supported_matmul(project(sym_mat).T, XXT, project(sym_mat))) + truth = project(project(sym_mat).T @ XXT @ project(sym_mat)) sym_mat_structured = structured_matrix_cls.from_dense(sym_mat) report_nonclose( truth, @@ -490,7 +493,7 @@ def test_from_inner2(self, dev: device, dtype: torch.dtype): sym_mat = symmetrize(rand((dim, dim), device=dev, dtype=dtype)) X = rand((dim, 2 * dim), device=dev, dtype=dtype) - XXT = supported_matmul(X, X.T) + XXT = X @ X.T _test_from_inner2(sym_mat, self.STRUCTURED_MATRIX_CLS, self.project, XXT) @mark.parametrize("dtype", DTYPES, ids=DTYPE_IDS)