Skip to content
This repository has been archived by the owner on Aug 30, 2022. It is now read-only.

Commit

Permalink
masking doctest fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
finiteprods committed Aug 20, 2020
1 parent 6c0ffbb commit 196e7e9
Showing 1 changed file with 17 additions and 17 deletions.
34 changes: 17 additions & 17 deletions rust/src/mask/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -103,8 +103,8 @@
//! };
//!
//! // mask the local models
//! let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//! let (local_mask_seed_1, masked_local_model_1, masked_local_scalar_1) = Masker::new(config).mask(scalar, local_model_1);
//! let (local_mask_seed_2, masked_local_model_2, masked_local_scalar_2) = Masker::new(config).mask(scalar, local_model_2);
//!
//! // derive the masks of the local masked models
//! let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
Expand All @@ -124,17 +124,17 @@
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
//! # let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);
//! // aggregate the local masks
//! # let (local_mask_seed_1, masked_local_model_1, masked_local_scalar_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2, masked_local_scalar_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let (local_model_mask_1, local_scalar_mask_1) = local_mask_seed_1.derive_mask(number_weights, config);
//! # let (local_model_mask_2, local_scalar_mask_2) = local_mask_seed_2.derive_mask(number_weights, config);
//! // aggregate the local model masks (similarly for local scalar masks)
//! let mut mask_aggregator = Aggregation::new(config, number_weights);
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_1) {
//! mask_aggregator.aggregate(local_mask_1);
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) {
//! mask_aggregator.aggregate(local_model_mask_1);
//! };
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_2) {
//! mask_aggregator.aggregate(local_mask_2);
//! if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) {
//! mask_aggregator.aggregate(local_model_mask_2);
//! };
//! let global_mask: MaskObject = mask_aggregator.into();
//!
Expand All @@ -160,13 +160,13 @@
//! # let local_model_1 = Model::from_primitives_bounded(vec![0_f32; number_weights].into_iter());
//! # let local_model_2 = Model::from_primitives_bounded(vec![1_f32; number_weights].into_iter());
//! # let config = MaskConfig { group_type: GroupType::Prime, data_type: DataType::F32, bound_type: BoundType::B0, model_type: ModelType::M3};
//! # let (local_mask_seed_1, masked_local_model_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let local_mask_1 = local_mask_seed_1.derive_mask(number_weights, config);
//! # let local_mask_2 = local_mask_seed_2.derive_mask(number_weights, config);
//! # let (local_mask_seed_1, masked_local_model_1, masked_local_scalar_1) = Masker::new(config).mask(scalar, local_model_1);
//! # let (local_mask_seed_2, masked_local_model_2, masked_local_scalar_2) = Masker::new(config).mask(scalar, local_model_2);
//! # let (local_model_mask_1, local_scalar_mask_1) = local_mask_seed_1.derive_mask(number_weights, config);
//! # let (local_model_mask_2, local_scalar_mask_2) = local_mask_seed_2.derive_mask(number_weights, config);
//! # let mut mask_aggregator = Aggregation::new(config, number_weights);
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_1) { mask_aggregator.aggregate(local_mask_1); };
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_mask_2) { mask_aggregator.aggregate(local_mask_2); };
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_1) { mask_aggregator.aggregate(local_model_mask_1); };
//! # if let Ok(_) = mask_aggregator.validate_aggregation(&local_model_mask_2) { mask_aggregator.aggregate(local_model_mask_2); };
//! # let global_mask: MaskObject = mask_aggregator.into();
//! # let mut model_aggregator = Aggregation::new(config, number_weights);
//! # if let Ok(_) = model_aggregator.validate_aggregation(&masked_local_model_1) { model_aggregator.aggregate(masked_local_model_1); };
Expand Down

0 comments on commit 196e7e9

Please sign in to comment.