Skip to content

Commit 465c69b

Browse files
committed
Context
This commit introduces `Context`, a structure that holds configuration of the decimal arithmetics. Eventually, the global variable `DIGITS` should be completely removed in favor of this newly-added structure.
1 parent a789905 commit 465c69b

13 files changed

+1821
-893
lines changed

.github/workflows/ci.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ jobs:
2222
fail-fast: false
2323
matrix:
2424
version:
25-
- '1.6'
25+
- '1.8'
2626
- '1'
2727
os:
2828
- ubuntu-latest

Project.toml

+5-1
Original file line numberDiff line numberDiff line change
@@ -2,8 +2,12 @@ name = "Decimals"
22
uuid = "abce61dc-4473-55a0-ba07-351d65e31d42"
33
version = "0.4.1"
44

5+
[deps]
6+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
7+
58
[compat]
6-
julia = "1"
9+
ScopedValues = "1"
10+
julia = "1.8"
711

812
[extras]
913
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

scripts/dectest.jl

+91-57
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,48 @@
11
function _precision(line)
22
m = match(r"^precision:\s*(\d+)$", line)
3+
isnothing(m) && throw(ArgumentError(line))
34
return parse(Int, m[1])
45
end
56

67
function _rounding(line)
78
m = match(r"^rounding:\s*(\w+)$", line)
8-
return Symbol(m[1])
9+
isnothing(m) && throw(ArgumentError(line))
10+
r = m[1]
11+
if r == "ceiling"
12+
return "RoundUp"
13+
elseif r == "down"
14+
return "RoundToZero"
15+
elseif r == "floor"
16+
return "RoundDown"
17+
elseif r == "half_even"
18+
return "RoundNearest"
19+
elseif r == "half_up"
20+
return "RoundNearestTiesAway"
21+
elseif r == "up"
22+
return "RoundFromZero"
23+
elseif r == "half_down"
24+
return "RoundHalfDownUnsupported"
25+
elseif r == "05up"
26+
return "Round05UpUnsupported"
27+
else
28+
throw(ArgumentError(r))
29+
end
930
end
1031

1132
function _maxexponent(line)
12-
m = match(r"^maxexponent:\s*(\d+)$", line)
33+
m = match(r"^maxexponent:\s*\+?(\d+)$", line)
34+
isnothing(m) && throw(ArgumentError(line))
1335
return parse(Int, m[1])
1436
end
1537

1638
function _minexponent(line)
1739
m = match(r"^minexponent:\s*(-\d+)$", line)
40+
isnothing(m) && throw(ArgumentError(line))
1841
return parse(Int, m[1])
1942
end
2043

2144
function _test(line)
45+
occursin("->", line) || throw(ArgumentError(line))
2246
lhs, rhs = split(line, "->")
2347
id, operation, operands... = split(lhs)
2448
result, conditions... = split(rhs)
@@ -31,47 +55,55 @@ function decimal(x)
3155
return "dec\"$x\""
3256
end
3357

34-
print_precision(io, p::Int) = println(io, " setprecision(Decimal, $p)")
35-
print_maxexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emax = $e")
36-
print_minexponent(io, e::Int) = println(io, " Decimals.CONTEXT.Emin = $e")
37-
function print_rounding(io, r::Symbol)
38-
modes = Dict(:ceiling => "RoundUp",
39-
:down => "RoundToZero",
40-
:floor => "RoundDown",
41-
:half_even => "RoundNearest",
42-
:half_up => "RoundNearestTiesAway",
43-
:up => "RoundFromZero",
44-
:half_down => "RoundHalfDownUnsupported",
45-
Symbol("05up") => "Round05UpUnsupported")
46-
haskey(modes, r) || throw(ArgumentError(r))
47-
rmod = modes[r]
48-
println(io, " setrounding(Decimal, $rmod)")
49-
end
50-
5158
function print_operation(io, operation, operands)
52-
if operation == "plus"
53-
print_plus(io, operands...)
54-
elseif operation == "minus"
55-
print_minus(io, operands...)
59+
if operation == "abs"
60+
print_abs(io, operands...)
61+
elseif operation == "add"
62+
print_add(io, operands...)
63+
elseif operation == "apply"
64+
print_apply(io, operands...)
5665
elseif operation == "compare"
5766
print_compare(io, operands...)
67+
elseif operation == "divide"
68+
print_divide(io, operands...)
69+
elseif operation == "minus"
70+
print_minus(io, operands...)
71+
elseif operation == "multiply"
72+
print_multiply(io, operands...)
73+
elseif operation == "plus"
74+
print_plus(io, operands...)
75+
elseif operation == "reduce"
76+
print_reduce(io, operands...)
77+
elseif operation == "subtract"
78+
print_subtract(io, operands...)
5879
else
5980
throw(ArgumentError(operation))
6081
end
6182
end
83+
print_abs(io, x) = print(io, "abs(", decimal(x), ")")
84+
print_add(io, x, y) = print(io, decimal(x), " + ", decimal(y))
85+
print_apply(io, x) = print(io, decimal(x))
6286
print_compare(io, x, y) = print(io, "cmp(", decimal(x), ", ", decimal(y), ")")
87+
print_divide(io, x, y) = print(io, decimal(x), " / ", decimal(y))
6388
print_minus(io, x) = print(io, "-(", decimal(x), ")")
89+
print_multiply(io, x, y) = print(io, decimal(x), " * ", decimal(y))
6490
print_plus(io, x) = print(io, "+(", decimal(x), ")")
91+
print_reduce(io, x) = print(io, "reduce(", decimal(x), ")")
92+
print_subtract(io, x, y) = print(io, decimal(x), " - ", decimal(y))
6593

66-
function print_test(io, test)
94+
function print_test(io, test, directives)
6795
println(io, " # $(test.id)")
6896

97+
names = sort!(collect(keys(directives)))
98+
params = join(("$k=$(directives[k])" for k in names), ", ")
99+
print(io, " @with_context ($params) ")
100+
69101
if :overflow test.conditions
70-
print(io, " @test_throws OverflowError ")
102+
print(io, "@test_throws OverflowError ")
71103
print_operation(io, test.operation, test.operands)
72104
println(io)
73105
else
74-
print(io, " @test ")
106+
print(io, "@test ")
75107
print_operation(io, test.operation, test.operands)
76108
print(io, " == ")
77109
println(io, decimal(test.result))
@@ -83,34 +115,36 @@ function isspecial(value)
83115
return occursin(r"(inf|nan|#)", value)
84116
end
85117

86-
function translate(io, line)
87-
isempty(line) && return
88-
startswith(line, "--") && return
89-
90-
line = lowercase(line)
91-
92-
if startswith(line, "version:")
93-
# ...
94-
elseif startswith(line, "extended:")
95-
# ...
96-
elseif startswith(line, "clamp:")
97-
# ...
98-
elseif startswith(line, "precision:")
99-
precision = _precision(line)
100-
print_precision(io, precision)
101-
elseif startswith(line, "rounding:")
102-
rounding = _rounding(line)
103-
print_rounding(io, rounding)
104-
elseif startswith(line, "maxexponent:")
105-
maxexponent = _maxexponent(line)
106-
print_maxexponent(io, maxexponent)
107-
elseif startswith(line, "minexponent:")
108-
minexponent = _minexponent(line)
109-
print_minexponent(io, minexponent)
110-
else
111-
test = _test(line)
112-
any(isspecial, test.operands) && return
113-
print_test(io, test)
118+
function translate(io, dectest_path)
119+
directives = Dict{String, Any}()
120+
121+
for line in eachline(dectest_path)
122+
line = strip(line)
123+
124+
isempty(line) && continue
125+
startswith(line, "--") && continue
126+
127+
line = lowercase(line)
128+
129+
if startswith(line, "version:")
130+
# ...
131+
elseif startswith(line, "extended:")
132+
# ...
133+
elseif startswith(line, "clamp:")
134+
# ...
135+
elseif startswith(line, "precision:")
136+
directives["precision"] = _precision(line)
137+
elseif startswith(line, "rounding:")
138+
directives["rounding"] = _rounding(line)
139+
elseif startswith(line, "maxexponent:")
140+
directives["Emax"] = _maxexponent(line)
141+
elseif startswith(line, "minexponent:")
142+
directives["Emin"] = _minexponent(line)
143+
else
144+
test = _test(line)
145+
any(isspecial, test.operands) && continue
146+
print_test(io, test, directives)
147+
end
114148
end
115149
end
116150

@@ -120,13 +154,13 @@ function (@main)(args=ARGS)
120154
open(output_path, "w") do io
121155
println(io, """
122156
using Decimals
157+
using ScopedValues
123158
using Test
159+
using Decimals: @with_context
124160
125161
@testset \"$name\" begin""")
126162

127-
for line in eachline(dectest_path)
128-
translate(io, line)
129-
end
163+
translate(io, dectest_path)
130164

131165
println(io, "end")
132166
end

src/Decimals.jl

+1
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ struct Decimal <: AbstractFloat
2121
end
2222

2323
include("bigint.jl")
24+
include("context.jl")
2425

2526
# Convert between Decimal objects, numbers, and strings
2627
include("decimal.jl")

src/arithmetic.jl

+2-4
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,8 @@ Base.promote_rule(::Type{Decimal}, ::Type{<:Real}) = Decimal
44
Base.promote_rule(::Type{BigFloat}, ::Type{Decimal}) = Decimal
55
Base.promote_rule(::Type{BigInt}, ::Type{Decimal}) = Decimal
66

7-
const BigTen = BigInt(10)
7+
Base.:(+)(x::Decimal) = fix(x)
8+
Base.:(-)(x::Decimal) = fix(Decimal(!x.s, x.c, x.q))
89

910
# Addition
1011
# To add, convert both decimals to the same exponent.
@@ -24,9 +25,6 @@ function Base.:(+)(x::Decimal, y::Decimal)
2425
return normalize(Decimal(s, abs(c), y.q))
2526
end
2627

27-
# Negation
28-
Base.:(-)(x::Decimal) = Decimal(!x.s, x.c, x.q)
29-
3028
# Subtraction
3129
Base.:(-)(x::Decimal, y::Decimal) = +(x, -y)
3230

src/bigint.jl

+2
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ else
44
const libgmp = Base.GMP.libgmp
55
end
66

7+
const BigTen = BigInt(10)
8+
79
function isdivisible(x::BigInt, n::Int)
810
r = ccall((:__gmpz_divisible_ui_p, libgmp), Cint,
911
(Base.GMP.MPZ.mpz_t, Culong), x, n)

0 commit comments

Comments
 (0)