Skip to content

Commit

Permalink
Migrates rust codegen to use tfhe-rs shortint library for boolean gat…
Browse files Browse the repository at this point in the history
…es. The boolean API is only for compatibility, and 1 bit shortints is faster than the boolean module.

PiperOrigin-RevId: 557818764
  • Loading branch information
asraa authored and copybara-github committed Aug 17, 2023
1 parent 3d1ab61 commit 52431fc
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 70 deletions.
13 changes: 0 additions & 13 deletions transpiler/examples/add_one/add_one_lib.rs
Original file line number Diff line number Diff line change
@@ -1,21 +1,11 @@
#[cfg(not(lut))]
use tfhe::boolean::ciphertext::Ciphertext;
#[cfg(not(lut))]
use tfhe::boolean::prelude::*;

#[cfg(lut)]
use tfhe::shortint::prelude::*;
#[cfg(lut)]
use tfhe::shortint::CiphertextBig as Ciphertext;

// Encrypt a u8
pub fn encrypt(value: u8, client_key: &ClientKey) -> Vec<Ciphertext> {
(0..8)
.map(|shift| {
let bit = (value >> shift) & 1;
#[cfg(not(lut))]
return client_key.encrypt(if bit != 0 { true } else { false });
#[cfg(lut)]
return client_key.encrypt(if bit != 0 { 1 } else { 0 });
})
.collect()
Expand All @@ -42,9 +32,6 @@ mod tests {
use tfhe::shortint::parameters::PARAM_MESSAGE_1_CARRY_2;

fn run_test_for(x: u8) -> u8 {
#[cfg(not(lut))]
let (client_key, server_key) = gen_keys();
#[cfg(lut)]
let (client_key, server_key) = gen_keys(PARAM_MESSAGE_1_CARRY_2);
let fhe_val = encrypt(x, &client_key);
decrypt(&add_one(&fhe_val, &server_key), &client_key)
Expand Down
48 changes: 26 additions & 22 deletions transpiler/rust/tfhe_rs_templates.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,10 @@ constexpr absl::string_view kCodegenTemplate = R"rust(
use rayon::prelude::*;
use std::collections::HashMap;
#[cfg(lut)]
use tfhe::shortint;
#[cfg(lut)]
use tfhe::shortint::prelude::*;
#[cfg(lut)]
use tfhe::shortint::CiphertextBig as Ciphertext;
#[cfg(not(lut))]
use tfhe::boolean::prelude::*;
#[cfg(not(lut))]
use tfhe::boolean::ciphertext::Ciphertext;
#[cfg(lut)]
fn generate_lut(lut_as_int: u64, server_key: &ServerKey) -> shortint::server_key::LookupTableOwned {
let f = |x: u64| (lut_as_int >> (x as u8)) & 1;
return server_key.generate_accumulator(f);
Expand All @@ -41,6 +32,7 @@ enum GateInput {
use GateInput::*;
#[cfg(not(lut))]
#[derive(PartialEq, Eq, Hash)]
enum CellType {
AND2,
NAND2,
Expand All @@ -49,7 +41,7 @@ enum CellType {
OR2,
NOR2,
INV,
IMUX2,
// TODO: Add back MUX2
}
#[cfg(lut)]
Expand All @@ -68,12 +60,8 @@ fn prune(temp_nodes: &mut HashMap<usize, Ciphertext>, temp_node_ids: &[usize]) {
}
pub fn $function_signature {
#[cfg(lut)]
let (constant_false, constant_true): (Ciphertext, Ciphertext) = (
server_key.create_trivial(0), server_key.create_trivial(1));
#[cfg(not(lut))]
let (constant_false, constant_true): (Ciphertext, Ciphertext) = (
server_key.trivial_encrypt(false), server_key.trivial_encrypt(true));
let args: &[&Vec<Ciphertext>] = &[$ordered_params];
Expand All @@ -87,6 +75,16 @@ pub fn $function_signature {
luts
};
#[cfg(not(lut))]
let luts = {
let mut luts: HashMap<CellType, shortint::server_key::LookupTableOwned> = HashMap::new();
const CELLS_TO_LUTS: [(CellType, u64); 3] = [(NAND2, 7), (NOR2, 1), (XNOR2, 9)];
for (cell, lut) in CELLS_TO_LUTS {
luts.insert(cell, generate_lut(lut, server_key));
}
luts
};
#[cfg(lut)]
let lut3 = |args: &[&Ciphertext], lut: u64| -> Ciphertext {
let top_bit = server_key.unchecked_scalar_mul(args[2], 4);
Expand All @@ -95,6 +93,13 @@ pub fn $function_signature {
return server_key.apply_lookup_table(&ct_input, &luts[&lut]);
};
#[cfg(not(lut))]
let boolean_lut = |args: &[&Ciphertext], cell: CellType| -> Ciphertext {
let first_bit = server_key.unchecked_scalar_mul(args[1], 2);
let ct_input = &server_key.unchecked_add(&first_bit, args[0]);
return server_key.apply_lookup_table(&ct_input, &luts[&cell]);
};
let mut temp_nodes = HashMap::new();
let mut $output_stem = Vec::new();
$output_stem.resize($num_outputs, constant_false.clone());
Expand All @@ -121,14 +126,13 @@ pub fn $function_signature {
};
#[cfg(not(lut))]
let gate_func = |args: &[&Ciphertext]| match celltype {
AND2 => server_key.and(args[0], args[1]),
NAND2 => server_key.nand(args[0], args[1]),
OR2 => server_key.or(args[0], args[1]),
NOR2 => server_key.nor(args[0], args[1]),
XOR2 => server_key.xor(args[0], args[1]),
XNOR2 => server_key.xnor(args[0], args[1]),
INV => server_key.not(args[0]),
IMUX2 => server_key.mux(args[0], args[1], args[2]),
AND2 => server_key.bitand(args[0], args[1]),
NAND2 => boolean_lut(args, NAND2),
OR2 => server_key.bitor(args[0], args[1]),
NOR2 => boolean_lut(args, NOR2),
XOR2 => server_key.bitxor(args[0], args[1]),
XNOR2 => boolean_lut(args, XNOR2),
INV => server_key.bitxor(args[0], &constant_true),
};
((*id, *is_output), gate_func(&task_args))
})
Expand Down
35 changes: 0 additions & 35 deletions transpiler/yosys/tfhe-rs_cells.liberty
Original file line number Diff line number Diff line change
Expand Up @@ -257,39 +257,4 @@ library(supergate) {
}
}
}

/* 2-input MUX */
cell(imux2) {
// Per TFHE spec, MUX has 2x the cost of the other gates.
area : 200;
pin(A) {
direction : input;
}
pin(B) {
direction : input;
}
pin(S) {
direction : input;
}
pin(Y) {
direction: output;
function : "(A * S) + (B * (S'))";
timing() {
related_pin : "A B S";
timing_sense : non_unate;
cell_rise(scalar) {
values("1000.0");
}
cell_fall(scalar) {
values("1000.0");
}
rise_transition(scalar) {
values("1000.0");
}
fall_transition(scalar) {
values("1000.0");
}
}
}
}
} /* end */

0 comments on commit 52431fc

Please sign in to comment.