refactoring

This commit is contained in:
Colombo 2020-03-09 13:08:32 +04:00
parent 45abcff3d1
commit a030ff6951
2 changed files with 25 additions and 16 deletions

View file

@ -19,8 +19,8 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
class TernausNet(object):
VERSION = 1
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
nn.initialize(data_format="NHWC")
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, data_format="NHWC"):
nn.initialize(data_format=data_format)
tf = nn.tf
class Ternaus(nn.ModelBase):
@ -87,23 +87,23 @@ class TernausNet(object):
x = self.conv_center(x)
x = tf.nn.relu(self.conv1_up(x))
x = tf.concat( [x,x4], -1)
x = tf.concat( [x,x4], nn.conv2d_ch_axis)
x = tf.nn.relu(self.conv1(x))
x = tf.nn.relu(self.conv2_up(x))
x = tf.concat( [x,x3], -1)
x = tf.concat( [x,x3], nn.conv2d_ch_axis)
x = tf.nn.relu(self.conv2(x))
x = tf.nn.relu(self.conv3_up(x))
x = tf.concat( [x,x2], -1)
x = tf.concat( [x,x2], nn.conv2d_ch_axis)
x = tf.nn.relu(self.conv3(x))
x = tf.nn.relu(self.conv4_up(x))
x = tf.concat( [x,x1], -1)
x = tf.concat( [x,x1], nn.conv2d_ch_axis)
x = tf.nn.relu(self.conv4(x))
x = tf.nn.relu(self.conv5_up(x))
x = tf.concat( [x,x0], -1)
x = tf.concat( [x,x0], nn.conv2d_ch_axis)
x = tf.nn.relu(self.conv5(x))
logits = self.out_conv(x)