Skip to content

Commit

Permalink
bugfix building block interfaces
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaskuestner committed Aug 20, 2020
1 parent 5d3aa05 commit 5c1a052
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 9 deletions.
3 changes: 2 additions & 1 deletion models/basic_block.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import tensorflow as tf
from tensorflow.keras.layers import *
from tensorflow.keras.activations import tanh

"""basic network building blocks"""

Expand Down Expand Up @@ -87,7 +88,7 @@ def func(x):
elif item == 's':
x = Softmax()(x)
elif item == 't':
x=Tanh()(x)
x=tanh()(x)

# Pooling#
elif item == 'ap': # average pooling
Expand Down
16 changes: 8 additions & 8 deletions models/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,10 @@ def DilatedConv(filters, conv_param):
return block(filters, k=3, s=1, order=['b', 'r', 'c'], order_param=[None, None, conv_param])

"""atrous spatial pyramid pooling"""
def ASPP(filters, strides, conv_param, image_level_pool_size):
def ASPP(filters, strides, conv_param, dilation_rate_list, image_level_pool_size):
## pyramid part
def func(x):
pyramid_1x1 = block(filters, 1, stride, order=['c', 'b', 'r'], order_param=[conv_param, None, None])(x)
pyramid_1x1 = block(filters, 1, strides, order=['c', 'b', 'r'], order_param=[conv_param, None, None])(x)
branch = [pyramid_1x1]
for rate in dilation_rate_list:
conv_param['dilation_rate'] = rate
Expand Down Expand Up @@ -77,10 +77,10 @@ def MR_local_pr_no_bn(x):
return MR_local_pr_no_bn


def MR_global_pr_no_bn():
def MR_global_pr_no_bn(filters, conv_param):
def MR_global_pr_no_bn(x):
x = block(filters // 2, 1, 1, order=['c', 'r'], order_param=[conv_param, None])(x)
x = DilatedConv_no_bn(filters // 2, conv_param)(x)
x = block(filters // 2, 3, 1, order=['c', 'r', 'c', 'r'], order_param=[conv_param, None])(x)
return block(filters, 1, 1, order=['c', 'r'], order_param=[conv_param, None])(x)
return MR_global_pr_no_bn

Expand Down Expand Up @@ -130,11 +130,11 @@ def MR_GE_block_merge(x, y):

def MRGE_exp_block(filters, dilation_max, conv_param_local,conv_param_global):
def MRGE_exp_block(x):
x, y = MR_block_split(filters, cov_param)(x)
x, y = MR_block_split(filters, conv_param_local)(x)
block_num = int(log2(dilation_max) + 1)
rate_list = [2 ** i for i in range(block_num)]
for rate in rate_list[:-1]:
cov_param['dilation_rate'] = rate
conv_param_local['dilation_rate'] = rate
x, y = MR_GE_block(filters, conv_param_local,conv_param_global)(x, y)
x = MR_GE_block_merge(filters, conv_param_local,conv_param_global)(x, y)
return x
Expand Down Expand Up @@ -219,10 +219,10 @@ def dense_conv(x):

# Transition pool layer
def transitionLayerPool(filters,conv_param):
return lambda x: block(filters, 1,1,order=['b','r','c','ap'], order_param=[None,None, conv_param,{pool_size:2}])(x)
return lambda x: block(filters, 1,1,order=['b','r','c','ap'], order_param=[None,None, conv_param,{'pool_size':2}])(x)

# Transition transpose up layer
def transitionLayerTransposeUp(mode, f, lbda):
def transitionLayerTransposeUp(filters, conv_param):
def func(x):
x=block(filters, 1, 1, order=['b', 'r', 'c'],order_param=[None, None, conv_param])(x)
return block(filters, 3, 2, order=[ 'dc'],order_param=[conv_param])(x)
Expand Down

0 comments on commit 5c1a052

Please sign in to comment.