Skip to content

Commit

Permalink
the boog with backward pass + hint has been found
Browse files Browse the repository at this point in the history
  • Loading branch information
AYadrov committed Jan 8, 2025
1 parent b26b236 commit 1445846
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 11 deletions.
15 changes: 9 additions & 6 deletions eval/adjust.rkt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#lang racket/base
#lang racket

(require "tricks.rkt"
"../ops/all.rkt"
Expand Down Expand Up @@ -105,7 +105,7 @@
(vector-set! vhint n hint*))
(values vhint converged?))

(define (backward-pass machine)
(define (backward-pass machine vhint)
; Since Step 2 writes into *sampling-iteration* if the max prec was reached - save the iter number for step 3
(define args (rival-machine-arguments machine))
(define ivec (rival-machine-instructions machine))
Expand Down Expand Up @@ -147,7 +147,7 @@
(vector-set! vuseful (- arg varc) #t))]))

; Step 2. Precision tuning
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful)
(precision-tuning ivec vregs vprecs-new varc vstart-precs vuseful vhint)

; Step 3. Repeating precisions check + Assigning if a operation should be computed again at all
; vrepeats[i] = #t if the node has the same precision as an iteration before and children have #t flag as well
Expand Down Expand Up @@ -187,12 +187,15 @@
; Roughly speaking, the upper precision bound is calculated as:
; vprecs-max[i] = (+ max-prec vstart-precs[i]), where min-prec < (+ max-prec vstart-precs[i]) < max-prec
; max-prec = (car (get-bounds parent))
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful)
(define (precision-tuning ivec vregs vprecs-max varc vstart-precs vuseful vhint)
(define vprecs-min (make-vector (vector-length ivec) 0))
(for ([instr (in-vector ivec (- (vector-length ivec) 1) -1 -1)]
[useful? (in-vector vuseful (- (vector-length vuseful) 1) -1 -1)]
[n (in-range (- (vector-length vregs) 1) -1 -1)]
#:when useful?)
[hint (if vhint
(in-vector vhint)
(in-producer (const #t)))]
#:when (and hint useful?))
(define op (car instr))
(define tail-registers (cdr instr))
(define srcs (map (lambda (x) (vector-ref vregs x)) tail-registers))
Expand Down Expand Up @@ -231,4 +234,4 @@
; Lower precision bound propogation
(vector-set! vprecs-min
(- x varc)
(max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound)))))))
(max (vector-ref vprecs-min (- x varc)) (+ min-prec (max 0 lo-bound)))))))
6 changes: 3 additions & 3 deletions eval/main.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@

(define ground-truth-require-convergence (make-parameter #t))

(define (rival-machine-full machine inputs [hint #f])
(define (rival-machine-full machine inputs [vhint #f])
(set-rival-machine-iteration! machine (*sampling-iteration*))
(rival-machine-adjust machine)
(rival-machine-adjust machine vhint)
(cond
[(>= (*sampling-iteration*) (*rival-max-iterations*)) (values #f #f #f #t #f)]
[else
(rival-machine-load machine inputs)
(rival-machine-run machine hint)
(rival-machine-run machine vhint)
(rival-machine-return machine)]))

(struct exn:rival exn:fail ())
Expand Down
4 changes: 2 additions & 2 deletions eval/run.rkt
Original file line number Diff line number Diff line change
Expand Up @@ -101,10 +101,10 @@
lo))
(values good? (and good? done?) bad? stuck? fvec))

(define (rival-machine-adjust machine)
(define (rival-machine-adjust machine vhint)
(define iter (rival-machine-iteration machine))
(let ([start (current-inexact-milliseconds)])
(if (zero? iter)
(vector-fill! (rival-machine-precisions machine) (rival-machine-initial-precision machine))
(backward-pass machine))
(backward-pass machine vhint))
(rival-machine-record machine 'adjust -1 (* iter 1000) (- (current-inexact-milliseconds) start))))

0 comments on commit 1445846

Please sign in to comment.