Skip to content

Commit

Permalink
[stdlib] Clean up memory.unsafe
Browse files Browse the repository at this point in the history
- Make more things infer-only
- Remove unnecessary overload
- Rename `bitcast` overload that does "movemask" to `pack_mask`

Signed-off-by: Yiwu Chen <210at85@gmail.com>
  • Loading branch information
soraros committed Oct 1, 2024
1 parent 02fc624 commit 130184e
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 45 deletions.
2 changes: 1 addition & 1 deletion stdlib/src/memory/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ from .arc import Arc
from .box import Box
from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation
from .reference import AddressSpace, Reference
from .unsafe import bitcast
from .unsafe import bitcast, pack_mask
from .unsafe_pointer import UnsafePointer
62 changes: 21 additions & 41 deletions stdlib/src/memory/unsafe.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ from sys import bitwidthof

@always_inline("nodebug")
fn bitcast[
new_type: DType, new_width: Int, src_type: DType, src_width: Int
](val: SIMD[src_type, src_width]) -> SIMD[new_type, new_width]:
type: DType,
width: Int, //,
new_type: DType,
new_width: Int = width,
](val: SIMD[type, width]) -> SIMD[new_type, new_width]:
"""Bitcasts a SIMD value to another SIMD value.
Constraints:
The bitwidth of the two types must be the same.
Parameters:
type: The source type.
width: The source width.
new_type: The target type.
new_width: The target width.
src_type: The source type.
src_width: The source width.
Args:
val: The source value.
Expand All @@ -49,13 +52,13 @@ fn bitcast[
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, src_width]]()
bitwidthof[SIMD[type, width]]()
== bitwidthof[SIMD[new_type, new_width]](),
"the source and destination types must have the same bitwidth",
]()

@parameter
if new_type == src_type:
if new_type == type:
return rebind[SIMD[new_type, new_width]](val)
return __mlir_op.`pop.bitcast`[
_type = __mlir_type[
Expand All @@ -65,45 +68,19 @@ fn bitcast[


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_type: DType
](val: SIMD[src_type, 1]) -> SIMD[new_type, 1]:
"""Bitcasts a SIMD value to another SIMD value.
Constraints:
The bitwidth of the two types must be the same.
Parameters:
new_type: The target type.
src_type: The source type.
Args:
val: The source value.
Returns:
A new SIMD value with the specified type and width with a bitcopy of the
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, 1]]() == bitwidthof[SIMD[new_type, 1]](),
"the source and destination types must have the same bitwidth",
]()

return bitcast[new_type, 1, src_type, 1](val)


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_width: Int
](val: SIMD[DType.bool, src_width]) -> Scalar[new_type]:
fn pack_mask[
width: Int, //,
new_type: DType,
](val: SIMD[DType.bool, width]) -> Scalar[new_type]:
"""Packs a SIMD bool into an integer.
Constraints:
The bitwidth of the two types must be the same.
The width of the bool vector must be the same as the bitwidth of the
target type.
Parameters:
width: The source width.
new_type: The target type.
src_width: The source width.
Args:
val: The source value.
Expand All @@ -112,8 +89,11 @@ fn bitcast[
A new integer scalar which has the same bitwidth as the bool vector.
"""
constrained[
src_width == bitwidthof[Scalar[new_type]](),
"the source and destination types must have the same bitwidth",
width == bitwidthof[Scalar[new_type]](),
(
"the width of the bool vector must be the same as the bitwidth of"
" the target type"
),
]()

return __mlir_op.`pop.bitcast`[
Expand Down
6 changes: 3 additions & 3 deletions stdlib/src/utils/stringref.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width
from collections.string import _atol, _isspace
from memory import UnsafePointer, memcmp, bitcast
from memory import UnsafePointer, memcmp, bitcast, pack_mask
from memory.memory import _memcmp_impl_unconstrained
from utils import StringSlice
from sys.ffi import c_char
Expand Down Expand Up @@ -634,7 +634,7 @@ fn _memchr[

for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = source.load[width=bool_mask_width](i) == first_needle
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_mask[_uint_type_of_width[bool_mask_width]()](bool_mask)
if mask:
return source + int(i + count_trailing_zeros(mask))

Expand Down Expand Up @@ -678,7 +678,7 @@ fn _memmem[
var eq_last = last_needle == last_block

var bool_mask = eq_first & eq_last
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_mask[_uint_type_of_width[bool_mask_width]()](bool_mask)

while mask:
var offset = int(i + count_trailing_zeros(mask))
Expand Down

0 comments on commit 130184e

Please sign in to comment.