diff --git a/torchsummary/torchsummary.py b/torchsummary/torchsummary.py index 9bc2d65..c393b3a 100644 --- a/torchsummary/torchsummary.py +++ b/torchsummary/torchsummary.py @@ -3,9 +3,9 @@ from torch.autograd import Variable from collections import OrderedDict +import numpy as np - -def summary(model, input_size, device="cuda"): +def summary(model, input_size, batch_size=-1,device="cuda"): def register_hook(module): def hook(module, input, output): class_name = str(module.__class__).split('.')[-1].split("'")[0] @@ -14,12 +14,12 @@ def hook(module, input, output): m_key = '%s-%i' % (class_name, module_idx+1) summary[m_key] = OrderedDict() summary[m_key]['input_shape'] = list(input[0].size()) - summary[m_key]['input_shape'][0] = -1 + summary[m_key]['input_shape'][0] = batch_size if isinstance(output, (list,tuple)): summary[m_key]['output_shape'] = [[-1] + list(o.size())[1:] for o in output] else: summary[m_key]['output_shape'] = list(output.size()) - summary[m_key]['output_shape'][0] = -1 + summary[m_key]['output_shape'][0] = batch_size params = 0 if hasattr(module, 'weight') and hasattr(module.weight, 'size'): @@ -67,18 +67,31 @@ def hook(module, input, output): print(line_new) print('================================================================') total_params = 0 + total_output = 0 trainable_params = 0 for layer in summary: # input_shape, output_shape, trainable, nb_params line_new = '{:>20} {:>25} {:>15}'.format(layer, str(summary[layer]['output_shape']), '{0:,}'.format(summary[layer]['nb_params'])) total_params += summary[layer]['nb_params'] + total_output += np.prod(summary[layer]['output_shape']) if 'trainable' in summary[layer]: if summary[layer]['trainable'] == True: trainable_params += summary[layer]['nb_params'] print(line_new) + #assume 4 bytes/number (float on cuda). + total_input_size = abs(np.prod(input_size)*batch_size*4./(1024**2.)) + total_output_size = abs(2.*total_output*4./(1024**2.)) #x2 for gradients + total_params_size = abs(total_params.numpy()*4./(1024**2.)) + total_size = total_params_size + total_output_size + total_input_size + print('================================================================') print('Total params: {0:,}'.format(total_params)) print('Trainable params: {0:,}'.format(trainable_params)) print('Non-trainable params: {0:,}'.format(total_params - trainable_params)) print('----------------------------------------------------------------') - # return summary + print('Input size (MB): %0.2f' % total_input_size) + print('Forward/backward pass size (MB): %0.2f' % total_output_size) + print('Params size (MB): %0.2f' % total_params_size) + print('Estimated Total Size (MB): %0.2f' % total_size) + print('----------------------------------------------------------------') + #return summary