Skip to content

Commit 49715dd

Browse files
feat: more trigonometric functions (#462)
* feat: more trigonometric functions * Update src/TracedRNumber.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * test: trig functions --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 1bc64a7 commit 49715dd

File tree

2 files changed

+33
-0
lines changed

2 files changed

+33
-0
lines changed

src/TracedRNumber.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,7 @@ for (jlop, hloop) in (
202202
(:(Base.:-), :negate),
203203
(:(Base.sin), :sine),
204204
(:(Base.cos), :cosine),
205+
(:(Base.tan), :tan),
205206
(:(Base.tanh), :tanh),
206207
(:(Base.FastMath.tanh_fast), :tanh),
207208
(:(Base.exp), :exponential),
@@ -214,6 +215,13 @@ for (jlop, hloop) in (
214215
@eval $(jlop)(@nospecialize(lhs::TracedRNumber)) = Ops.$(hloop)(lhs)
215216
end
216217

218+
for (jlop, hloop) in
219+
((:(Base.sinpi), :sine), (:(Base.cospi), :cosine), (:(Base.tanpi), :tan))
220+
@eval $(jlop)(@nospecialize(lhs::TracedRNumber{T})) where {T} = Ops.$(hloop)(T(π) * lhs)
221+
end
222+
223+
Base.sincospi(x::TracedRNumber{T}) where {T} = Ops.sine(T(π) * x), Ops.cosine(T(π) * x)
224+
217225
Base.conj(x::TracedRNumber) = x
218226
Base.conj(x::TracedRNumber{<:Complex}) = Ops.conj(x)
219227

test/basic.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -885,3 +885,28 @@ end
885885
res = @jit fn(x_ra, Array(idxs_ra))
886886
@test res fn(Array(x_ra), Array(idxs_ra))
887887
end
888+
889+
@testset "Common Trig Functions" begin
890+
x = rand(Float32, 4, 16)
891+
x_ra = Reactant.to_rarray(x)
892+
893+
@testset for fn in (sinpi, cospi, tanpi, sin, cos, tan)
894+
@test @jit(fn.(x_ra)) fn.(x)
895+
@test @jit(fn.(x_ra)) isa ConcreteRArray{Float32,2}
896+
end
897+
898+
x = 0.235f0
899+
x_ra = Reactant.to_rarray(x; track_numbers=(Number,))
900+
901+
@testset for fn in (sinpi, cospi, tanpi, sin, cos, tan)
902+
@test @jit(fn.(x_ra)) fn.(x)
903+
@test @jit(fn.(x_ra)) isa ConcreteRNumber{Float32}
904+
end
905+
@testset for fn in (sincospi, sincos)
906+
res = @jit fn(x_ra)
907+
@test res[1] fn(x)[1]
908+
@test res[2] fn(x)[2]
909+
@test res[1] isa ConcreteRNumber{Float32}
910+
@test res[2] isa ConcreteRNumber{Float32}
911+
end
912+
end

0 commit comments

Comments
 (0)