Skip to content

Commit 828ba3f

Browse files
authored
StaticInt for n-first/last (#289)
* StaticInt for n-first/last
1 parent 8d89c2f commit 828ba3f

File tree

3 files changed

+31
-2
lines changed

3 files changed

+31
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ArrayInterface"
22
uuid = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
3-
version = "6.0.8"
3+
version = "6.0.9"
44

55
[deps]
66
ArrayInterfaceCore = "30b0a656-2188-435a-8636-2ec0e6a096e2"

src/indexing.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,27 @@
11

2+
function known_lastindex(::Type{T}) where {T}
3+
if known_offset1(T) === nothing || known_length(T) === nothing
4+
return nothing
5+
else
6+
return known_length(T) - known_offset1(T) + 1
7+
end
8+
end
9+
known_lastindex(@nospecialize x) = known_lastindex(typeof(x))
10+
11+
@inline static_lastindex(x) = Static.maybe_static(known_lastindex, lastindex, x)
12+
13+
function Base.first(x::AbstractVector, n::StaticInt)
14+
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
15+
start = offset1(x)
16+
@inbounds x[start:min((start - one(start)) + n, static_lastindex(x))]
17+
end
18+
19+
function Base.last(x::AbstractVector, n::StaticInt)
20+
@boundscheck n < 0 && throw(ArgumentError("Number of elements must be nonnegative"))
21+
stop = static_lastindex(x)
22+
@inbounds x[max(offset1(x), (stop + one(stop)) - n):stop]
23+
end
24+
225
function _is_splat(::Type{I}, i::StaticInt) where {I}
326
if dynamic(is_splat_index(field_type(I, i)))
427
True()

test/indexing.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,6 +192,13 @@ end
192192
end
193193
end
194194

195+
@testset "n-first/last" begin
196+
x = MArray([1, 2, 3, 4])
197+
n = static(2)
198+
@test @inferred(first(x, n)) == [1, 2]
199+
@test @inferred(last(x, n)) == [3, 4]
200+
end
201+
195202
A = zeros(3, 4, 5);
196203
A[:] = 1:60
197204
Ap = @view(PermutedDimsArray(A, (3, 1, 2))[:, 1:2, 1])';
@@ -222,4 +229,3 @@ izip = zip(S, S)
222229

223230
sv5 = MArray(zeros(5));
224231
v5 = Vector{Float64}(undef, 5);
225-

0 commit comments

Comments
 (0)