From ff3b2efd15cb4a5cccf327a2b4e354d7c08bb048 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Fri, 13 Dec 2024 17:06:16 -0700 Subject: [PATCH 01/29] some hints template --- eval/adjust.rkt | 72 ++++++++++++++++++++++++++++++++++++++++++++++++- eval/main.rkt | 3 ++- 2 files changed, 73 insertions(+), 2 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 86f7c97..4ceb2a2 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -6,7 +6,77 @@ racket/list racket/match) -(provide backward-pass) +(provide backward-pass + make-hint) + +(define (make-hint machine) + (define args (rival-machine-arguments machine)) + (define ivec (rival-machine-instructions machine)) + (define rootvec (rival-machine-outputs machine)) + (define vregs (rival-machine-registers machine)) + + (define varc (vector-length args)) + (define vhint (make-vector (vector-length ivec) #f)) + + (for ([root-reg (in-vector rootvec)]) + (vector-set! vhint (- root-reg varc) #t)) + (for/vector ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] + [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] + #:when hint) + (case (object-name (car instr)) + [(ival-if) + (match-define (vector _ cond tru fls) instr) + (define cond-reg (vector-ref vregs cond)) + (cond + [(and (ival-lo cond-reg) (ival-hi cond-reg)) + (vector-set! vhint (- cond varc) (or #f (vector-ref vhint (- cond varc)))) + (vector-set! vhint (- tru varc) #t) + (vector-set! vhint (- fls varc) (or #f (vector-ref vhint (- fls varc)))) + 1] + [(not (or (ival-lo cond-reg) (ival-hi cond-reg))) + (vector-set! vhint (- cond varc) (or #f (vector-ref vhint (- cond varc)))) + (vector-set! vhint (- tru varc) (or #f (vector-ref vhint (- tru varc)))) + (vector-set! vhint (- fls varc) #t) + 2] + [else + (vector-set! vhint (- cond varc) #t) + (vector-set! vhint (- tru varc) #t) + (vector-set! vhint (- fls varc) #t) + #t])] + [(ival-fmax) + (match-define (vector _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (cond + [(and (ival-lo cmp) (ival-hi cmp)) + (vector-set! vhint (- arg2 varc) (or #f (vector-ref vhint (- arg2 varc)))) + (vector-set! vhint (- arg1 varc) #t) + 0] + [(not (or (ival-lo cmp) (ival-hi cmp))) + (vector-set! vhint (- arg1 varc) (or #f (vector-ref vhint (- arg1 varc)))) + (vector-set! vhint (- arg2 varc) #t) + 1] + [else + (vector-set! vhint (- arg1 varc) #t) + (vector-set! vhint (- arg2 varc) #t) + #t])] + [(ival-fmin) + (match-define (vector _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (cond + [(and (ival-lo cmp) (ival-hi cmp)) + (vector-set! vhint (- arg1 varc) (or #f (vector-ref vhint (- arg1 varc)))) + (vector-set! vhint (- arg2 varc) #t) + 1] + [(not (or (ival-lo cmp) (ival-hi cmp))) + (vector-set! vhint (- arg2 varc) (or #f (vector-ref vhint (- arg2 varc)))) + (vector-set! vhint (- arg1 varc) #t) + 0] + [else + (vector-set! vhint (- arg1 varc) #t) + (vector-set! vhint (- arg2 varc) #t) + #t])] + ;(vector-map (curry vector-ref vregs) (vector-rest instr)) + [else #t]))) (define (backward-pass machine) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 diff --git a/eval/main.rkt b/eval/main.rkt index 3b5a3d8..fa80171 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -83,4 +83,5 @@ (parameterize ([*sampling-iteration* 0] [ground-truth-require-convergence #f]) (rival-machine-full machine rect))) - (ival (or bad? stuck?) (not good?))) + (define hint (make-hint machine)) + (values (ival (or bad? stuck?) (not good?)) hint)) From 9cbf9e463e187af4874943f2ddb14c2a63e98036 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Mon, 16 Dec 2024 18:10:52 -0700 Subject: [PATCH 02/29] hint seems to work, needs to be tested --- eval/adjust.rkt | 129 ++++++++++++++++++++++++++---------------------- eval/main.rkt | 26 ++++++++-- eval/run.rkt | 42 +++++++++++----- 3 files changed, 123 insertions(+), 74 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 4ceb2a2..ca42378 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -18,65 +18,78 @@ (define varc (vector-length args)) (define vhint (make-vector (vector-length ivec) #f)) + (define (vhint-set! idx x) + (when (>= idx varc) + (vector-set! vhint (- idx varc) x))) + (for ([root-reg (in-vector rootvec)]) - (vector-set! vhint (- root-reg varc) #t)) - (for/vector ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] - [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] - #:when hint) - (case (object-name (car instr)) - [(ival-if) - (match-define (vector _ cond tru fls) instr) - (define cond-reg (vector-ref vregs cond)) - (cond - [(and (ival-lo cond-reg) (ival-hi cond-reg)) - (vector-set! vhint (- cond varc) (or #f (vector-ref vhint (- cond varc)))) - (vector-set! vhint (- tru varc) #t) - (vector-set! vhint (- fls varc) (or #f (vector-ref vhint (- fls varc)))) - 1] - [(not (or (ival-lo cond-reg) (ival-hi cond-reg))) - (vector-set! vhint (- cond varc) (or #f (vector-ref vhint (- cond varc)))) - (vector-set! vhint (- tru varc) (or #f (vector-ref vhint (- tru varc)))) - (vector-set! vhint (- fls varc) #t) - 2] - [else - (vector-set! vhint (- cond varc) #t) - (vector-set! vhint (- tru varc) #t) - (vector-set! vhint (- fls varc) #t) - #t])] - [(ival-fmax) - (match-define (vector _ arg1 arg2) instr) - (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (cond - [(and (ival-lo cmp) (ival-hi cmp)) - (vector-set! vhint (- arg2 varc) (or #f (vector-ref vhint (- arg2 varc)))) - (vector-set! vhint (- arg1 varc) #t) - 0] - [(not (or (ival-lo cmp) (ival-hi cmp))) - (vector-set! vhint (- arg1 varc) (or #f (vector-ref vhint (- arg1 varc)))) - (vector-set! vhint (- arg2 varc) #t) - 1] - [else - (vector-set! vhint (- arg1 varc) #t) - (vector-set! vhint (- arg2 varc) #t) - #t])] - [(ival-fmin) - (match-define (vector _ arg1 arg2) instr) - (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (cond - [(and (ival-lo cmp) (ival-hi cmp)) - (vector-set! vhint (- arg1 varc) (or #f (vector-ref vhint (- arg1 varc)))) - (vector-set! vhint (- arg2 varc) #t) - 1] - [(not (or (ival-lo cmp) (ival-hi cmp))) - (vector-set! vhint (- arg2 varc) (or #f (vector-ref vhint (- arg2 varc)))) - (vector-set! vhint (- arg1 varc) #t) - 0] - [else - (vector-set! vhint (- arg1 varc) #t) - (vector-set! vhint (- arg2 varc) #t) - #t])] - ;(vector-map (curry vector-ref vregs) (vector-rest instr)) - [else #t]))) + (vhint-set! root-reg #t)) + (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] + [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] + [n (in-range (- (vector-length vhint) 1) -1 -1)] + #:when hint) + (define hint* + (case (object-name (car instr)) + [(ival-if) + (match-define (list _ cond tru fls) instr) + (define cond-reg (vector-ref vregs cond)) + (cond + [(and (ival-lo cond-reg) (ival-hi cond-reg)) + (vhint-set! cond (or #f (vector-ref vhint (- cond varc)))) + (vhint-set! tru #t) + (vhint-set! fls (or #f (vector-ref vhint (- fls varc)))) + 2] + [(not (or (ival-lo cond-reg) (ival-hi cond-reg))) + (vhint-set! cond (or #f (vector-ref vhint (- cond varc)))) + (vhint-set! tru (or #f (vector-ref vhint (- tru varc)))) + (vhint-set! fls #t) + 3] + [#t + (vhint-set! cond #t) + (vhint-set! tru #t) + (vhint-set! fls #t) + #t])] + [(ival-fmax) + (match-define (list _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (cond + [(and (ival-lo cmp) (ival-hi cmp)) + (vhint-set! arg2 (or #f (vector-ref vhint (- arg2 varc)))) + (vhint-set! arg1 #t) + 1] + [(not (or (ival-lo cmp) (ival-hi cmp))) + (vhint-set! arg1 (or #f (vector-ref vhint (- arg1 varc)))) + (vhint-set! arg2 #t) + 2] + [#t + (vhint-set! arg1 #t) + (vhint-set! arg2 #t) + #t])] + [(ival-fmin) + (match-define (list _ arg1 arg2) instr) + (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) + (cond + [(and (ival-lo cmp) (ival-hi cmp)) + (vhint-set! arg1 (or #f (vector-ref vhint (- arg1 varc)))) + (vhint-set! arg2 #t) + 2] + [(not (or (ival-lo cmp) (ival-hi cmp))) + (vhint-set! arg2 (or #f (vector-ref vhint (- arg2 varc)))) + (vhint-set! arg1 #t) + 1] + [#t + (vhint-set! arg1 #t) + (vhint-set! arg2 #t) + #t])] + [else + (define srcs (rest instr)) + (map (λ (x) + (vhint-set! x) + #t) + srcs) + #t])) + (vector-set! vhint n hint*)) + vhint) (define (backward-pass machine) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 diff --git a/eval/main.rkt b/eval/main.rkt index fa80171..7a131fe 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -24,14 +24,14 @@ (define ground-truth-require-convergence (make-parameter #t)) -(define (rival-machine-full machine inputs) +(define (rival-machine-full machine inputs [hint #f]) (set-rival-machine-iteration! machine (*sampling-iteration*)) (rival-machine-adjust machine) (cond [(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)] [else (rival-machine-load machine inputs) - (rival-machine-run machine) + (rival-machine-run machine hint) (rival-machine-return machine)])) (struct exn:rival exn:fail ()) @@ -62,14 +62,14 @@ (define (ival-real x) (ival x)) -(define (rival-apply machine pt) +(define (rival-apply machine pt [hint #f]) (define discs (rival-machine-discs machine)) (set-rival-machine-bumps! machine 0) (let loop ([iter 0]) (define-values (good? done? bad? stuck? fvec) (parameterize ([*sampling-iteration* iter] [ground-truth-require-convergence #t]) - (rival-machine-full machine (vector-map ival-real pt)))) + (rival-machine-full machine (vector-map ival-real pt) hint))) (cond [bad? (raise (exn:rival:invalid "Invalid input" (current-continuation-marks) pt))] [done? fvec] @@ -85,3 +85,21 @@ (rival-machine-full machine rect))) (define hint (make-hint machine)) (values (ival (or bad? stuck?) (not good?)) hint)) + +(module+ test + (require rackunit + "compile.rkt" + "../utils.rkt" + math/bigfloat) + (define (rival-check-hint machine hint pt) + (check-equal? (rival-apply machine pt hint) (rival-apply machine pt))) + + (define discs (list boolean-discretization flonum-discretization)) + (define vars '(x y)) + (define expr (list '(TRUE) '(fmax -5 (fmin x (fmax y (cos PI)))))) + (define machine (rival-compile expr vars discs)) + + (define-values (a hint) (rival-analyze machine (vector (ival (bf 0) (bf 1)) (ival (bf 4) (bf 5))))) + (rival-check-hint machine hint (vector (bf 1) (bf 5))) + (rival-check-hint machine hint (vector (bf 0) (bf 4))) + (rival-check-hint machine hint (vector (bf 0.5) (bf 4.5)))) diff --git a/eval/run.rkt b/eval/run.rkt index be76c5c..4bc833b 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -31,7 +31,7 @@ (flvector-set! profile-time profile-ptr time) (set-rival-machine-profile-ptr! machine (add1 profile-ptr)))) -(define (rival-machine-run machine) +(define (rival-machine-run machine [vhint #f]) (define ivec (rival-machine-instructions machine)) (define varc (vector-length (rival-machine-arguments machine))) (define precisions (rival-machine-precisions machine)) @@ -40,17 +40,35 @@ ; parameter for sampling histogram table (define first-iter? (zero? (rival-machine-iteration machine))) - (for ([instr (in-vector ivec)] - [n (in-naturals varc)] - [precision (in-vector precisions)] - [repeat (in-vector repeats)] - #:unless (and (not first-iter?) repeat)) - (define start (current-inexact-milliseconds)) - (parameterize ([bf-precision precision]) - (vector-set! vregs n (apply-instruction instr vregs))) - (define name (object-name (car instr))) - (define time (- (current-inexact-milliseconds) start)) - (rival-machine-record machine name n precision time))) + (if vhint + (for ([instr (in-vector ivec)] + [n (in-naturals varc)] + [precision (in-vector precisions)] + [repeat (in-vector repeats)] + [hint (in-vector vhint)] + #:unless (or (not hint) (and (not first-iter?) repeat))) + (define start (current-inexact-milliseconds)) + (parameterize ([bf-precision precision]) + (vector-set! vregs + n + (if (integer? hint) + (vector-ref vregs (list-ref instr hint)) + (apply-instruction instr vregs)))) + (define name (object-name (car instr))) + (define time (- (current-inexact-milliseconds) start)) + (rival-machine-record machine name n precision time)) + + (for ([instr (in-vector ivec)] + [n (in-naturals varc)] + [precision (in-vector precisions)] + [repeat (in-vector repeats)] + #:unless (and (not first-iter?) repeat)) + (define start (current-inexact-milliseconds)) + (parameterize ([bf-precision precision]) + (vector-set! vregs n (apply-instruction instr vregs))) + (define name (object-name (car instr))) + (define time (- (current-inexact-milliseconds) start)) + (rival-machine-record machine name n precision time)))) (define (apply-instruction instr regs) ;; By special-casing the 0-3 instruction case, From 33f556642be49a7eb0b0fd33ee5e0cf5bc75f343 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Tue, 17 Dec 2024 17:04:46 -0700 Subject: [PATCH 03/29] tests are done, one weird syntax issue exists --- eval/adjust.rkt | 38 +++++++++++++------------ eval/main.rkt | 73 +++++++++++++++++++++++++++++++++++++++++++++---- 2 files changed, 87 insertions(+), 24 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index ca42378..5e35b15 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -18,9 +18,13 @@ (define varc (vector-length args)) (define vhint (make-vector (vector-length ivec) #f)) - (define (vhint-set! idx x) + (define (vhint-set! idx val) (when (>= idx varc) - (vector-set! vhint (- idx varc) x))) + (vector-set! vhint (- idx varc) val))) + (define (vhint-ref idx) + (if (>= idx varc) + (vector-ref vhint (- idx varc)) + #f)) (for ([root-reg (in-vector rootvec)]) (vhint-set! root-reg #t)) @@ -35,16 +39,16 @@ (define cond-reg (vector-ref vregs cond)) (cond [(and (ival-lo cond-reg) (ival-hi cond-reg)) - (vhint-set! cond (or #f (vector-ref vhint (- cond varc)))) + (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru #t) - (vhint-set! fls (or #f (vector-ref vhint (- fls varc)))) + (vhint-set! fls (or #f (vhint-ref fls))) 2] [(not (or (ival-lo cond-reg) (ival-hi cond-reg))) - (vhint-set! cond (or #f (vector-ref vhint (- cond varc)))) - (vhint-set! tru (or #f (vector-ref vhint (- tru varc)))) + (vhint-set! cond (or #f (vhint-ref cond))) + (vhint-set! tru (or #f (vhint-ref tru))) (vhint-set! fls #t) 3] - [#t + [else (vhint-set! cond #t) (vhint-set! tru #t) (vhint-set! fls #t) @@ -54,14 +58,14 @@ (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) (cond [(and (ival-lo cmp) (ival-hi cmp)) - (vhint-set! arg2 (or #f (vector-ref vhint (- arg2 varc)))) + (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] [(not (or (ival-lo cmp) (ival-hi cmp))) - (vhint-set! arg1 (or #f (vector-ref vhint (- arg1 varc)))) + (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] - [#t + [else (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] @@ -70,24 +74,22 @@ (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) (cond [(and (ival-lo cmp) (ival-hi cmp)) - (vhint-set! arg1 (or #f (vector-ref vhint (- arg1 varc)))) + (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] [(not (or (ival-lo cmp) (ival-hi cmp))) - (vhint-set! arg2 (or #f (vector-ref vhint (- arg2 varc)))) + (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] - [#t + [else (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] [else (define srcs (rest instr)) - (map (λ (x) - (vhint-set! x) - #t) - srcs) + (map (λ (x) (vhint-set! x #t)) srcs) #t])) + (println "done") (vector-set! vhint n hint*)) vhint) @@ -217,4 +219,4 @@ ; Lower precision bound propogation (vector-set! vprecs-min (- x varc) - (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) + (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) \ No newline at end of file diff --git a/eval/main.rkt b/eval/main.rkt index 7a131fe..3ece430 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -91,15 +91,76 @@ "compile.rkt" "../utils.rkt" math/bigfloat) + (define number-of-random-hyperrects 100) + (define number-of-random-pts-per-rect 100) + (bf-precision 53) + + ; Check whether outputs are the same for the hint and without hint executions (define (rival-check-hint machine hint pt) (check-equal? (rival-apply machine pt hint) (rival-apply machine pt))) + (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) + (for/vector ([_ (in-range varc)]) + (define xlo-range-length (bf- rect-hi rect-lo)) + (define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo)) + (define xhi-range-length (bf- rect-hi xlo)) + (define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo)) + (check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo)) + "Hyperrect is out of bounds") + (ival xlo xhi))) + + (define (sample-pts hyperrect) + (for/vector ([rect (in-vector hyperrect)]) + (define range-length (bf- (ival-hi rect) (ival-lo rect))) + (define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect))) + (check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect))) + "Sampled point is out of hyperrect range") + pt)) + + (define (hints-random-checks machine rect-lo rect-hi varc) + (define evaluated-instructions 0) + (define number-of-instructions-total + (* number-of-random-hyperrects (vector-length (rival-machine-instructions machine)))) + + (for ([n (in-range number-of-random-hyperrects)]) + (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) + (define-values (res hint) (rival-analyze machine hyperrect)) + (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) + + (for ([_ (in-range number-of-random-pts-per-rect)]) + (define pt (sample-pts hyperrect)) + (rival-check-hint machine hint pt))) + + (define skipped-instructions-by-hint (- number-of-instructions-total evaluated-instructions)) + (define skipped-percentage (* (/ skipped-instructions-by-hint number-of-instructions-total) 100)) + skipped-percentage) + (define discs (list boolean-discretization flonum-discretization)) (define vars '(x y)) - (define expr (list '(TRUE) '(fmax -5 (fmin x (fmax y (cos PI)))))) - (define machine (rival-compile expr vars discs)) + (define varc (length vars)) + + (define expr1 (list '(TRUE) '(fmax -5 (fmin x (fmax y (cos PI)))))) + (define machine1 (rival-compile expr1 vars discs)) + #;(define skipped-instr1 (hints-random-checks machine1 (bf -10) (bf 10) varc)) + #;(printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) + + (define expr2 + (list '(TRUE) + '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) + (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) + (define machine2 (rival-compile expr2 vars discs)) + #;(define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) + #;(printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) - (define-values (a hint) (rival-analyze machine (vector (ival (bf 0) (bf 1)) (ival (bf 4) (bf 5))))) - (rival-check-hint machine hint (vector (bf 1) (bf 5))) - (rival-check-hint machine hint (vector (bf 0) (bf 4))) - (rival-check-hint machine hint (vector (bf 0.5) (bf 4.5)))) + (define expr3 + (list '(TRUE) + '(if (> (exp x) (+ 10 (log y))) + (if (> (fmax (* x y) (+ x y)) 4) + (cos (fmax x y)) + (cos (fmin x y))) + (if (< (pow 2 x) (- (exp x) 10)) + (* PI x) + (fmax x (- (cos y) (+ 10 (log y)))))))) + (define machine3 (rival-compile expr3 vars discs)) + (define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc)) + (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3))) From 0c174dbd793d4e95717bda2b5bcbddf2807206f9 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Tue, 17 Dec 2024 17:56:51 -0700 Subject: [PATCH 04/29] debugging sesh, turned out that we care about error flags when doing min/max/if optimizations --- eval/adjust.rkt | 25 ++++++++++++------------- eval/main.rkt | 17 +++++++++++------ 2 files changed, 23 insertions(+), 19 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 5e35b15..c6e7435 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -37,18 +37,18 @@ [(ival-if) (match-define (list _ cond tru fls) instr) (define cond-reg (vector-ref vregs cond)) - (cond - [(and (ival-lo cond-reg) (ival-hi cond-reg)) + (match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg)) + [(#t #t #f) (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru #t) (vhint-set! fls (or #f (vhint-ref fls))) 2] - [(not (or (ival-lo cond-reg) (ival-hi cond-reg))) + [(#f #f #f) (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru (or #f (vhint-ref tru))) (vhint-set! fls #t) 3] - [else + [(_ _ _) (vhint-set! cond #t) (vhint-set! tru #t) (vhint-set! fls #t) @@ -56,32 +56,32 @@ [(ival-fmax) (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (cond - [(and (ival-lo cmp) (ival-hi cmp)) + (match* ((ival-lo cmp) (ival-hi cmp)) + [(#t #t) (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] - [(not (or (ival-lo cmp) (ival-hi cmp))) + [(#f #f) (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] - [else + [(#f #t) (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] [(ival-fmin) (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (cond - [(and (ival-lo cmp) (ival-hi cmp)) + (match* ((ival-lo cmp) (ival-hi cmp)) + [(#t #t) (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] - [(not (or (ival-lo cmp) (ival-hi cmp))) + [(#f #f) (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] - [else + [(#f #t) (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] @@ -89,7 +89,6 @@ (define srcs (rest instr)) (map (λ (x) (vhint-set! x #t)) srcs) #t])) - (println "done") (vector-set! vhint n hint*)) vhint) diff --git a/eval/main.rkt b/eval/main.rkt index 3ece430..20754db 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -97,7 +97,12 @@ ; Check whether outputs are the same for the hint and without hint executions (define (rival-check-hint machine hint pt) - (check-equal? (rival-apply machine pt hint) (rival-apply machine pt))) + (check-equal? (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt)) + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt hint)))) (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) (for/vector ([_ (in-range varc)]) @@ -139,18 +144,18 @@ (define vars '(x y)) (define varc (length vars)) - (define expr1 (list '(TRUE) '(fmax -5 (fmin x (fmax y (cos PI)))))) + (define expr1 (list '(TRUE) '(fmax -5 (fmin (log x) (fmax y (cos PI)))))) (define machine1 (rival-compile expr1 vars discs)) - #;(define skipped-instr1 (hints-random-checks machine1 (bf -10) (bf 10) varc)) - #;(printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) + (define skipped-instr1 (hints-random-checks machine1 (bf -10) (bf 10) varc)) + (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) (define expr2 (list '(TRUE) '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) (define machine2 (rival-compile expr2 vars discs)) - #;(define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) - #;(printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) + (define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) + (printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) (define expr3 (list '(TRUE) From 007ce203cc9d1ff266c86be46677c3ebad6ec0e1 Mon Sep 17 00:00:00 2001 From: AYadrov <45910827+AYadrov@users.noreply.github.com> Date: Tue, 17 Dec 2024 18:00:17 -0700 Subject: [PATCH 05/29] Update tests.yml, added tests from eval/*.rkt --- .github/workflows/tests.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 82490da..4e4a21c 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -22,3 +22,4 @@ jobs: - name: "Make sure files are correctly formatted with raco fmt" run: git diff --exit-code - run: raco test *.rkt + - run: raco test eval/*.rkt From 441caa7aab3c853756ab1dd9d4b1077d044ed3ce Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 18 Dec 2024 13:12:58 -0700 Subject: [PATCH 06/29] rival-machine-run looksmaxxing --- eval/main.rkt | 15 +++++++++++++-- eval/run.rkt | 47 ++++++++++++++++++----------------------------- 2 files changed, 31 insertions(+), 31 deletions(-) diff --git a/eval/main.rkt b/eval/main.rkt index 20754db..710baaf 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -104,6 +104,7 @@ [exn:rival:unsamplable? (λ (e) 'unsamplable)]) (rival-apply machine pt hint)))) + ; Random sampling hyperrects given a general range as [rect-lo, rect-hi] (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) (for/vector ([_ (in-range varc)]) (define xlo-range-length (bf- rect-hi rect-lo)) @@ -114,6 +115,7 @@ "Hyperrect is out of bounds") (ival xlo xhi))) + ; Sample points with respect to the input hyperrect (define (sample-pts hyperrect) (for/vector ([rect (in-vector hyperrect)]) (define range-length (bf- (ival-hi rect) (ival-lo rect))) @@ -122,6 +124,8 @@ "Sampled point is out of hyperrect range") pt)) + ; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by + ; 'number-of-random-pts-per-rect' points each (define (hints-random-checks machine rect-lo rect-hi varc) (define evaluated-instructions 0) (define number-of-instructions-total @@ -144,9 +148,16 @@ (define vars '(x y)) (define varc (length vars)) - (define expr1 (list '(TRUE) '(fmax -5 (fmin (log x) (fmax y (cos PI)))))) + (define expr1 + (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) + (fmax (* (log x) (log y)) (+ (log x) (log y))) + (fmin (* (log x) (log y)) (+ (log x) (log y)))) + (if (> (+ (log x) (log y)) (* (log x) (log y))) + (fmax (/ (log x) (log y)) (- (log x) (log y))) + (fmin (/ (log x) (log y)) (- (log x) (log y))))))) (define machine1 (rival-compile expr1 vars discs)) - (define skipped-instr1 (hints-random-checks machine1 (bf -10) (bf 10) varc)) + (define skipped-instr1 (hints-random-checks machine1 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) (define expr2 diff --git a/eval/run.rkt b/eval/run.rkt index 4bc833b..92a840f 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -40,35 +40,24 @@ ; parameter for sampling histogram table (define first-iter? (zero? (rival-machine-iteration machine))) - (if vhint - (for ([instr (in-vector ivec)] - [n (in-naturals varc)] - [precision (in-vector precisions)] - [repeat (in-vector repeats)] - [hint (in-vector vhint)] - #:unless (or (not hint) (and (not first-iter?) repeat))) - (define start (current-inexact-milliseconds)) - (parameterize ([bf-precision precision]) - (vector-set! vregs - n - (if (integer? hint) - (vector-ref vregs (list-ref instr hint)) - (apply-instruction instr vregs)))) - (define name (object-name (car instr))) - (define time (- (current-inexact-milliseconds) start)) - (rival-machine-record machine name n precision time)) - - (for ([instr (in-vector ivec)] - [n (in-naturals varc)] - [precision (in-vector precisions)] - [repeat (in-vector repeats)] - #:unless (and (not first-iter?) repeat)) - (define start (current-inexact-milliseconds)) - (parameterize ([bf-precision precision]) - (vector-set! vregs n (apply-instruction instr vregs))) - (define name (object-name (car instr))) - (define time (- (current-inexact-milliseconds) start)) - (rival-machine-record machine name n precision time)))) + (for ([instr (in-vector ivec)] + [n (in-naturals varc)] + [precision (in-vector precisions)] + [repeat (in-vector repeats)] + [hint (if vhint + (in-vector vhint) + (in-producer (const #t)))] + #:unless (or (not hint) (and (not first-iter?) repeat))) + (define start (current-inexact-milliseconds)) + (parameterize ([bf-precision precision]) + (vector-set! vregs + n + (if (integer? hint) + (vector-ref vregs (list-ref instr hint)) + (apply-instruction instr vregs)))) + (define name (object-name (car instr))) + (define time (- (current-inexact-milliseconds) start)) + (rival-machine-record machine name n precision time))) (define (apply-instruction instr regs) ;; By special-casing the 0-3 instruction case, From 7037ccf285ec6490068e8b8e10764aab10e9646e Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 18 Dec 2024 14:08:50 -0700 Subject: [PATCH 07/29] some comments --- eval/adjust.rkt | 31 ++++++++++++++++++++----------- 1 file changed, 20 insertions(+), 11 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index c6e7435..dcb8a11 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -9,6 +9,15 @@ (provide backward-pass make-hint) +; Hint is a vector with len(ivec) elements which +; guides Rival on which instructions should not be executed +; for points from a particular hyperrect of input parameters. +; make-hint is called as a last step of rival-analyze and returns hint as a result +; Values of a hint: +; #f - instruction should not be executed +; #t - instruction should be executed +; integer n - instead of executing, refer to vregs with (list-ref instr n) index +; (the result is known and stored in another register) (define (make-hint machine) (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) @@ -38,17 +47,17 @@ (match-define (list _ cond tru fls) instr) (define cond-reg (vector-ref vregs cond)) (match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg)) - [(#t #t #f) + [(#t #t #f) ; only true path should be executed (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru #t) (vhint-set! fls (or #f (vhint-ref fls))) 2] - [(#f #f #f) + [(#f #f #f) ; only false path should be executed (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru (or #f (vhint-ref tru))) (vhint-set! fls #t) 3] - [(_ _ _) + [(_ _ _) ; execute both paths and cond as well (vhint-set! cond #t) (vhint-set! tru #t) (vhint-set! fls #t) @@ -57,15 +66,15 @@ (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) (match* ((ival-lo cmp) (ival-hi cmp)) - [(#t #t) + [(#t #t) ; only arg1 should be executed (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] - [(#f #f) + [(#f #f) ; only arg2 should be executed (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] - [(#f #t) + [(#f #t) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] @@ -73,20 +82,20 @@ (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) (match* ((ival-lo cmp) (ival-hi cmp)) - [(#t #t) + [(#t #t) ; only arg2 should be executed (vhint-set! arg1 (or #f (vhint-ref arg1))) (vhint-set! arg2 #t) 2] - [(#f #f) + [(#f #f) ; only arg1 should be executed (vhint-set! arg2 (or #f (vhint-ref arg2))) (vhint-set! arg1 #t) 1] - [(#f #t) + [(#f #t) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) #t])] - [else - (define srcs (rest instr)) + [else ; at this point we are given that the current instruction should be executed + (define srcs (rest instr)) ; then, children instructions should be executed as well (map (λ (x) (vhint-set! x #t)) srcs) #t])) (vector-set! vhint n hint*)) From fd1275e0303be89fe8f7f07b87dddde0f4702131 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Fri, 20 Dec 2024 14:27:22 -0700 Subject: [PATCH 08/29] contract change for rival-analyze --- main.rkt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/main.rkt b/main.rkt index a358c3d..8a29f14 100644 --- a/main.rkt +++ b/main.rkt @@ -90,10 +90,11 @@ (require "eval/main.rkt" (only-in "eval/machine.rkt" rival-machine?)) -(provide (contract-out [rival-compile - (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] - [rival-apply (-> rival-machine? (vectorof value?) (vectorof any/c))] - [rival-analyze (-> rival-machine? (vectorof ival?) ival?)]) +(provide (contract-out + [rival-compile (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] + [rival-apply (-> rival-machine? (vectorof value?) (vectorof any/c))] + [rival-analyze + (-> rival-machine? (vectorof ival?) (values ival? (vectorof (or/c boolean? number?))))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid) From 5ad41cd80e40d519cadb76b63247d06156f52bcb Mon Sep 17 00:00:00 2001 From: AYadrov Date: Mon, 23 Dec 2024 15:25:28 -0700 Subject: [PATCH 09/29] no backward pass when hint says so --- eval/adjust.rkt | 11 +++++++---- eval/main.rkt | 2 +- eval/run.rkt | 4 ++-- main.rkt | 8 +++++--- 4 files changed, 15 insertions(+), 10 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index dcb8a11..a6d6844 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -101,7 +101,7 @@ (vector-set! vhint n hint*)) vhint) -(define (backward-pass machine) +(define (backward-pass machine [vhint #f]) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) @@ -143,7 +143,7 @@ (vector-set! vuseful (- arg varc) #t))])) ; Step 2. Precision tuning - (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful) + (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint) ; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all ; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well @@ -183,12 +183,15 @@ ; Roughly speaking, the upper precision bound is calculated as: ; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec ; max-prec = (car (get-bounds parent)) -(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful) +(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint) (define vprecs-min (make-vector (vector-length ivec) 0)) (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] - #:when useful?) + [hint (if vhint + (in-vector vhint) + (in-producer (const #t)))] + #:when (and vhint useful?)) (define op (car instr)) (define tail-registers (cdr instr)) (define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers)) diff --git a/eval/main.rkt b/eval/main.rkt index 710baaf..d6aecfb 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -26,7 +26,7 @@ (define (rival-machine-full machine inputs [hint #f]) (set-rival-machine-iteration! machine (*sampling-iteration*)) - (rival-machine-adjust machine) + (rival-machine-adjust machine hint) (cond [(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)] [else diff --git a/eval/run.rkt b/eval/run.rkt index 92a840f..df4330b 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -101,10 +101,10 @@ lo)) (values good? (and good? done?) bad? stuck? fvec)) -(define (rival-machine-adjust machine) +(define (rival-machine-adjust machine [hint #f]) (define iter (rival-machine-iteration machine)) (let ([start (current-inexact-milliseconds)]) (if (zero? iter) (vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine)) - (backward-pass machine)) + (backward-pass machine hint)) (rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start)))) diff --git a/main.rkt b/main.rkt index 8a29f14..e20a26b 100644 --- a/main.rkt +++ b/main.rkt @@ -92,9 +92,11 @@ (only-in "eval/machine.rkt" rival-machine?)) (provide (contract-out [rival-compile (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] - [rival-apply (-> rival-machine? (vectorof value?) (vectorof any/c))] - [rival-analyze - (-> rival-machine? (vectorof ival?) (values ival? (vectorof (or/c boolean? number?))))]) + [rival-apply + (->* (rival-machine? (vectorof value?)) + ((or/c (vectorof any/c) boolean?)) + (vectorof any/c))] + [rival-analyze (-> rival-machine? (vectorof ival?) (values ival? (vectorof any/c)))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid) From b26b23695fd111a1ba4159bd51741cc389d7c901 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 16:11:36 -0700 Subject: [PATCH 10/29] removed skipping of backward pass due to some unknown yet bugs + converged flag for rival-analyze --- eval/adjust.rkt | 17 +++++++++-------- eval/main.rkt | 10 +++++----- eval/run.rkt | 4 ++-- 3 files changed, 16 insertions(+), 15 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index a6d6844..40447cc 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -26,6 +26,7 @@ (define varc (vector-length args)) (define vhint (make-vector (vector-length ivec) #f)) + (define converged? #t) (define (vhint-set! idx val) (when (>= idx varc) @@ -61,6 +62,7 @@ (vhint-set! cond #t) (vhint-set! tru #t) (vhint-set! fls #t) + (set! converged? #f) #t])] [(ival-fmax) (match-define (list _ arg1 arg2) instr) @@ -77,6 +79,7 @@ [(#f #t) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) + (set! converged? #f) #t])] [(ival-fmin) (match-define (list _ arg1 arg2) instr) @@ -93,15 +96,16 @@ [(#f #t) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) + (set! converged? #f) #t])] [else ; at this point we are given that the current instruction should be executed (define srcs (rest instr)) ; then, children instructions should be executed as well (map (λ (x) (vhint-set! x #t)) srcs) #t])) (vector-set! vhint n hint*)) - vhint) + (values vhint converged?)) -(define (backward-pass machine [vhint #f]) +(define (backward-pass machine) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) @@ -143,7 +147,7 @@ (vector-set! vuseful (- arg varc) #t))])) ; Step 2. Precision tuning - (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint) + (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful) ; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all ; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well @@ -183,15 +187,12 @@ ; Roughly speaking, the upper precision bound is calculated as: ; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec ; max-prec = (car (get-bounds parent)) -(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint) +(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful) (define vprecs-min (make-vector (vector-length ivec) 0)) (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] - [hint (if vhint - (in-vector vhint) - (in-producer (const #t)))] - #:when (and vhint useful?)) + #:when useful?) (define op (car instr)) (define tail-registers (cdr instr)) (define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers)) diff --git a/eval/main.rkt b/eval/main.rkt index d6aecfb..29a3f43 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -26,7 +26,7 @@ (define (rival-machine-full machine inputs [hint #f]) (set-rival-machine-iteration! machine (*sampling-iteration*)) - (rival-machine-adjust machine hint) + (rival-machine-adjust machine) (cond [(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)] [else @@ -42,7 +42,7 @@ (define (rival-profile machine param) (match param - ['instructions (vector-length (rival-machine-instructions machine))] + ['instructions (rival-machine-instructions machine)] ['iterations (rival-machine-iteration machine)] ['bumps (rival-machine-bumps machine)] ['executions @@ -83,8 +83,8 @@ (parameterize ([*sampling-iteration* 0] [ground-truth-require-convergence #f]) (rival-machine-full machine rect))) - (define hint (make-hint machine)) - (values (ival (or bad? stuck?) (not good?)) hint)) + (define-values (hint hint-converged?) (make-hint machine)) + (values (ival (or bad? stuck?) (not good?)) hint hint-converged?)) (module+ test (require rackunit @@ -133,7 +133,7 @@ (for ([n (in-range number-of-random-hyperrects)]) (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) - (define-values (res hint) (rival-analyze machine hyperrect)) + (define-values (res hint _) (rival-analyze machine hyperrect)) (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) (for ([_ (in-range number-of-random-pts-per-rect)]) diff --git a/eval/run.rkt b/eval/run.rkt index df4330b..92a840f 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -101,10 +101,10 @@ lo)) (values good? (and good? done?) bad? stuck? fvec)) -(define (rival-machine-adjust machine [hint #f]) +(define (rival-machine-adjust machine) (define iter (rival-machine-iteration machine)) (let ([start (current-inexact-milliseconds)]) (if (zero? iter) (vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine)) - (backward-pass machine hint)) + (backward-pass machine)) (rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start)))) From 1445846ab550d1888a781500c3a423ce4a2723b0 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 16:53:53 -0700 Subject: [PATCH 11/29] the boog with backward pass + hint has been found --- eval/adjust.rkt | 15 +++++++++------ eval/main.rkt | 6 +++--- eval/run.rkt | 4 ++-- 3 files changed, 14 insertions(+), 11 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 40447cc..2e9320d 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -1,4 +1,4 @@ -#lang racket/base +#lang racket (require "tricks.rkt" "../ops/all.rkt" @@ -105,7 +105,7 @@ (vector-set! vhint n hint*)) (values vhint converged?)) -(define (backward-pass machine) +(define (backward-pass machine vhint) ; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3 (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) @@ -147,7 +147,7 @@ (vector-set! vuseful (- arg varc) #t))])) ; Step 2. Precision tuning - (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful) + (precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint) ; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all ; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well @@ -187,12 +187,15 @@ ; Roughly speaking, the upper precision bound is calculated as: ; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec ; max-prec = (car (get-bounds parent)) -(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful) +(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint) (define vprecs-min (make-vector (vector-length ivec) 0)) (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] - #:when useful?) + [hint (if vhint + (in-vector vhint) + (in-producer (const #t)))] + #:when (and hint useful?)) (define op (car instr)) (define tail-registers (cdr instr)) (define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers)) @@ -231,4 +234,4 @@ ; Lower precision bound propogation (vector-set! vprecs-min (- x varc) - (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) \ No newline at end of file + (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) diff --git a/eval/main.rkt b/eval/main.rkt index 29a3f43..a984ef4 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -24,14 +24,14 @@ (define ground-truth-require-convergence (make-parameter #t)) -(define (rival-machine-full machine inputs [hint #f]) +(define (rival-machine-full machine inputs [vhint #f]) (set-rival-machine-iteration! machine (*sampling-iteration*)) - (rival-machine-adjust machine) + (rival-machine-adjust machine vhint) (cond [(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)] [else (rival-machine-load machine inputs) - (rival-machine-run machine hint) + (rival-machine-run machine vhint) (rival-machine-return machine)])) (struct exn:rival exn:fail ()) diff --git a/eval/run.rkt b/eval/run.rkt index 92a840f..ad378a6 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -101,10 +101,10 @@ lo)) (values good? (and good? done?) bad? stuck? fvec)) -(define (rival-machine-adjust machine) +(define (rival-machine-adjust machine vhint) (define iter (rival-machine-iteration machine)) (let ([start (current-inexact-milliseconds)]) (if (zero? iter) (vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine)) - (backward-pass machine)) + (backward-pass machine vhint)) (rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start)))) From 936c72ff1c3a67b38e004a4161031a00e9ebaa08 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 16:55:16 -0700 Subject: [PATCH 12/29] contract change for rival-analyze including completeness flag --- main.rkt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/main.rkt b/main.rkt index e20a26b..2ba5ef3 100644 --- a/main.rkt +++ b/main.rkt @@ -96,7 +96,8 @@ (->* (rival-machine? (vectorof value?)) ((or/c (vectorof any/c) boolean?)) (vectorof any/c))] - [rival-analyze (-> rival-machine? (vectorof ival?) (values ival? (vectorof any/c)))]) + [rival-analyze + (-> rival-machine? (vectorof ival?) (values ival? (vectorof any/c) boolean?))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid) From 5fe3f6831dcfd6e0e3950556cfe254feb0126d9c Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 16:57:27 -0700 Subject: [PATCH 13/29] restoring rival-profile --- eval/main.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/main.rkt b/eval/main.rkt index a984ef4..8a1492d 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -42,7 +42,7 @@ (define (rival-profile machine param) (match param - ['instructions (rival-machine-instructions machine)] + ['instructions (vector-length (rival-machine-instructions machine))] ['iterations (rival-machine-iteration machine)] ['bumps (rival-machine-bumps machine)] ['executions From 4b2080da55d8c59dcf6d4e31075edc3e7e973693 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 17:18:33 -0700 Subject: [PATCH 14/29] quite obvious bug --- eval/adjust.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 2e9320d..cf2b706 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -193,7 +193,7 @@ [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] [hint (if vhint - (in-vector vhint) + (in-vector vhint (- (vector-length vhint) 1) -1 -1) (in-producer (const #t)))] #:when (and hint useful?)) (define op (car instr)) From 757f9be64af6f863974d420276b1c390e15738de Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 17:54:22 -0700 Subject: [PATCH 15/29] change of contract + return value for rival-analyze (temporarily --- eval/main.rkt | 2 +- main.rkt | 18 ++++++++++-------- 2 files changed, 11 insertions(+), 9 deletions(-) diff --git a/eval/main.rkt b/eval/main.rkt index 8a1492d..d43caae 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -84,7 +84,7 @@ [ground-truth-require-convergence #f]) (rival-machine-full machine rect))) (define-values (hint hint-converged?) (make-hint machine)) - (values (ival (or bad? stuck?) (not good?)) hint hint-converged?)) + (list (ival (or bad? stuck?) (not good?)) hint hint-converged?)) (module+ test (require rackunit diff --git a/main.rkt b/main.rkt index 2ba5ef3..e13d2dd 100644 --- a/main.rkt +++ b/main.rkt @@ -90,14 +90,16 @@ (require "eval/main.rkt" (only-in "eval/machine.rkt" rival-machine?)) -(provide (contract-out - [rival-compile (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] - [rival-apply - (->* (rival-machine? (vectorof value?)) - ((or/c (vectorof any/c) boolean?)) - (vectorof any/c))] - [rival-analyze - (-> rival-machine? (vectorof ival?) (values ival? (vectorof any/c) boolean?))]) +(provide (contract-out [rival-compile + (-> (listof any/c) (listof symbol?) (listof discretization?) rival-machine?)] + [rival-apply + (->* (rival-machine? (vectorof value?)) + ((or/c (vectorof any/c) boolean?)) + (vectorof any/c))] + [rival-analyze + (-> rival-machine? + (vectorof ival?) + (listof any/c))]) ; (values ival? (vectorof any/c) boolean?))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid) From 145d410b9d539732c476701ddf113ab3385a2b5d Mon Sep 17 00:00:00 2001 From: AYadrov Date: Wed, 8 Jan 2025 18:11:07 -0700 Subject: [PATCH 16/29] oops, unit tests fix --- eval/main.rkt | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/eval/main.rkt b/eval/main.rkt index d43caae..cbc2c58 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -133,7 +133,7 @@ (for ([n (in-range number-of-random-hyperrects)]) (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) - (define-values (res hint _) (rival-analyze machine hyperrect)) + (match-define (list res hint _) (rival-analyze machine hyperrect)) (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) (for ([_ (in-range number-of-random-pts-per-rect)]) From 646f2663ba3d37af76ffdbe4659a17a5dc50c65e Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 14:37:08 -0700 Subject: [PATCH 17/29] hint is added as a field of rival-machine --- eval/compile.rkt | 2 ++ eval/machine.rkt | 1 + eval/main.rkt | 4 ++-- eval/run.rkt | 18 ++++++++---------- 4 files changed, 13 insertions(+), 12 deletions(-) diff --git a/eval/compile.rkt b/eval/compile.rkt index a7c5da5..e8e3135 100644 --- a/eval/compile.rkt +++ b/eval/compile.rkt @@ -278,6 +278,7 @@ (define incremental-precisions (setup-vstart-precs instructions (length vars) roots discs)) (define initial-precision (+ (argmax identity (map discretization-target discs)) (*base-tuning-precision*))) + (define hint (make-vector (vector-length instructions) #t)) (rival-machine (list->vector vars) instructions @@ -289,6 +290,7 @@ incremental-precisions (make-vector (vector-length roots)) initial-precision + hint 0 0 0 diff --git a/eval/machine.rkt b/eval/machine.rkt index a6c8c38..8d6e3b3 100644 --- a/eval/machine.rkt +++ b/eval/machine.rkt @@ -29,6 +29,7 @@ incremental-precisions output-distance initial-precision + hint [iteration #:mutable] [bumps #:mutable] [profile-ptr #:mutable] diff --git a/eval/main.rkt b/eval/main.rkt index cbc2c58..8b36652 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -24,7 +24,7 @@ (define ground-truth-require-convergence (make-parameter #t)) -(define (rival-machine-full machine inputs [vhint #f]) +(define (rival-machine-full machine inputs [vhint (rival-machine-hint machine)]) (set-rival-machine-iteration! machine (*sampling-iteration*)) (rival-machine-adjust machine vhint) (cond @@ -62,7 +62,7 @@ (define (ival-real x) (ival x)) -(define (rival-apply machine pt [hint #f]) +(define (rival-apply machine pt [hint (rival-machine-hint machine)]) (define discs (rival-machine-discs machine)) (set-rival-machine-bumps! machine 0) (let loop ([iter 0]) diff --git a/eval/run.rkt b/eval/run.rkt index ad378a6..db183b4 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -31,7 +31,7 @@ (flvector-set! profile-time profile-ptr time) (set-rival-machine-profile-ptr! machine (add1 profile-ptr)))) -(define (rival-machine-run machine [vhint #f]) +(define (rival-machine-run machine vhint) (define ivec (rival-machine-instructions machine)) (define varc (vector-length (rival-machine-arguments machine))) (define precisions (rival-machine-precisions machine)) @@ -44,17 +44,15 @@ [n (in-naturals varc)] [precision (in-vector precisions)] [repeat (in-vector repeats)] - [hint (if vhint - (in-vector vhint) - (in-producer (const #t)))] + [hint (in-vector vhint)] #:unless (or (not hint) (and (not first-iter?) repeat))) (define start (current-inexact-milliseconds)) - (parameterize ([bf-precision precision]) - (vector-set! vregs - n - (if (integer? hint) - (vector-ref vregs (list-ref instr hint)) - (apply-instruction instr vregs)))) + (define out + (parameterize ([bf-precision precision]) + (if (integer? hint) + (vector-ref vregs (list-ref instr hint)) + (apply-instruction instr vregs)))) + (vector-set! vregs n out) (define name (object-name (car instr))) (define time (- (current-inexact-milliseconds) start)) (rival-machine-record machine name n precision time))) From af8950dc7ca43509cc47c55ce49806ed271fd802 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 15:54:23 -0700 Subject: [PATCH 18/29] added hint for assert function --- eval/adjust.rkt | 19 +++++++++++++++---- eval/main.rkt | 8 +++++++- eval/run.rkt | 7 ++++--- 3 files changed, 26 insertions(+), 8 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index cf2b706..a3006f9 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -44,6 +44,19 @@ #:when hint) (define hint* (case (object-name (car instr)) + [(ival-assert) + (match-define (list _ bool-idx) instr) + (define bool-reg (vector-ref vregs bool-idx)) + (match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg)) + [(#t #t #f) ; assert and its children should not be reexecuted if it is true already + (vhint-set! bool-idx (or #f (vhint-ref bool-idx))) + (ival-bool #t)] + [(#f #f #f) ; assert and its children should not be reexecuted if it is false already + (vhint-set! bool-idx (or #f (vhint-ref bool-idx))) + (ival-bool #f)] + [(_ _ _) ; assert and its children should be reexecuted + (vhint-set! bool-idx #t) + #t])] [(ival-if) (match-define (list _ cond tru fls) instr) (define cond-reg (vector-ref vregs cond)) @@ -192,9 +205,7 @@ (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] [useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)] [n (in-range (- (vector-length vregs) 1) -1 -1)] - [hint (if vhint - (in-vector vhint (- (vector-length vhint) 1) -1 -1) - (in-producer (const #t)))] + [hint (in-vector vhint (- (vector-length vhint) 1) -1 -1)] #:when (and hint useful?)) (define op (car instr)) (define tail-registers (cdr instr)) @@ -234,4 +245,4 @@ ; Lower precision bound propogation (vector-set! vprecs-min (- x varc) - (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) + (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) \ No newline at end of file diff --git a/eval/main.rkt b/eval/main.rkt index 8b36652..6163f59 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -179,4 +179,10 @@ (fmax x (- (cos y) (+ 10 (log y)))))))) (define machine3 (rival-compile expr3 vars discs)) (define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3))) + (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3)) + + ; Test checks hint on assert where an error can be observed + (define expr4 (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (cos x) (cos y)))) + (define machine4 (rival-compile expr4 vars discs)) + (define skipped-instr4 (hints-random-checks machine4 (bf -100) (bf 100) varc)) + (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4))) diff --git a/eval/run.rkt b/eval/run.rkt index db183b4..0232f37 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -49,9 +49,10 @@ (define start (current-inexact-milliseconds)) (define out (parameterize ([bf-precision precision]) - (if (integer? hint) - (vector-ref vregs (list-ref instr hint)) - (apply-instruction instr vregs)))) + (match hint + [#t (apply-instruction instr vregs)] + [(? integer? x) (vector-ref vregs (list-ref instr x))] + [(? ival? x) x]))) (vector-set! vregs n out) (define name (object-name (car instr))) (define time (- (current-inexact-milliseconds) start)) From 9a1db6196d4e491c273a41260d99515aa7f1008b Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 16:41:51 -0700 Subject: [PATCH 19/29] emperical evaluation of skipped instructions for unit tests in main.rkt --- eval/adjust.rkt | 1 + eval/main.rkt | 28 +++++++++++++++++++--------- 2 files changed, 20 insertions(+), 9 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index a3006f9..9b9033f 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -18,6 +18,7 @@ ; #t - instruction should be executed ; integer n - instead of executing, refer to vregs with (list-ref instr n) index ; (the result is known and stored in another register) +; ival - instead of executing, just copy ival as a result of the instruction (define (make-hint machine) (define args (rival-machine-arguments machine)) (define ivec (rival-machine-instructions machine)) diff --git a/eval/main.rkt b/eval/main.rkt index 6163f59..9643a41 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -97,12 +97,19 @@ ; Check whether outputs are the same for the hint and without hint executions (define (rival-check-hint machine hint pt) - (check-equal? (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] - [exn:rival:unsamplable? (λ (e) 'unsamplable)]) - (rival-apply machine pt)) - (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] - [exn:rival:unsamplable? (λ (e) 'unsamplable)]) - (rival-apply machine pt hint)))) + (define no-hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt))) + (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) + + (define hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt hint))) + (define hint-instr-count (vector-length (rival-profile machine 'executions))) + (check-equal? hint-result no-hint-result) + (values no-hint-instr-count hint-instr-count)) ; Random sampling hyperrects given a general range as [rect-lo, rect-hi] (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) @@ -131,6 +138,8 @@ (define number-of-instructions-total (* number-of-random-hyperrects (vector-length (rival-machine-instructions machine)))) + (define hint-cnt 0) + (define no-hint-cnt 0) (for ([n (in-range number-of-random-hyperrects)]) (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) (match-define (list res hint _) (rival-analyze machine hyperrect)) @@ -138,10 +147,11 @@ (for ([_ (in-range number-of-random-pts-per-rect)]) (define pt (sample-pts hyperrect)) - (rival-check-hint machine hint pt))) + (define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt)) + (set! hint-cnt (+ hint-cnt hint-cnt*)) + (set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*)))) - (define skipped-instructions-by-hint (- number-of-instructions-total evaluated-instructions)) - (define skipped-percentage (* (/ skipped-instructions-by-hint number-of-instructions-total) 100)) + (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) skipped-percentage) (define discs (list boolean-discretization flonum-discretization)) From 9976e646daeb1da9f854a6f31978e679990c1880 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 16:52:26 -0700 Subject: [PATCH 20/29] unit tests update for rivals hint to verify automatically that the instructions got skipped --- eval/adjust.rkt | 2 +- eval/main.rkt | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 9b9033f..5fd8d99 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -1,4 +1,4 @@ -#lang racket +#lang racket/base (require "tricks.rkt" "../ops/all.rkt" diff --git a/eval/main.rkt b/eval/main.rkt index 9643a41..9bbc6f3 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -169,6 +169,7 @@ (define machine1 (rival-compile expr1 vars discs)) (define skipped-instr1 (hints-random-checks machine1 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) + (check-true (< skipped-instr1 100)) (define expr2 (list '(TRUE) @@ -177,6 +178,7 @@ (define machine2 (rival-compile expr2 vars discs)) (define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) + (check-true (< skipped-instr2 100)) (define expr3 (list '(TRUE) @@ -190,9 +192,11 @@ (define machine3 (rival-compile expr3 vars discs)) (define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3)) + (check-true (< skipped-instr3 100)) ; Test checks hint on assert where an error can be observed (define expr4 (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (cos x) (cos y)))) (define machine4 (rival-compile expr4 vars discs)) (define skipped-instr4 (hints-random-checks machine4 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4))) + (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4)) + (check-true (< skipped-instr4 100))) From b2af4c88bfca1be92d6659516343ac0070b14dfb Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 16:55:07 -0700 Subject: [PATCH 21/29] less allocation in run.rkt --- eval/run.rkt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/eval/run.rkt b/eval/run.rkt index 0232f37..6f4b22b 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -51,8 +51,8 @@ (parameterize ([bf-precision precision]) (match hint [#t (apply-instruction instr vregs)] - [(? integer? x) (vector-ref vregs (list-ref instr x))] - [(? ival? x) x]))) + [(? integer? _) (vector-ref vregs (list-ref instr hint))] + [(? ival? _) hint]))) (vector-set! vregs n out) (define name (object-name (car instr))) (define time (- (current-inexact-milliseconds) start)) From 3a8b855810ae414f9885cef88e5ae49765e088fc Mon Sep 17 00:00:00 2001 From: AYadrov Date: Thu, 9 Jan 2025 18:18:32 -0700 Subject: [PATCH 22/29] introduced some threshold value for unit tests that rival should pass --- eval/main.rkt | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/eval/main.rkt b/eval/main.rkt index 9bbc6f3..fbc708a 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -93,6 +93,7 @@ math/bigfloat) (define number-of-random-hyperrects 100) (define number-of-random-pts-per-rect 100) + (define threshold 95) ; at least 5% of instructions should be skipped by hint to pass the tests! (bf-precision 53) ; Check whether outputs are the same for the hint and without hint executions @@ -108,6 +109,7 @@ [exn:rival:unsamplable? (λ (e) 'unsamplable)]) (rival-apply machine pt hint))) (define hint-instr-count (vector-length (rival-profile machine 'executions))) + (check-equal? hint-result no-hint-result) (values no-hint-instr-count hint-instr-count)) @@ -150,7 +152,6 @@ (define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt)) (set! hint-cnt (+ hint-cnt hint-cnt*)) (set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*)))) - (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) skipped-percentage) @@ -169,7 +170,7 @@ (define machine1 (rival-compile expr1 vars discs)) (define skipped-instr1 (hints-random-checks machine1 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) - (check-true (< skipped-instr1 100)) + (check-true (< skipped-instr1 threshold)) (define expr2 (list '(TRUE) @@ -178,7 +179,7 @@ (define machine2 (rival-compile expr2 vars discs)) (define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) - (check-true (< skipped-instr2 100)) + (check-true (< skipped-instr2 threshold)) (define expr3 (list '(TRUE) @@ -192,11 +193,11 @@ (define machine3 (rival-compile expr3 vars discs)) (define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3)) - (check-true (< skipped-instr3 100)) + (check-true (< skipped-instr3 threshold)) ; Test checks hint on assert where an error can be observed (define expr4 (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (cos x) (cos y)))) (define machine4 (rival-compile expr4 vars discs)) (define skipped-instr4 (hints-random-checks machine4 (bf -100) (bf 100) varc)) (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4)) - (check-true (< skipped-instr4 100))) + (check-true (< skipped-instr4 threshold))) From 93ef101edb0c150c77aabf4b42ce3ec9b3cf8646 Mon Sep 17 00:00:00 2001 From: AYadrov <45910827+AYadrov@users.noreply.github.com> Date: Sun, 19 Jan 2025 16:37:18 -0700 Subject: [PATCH 23/29] Suggested update at test.yml Co-authored-by: Pavel Panchekha --- .github/workflows/tests.yml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4e4a21c..55df959 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -21,5 +21,4 @@ jobs: run: raco fmt -i **/*.rkt - name: "Make sure files are correctly formatted with raco fmt" run: git diff --exit-code - - run: raco test *.rkt - - run: raco test eval/*.rkt + - run: raco test . From d389cdd050e8939c4eb6d353ed8e8af32a922025 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:07:02 -0700 Subject: [PATCH 24/29] PR suggestions and fixes --- eval/adjust.rkt | 39 ++++++++------------ eval/main.rkt | 94 ++++++++++++++++++++++--------------------------- 2 files changed, 58 insertions(+), 75 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 5fd8d99..0e0c9dd 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -49,27 +49,22 @@ (match-define (list _ bool-idx) instr) (define bool-reg (vector-ref vregs bool-idx)) (match* ((ival-lo bool-reg) (ival-hi bool-reg) (ival-err? bool-reg)) - [(#t #t #f) ; assert and its children should not be reexecuted if it is true already - (vhint-set! bool-idx (or #f (vhint-ref bool-idx))) - (ival-bool #t)] - [(#f #f #f) ; assert and its children should not be reexecuted if it is false already - (vhint-set! bool-idx (or #f (vhint-ref bool-idx))) - (ival-bool #f)] + ; assert and its children should not be reexecuted if it is true already + [(#t #t #f) (ival-bool #t)] + ; assert and its children should not be reexecuted if it is false already + [(#f #f #f) (ival-bool #f)] [(_ _ _) ; assert and its children should be reexecuted (vhint-set! bool-idx #t) + (set! converged? #f) #t])] [(ival-if) (match-define (list _ cond tru fls) instr) (define cond-reg (vector-ref vregs cond)) (match* ((ival-lo cond-reg) (ival-hi cond-reg) (ival-err? cond-reg)) [(#t #t #f) ; only true path should be executed - (vhint-set! cond (or #f (vhint-ref cond))) (vhint-set! tru #t) - (vhint-set! fls (or #f (vhint-ref fls))) 2] [(#f #f #f) ; only false path should be executed - (vhint-set! cond (or #f (vhint-ref cond))) - (vhint-set! tru (or #f (vhint-ref tru))) (vhint-set! fls #t) 3] [(_ _ _) ; execute both paths and cond as well @@ -81,16 +76,14 @@ [(ival-fmax) (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (match* ((ival-lo cmp) (ival-hi cmp)) - [(#t #t) ; only arg1 should be executed - (vhint-set! arg2 (or #f (vhint-ref arg2))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + [(#t #t #f) ; only arg1 should be executed (vhint-set! arg1 #t) 1] - [(#f #f) ; only arg2 should be executed - (vhint-set! arg1 (or #f (vhint-ref arg1))) + [(#f #f #f) ; only arg2 should be executed (vhint-set! arg2 #t) 2] - [(#f #t) ; both paths should be executed + [(_ _ _) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) (set! converged? #f) @@ -98,23 +91,21 @@ [(ival-fmin) (match-define (list _ arg1 arg2) instr) (define cmp (ival-> (vector-ref vregs arg1) (vector-ref vregs arg2))) - (match* ((ival-lo cmp) (ival-hi cmp)) - [(#t #t) ; only arg2 should be executed - (vhint-set! arg1 (or #f (vhint-ref arg1))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + [(#t #t #f) ; only arg2 should be executed (vhint-set! arg2 #t) 2] - [(#f #f) ; only arg1 should be executed - (vhint-set! arg2 (or #f (vhint-ref arg2))) + [(#f #f #f) ; only arg1 should be executed (vhint-set! arg1 #t) 1] - [(#f #t) ; both paths should be executed + [(_ _ _) ; both paths should be executed (vhint-set! arg1 #t) (vhint-set! arg2 #t) (set! converged? #f) #t])] [else ; at this point we are given that the current instruction should be executed (define srcs (rest instr)) ; then, children instructions should be executed as well - (map (λ (x) (vhint-set! x #t)) srcs) + (for-each (λ (x) (vhint-set! x #t)) srcs) #t])) (vector-set! vhint n hint*)) (values vhint converged?)) @@ -246,4 +237,4 @@ ; Lower precision bound propogation (vector-set! vprecs-min (- x varc) - (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) \ No newline at end of file + (max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound))))))) diff --git a/eval/main.rkt b/eval/main.rkt index fbc708a..2cd353a 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -93,23 +93,23 @@ math/bigfloat) (define number-of-random-hyperrects 100) (define number-of-random-pts-per-rect 100) - (define threshold 95) ; at least 5% of instructions should be skipped by hint to pass the tests! + (define rect (list (bf -100) (bf 100))) (bf-precision 53) ; Check whether outputs are the same for the hint and without hint executions (define (rival-check-hint machine hint pt) - (define no-hint-result - (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] - [exn:rival:unsamplable? (λ (e) 'unsamplable)]) - (rival-apply machine pt))) - (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) - (define hint-result (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] [exn:rival:unsamplable? (λ (e) 'unsamplable)]) (rival-apply machine pt hint))) (define hint-instr-count (vector-length (rival-profile machine 'executions))) + (define no-hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt))) + (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) + (check-equal? hint-result no-hint-result) (values no-hint-instr-count hint-instr-count)) @@ -155,49 +155,41 @@ (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) skipped-percentage) - (define discs (list boolean-discretization flonum-discretization)) - (define vars '(x y)) - (define varc (length vars)) - - (define expr1 - (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) - '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) - (fmax (* (log x) (log y)) (+ (log x) (log y))) - (fmin (* (log x) (log y)) (+ (log x) (log y)))) - (if (> (+ (log x) (log y)) (* (log x) (log y))) - (fmax (/ (log x) (log y)) (- (log x) (log y))) - (fmin (/ (log x) (log y)) (- (log x) (log y))))))) - (define machine1 (rival-compile expr1 vars discs)) - (define skipped-instr1 (hints-random-checks machine1 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr1)) - (check-true (< skipped-instr1 threshold)) - - (define expr2 - (list '(TRUE) - '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) - (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) - (define machine2 (rival-compile expr2 vars discs)) - (define skipped-instr2 (hints-random-checks machine2 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr2 = ~a\n" (round skipped-instr2)) - (check-true (< skipped-instr2 threshold)) - - (define expr3 - (list '(TRUE) - '(if (> (exp x) (+ 10 (log y))) - (if (> (fmax (* x y) (+ x y)) 4) - (cos (fmax x y)) - (cos (fmin x y))) - (if (< (pow 2 x) (- (exp x) 10)) - (* PI x) - (fmax x (- (cos y) (+ 10 (log y)))))))) - (define machine3 (rival-compile expr3 vars discs)) - (define skipped-instr3 (hints-random-checks machine3 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr3 = ~a\n" (round skipped-instr3)) - (check-true (< skipped-instr3 threshold)) + (define (expressions2d-check expressions) + (define discs (list boolean-discretization flonum-discretization)) + (define vars '(x y)) + (define varc (length vars)) + (define machine (rival-compile expressions vars discs)) + (define skipped-instr + (parameterize ([bf-precision 63]) + (hints-random-checks machine (first rect) (second rect) varc))) + (printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr))) + + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) + (fmax (* (log x) (log y)) (+ (log x) (log y))) + (fmin (* (log x) (log y)) (+ (log x) (log y)))) + (if (> (+ (log x) (log y)) (* (log x) (log y))) + (fmax (/ (log x) (log y)) (- (log x) (log y))) + (fmin (/ (log x) (log y)) (- (log x) (log y))))))) + + (expressions2d-check + (list '(TRUE) + '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) + (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) + + (expressions2d-check (list '(TRUE) + '(if (> (exp x) (+ 10 (log y))) + (if (> (fmax (* x y) (+ x y)) 4) + (cos (fmax x y)) + (cos (fmin x y))) + (if (< (pow 2 x) (- (exp x) 10)) + (* PI x) + (fmax x (- (cos y) (+ 10 (log y)))))))) ; Test checks hint on assert where an error can be observed - (define expr4 (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (cos x) (cos y)))) - (define machine4 (rival-compile expr4 vars discs)) - (define skipped-instr4 (hints-random-checks machine4 (bf -100) (bf 100) varc)) - (printf "Percentage of skipped instructions by hint in expr1 = ~a\n" (round skipped-instr4)) - (check-true (< skipped-instr4 threshold))) + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (cos x) (cos y)))) + + ; Test checks hint on fmax where an error can be observed + (expressions2d-check (list '(TRUE) '(fmax (log x) y)))) From 8d2a9551852a1fd5522c044af0903443aabfa028 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:33:13 -0700 Subject: [PATCH 25/29] added hint support for or, and, not operations --- eval/adjust.rkt | 18 ++++++++++++++---- eval/main.rkt | 14 ++++++++++---- eval/run.rkt | 11 ++++++----- main.rkt | 5 +---- 4 files changed, 31 insertions(+), 17 deletions(-) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 0e0c9dd..379718b 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -29,14 +29,12 @@ (define vhint (make-vector (vector-length ivec) #f)) (define converged? #t) + ; helper function (define (vhint-set! idx val) (when (>= idx varc) (vector-set! vhint (- idx varc) val))) - (define (vhint-ref idx) - (if (>= idx varc) - (vector-ref vhint (- idx varc)) - #f)) + ; roots always should be executed (for ([root-reg (in-vector rootvec)]) (vhint-set! root-reg #t)) (for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)] @@ -103,6 +101,18 @@ (vhint-set! arg2 #t) (set! converged? #f) #t])] + [(ival-< ival-<= ival-> ival->= ival-== ival-!= ival-and ival-or ival-not) + (define cmp (vector-ref vregs (+ varc n))) + (match* ((ival-lo cmp) (ival-hi cmp) (ival-err? cmp)) + ; result is known + [(#t #t #f) (ival-bool #t)] + ; result is known + [(#f #f #f) (ival-bool #f)] + [(_ _ _) ; all the paths should be executed + (define srcs (rest instr)) + (for-each (λ (x) (vhint-set! x #t)) srcs) + #t])] + [else ; at this point we are given that the current instruction should be executed (define srcs (rest instr)) ; then, children instructions should be executed as well (for-each (λ (x) (vhint-set! x #t)) srcs) diff --git a/eval/main.rkt b/eval/main.rkt index 2cd353a..3c2f5a9 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -160,9 +160,7 @@ (define vars '(x y)) (define varc (length vars)) (define machine (rival-compile expressions vars discs)) - (define skipped-instr - (parameterize ([bf-precision 63]) - (hints-random-checks machine (first rect) (second rect) varc))) + (define skipped-instr (hints-random-checks machine (first rect) (second rect) varc)) (printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr))) (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) @@ -192,4 +190,12 @@ '(+ (cos x) (cos y)))) ; Test checks hint on fmax where an error can be observed - (expressions2d-check (list '(TRUE) '(fmax (log x) y)))) + (expressions2d-check (list '(TRUE) '(fmax (log x) y))) + + ; Test checks hints on comparison operators + (expressions2d-check + (list '(and (and (> (log x) y) (or (== (exp x) (exp y)) (> (cos x) (cos y)))) (<= (log y) (log x))) + '(if (or (or (< (log x) y) (and (!= (exp x) (exp y)) (< (cos x) (cos y)))) + (>= (log y) (log x))) + x + y)))) diff --git a/eval/run.rkt b/eval/run.rkt index 6f4b22b..5b2cd2b 100644 --- a/eval/run.rkt +++ b/eval/run.rkt @@ -48,11 +48,12 @@ #:unless (or (not hint) (and (not first-iter?) repeat))) (define start (current-inexact-milliseconds)) (define out - (parameterize ([bf-precision precision]) - (match hint - [#t (apply-instruction instr vregs)] - [(? integer? _) (vector-ref vregs (list-ref instr hint))] - [(? ival? _) hint]))) + (match hint + [#t + (parameterize ([bf-precision precision]) + (apply-instruction instr vregs))] + [(? integer? _) (vector-ref vregs (list-ref instr hint))] + [(? ival? _) hint])) (vector-set! vregs n out) (define name (object-name (car instr))) (define time (- (current-inexact-milliseconds) start)) diff --git a/main.rkt b/main.rkt index e13d2dd..2a52248 100644 --- a/main.rkt +++ b/main.rkt @@ -96,10 +96,7 @@ (->* (rival-machine? (vectorof value?)) ((or/c (vectorof any/c) boolean?)) (vectorof any/c))] - [rival-analyze - (-> rival-machine? - (vectorof ival?) - (listof any/c))]) ; (values ival? (vectorof any/c) boolean?))]) + [rival-analyze (-> rival-machine? (vectorof ival?) (listof any/c))]) (struct-out discretization) (struct-out exn:rival) (struct-out exn:rival:invalid) From f2dd25478bbeb63c0de5995a19732859633ed53d Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:34:37 -0700 Subject: [PATCH 26/29] missed converged flag --- eval/adjust.rkt | 1 + 1 file changed, 1 insertion(+) diff --git a/eval/adjust.rkt b/eval/adjust.rkt index 379718b..bc97a0c 100644 --- a/eval/adjust.rkt +++ b/eval/adjust.rkt @@ -111,6 +111,7 @@ [(_ _ _) ; all the paths should be executed (define srcs (rest instr)) (for-each (λ (x) (vhint-set! x #t)) srcs) + (set! converged? #f) #t])] [else ; at this point we are given that the current instruction should be executed From 4a7bb19474bda2c1a7f57bf343a2307b6f4e72a1 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:40:17 -0700 Subject: [PATCH 27/29] a bug that was bothering me. Wrong length of some vectors --- eval/compile.rkt | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/eval/compile.rkt b/eval/compile.rkt index e8e3135..018f470 100644 --- a/eval/compile.rkt +++ b/eval/compile.rkt @@ -270,15 +270,16 @@ ([node (in-vector nodes num-vars)]) (fn->ival-fn node))) - (define register-count (+ (length vars) (vector-length instructions))) + (define ivec-length (vector-length instructions)) + (define register-count (+ (length vars) ivec-length)) (define registers (make-vector register-count)) - (define repeats (make-vector register-count #f)) ; flags whether an op should be evaluated - (define precisions (make-vector register-count)) ; vector that stores working precisions + (define repeats (make-vector ivec-length #f)) ; flags whether an op should be evaluated + (define precisions (make-vector ivec-length)) ; vector that stores working precisions ;; vector for adjusting precisions (define incremental-precisions (setup-vstart-precs instructions (length vars) roots discs)) (define initial-precision (+ (argmax identity (map discretization-target discs)) (*base-tuning-precision*))) - (define hint (make-vector (vector-length instructions) #t)) + (define hint (make-vector ivec-length #t)) (rival-machine (list->vector vars) instructions From ec2b709edb68d6e780200ce266fc338ad4aea518 Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:53:35 -0700 Subject: [PATCH 28/29] moved tests to a separate file --- eval/compile.rkt | 38 ------------ eval/main.rkt | 114 ---------------------------------- eval/tests.rkt | 155 +++++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 155 insertions(+), 152 deletions(-) create mode 100644 eval/tests.rkt diff --git a/eval/compile.rkt b/eval/compile.rkt index 018f470..3f6a445 100644 --- a/eval/compile.rkt +++ b/eval/compile.rkt @@ -328,41 +328,3 @@ idx-prec ; We wanna make sure that we do not tune a precision down (+ current-prec (*ampl-tuning-bits*)))))) vstart-precs) - -(module+ test - (require rackunit - "../utils.rkt") - ; This function is needed to unwrap constant procedure which fails tests otherwise - ; (const (ival 0.bf 0.bf)) != (const (ival 0.bf 0.bf)) - (define (drop-ival-const instrs) - (for/vector ([instr (in-vector instrs)]) - (match instr - [`(,const) (const)] - [_ instr]))) - - (define discs (list flonum-discretization)) - (define vars '(x y z)) - - (define (check-rival-optimization expr target-expr) - (define optimized-instrs - (drop-ival-const (parameterize ([*rival-use-shorthands* #t]) - (rival-machine-instructions (rival-compile (list expr) vars discs))))) - (define target-instrs - (drop-ival-const (parameterize ([*rival-use-shorthands* #f]) - (rival-machine-instructions (rival-compile (list target-expr) vars discs))))) - (check-equal? optimized-instrs target-instrs)) - - (check-rival-optimization `(* (log (exp x)) y) `(* x y)) - (check-rival-optimization `(* (exp (log x)) y) `(* (then (assert (> x 0)) x) y)) - (check-rival-optimization `(fma x y z) `(+ (* x y) z)) - (check-rival-optimization `(- (exp x) 1) `(expm1 x)) - (check-rival-optimization `(- 1 (exp x)) `(neg (expm1 x))) - (check-rival-optimization `(log (+ 1 x)) `(log1p x)) - (check-rival-optimization `(log (+ x 1)) `(log1p x)) - (check-rival-optimization `(sqrt (+ (* x x) (* y y))) `(hypot x y)) - (check-rival-optimization `(sqrt (+ (* x x) 1)) `(hypot x 1)) - (check-rival-optimization `(sqrt (+ 1 (* x x))) `(hypot 1 x)) - (check-rival-optimization `(pow x 2) `(pow2 x)) - (check-rival-optimization `(pow x 1/3) `(cbrt x)) - (check-rival-optimization `(pow x 1/2) `(sqrt x)) - (check-rival-optimization `(pow 2 x) `(exp2 x))) diff --git a/eval/main.rkt b/eval/main.rkt index 3c2f5a9..d00753f 100644 --- a/eval/main.rkt +++ b/eval/main.rkt @@ -85,117 +85,3 @@ (rival-machine-full machine rect))) (define-values (hint hint-converged?) (make-hint machine)) (list (ival (or bad? stuck?) (not good?)) hint hint-converged?)) - -(module+ test - (require rackunit - "compile.rkt" - "../utils.rkt" - math/bigfloat) - (define number-of-random-hyperrects 100) - (define number-of-random-pts-per-rect 100) - (define rect (list (bf -100) (bf 100))) - (bf-precision 53) - - ; Check whether outputs are the same for the hint and without hint executions - (define (rival-check-hint machine hint pt) - (define hint-result - (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] - [exn:rival:unsamplable? (λ (e) 'unsamplable)]) - (rival-apply machine pt hint))) - (define hint-instr-count (vector-length (rival-profile machine 'executions))) - - (define no-hint-result - (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] - [exn:rival:unsamplable? (λ (e) 'unsamplable)]) - (rival-apply machine pt))) - (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) - - (check-equal? hint-result no-hint-result) - (values no-hint-instr-count hint-instr-count)) - - ; Random sampling hyperrects given a general range as [rect-lo, rect-hi] - (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) - (for/vector ([_ (in-range varc)]) - (define xlo-range-length (bf- rect-hi rect-lo)) - (define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo)) - (define xhi-range-length (bf- rect-hi xlo)) - (define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo)) - (check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo)) - "Hyperrect is out of bounds") - (ival xlo xhi))) - - ; Sample points with respect to the input hyperrect - (define (sample-pts hyperrect) - (for/vector ([rect (in-vector hyperrect)]) - (define range-length (bf- (ival-hi rect) (ival-lo rect))) - (define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect))) - (check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect))) - "Sampled point is out of hyperrect range") - pt)) - - ; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by - ; 'number-of-random-pts-per-rect' points each - (define (hints-random-checks machine rect-lo rect-hi varc) - (define evaluated-instructions 0) - (define number-of-instructions-total - (* number-of-random-hyperrects (vector-length (rival-machine-instructions machine)))) - - (define hint-cnt 0) - (define no-hint-cnt 0) - (for ([n (in-range number-of-random-hyperrects)]) - (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) - (match-define (list res hint _) (rival-analyze machine hyperrect)) - (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) - - (for ([_ (in-range number-of-random-pts-per-rect)]) - (define pt (sample-pts hyperrect)) - (define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt)) - (set! hint-cnt (+ hint-cnt hint-cnt*)) - (set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*)))) - (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) - skipped-percentage) - - (define (expressions2d-check expressions) - (define discs (list boolean-discretization flonum-discretization)) - (define vars '(x y)) - (define varc (length vars)) - (define machine (rival-compile expressions vars discs)) - (define skipped-instr (hints-random-checks machine (first rect) (second rect) varc)) - (printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr))) - - (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) - '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) - (fmax (* (log x) (log y)) (+ (log x) (log y))) - (fmin (* (log x) (log y)) (+ (log x) (log y)))) - (if (> (+ (log x) (log y)) (* (log x) (log y))) - (fmax (/ (log x) (log y)) (- (log x) (log y))) - (fmin (/ (log x) (log y)) (- (log x) (log y))))))) - - (expressions2d-check - (list '(TRUE) - '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) - (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) - - (expressions2d-check (list '(TRUE) - '(if (> (exp x) (+ 10 (log y))) - (if (> (fmax (* x y) (+ x y)) 4) - (cos (fmax x y)) - (cos (fmin x y))) - (if (< (pow 2 x) (- (exp x) 10)) - (* PI x) - (fmax x (- (cos y) (+ 10 (log y)))))))) - - ; Test checks hint on assert where an error can be observed - (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) - '(+ (cos x) (cos y)))) - - ; Test checks hint on fmax where an error can be observed - (expressions2d-check (list '(TRUE) '(fmax (log x) y))) - - ; Test checks hints on comparison operators - (expressions2d-check - (list '(and (and (> (log x) y) (or (== (exp x) (exp y)) (> (cos x) (cos y)))) (<= (log y) (log x))) - '(if (or (or (< (log x) y) (and (!= (exp x) (exp y)) (< (cos x) (cos y)))) - (>= (log y) (log x))) - x - y)))) diff --git a/eval/tests.rkt b/eval/tests.rkt new file mode 100644 index 0000000..afd6307 --- /dev/null +++ b/eval/tests.rkt @@ -0,0 +1,155 @@ +#lang racket + +; Test compile optimizations +(module+ test + (require rackunit + "machine.rkt" + "../main.rkt") + ; This function is needed to unwrap constant procedure which fails tests otherwise + ; (const (ival 0.bf 0.bf)) != (const (ival 0.bf 0.bf)) + (define (drop-ival-const instrs) + (for/vector ([instr (in-vector instrs)]) + (match instr + [`(,const) (const)] + [_ instr]))) + + (define discs (list flonum-discretization)) + (define vars '(x y z)) + + (define (check-rival-optimization expr target-expr) + (define optimized-instrs + (drop-ival-const (parameterize ([*rival-use-shorthands* #t]) + (rival-machine-instructions (rival-compile (list expr) vars discs))))) + (define target-instrs + (drop-ival-const (parameterize ([*rival-use-shorthands* #f]) + (rival-machine-instructions (rival-compile (list target-expr) vars discs))))) + (check-equal? optimized-instrs target-instrs)) + + (check-rival-optimization `(* (log (exp x)) y) `(* x y)) + (check-rival-optimization `(* (exp (log x)) y) `(* (then (assert (> x 0)) x) y)) + (check-rival-optimization `(fma x y z) `(+ (* x y) z)) + (check-rival-optimization `(- (exp x) 1) `(expm1 x)) + (check-rival-optimization `(- 1 (exp x)) `(neg (expm1 x))) + (check-rival-optimization `(log (+ 1 x)) `(log1p x)) + (check-rival-optimization `(log (+ x 1)) `(log1p x)) + (check-rival-optimization `(sqrt (+ (* x x) (* y y))) `(hypot x y)) + (check-rival-optimization `(sqrt (+ (* x x) 1)) `(hypot x 1)) + (check-rival-optimization `(sqrt (+ 1 (* x x))) `(hypot 1 x)) + (check-rival-optimization `(pow x 2) `(pow2 x)) + (check-rival-optimization `(pow x 1/3) `(cbrt x)) + (check-rival-optimization `(pow x 1/2) `(sqrt x)) + (check-rival-optimization `(pow 2 x) `(exp2 x))) + +; Check hints from rival-analyze +(module+ test + (require rackunit + math/bigfloat + "../main.rkt") + (define number-of-random-hyperrects 100) + (define number-of-random-pts-per-rect 100) + (define rect (list (bf -100) (bf 100))) + (bf-precision 53) + + ; Check whether outputs are the same for the hint and without hint executions + (define (rival-check-hint machine hint pt) + (define hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt hint))) + (define hint-instr-count (vector-length (rival-profile machine 'executions))) + + (define no-hint-result + (with-handlers ([exn:rival:invalid? (λ (e) 'invalid)] + [exn:rival:unsamplable? (λ (e) 'unsamplable)]) + (rival-apply machine pt))) + (define no-hint-instr-count (vector-length (rival-profile machine 'executions))) + + (check-equal? hint-result no-hint-result) + (values no-hint-instr-count hint-instr-count)) + + ; Random sampling hyperrects given a general range as [rect-lo, rect-hi] + (define (sample-hyperrect-within-bounds rect-lo rect-hi varc) + (for/vector ([_ (in-range varc)]) + (define xlo-range-length (bf- rect-hi rect-lo)) + (define xlo (bf+ (bf* (bfrandom) xlo-range-length) rect-lo)) + (define xhi-range-length (bf- rect-hi xlo)) + (define xhi (bf+ (bf* (bfrandom) xhi-range-length) xlo)) + (check-true (and (bf> rect-hi xhi) (bf> xlo rect-lo) (bf> xhi xlo)) + "Hyperrect is out of bounds") + (ival xlo xhi))) + + ; Sample points with respect to the input hyperrect + (define (sample-pts hyperrect) + (for/vector ([rect (in-vector hyperrect)]) + (define range-length (bf- (ival-hi rect) (ival-lo rect))) + (define pt (bf+ (bf* (bfrandom) range-length) (ival-lo rect))) + (check-true (and (bf> pt (ival-lo rect)) (bf< pt (ival-hi rect))) + "Sampled point is out of hyperrect range") + pt)) + + ; Testing hint on an expression for 'number-of-random-hyperrects' hyperrects by + ; 'number-of-random-pts-per-rect' points each + (define (hints-random-checks machine rect-lo rect-hi varc) + (define evaluated-instructions 0) + (define number-of-instructions-total + (* number-of-random-hyperrects (vector-length (rival-machine-instructions machine)))) + + (define hint-cnt 0) + (define no-hint-cnt 0) + (for ([n (in-range number-of-random-hyperrects)]) + (define hyperrect (sample-hyperrect-within-bounds rect-lo rect-hi varc)) + (match-define (list res hint _) (rival-analyze machine hyperrect)) + (set! evaluated-instructions (+ evaluated-instructions (vector-count false? hint))) + + (for ([_ (in-range number-of-random-pts-per-rect)]) + (define pt (sample-pts hyperrect)) + (define-values (no-hint-cnt* hint-cnt*) (rival-check-hint machine hint pt)) + (set! hint-cnt (+ hint-cnt hint-cnt*)) + (set! no-hint-cnt (+ no-hint-cnt no-hint-cnt*)))) + (define skipped-percentage (* (/ hint-cnt no-hint-cnt) 100)) + skipped-percentage) + + (define (expressions2d-check expressions) + (define discs (list boolean-discretization flonum-discretization)) + (define vars '(x y)) + (define varc (length vars)) + (define machine (rival-compile expressions vars discs)) + (define skipped-instr (hints-random-checks machine (first rect) (second rect) varc)) + (printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr))) + + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (if (> (/ (log x) (log y)) (* (log x) (log y))) + (fmax (* (log x) (log y)) (+ (log x) (log y))) + (fmin (* (log x) (log y)) (+ (log x) (log y)))) + (if (> (+ (log x) (log y)) (* (log x) (log y))) + (fmax (/ (log x) (log y)) (- (log x) (log y))) + (fmin (/ (log x) (log y)) (- (log x) (log y))))))) + + (expressions2d-check + (list '(TRUE) + '(fmax (fmin (fmax (* x y) (+ x y)) (+ (fmax x (* 2 y)) (fmin y (* x 2)))) + (fmax (fmin (* x y) (+ x y)) (+ (fmin x (* 2 y)) (fmax y (* x 2))))))) + + (expressions2d-check (list '(TRUE) + '(if (> (exp x) (+ 10 (log y))) + (if (> (fmax (* x y) (+ x y)) 4) + (cos (fmax x y)) + (cos (fmin x y))) + (if (< (pow 2 x) (- (exp x) 10)) + (* PI x) + (fmax x (- (cos y) (+ 10 (log y)))))))) + + ; Test checks hint on assert where an error can be observed + (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) + '(+ (cos x) (cos y)))) + + ; Test checks hint on fmax where an error can be observed + (expressions2d-check (list '(TRUE) '(fmax (log x) y))) + + ; Test checks hints on comparison operators + (expressions2d-check + (list '(and (and (> (log x) y) (or (== (exp x) (exp y)) (> (cos x) (cos y)))) (<= (log y) (log x))) + '(if (or (or (< (log x) y) (and (!= (exp x) (exp y)) (< (cos x) (cos y)))) + (>= (log y) (log x))) + x + y)))) \ No newline at end of file From 0fd937c3e1f56fd1b67e84da2f59a71acc2bbd9d Mon Sep 17 00:00:00 2001 From: AYadrov Date: Sun, 19 Jan 2025 19:58:53 -0700 Subject: [PATCH 29/29] made tests silent --- eval/tests.rkt | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/eval/tests.rkt b/eval/tests.rkt index afd6307..38d5b1d 100644 --- a/eval/tests.rkt +++ b/eval/tests.rkt @@ -115,7 +115,8 @@ (define varc (length vars)) (define machine (rival-compile expressions vars discs)) (define skipped-instr (hints-random-checks machine (first rect) (second rect) varc)) - (printf "Percentage of skipped instructions by hint in expr = ~a\n" (round skipped-instr))) + (check-true (< skipped-instr 99) + (format "Almost no instructions got skipped by hint at ~a" expressions))) (expressions2d-check (list '(assert (> (+ (log x) (log y)) (- (log x) (log y)))) '(+ (if (> (/ (log x) (log y)) (* (log x) (log y)))