Skip to content

Commit

Permalink
Fix pretrained initialization. Default initialization is now PCArchet…
Browse files Browse the repository at this point in the history
…ypal
  • Loading branch information
AlbertDominguez committed Jul 21, 2022
1 parent 171b5d8 commit 764af86
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 28 deletions.
17 changes: 15 additions & 2 deletions neural_admixture/model/initializations.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_decoder_init(cls, X, K, path, run_name, n_components):
centers = np.concatenate([obj.cluster_centers_ for obj in k_means_objs])
P_init = torch.as_tensor(pca_obj.inverse_transform(centers), dtype=torch.float32).view(sum(K), -1)
else:
k_means_obj = KMeans(n_clusters=K, random_state=42, n_init=10, max_iter=10).fit(X_tsvd)
k_means_obj = KMeans(n_clusters=K, random_state=42, n_init=10, max_iter=10).fit(X_pca)
P_init = torch.as_tensor(pca_obj.inverse_transform(k_means_obj.cluster_centers_), dtype=torch.float32).view(K, -1)
te = time.time()
log.info('Weights initialized in {} seconds.'.format(te-t0))
Expand All @@ -70,7 +70,7 @@ def get_decoder_init(cls, X, K, path, run_name, n_components):
class PCArchetypal(object):
@classmethod
def get_decoder_init(cls, X, K, path, run_name, n_components, seed):
log.info('Running ArchetypalPCA initialization...')
log.info('Running PCArchetypal initialization...')
np.random.seed(seed)
t0 = time.time()
try:
Expand Down Expand Up @@ -123,3 +123,16 @@ def get_decoder_init(cls, X, y, K):
te = time.time()
log.info('Weights initialized in {} seconds.'.format(te-t0))
return P_init


class PretrainedInitialization(object):
@classmethod
def get_decoder_init(cls, X, K, path):
log.info('Fetching pretrained weights...')
if len(K) > 1:
raise NotImplementedError("Pretrained mode is only supported for single-head runs.")
# Loads standard ADMIXTURE output format
P_init = torch.as_tensor(1-np.genfromtxt(path, delimiter=' ').T, dtype=torch.float32)
assert P_init.shape[0] == K[0], 'Input P is not coherent with the value of K'
log.info('Weights fetched.')
return P_init
21 changes: 0 additions & 21 deletions neural_admixture/model/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,24 +62,3 @@ def forward(self, hid_states):
outputs = [torch.clamp(self._get_decoder_for_k(self.ks[i])(hid_states[i]), 0, 1) for i in range(len(self.ks))]
return outputs

# class NonLinearMultiHeadDecoder(nn.Module):
# def __init__(self, ks, output_size, bias=False,
# hidden_size=512, hidden_activation=nn.ReLU(),
# inits=None):
# super().__init__()
# self.ks = ks
# self.hidden_size = hidden_size
# self.output_size = output_size
# self.heads_decoder = nn.Linear(sum(self.ks), self.hidden_size, bias=bias)
# self.common_decoder = nn.Linear(self.hidden_size, self.output_size)
# self.nonlinearity = hidden_activation
# self.sigmoid = nn.Sigmoid()

# def forward(self, hid_states):
# if len(hid_states) > 1:
# concat_states = torch.cat(hid_states, 1)
# else:
# concat_states = hid_states[0]
# dec = self.nonlinearity(self.heads_decoder(concat_states))
# rec = self.sigmoid(self.common_decoder(dec))
# return rec
1 change: 0 additions & 1 deletion neural_admixture/model/switchers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@ class Switchers(object):
'pcarchetypal': lambda X, y, k, seed, path, run_name, n_comp: init.PCArchetypal.get_decoder_init(X, k, path, run_name, n_comp, seed),
'pretrained': lambda X, y, k, seed, path, run_name, n_comp: init.PretrainedInitialization.get_decoder_init(X, k, path),
'supervised': lambda X, y, k, seed, path, run_name, n_comp: init.SupervisedInitialization.get_decoder_init(X, y, k)

}

_optimizers = {
Expand Down
4 changes: 2 additions & 2 deletions neural_admixture/src/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,10 @@ def fit_model(trX, args, valX=None, trY=None, valY=None):
torch.manual_seed(seed)
# Initialization
log.info('Initializing...')
if init_file is None:
if init_file is None and decoder_init != "pretrained":
log.warning(f'Initialization filename not provided. Going to store it to {save_dir}/{run_name}.pkl')
init_file = f'{run_name}.pkl'
init_path = f'{save_dir}/{init_file}'
init_path = f'{save_dir}/{init_file}' if decoder_init != "pretrained" else init_file
P_init = switchers['initializations'][decoder_init](trX, trY, Ks, seed, init_path, run_name, n_components)
activation = switchers['activations'][activation_str](0)
log.info('Variants: {}'.format(trX.shape[1]))
Expand Down
2 changes: 1 addition & 1 deletion neural_admixture/src/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ def parse_train_args(argv):
description='Rapid population clustering with autoencoders - training mode')
parser.add_argument('--learning_rate', required=False, default=0.0001, type=float, help='Learning rate')
parser.add_argument('--max_epochs', required=False, type=int, default=50, help='Maximum number of epochs')
parser.add_argument('--initialization', required=False, type=str, default = 'pckmeans',
parser.add_argument('--initialization', required=False, type=str, default = 'pcarchetypal',
choices=['pretrained', 'pckmeans', 'supervised', 'pcarchetypal'],
help='Decoder initialization (overriden if supervised)')
parser.add_argument('--optimizer', required=False, default='adam', type=str, choices=['adam', 'sgd'], help='Optimizer')
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

setup(
name='neural-admixture',
version='1.1.5',
version='1.1.6',
long_description=(Path(__file__).parent / 'README.md').read_text(),
long_description_content_type='text/markdown',
description='Population clustering with autoencoders',
Expand Down

0 comments on commit 764af86

Please sign in to comment.