Skip to content

Commit 81cd686

Browse files
committed
Move aggregation (convpool) for nest into NestLevel, cleanup and enable features_only use. Finalize weight url.
1 parent 6ae0ac6 commit 81cd686

File tree

3 files changed

+95
-112
lines changed

3 files changed

+95
-112
lines changed

README.md

+3
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,9 @@ I'm fortunate to be able to dedicate significant time and money of my own suppor
2323

2424
## What's New
2525

26+
### July 5, 2021
27+
* Add 'Aggregating Nested Transformer' (NesT) w/ weights converted from official [Flax impl](https://github.com/google-research/nested-transformer). Contributed by [Alexander Soare](https://github.com/alexander-soare).
28+
2629
### June 23, 2021
2730
* Reproduce gMLP model training, `gmlp_s16_224` trained to 79.6 top-1, matching [paper](https://arxiv.org/abs/2105.08050). Hparams for this and other recent MLP training [here](https://gist.github.com/rwightman/d6c264a9001f9167e06c209f630b2cc6)
2831

convert/convert_nest_flax.py

+12-12
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,18 @@ def convert_nest(checkpoint_path, arch):
7979
state_dict[f'levels.{level}.transformer_encoder.{layer}.mlp.fc{i+1}.bias'] = torch.tensor(
8080
flax_dict[f'EncoderNDBlock_{global_layer_ix}']['MlpBlock_0'][f'Dense_{i}']['bias'])
8181

82-
# Block aggregations
83-
for level in range(len(depths)-1):
82+
# Block aggregations (ConvPool)
83+
for level in range(1, len(depths)):
8484
# Convs
85-
state_dict[f'block_aggs.{level}.conv.weight'] = torch.tensor(
86-
flax_dict[f'ConvPool_{level}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
87-
state_dict[f'block_aggs.{level}.conv.bias'] = torch.tensor(
88-
flax_dict[f'ConvPool_{level}']['Conv_0']['bias'])
85+
state_dict[f'levels.{level}.pool.conv.weight'] = torch.tensor(
86+
flax_dict[f'ConvPool_{level-1}']['Conv_0']['kernel']).permute(3, 2, 0, 1)
87+
state_dict[f'levels.{level}.pool.conv.bias'] = torch.tensor(
88+
flax_dict[f'ConvPool_{level-1}']['Conv_0']['bias'])
8989
# Norms
90-
state_dict[f'block_aggs.{level}.norm.weight'] = torch.tensor(
91-
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['scale'])
92-
state_dict[f'block_aggs.{level}.norm.bias'] = torch.tensor(
93-
flax_dict[f'ConvPool_{level}']['LayerNorm_0']['bias'])
90+
state_dict[f'levels.{level}.pool.norm.weight'] = torch.tensor(
91+
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['scale'])
92+
state_dict[f'levels.{level}.pool.norm.bias'] = torch.tensor(
93+
flax_dict[f'ConvPool_{level-1}']['LayerNorm_0']['bias'])
9494

9595
# Final norm
9696
state_dict[f'norm.weight'] = torch.tensor(flax_dict['LayerNorm_0']['scale'])
@@ -105,5 +105,5 @@ def convert_nest(checkpoint_path, arch):
105105

106106
if __name__ == '__main__':
107107
variant = sys.argv[1] # base, small, or tiny
108-
state_dict = convert_nest(f'../nested-transformer/checkpoints/nest-{variant[0]}_imagenet', f'nest_{variant}')
109-
torch.save(state_dict, f'/home/alexander/.cache/torch/hub/checkpoints/jx_nest_{variant}.pth')
108+
state_dict = convert_nest(f'./nest-{variant[0]}_imagenet', f'nest_{variant}')
109+
torch.save(state_dict, f'./jx_nest_{variant}.pth')

0 commit comments

Comments
 (0)