From bb77d9f1dc5ff185ac0234ad11214d71d32454c9 Mon Sep 17 00:00:00 2001 From: Ryan McCleeary Date: Tue, 9 Aug 2022 19:54:01 +0300 Subject: [PATCH 1/3] Enhance TypeCheck to reason more about exponents. Add (x >= 2, x^a * x^b = x^c => a + b = c) to Numeric.hs. Add (Nat a, a * a^x = a^y => 1 + x = y) to SimpType.hs Add (x >= 2, x^a >= x^b => a >= b) to Numeric.hs Start adding tests for exponent TC checks. --- src/Cryptol/TypeCheck/SimpType.hs | 5 +- src/Cryptol/TypeCheck/Solver/Numeric.hs | 27 ++++++++++ tests/issues/issue1489/issue1489.cry | 54 ++++++++++++++++++++ tests/issues/issue1489/issue1489.icry | 1 + tests/issues/issue1489/issue1489.icry.stdout | 3 ++ 5 files changed, 89 insertions(+), 1 deletion(-) create mode 100644 tests/issues/issue1489/issue1489.cry create mode 100644 tests/issues/issue1489/issue1489.icry create mode 100644 tests/issues/issue1489/issue1489.icry.stdout diff --git a/src/Cryptol/TypeCheck/SimpType.hs b/src/Cryptol/TypeCheck/SimpType.hs index af3317f62..971da0703 100644 --- a/src/Cryptol/TypeCheck/SimpType.hs +++ b/src/Cryptol/TypeCheck/SimpType.hs @@ -154,7 +154,10 @@ tMul x y , Just b' <- tIsNum b -- XXX: similar for a = b * k? , n == b' = tSub a (tMod a b) - + -- c * c ^ x = c ^ (1 + x) + | TCon (TF TCExp) [a,b] <- t' + , Just n' <- tIsNum a + , n == n' = tf2 TCExp a (tAdd (tNum (1::Int)) b) | otherwise = tf2 TCMul (tNum n) t where t' = tNoUser t diff --git a/src/Cryptol/TypeCheck/Solver/Numeric.hs b/src/Cryptol/TypeCheck/Solver/Numeric.hs index 60123183a..dc7e11c8d 100644 --- a/src/Cryptol/TypeCheck/Solver/Numeric.hs +++ b/src/Cryptol/TypeCheck/Solver/Numeric.hs @@ -48,6 +48,7 @@ cryIsEqual ctxt t1 t2 = <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 + <|> tryEqExp t1 t2 -- | Try to solve @t1 /= t2@ cryIsNotEqual :: Ctxt -> Type -> Type -> Solved @@ -67,6 +68,7 @@ cryIsGeq i t1 t2 = <|> tryAddConst (>==) t1 t2 <|> tryCancelVar i (>==) t1 t2 <|> tryMinIsGeq t1 t2 + <|> tryGeqExp i t1 t2 -- XXX: k >= width e -- XXX: width e >= k @@ -137,6 +139,17 @@ tryGeqThanK _ t (Nat k) = -- XXX: K1 ^^ n >= K2 +-- (x >= 2 && x^a >= x^b) => a >= b +tryGeqExp :: Ctxt -> Type -> Type -> Match Solved +tryGeqExp _ x y = + do (x_1, a) <- (|^|) x + (x_2, b) <- (|^|) y + guard (x_1 == x_2) + n <- aNat x_1 + guard (n >= 2) + return $ SolvedIf [ a >== b ] + + tryGeqThanSub :: Ctxt -> Type -> Type -> Match Solved tryGeqThanSub _ x y = @@ -223,6 +236,20 @@ tryCancelVar ctxt p t1 t2 = +-- if (x >= 2) && x^a*x^b = x^c => a+b = c +tryEqExp :: Type -> Type -> Match Solved +tryEqExp x y = check x y <|> check y x + where + check i j = + do (m1,m2) <- aMul i + (x_1, a) <- (|^|) m1 + (x_2, b) <- (|^|) m2 + (x_3, c) <- (|^|) j + guard (x_1 == x_2 && x_2 == x_3) + n <- aNat x_1 + guard (n >= 2) + return $ SolvedIf [ tAdd a b =#= c ] + -- min t1 t2 = t1 ~> t1 <= t2 tryEqMin :: Type -> Type -> Match Solved tryEqMin x y = diff --git a/tests/issues/issue1489/issue1489.cry b/tests/issues/issue1489/issue1489.cry new file mode 100644 index 000000000..bac33e1a9 --- /dev/null +++ b/tests/issues/issue1489/issue1489.cry @@ -0,0 +1,54 @@ +module ID where + +id : {k} (fin k, k > 0) => [2^^k] -> [2^^k] +id x = join(split`{2,2^^(k-1)}x) + +type q = 3329 + +ct_butterfly : + {m, hm} + (m >= 2, m <= 8, hm >= 1, hm <= 7, hm == m - 1) => + [2^^m](Z q) -> (Z q) -> [2^^m](Z q) +ct_butterfly v z = new_v + where + halflen = 2^^`hm + lower, upper : [2^^hm](Z q) + lower@x = v@x + z * v@(x + halflen) + upper@x = v@x - z * v@(x + halflen) + new_v = lower # upper + +zeta_expc : [128](Z q) +zeta_expc = [ 1, 1729, 2580, 3289, 2642, 630, 1897, 848, + 1062, 1919, 193, 797, 2786, 3260, 569, 1746, + 296, 2447, 1339, 1476, 3046, 56, 2240, 1333, + 1426, 2094, 535, 2882, 2393, 2879, 1974, 821, + 289, 331, 3253, 1756, 1197, 2304, 2277, 2055, + 650, 1977, 2513, 632, 2865, 33, 1320, 1915, + 2319, 1435, 807, 452, 1438, 2868, 1534, 2402, + 2647, 2617, 1481, 648, 2474, 3110, 1227, 910, + 17, 2761, 583, 2649, 1637, 723, 2288, 1100, + 1409, 2662, 3281, 233, 756, 2156, 3015, 3050, + 1703, 1651, 2789, 1789, 1847, 952, 1461, 2687, + 939, 2308, 2437, 2388, 733, 2337, 268, 641, + 1584, 2298, 2037, 3220, 375, 2549, 2090, 1645, + 1063, 319, 2773, 757, 2099, 561, 2466, 2594, + 2804, 1092, 403, 1026, 1143, 2150, 2775, 886, + 1722, 1212, 1874, 1029, 2110, 2935, 885, 2154 ] + +fast_nttl : + {lv} // Length of v is a member of {256,128,64,32,16,8,4} + (lv >= 2, lv <= 8) => + [2^^lv](Z q) -> [8] -> [2^^lv](Z q) +fast_nttl v k + // Base case. lv==2 so just compute the butterfly and return + | lv == 2 => ct_butterfly`{lv,lv-1} v (zeta_expc@k) + + // Recursive case. Butterfly what we have, then recurse on each half, + // concatenate the results and return. As above, we need coerceSize + // here (twice) to satisfy the type checker. + | lv > 2 => (fast_nttl`{lv-1} s0 (k * 2)) # + (fast_nttl`{lv-1} s1 (k * 2 + 1)) + where + t = ct_butterfly`{lv,lv-1} v (zeta_expc@k) + // Split t into two halves s0 and s1 + [s0, s1] = split t \ No newline at end of file diff --git a/tests/issues/issue1489/issue1489.icry b/tests/issues/issue1489/issue1489.icry new file mode 100644 index 000000000..972ee633e --- /dev/null +++ b/tests/issues/issue1489/issue1489.icry @@ -0,0 +1 @@ +:load ./issue1489.cry \ No newline at end of file diff --git a/tests/issues/issue1489/issue1489.icry.stdout b/tests/issues/issue1489/issue1489.icry.stdout new file mode 100644 index 000000000..0f044dbc8 --- /dev/null +++ b/tests/issues/issue1489/issue1489.icry.stdout @@ -0,0 +1,3 @@ +Loading module Cryptol +Loading module Cryptol +Loading module ID From 681fcd45ca6e4aaebeb8db730e33070bbc2b29ff Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 30 Jul 2024 11:40:24 -0600 Subject: [PATCH 2/3] Add (x>=2, x^a = x^b => a = b) rule to Numeric.hs. --- src/Cryptol/TypeCheck/Solver/Numeric.hs | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/src/Cryptol/TypeCheck/Solver/Numeric.hs b/src/Cryptol/TypeCheck/Solver/Numeric.hs index dc7e11c8d..929b89089 100644 --- a/src/Cryptol/TypeCheck/Solver/Numeric.hs +++ b/src/Cryptol/TypeCheck/Solver/Numeric.hs @@ -48,6 +48,7 @@ cryIsEqual ctxt t1 t2 = <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 + <|> tryEqMulExp t1 t2 <|> tryEqExp t1 t2 -- | Try to solve @t1 /= t2@ @@ -237,8 +238,8 @@ tryCancelVar ctxt p t1 t2 = -- if (x >= 2) && x^a*x^b = x^c => a+b = c -tryEqExp :: Type -> Type -> Match Solved -tryEqExp x y = check x y <|> check y x +tryEqMulExp :: Type -> Type -> Match Solved +tryEqMulExp x y = check x y <|> check y x where check i j = do (m1,m2) <- aMul i @@ -249,6 +250,19 @@ tryEqExp x y = check x y <|> check y x n <- aNat x_1 guard (n >= 2) return $ SolvedIf [ tAdd a b =#= c ] + +-- if (x >= 2) && x^a = x^b => a = b +tryEqExp :: Type -> Type -> Match Solved +tryEqExp x y = check x y <|> check y x + where + check i j = + do + (x_1, a) <- (|^|) i + (x_2, b) <- (|^|) j + guard (x_1 == x_2) + n <- aNat x_1 + guard (n >= 2) + return $ SolvedIf [ a =#= b ] -- min t1 t2 = t1 ~> t1 <= t2 tryEqMin :: Type -> Type -> Match Solved From 110a753a94506829ac11402ccf4f0067bb18bebe Mon Sep 17 00:00:00 2001 From: Ryan Date: Tue, 30 Jul 2024 13:07:43 -0600 Subject: [PATCH 3/3] Move (K^a * K^b) => K^(a + b) from Numeric.hs to SimpType.hs to allow EQ check in Numeric.hs to work for multiplication of exponents as well. Clean up comments by using K for numeric constants, and move numeric constant checks to as early as possibly in Numeric.hs --- src/Cryptol/TypeCheck/SimpType.hs | 13 +++++++-- src/Cryptol/TypeCheck/Solver/Numeric.hs | 35 +++++++------------------ 2 files changed, 21 insertions(+), 27 deletions(-) diff --git a/src/Cryptol/TypeCheck/SimpType.hs b/src/Cryptol/TypeCheck/SimpType.hs index 971da0703..1892d8c3c 100644 --- a/src/Cryptol/TypeCheck/SimpType.hs +++ b/src/Cryptol/TypeCheck/SimpType.hs @@ -142,7 +142,8 @@ tMul x y | Just n <- tIsNum x = mulK n y | Just n <- tIsNum y = mulK n x | Just v <- matchMaybe swapVars = v - | otherwise = tf2 TCMul x y + | otherwise = checkExpMul x y + where mulK 0 _ = tNum (0 :: Int) mulK 1 t = t @@ -158,7 +159,7 @@ tMul x y | TCon (TF TCExp) [a,b] <- t' , Just n' <- tIsNum a , n == n' = tf2 TCExp a (tAdd (tNum (1::Int)) b) - + -- c^x * c^y = c ^ (y + x) | otherwise = tf2 TCMul (tNum n) t where t' = tNoUser t @@ -166,6 +167,14 @@ tMul x y b <- aTVar y guard (b < a) return (tf2 TCMul y x) + + -- Check if (K^a * K^b) => K^(a + b) otherwise default to standard mul + checkExpMul s t | TCon (TF TCExp) [a,aExp] <- s + , Just a' <- tIsNum a + , TCon (TF TCExp) [b,bExp] <- t + , Just b' <- tIsNum b + , (a' >= 2 && a' == b') = tf2 TCExp a (tAdd aExp bExp) + | otherwise = tf2 TCMul x y diff --git a/src/Cryptol/TypeCheck/Solver/Numeric.hs b/src/Cryptol/TypeCheck/Solver/Numeric.hs index 929b89089..1ea8dea23 100644 --- a/src/Cryptol/TypeCheck/Solver/Numeric.hs +++ b/src/Cryptol/TypeCheck/Solver/Numeric.hs @@ -48,7 +48,6 @@ cryIsEqual ctxt t1 t2 = <|> tryCancelVar ctxt (=#=) t1 t2 <|> tryLinearSolution t1 t2 <|> tryLinearSolution t2 t1 - <|> tryEqMulExp t1 t2 <|> tryEqExp t1 t2 -- | Try to solve @t1 /= t2@ @@ -140,14 +139,14 @@ tryGeqThanK _ t (Nat k) = -- XXX: K1 ^^ n >= K2 --- (x >= 2 && x^a >= x^b) => a >= b +-- (K >= 2 && K^a >= K^b) => a >= b tryGeqExp :: Ctxt -> Type -> Type -> Match Solved tryGeqExp _ x y = - do (x_1, a) <- (|^|) x - (x_2, b) <- (|^|) y - guard (x_1 == x_2) - n <- aNat x_1 + do (k_1, a) <- (|^|) x + n <- aNat k_1 guard (n >= 2) + (k_2, b) <- (|^|) y + guard (k_1 == k_2) return $ SolvedIf [ a >== b ] @@ -237,31 +236,17 @@ tryCancelVar ctxt p t1 t2 = --- if (x >= 2) && x^a*x^b = x^c => a+b = c -tryEqMulExp :: Type -> Type -> Match Solved -tryEqMulExp x y = check x y <|> check y x - where - check i j = - do (m1,m2) <- aMul i - (x_1, a) <- (|^|) m1 - (x_2, b) <- (|^|) m2 - (x_3, c) <- (|^|) j - guard (x_1 == x_2 && x_2 == x_3) - n <- aNat x_1 - guard (n >= 2) - return $ SolvedIf [ tAdd a b =#= c ] - --- if (x >= 2) && x^a = x^b => a = b +-- if (K >= 2) && K^a = K^b => a = b tryEqExp :: Type -> Type -> Match Solved tryEqExp x y = check x y <|> check y x where check i j = do - (x_1, a) <- (|^|) i - (x_2, b) <- (|^|) j - guard (x_1 == x_2) - n <- aNat x_1 + (k_1, a) <- (|^|) i + n <- aNat k_1 guard (n >= 2) + (k_2, b) <- (|^|) j + guard (k_1 == k_2) return $ SolvedIf [ a =#= b ] -- min t1 t2 = t1 ~> t1 <= t2