Skip to content

Commit

Permalink
[stdlib] Make pack_mask infer its return type
Browse files Browse the repository at this point in the history
Signed-off-by: Yiwu Chen <210at85@gmail.com>
  • Loading branch information
soraros committed Oct 1, 2024
1 parent f24df7e commit 282ea3c
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 4 deletions.
14 changes: 13 additions & 1 deletion stdlib/src/memory/unsafe.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,22 @@ fn bitcast[
](val.value)


@always_inline("nodebug")
fn _uint(n: Int) -> DType:
if n == 8:
return DType.uint8
elif n == 16:
return DType.uint16
elif n == 32:
return DType.uint32
else:
return DType.uint64


@always_inline("nodebug")
fn pack_mask[
width: Int, //,
new_type: DType,
new_type: DType = _uint(width),
](val: SIMD[DType.bool, width]) -> Scalar[new_type]:
"""Packs a SIMD bool into an integer.
Expand Down
6 changes: 3 additions & 3 deletions stdlib/src/utils/stringref.mojo
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from bit import count_trailing_zeros
from builtin.dtype import _uint_type_of_width
from collections.string import _atol, _isspace
from memory import UnsafePointer, memcmp, bitcast, pack_mask
from memory import UnsafePointer, memcmp, pack_mask
from memory.memory import _memcmp_impl_unconstrained
from utils import StringSlice
from sys.ffi import c_char
Expand Down Expand Up @@ -634,7 +634,7 @@ fn _memchr[

for i in range(0, vectorized_end, bool_mask_width):
var bool_mask = source.load[width=bool_mask_width](i) == first_needle
var mask = pack_mask[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_mask(bool_mask)
if mask:
return source + int(i + count_trailing_zeros(mask))

Expand Down Expand Up @@ -678,7 +678,7 @@ fn _memmem[
var eq_last = last_needle == last_block

var bool_mask = eq_first & eq_last
var mask = pack_mask[_uint_type_of_width[bool_mask_width]()](bool_mask)
var mask = pack_mask(bool_mask)

while mask:
var offset = int(i + count_trailing_zeros(mask))
Expand Down

0 comments on commit 282ea3c

Please sign in to comment.