diff --git a/facelib/TernausNet.py b/facelib/TernausNet.py index 1efab9e..9ac54c0 100644 --- a/facelib/TernausNet.py +++ b/facelib/TernausNet.py @@ -19,8 +19,8 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat class TernausNet(object): VERSION = 1 - def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False): - nn.initialize(data_format="NHWC") + def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False, data_format="NHWC"): + nn.initialize(data_format=data_format) tf = nn.tf class Ternaus(nn.ModelBase): @@ -87,23 +87,23 @@ class TernausNet(object): x = self.conv_center(x) x = tf.nn.relu(self.conv1_up(x)) - x = tf.concat( [x,x4], -1) + x = tf.concat( [x,x4], nn.conv2d_ch_axis) x = tf.nn.relu(self.conv1(x)) x = tf.nn.relu(self.conv2_up(x)) - x = tf.concat( [x,x3], -1) + x = tf.concat( [x,x3], nn.conv2d_ch_axis) x = tf.nn.relu(self.conv2(x)) x = tf.nn.relu(self.conv3_up(x)) - x = tf.concat( [x,x2], -1) + x = tf.concat( [x,x2], nn.conv2d_ch_axis) x = tf.nn.relu(self.conv3(x)) x = tf.nn.relu(self.conv4_up(x)) - x = tf.concat( [x,x1], -1) + x = tf.concat( [x,x1], nn.conv2d_ch_axis) x = tf.nn.relu(self.conv4(x)) x = tf.nn.relu(self.conv5_up(x)) - x = tf.concat( [x,x0], -1) + x = tf.concat( [x,x0], nn.conv2d_ch_axis) x = tf.nn.relu(self.conv5(x)) logits = self.out_conv(x) diff --git a/samplelib/SampleProcessor.py b/samplelib/SampleProcessor.py index d14759d..7705ee1 100644 --- a/samplelib/SampleProcessor.py +++ b/samplelib/SampleProcessor.py @@ -10,11 +10,12 @@ from facelib import FaceType, LandmarksProcessor class SampleProcessor(object): class SampleType(IntEnum): NONE = 0 - FACE_IMAGE = 1 - FACE_MASK = 2 - LANDMARKS_ARRAY = 3 - PITCH_YAW_ROLL = 4 - PITCH_YAW_ROLL_SIGMOID = 5 + IMAGE = 1 + FACE_IMAGE = 2 + FACE_MASK = 3 + LANDMARKS_ARRAY = 4 + PITCH_YAW_ROLL = 5 + PITCH_YAW_ROLL_SIGMOID = 6 class ChannelType(IntEnum): NONE = 0 @@ -92,11 +93,12 @@ class SampleProcessor(object): ct_mode = opts.get('ct_mode', None) data_format = opts.get('data_format', 'NHWC') - if sample_type == SPST.FACE_MASK: + if sample_type == SPST.FACE_MASK or sample_type == SPST.IMAGE: border_replicate = False elif sample_type == SPST.FACE_IMAGE: border_replicate = True + border_replicate = opts.get('border_replicate', border_replicate) borderMode = cv2.BORDER_REPLICATE if border_replicate else cv2.BORDER_CONSTANT @@ -230,9 +232,16 @@ class SampleProcessor(object): out_sample = np.clip (out_sample * 2.0 - 1.0, -1.0, 1.0) if data_format == "NCHW": out_sample = np.transpose(out_sample, (2,0,1) ) - #else: - # img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True) - # img = cv2.resize( img, (resolution,resolution), cv2.INTER_CUBIC ) + elif sample_type == SPST.IMAGE: + img = sample_bgr + img = imagelib.warp_by_params (params, img, warp, transform, can_flip=True, border_replicate=True) + img = cv2.resize( img, (resolution, resolution), cv2.INTER_CUBIC ) + out_sample = img + + if data_format == "NCHW": + out_sample = np.transpose(out_sample, (2,0,1) ) + + elif sample_type == SPST.LANDMARKS_ARRAY: l = sample_landmarks l = np.concatenate ( [ np.expand_dims(l[:,0] / w,-1), np.expand_dims(l[:,1] / h,-1) ], -1 )