Skip to content

Commit

Permalink
expose data as UnsafeMutableBufferPointer
Browse files Browse the repository at this point in the history
  • Loading branch information
dastrobu committed Apr 2, 2022
1 parent ecca47e commit 5384873
Show file tree
Hide file tree
Showing 13 changed files with 109 additions and 103 deletions.
8 changes: 4 additions & 4 deletions Sources/NdArray/Equitable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ extension NdArray: Equatable where T: Equatable {
case 0:
return true
case 1:
var l = lhs.data
var r = rhs.data
var l = lhs.dataStart
var r = rhs.dataStart
let ls = lhs.strides[0]
let rs = rhs.strides[0]
for _ in 0..<lhs.count {
Expand All @@ -37,8 +37,8 @@ extension NdArray: Equatable where T: Equatable {
}
default:
if (lhs.isCContiguous && rhs.isCContiguous) || (lhs.isFContiguous && rhs.isFContiguous) {
var l = lhs.data
var r = rhs.data
var l = lhs.dataStart
var r = rhs.dataStart
for _ in 0..<lhs.count {
if l.pointee != r.pointee {
return false
Expand Down
28 changes: 14 additions & 14 deletions Sources/NdArray/Matrix.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,15 +28,15 @@ open class Matrix<T>: NdArray<T> {
for i in 0..<rowCount {
let row = a[i]
precondition(row.count == colCount, "\(row.count) == \(colCount) at row \(i)")
memcpy(data + i * strides[0], row, colCount * MemoryLayout<T>.stride)
memcpy(dataStart + i * strides[0], row, colCount * MemoryLayout<T>.stride)
}
case .F:
for i in 0..<rowCount {
let row = a[i]
precondition(row.count == colCount, "\(row.count) == \(colCount) at row \(i)")
// manual memcopy for strided data
for j in 0..<colCount {
data[i * strides[0] + j * strides[1]] = row[j]
dataStart[i * strides[0] + j * strides[1]] = row[j]
}
}
}
Expand Down Expand Up @@ -177,7 +177,7 @@ public extension Matrix where T == Double {
var lda: __CLPK_integer = __CLPK_integer(n)
var ldb = __CLPK_integer(B.shape[0])
var info: __CLPK_integer = 0
dgesv_(&n, &nrhs, A.data, &lda, &ipiv, B.data, &ldb, &info)
dgesv_(&n, &nrhs, A.dataStart, &lda, &ipiv, B.dataStart, &ldb, &info)
if info != 0 {
throw LapackError.dgesv(info)
}
Expand Down Expand Up @@ -212,7 +212,7 @@ public extension Matrix where T == Double {
// do optimal workspace query
var lwork: __CLPK_integer = -1
var work = [__CLPK_doublereal](repeating: 0.0, count: 1)
dgetri_(&n, A.data, &lda, &ipiv, &work, &lwork, &info)
dgetri_(&n, A.dataStart, &lda, &ipiv, &work, &lwork, &info)
if info != 0 {
throw LapackError.getri(info)
}
Expand All @@ -222,7 +222,7 @@ public extension Matrix where T == Double {
work = [__CLPK_doublereal](repeating: 0.0, count: Int(lwork))

// do the inversion
dgetri_(&n, A.data, &lda, &ipiv, &work, &lwork, &info)
dgetri_(&n, A.dataStart, &lda, &ipiv, &work, &lwork, &info)
if info != 0 {
throw LapackError.getri(info)
}
Expand Down Expand Up @@ -257,7 +257,7 @@ public extension Matrix where T == Double {
// leading dimension is the number of rows in column major order
var lda = __CLPK_integer(m)
var info: __CLPK_integer = 0
dgetrf_(&m, &n, data, &lda, &ipiv, &info)
dgetrf_(&m, &n, dataStart, &lda, &ipiv, &info)
if info != 0 {
throw LapackError.getrf(info)
}
Expand Down Expand Up @@ -336,7 +336,7 @@ public extension Matrix where T == Float {
var lda: __CLPK_integer = __CLPK_integer(n)
var ldb = __CLPK_integer(B.shape[0])
var info: __CLPK_integer = 0
sgesv_(&n, &nrhs, A.data, &lda, &ipiv, B.data, &ldb, &info)
sgesv_(&n, &nrhs, A.dataStart, &lda, &ipiv, B.dataStart, &ldb, &info)
if info != 0 {
throw LapackError.dgesv(info)
}
Expand Down Expand Up @@ -371,7 +371,7 @@ public extension Matrix where T == Float {
// do optimal workspace query
var lwork: __CLPK_integer = -1
var work = [__CLPK_real](repeating: 0.0, count: 1)
sgetri_(&n, A.data, &lda, &ipiv, &work, &lwork, &info)
sgetri_(&n, A.dataStart, &lda, &ipiv, &work, &lwork, &info)
if info != 0 {
throw LapackError.getri(info)
}
Expand All @@ -381,7 +381,7 @@ public extension Matrix where T == Float {
work = [__CLPK_real](repeating: 0.0, count: Int(lwork))

// do the inversion
sgetri_(&n, A.data, &lda, &ipiv, &work, &lwork, &info)
sgetri_(&n, A.dataStart, &lda, &ipiv, &work, &lwork, &info)
if info != 0 {
throw LapackError.getri(info)
}
Expand Down Expand Up @@ -416,7 +416,7 @@ public extension Matrix where T == Float {
// leading dimension is the number of rows in column major order
var lda = __CLPK_integer(m)
var info: __CLPK_integer = 0
sgetrf_(&m, &n, data, &lda, &ipiv, &info)
sgetrf_(&m, &n, dataStart, &lda, &ipiv, &info)
if info != 0 {
throw LapackError.getrf(info)
}
Expand Down Expand Up @@ -448,7 +448,7 @@ public func * (A: Matrix<Double>, x: Vector<Double>) -> Vector<Double> {
let lda: Int32 = Int32(a.shape[0])
let incX: Int32 = Int32(x.strides[0])
let incY: Int32 = Int32(y.strides[0])
cblas_dgemv(order, CblasNoTrans, m, n, 1, a.data, lda, x.data, incX, 0, y.data, incY)
cblas_dgemv(order, CblasNoTrans, m, n, 1, a.dataStart, lda, x.dataStart, incX, 0, y.dataStart, incY)

return y
}
Expand Down Expand Up @@ -481,7 +481,7 @@ public func * (A: Matrix<Double>, B: Matrix<Double>) -> Matrix<Double> {
let lda: Int32 = Int32(a.shape[0])
let ldb: Int32 = Int32(b.shape[0])
let ldc: Int32 = Int32(c.shape[0])
cblas_dgemm(order, CblasNoTrans, CblasNoTrans, m, n, k, 1, a.data, lda, b.data, ldb, 0, c.data, ldc)
cblas_dgemm(order, CblasNoTrans, CblasNoTrans, m, n, k, 1, a.dataStart, lda, b.dataStart, ldb, 0, c.dataStart, ldc)
return c
}

Expand Down Expand Up @@ -510,7 +510,7 @@ public func * (A: Matrix<Float>, x: Vector<Float>) -> Vector<Float> {
let lda: Int32 = Int32(a.shape[0])
let incX: Int32 = Int32(x.strides[0])
let incY: Int32 = Int32(y.strides[0])
cblas_sgemv(order, CblasNoTrans, m, n, 1, a.data, lda, x.data, incX, 0, y.data, incY)
cblas_sgemv(order, CblasNoTrans, m, n, 1, a.dataStart, lda, x.dataStart, incX, 0, y.dataStart, incY)

return y
}
Expand Down Expand Up @@ -544,6 +544,6 @@ public func * (A: Matrix<Float>, B: Matrix<Float>) -> Matrix<Float> {
let lda: Int32 = Int32(a.shape[0])
let ldb: Int32 = Int32(b.shape[0])
let ldc: Int32 = Int32(c.shape[0])
cblas_sgemm(order, CblasNoTrans, CblasNoTrans, m, n, k, 1, a.data, lda, b.data, ldb, 0, c.data, ldc)
cblas_sgemm(order, CblasNoTrans, CblasNoTrans, m, n, k, 1, a.dataStart, lda, b.dataStart, ldb, 0, c.dataStart, ldc)
return c
}
29 changes: 17 additions & 12 deletions Sources/NdArray/NdArray.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,13 @@ public enum Contiguous {
open class NdArray<T>: CustomDebugStringConvertible,
CustomStringConvertible {

/// data buffer
internal(set) public var data: UnsafeMutablePointer<T>
/// data buffer start
internal var dataStart: UnsafeMutablePointer<T>

/// data buffer for row data access
public var data: UnsafeMutableBufferPointer<T> {
UnsafeMutableBufferPointer(start: dataStart, count: count)
}

/// length of the buffer
internal var count: Int
Expand Down Expand Up @@ -53,7 +58,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
/// create a new array without initializing any memory
public required init(empty count: Int = 0) {
self.count = count
data = UnsafeMutablePointer<T>.allocate(capacity: count)
dataStart = UnsafeMutablePointer<T>.allocate(capacity: count)
if count == 0 {
shape = [0]
} else {
Expand Down Expand Up @@ -95,7 +100,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
} else {
self.owner = a.owner
}
self.data = a.data
self.dataStart = a.dataStart
self.count = a.count
self.shape = a.shape
self.strides = a.strides
Expand All @@ -104,7 +109,7 @@ open class NdArray<T>: CustomDebugStringConvertible,

deinit {
if ownsData {
data.deallocate()
dataStart.deallocate()
}
}

Expand Down Expand Up @@ -160,7 +165,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
/// create an 1D NdArray from a plain array
public convenience init(_ a: [T]) {
self.init(empty: a.count)
data.initialize(from: a, count: a.count)
dataStart.initialize(from: a, count: a.count)
}

/// create an 2D NdArray from a plain array
Expand All @@ -179,7 +184,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
for i in 0..<rowCount {
let row = a[i]
precondition(row.count == colCount, "\(row.count) == \(colCount) at row \(i)")
memcpy(data + i * strides[0], row, colCount * MemoryLayout<T>.stride)
memcpy(dataStart + i * strides[0], row, colCount * MemoryLayout<T>.stride)
}
case .F:
for i in 0..<rowCount {
Expand Down Expand Up @@ -212,7 +217,7 @@ open class NdArray<T>: CustomDebugStringConvertible,
for j in 0..<jCount {
let aij = ai[j]
precondition(aij.count == kCount, "\(aij.count) == \(kCount) at index \(i), \(j)")
memcpy(data + i * strides[0] + j * strides[1], aij, kCount * MemoryLayout<T>.stride)
memcpy(dataStart + i * strides[0] + j * strides[1], aij, kCount * MemoryLayout<T>.stride)
}
}
case .F:
Expand All @@ -235,11 +240,11 @@ open class NdArray<T>: CustomDebugStringConvertible,
if isEmpty {
return []
}
return Array(UnsafeBufferPointer(start: data, count: count))
return Array(UnsafeBufferPointer(start: dataStart, count: count))
}

public var debugDescription: String {
let address = String(format: "%p", Int(bitPattern: data))
let address = String(format: "%p", Int(bitPattern: dataStart))
return "NdArray(shape: \(shape), strides: \(strides), data: \(address))"
}

Expand Down Expand Up @@ -473,11 +478,11 @@ extension NdArray {
/// - Returns: true if data regions of this array overlap with data region of the other array
public func overlaps(_ other: NdArray<T>) -> Bool {
// check if other starts within our memory
if other.data >= self.data && other.data < self.data + self.count {
if other.dataStart >= self.dataStart && other.dataStart < self.dataStart + self.count {
return true
}
// check if our memory starts within other memory
if self.data >= other.data && self.data < other.data + other.count {
if self.dataStart >= other.dataStart && self.dataStart < other.dataStart + other.count {
return true
}
return false
Expand Down
4 changes: 2 additions & 2 deletions Sources/NdArray/NdArraySlice.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ internal class NdArraySlice<T>: NdArray<T> {
super.init(a)

let start = a.flatIndex(startIndex)
data = a.data + start
dataStart = a.dataStart + start
count = a.len
}

Expand Down Expand Up @@ -318,7 +318,7 @@ internal class NdArraySlice<T>: NdArray<T> {
}

public override var debugDescription: String {
let address = String(format: "%p", Int(bitPattern: data))
let address = String(format: "%p", Int(bitPattern: dataStart))
var sliceDescription = sliceDescription.joined()
if sliceDescription == "" {
sliceDescription = "-"
Expand Down
22 changes: 11 additions & 11 deletions Sources/NdArray/Vector.swift
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ open class Vector<T>: NdArray<T>, Sequence {
/// create an 1D NdArray from a plain array
public convenience init(_ a: [T]) {
self.init(empty: a.count)
data.initialize(from: a, count: a.count)
dataStart.initialize(from: a, count: a.count)
}

public required convenience init(copy a: NdArray<T>) {
Expand Down Expand Up @@ -64,12 +64,12 @@ public extension Vector where T == Double {
Precondition failed while trying to compute dot product for vectors from \(debugDescription) and \(y.debugDescription).
""")
let n = Int32(shape[0])
return cblas_ddot(n, data, Int32(strides[0]), y.data, Int32(y.strides[0]))
return cblas_ddot(n, dataStart, Int32(strides[0]), y.dataStart, Int32(y.strides[0]))
}

func norm2() -> T {
let n = Int32(shape[0])
return cblas_dnrm2(n, data, Int32(strides[0]))
return cblas_dnrm2(n, dataStart, Int32(strides[0]))
}

func sort(order: SortOrder = .ascending) {
Expand All @@ -83,18 +83,18 @@ public extension Vector where T == Double {
}

if isContiguous {
vDSP_vsortD(data, n, sortOrder)
vDSP_vsortD(dataStart, n, sortOrder)
} else {
// make a copy sort it and copy back if array is not contiguous
let cpy = Vector(copy: self)
vDSP_vsortD(cpy.data, n, sortOrder)
vDSP_vsortD(cpy.dataStart, n, sortOrder)
self[[0...]] = cpy[[0...]]
}
}

func reverse() {
let n = vDSP_Length(shape[0])
vDSP_vrvrsD(data, strides[0], n)
vDSP_vrvrsD(dataStart, strides[0], n)
}
}

Expand All @@ -106,12 +106,12 @@ public extension Vector where T == Float {
Precondition failed while trying to compute dot product for vectors from \(debugDescription) and \(y.debugDescription).
""")
let n = Int32(shape[0])
return cblas_sdot(n, data, Int32(strides[0]), y.data, Int32(y.strides[0]))
return cblas_sdot(n, dataStart, Int32(strides[0]), y.dataStart, Int32(y.strides[0]))
}

func norm2() -> T {
let n = Int32(shape[0])
return cblas_snrm2(n, data, Int32(strides[0]))
return cblas_snrm2(n, dataStart, Int32(strides[0]))
}

func sort(order: SortOrder = .ascending) {
Expand All @@ -125,18 +125,18 @@ public extension Vector where T == Float {
}

if isContiguous {
vDSP_vsort(data, n, sortOrder)
vDSP_vsort(dataStart, n, sortOrder)
} else {
// make a copy sort it and copy back if array is not contiguous
let cpy = Vector(copy: self)
vDSP_vsort(cpy.data, n, sortOrder)
vDSP_vsort(cpy.dataStart, n, sortOrder)
self[[0...]] = cpy[[0...]]
}
}

func reverse() {
let n = vDSP_Length(shape[0])
vDSP_vrvrs(data, strides[0], n)
vDSP_vrvrs(dataStart, strides[0], n)
}
}

Expand Down
6 changes: 3 additions & 3 deletions Sources/NdArray/VectorSequence.swift
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ public extension Vector {

func makeIterator() -> VectorIterator<T> {
if isEmpty {
return VectorIterator(baseAddress: data, stride: 0, count: 0)
return VectorIterator(baseAddress: dataStart, stride: 0, count: 0)
}
return VectorIterator(baseAddress: data, stride: strides[0], count: shape[0])
return VectorIterator(baseAddress: dataStart, stride: strides[0], count: shape[0])
}

/// - Returns: shape[0] or 0 if vector is empty
Expand All @@ -46,7 +46,7 @@ public extension Vector {
/// If the vector is not contiguous, body is not called and nil is returned.
func withContiguousStorageIfAvailable<R>(_ body: (UnsafeBufferPointer<T>) throws -> R) rethrows -> R? {
if isContiguous {
return try body(UnsafeBufferPointer(start: data, count: count))
return try body(UnsafeBufferPointer(start: dataStart, count: count))
} else {
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/NdArray/apply.swift
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,13 @@ public extension NdArray {
func apply(_ f: (T) throws -> T) rethrows {
try apply1d(f1d: { n in
let s = strides[0]
var p = data
var p = dataStart
for _ in 0..<n {
p.initialize(to: try f(p.pointee))
p += s
}
}, fContiguous: { n in
var p = data
var p = dataStart
for _ in 0..<n {
p.initialize(to: try f(p.pointee))
p += 1
Expand Down
Loading

0 comments on commit 5384873

Please sign in to comment.