Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min/max/if optimizations based on the insights from rival-analyze #87

Merged
merged 31 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 29 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff3b2ef
some hints template
AYadrov Dec 14, 2024
9cbf9e4
hint seems to work, needs to be tested
AYadrov Dec 17, 2024
33f5566
tests are done, one weird syntax issue exists
AYadrov Dec 18, 2024
0c174db
debugging sesh, turned out that we care about error flags when doing …
AYadrov Dec 18, 2024
007ce20
Update tests.yml, added tests from eval/*.rkt
AYadrov Dec 18, 2024
441caa7
rival-machine-run looksmaxxing
AYadrov Dec 18, 2024
c07cd6e
Merge branch 'min-max-optimizations' of github.com:herbie-fp/rival in…
AYadrov Dec 18, 2024
7037ccf
some comments
AYadrov Dec 18, 2024
fd1275e
contract change for rival-analyze
AYadrov Dec 20, 2024
5ad41cd
no backward pass when hint says so
AYadrov Dec 23, 2024
b26b236
removed skipping of backward pass due to some unknown yet bugs + conv…
AYadrov Jan 8, 2025
1445846
the boog with backward pass + hint has been found
AYadrov Jan 8, 2025
936c72f
contract change for rival-analyze including completeness flag
AYadrov Jan 8, 2025
5fe3f68
restoring rival-profile
AYadrov Jan 8, 2025
4b2080d
quite obvious bug
AYadrov Jan 9, 2025
757f9be
change of contract + return value for rival-analyze (temporarily
AYadrov Jan 9, 2025
145d410
oops, unit tests fix
AYadrov Jan 9, 2025
646f266
hint is added as a field of rival-machine
AYadrov Jan 9, 2025
af8950d
added hint for assert function
AYadrov Jan 9, 2025
9a1db61
emperical evaluation of skipped instructions for unit tests in main.rkt
AYadrov Jan 9, 2025
9976e64
unit tests update for rivals hint to verify automatically that the in…
AYadrov Jan 9, 2025
b2af4c8
less allocation in run.rkt
AYadrov Jan 9, 2025
3a8b855
introduced some threshold value for unit tests that rival should pass
AYadrov Jan 10, 2025
93ef101
Suggested update at test.yml
AYadrov Jan 19, 2025
d389cdd
PR suggestions and fixes
AYadrov Jan 20, 2025
da3942e
Merge branch 'min-max-optimizations' of github.com:herbie-fp/rival in…
AYadrov Jan 20, 2025
8d2a955
added hint support for or, and, not operations
AYadrov Jan 20, 2025
f2dd254
missed converged flag
AYadrov Jan 20, 2025
4a7bb19
a bug that was bothering me. Wrong length of some vectors
AYadrov Jan 20, 2025
ec2b709
moved tests to a separate file
AYadrov Jan 20, 2025
0fd937c
made tests silent
AYadrov Jan 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,4 +21,4 @@ jobs:
run: raco fmt -i **/*.rkt
- name: "Make sure files are correctly formatted with raco fmt"
run: git diff --exit-code
- run: raco test *.rkt
- run: raco test .
124 changes: 119 additions & 5 deletions eval/adjust.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,122 @@
racket/list
racket/match)

(provide backward-pass)
(provide backward-pass
make-hint)

; Hint is a vector with len(ivec) elements which
; guides Rival on which instructions should not be executed
; for points from a particular hyperrect of input parameters.
; make-hint is called as a last step of rival-analyze and returns hint as a result
; Values of a hint:
; #f - instruction should not be executed
; #t - instruction should be executed
; integer n - instead of executing, refer to vregs with (list-ref instr n) index
; (the result is known and stored in another register)
; ival - instead of executing, just copy ival as a result of the instruction
(define (make-hint machine)
(define args (rival-machine-arguments machine))
(define ivec (rival-machine-instructions machine))
(define rootvec (rival-machine-outputs machine))
(define vregs (rival-machine-registers machine))

(define (backward-pass machine)
(define varc (vector-length args))
(define vhint (make-vector (vector-length ivec) #f))
(define converged? #t)

; helper function
(define (vhint-set! idx val)
(when (>= idx varc)
(vector-set! vhint (- idx varc) val)))

; roots always should be executed
(for ([root-reg (in-vector rootvec)])
(vhint-set! root-reg #t))
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
[hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)]
[n (in-range (- (vector-length vhint) 1) -1 -1)]
#:when hint)
(define hint*
(case (object-name (car instr))
[(ival-assert)
(match-define (list _ bool-idx) instr)
(define bool-reg (vector-ref vregs bool-idx))
(match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg))
; assert and its children should not be reexecuted if it is true already
[(#t #t #f) (ival-bool #t)]
; assert and its children should not be reexecuted if it is false already
[(#f #f #f) (ival-bool #f)]
[(_ _ _) ; assert and its children should be reexecuted
(vhint-set! bool-idx #t)
(set! converged? #f)
#t])]
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
[(ival-if)
(match-define (list _ cond tru fls) instr)
(define cond-reg (vector-ref vregs cond))
(match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg))
pavpanchekha marked this conversation as resolved.
Show resolved Hide resolved
[(#t #t #f) ; only true path should be executed
(vhint-set! tru #t)
2]
[(#f #f #f) ; only false path should be executed
(vhint-set! fls #t)
3]
[(_ _ _) ; execute both paths and cond as well
(vhint-set! cond #t)
(vhint-set! tru #t)
(vhint-set! fls #t)
(set! converged? #f)
#t])]
[(ival-fmax)
(match-define (list _ arg1 arg2) instr)
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
pavpanchekha marked this conversation as resolved.
Show resolved Hide resolved
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
[(#t #t #f) ; only arg1 should be executed
(vhint-set! arg1 #t)
1]
[(#f #f #f) ; only arg2 should be executed
(vhint-set! arg2 #t)
2]
[(_ _ _) ; both paths should be executed
(vhint-set! arg1 #t)
(vhint-set! arg2 #t)
(set! converged? #f)
#t])]
[(ival-fmin)
(match-define (list _ arg1 arg2) instr)
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
[(#t #t #f) ; only arg2 should be executed
(vhint-set! arg2 #t)
2]
[(#f #f #f) ; only arg1 should be executed
(vhint-set! arg1 #t)
1]
[(_ _ _) ; both paths should be executed
(vhint-set! arg1 #t)
(vhint-set! arg2 #t)
(set! converged? #f)
#t])]
[(ival-< ival-<= ival-> ival->= ival-== ival-!= ival-and ival-or ival-not)
(define cmp (vector-ref vregs (+ varc n)))
(match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp))
; result is known
[(#t #t #f) (ival-bool #t)]
; result is known
[(#f #f #f) (ival-bool #f)]
[(_ _ _) ; all the paths should be executed
(define srcs (rest instr))
(for-each (λ (x) (vhint-set! x #t)) srcs)
(set! converged? #f)
#t])]

[else ; at this point we are given that the current instruction should be executed
(define srcs (rest instr)) ; then, children instructions should be executed as well
(for-each (λ (x) (vhint-set! x #t)) srcs)
#t]))
(vector-set! vhint n hint*))
(values vhint converged?))

(define (backward-pass machine vhint)
; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3
(define args (rival-machine-arguments machine))
(define ivec (rival-machine-instructions machine))
Expand Down Expand Up @@ -50,7 +163,7 @@
(vector-set! vuseful (- arg varc) #t))]))

; Step 2. Precision tuning
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful)
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint)

; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all
; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well
Expand Down Expand Up @@ -90,12 +203,13 @@
; Roughly speaking, the upper precision bound is calculated as:
; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec
; max-prec = (car (get-bounds parent))
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful)
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint)
(define vprecs-min (make-vector (vector-length ivec) 0))
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
[useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)]
[n (in-range (- (vector-length vregs) 1) -1 -1)]
#:when useful?)
[hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)]
#:when (and hint useful?))
(define op (car instr))
(define tail-registers (cdr instr))
(define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers))
Expand Down
9 changes: 6 additions & 3 deletions eval/compile.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -270,14 +270,16 @@
([node (in-vector nodes num-vars)])
(fn->ival-fn node)))

(define register-count (+ (length vars) (vector-length instructions)))
(define ivec-length (vector-length instructions))
(define register-count (+ (length vars) ivec-length))
(define registers (make-vector register-count))
(define repeats (make-vector register-count #f)) ; flags whether an op should be evaluated
(define precisions (make-vector register-count)) ; vector that stores working precisions
(define repeats (make-vector ivec-length #f)) ; flags whether an op should be evaluated
(define precisions (make-vector ivec-length)) ; vector that stores working precisions
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
;; vector for adjusting precisions
(define incremental-precisions (setup-vstart-precs instructions (length vars) roots discs))
(define initial-precision
(+ (argmax identity (map discretization-target discs)) (*base-tuning-precision*)))
(define hint (make-vector ivec-length #t))

(rival-machine (list->vector vars)
instructions
Expand All @@ -289,6 +291,7 @@
incremental-precisions
(make-vector (vector-length roots))
initial-precision
hint
0
0
0
Expand Down
1 change: 1 addition & 0 deletions eval/machine.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
incremental-precisions
output-distance
initial-precision
hint
[iteration #:mutable]
[bumps #:mutable]
[profile-ptr #:mutable]
Expand Down
127 changes: 121 additions & 6 deletions eval/main.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

(define ground-truth-require-convergence (make-parameter #t))

(define (rival-machine-full machine inputs)
(define (rival-machine-full machine inputs [vhint (rival-machine-hint machine)])
(set-rival-machine-iteration! machine (*sampling-iteration*))
(rival-machine-adjust machine)
(rival-machine-adjust machine vhint)
(cond
[(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)]
[else
(rival-machine-load machine inputs)
(rival-machine-run machine)
(rival-machine-run machine vhint)
(rival-machine-return machine)]))

(struct exn:rival exn:fail ())
Expand Down Expand Up @@ -62,14 +62,14 @@
(define (ival-real x)
(ival x))

(define (rival-apply machine pt)
(define (rival-apply machine pt [hint (rival-machine-hint machine)])
(define discs (rival-machine-discs machine))
(set-rival-machine-bumps! machine 0)
(let loop ([iter 0])
(define-values (good? done? bad? stuck? fvec)
(parameterize ([*sampling-iteration* iter]
[ground-truth-require-convergence #t])
(rival-machine-full machine (vector-map ival-real pt))))
(rival-machine-full machine (vector-map ival-real pt) hint)))
(cond
[bad? (raise (exn:rival:invalid "Invalid input" (current-continuation-marks) pt))]
[done? fvec]
Expand All @@ -83,4 +83,119 @@
(parameterize ([*sampling-iteration* 0]
[ground-truth-require-convergence #f])
(rival-machine-full machine rect)))
(ival (or bad? stuck?) (not good?)))
(define-values (hint hint-converged?) (make-hint machine))
(list (ival (or bad? stuck?) (not good?)) hint hint-converged?))

(module+ test
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests represent a random testing of the hints implementation.

It samples random hyperrects in a global box of [-100, 100] and points inside these hyperrects.
The hyperrects are analyzed using rival-analyze and obtained hint is executed for points.
The tests verify that execution with hint and without hint produces the same results.
Additionally, the tests check the percentage of instruction executions being skipped by hint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the tests to a separate file? Otherwise, sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not review the tests closely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure

(require rackunit
"compile.rkt"
"../utils.rkt"
math/bigfloat)
(define number-of-random-hyperrects 100)
(define number-of-random-pts-per-rect 100)
(define rect (list (bf -100) (bf 100)))
(bf-precision 53)

; Check whether outputs are the same for the hint and without hint executions
(define (rival-check-hint machine hint pt)
(define hint-result
(with-handlers ([exn:rival:invalid? (λ (e) 'invalid)]
[exn:rival:unsamplable? (λ (e) 'unsamplable)])
(rival-apply machine pt hint)))
(define hint-instr-count (vector-length (rival-profile machine 'executions)))

(define no-hint-result
(with-handlers ([exn:rival:invalid? (λ (e) 'invalid)]
[exn:rival:unsamplable? (λ (e) 'unsamplable)])
(rival-apply machine pt)))
(define no-hint-instr-count (vector-length (rival-profile machine 'executions)))

(check-equal? hint-result no-hint-result)
(values no-hint-instr-count hint-instr-count))

; Random sampling hyperrects given a general range as [rect-lo, rect-hi]
(define (sample-hyperrect-within-bounds rect-lo rect-hi varc)
(for/vector ([_ (in-range varc)])
(define xlo-range-length (bf- rect-hi rect-lo))
(define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo))
(define xhi-range-length (bf- rect-hi xlo))
(define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo))
(check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo))
"Hyperrect is out of bounds")
(ival xlo xhi)))

; Sample points with respect to the input hyperrect
(define (sample-pts hyperrect)
(for/vector ([rect (in-vector hyperrect)])
(define range-length (bf- (ival-hi rect) (ival-lo rect)))
(define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect)))
(check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect)))
"Sampled point is out of hyperrect range")
pt))

; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by
; 'number-of-random-pts-per-rect' points each
(define (hints-random-checks machine rect-lo rect-hi varc)
(define evaluated-instructions 0)
(define number-of-instructions-total
(* number-of-random-hyperrects (vector-length (rival-machine-instructions machine))))

(define hint-cnt 0)
(define no-hint-cnt 0)
(for ([n (in-range number-of-random-hyperrects)])
(define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc))
(match-define (list res hint _) (rival-analyze machine hyperrect))
(set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint)))

(for ([_ (in-range number-of-random-pts-per-rect)])
(define pt (sample-pts hyperrect))
(define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt))
(set! hint-cnt (+ hint-cnt hint-cnt*))
(set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*))))
(define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100))
skipped-percentage)

(define (expressions2d-check expressions)
(define discs (list boolean-discretization flonum-discretization))
(define vars '(x y))
(define varc (length vars))
(define machine (rival-compile expressions vars discs))
(define skipped-instr (hints-random-checks machine (first rect) (second rect) varc))
(printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr)))

(expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y))))
'(+ (if (> (/ (log x) (log y)) (* (log x) (log y)))
(fmax (* (log x) (log y)) (+ (log x) (log y)))
(fmin (* (log x) (log y)) (+ (log x) (log y))))
(if (> (+ (log x) (log y)) (* (log x) (log y)))
(fmax (/ (log x) (log y)) (- (log x) (log y)))
(fmin (/ (log x) (log y)) (- (log x) (log y)))))))

(expressions2d-check
(list '(TRUE)
'(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2))))
(fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2)))))))

(expressions2d-check (list '(TRUE)
'(if (> (exp x) (+ 10 (log y)))
(if (> (fmax (* x y) (+ x y)) 4)
(cos (fmax x y))
(cos (fmin x y)))
(if (< (pow 2 x) (- (exp x) 10))
(* PI x)
(fmax x (- (cos y) (+ 10 (log y))))))))

; Test checks hint on assert where an error can be observed
(expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y))))
'(+ (cos x) (cos y))))

; Test checks hint on fmax where an error can be observed
(expressions2d-check (list '(TRUE) '(fmax (log x) y)))

; Test checks hints on comparison operators
(expressions2d-check
(list '(and (and (> (log x) y) (or (== (exp x) (exp y)) (> (cos x) (cos y)))) (<= (log y) (log x)))
'(if (or (or (< (log x) y) (and (!= (exp x) (exp y)) (< (cos x) (cos y))))
(>= (log y) (log x)))
x
y))))
19 changes: 13 additions & 6 deletions eval/run.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
(flvector-set! profile-time profile-ptr time)
(set-rival-machine-profile-ptr! machine (add1 profile-ptr))))

(define (rival-machine-run machine)
(define (rival-machine-run machine vhint)
(define ivec (rival-machine-instructions machine))
(define varc (vector-length (rival-machine-arguments machine)))
(define precisions (rival-machine-precisions machine))
Expand All @@ -44,10 +44,17 @@
[n (in-naturals varc)]
[precision (in-vector precisions)]
[repeat (in-vector repeats)]
#:unless (and (not first-iter?) repeat))
[hint (in-vector vhint)]
#:unless (or (not hint) (and (not first-iter?) repeat)))
(define start (current-inexact-milliseconds))
(parameterize ([bf-precision precision])
(vector-set! vregs n (apply-instruction instr vregs)))
(define out
(match hint
[#t
(parameterize ([bf-precision precision])
(apply-instruction instr vregs))]
[(? integer? _) (vector-ref vregs (list-ref instr hint))]
[(? ival? _) hint]))
(vector-set! vregs n out)
(define name (object-name (car instr)))
(define time (- (current-inexact-milliseconds) start))
(rival-machine-record machine name n precision time)))
Expand Down Expand Up @@ -94,10 +101,10 @@
lo))
(values good? (and good? done?) bad? stuck? fvec))

(define (rival-machine-adjust machine)
(define (rival-machine-adjust machine vhint)
(define iter (rival-machine-iteration machine))
(let ([start (current-inexact-milliseconds)])
(if (zero? iter)
(vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine))
(backward-pass machine))
(backward-pass machine vhint))
(rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start))))
Loading
Loading