From 4b586a5a22002dc955d025b890bc632daa3c01c7 Mon Sep 17 00:00:00 2001 From: hwang595 Date: Tue, 14 Jul 2020 15:43:25 -0500 Subject: [PATCH] add 1-round fedma mnist lenet --- matching/pfnm.py | 37 +++++++++++++++++++++++++------------ matching/utils.py | 16 ++++++++++++++++ 2 files changed, 41 insertions(+), 12 deletions(-) diff --git a/matching/pfnm.py b/matching/pfnm.py index 6fa499d..4bae386 100644 --- a/matching/pfnm.py +++ b/matching/pfnm.py @@ -168,6 +168,8 @@ def block_patching(w_j, L_next, assignment_j_c, layer_index, model_meta_data, shape_estimator = ModerateCNNContainerConvBlocks(num_filters=matching_shapes) elif dataset == "mnist": shape_estimator = ModerateCNNContainerConvBlocksMNIST(num_filters=matching_shapes) + elif network_name == "lenet": + shape_estimator = LeNetContainer(num_filters=matching_shapes, kernel_size=5) if dataset in ("cifar10", "cinic10"): dummy_input = torch.rand(1, 3, 32, 32) @@ -890,7 +892,9 @@ def layer_wise_group_descent(batch_weights, layer_index, batch_frequencies, sigm #sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)] #sigma_inv_layer = [np.array((matching_shapes[layer_index - 2]) * [1 / sigma] + [1 / sigma_bias] + [y / sigma for y in last_layer_const[j]]) for j in range(J)] - sigma_inv_layer = [np.array((matching_shapes[layer_index - 2]) * [1 / sigma] + [1 / sigma_bias]) for j in range(J)] + + #sigma_inv_layer = [np.array((matching_shapes[layer_index - 2]) * [1 / sigma] + [1 / sigma_bias]) for j in range(J)] + sigma_inv_layer = [np.array([1 / sigma_bias] + (weights_bias[j].shape[1] - 1) * [1 / sigma]) for j in range(J)] elif (layer_index > 1 and layer_index < (n_layers - 1)): layer_type = model_layer_type[2 * layer_index - 2] @@ -966,19 +970,28 @@ def layer_wise_group_descent(batch_weights, layer_index, batch_frequencies, sigm # softmax_inv_sigma] # remove fitting the last layer - if first_fc_identifier: - global_weights_out = [global_weights_c[:, 0:-softmax_bias.shape[0]-1].T, - global_weights_c[:, -softmax_bias.shape[0]-1]] + # if first_fc_identifier: + # global_weights_out = [global_weights_c[:, 0:-softmax_bias.shape[0]-1].T, + # global_weights_c[:, -softmax_bias.shape[0]-1]] - global_inv_sigmas_out = [global_sigmas_c[:, 0:-softmax_bias.shape[0]-1].T, - global_sigmas_c[:, -softmax_bias.shape[0]-1]] - else: - global_weights_out = [global_weights_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, - global_weights_c[:, matching_shapes[layer_index - 1 - 1]]] + # global_inv_sigmas_out = [global_sigmas_c[:, 0:-softmax_bias.shape[0]-1].T, + # global_sigmas_c[:, -softmax_bias.shape[0]-1]] + # else: + # global_weights_out = [global_weights_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, + # global_weights_c[:, matching_shapes[layer_index - 1 - 1]]] - global_inv_sigmas_out = [global_sigmas_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, - global_sigmas_c[:, matching_shapes[layer_index - 1 - 1]]] - logger.info("Branch B, Layer index: {}, Global weights out shapes: {}".format(layer_index, [gwo.shape for gwo in global_weights_out])) + # global_inv_sigmas_out = [global_sigmas_c[:, 0:matching_shapes[layer_index - 1 - 1]].T, + # global_sigmas_c[:, matching_shapes[layer_index - 1 - 1]]] + layer_type = model_layer_type[2 * layer_index - 2] + gwc_shape = global_weights_c.shape + if "conv" in layer_type or 'features' in layer_type: + global_weights_out = [global_weights_c[:, 0:gwc_shape[1]-1], global_weights_c[:, gwc_shape[1]-1]] + global_inv_sigmas_out = [global_sigmas_c[:, 0:gwc_shape[1]-1], global_sigmas_c[:, gwc_shape[1]-1]] + elif "fc" in layer_type or 'classifier' in layer_type: + global_weights_out = [global_weights_c[:, 0:gwc_shape[1]-1].T, global_weights_c[:, gwc_shape[1]-1]] + global_inv_sigmas_out = [global_sigmas_c[:, 0:gwc_shape[1]-1].T, global_sigmas_c[:, gwc_shape[1]-1]] + + logger.info("#### Branch B, Layer index: {}, Global weights out shapes: {}".format(layer_index, [gwo.shape for gwo in global_weights_out])) elif (layer_index > 1 and layer_index < (n_layers - 1)): layer_type = model_layer_type[2 * layer_index - 2] diff --git a/matching/utils.py b/matching/utils.py index 2a4318a..33d3189 100644 --- a/matching/utils.py +++ b/matching/utils.py @@ -101,4 +101,20 @@ def __init__(self, num_filters, output_dim=10): def forward(self, x): x = self.conv_layer(x) + return x + + +class LeNetContainer(nn.Module): + def __init__(self, num_filters, kernel_size=5): + super(LeNetContainer, self).__init__() + self.conv1 = nn.Conv2d(1, num_filters[0], kernel_size, 1) + self.conv2 = nn.Conv2d(num_filters[0], num_filters[1], kernel_size, 1) + + def forward(self, x): + x = self.conv1(x) + x = F.max_pool2d(x, 2, 2) + #x = F.relu(x) + x = self.conv2(x) + x = F.max_pool2d(x, 2, 2) + #x = F.relu(x) return x \ No newline at end of file