Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【PaddlePaddle Hackathon】16、在 Paddle 中新增 WideResNet #36954

Closed
wants to merge 25 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 16 additions & 27 deletions python/paddle/tests/test_vision_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
unit tests of all models in paddle.vision.models
"""
import unittest
import numpy as np

Expand All @@ -20,18 +23,22 @@


class TestVisonModels(unittest.TestCase):
"""
unit tests of all models in paddle.vision.models
"""
def models_infer(self, arch, pretrained=False, batch_norm=False):

"""
arch: the name of the model to be tested
"""
x = np.array(np.random.random((2, 3, 224, 224)), dtype=np.float32)
if batch_norm:
net = models.__dict__[arch](pretrained=pretrained, batch_norm=True)
else:
net = models.__dict__[arch](pretrained=pretrained)

input = InputSpec([None, 3, 224, 224], 'float32', 'image')
model = paddle.Model(net, input)
model.prepare()

model.predict_batch(x)

def test_mobilenetv2_pretrained(self):
Expand Down Expand Up @@ -70,6 +77,12 @@ def test_resnet101(self):
def test_resnet152(self):
self.models_infer('resnet152')

def test_wide_resnet50_2(self):
self.models_infer('wide_resnet50_2')

def test_wide_resnet101_2(self):
self.models_infer('wide_resnet101_2')

def test_densenet121(self):
self.models_infer('densenet121')

Expand All @@ -88,9 +101,6 @@ def test_densenet264(self):
def test_alexnet(self):
self.models_infer('alexnet')

def test_shufflenetv2_swish(self):
self.models_infer('shufflenet_v2_swish')

def test_resnext50_32x4d(self):
self.models_infer('resnext50_32x4d')

Expand All @@ -112,27 +122,6 @@ def test_resnext152_64x4d(self):
def test_inception_v3(self):
self.models_infer('inception_v3')

def test_googlenet(self):
self.models_infer('googlenet')

def test_shufflenetv2_x0_25(self):
self.models_infer('shufflenet_v2_x0_25')

def test_shufflenetv2_x0_33(self):
self.models_infer('shufflenet_v2_x0_33')

def test_shufflenetv2_x0_5(self):
self.models_infer('shufflenet_v2_x0_5')

def test_shufflenetv2_x1_0(self):
self.models_infer('shufflenet_v2_x1_0')

def test_shufflenetv2_x1_5(self):
self.models_infer('shufflenet_v2_x1_5')

def test_shufflenetv2_x2_0(self):
self.models_infer('shufflenet_v2_x2_0')

def test_vgg16_num_classes(self):
vgg16 = models.__dict__['vgg16'](pretrained=False, num_classes=10)

Expand Down
2 changes: 2 additions & 0 deletions python/paddle/vision/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@
from .models import resnet50 # noqa: F401
from .models import resnet101 # noqa: F401
from .models import resnet152 # noqa: F401
from .models import wide_resnet50_2
from .models import wide_resnet101_2
from .models import MobileNetV1 # noqa: F401
from .models import mobilenet_v1 # noqa: F401
from .models import MobileNetV2 # noqa: F401
Expand Down
5 changes: 4 additions & 1 deletion python/paddle/vision/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,14 @@
#WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#See the License for the specific language governing permissions and
#limitations under the License.

from .resnet import ResNet # noqa: F401
from .resnet import resnet18 # noqa: F401
from .resnet import resnet34 # noqa: F401
from .resnet import resnet50 # noqa: F401
from .resnet import resnet101 # noqa: F401
from .resnet import resnet152 # noqa: F401
from .resnet import wide_resnet50_2
from .resnet import wide_resnet101_2
from .mobilenetv1 import MobileNetV1 # noqa: F401
from .mobilenetv1 import mobilenet_v1 # noqa: F401
from .mobilenetv2 import MobileNetV2 # noqa: F401
Expand Down Expand Up @@ -63,6 +64,8 @@
'resnet50',
'resnet101',
'resnet152',
'wide_resnet50_2',
'wide_resnet101_2',
'VGG',
'vgg11',
'vgg13',
Expand Down
Loading