Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[stdlib] Clean up memory.unsafe #3588

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion stdlib/src/memory/__init__.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,5 @@ from .arc import Arc
from .box import Box
from .memory import memcmp, memcpy, memset, memset_zero, stack_allocation
from .pointer import AddressSpace, Pointer
from .unsafe import bitcast
from .unsafe import bitcast, pack_bits
from .unsafe_pointer import UnsafePointer
68 changes: 30 additions & 38 deletions stdlib/src/memory/unsafe.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -28,18 +28,21 @@ from sys import bitwidthof

@always_inline("nodebug")
fn bitcast[
new_type: DType, new_width: Int, src_type: DType, src_width: Int
](val: SIMD[src_type, src_width]) -> SIMD[new_type, new_width]:
type: DType,
width: Int, //,
new_type: DType,
new_width: Int = width,
](val: SIMD[type, width]) -> SIMD[new_type, new_width]:
"""Bitcasts a SIMD value to another SIMD value.

Constraints:
The bitwidth of the two types must be the same.

Parameters:
type: The source type.
width: The source width.
new_type: The target type.
new_width: The target width.
src_type: The source type.
src_width: The source width.

Args:
val: The source value.
Expand All @@ -49,13 +52,13 @@ fn bitcast[
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, src_width]]()
bitwidthof[SIMD[type, width]]()
== bitwidthof[SIMD[new_type, new_width]](),
"the source and destination types must have the same bitwidth",
]()

@parameter
if new_type == src_type:
if new_type == type:
return rebind[SIMD[new_type, new_width]](val)
return __mlir_op.`pop.bitcast`[
_type = __mlir_type[
Expand All @@ -65,45 +68,31 @@ fn bitcast[


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_type: DType
](val: SIMD[src_type, 1]) -> SIMD[new_type, 1]:
"""Bitcasts a SIMD value to another SIMD value.

Constraints:
The bitwidth of the two types must be the same.

Parameters:
new_type: The target type.
src_type: The source type.

Args:
val: The source value.

Returns:
A new SIMD value with the specified type and width with a bitcopy of the
source SIMD value.
"""
constrained[
bitwidthof[SIMD[src_type, 1]]() == bitwidthof[SIMD[new_type, 1]](),
"the source and destination types must have the same bitwidth",
]()

return bitcast[new_type, 1, src_type, 1](val)
fn _uint(n: Int) -> DType:
if n == 8:
return DType.uint8
elif n == 16:
return DType.uint16
elif n == 32:
return DType.uint32
else:
return DType.uint64


@always_inline("nodebug")
fn bitcast[
new_type: DType, src_width: Int
](val: SIMD[DType.bool, src_width]) -> Scalar[new_type]:
fn pack_bits[
width: Int, //,
new_type: DType = _uint(width),
](val: SIMD[DType.bool, width]) -> Scalar[new_type]:
"""Packs a SIMD bool into an integer.

Constraints:
The bitwidth of the two types must be the same.
The width of the bool vector must be the same as the bitwidth of the
target type.

Parameters:
width: The source width.
new_type: The target type.
src_width: The source width.

Args:
val: The source value.
Expand All @@ -112,8 +101,11 @@ fn bitcast[
A new integer scalar which has the same bitwidth as the bool vector.
"""
constrained[
src_width == bitwidthof[Scalar[new_type]](),
"the source and destination types must have the same bitwidth",
width == bitwidthof[Scalar[new_type]](),
(
"the width of the bool vector must be the same as the bitwidth of"
" the target type"
),
]()

return __mlir_op.`pop.bitcast`[
Expand Down
6 changes: 3 additions & 3 deletions stdlib/src/utils/stringref.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width
from collections.string import _atol, _isspace
from hashlib._hasher import _HashableWithHasher, _Hasher
from memory import UnsafePointer, memcmp, bitcast
from memory import UnsafePointer, memcmp, pack_bits
from memory.memory import _memcmp_impl_unconstrained
from utils import StringSlice
from sys.ffi import c_char
Expand Down Expand Up @@ -698,7 +698,7 @@ fn _memchr[

for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = source.load[width=bool_mask_width](i) == first_needle
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_bits(bool_mask)
if mask:
return source + int(i + count_trailing_zeros(mask))

Expand Down Expand Up @@ -742,7 +742,7 @@ fn _memmem[
var eq_last = last_needle == last_block

var bool_mask = eq_first & eq_last
var mask = bitcast[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_bits(bool_mask)

while mask:
var offset = int(i + count_trailing_zeros(mask))
Expand Down