Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reshape_infer procedure #646

Merged
merged 1 commit into from
Apr 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 28 additions & 3 deletions src/arraymancer/tensor/private/p_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import ../../laser/tensor/initialization,
./p_checks,
nimblas

import std / sequtils

proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T]) =
if layout == rowMajor:
result = t.map_inline(x)
Expand All @@ -28,16 +30,39 @@ proc contiguousImpl*[T](t: Tensor[T], layout: OrderType, result: var Tensor[T])
apply2_inline(result, t):
y

proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata, result: var Tensor[T]) =
proc reshape_with_copy*[T](t: Tensor[T], new_shape: varargs[int]|Metadata|seq[int], result: var Tensor[T]) =
result = newTensorUninit[T](new_shape)
result.apply2_inline(t,y)

proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
proc reshape_no_copy*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int], result: var AnyTensor, layout: OrderType) {.noSideEffect.}=
result.shape.copyFrom(new_shape)
shape_to_strides(result.shape, layout, result.strides)
result.offset = t.offset

proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata, result: var AnyTensor) =
proc infer_shape*(t: Tensor, new_shape: varargs[int]): seq[int] {.noinit.} =
## Replace the single -1 value on `new_shape` with the value that
## makes the size the same as that of the input tensor
result = new_shape.toSeq
var auto_axis = -1
var auto_axis_count = 0
for n in 0 .. result.high:
if result[n] == -1:
auto_axis_count += 1
auto_axis = n
break
if auto_axis_count > 1:
raise newException(ValueError, "Only one dimension can be inferred by inferShape")
elif auto_axis_count == 0:
when compileOption("boundChecks"):
raise newException(ValueError, "At least one dimension must be inferred by inferShape")
else:
result[auto_axis] = t.size div result.filterIt(it != -1).prod

proc reshapeImpl*(t: AnyTensor, new_shape: varargs[int]|Metadata|seq[int],
result: var AnyTensor, infer: static bool) =
when infer:
let new_shape = t.infer_shape(new_shape)

when compileOption("boundChecks"):
check_reshape(t, new_shape)

Expand Down
22 changes: 19 additions & 3 deletions src/arraymancer/tensor/shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,12 @@ proc reshape*(t: Tensor, new_shape: varargs[int]): Tensor {.noinit.} =
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same
## - a new shape. Number of elements must be the same. Unlike numpy,
## dimensions cannot be -1 to infer their value. If that is what you need
## you must use the alternative `reshape_infer` proc.
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)
reshapeImpl(t, new_shape, result, infer = false)

proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
Expand All @@ -78,7 +80,21 @@ proc reshape*(t: Tensor, new_shape: Metadata): Tensor {.noinit.} =
## - a new shape. Number of elements must be the same
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result)
reshapeImpl(t, new_shape, result, infer = false)

proc reshape_infer*(t: Tensor, new_shape: varargs[int]):
Tensor {.noinit.} =
## Reshape a tensor. If possible no data copy is done and the returned tensor
## shares data with the input. If input is not contiguous, this is not possible
## and a copy will be made.
##
## Input:
## - a tensor
## - a new shape. Number of elements must be the same. The new shape can
## contain -1 to infer the size of one (and only one) dimension
## Returns:
## - a tensor with the same data but reshaped.
reshapeImpl(t, new_shape, result, infer = true)

proc flatten*(t: Tensor): Tensor {.noinit,inline.} =
## Flatten a tensor, returning a rank-1 tensor with the same data as the input.
Expand Down
6 changes: 5 additions & 1 deletion tests/tensor/test_shapeshifting.nim
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
let d = a.asContiguous(colMajor, force = true)
# this test needs `toRawSeq` due to the changed layout. `toFlatSeq` provides the
# same as for `c` above!
check: d.toRawSeq == @[7, 8, 2, 4, 1, 0, 3, 6, 4, 1, 2, 3, 8, 6, 2, 6, 6, 0]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (version-1-6)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (version-1-6)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (version-1-6)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (version-1-6)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (version-2-0)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (version-2-0)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (version-2-0)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (version-2-0)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (devel)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / linux-amd64-c (devel)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (devel)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]

Check warning on line 45 in tests/tensor/test_shapeshifting.nim

View workflow job for this annotation

GitHub Actions / macos-amd64-c (devel)

This proc cannot be reimplemented in a backward compatible way.; toRawSeq is deprecated [Deprecated]


# # Now test with a non contiguous tensor
Expand All @@ -62,9 +62,13 @@
check: a == b

test "Reshape":
let a = toSeq(1..4).toTensor().reshape(2,2)
let a = toSeq(1..4).toTensor().reshape(2, 2)
let b = toSeq(1..4).toTensor().reshape_infer(-1, 2)
let c = toSeq(1..4).toTensor().reshape_infer(2, -1)
check: a == [[1,2],
[3,4]].toTensor()
check: a == b
check: a == c

test "Unsafe reshape":
block:
Expand Down
Loading