-
Notifications
You must be signed in to change notification settings - Fork 0
/
zx.egg
127 lines (95 loc) · 3.66 KB
/
zx.egg
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
(datatype Dim (Plus Dim Dim) (NamedDim String) (Lit i64))
(datatype ZX (Cast ZX Dim Dim) (Stack ZX ZX) (Compose ZX ZX) (Val String Dim Dim) (Z Dim Dim Dim) (nWire Dim))
(datatype Proof (Cons String Proof) (Base String))
(function in (ZX) Dim)
(function out (ZX) Dim)
; TODO: Restrict growing rewrites to not chain by checking right side
; demand dimensions on all diags
(rule ((= x (Cast a b c))) ((in x) (out x)))
(rule ((= x (Stack a b))) ((in x) (out x)))
(rule ((= x (Val a b c))) ((in x) (out x)))
(rule ((= x (Compose a b))) ((in x) (out x)))
(rule ((= x (nWire a))) ((in x) (out x)))
(rule ((= x (Z a b c))) ((in x) (out x)))
; Constant folding for cast id
(rewrite (Plus (Lit a) (Lit b)) (Lit (+ a b)))
(rewrite (Lit a) (Lit b)
:when ((= a b)))
(rewrite (in (Val s n m)) n)
(rewrite (out (Val s n m)) m)
(rewrite (in (Z n m a)) n)
(rewrite (out (Z n m a)) m)
(rewrite (in (Cast zx n m)) n)
(rewrite (out (Cast zx n m)) m)
(rewrite (in (nWire n)) n)
(rewrite (out (nWire n)) n)
; Empty rules
(let zero (Lit 0))
(let Empty (Val "Empty" zero zero))
(rewrite (Compose a Empty) a
:when ((= (out a) zero)))
(rewrite (Compose Empty a) a
:when ((= (in a) zero)))
(rewrite (nWire zero) Empty)
(rewrite (Stack Empty a) a)
(rewrite (Stack a Empty) (Cast a (Plus (in a) zero) (Plus (out a) zero)))
; Wire rules
(let one (Lit 1))
(let Wire (Val "Wire" one one))
(rewrite Wire (nWire one))
(rewrite (Stack (nWire n) (nWire m)) (nWire (Plus n m)))
(rewrite (Compose (nWire n) a) a
:when ((= n (in a))))
(rewrite (Compose a (nWire n)) a
:when ((= n (out a))))
(rewrite (in (Stack zx0 zx1)) (Plus (in zx0) (in zx1)))
(rewrite (out (Stack zx0 zx1)) (Plus (out zx0) (out zx1)))
(rewrite (in (Compose zx0 zx1)) (in zx0)
:when ((= (out zx0) (in zx1)))
)
(rewrite (out (Compose zx0 zx1)) (out zx1)
:when ((= (out zx0) (in zx1)))
)
(rewrite (Compose ?a (Compose ?b ?c)) (Compose (Compose ?a ?b) ?c))
(rewrite (Compose (Stack a b) (Stack c d)) (Stack (Compose a c) (Compose b d))
:when ((= (out a) (in c))
(= (out b) (in d)))
)
(rewrite (Stack a (Stack b c)) (Cast (Stack (Stack a b) c) (Plus (in a) (Plus (in b) (in c))) (Plus (out a) (Plus (out b) (out c)))))
(rewrite (Stack (Stack a b) c) (Cast (Stack a (Stack b c)) (Plus (Plus (in a) (in b)) (in c)) (Plus (Plus (out a) (out b)) (out c))))
; Spider Fusion
;((n_wire top ↕ Z input (S mid + bot) α) ⟷ cast (top + ((S mid) + bot)) _ prfn prfm (Z (top + (S mid)) output β ↕ n_wire bot)) ∝ Z (top + input) (output + bot) (β + α).
(rewrite (
Compose
(Stack
(nWire top)
(Z i (Plus (Plus one mid) out) a)
)
(Cast
(Stack
(Z (Plus top (Plus one mid)) o b)
(nWire bot)
)
(Plus top (Plus (Plus one mid) bot))
(Plus o bot)
)
)
(Z (Plus top i) (Plus o bot) (Plus a b)))
(rewrite (Cast zx a b) zx
:when ((= (in zx) a)
(= (out zx) b))) ; cast_id ; handles rules together with constanct folding
(let two (Lit 2))
(let Cup (Val "Cup" zero two))
(let Cap (Val "Cap" two zero))
(let Z12 (Val "Z" one two))
(rewrite (Compose (Stack Wire Cup) (Stack Cap Wire)) Wire) ; yank
(let hidden_yank (Compose (Compose (Stack (nWire two) Cup) (Stack Z12 (nWire (Plus one two)))) (Stack (nWire two) (Stack Cap Wire))))
(let sww (Stack Z12 Wire))
(let lhs (Compose (Stack Wire (Stack (Compose Wire Wire) (Compose Cup Cap))) (nWire (Lit 2))))
(let rhs (Compose (Stack (Stack (Cast Wire (Plus zero one) (Plus one zero)) Wire) Cup) (Stack Wire (Stack (Stack (nWire (Lit 1)) Empty) Cap))))
(let oneonefustionlhs (Compose (Z one one zero) (Z one one zero)))
(let oneonefustionrhs (Z one one zero))
(run 20)
(check (= sww hidden_yank))
(check (= lhs rhs))
; (check (= oneonefustionlhs oneonefustionrhs))