mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
refactoring
This commit is contained in:
parent
45abcff3d1
commit
a030ff6951
2 changed files with 25 additions and 16 deletions
|
@ -19,8 +19,8 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
|
||||||
|
|
||||||
class TernausNet(object):
|
class TernausNet(object):
|
||||||
VERSION = 1
|
VERSION = 1
|
||||||
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
|
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, data_format="NHWC"):
|
||||||
nn.initialize(data_format="NHWC")
|
nn.initialize(data_format=data_format)
|
||||||
tf = nn.tf
|
tf = nn.tf
|
||||||
|
|
||||||
class Ternaus(nn.ModelBase):
|
class Ternaus(nn.ModelBase):
|
||||||
|
@ -87,23 +87,23 @@ class TernausNet(object):
|
||||||
x = self.conv_center(x)
|
x = self.conv_center(x)
|
||||||
|
|
||||||
x = tf.nn.relu(self.conv1_up(x))
|
x = tf.nn.relu(self.conv1_up(x))
|
||||||
x = tf.concat( [x,x4], -1)
|
x = tf.concat( [x,x4], nn.conv2d_ch_axis)
|
||||||
x = tf.nn.relu(self.conv1(x))
|
x = tf.nn.relu(self.conv1(x))
|
||||||
|
|
||||||
x = tf.nn.relu(self.conv2_up(x))
|
x = tf.nn.relu(self.conv2_up(x))
|
||||||
x = tf.concat( [x,x3], -1)
|
x = tf.concat( [x,x3], nn.conv2d_ch_axis)
|
||||||
x = tf.nn.relu(self.conv2(x))
|
x = tf.nn.relu(self.conv2(x))
|
||||||
|
|
||||||
x = tf.nn.relu(self.conv3_up(x))
|
x = tf.nn.relu(self.conv3_up(x))
|
||||||
x = tf.concat( [x,x2], -1)
|
x = tf.concat( [x,x2], nn.conv2d_ch_axis)
|
||||||
x = tf.nn.relu(self.conv3(x))
|
x = tf.nn.relu(self.conv3(x))
|
||||||
|
|
||||||
x = tf.nn.relu(self.conv4_up(x))
|
x = tf.nn.relu(self.conv4_up(x))
|
||||||
x = tf.concat( [x,x1], -1)
|
x = tf.concat( [x,x1], nn.conv2d_ch_axis)
|
||||||
x = tf.nn.relu(self.conv4(x))
|
x = tf.nn.relu(self.conv4(x))
|
||||||
|
|
||||||
x = tf.nn.relu(self.conv5_up(x))
|
x = tf.nn.relu(self.conv5_up(x))
|
||||||
x = tf.concat( [x,x0], -1)
|
x = tf.concat( [x,x0], nn.conv2d_ch_axis)
|
||||||
x = tf.nn.relu(self.conv5(x))
|
x = tf.nn.relu(self.conv5(x))
|
||||||
|
|
||||||
logits = self.out_conv(x)
|
logits = self.out_conv(x)
|
||||||
|
|
|
@ -10,11 +10,12 @@ from facelib import FaceType, LandmarksProcessor
|
||||||
class SampleProcessor(object):
|
class SampleProcessor(object):
|
||||||
class SampleType(IntEnum):
|
class SampleType(IntEnum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
FACE_IMAGE = 1
|
IMAGE = 1
|
||||||
FACE_MASK = 2
|
FACE_IMAGE = 2
|
||||||
LANDMARKS_ARRAY = 3
|
FACE_MASK = 3
|
||||||
PITCH_YAW_ROLL = 4
|
LANDMARKS_ARRAY = 4
|
||||||
PITCH_YAW_ROLL_SIGMOID = 5
|
PITCH_YAW_ROLL = 5
|
||||||
|
PITCH_YAW_ROLL_SIGMOID = 6
|
||||||
|
|
||||||
class ChannelType(IntEnum):
|
class ChannelType(IntEnum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
|
@ -92,11 +93,12 @@ class SampleProcessor(object):
|
||||||
ct_mode = opts.get('ct_mode', None)
|
ct_mode = opts.get('ct_mode', None)
|
||||||
data_format = opts.get('data_format', 'NHWC')
|
data_format = opts.get('data_format', 'NHWC')
|
||||||
|
|
||||||
if sample_type == SPST.FACE_MASK:
|
if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE:
|
||||||
border_replicate = False
|
border_replicate = False
|
||||||
elif sample_type == SPST.FACE_IMAGE:
|
elif sample_type == SPST.FACE_IMAGE:
|
||||||
border_replicate = True
|
border_replicate = True
|
||||||
|
|
||||||
|
|
||||||
border_replicate = opts.get('border_replicate', border_replicate)
|
border_replicate = opts.get('border_replicate', border_replicate)
|
||||||
borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT
|
borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT
|
||||||
|
|
||||||
|
@ -230,9 +232,16 @@ class SampleProcessor(object):
|
||||||
out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0)
|
out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0)
|
||||||
if data_format == "NCHW":
|
if data_format == "NCHW":
|
||||||
out_sample = np.transpose(out_sample, (2,0,1) )
|
out_sample = np.transpose(out_sample, (2,0,1) )
|
||||||
#else:
|
elif sample_type == SPST.IMAGE:
|
||||||
# img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True)
|
img = sample_bgr
|
||||||
# img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC )
|
img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True)
|
||||||
|
img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC )
|
||||||
|
out_sample = img
|
||||||
|
|
||||||
|
if data_format == "NCHW":
|
||||||
|
out_sample = np.transpose(out_sample, (2,0,1) )
|
||||||
|
|
||||||
|
|
||||||
elif sample_type == SPST.LANDMARKS_ARRAY:
|
elif sample_type == SPST.LANDMARKS_ARRAY:
|
||||||
l = sample_landmarks
|
l = sample_landmarks
|
||||||
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue