diff --git a/lib/codecs/array_to_array.ml b/lib/codecs/array_to_array.ml index f4c229f7..05decb2a 100644 --- a/lib/codecs/array_to_array.ml +++ b/lib/codecs/array_to_array.ml @@ -1,6 +1,5 @@ open Codecs_intf -module Ndarray = Owl.Dense.Ndarray.Generic (* https://zarr-specs.readthedocs.io/en/latest/v3/codecs/transpose/v1.0.html *) module TransposeCodec = struct @@ -29,7 +28,6 @@ module TransposeCodec = struct the decoded representation dimensionality." in Result.error @@ `Transpose_order (t, msg) - let parse_order o = if Array.length o = 0 then let msg = "transpose order cannot be empty." in @@ -66,14 +64,23 @@ module TransposeCodec = struct else Ok () - let encode o x = - try Ok (Ndarray.transpose ~axis:o x) with - | Failure s -> Error (`Transpose_order (o, s)) + let transpose ?axis x = + let module A = Owl.Dense.Ndarray.Any in + let module N = Owl.Dense.Ndarray.Generic in + try + let y = A.transpose ?axis @@ A.init_nd (N.shape x) @@ N.get x in + Result.ok @@ N.init_nd (N.kind x) (A.shape y) @@ A.get y + with + | Assert_failure _ -> + Result.error @@ + `Transpose_order (Option.get axis, "Invalid transpose order.") + + let encode o x = transpose ~axis:o x let decode o x = let inv_order = Array.(make (length o) 0) in Array.iteri (fun i x -> inv_order.(x) <- i) o; - Ok (Ndarray.transpose ~axis:inv_order x) + transpose ~axis:inv_order x let to_yojson order = let o = diff --git a/test/test_codecs.ml b/test/test_codecs.ml index 2f67fb8b..ee8a0713 100644 --- a/test/test_codecs.ml +++ b/test/test_codecs.ml @@ -40,10 +40,10 @@ let bytes_encode_decode let tests = [ "test codec chain" >:: (fun _ -> let decoded_repr - : (float, Bigarray.float32_elt) array_repr = + : (int, Bigarray.int16_signed_elt) array_repr = {shape = [|10; 15; 10|] - ;kind = Bigarray.Float32 - ;fill_value = (-10.)} + ;kind = Bigarray.Int16_signed + ;fill_value = 10} in let shard_cfg = {chunk_shape = [|2; 5; 5|] @@ -323,7 +323,7 @@ let tests = [ let cfg = {chunk_shape = [|3; 5; 5|] ;index_location = Start - ;index_codecs = [`Bytes LE; `Crc32c] + ;index_codecs = [`Transpose [|0; 3; 1; 2|]; `Bytes LE; `Crc32c] ;codecs = [`Bytes BE]} in let chain = [`ShardingIndexed cfg] in @@ -362,20 +362,6 @@ let tests = [ assert_failure "Successfully encoded array should decode without fail"); - (* test if including a transpose codec for index_codec chain results in - a failure. *) - let chain' = - [`ShardingIndexed {cfg with - chunk_shape = [|5; 3; 5|] - ;index_codecs = `Transpose [|0; 3; 1; 2|] :: cfg.index_codecs}] - in - let cc = Chain.create decoded_repr chain' |> Result.get_ok in - assert_bool - "shard index chain can't be encoded since Owl does not support transposing - Int64 types. See: - https://github.com/owlbarn/owl/issues/671#issuecomment-2211303040" @@ - Result.is_error @@ Chain.encode cc arr; - (* test correctness of decoding nested sharding codecs.*) let str = {|[