diff --git a/train.py b/train.py index df61e6dc74..ed926a93db 100755 --- a/train.py +++ b/train.py @@ -485,7 +485,7 @@ def main(): bn_eps=args.bn_eps, scriptable=args.torchscript, checkpoint_path=args.initial_checkpoint, - features_only=args.use_pyramid_head, + features_only=args.use_pyramid_head or args.use_fp2t, **factory_kwargs, **args.model_kwargs, )