Skip to content

Commit

Permalink
perf(compiler): use readers as much as possible to avoid copies
Browse files Browse the repository at this point in the history
Readers were automatically casted to Messages which cost a memory copy.
It's now required to explicitly make this conversion (copy).
  • Loading branch information
youben11 committed Dec 9, 2024
1 parent f0e0a08 commit c409a51
Show file tree
Hide file tree
Showing 14 changed files with 244 additions and 127 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ class LweSecretKey {
static LweSecretKey
fromProto(const Message<concreteprotocol::LweSecretKey> &proto);

static LweSecretKey fromProto(concreteprotocol::LweSecretKey::Reader reader);

Message<concreteprotocol::LweSecretKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down Expand Up @@ -95,6 +97,10 @@ class LweBootstrapKey {
static LweBootstrapKey
fromProto(const Message<concreteprotocol::LweBootstrapKey> &proto);

/// @brief Initialize the key from a reader.
static LweBootstrapKey
fromProto(concreteprotocol::LweBootstrapKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweBootstrapKey> toProto() const;

Expand Down Expand Up @@ -147,6 +153,10 @@ class LweKeyswitchKey {
static LweKeyswitchKey
fromProto(const Message<concreteprotocol::LweKeyswitchKey> &proto);

/// @brief Initialize the key from a reader.
static LweKeyswitchKey
fromProto(concreteprotocol::LweKeyswitchKey::Reader reader);

/// @brief Returns the serialized form of the key.
Message<concreteprotocol::LweKeyswitchKey> toProto() const;

Expand Down Expand Up @@ -199,6 +209,9 @@ class PackingKeyswitchKey {
static PackingKeyswitchKey
fromProto(const Message<concreteprotocol::PackingKeyswitchKey> &proto);

static PackingKeyswitchKey
fromProto(concreteprotocol::PackingKeyswitchKey::Reader reader);

Message<concreteprotocol::PackingKeyswitchKey> toProto() const;

const uint64_t *getRawPtr() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ struct ClientKeyset {
static ClientKeyset
fromProto(const Message<concreteprotocol::ClientKeyset> &proto);

static ClientKeyset fromProto(concreteprotocol::ClientKeyset::Reader reader);

Message<concreteprotocol::ClientKeyset> toProto() const;
};

Expand All @@ -43,6 +45,7 @@ struct ServerKeyset {

static ServerKeyset
fromProto(const Message<concreteprotocol::ServerKeyset> &proto);
static ServerKeyset fromProto(concreteprotocol::ServerKeyset::Reader reader);

Message<concreteprotocol::ServerKeyset> toProto() const;
};
Expand Down Expand Up @@ -73,6 +76,7 @@ struct Keyset {
: server(server), client(client) {}

static Keyset fromProto(const Message<concreteprotocol::Keyset> &proto);
static Keyset fromProto(concreteprotocol::Keyset::Reader reader);

Message<concreteprotocol::Keyset> toProto() const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,8 @@ template <typename MessageType> struct Message {
message = regionBuilder->initRoot<MessageType>();
}

Message(const typename MessageType::Reader &reader) : message(nullptr) {
explicit Message(const typename MessageType::Reader &reader)
: message(nullptr) {
regionBuilder = new capnp::MallocMessageBuilder(
std::min(reader.totalSize().wordCount, MAX_SEGMENT_SIZE),
capnp::AllocationStrategy::FIXED_SIZE);
Expand Down Expand Up @@ -308,7 +309,12 @@ vectorToProtoPayload(const std::vector<T> &input) {
template <typename T>
std::vector<T>
protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToVector<T>(input.asReader());
}

template <typename T>
std::vector<T> protoPayloadToVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
auto elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -331,7 +337,13 @@ protoPayloadToVector(const Message<concreteprotocol::Payload> &input) {
template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
auto payloadData = input.asReader().getData();
return protoPayloadToSharedVector<T>(input.asReader());
}

template <typename T>
std::shared_ptr<std::vector<T>>
protoPayloadToSharedVector(concreteprotocol::Payload::Reader reader) {
auto payloadData = reader.getData();
size_t elmsPerBlob = capnp::MAX_TEXT_SIZE / sizeof(T);
size_t totalPayloadSize = 0;
for (auto blob : payloadData) {
Expand All @@ -353,6 +365,8 @@ protoPayloadToSharedVector(const Message<concreteprotocol::Payload> &input) {
/// dimensions.
std::vector<size_t>
protoShapeToDimensions(const Message<concreteprotocol::Shape> &shape);
std::vector<size_t>
protoShapeToDimensions(concreteprotocol::Shape::Reader reader);

/// Helper function turning a protocol `Shape` object into a vector of
/// dimensions.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,7 @@ struct Value {

bool
isCompatibleWithShape(const Message<concreteprotocol::Shape> &shape) const;
bool isCompatibleWithShape(concreteprotocol::Shape::Reader reader) const;

bool isScalar() const;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1107,7 +1107,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto secretKeys = std::vector<LweSecretKeyParam>();
for (auto key : keysetInfo.asReader().getLweSecretKeys()) {
secretKeys.push_back(LweSecretKeyParam{key});
secretKeys.push_back(LweSecretKeyParam{
(Message<concreteprotocol::LweSecretKeyInfo>)key});
}
return secretKeys;
},
Expand All @@ -1117,7 +1118,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto bootstrapKeys = std::vector<BootstrapKeyParam>();
for (auto key : keysetInfo.asReader().getLweBootstrapKeys()) {
bootstrapKeys.push_back(BootstrapKeyParam{key});
bootstrapKeys.push_back(BootstrapKeyParam{
(Message<concreteprotocol::LweBootstrapKeyInfo>)key});
}
return bootstrapKeys;
},
Expand All @@ -1127,7 +1129,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto keyswitchKeys = std::vector<KeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getLweKeyswitchKeys()) {
keyswitchKeys.push_back(KeyswitchKeyParam{key});
keyswitchKeys.push_back(KeyswitchKeyParam{
(Message<concreteprotocol::LweKeyswitchKeyInfo>)key});
}
return keyswitchKeys;
},
Expand All @@ -1137,7 +1140,8 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](KeysetInfo &keysetInfo) {
auto packingKeyswitchKeys = std::vector<PackingKeyswitchKeyParam>();
for (auto key : keysetInfo.asReader().getPackingKeyswitchKeys()) {
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{key});
packingKeyswitchKeys.push_back(PackingKeyswitchKeyParam{
(Message<concreteprotocol::PackingKeyswitchKeyInfo>)key});
}
return packingKeyswitchKeys;
},
Expand Down Expand Up @@ -1220,13 +1224,13 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_type_info",
[](GateInfo &gate) -> TypeInfo {
return {gate.asReader().getTypeInfo()};
return {(TypeInfo)gate.asReader().getTypeInfo()};
},
"Return the type associated to the gate.")
.def(
"get_raw_info",
[](GateInfo &gate) -> RawInfo {
return {gate.asReader().getRawInfo()};
return {(RawInfo)gate.asReader().getRawInfo()};
},
"Return the raw type associated to the gate.")
.doc() = "Informations describing a circuit gate (input or output).";
Expand All @@ -1247,7 +1251,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getInputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand All @@ -1257,7 +1261,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
[](CircuitInfo &circuit) -> std::vector<GateInfo> {
auto output = std::vector<GateInfo>();
for (auto gate : circuit.asReader().getOutputs()) {
output.push_back({gate});
output.push_back({(Message<concreteprotocol::GateInfo>)gate});
}
return output;
},
Expand Down Expand Up @@ -1415,7 +1419,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
.def(
"get_keyset_info",
[](ProgramInfo &programInfo) -> KeysetInfo {
return programInfo.programInfo.asReader().getKeyset();
return (KeysetInfo)programInfo.programInfo.asReader().getKeyset();
},
"Return the keyset info associated to the program.")
.def(
Expand All @@ -1424,7 +1428,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
auto output = std::vector<CircuitInfo>();
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
output.push_back(circuit);
output.push_back((CircuitInfo)circuit);
}
return output;
},
Expand All @@ -1435,7 +1439,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
for (auto circuit :
programInfo.programInfo.asReader().getCircuits()) {
if (circuit.getName() == name) {
return circuit;
return (CircuitInfo)circuit;
}
}
throw std::runtime_error("couldn't find circuit.");
Expand Down Expand Up @@ -1552,7 +1556,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize server keyset." +
maybeError.as_failure().error().mesg);
}
return ServerKeyset::fromProto(serverKeysetProto);
return ServerKeyset::fromProto(serverKeysetProto.asReader());
},
"Deserialize a ServerKeyset from bytes.", arg("bytes"))
.def(
Expand Down Expand Up @@ -1604,17 +1608,18 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
GET_OR_THROW_RESULT(
Keyset keyset,
(*cache).getKeyset(
programInfo.programInfo.asReader().getKeyset(),
(KeysetInfo)programInfo.programInfo.asReader()
.getKeyset(),
secretSeed, encryptionSeed,
initialLweSecretKeys.value()));
return std::make_unique<Keyset>(std::move(keyset));
} else {
::concretelang::csprng::SecretCSPRNG secCsprng(secretSeed);
::concretelang::csprng::EncryptionCSPRNG encCsprng(
encryptionSeed);
auto keyset =
Keyset(programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
auto keyset = Keyset(
(KeysetInfo)programInfo.programInfo.asReader().getKeyset(),
secCsprng, encCsprng, initialLweSecretKeys.value());
return std::make_unique<Keyset>(std::move(keyset));
}
}),
Expand Down Expand Up @@ -1652,7 +1657,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Failed to deserialize keyset." +
maybeError.as_failure().error().mesg);
}
auto keyset = Keyset::fromProto(keysetProto);
auto keyset = Keyset::fromProto(std::move(keysetProto));
return std::make_unique<Keyset>(std::move(keyset));
},
"Deserialize a Keyset from a file.", arg("path"))
Expand Down Expand Up @@ -2034,8 +2039,10 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(

GET_OR_THROW_RESULT(auto pi, library.getProgramInfo());
GET_OR_THROW_RESULT(
auto result, ServerProgram::load(pi.asReader(), sharedLibPath,
useSimulation));
auto result,
ServerProgram::load(
(Message<concreteprotocol::ProgramInfo>)pi.asReader(),
sharedLibPath, useSimulation));
return result;
}),
arg("library"), arg("use_simulation"))
Expand All @@ -2061,7 +2068,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(
auto ok, circuit.prepareInput(typeTransformer(arg), pos));
return ok;
Expand All @@ -2084,7 +2091,7 @@ void mlir::concretelang::python::populateCompilerAPISubmodule(
throw std::runtime_error("Unknown position.");
}
auto info = circuit.getCircuitInfo().asReader().getInputs()[pos];
auto typeTransformer = getPythonTypeTransformer(info);
auto typeTransformer = getPythonTypeTransformer((GateInfo)info);
GET_OR_THROW_RESULT(auto ok, circuit.simulatePrepareInput(
typeTransformer(arg), pos));
return ok;
Expand Down
28 changes: 19 additions & 9 deletions compilers/concrete-compiler/compiler/lib/ClientLib/ClientLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,17 @@ ClientCircuit::create(const Message<concreteprotocol::CircuitInfo> &info,
InputTransformer transformer;
if (gateInfo.getTypeInfo().hasIndex()) {
OUTCOME_TRY(transformer,
TransformerFactory::getIndexInputTransformer(gateInfo));
TransformerFactory::getIndexInputTransformer(
(Message<concreteprotocol::GateInfo>)gateInfo));
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getPlaintextInputTransformer(gateInfo));
TransformerFactory::getPlaintextInputTransformer(
(Message<concreteprotocol::GateInfo>)gateInfo));
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getLweCiphertextInputTransformer(
keyset, gateInfo, csprng, useSimulation));
keyset, (Message<concreteprotocol::GateInfo>)gateInfo,
csprng, useSimulation));
} else {
return StringError("Malformed input gate info.");
}
Expand All @@ -69,14 +72,17 @@ ClientCircuit::create(const Message<concreteprotocol::CircuitInfo> &info,
OutputTransformer transformer;
if (gateInfo.getTypeInfo().hasIndex()) {
OUTCOME_TRY(transformer,
TransformerFactory::getIndexOutputTransformer(gateInfo));
TransformerFactory::getIndexOutputTransformer(
(Message<concreteprotocol::GateInfo>)gateInfo));
} else if (gateInfo.getTypeInfo().hasPlaintext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getPlaintextOutputTransformer(gateInfo));
TransformerFactory::getPlaintextOutputTransformer(
(Message<concreteprotocol::GateInfo>)gateInfo));
} else if (gateInfo.getTypeInfo().hasLweCiphertext()) {
OUTCOME_TRY(transformer,
TransformerFactory::getLweCiphertextOutputTransformer(
keyset, gateInfo, useSimulation));
keyset, (Message<concreteprotocol::GateInfo>)gateInfo,
useSimulation));
} else {
return StringError("Malformed output gate info.");
}
Expand Down Expand Up @@ -161,7 +167,9 @@ Result<ClientProgram> ClientProgram::createEncrypted(
ClientProgram output;
for (auto circuitInfo : info.asReader().getCircuits()) {
OUTCOME_TRY(const ClientCircuit clientCircuit,
ClientCircuit::createEncrypted(circuitInfo, keyset, csprng));
ClientCircuit::createEncrypted(
(Message<concreteprotocol::CircuitInfo>)circuitInfo, keyset,
csprng));
output.circuits.push_back(clientCircuit);
}
return output;
Expand All @@ -172,8 +180,10 @@ Result<ClientProgram> ClientProgram::createSimulated(
std::shared_ptr<csprng::EncryptionCSPRNG> csprng) {
ClientProgram output;
for (auto circuitInfo : info.asReader().getCircuits()) {
OUTCOME_TRY(const ClientCircuit clientCircuit,
ClientCircuit::createSimulated(circuitInfo, csprng));
OUTCOME_TRY(
const ClientCircuit clientCircuit,
ClientCircuit::createSimulated(
(Message<concreteprotocol::CircuitInfo>)circuitInfo, csprng));
output.circuits.push_back(clientCircuit);
}
return output;
Expand Down
Loading

0 comments on commit c409a51

Please sign in to comment.