Skip to content

Commit

Permalink
import multi_gpu_model
Browse files Browse the repository at this point in the history
  • Loading branch information
thomaskuestner committed Aug 20, 2020
1 parent f06675f commit 194a142
Showing 1 changed file with 5 additions and 2 deletions.
7 changes: 5 additions & 2 deletions models/ModelSet.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import tensorflow as tf
from tensorflow.keras.layers import *
import tensorflow.keras.backend as K
from tensorflow.keras.utils import multi_gpu_model
from .blocks import *
# from .loss_Function import *
import models.loss_function as loss_function
Expand Down Expand Up @@ -144,7 +145,7 @@ def model_MRGE_1(self, config):
if config['feed_pos']:
return create_and_compile_model([inputs, in_pos], out, config)
else:
return create_and_compile_model(in_, out, config)
return create_and_compile_model(inputs, out, config)

def model_MRGE_2(self, config):
"Experimental"
Expand Down Expand Up @@ -211,14 +212,15 @@ def model_MRGE_2(self, config):
if config['feed_pos']:
return create_and_compile_model([inputs, in_pos], out, config)
else:
return create_and_compile_model(in_, out, config)
return create_and_compile_model(inputs, out, config)


def model_U_net_old(self, config, depth=None):

conv_param = config['convolution_parameter']
inputs = Input(shape=(*config['patch_size'],) + (config['channel_img_num'],), name='inp1')
x = inputs
in_pos = None
levels = list()
# add levels with max pooling

Expand Down Expand Up @@ -638,6 +640,7 @@ def model_body_identification_classification(self, config):
'''

inputs = Input(shape=config['patch_size'], name='input_layer')
in_pos = None
n_base_filter = 32
reshaped = Reshape([config['patch_size'][1], config['patch_size'][2], 1])(inputs)

Expand Down

0 comments on commit 194a142

Please sign in to comment.