Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

A deleter field (function pointer) for CpuStorageObj #664

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 38 additions & 6 deletions src/arraymancer/laser/tensor/datatypes.nim
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,10 @@ type
CpuStorageObj[T] {.shallow.} = object
# Workaround supportsCopyMem in type section - https://github.com/nim-lang/Nim/issues/13193
when T is KnownSupportsCopyMem:
raw_buffer*: ptr UncheckedArray[T] # 8 bytes
memalloc*: pointer # 8 bytes
isMemOwner*: bool # 1 byte
raw_buffer*: ptr UncheckedArray[T] # 8 bytes
memalloc*: pointer # 8 bytes
isMemOwner*: bool # 1 byte
deleter*: proc(memalloc: pointer) {.noconv, gcsafe, raises: [].} # 8 bytes
else: # Tensors of strings, other ref types or non-trivial destructors
raw_buffer*: seq[T] # 8 bytes (16 for seq v2 backed by destructors?)

Expand Down Expand Up @@ -74,14 +75,21 @@ when not defined(gcDestructors):
proc finalizer[T](storage: CpuStorage[T]) =
static: assert T is KnownSupportsCopyMem, "Tensors of seq, strings, ref types and types with non-trivial destructors cannot be finalized by this proc"
if storage.isMemOwner and not storage.memalloc.isNil:
storage.memalloc.deallocShared()
if storage.deleter.isNil:
storage.memalloc.deallocShared()
else:
storage.deleter(storage.memalloc)
storage.deleter = nil
storage.memalloc = nil
else:
when (NimMajor, NimMinor, NimPatch) >= (2, 0, 0):
proc `=destroy`[T](storage: CpuStorageObj[T]) =
when T is KnownSupportsCopyMem:
if storage.isMemOwner and not storage.memalloc.isNil:
storage.memalloc.deallocShared()
if storage.deleter.isNil:
storage.memalloc.deallocShared()
else:
storage.deleter(storage.memalloc)
else:
# The following cast removes the following warning:
# =destroy(storage.raw_buffer) can raise an unlisted exception: Exception
Expand All @@ -91,7 +99,11 @@ else:
proc `=destroy`[T](storage: var CpuStorageObj[T]) =
when T is KnownSupportsCopyMem:
if storage.isMemOwner and not storage.memalloc.isNil:
storage.memalloc.deallocShared()
if storage.deleter.isNil:
storage.memalloc.deallocShared()
else:
storage.deleter(storage.memalloc)
storage.deleter = nil
storage.memalloc = nil
else:
`=destroy`(storage.raw_buffer)
Expand All @@ -113,6 +125,7 @@ proc allocCpuStorage*[T](storage: var CpuStorage[T], size: int) =
storage.memalloc = allocShared(sizeof(T) * size + LASER_MEM_ALIGN - 1)
storage.isMemOwner = true
storage.raw_buffer = align_raw_data(T, storage.memalloc)
storage.deleter = nil
else: # Always 0-initialize Tensors of seq, strings, ref types and types with non-trivial destructors
new(storage)
storage.raw_buffer.newSeq(size)
Expand All @@ -133,6 +146,25 @@ proc cpuStorageFromBuffer*[T: KnownSupportsCopyMem](
storage.memalloc = rawBuffer
storage.isMemOwner = false
storage.raw_buffer = cast[ptr UncheckedArray[T]](storage.memalloc)
storage.deleter = nil

proc cpuStorageFromBuffer*[T: KnownSupportsCopyMem](
storage: var CpuStorage[T],
rawBuffer: pointer,
memalloc: pointer,
deleter: proc(memalloc: pointer) {.noconv, gcsafe, raises: [].}) =
## Create a `CpuStorage`, which stores data from a given raw pointer, which
## it will own. The destructor/finalizer will do a `deleter(memalloc)` call.
##
## The input buffer must be a raw `pointer`.
when not defined(gcDestructors):
new(storage, finalizer[T])
else:
new(storage)
storage.memalloc = memalloc
storage.isMemOwner = true
storage.raw_buffer = cast[ptr UncheckedArray[T]](rawBuffer)
storage.deleter = deleter

func is_C_contiguous*(t: Tensor): bool =
## Check if the tensor follows C convention / is row major
Expand Down
Loading