Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
Merge pull request #507 from xaynetwork/mobile_client_expose_type
Browse files Browse the repository at this point in the history
Expose current state of the client & make scalar configurable
  • Loading branch information
Robert-Steiner authored Sep 7, 2020
2 parents f22e75b + 7239149 commit 6e4a511
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 40 deletions.
18 changes: 12 additions & 6 deletions rust/examples/mobile-client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@ extern crate tracing;
use std::io::{stdin, stdout, Read, Write};
use structopt::StructOpt;
use tracing_subscriber::*;
use xaynet_client::mobile_client::{participant::ParticipantSettings, MobileClient};
use xaynet_client::mobile_client::{
participant::{AggregationConfig, ParticipantSettings},
MobileClient,
};
use xaynet_core::mask::{
BoundType,
DataType,
Expand Down Expand Up @@ -41,11 +44,14 @@ fn get_participant_settings() -> ParticipantSettings {
let secret_key = MobileClient::create_participant_secret_key();
ParticipantSettings {
secret_key,
mask_config: MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
aggregation_config: AggregationConfig {
mask: MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
},
scalar: 1_f64,
},
}
}
Expand Down
12 changes: 3 additions & 9 deletions rust/xaynet-client/src/mobile_client/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,21 +170,15 @@ impl ClientState<Update> {
.await
.ok_or(ClientError::TooEarly("local model"))?;

debug!("setting model scalar");
let scalar = 1_f64; // TODO parametrise this!

debug!("polling for sum dict");
let sums = api
.get_sums()
.await?
.ok_or(ClientError::TooEarly("sum dict"))?;

let upd_msg = self.participant.compose_update_message(
self.round_params.pk,
&sums,
scalar,
local_model,
);
let upd_msg =
self.participant
.compose_update_message(self.round_params.pk, &sums, local_model);
let sealed_msg = self
.participant
.seal_message(&self.round_params.pk, &upd_msg);
Expand Down
17 changes: 17 additions & 0 deletions rust/xaynet-client/src/mobile_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ impl MobileClient {
})
}

/// Returns the current state of the client.
pub fn get_current_state(&self) -> ClientStateName {
match self.client_state {
ClientStateMachine::Awaiting(_) => ClientStateName::Awaiting,
ClientStateMachine::Sum(_) => ClientStateName::Sum,
ClientStateMachine::Update(_) => ClientStateName::Update,
ClientStateMachine::Sum2(_) => ClientStateName::Sum2,
}
}

/// Sets the local model.
///
/// The local model is only sent if the client has been selected as an update client.
Expand Down Expand Up @@ -177,6 +187,13 @@ impl MobileClient {
}
}

pub enum ClientStateName {
Awaiting,
Sum,
Update,
Sum2,
}

struct LocalModelCache(Option<Model>);

impl LocalModelCache {
Expand Down
18 changes: 11 additions & 7 deletions rust/xaynet-client/src/mobile_client/participant/awaiting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ impl Participant<Awaiting> {
#[cfg(test)]
mod tests {
use super::*;
use crate::mobile_client::participant::AggregationConfig;
use sodiumoxide::randombytes::randombytes;
use xaynet_core::{
crypto::{ByteObject, SigningKeyPair},
Expand All @@ -59,16 +60,19 @@ mod tests {
fn participant_state() -> ParticipantState {
sodiumoxide::init().unwrap();

let mask_config = MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
};
let aggregation_config = AggregationConfig {
mask: MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
},

scalar: 1_f64,
};
ParticipantState {
keys: SigningKeyPair::generate(),
mask_config,
aggregation_config,
}
}

Expand Down
14 changes: 10 additions & 4 deletions rust/xaynet-client/src/mobile_client/participant/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,33 +19,39 @@ pub mod update;

pub use self::{awaiting::Awaiting, sum::Sum, sum2::Sum2, update::Update};

#[derive(Serialize, Deserialize)]
pub struct AggregationConfig {
pub mask: MaskConfig,
pub scalar: f64,
}

#[derive(Serialize, Deserialize)]
pub struct ParticipantState {
// credentials
pub keys: SigningKeyPair,
// Mask config
pub mask_config: MaskConfig,
pub aggregation_config: AggregationConfig,
}

#[derive(Serialize, Deserialize)]
pub struct ParticipantSettings {
pub secret_key: ParticipantSecretKey,
pub mask_config: MaskConfig,
pub aggregation_config: AggregationConfig,
}

impl From<ParticipantSettings> for ParticipantState {
fn from(
ParticipantSettings {
secret_key,
mask_config,
aggregation_config,
}: ParticipantSettings,
) -> ParticipantState {
ParticipantState {
keys: SigningKeyPair {
public: secret_key.public_key(),
secret: secret_key,
},
mask_config,
aggregation_config,
}
}
}
Expand Down
25 changes: 15 additions & 10 deletions rust/xaynet-client/src/mobile_client/participant/sum2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,10 +84,11 @@ impl Participant<Sum2> {
return Err(PetError::InvalidMask);
}

let mut model_mask_agg = Aggregation::new(self.state.mask_config, mask_len);
let mut scalar_mask_agg = Aggregation::new(self.state.mask_config, 1);
let mut model_mask_agg = Aggregation::new(self.state.aggregation_config.mask, mask_len);
let mut scalar_mask_agg = Aggregation::new(self.state.aggregation_config.mask, 1);
for seed in mask_seeds.into_iter() {
let (model_mask, scalar_mask) = seed.derive_mask(mask_len, self.state.mask_config);
let (model_mask, scalar_mask) =
seed.derive_mask(mask_len, self.state.aggregation_config.mask);

model_mask_agg
.validate_aggregation(&model_mask)
Expand All @@ -106,6 +107,7 @@ impl Participant<Sum2> {
#[cfg(test)]
mod tests {
use super::*;
use crate::mobile_client::participant::AggregationConfig;
use sodiumoxide::randombytes::{randombytes, randombytes_uniform};
use std::{collections::HashSet, iter};
use xaynet_core::{
Expand All @@ -117,16 +119,19 @@ mod tests {
fn participant_state() -> ParticipantState {
sodiumoxide::init().unwrap();

let mask_config = MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
};
let aggregation_config = AggregationConfig {
mask: MaskConfig {
group_type: GroupType::Prime,
data_type: DataType::F32,
bound_type: BoundType::B0,
model_type: ModelType::M3,
},

scalar: 1_f64,
};
ParticipantState {
keys: SigningKeyPair::generate(),
mask_config,
aggregation_config,
}
}

Expand Down
9 changes: 5 additions & 4 deletions rust/xaynet-client/src/mobile_client/participant/update.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,10 @@ impl Participant<Update> {
&self,
coordinator_pk: CoordinatorPublicKey,
sum_dict: &SumDict,
scalar: f64,

local_model: Model,
) -> Message {
let (mask_seed, masked_model, masked_scalar) = self.mask_model(scalar, local_model);
let (mask_seed, masked_model, masked_scalar) = self.mask_model(local_model);
let local_seed_dict = Self::create_local_seed_dict(sum_dict, &mask_seed);

Message {
Expand All @@ -56,8 +56,9 @@ impl Participant<Update> {
}

/// Generate a mask seed and mask a local model.
fn mask_model(&self, scalar: f64, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) {
Masker::new(self.state.mask_config).mask(scalar, local_model)
fn mask_model(&self, local_model: Model) -> (MaskSeed, MaskObject, MaskObject) {
Masker::new(self.state.aggregation_config.mask)
.mask(self.state.aggregation_config.scalar, local_model)
}

// Create a local seed dictionary from a sum dictionary.
Expand Down

0 comments on commit 6e4a511

Please sign in to comment.