Skip to content

Commit

Permalink
Merge branch 'main' into trig-tuning-fix
Browse files Browse the repository at this point in the history
  • Loading branch information
pavpanchekha committed Mar 12, 2024
2 parents fa89ac2 + a45489c commit 05866cf
Show file tree
Hide file tree
Showing 2 changed files with 196 additions and 151 deletions.
189 changes: 94 additions & 95 deletions main.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
(require (for-syntax racket/base))
(module+ test (require rackunit))

(define *rival-precision* (make-parameter 8192))
(define *rival-precision* (make-parameter (expt 2 20)))

(define-match-expander ival-expander
(λ (stx)
Expand Down Expand Up @@ -131,8 +131,9 @@
(define (mk-big-ival x y)
(cond
[(and (bigfloat? x) (bigfloat? y))
(define err? (or (bfnan? x) (bfnan? y)))
(define fix? (bf=? x y))
(define err? (or (bfnan? x) (bfnan? y)
(and (bfinfinite? x) fix?)))
(ival (endpoint x fix?) (endpoint y fix?) err? err?)]
[(and (boolean? x) (boolean? y))
(define fix? (equal? x y))
Expand Down Expand Up @@ -190,43 +191,28 @@
(define (endpoint-min2 e1 e2)
(match-define (endpoint x x!) e1)
(match-define (endpoint y y!) e2)
(cond
[(bflt? x y)
e1]
[(bflt? y x)
e2]
[else
(endpoint (bfmin2 x y) (or x! y!))]))
(define out (bfmin2 x y))
(endpoint out (or (and (bf=? out x) x!) (and (bf=? out y) y!))))

(define (endpoint-max2 e1 e2)
(match-define (endpoint x x!) e1)
(match-define (endpoint y y!) e2)
(cond
[(bfgt? x y)
e1]
[(bfgt? y x)
e2]
[else
(endpoint (bfmax2 x y) (or x! y!))]))
(define out (bfmax2 x y))
(endpoint out (or (and (bf=? out x) x!) (and (bf=? out y) y!))))

(define (ival-union x y)
(cond
[(ival-err x) (struct-copy ival y [err? #t])]
[(ival-err y) (struct-copy ival x [err? #t])]
[(bigfloat? (ival-lo-val x))
(ival (endpoint-min2 (ival-lo x) (ival-lo y))
(endpoint-max2 (ival-hi x) (ival-hi y))
(ival (rnd 'down endpoint-min2 (ival-lo x) (ival-lo y))
(rnd 'up endpoint-max2 (ival-hi x) (ival-hi y))
(or (ival-err? x) (ival-err? y)) (and (ival-err x) (ival-err y)))]
[(boolean? (ival-lo-val x))
(ival (epfn and-fn (ival-lo x) (ival-lo y))
(epfn or-fn (ival-hi x) (ival-hi y))
(or (ival-err? x) (ival-err? y)) (and (ival-err x) (ival-err y)))]))

(define (propagate-err c x)
(ival (ival-lo x) (ival-hi x)
(or (ival-err? c) (ival-err? x))
(or (ival-err c) (ival-err x))))

;; This function computes and propagates the immovable? flag for endpoints
(define (epfn op . args)
(define args-bf (map endpoint-val args))
Expand All @@ -249,13 +235,6 @@
(values out exact?))
;; End hairy code

(define (ival-neg x)
;; No rounding, negation is exact
(ival
(epfn bfneg (ival-hi x))
(epfn bfneg (ival-lo x))
(ival-err? x) (ival-err x)))

;; Endpoint computation for both `add`, `sub`, and `hypot` (which has an add inside)
(define (eplinear bffn a-endpoint b-endpoint)
(match-define (endpoint a a!) a-endpoint)
Expand Down Expand Up @@ -373,11 +352,15 @@

(define ((clamp lo hi) x)
(match-define (ival (endpoint xlo xlo!) (endpoint xhi xhi!) xerr? xerr) x)
(define err? (or xerr? (bflt? xlo lo) (bfgt? xhi hi)))
(define err (or (or xerr (bflt? xhi lo) (bfgt? xlo hi))))

(ival (endpoint (if (bflt? xlo lo) lo xlo) xlo!)
(endpoint (if (bfgt? xhi hi) hi xhi) xhi!)
(or xerr? (bflt? xlo lo) (bfgt? xhi hi))
(or xerr (bflt? xhi lo) (bfgt? xlo hi))))
(if (and (bfzero? lo) (bfzero? xhi))
(ival (endpoint 0.bf xlo!) (endpoint 0.bf xhi!) err? err)
(ival (endpoint (if (bflt? xlo lo) lo xlo) xlo!)
(endpoint (if (bfgt? xhi hi) hi xhi) xhi!)
err?
err)))

(define ((clamp-strict lo hi) x)
(match-define (ival (endpoint xlo xlo!) (endpoint xhi xhi!) xerr? xerr) x)
Expand All @@ -393,21 +376,37 @@
(endpoint yhi (or yhi! (bfgte? xlo hi) (and (bfgte? xhi hi) xhi!)))
xerr? xerr))

(define* ival-rint (monotonic bfrint))
(define* ival-round (monotonic bfround))
(define* ival-ceil (monotonic bfceiling))
(define* ival-floor (monotonic bffloor))
(define* ival-trunc (monotonic bftruncate))
(define* ival-neg (comonotonic bfneg))

(define (ival-fabs x)
;; This function fixes a bug in MPFR where mixed-precision
;; rint/round/ceil/floor/trunc operations are rounded in the input
;; precision, not the output precision, so (rnd 'down bfround xxx) can
;; return +inf.bf
(define (fix-infinite-pt-interval x)
(match-define (ival (endpoint xlo xlo!) (endpoint xhi xhi!) xerr? xerr) x)
(cond
[(bfgt? xlo 0.bf) x]
[(bflt? xhi 0.bf) (ival-neg x)]
[else ; interval stradles 0
(ival (endpoint 0.bf (and xlo! xhi!))
(endpoint-max2 (endpoint (bfneg xlo) xlo!) (ival-hi x))
(ival-err? x) (ival-err x))]))
[(and (bfnegative? xhi) (bfinfinite? xhi))
(ival (endpoint xlo xlo!) (endpoint (bfstep xhi 1) #f) xerr? xerr)]
[(and (bfpositive? xlo) (bfinfinite? xlo))
(ival (endpoint (bfstep xlo -1) #f) (endpoint xhi xhi!) xerr? xerr)]
[else
x]))

(define* ival-rint (compose fix-infinite-pt-interval (monotonic bfrint)))
(define* ival-round (compose fix-infinite-pt-interval (monotonic bfround)))
(define* ival-ceil (compose fix-infinite-pt-interval (monotonic bfceiling)))
(define* ival-floor (compose fix-infinite-pt-interval (monotonic bffloor)))
(define* ival-trunc (compose fix-infinite-pt-interval (monotonic bftruncate)))

(define (ival-fabs x)
(match (classify-ival x)
[-1 ((comonotonic bfabs) x)]
[1 ((monotonic bfabs) x)]
[0
(match-define (ival (endpoint xlo xlo!) (endpoint xhi xhi!) xerr? xerr) x)
(ival (endpoint 0.bf (and xlo! xhi!))
(rnd 'up endpoint-max2 (epfn bfabs (ival-lo x)) (ival-hi x))
(ival-err? x) (ival-err x))]))

;; Since MPFR has a cap on exponents, no value can be more than twice MAX_VAL
(define exp-overflow-threshold (bfadd (bflog (bfprev +inf.bf)) 1.bf))
Expand Down Expand Up @@ -488,33 +487,19 @@

(define (ival-pow-neg x y)
;; Assumes x is negative
(define err? (or (ival-err? x) (ival-err? y) (bflt? (ival-lo-val y) (ival-hi-val y))))
(define err (or (ival-err x) (ival-err y)))
(define xpos (ival-fabs x))
(define a (bfceiling (ival-lo-val y)))
(define b (bffloor (ival-hi-val y)))
(cond
[(bflt? b a) ; y does not contain an integer
; But it still contains many odd fractions
; It is sort-of unclear what we actually do here:
; (-1)^(1/3) = -1 makes sense, but what about
; (-1)^(2/3) = 1? Or (-1)^(2/6)?
; We go with an expansive definition, hoping it will never matter.
(define pos-pow (ival-pow-pos xpos y))
(ival-then ival-maybe (ival-union (ival-neg pos-pow) pos-pow))]
[(bf=? a b)
(define aep (endpoint a (and (endpoint-immovable? (ival-lo y)) (endpoint-immovable? (ival-hi y)))))
(if (bfodd? a)
(ival-neg (ival-pow-pos xpos (ival aep aep err? err)))
(ival-pow-pos xpos (ival aep aep err? err)))]
[else
;; TODO: the movability here is pretty subtle
(define odds (ival (endpoint (if (bfodd? a) a (bfadd a 1.bf)) #f)
(endpoint (if (bfodd? b) b (bfsub b 1.bf)) #f) err? err))
(define evens (ival (endpoint (if (bfodd? a) (bfadd a 1.bf) a) #f)
(endpoint (if (bfodd? b) (bfsub b 1.bf) b) #f) err? err))
(ival-union (ival-pow-pos xpos evens)
(ival-neg (ival-pow-pos xpos odds)))]))
(if (bf=? (ival-lo-val y) (ival-hi-val y))
(if (bfinteger? (ival-lo-val y))
; If y is an integer point interval, there's no error,
; because it's always valid to raise to an integer power.
(if (bfodd? (ival-lo-val y))
(ival-neg (ival-pow-pos (ival-fabs x) y)) ; Use fabs in case of [x, 0]
(ival-pow-pos (ival-fabs x) y))
; If y is non-integer point interval, it must be an even
; fraction (because all bigfloats are) so we always error
ival-illegal)
; Moreover, if we have (-x)^y, that's basically x^y U -(x^y).
(let ([pospow (ival-pow-pos (ival-fabs x) y)])
(ival-then (ival-assert ival-maybe) (ival-union (ival-neg pospow) pospow)))))

(define* ival-pow2 (compose (monotonic (lambda (x) (bfmul x x))) ival-fabs))

Expand Down Expand Up @@ -736,12 +721,14 @@
(define d (rnd 'up bftruncate (bfdiv (ival-hi-val x) (ival-lo-val y))))
(cond
[(bf=? c d) ; No intersection along `x.hi` either; use top-left/bottom-right point
(ival (endpoint (rnd 'down bfsub (ival-lo-val x) (rnd 'up bfmul* c (ival-hi-val y))) #f)
(endpoint (rnd 'up bfsub (ival-hi-val x) (rnd 'down bfmul* c (ival-lo-val y))) #f)
(define lo (rnd 'down bfsub (ival-lo-val x) (rnd 'up bfmul* c (ival-hi-val y))))
(define hi (rnd 'up bfsub (ival-hi-val x) (rnd 'down bfmul* c (ival-lo-val y))))
(ival (endpoint lo #f)
(endpoint hi #f)
err? err)]
[else
(ival (endpoint 0.bf #f)
(endpoint (bfmax2 (rnd 'up bfdiv (ival-hi-val x) (bfadd c 1.bf)) 0.bf) #f) err? err)])]
(endpoint (rnd 'up bfmax2 (bfdiv (ival-hi-val x) (bfadd c 1.bf)) 0.bf) #f) err? err)])]
[else
(ival (endpoint 0.bf #f) (endpoint (ival-hi-val y) #f) err? err)]))

Expand Down Expand Up @@ -772,20 +759,21 @@
(cond
[(bf=? c d) ; No intersection along `x.hi` either; use top-left/bottom-right point
(define y* (bfdiv (ival-hi-val y) 2.bf))
(ival (endpoint (bfmax2 (rnd 'down bfsub (ival-lo-val x) (rnd 'up bfmul c (ival-hi-val y)))
(ival (endpoint (rnd 'down bfmax2 (bfsub (ival-lo-val x) (rnd 'up bfmul c (ival-hi-val y)))
(bfneg y*)) #f)
(endpoint (bfmin2 (rnd 'up bfsub (ival-hi-val x) (rnd 'down bfmul c (ival-lo-val y)))
(endpoint (rnd 'up bfmin2 (bfsub (ival-hi-val x) (rnd 'down bfmul c (ival-lo-val y)))
y*) #f)
err? err)]
[else
;; NOPE! need to subtract half.bf one way, add it another!
(define y*-hi (bfdiv (rnd 'down bfdiv (ival-hi-val x) (bfadd c half.bf)) 2.bf))
(define y*-lo (bfmax2 (rnd 'down bfsub (ival-lo-val x) (rnd 'up bfmul c (ival-hi-val y)))
(bfneg (bfdiv (ival-hi-val y) 2.bf))))
(ival (endpoint (bfmin2 y*-lo (bfneg y*-hi)) #f) (endpoint y*-hi #f) err? err)])]
(define y*-hi (rnd 'up bfdiv (bfdiv (ival-hi-val x) (bfadd c half.bf)) 2.bf))
(define y*-lo (rnd 'down bfmax2
(bfsub (ival-lo-val x) (rnd 'up bfmul c (ival-hi-val y)))
(bfneg (bfdiv (ival-hi-val y) 2.bf))))
(ival (endpoint (rnd 'down bfmin2 y*-lo (bfneg y*-hi)) #f) (endpoint y*-hi #f) err? err)])]
[else
(define y* (bfdiv (ival-hi-val y) 2.bf))
(ival (endpoint (bfneg y*) #f) (endpoint y* #f) err? err)]))
(define y* (rnd 'up bfdiv (ival-hi-val y) 2.bf))
(ival (endpoint (rnd 'down bfneg y*) #f) (endpoint y* #f) err? err)]))

;; Seems unnecessary
(define (ival-remainder x y)
Expand Down Expand Up @@ -913,6 +901,10 @@
;; This case only happens if xnegr = #f meaning lo = rnd[up](lo + 1) meaning lo = -inf
(mk-big-ival -inf.bf +inf.bf)))

(define (exact-bffloor x)
(parameterize ([bf-precision (bigfloat-precision x)])
(bffloor x)))

(define (ival-tgamma x)
(define logy (ival-lgamma x))
(unless logy
Expand All @@ -923,13 +915,13 @@
(cond
[(bfgte? lo 0.bf)
absy]
[(not (bf=? (bffloor lo) (bffloor hi)))
[(not (bf=? (exact-bffloor lo) (exact-bffloor hi)))
(ival (endpoint -inf.bf (ival-lo-fixed? x))
(endpoint +inf.bf (ival-hi-fixed? x))
#t (ival-err x))]
[(and (not (bfpositive? lo)) (bf=? lo hi) (bfinteger? lo))
ival-illegal]
[(bfeven? (bffloor lo))
[(bfeven? (exact-bffloor lo))
absy]
[else
(ival-neg absy)]))
Expand Down Expand Up @@ -1012,29 +1004,36 @@
(or (ival-err? a) (ormap ival-err? as))
(or (ival-err a) (ormap ival-err as))))

(define* ival-identity (monotonic bfcopy))

(define (ival-if c x y)
(cond
[(ival-lo-val c) (propagate-err c x)]
[(not (ival-hi-val c)) (propagate-err c y)]
[else (propagate-err c (ival-union x y))]))
[(ival-lo-val c) (ival-then c (ival-identity x))]
[(not (ival-hi-val c)) (ival-then c (ival-identity y))]
[else (ival-then c (ival-union x y))]))

(define (ival-fmin x y)
(ival (endpoint-min2 (ival-lo x) (ival-lo y)) (endpoint-min2 (ival-hi x) (ival-hi y))
(ival (rnd 'down endpoint-min2 (ival-lo x) (ival-lo y))
(rnd 'up endpoint-min2 (ival-hi x) (ival-hi y))
(or (ival-err? x) (ival-err? y)) (or (ival-err x) (ival-err y))))

(define (ival-fmax x y)
(ival (endpoint-max2 (ival-lo x) (ival-lo y)) (endpoint-max2 (ival-hi x) (ival-hi y))
(ival (rnd 'down endpoint-max2 (ival-lo x) (ival-lo y))
(rnd 'up endpoint-max2 (ival-hi x) (ival-hi y))
(or (ival-err? x) (ival-err? y)) (or (ival-err x) (ival-err y))))

(define (ival-copysign x y)
(match-define (ival xlo xhi xerr? xerr) (ival-fabs x))
(define can-neg (= (bigfloat-signbit (ival-lo-val y)) 1))
(define can-pos (= (bigfloat-signbit (ival-hi-val y)) 0))
(define can-zero
(or (bfzero? (ival-lo-val y)) (bfzero? (ival-hi-val y))))
;; 0 is both positive and negative because we don't handle signed zero well
(define can-neg (or (= (bigfloat-signbit (ival-lo-val y)) 1) can-zero))
(define can-pos (or (= (bigfloat-signbit (ival-hi-val y)) 0) can-zero))
(define err? (or (ival-err? y) xerr?))
(define err (or (ival-err y) xerr))
(match* (can-neg can-pos)
[(#t #t) (ival (epfn bfneg xhi) xhi err? err)]
[(#t #f) (ival (epfn bfneg xhi) (epfn bfneg xlo) err? err)]
[(#t #t) (ival (rnd 'down epfn bfneg xhi) (rnd 'up epfn bfcopy xhi) err? err)]
[(#t #f) (ival (rnd 'down epfn bfneg xhi) (rnd 'up epfn bfneg xlo) err? err)]
[(#f #t) (ival xlo xhi err? err)]
[(#f #f)
(unless (ival-err y)
Expand Down
Loading

0 comments on commit 05866cf

Please sign in to comment.