diff --git a/tests/python/unittest/test_gluon.py b/tests/python/unittest/test_gluon.py index 3049674821c9..abe6b136fe0c 100644 --- a/tests/python/unittest/test_gluon.py +++ b/tests/python/unittest/test_gluon.py @@ -2411,7 +2411,7 @@ def hybrid_forward(self, F, x): x_reshape = x.reshape(self.reshape) out = self.act(x_reshape) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for act in acts: x = mx.nd.random.uniform(-1, 1, shape=(4, 16, 32, 32)) shape = (4, 32, 32, -1) @@ -2433,7 +2433,7 @@ def hybrid_forward(self, F, x): out = self.act(x_slice) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for act in acts: x = mx.nd.random.uniform(-1, 1, shape=(8, 32, 64, 64)) slice = [(0, 16, 32, 32), (4, 32, 64, 64)] @@ -2457,7 +2457,7 @@ def hybrid_forward(self, F, x): y_reshape = y.reshape(self.reshape[1]) out = self.act1(y_reshape) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for idx0, act0 in enumerate(acts): for idx1, act1 in enumerate(acts): if idx1 == idx0: @@ -2484,7 +2484,7 @@ def hybrid_forward(self, F, x): y_slice = y.slice(begin=self.slice[1][0], end=self.slice[1][1]) out = self.act1(y_slice) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for idx0, act0 in enumerate(acts): for idx1, act1 in enumerate(acts): if idx1 == idx0: @@ -2512,7 +2512,7 @@ def hybrid_forward(self, F, x): y_slice = y.slice(begin=self.slice[0], end=self.slice[1]) out = self.act1(y_slice) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for idx0, act0 in enumerate(acts): for idx1, act1 in enumerate(acts): if idx1 == idx0: @@ -2541,7 +2541,7 @@ def hybrid_forward(self, F, x): y_reshape = y.reshape(self.reshape) out = self.act1(y_reshape) return out - acts = ["relu", "sigmoid", "tanh", "softrelu"] + acts = ["relu", "sigmoid", "tanh", "softrelu", "softsign"] for idx0, act0 in enumerate(acts): for idx1, act1 in enumerate(acts): if idx1 == idx0: