diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 82490da..55df959 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,4 +21,4 @@ jobs: run: raco fmt -i **/*.rkt - name: "Make sure files are correctly formatted with raco fmt" run: git diff --exit-code - - run: raco test *.rkt + - run: raco test . diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 86f7c97..bc97a0c 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -6,9 +6,122 @@ racket/list racket/match) -(provide backward-pass) +(provide backward-pass + make-hint) + +; Hint is a vector with len(ivec) elements which +; guides Rival on which instructions should not be executed +; for points from a particular hyperrect of input parameters. +; make-hint is called as a last step of rival-analyze and returns hint as a result +; Values of a hint: +; #f - instruction should not be executed +; #t - instruction should be executed +; integer n - instead of executing, refer to vregs with (list-ref instr n) index +; (the result is known and stored in another register) +; ival - instead of executing, just copy ival as a result of the instruction +(define (make-hint machine) + (define args (rival-machine-arguments machine)) + (define ivec (rival-machine-instructions machine)) + (define rootvec (rival-machine-outputs machine)) + (define vregs (rival-machine-registers machine)) -(define (backward-pass machine) + (define varc (vector-length args)) + (define vhint (make-vector (vector-length ivec) #f)) + (define converged? #t) + + ; helper function + (define (vhint-set! idx val) + (when (>= idx varc) + (vector-set! vhint (- idx varc) val))) + + ; roots always should be executed + (for ([root-reg (in-vector rootvec)]) + (vhint-set! root-reg #t)) + (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] + [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] + [n (in-range (- (vector-length vhint) 1) -1 -1)] + #:when hint) + (define hint* + (case (object-name (car instr)) + [(ival-assert) + (match-define (list _ bool-idx) instr) + (define bool-reg (vector-ref vregs bool-idx)) + (match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg)) + ; assert and its children should not be reexecuted if it is true already + [(#t #t #f) (ival-bool #t)] + ; assert and its children should not be reexecuted if it is false already + [(#f #f #f) (ival-bool #f)] + [(_ _ _) ; assert and its children should be reexecuted + (vhint-set! bool-idx #t) + (set! converged? #f) + #t])] + [(ival-if) + (match-define (list _ cond tru fls) instr) + (define cond-reg (vector-ref vregs cond)) + (match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg)) + [(#t #t #f) ; only true path should be executed + (vhint-set! tru #t) + 2] + [(#f #f #f) ; only false path should be executed + (vhint-set! fls #t) + 3] + [(_ _ _) ; execute both paths and cond as well + (vhint-set! cond #t) + (vhint-set! tru #t) + (vhint-set! fls #t) + (set! converged? #f) + #t])] + [(ival-fmax) + (match-define (list _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + [(#t #t #f) ; only arg1 should be executed + (vhint-set! arg1 #t) + 1] + [(#f #f #f) ; only arg2 should be executed + (vhint-set! arg2 #t) + 2] + [(_ _ _) ; both paths should be executed + (vhint-set! arg1 #t) + (vhint-set! arg2 #t) + (set! converged? #f) + #t])] + [(ival-fmin) + (match-define (list _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + [(#t #t #f) ; only arg2 should be executed + (vhint-set! arg2 #t) + 2] + [(#f #f #f) ; only arg1 should be executed + (vhint-set! arg1 #t) + 1] + [(_ _ _) ; both paths should be executed + (vhint-set! arg1 #t) + (vhint-set! arg2 #t) + (set! converged? #f) + #t])] + [(ival-< ival-<= ival-> ival->= ival-== ival-!= ival-and ival-or ival-not) + (define cmp (vector-ref vregs (+ varc n))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + ; result is known + [(#t #t #f) (ival-bool #t)] + ; result is known + [(#f #f #f) (ival-bool #f)] + [(_ _ _) ; all the paths should be executed + (define srcs (rest instr)) + (for-each (λ (x) (vhint-set! x #t)) srcs) + (set! converged? #f) + #t])] + + [else ; at this point we are given that the current instruction should be executed + (define srcs (rest instr)) ; then, children instructions should be executed as well + (for-each (λ (x) (vhint-set! x #t)) srcs) + #t])) + (vector-set! vhint n hint*)) + (values vhint converged?)) + +(define (backward-pass machine vhint) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) @@ -50,7 +163,7 @@ (vector-set! vuseful (- arg varc) #t))])) ; Step 2. Precision tuning - (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful) + (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint) ; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all ; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well @@ -90,12 +203,13 @@ ; Roughly speaking, the upper precision bound is calculated as: ; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec ; max-prec = (car (get-bounds parent)) -(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful) +(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint) (define vprecs-min (make-vector (vector-length ivec) 0)) (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] - #:when useful?) + [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] + #:when (and hint useful?)) (define op (car instr)) (define tail-registers (cdr instr)) (define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers)) diff --git a/eval/compile.rkt b/eval/compile.rkt index a7c5da5..3f6a445 100644 --- a/eval/compile.rkt +++ b/eval/compile.rkt @@ -270,14 +270,16 @@ ([node (in-vector nodes num-vars)]) (fn->ival-fn node))) - (define register-count (+ (length vars) (vector-length instructions))) + (define ivec-length (vector-length instructions)) + (define register-count (+ (length vars) ivec-length)) (define registers (make-vector register-count)) - (define repeats (make-vector register-count #f)) ; flags whether an op should be evaluated - (define precisions (make-vector register-count)) ; vector that stores working precisions + (define repeats (make-vector ivec-length #f)) ; flags whether an op should be evaluated + (define precisions (make-vector ivec-length)) ; vector that stores working precisions ;; vector for adjusting precisions (define incremental-precisions (setup-vstart-precs instructions (length vars) roots discs)) (define initial-precision (+ (argmax identity (map discretization-target discs)) (*base-tuning-precision*))) + (define hint (make-vector ivec-length #t)) (rival-machine (list->vector vars) instructions @@ -289,6 +291,7 @@ incremental-precisions (make-vector (vector-length roots)) initial-precision + hint 0 0 0 @@ -325,41 +328,3 @@ idx-prec ; We wanna make sure that we do not tune a precision down (+ current-prec (*ampl-tuning-bits*)))))) vstart-precs) - -(module+ test - (require rackunit - "../utils.rkt") - ; This function is needed to unwrap constant procedure which fails tests otherwise - ; (const (ival 0.bf 0.bf)) != (const (ival 0.bf 0.bf)) - (define (drop-ival-const instrs) - (for/vector ([instr (in-vector instrs)]) - (match instr - [`(,const) (const)] - [_ instr]))) - - (define discs (list flonum-discretization)) - (define vars '(x y z)) - - (define (check-rival-optimization expr target-expr) - (define optimized-instrs - (drop-ival-const (parameterize ([*rival-use-shorthands* #t]) - (rival-machine-instructions (rival-compile (list expr) vars discs))))) - (define target-instrs - (drop-ival-const (parameterize ([*rival-use-shorthands* #f]) - (rival-machine-instructions (rival-compile (list target-expr) vars discs))))) - (check-equal? optimized-instrs target-instrs)) - - (check-rival-optimization `(* (log (exp x)) y) `(* x y)) - (check-rival-optimization `(* (exp (log x)) y) `(* (then (assert (> x 0)) x) y)) - (check-rival-optimization `(fma x y z) `(+ (* x y) z)) - (check-rival-optimization `(- (exp x) 1) `(expm1 x)) - (check-rival-optimization `(- 1 (exp x)) `(neg (expm1 x))) - (check-rival-optimization `(log (+ 1 x)) `(log1p x)) - (check-rival-optimization `(log (+ x 1)) `(log1p x)) - (check-rival-optimization `(sqrt (+ (* x x) (* y y))) `(hypot x y)) - (check-rival-optimization `(sqrt (+ (* x x) 1)) `(hypot x 1)) - (check-rival-optimization `(sqrt (+ 1 (* x x))) `(hypot 1 x)) - (check-rival-optimization `(pow x 2) `(pow2 x)) - (check-rival-optimization `(pow x 1/3) `(cbrt x)) - (check-rival-optimization `(pow x 1/2) `(sqrt x)) - (check-rival-optimization `(pow 2 x) `(exp2 x))) diff --git a/eval/machine.rkt b/eval/machine.rkt index a6c8c38..8d6e3b3 100644 --- a/eval/machine.rkt +++ b/eval/machine.rkt @@ -29,6 +29,7 @@ incremental-precisions output-distance initial-precision + hint [iteration #:mutable] [bumps #:mutable] [profile-ptr #:mutable] diff --git a/eval/main.rkt b/eval/main.rkt index 3b5a3d8..d00753f 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -24,14 +24,14 @@ (define ground-truth-require-convergence (make-parameter #t)) -(define (rival-machine-full machine inputs) +(define (rival-machine-full machine inputs [vhint (rival-machine-hint machine)]) (set-rival-machine-iteration! machine (*sampling-iteration*)) - (rival-machine-adjust machine) + (rival-machine-adjust machine vhint) (cond [(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)] [else (rival-machine-load machine inputs) - (rival-machine-run machine) + (rival-machine-run machine vhint) (rival-machine-return machine)])) (struct exn:rival exn:fail ()) @@ -62,14 +62,14 @@ (define (ival-real x) (ival x)) -(define (rival-apply machine pt) +(define (rival-apply machine pt [hint (rival-machine-hint machine)]) (define discs (rival-machine-discs machine)) (set-rival-machine-bumps! machine 0) (let loop ([iter 0]) (define-values (good? done? bad? stuck? fvec) (parameterize ([*sampling-iteration* iter] [ground-truth-require-convergence #t]) - (rival-machine-full machine (vector-map ival-real pt)))) + (rival-machine-full machine (vector-map ival-real pt) hint))) (cond [bad? (raise (exn:rival:invalid "Invalid input" (current-continuation-marks) pt))] [done? fvec] @@ -83,4 +83,5 @@ (parameterize ([*sampling-iteration* 0] [ground-truth-require-convergence #f]) (rival-machine-full machine rect))) - (ival (or bad? stuck?) (not good?))) + (define-values (hint hint-converged?) (make-hint machine)) + (list (ival (or bad? stuck?) (not good?)) hint hint-converged?)) diff --git a/eval/run.rkt b/eval/run.rkt index be76c5c..5b2cd2b 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -31,7 +31,7 @@ (flvector-set! profile-time profile-ptr time) (set-rival-machine-profile-ptr! machine (add1 profile-ptr)))) -(define (rival-machine-run machine) +(define (rival-machine-run machine vhint) (define ivec (rival-machine-instructions machine)) (define varc (vector-length (rival-machine-arguments machine))) (define precisions (rival-machine-precisions machine)) @@ -44,10 +44,17 @@ [n (in-naturals varc)] [precision (in-vector precisions)] [repeat (in-vector repeats)] - #:unless (and (not first-iter?) repeat)) + [hint (in-vector vhint)] + #:unless (or (not hint) (and (not first-iter?) repeat))) (define start (current-inexact-milliseconds)) - (parameterize ([bf-precision precision]) - (vector-set! vregs n (apply-instruction instr vregs))) + (define out + (match hint + [#t + (parameterize ([bf-precision precision]) + (apply-instruction instr vregs))] + [(? integer? _) (vector-ref vregs (list-ref instr hint))] + [(? ival? _) hint])) + (vector-set! vregs n out) (define name (object-name (car instr))) (define time (- (current-inexact-milliseconds) start)) (rival-machine-record machine name n precision time))) @@ -94,10 +101,10 @@ lo)) (values good? (and good? done?) bad? stuck? fvec)) -(define (rival-machine-adjust machine) +(define (rival-machine-adjust machine vhint) (define iter (rival-machine-iteration machine)) (let ([start (current-inexact-milliseconds)]) (if (zero? iter) (vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine)) - (backward-pass machine)) + (backward-pass machine vhint)) (rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start)))) diff --git a/eval/tests.rkt b/eval/tests.rkt new file mode 100644 index 0000000..38d5b1d --- /dev/null +++ b/eval/tests.rkt @@ -0,0 +1,156 @@ +#lang racket + +; Test compile optimizations +(module+ test + (require rackunit + "machine.rkt" + "../main.rkt") + ; This function is needed to unwrap constant procedure which fails tests otherwise + ; (const (ival 0.bf 0.bf)) != (const (ival 0.bf 0.bf)) + (define (drop-ival-const instrs) + (for/vector ([instr (in-vector instrs)]) + (match instr + [`(,const) (const)] + [_ instr]))) + + (define discs (list flonum-discretization)) + (define vars '(x y z)) + + (define (check-rival-optimization expr target-expr) + (define optimized-instrs + (drop-ival-const (parameterize ([*rival-use-shorthands* #t]) + (rival-machine-instructions (rival-compile (list expr) vars discs))))) + (define target-instrs + (drop-ival-const (parameterize ([*rival-use-shorthands* #f]) + (rival-machine-instructions (rival-compile (list target-expr) vars discs))))) + (check-equal? optimized-instrs target-instrs)) + + (check-rival-optimization `(* (log (exp x)) y) `(* x y)) + (check-rival-optimization `(* (exp (log x)) y) `(* (then (assert (> x 0)) x) y)) + (check-rival-optimization `(fma x y z) `(+ (* x y) z)) + (check-rival-optimization `(- (exp x) 1) `(expm1 x)) + (check-rival-optimization `(- 1 (exp x)) `(neg (expm1 x))) + (check-rival-optimization `(log (+ 1 x)) `(log1p x)) + (check-rival-optimization `(log (+ x 1)) `(log1p x)) + (check-rival-optimization `(sqrt (+ (* x x) (* y y))) `(hypot x y)) + (check-rival-optimization `(sqrt (+ (* x x) 1)) `(hypot x 1)) + (check-rival-optimization `(sqrt (+ 1 (* x x))) `(hypot 1 x)) + (check-rival-optimization `(pow x 2) `(pow2 x)) + (check-rival-optimization `(pow x 1/3) `(cbrt x)) + (check-rival-optimization `(pow x 1/2) `(sqrt x)) + (check-rival-optimization `(pow 2 x) `(exp2 x))) + +; Check hints from rival-analyze +(module+ test + (require rackunit + math/bigfloat + "../main.rkt") + (define number-of-random-hyperrects 100) + (define number-of-random-pts-per-rect 100) + (define rect (list (bf -100) (bf 100))) + (bf-precision 53) + + ; Check whether outputs are the same for the hint and without hint executions + (define (rival-check-hint machine hint pt) + (define hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt hint))) + (define hint-instr-count (vector-length (rival-profile machine 'executions))) + + (define no-hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt))) + (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) + + (check-equal? hint-result no-hint-result) + (values no-hint-instr-count hint-instr-count)) + + ; Random sampling hyperrects given a general range as [rect-lo, rect-hi] + (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) + (for/vector ([_ (in-range varc)]) + (define xlo-range-length (bf- rect-hi rect-lo)) + (define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo)) + (define xhi-range-length (bf- rect-hi xlo)) + (define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo)) + (check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo)) + "Hyperrect is out of bounds") + (ival xlo xhi))) + + ; Sample points with respect to the input hyperrect + (define (sample-pts hyperrect) + (for/vector ([rect (in-vector hyperrect)]) + (define range-length (bf- (ival-hi rect) (ival-lo rect))) + (define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect))) + (check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect))) + "Sampled point is out of hyperrect range") + pt)) + + ; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by + ; 'number-of-random-pts-per-rect' points each + (define (hints-random-checks machine rect-lo rect-hi varc) + (define evaluated-instructions 0) + (define number-of-instructions-total + (* number-of-random-hyperrects (vector-length (rival-machine-instructions machine)))) + + (define hint-cnt 0) + (define no-hint-cnt 0) + (for ([n (in-range number-of-random-hyperrects)]) + (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) + (match-define (list res hint _) (rival-analyze machine hyperrect)) + (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) + + (for ([_ (in-range number-of-random-pts-per-rect)]) + (define pt (sample-pts hyperrect)) + (define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt)) + (set! hint-cnt (+ hint-cnt hint-cnt*)) + (set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*)))) + (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) + skipped-percentage) + + (define (expressions2d-check expressions) + (define discs (list boolean-discretization flonum-discretization)) + (define vars '(x y)) + (define varc (length vars)) + (define machine (rival-compile expressions vars discs)) + (define skipped-instr (hints-random-checks machine (first rect) (second rect) varc)) + (check-true (< skipped-instr 99) + (format "Almost no instructions got skipped by hint at ~a" expressions))) + + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) + (fmax (* (log x) (log y)) (+ (log x) (log y))) + (fmin (* (log x) (log y)) (+ (log x) (log y)))) + (if (> (+ (log x) (log y)) (* (log x) (log y))) + (fmax (/ (log x) (log y)) (- (log x) (log y))) + (fmin (/ (log x) (log y)) (- (log x) (log y))))))) + + (expressions2d-check + (list '(TRUE) + '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) + (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) + + (expressions2d-check (list '(TRUE) + '(if (> (exp x) (+ 10 (log y))) + (if (> (fmax (* x y) (+ x y)) 4) + (cos (fmax x y)) + (cos (fmin x y))) + (if (< (pow 2 x) (- (exp x) 10)) + (* PI x) + (fmax x (- (cos y) (+ 10 (log y)))))))) + + ; Test checks hint on assert where an error can be observed + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (cos x) (cos y)))) + + ; Test checks hint on fmax where an error can be observed + (expressions2d-check (list '(TRUE) '(fmax (log x) y))) + + ; Test checks hints on comparison operators + (expressions2d-check + (list '(and (and (> (log x) y) (or (== (exp x) (exp y)) (> (cos x) (cos y)))) (<= (log y) (log x))) + '(if (or (or (< (log x) y) (and (!= (exp x) (exp y)) (< (cos x) (cos y)))) + (>= (log y) (log x))) + x + y)))) \ No newline at end of file diff --git a/main.rkt b/main.rkt index a358c3d..2a52248 100644 --- a/main.rkt +++ b/main.rkt @@ -92,8 +92,11 @@ (only-in "eval/machine.rkt" rival-machine?)) (provide (contract-out [rival-compile (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] - [rival-apply (-> rival-machine? (vectorof value?) (vectorof any/c))] - [rival-analyze (-> rival-machine? (vectorof ival?) ival?)]) + [rival-apply + (->* (rival-machine? (vectorof value?)) + ((or/c (vectorof any/c) boolean?)) + (vectorof any/c))] + [rival-analyze (-> rival-machine? (vectorof ival?) (listof any/c))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid)