mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-16 10:03:41 -07:00
refactoring
This commit is contained in:
parent
45abcff3d1
commit
a030ff6951
2 changed files with 25 additions and 16 deletions
|
@ -19,8 +19,8 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
|
|||
|
||||
class TernausNet(object):
|
||||
VERSION = 1
|
||||
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
|
||||
nn.initialize(data_format="NHWC")
|
||||
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, data_format="NHWC"):
|
||||
nn.initialize(data_format=data_format)
|
||||
tf = nn.tf
|
||||
|
||||
class Ternaus(nn.ModelBase):
|
||||
|
@ -87,23 +87,23 @@ class TernausNet(object):
|
|||
x = self.conv_center(x)
|
||||
|
||||
x = tf.nn.relu(self.conv1_up(x))
|
||||
x = tf.concat( [x,x4], -1)
|
||||
x = tf.concat( [x,x4], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv1(x))
|
||||
|
||||
x = tf.nn.relu(self.conv2_up(x))
|
||||
x = tf.concat( [x,x3], -1)
|
||||
x = tf.concat( [x,x3], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv2(x))
|
||||
|
||||
x = tf.nn.relu(self.conv3_up(x))
|
||||
x = tf.concat( [x,x2], -1)
|
||||
x = tf.concat( [x,x2], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv3(x))
|
||||
|
||||
x = tf.nn.relu(self.conv4_up(x))
|
||||
x = tf.concat( [x,x1], -1)
|
||||
x = tf.concat( [x,x1], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv4(x))
|
||||
|
||||
x = tf.nn.relu(self.conv5_up(x))
|
||||
x = tf.concat( [x,x0], -1)
|
||||
x = tf.concat( [x,x0], nn.conv2d_ch_axis)
|
||||
x = tf.nn.relu(self.conv5(x))
|
||||
|
||||
logits = self.out_conv(x)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue