Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Min/max/if optimizations based on the insights from rival-analyze #87

Merged
merged 31 commits into from
Jan 20, 2025
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ff3b2ef
some hints template
AYadrov Dec 14, 2024
9cbf9e4
hint seems to work, needs to be tested
AYadrov Dec 17, 2024
33f5566
tests are done, one weird syntax issue exists
AYadrov Dec 18, 2024
0c174db
debugging sesh, turned out that we care about error flags when doing …
AYadrov Dec 18, 2024
007ce20
Update tests.yml, added tests from eval/*.rkt
AYadrov Dec 18, 2024
441caa7
rival-machine-run looksmaxxing
AYadrov Dec 18, 2024
c07cd6e
Merge branch 'min-max-optimizations' of github.com:herbie-fp/rival in…
AYadrov Dec 18, 2024
7037ccf
some comments
AYadrov Dec 18, 2024
fd1275e
contract change for rival-analyze
AYadrov Dec 20, 2024
5ad41cd
no backward pass when hint says so
AYadrov Dec 23, 2024
b26b236
removed skipping of backward pass due to some unknown yet bugs + conv…
AYadrov Jan 8, 2025
1445846
the boog with backward pass + hint has been found
AYadrov Jan 8, 2025
936c72f
contract change for rival-analyze including completeness flag
AYadrov Jan 8, 2025
5fe3f68
restoring rival-profile
AYadrov Jan 8, 2025
4b2080d
quite obvious bug
AYadrov Jan 9, 2025
757f9be
change of contract + return value for rival-analyze (temporarily
AYadrov Jan 9, 2025
145d410
oops, unit tests fix
AYadrov Jan 9, 2025
646f266
hint is added as a field of rival-machine
AYadrov Jan 9, 2025
af8950d
added hint for assert function
AYadrov Jan 9, 2025
9a1db61
emperical evaluation of skipped instructions for unit tests in main.rkt
AYadrov Jan 9, 2025
9976e64
unit tests update for rivals hint to verify automatically that the in…
AYadrov Jan 9, 2025
b2af4c8
less allocation in run.rkt
AYadrov Jan 9, 2025
3a8b855
introduced some threshold value for unit tests that rival should pass
AYadrov Jan 10, 2025
93ef101
Suggested update at test.yml
AYadrov Jan 19, 2025
d389cdd
PR suggestions and fixes
AYadrov Jan 20, 2025
da3942e
Merge branch 'min-max-optimizations' of github.com:herbie-fp/rival in…
AYadrov Jan 20, 2025
8d2a955
added hint support for or, and, not operations
AYadrov Jan 20, 2025
f2dd254
missed converged flag
AYadrov Jan 20, 2025
4a7bb19
a bug that was bothering me. Wrong length of some vectors
AYadrov Jan 20, 2025
ec2b709
moved tests to a separate file
AYadrov Jan 20, 2025
0fd937c
made tests silent
AYadrov Jan 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,3 +22,4 @@ jobs:
- name: "Make sure files are correctly formatted with raco fmt"
run: git diff --exit-code
- run: raco test *.rkt
- run: raco test eval/*.rkt
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
124 changes: 118 additions & 6 deletions eval/adjust.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,120 @@
racket/list
racket/match)

(provide backward-pass)
(provide backward-pass
make-hint)

; Hint is a vector with len(ivec) elements which
; guides Rival on which instructions should not be executed
; for points from a particular hyperrect of input parameters.
; make-hint is called as a last step of rival-analyze and returns hint as a result
; Values of a hint:
; #f - instruction should not be executed
; #t - instruction should be executed
; integer n - instead of executing, refer to vregs with (list-ref instr n) index
; (the result is known and stored in another register)
; ival - instead of executing, just copy ival as a result of the instruction
(define (make-hint machine)
(define args (rival-machine-arguments machine))
(define ivec (rival-machine-instructions machine))
(define rootvec (rival-machine-outputs machine))
(define vregs (rival-machine-registers machine))

(define (backward-pass machine)
(define varc (vector-length args))
(define vhint (make-vector (vector-length ivec) #f))
(define converged? #t)

(define (vhint-set! idx val)
(when (>= idx varc)
(vector-set! vhint (- idx varc) val)))
(define (vhint-ref idx)
(if (>= idx varc)
(vector-ref vhint (- idx varc))
#f))

(for ([root-reg (in-vector rootvec)])
(vhint-set! root-reg #t))
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
[hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)]
[n (in-range (- (vector-length vhint) 1) -1 -1)]
#:when hint)
(define hint*
(case (object-name (car instr))
[(ival-assert)
(match-define (list _ bool-idx) instr)
(define bool-reg (vector-ref vregs bool-idx))
(match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg))
[(#t #t #f) ; assert and its children should not be reexecuted if it is true already
(vhint-set! bool-idx (or #f (vhint-ref bool-idx)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a no-op, right? (or #f x) is just x.

Copy link
Contributor Author

@AYadrov AYadrov Jan 19, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My thinking was like this:
since we are using batches, we may have duplicates and this why I made or operation.
The thing is that make-hint function first will create a mask for all the instructions,
and mark #t nodes that are roots (they always need to be executed to write some result in vreg vector).
Then it starts a traverse in the reverse order.
This line of code catches the following case:
Imagine that we have a list of expressions as:
(list '(< (cos x) 0) '(assert (< (cos x) 0)))
The algorithm first will mark roots < and assert as #t - should be executed.
And then, the algorithm will start to analyze instruction by instruction.
It will look at assert ... and figure out that assert, for example, is always true - which means that (< (cos x) 0) basically should not be executed.
And here a mistake can happen, (< (cos x) 0) actually should be executed because it is another root and some result should be written into vregs, if we mark (< (cos x) 0) as #f - there will be no output in vregs.
Therefore, the algorithm before marking anything as #f checks whether it previously was marked as #t and if not - it is safe to mark it as #f, otherwise, the instruction should be executed

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I get all that. My proposal is delete this line of code entirely. If no one else wants this line of code, its hint is already set to #f (that's the default). If someone else wants it, good, we don't disturb it.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right right, here this line of code is useless

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess the same operation can be removed everywhere

(ival-bool #t)]
[(#f #f #f) ; assert and its children should not be reexecuted if it is false already
(vhint-set! bool-idx (or #f (vhint-ref bool-idx)))
(ival-bool #f)]
[(_ _ _) ; assert and its children should be reexecuted
(vhint-set! bool-idx #t)
#t])]
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
[(ival-if)
(match-define (list _ cond tru fls) instr)
(define cond-reg (vector-ref vregs cond))
(match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg))
pavpanchekha marked this conversation as resolved.
Show resolved Hide resolved
[(#t #t #f) ; only true path should be executed
(vhint-set! cond (or #f (vhint-ref cond)))
(vhint-set! tru #t)
(vhint-set! fls (or #f (vhint-ref fls)))
2]
[(#f #f #f) ; only false path should be executed
(vhint-set! cond (or #f (vhint-ref cond)))
(vhint-set! tru (or #f (vhint-ref tru)))
(vhint-set! fls #t)
3]
[(_ _ _) ; execute both paths and cond as well
(vhint-set! cond #t)
(vhint-set! tru #t)
(vhint-set! fls #t)
(set! converged? #f)
#t])]
[(ival-fmax)
(match-define (list _ arg1 arg2) instr)
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
pavpanchekha marked this conversation as resolved.
Show resolved Hide resolved
(match* ((ival-lo cmp) (ival-hi cmp))
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
[(#t #t) ; only arg1 should be executed
(vhint-set! arg2 (or #f (vhint-ref arg2)))
(vhint-set! arg1 #t)
1]
[(#f #f) ; only arg2 should be executed
(vhint-set! arg1 (or #f (vhint-ref arg1)))
(vhint-set! arg2 #t)
2]
[(#f #t) ; both paths should be executed
(vhint-set! arg1 #t)
(vhint-set! arg2 #t)
(set! converged? #f)
#t])]
[(ival-fmin)
(match-define (list _ arg1 arg2) instr)
(define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2)))
(match* ((ival-lo cmp) (ival-hi cmp))
[(#t #t) ; only arg2 should be executed
(vhint-set! arg1 (or #f (vhint-ref arg1)))
(vhint-set! arg2 #t)
2]
[(#f #f) ; only arg1 should be executed
(vhint-set! arg2 (or #f (vhint-ref arg2)))
(vhint-set! arg1 #t)
1]
[(#f #t) ; both paths should be executed
(vhint-set! arg1 #t)
(vhint-set! arg2 #t)
(set! converged? #f)
#t])]
[else ; at this point we are given that the current instruction should be executed
(define srcs (rest instr)) ; then, children instructions should be executed as well
(map (λ (x) (vhint-set! x #t)) srcs)
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
#t]))
(vector-set! vhint n hint*))
(values vhint converged?))

(define (backward-pass machine vhint)
; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3
(define args (rival-machine-arguments machine))
(define ivec (rival-machine-instructions machine))
Expand Down Expand Up @@ -50,7 +161,7 @@
(vector-set! vuseful (- arg varc) #t))]))

; Step 2. Precision tuning
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful)
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint)

; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all
; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well
Expand Down Expand Up @@ -90,12 +201,13 @@
; Roughly speaking, the upper precision bound is calculated as:
; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec
; max-prec = (car (get-bounds parent))
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful)
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint)
(define vprecs-min (make-vector (vector-length ivec) 0))
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
[useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)]
[n (in-range (- (vector-length vregs) 1) -1 -1)]
#:when useful?)
[hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)]
#:when (and hint useful?))
(define op (car instr))
(define tail-registers (cdr instr))
(define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers))
Expand Down Expand Up @@ -134,4 +246,4 @@
; Lower precision bound propogation
(vector-set! vprecs-min
(- x varc)
(max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound)))))))
(max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound)))))))
AYadrov marked this conversation as resolved.
Show resolved Hide resolved
2 changes: 2 additions & 0 deletions eval/compile.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,7 @@
(define incremental-precisions (setup-vstart-precs instructions (length vars) roots discs))
(define initial-precision
(+ (argmax identity (map discretization-target discs)) (*base-tuning-precision*)))
(define hint (make-vector (vector-length instructions) #t))

(rival-machine (list->vector vars)
instructions
Expand All @@ -289,6 +290,7 @@
incremental-precisions
(make-vector (vector-length roots))
initial-precision
hint
0
0
0
Expand Down
1 change: 1 addition & 0 deletions eval/machine.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
incremental-precisions
output-distance
initial-precision
hint
[iteration #:mutable]
[bumps #:mutable]
[profile-ptr #:mutable]
Expand Down
129 changes: 123 additions & 6 deletions eval/main.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

(define ground-truth-require-convergence (make-parameter #t))

(define (rival-machine-full machine inputs)
(define (rival-machine-full machine inputs [vhint (rival-machine-hint machine)])
(set-rival-machine-iteration! machine (*sampling-iteration*))
(rival-machine-adjust machine)
(rival-machine-adjust machine vhint)
(cond
[(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)]
[else
(rival-machine-load machine inputs)
(rival-machine-run machine)
(rival-machine-run machine vhint)
(rival-machine-return machine)]))

(struct exn:rival exn:fail ())
Expand Down Expand Up @@ -62,14 +62,14 @@
(define (ival-real x)
(ival x))

(define (rival-apply machine pt)
(define (rival-apply machine pt [hint (rival-machine-hint machine)])
(define discs (rival-machine-discs machine))
(set-rival-machine-bumps! machine 0)
(let loop ([iter 0])
(define-values (good? done? bad? stuck? fvec)
(parameterize ([*sampling-iteration* iter]
[ground-truth-require-convergence #t])
(rival-machine-full machine (vector-map ival-real pt))))
(rival-machine-full machine (vector-map ival-real pt) hint)))
(cond
[bad? (raise (exn:rival:invalid "Invalid input" (current-continuation-marks) pt))]
[done? fvec]
Expand All @@ -83,4 +83,121 @@
(parameterize ([*sampling-iteration* 0]
[ground-truth-require-convergence #f])
(rival-machine-full machine rect)))
(ival (or bad? stuck?) (not good?)))
(define-values (hint hint-converged?) (make-hint machine))
(list (ival (or bad? stuck?) (not good?)) hint hint-converged?))

(module+ test
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The tests represent a random testing of the hints implementation.

It samples random hyperrects in a global box of [-100, 100] and points inside these hyperrects.
The hyperrects are analyzed using rival-analyze and obtained hint is executed for points.
The tests verify that execution with hint and without hint produces the same results.
Additionally, the tests check the percentage of instruction executions being skipped by hint.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you move the tests to a separate file? Otherwise, sure.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I did not review the tests closely.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, sure

(require rackunit
"compile.rkt"
"../utils.rkt"
math/bigfloat)
(define number-of-random-hyperrects 100)
(define number-of-random-pts-per-rect 100)
(define threshold 95) ; at least 5% of instructions should be skipped by hint to pass the tests!
(bf-precision 53)

; Check whether outputs are the same for the hint and without hint executions
(define (rival-check-hint machine hint pt)
(define no-hint-result
(with-handlers ([exn:rival:invalid? (λ (e) 'invalid)]
[exn:rival:unsamplable? (λ (e) 'unsamplable)])
(rival-apply machine pt)))
(define no-hint-instr-count (vector-length (rival-profile machine 'executions)))

(define hint-result
(with-handlers ([exn:rival:invalid? (λ (e) 'invalid)]
[exn:rival:unsamplable? (λ (e) 'unsamplable)])
(rival-apply machine pt hint)))
(define hint-instr-count (vector-length (rival-profile machine 'executions)))

(check-equal? hint-result no-hint-result)
(values no-hint-instr-count hint-instr-count))

; Random sampling hyperrects given a general range as [rect-lo, rect-hi]
(define (sample-hyperrect-within-bounds rect-lo rect-hi varc)
(for/vector ([_ (in-range varc)])
(define xlo-range-length (bf- rect-hi rect-lo))
(define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo))
(define xhi-range-length (bf- rect-hi xlo))
(define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo))
(check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo))
"Hyperrect is out of bounds")
(ival xlo xhi)))

; Sample points with respect to the input hyperrect
(define (sample-pts hyperrect)
(for/vector ([rect (in-vector hyperrect)])
(define range-length (bf- (ival-hi rect) (ival-lo rect)))
(define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect)))
(check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect)))
"Sampled point is out of hyperrect range")
pt))

; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by
; 'number-of-random-pts-per-rect' points each
(define (hints-random-checks machine rect-lo rect-hi varc)
(define evaluated-instructions 0)
(define number-of-instructions-total
(* number-of-random-hyperrects (vector-length (rival-machine-instructions machine))))

(define hint-cnt 0)
(define no-hint-cnt 0)
(for ([n (in-range number-of-random-hyperrects)])
(define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc))
(match-define (list res hint _) (rival-analyze machine hyperrect))
(set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint)))

(for ([_ (in-range number-of-random-pts-per-rect)])
(define pt (sample-pts hyperrect))
(define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt))
(set! hint-cnt (+ hint-cnt hint-cnt*))
(set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*))))
(define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100))
skipped-percentage)

(define discs (list boolean-discretization flonum-discretization))
(define vars '(x y))
(define varc (length vars))

(define expr1
(list '(assert (> (+ (log x) (log y)) (- (log x) (log y))))
'(+ (if (> (/ (log x) (log y)) (* (log x) (log y)))
(fmax (* (log x) (log y)) (+ (log x) (log y)))
(fmin (* (log x) (log y)) (+ (log x) (log y))))
(if (> (+ (log x) (log y)) (* (log x) (log y)))
(fmax (/ (log x) (log y)) (- (log x) (log y)))
(fmin (/ (log x) (log y)) (- (log x) (log y)))))))
(define machine1 (rival-compile expr1 vars discs))
(define skipped-instr1 (hints-random-checks machine1 (bf -100) (bf 100) varc))
(printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1))
(check-true (< skipped-instr1 threshold))

(define expr2
(list '(TRUE)
'(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2))))
(fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2)))))))
(define machine2 (rival-compile expr2 vars discs))
(define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc))
(printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2))
(check-true (< skipped-instr2 threshold))

(define expr3
(list '(TRUE)
'(if (> (exp x) (+ 10 (log y)))
(if (> (fmax (* x y) (+ x y)) 4)
(cos (fmax x y))
(cos (fmin x y)))
(if (< (pow 2 x) (- (exp x) 10))
(* PI x)
(fmax x (- (cos y) (+ 10 (log y))))))))
(define machine3 (rival-compile expr3 vars discs))
(define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc))
(printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3))
(check-true (< skipped-instr3 threshold))

; Test checks hint on assert where an error can be observed
(define expr4 (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (cos x) (cos y))))
(define machine4 (rival-compile expr4 vars discs))
(define skipped-instr4 (hints-random-checks machine4 (bf -100) (bf 100) varc))
(printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4))
(check-true (< skipped-instr4 threshold)))
18 changes: 12 additions & 6 deletions eval/run.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
(flvector-set! profile-time profile-ptr time)
(set-rival-machine-profile-ptr! machine (add1 profile-ptr))))

(define (rival-machine-run machine)
(define (rival-machine-run machine vhint)
(define ivec (rival-machine-instructions machine))
(define varc (vector-length (rival-machine-arguments machine)))
(define precisions (rival-machine-precisions machine))
Expand All @@ -44,10 +44,16 @@
[n (in-naturals varc)]
[precision (in-vector precisions)]
[repeat (in-vector repeats)]
#:unless (and (not first-iter?) repeat))
[hint (in-vector vhint)]
#:unless (or (not hint) (and (not first-iter?) repeat)))
(define start (current-inexact-milliseconds))
(parameterize ([bf-precision precision])
(vector-set! vregs n (apply-instruction instr vregs)))
(define out
(parameterize ([bf-precision precision])
(match hint
[#t (apply-instruction instr vregs)]
[(? integer? _) (vector-ref vregs (list-ref instr hint))]
[(? ival? _) hint])))
(vector-set! vregs n out)
(define name (object-name (car instr)))
(define time (- (current-inexact-milliseconds) start))
(rival-machine-record machine name n precision time)))
Expand Down Expand Up @@ -94,10 +100,10 @@
lo))
(values good? (and good? done?) bad? stuck? fvec))

(define (rival-machine-adjust machine)
(define (rival-machine-adjust machine vhint)
(define iter (rival-machine-iteration machine))
(let ([start (current-inexact-milliseconds)])
(if (zero? iter)
(vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine))
(backward-pass machine))
(backward-pass machine vhint))
(rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start))))
Loading
Loading