mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added test option: 'Apply random color transfer to src faceset'
This commit is contained in:
parent
bde700243c
commit
a805f81142
9 changed files with 152 additions and 129 deletions
|
@ -1,15 +1,18 @@
|
||||||
|
import time
|
||||||
import traceback
|
import traceback
|
||||||
from .Converter import Converter
|
|
||||||
from facelib import LandmarksProcessor
|
|
||||||
from facelib import FaceType
|
|
||||||
from facelib import FANSegmentator
|
|
||||||
import cv2
|
import cv2
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
import imagelib
|
import imagelib
|
||||||
|
from facelib import FaceType, FANSegmentator, LandmarksProcessor
|
||||||
from interact import interact as io
|
from interact import interact as io
|
||||||
from joblib import SubprocessFunctionCaller
|
from joblib import SubprocessFunctionCaller
|
||||||
from utils.pickle_utils import AntiPickler
|
from utils.pickle_utils import AntiPickler
|
||||||
import time
|
|
||||||
|
from .Converter import Converter
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
default_mode = {1:'overlay',
|
default_mode = {1:'overlay',
|
||||||
2:'hist-match',
|
2:'hist-match',
|
||||||
|
@ -93,10 +96,10 @@ class ConverterMasked(Converter):
|
||||||
self.blur_mask_modifier = base_blur_mask_modifier + np.clip ( io.input_int ("Choose blur mask modifier [-200..200] (skip:%d) : " % (default_blur_mask_modifier), default_blur_mask_modifier), -200, 200)
|
self.blur_mask_modifier = base_blur_mask_modifier + np.clip ( io.input_int ("Choose blur mask modifier [-200..200] (skip:%d) : " % (default_blur_mask_modifier), default_blur_mask_modifier), -200, 200)
|
||||||
|
|
||||||
self.output_face_scale = np.clip ( 1.0 + io.input_int ("Choose output face scale modifier [-50..50] (skip:0) : ", 0)*0.01, 0.5, 1.5)
|
self.output_face_scale = np.clip ( 1.0 + io.input_int ("Choose output face scale modifier [-50..50] (skip:0) : ", 0)*0.01, 0.5, 1.5)
|
||||||
|
|
||||||
if self.mode != 'raw':
|
if self.mode != 'raw':
|
||||||
self.color_transfer_mode = io.input_str ("Apply color transfer to predicted face? Choose mode ( rct/lct skip:None ) : ", None, ['rct','lct'])
|
self.color_transfer_mode = io.input_str ("Apply color transfer to predicted face? Choose mode ( rct/lct skip:None ) : ", None, ['rct','lct'])
|
||||||
|
|
||||||
self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False, help_message="Enhance details by applying DCSCN network.")
|
self.super_resolution = io.input_bool("Apply super resolution? (y/n ?:help skip:n) : ", False, help_message="Enhance details by applying DCSCN network.")
|
||||||
|
|
||||||
if self.mode != 'raw':
|
if self.mode != 'raw':
|
||||||
|
@ -173,12 +176,12 @@ class ConverterMasked(Converter):
|
||||||
prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
prd_face_mask_a_0 = cv2.resize (dst_face_mask_a_0, (output_size,output_size), cv2.INTER_CUBIC)
|
||||||
elif self.mask_mode >= 3 and self.mask_mode <= 6:
|
elif self.mask_mode >= 3 and self.mask_mode <= 6:
|
||||||
|
|
||||||
if self.mask_mode == 3 or self.mask_mode == 5 or self.mask_mode == 6:
|
if self.mask_mode == 3 or self.mask_mode == 5 or self.mask_mode == 6:
|
||||||
prd_face_bgr_256 = cv2.resize (prd_face_bgr, (256,256) )
|
prd_face_bgr_256 = cv2.resize (prd_face_bgr, (256,256) )
|
||||||
prd_face_bgr_256_mask = self.fan_seg.extract( prd_face_bgr_256 )
|
prd_face_bgr_256_mask = self.fan_seg.extract( prd_face_bgr_256 )
|
||||||
FAN_prd_face_mask_a_0 = cv2.resize (prd_face_bgr_256_mask, (output_size,output_size), cv2.INTER_CUBIC)
|
FAN_prd_face_mask_a_0 = cv2.resize (prd_face_bgr_256_mask, (output_size,output_size), cv2.INTER_CUBIC)
|
||||||
|
|
||||||
if self.mask_mode == 4 or self.mask_mode == 5 or self.mask_mode == 6:
|
if self.mask_mode == 4 or self.mask_mode == 5 or self.mask_mode == 6:
|
||||||
face_256_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, 256, face_type=FaceType.FULL)
|
face_256_mat = LandmarksProcessor.get_transform_mat (img_face_landmarks, 256, face_type=FaceType.FULL)
|
||||||
dst_face_256_bgr = cv2.warpAffine(img_bgr, face_256_mat, (256, 256), flags=cv2.INTER_LANCZOS4 )
|
dst_face_256_bgr = cv2.warpAffine(img_bgr, face_256_mat, (256, 256), flags=cv2.INTER_LANCZOS4 )
|
||||||
dst_face_256_mask = self.fan_seg.extract( dst_face_256_bgr )
|
dst_face_256_mask = self.fan_seg.extract( dst_face_256_bgr )
|
||||||
|
@ -192,7 +195,7 @@ class ConverterMasked(Converter):
|
||||||
prd_face_mask_a_0 = FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
prd_face_mask_a_0 = FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
||||||
elif self.mask_mode == 6:
|
elif self.mask_mode == 6:
|
||||||
prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
prd_face_mask_a_0 = prd_face_mask_a_0 * FAN_prd_face_mask_a_0 * FAN_dst_face_mask_a_0
|
||||||
|
|
||||||
prd_face_mask_a_0[ prd_face_mask_a_0 < 0.001 ] = 0.0
|
prd_face_mask_a_0[ prd_face_mask_a_0 < 0.001 ] = 0.0
|
||||||
|
|
||||||
prd_face_mask_a = prd_face_mask_a_0[...,np.newaxis]
|
prd_face_mask_a = prd_face_mask_a_0[...,np.newaxis]
|
||||||
|
@ -316,7 +319,7 @@ class ConverterMasked(Converter):
|
||||||
|
|
||||||
if self.masked_hist_match:
|
if self.masked_hist_match:
|
||||||
hist_mask_a *= prd_face_mask_a
|
hist_mask_a *= prd_face_mask_a
|
||||||
|
|
||||||
white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
white = (1.0-hist_mask_a)* np.ones ( prd_face_bgr.shape[:2] + (1,) , dtype=np.float32)
|
||||||
|
|
||||||
hist_match_1 = prd_face_bgr*hist_mask_a + white
|
hist_match_1 = prd_face_bgr*hist_mask_a + white
|
||||||
|
@ -326,10 +329,10 @@ class ConverterMasked(Converter):
|
||||||
hist_match_2[ hist_match_1 > 1.0 ] = 1.0
|
hist_match_2[ hist_match_1 > 1.0 ] = 1.0
|
||||||
|
|
||||||
prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, self.hist_match_threshold )
|
prd_face_bgr = imagelib.color_hist_match(hist_match_1, hist_match_2, self.hist_match_threshold )
|
||||||
|
|
||||||
#if self.masked_hist_match:
|
#if self.masked_hist_match:
|
||||||
# prd_face_bgr -= white
|
# prd_face_bgr -= white
|
||||||
|
|
||||||
if self.mode == 'hist-match-bw':
|
if self.mode == 'hist-match-bw':
|
||||||
prd_face_bgr = prd_face_bgr.astype(dtype=np.float32)
|
prd_face_bgr = prd_face_bgr.astype(dtype=np.float32)
|
||||||
|
|
||||||
|
@ -401,7 +404,7 @@ class ConverterMasked(Converter):
|
||||||
|
|
||||||
if debug:
|
if debug:
|
||||||
debugs += [ np.clip( cv2.warpAffine( new_out_face_bgr, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT ), 0, 1.0) ]
|
debugs += [ np.clip( cv2.warpAffine( new_out_face_bgr, face_output_mat, img_size, np.zeros(img_bgr.shape, dtype=np.float32), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT ), 0, 1.0) ]
|
||||||
|
|
||||||
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
|
new_out = cv2.warpAffine( new_out_face_bgr, face_mat, img_size, img_bgr.copy(), cv2.WARP_INVERSE_MAP | cv2.INTER_LANCZOS4, cv2.BORDER_TRANSPARENT )
|
||||||
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
|
out_img = np.clip( img_bgr*(1-img_mask_blurry_aaa) + (new_out*img_mask_blurry_aaa) , 0, 1.0 )
|
||||||
|
|
||||||
|
@ -430,4 +433,3 @@ class ConverterMasked(Converter):
|
||||||
debugs += [out_img.copy()]
|
debugs += [out_img.copy()]
|
||||||
|
|
||||||
return debugs if debug else out_img
|
return debugs if debug else out_img
|
||||||
|
|
||||||
|
|
|
@ -48,7 +48,7 @@ class Model(ModelBase):
|
||||||
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution' : self.resolution, 'motion_blur':(25, 1) },
|
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution' : self.resolution, 'motion_blur':(25, 1) },
|
||||||
{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': self.resolution },
|
{ 'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_M), 'resolution': self.resolution },
|
||||||
]),
|
]),
|
||||||
|
|
||||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
|
|
|
@ -47,7 +47,7 @@ class Model(ModelBase):
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M, t.FACE_MASK_FULL), 'resolution':128} ]
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
|
|
|
@ -57,7 +57,7 @@ class Model(ModelBase):
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':128},
|
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':128},
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M, t.FACE_MASK_FULL), 'resolution':128} ]
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':128} ]
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
|
|
|
@ -58,7 +58,7 @@ class Model(ModelBase):
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':64},
|
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':64},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':64},
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_BGR), 'resolution':64},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M, t.FACE_MASK_FULL), 'resolution':64} ]
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_HALF, t.MODE_M), 'resolution':64} ]
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
|
|
|
@ -52,7 +52,7 @@ class Model(ModelBase):
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
output_sample_types=[ { 'types': (t.IMG_WARPED_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_BGR), 'resolution':128},
|
||||||
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M, t.FACE_MASK_FULL), 'resolution':128} ]
|
{ 'types': (t.IMG_TRANSFORMED, t.FACE_TYPE_FULL, t.MODE_M), 'resolution':128} ]
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
SampleGeneratorFace(self.training_data_src_path, sort_by_yaw_target_samples_path=self.training_data_dst_path if self.sort_by_yaw else None,
|
||||||
|
|
|
@ -59,7 +59,7 @@ class SAEModel(ModelBase):
|
||||||
default_e_ch_dims = 42
|
default_e_ch_dims = 42
|
||||||
default_d_ch_dims = default_e_ch_dims // 2
|
default_d_ch_dims = default_e_ch_dims // 2
|
||||||
def_ca_weights = False
|
def_ca_weights = False
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
self.options['ae_dims'] = np.clip ( io.input_int("AutoEncoder dims (32-1024 ?:help skip:%d) : " % (default_ae_dims) , default_ae_dims, help_message="All face information will packed to AE dims. If amount of AE dims are not enough, then for example closed eyes will not be recognized. More dims are better, but require more VRAM. You can fine-tune model size to fit your GPU." ), 32, 1024 )
|
||||||
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
self.options['e_ch_dims'] = np.clip ( io.input_int("Encoder dims per channel (21-85 ?:help skip:%d) : " % (default_e_ch_dims) , default_e_ch_dims, help_message="More encoder dims help to recognize more facial features, but require more VRAM. You can fine-tune model size to fit your GPU." ), 21, 85 )
|
||||||
|
@ -87,16 +87,21 @@ class SAEModel(ModelBase):
|
||||||
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
default_bg_style_power = default_bg_style_power if is_first_run else self.options.get('bg_style_power', default_bg_style_power)
|
||||||
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
self.options['bg_style_power'] = np.clip ( io.input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power,
|
||||||
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
help_message="Learn to transfer image around face. This can make face more like dst. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
|
||||||
|
|
||||||
|
default_apply_random_ct = False if is_first_run else self.options.get('apply_random_ct', False)
|
||||||
|
self.options['apply_random_ct'] = io.input_bool ("Apply random color transfer to src faceset? (y/n, ?:help skip:%s) : " % (yn_str[default_apply_random_ct]), default_apply_random_ct, help_message="Increase variativity of src samples by apply RCT color transfer from random dst samples.")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||||
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
|
||||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||||
|
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
self.options['pretrain'] = io.input_bool ("Pretrain the model? (y/n, ?:help skip:n) : ", False, help_message="Pretrain the model with large amount of various faces. This technique may help to train the fake with overly different face shapes and light conditions of src/dst data. Face will be look more like a morphed. To reduce the morph effect, some model files will be initialized but not be updated after pretrain: LIAE: inter_AB.h5 DF: encoder.h5. The longer you pretrain the model the more morphed face will look. After that, save and run the training again.")
|
||||||
else:
|
else:
|
||||||
self.options['pretrain'] = False
|
self.options['pretrain'] = False
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onInitialize(self):
|
def onInitialize(self):
|
||||||
exec(nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
|
@ -110,13 +115,14 @@ class SAEModel(ModelBase):
|
||||||
self.pretrain = self.options['pretrain'] = self.options.get('pretrain', False)
|
self.pretrain = self.options['pretrain'] = self.options.get('pretrain', False)
|
||||||
if not self.pretrain:
|
if not self.pretrain:
|
||||||
self.options.pop('pretrain')
|
self.options.pop('pretrain')
|
||||||
|
|
||||||
d_residual_blocks = True
|
d_residual_blocks = True
|
||||||
bgr_shape = (resolution, resolution, 3)
|
bgr_shape = (resolution, resolution, 3)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
|
||||||
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
self.ms_count = ms_count = 3 if (self.options['multiscale_decoder']) else 1
|
||||||
|
|
||||||
|
apply_random_ct = self.options.get('apply_random_ct', False)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
warped_src = Input(bgr_shape)
|
warped_src = Input(bgr_shape)
|
||||||
|
@ -133,8 +139,8 @@ class SAEModel(ModelBase):
|
||||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
|
||||||
common_flow_kwargs = { 'padding': 'zero',
|
common_flow_kwargs = { 'padding': 'zero',
|
||||||
'norm': 'norm',
|
'norm': '',
|
||||||
'act':'' }
|
'act':'' }
|
||||||
models_list = []
|
models_list = []
|
||||||
weights_to_load = []
|
weights_to_load = []
|
||||||
if 'liae' in self.options['archi']:
|
if 'liae' in self.options['archi']:
|
||||||
|
@ -149,11 +155,11 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoder = modelify(SAEModel.LIAEDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
models_list += [self.encoder, self.inter_B, self.inter_AB, self.decoder]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs)) (inter_output_Inputs)
|
||||||
models_list += [self.decoderm]
|
models_list += [self.decoderm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
[self.inter_B , 'inter_B.h5'],
|
[self.inter_B , 'inter_B.h5'],
|
||||||
|
@ -191,12 +197,12 @@ class SAEModel(ModelBase):
|
||||||
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_src = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dst = modelify(SAEModel.DFDecFlow (bgr_shape[2],ch_dims=d_ch_dims, multiscale_count=self.ms_count, add_residual_blocks=d_residual_blocks, **common_flow_kwargs )) (dec_Inputs)
|
||||||
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
models_list += [self.encoder, self.decoder_src, self.decoder_dst]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_srcm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ch_dims=d_ch_dims, **common_flow_kwargs )) (dec_Inputs)
|
||||||
models_list += [self.decoder_srcm, self.decoder_dstm]
|
models_list += [self.decoder_srcm, self.decoder_dstm]
|
||||||
|
|
||||||
if not self.is_first_run():
|
if not self.is_first_run():
|
||||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||||
[self.decoder_src, 'decoder_src.h5'],
|
[self.decoder_src, 'decoder_src.h5'],
|
||||||
|
@ -217,18 +223,18 @@ class SAEModel(ModelBase):
|
||||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
if self.options.get('ca_weights',False):
|
if self.options.get('ca_weights',False):
|
||||||
conv_weights_list = []
|
conv_weights_list = []
|
||||||
for model in models_list:
|
for model in models_list:
|
||||||
for layer in model.layers:
|
for layer in model.layers:
|
||||||
if type(layer) == keras.layers.Conv2D:
|
if type(layer) == keras.layers.Conv2D:
|
||||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||||
CAInitializerMP ( conv_weights_list )
|
CAInitializerMP ( conv_weights_list )
|
||||||
else:
|
else:
|
||||||
self.load_weights_safe(weights_to_load)
|
self.load_weights_safe(weights_to_load)
|
||||||
|
|
||||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
@ -264,7 +270,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
if self.is_training_mode:
|
if self.is_training_mode:
|
||||||
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
|
||||||
|
@ -328,7 +334,7 @@ class SAEModel(ModelBase):
|
||||||
else:
|
else:
|
||||||
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1] ] )
|
||||||
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
|
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1], pred_dst_dstm[-1], pred_src_dstm[-1] ])
|
||||||
|
@ -345,29 +351,31 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||||
|
|
||||||
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ]
|
|
||||||
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
|
||||||
output_sample_types += [ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
|
||||||
|
|
||||||
training_data_src_path = self.training_data_src_path
|
training_data_src_path = self.training_data_src_path
|
||||||
training_data_dst_path = self.training_data_dst_path
|
training_data_dst_path = self.training_data_dst_path
|
||||||
sort_by_yaw = self.sort_by_yaw
|
sort_by_yaw = self.sort_by_yaw
|
||||||
|
|
||||||
if self.pretrain and self.pretraining_data_path is not None:
|
if self.pretrain and self.pretraining_data_path is not None:
|
||||||
training_data_src_path = self.pretraining_data_path
|
training_data_src_path = self.pretraining_data_path
|
||||||
training_data_dst_path = self.pretraining_data_path
|
training_data_dst_path = self.pretraining_data_path
|
||||||
sort_by_yaw = False
|
sort_by_yaw = False
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
self.set_training_data_generators ([
|
||||||
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
debug=self.is_debug(), batch_size=self.batch_size,
|
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
|
||||||
|
debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
|
||||||
output_sample_types=output_sample_types ),
|
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct} ] + \
|
||||||
|
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i), 'apply_ct': apply_random_ct } for i in range(ms_count)] + \
|
||||||
|
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)]
|
||||||
|
),
|
||||||
|
|
||||||
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
output_sample_types=output_sample_types )
|
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution} ] + \
|
||||||
])
|
[ {'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution // (2**i)} for i in range(ms_count)] + \
|
||||||
|
[ {'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution // (2**i) } for i in range(ms_count)])
|
||||||
|
])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onSave(self):
|
def onSave(self):
|
||||||
|
@ -380,23 +388,23 @@ class SAEModel(ModelBase):
|
||||||
[self.inter_B, 'inter_B.h5'],
|
[self.inter_B, 'inter_B.h5'],
|
||||||
[self.decoder, 'decoder.h5']
|
[self.decoder, 'decoder.h5']
|
||||||
]
|
]
|
||||||
|
|
||||||
if not self.pretrain or self.iter == 0:
|
if not self.pretrain or self.iter == 0:
|
||||||
ar += [ [self.inter_AB, 'inter_AB.h5'],
|
ar += [ [self.inter_AB, 'inter_AB.h5'],
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
ar += [ [self.decoderm, 'decoderm.h5'] ]
|
||||||
|
|
||||||
elif 'df' in self.options['archi']:
|
elif 'df' in self.options['archi']:
|
||||||
if not self.pretrain or self.iter == 0:
|
if not self.pretrain or self.iter == 0:
|
||||||
ar += [ [self.encoder, 'encoder.h5'],
|
ar += [ [self.encoder, 'encoder.h5'],
|
||||||
]
|
]
|
||||||
|
|
||||||
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
ar += [ [self.decoder_src, 'decoder_src.h5'],
|
||||||
[self.decoder_dst, 'decoder_dst.h5']
|
[self.decoder_dst, 'decoder_dst.h5']
|
||||||
]
|
]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
ar += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||||
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
[self.decoder_dstm, 'decoder_dstm.h5'] ]
|
||||||
|
@ -442,15 +450,15 @@ class SAEModel(ModelBase):
|
||||||
for i in range(0, len(test_S)):
|
for i in range(0, len(test_S)):
|
||||||
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
ar = S[i], SS[i], D[i], DD[i], SD[i]
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
result += [ ('SAE', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
st_m = []
|
st_m = []
|
||||||
for i in range(0, len(test_S)):
|
for i in range(0, len(test_S)):
|
||||||
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
ar = S[i]*test_S_m[i], SS[i], D[i]*test_D_m[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
|
||||||
st_m.append ( np.concatenate ( ar, axis=1) )
|
st_m.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
result += [ ('SAE masked', np.concatenate (st_m, axis=0 )), ]
|
result += [ ('SAE masked', np.concatenate (st_m, axis=0 )), ]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
@ -458,7 +466,7 @@ class SAEModel(ModelBase):
|
||||||
def predictor_func (self, face):
|
def predictor_func (self, face):
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert ([face[np.newaxis,...]])
|
bgr, mask_dst_dstm, mask_src_dstm = self.AE_convert ([face[np.newaxis,...]])
|
||||||
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
mask = mask_dst_dstm[0] * mask_src_dstm[0]
|
||||||
return bgr[0], mask[...,0]
|
return bgr[0], mask[...,0]
|
||||||
else:
|
else:
|
||||||
bgr, = self.AE_convert ([face[np.newaxis,...]])
|
bgr, = self.AE_convert ([face[np.newaxis,...]])
|
||||||
|
@ -493,13 +501,13 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
def NormPass(x):
|
def NormPass(x):
|
||||||
return x
|
return x
|
||||||
|
|
||||||
def Norm(norm=''):
|
def Norm(norm=''):
|
||||||
if norm == 'bn':
|
if norm == 'bn':
|
||||||
return BatchNormalization(axis=-1)
|
return BatchNormalization(axis=-1)
|
||||||
else:
|
else:
|
||||||
return NormPass
|
return NormPass
|
||||||
|
|
||||||
def Act(act='', lrelu_alpha=0.1):
|
def Act(act='', lrelu_alpha=0.1):
|
||||||
if act == 'prelu':
|
if act == 'prelu':
|
||||||
return PReLU()
|
return PReLU()
|
||||||
|
@ -549,7 +557,7 @@ class SAEModel(ModelBase):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
upscale = partial(SAEModel.upscale, **kwargs)
|
upscale = partial(SAEModel.upscale, **kwargs)
|
||||||
downscale = partial(SAEModel.downscale, **kwargs)
|
downscale = partial(SAEModel.downscale, **kwargs)
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
dims = K.int_shape(input)[-1]*ch_dims
|
dims = K.int_shape(input)[-1]*ch_dims
|
||||||
|
|
||||||
|
@ -571,7 +579,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x = Dense(ae_dims)(x)
|
x = Dense(ae_dims)(x)
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims*2)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims*2))(x)
|
||||||
x = upscale(ae_dims*2)(x)
|
x = upscale(ae_dims*2)(x)
|
||||||
|
@ -635,8 +643,8 @@ class SAEModel(ModelBase):
|
||||||
x = downscale(dims*4)(x)
|
x = downscale(dims*4)(x)
|
||||||
x = downscale(dims*8)(x)
|
x = downscale(dims*8)(x)
|
||||||
|
|
||||||
x = Dense(ae_dims)(Flatten()(x))
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||||
x = upscale(ae_dims)(x)
|
x = upscale(ae_dims)(x)
|
||||||
return x
|
return x
|
||||||
|
|
|
@ -1,11 +1,14 @@
|
||||||
import traceback
|
|
||||||
import numpy as np
|
|
||||||
import cv2
|
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
from utils import iter_utils
|
import traceback
|
||||||
from facelib import LandmarksProcessor
|
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from facelib import LandmarksProcessor
|
||||||
|
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
|
||||||
|
SampleType)
|
||||||
|
from utils import iter_utils
|
||||||
|
|
||||||
from samplelib import SampleType, SampleProcessor, SampleLoader, SampleGeneratorBase
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
arg
|
arg
|
||||||
|
@ -15,7 +18,7 @@ output_sample_types = [
|
||||||
]
|
]
|
||||||
'''
|
'''
|
||||||
class SampleGeneratorFace(SampleGeneratorBase):
|
class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None, **kwargs):
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
self.sample_process_options = sample_process_options
|
self.sample_process_options = sample_process_options
|
||||||
self.output_sample_types = output_sample_types
|
self.output_sample_types = output_sample_types
|
||||||
|
@ -32,15 +35,17 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
raise ValueError("len(generators_random_seed) != generators_count")
|
raise ValueError("len(generators_random_seed) != generators_count")
|
||||||
|
|
||||||
self.generators_random_seed = generators_random_seed
|
self.generators_random_seed = generators_random_seed
|
||||||
|
|
||||||
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||||
|
|
||||||
|
ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples) )]
|
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )]
|
||||||
else:
|
else:
|
||||||
self.generators_count = min ( generators_count, len(samples) )
|
self.generators_count = min ( generators_count, len(samples) )
|
||||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count] ) ) for i in range(self.generators_count) ]
|
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ]
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
@ -53,14 +58,16 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
return next(generator)
|
return next(generator)
|
||||||
|
|
||||||
def batch_func(self, param ):
|
def batch_func(self, param ):
|
||||||
generator_id, samples = param
|
generator_id, samples, ct_samples = param
|
||||||
|
|
||||||
if self.generators_random_seed is not None:
|
if self.generators_random_seed is not None:
|
||||||
np.random.seed ( self.generators_random_seed[generator_id] )
|
np.random.seed ( self.generators_random_seed[generator_id] )
|
||||||
|
|
||||||
samples_len = len(samples)
|
samples_len = len(samples)
|
||||||
samples_idxs = [*range(samples_len)]
|
samples_idxs = [*range(samples_len)]
|
||||||
|
|
||||||
|
ct_samples_len = len(ct_samples) if ct_samples is not None else 0
|
||||||
|
|
||||||
if len(samples_idxs) == 0:
|
if len(samples_idxs) == 0:
|
||||||
raise ValueError('No training data provided.')
|
raise ValueError('No training data provided.')
|
||||||
|
|
||||||
|
@ -106,7 +113,8 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
|
|
||||||
if sample is not None:
|
if sample is not None:
|
||||||
try:
|
try:
|
||||||
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug)
|
x = SampleProcessor.process (sample, self.sample_process_options, self.output_sample_types, self.debug,
|
||||||
|
ct_sample=ct_samples[np.random.randint(ct_samples_len)] if ct_samples is not None else None )
|
||||||
except:
|
except:
|
||||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||||
|
|
||||||
|
|
|
@ -13,7 +13,7 @@ output_sample_types = [
|
||||||
{} opts,
|
{} opts,
|
||||||
...
|
...
|
||||||
]
|
]
|
||||||
|
|
||||||
opts:
|
opts:
|
||||||
'types' : (S,S,...,S)
|
'types' : (S,S,...,S)
|
||||||
where S:
|
where S:
|
||||||
|
@ -23,31 +23,30 @@ opts:
|
||||||
'IMG_TRANSFORMED'
|
'IMG_TRANSFORMED'
|
||||||
'IMG_LANDMARKS_ARRAY' #currently unused
|
'IMG_LANDMARKS_ARRAY' #currently unused
|
||||||
'IMG_PITCH_YAW_ROLL'
|
'IMG_PITCH_YAW_ROLL'
|
||||||
|
|
||||||
'FACE_TYPE_HALF'
|
'FACE_TYPE_HALF'
|
||||||
'FACE_TYPE_FULL'
|
'FACE_TYPE_FULL'
|
||||||
'FACE_TYPE_HEAD' #currently unused
|
'FACE_TYPE_HEAD' #currently unused
|
||||||
'FACE_TYPE_AVATAR' #currently unused
|
'FACE_TYPE_AVATAR' #currently unused
|
||||||
|
|
||||||
'FACE_MASK_FULL'
|
|
||||||
'FACE_MASK_EYES' #currently unused
|
|
||||||
|
|
||||||
'MODE_BGR' #BGR
|
'MODE_BGR' #BGR
|
||||||
'MODE_G' #Grayscale
|
'MODE_G' #Grayscale
|
||||||
'MODE_GGG' #3xGrayscale
|
'MODE_GGG' #3xGrayscale
|
||||||
'MODE_M' #mask only
|
'MODE_M' #mask only
|
||||||
'MODE_BGR_SHUFFLE' #BGR shuffle
|
'MODE_BGR_SHUFFLE' #BGR shuffle
|
||||||
|
|
||||||
'resolution' : N
|
'resolution' : N
|
||||||
|
|
||||||
'motion_blur' : (chance_int, range) - chance 0..100 to apply to face (not mask), and range [1..3] where 3 is highest power of motion blur
|
'motion_blur' : (chance_int, range) - chance 0..100 to apply to face (not mask), and range [1..3] where 3 is highest power of motion blur
|
||||||
|
|
||||||
|
'apply_ct' : bool
|
||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class SampleProcessor(object):
|
class SampleProcessor(object):
|
||||||
class Types(IntEnum):
|
class Types(IntEnum):
|
||||||
NONE = 0
|
NONE = 0
|
||||||
|
|
||||||
IMG_TYPE_BEGIN = 1
|
IMG_TYPE_BEGIN = 1
|
||||||
IMG_SOURCE = 1
|
IMG_SOURCE = 1
|
||||||
IMG_WARPED = 2
|
IMG_WARPED = 2
|
||||||
|
@ -57,19 +56,14 @@ class SampleProcessor(object):
|
||||||
IMG_PITCH_YAW_ROLL = 6
|
IMG_PITCH_YAW_ROLL = 6
|
||||||
IMG_PITCH_YAW_ROLL_SIGMOID = 7
|
IMG_PITCH_YAW_ROLL_SIGMOID = 7
|
||||||
IMG_TYPE_END = 10
|
IMG_TYPE_END = 10
|
||||||
|
|
||||||
FACE_TYPE_BEGIN = 10
|
FACE_TYPE_BEGIN = 10
|
||||||
FACE_TYPE_HALF = 10
|
FACE_TYPE_HALF = 10
|
||||||
FACE_TYPE_FULL = 11
|
FACE_TYPE_FULL = 11
|
||||||
FACE_TYPE_HEAD = 12 #currently unused
|
FACE_TYPE_HEAD = 12 #currently unused
|
||||||
FACE_TYPE_AVATAR = 13 #currently unused
|
FACE_TYPE_AVATAR = 13 #currently unused
|
||||||
FACE_TYPE_END = 20
|
FACE_TYPE_END = 20
|
||||||
|
|
||||||
FACE_MASK_BEGIN = 20
|
|
||||||
FACE_MASK_FULL = 20
|
|
||||||
FACE_MASK_EYES = 21 #currently unused
|
|
||||||
FACE_MASK_END = 30
|
|
||||||
|
|
||||||
MODE_BEGIN = 40
|
MODE_BEGIN = 40
|
||||||
MODE_BGR = 40 #BGR
|
MODE_BGR = 40 #BGR
|
||||||
MODE_G = 41 #Grayscale
|
MODE_G = 41 #Grayscale
|
||||||
|
@ -77,7 +71,7 @@ class SampleProcessor(object):
|
||||||
MODE_M = 43 #mask only
|
MODE_M = 43 #mask only
|
||||||
MODE_BGR_SHUFFLE = 44 #BGR shuffle
|
MODE_BGR_SHUFFLE = 44 #BGR shuffle
|
||||||
MODE_END = 50
|
MODE_END = 50
|
||||||
|
|
||||||
class Options(object):
|
class Options(object):
|
||||||
|
|
||||||
def __init__(self, random_flip = True, normalize_tanh = False, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
|
def __init__(self, random_flip = True, normalize_tanh = False, rotation_range=[-10,10], scale_range=[-0.05, 0.05], tx_range=[-0.05, 0.05], ty_range=[-0.05, 0.05] ):
|
||||||
|
@ -89,10 +83,12 @@ class SampleProcessor(object):
|
||||||
self.ty_range = ty_range
|
self.ty_range = ty_range
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def process (sample, sample_process_options, output_sample_types, debug):
|
def process (sample, sample_process_options, output_sample_types, debug, ct_sample=None):
|
||||||
SPTF = SampleProcessor.Types
|
SPTF = SampleProcessor.Types
|
||||||
|
|
||||||
sample_bgr = sample.load_bgr()
|
sample_bgr = sample.load_bgr()
|
||||||
|
ct_sample_bgr = None
|
||||||
|
ct_sample_mask = None
|
||||||
h,w,c = sample_bgr.shape
|
h,w,c = sample_bgr.shape
|
||||||
|
|
||||||
is_face_sample = sample.landmarks is not None
|
is_face_sample = sample.landmarks is not None
|
||||||
|
@ -103,25 +99,26 @@ class SampleProcessor(object):
|
||||||
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
params = imagelib.gen_warp_params(sample_bgr, sample_process_options.random_flip, rotation_range=sample_process_options.rotation_range, scale_range=sample_process_options.scale_range, tx_range=sample_process_options.tx_range, ty_range=sample_process_options.ty_range )
|
||||||
|
|
||||||
cached_images = collections.defaultdict(dict)
|
cached_images = collections.defaultdict(dict)
|
||||||
|
|
||||||
sample_rnd_seed = np.random.randint(0x80000000)
|
sample_rnd_seed = np.random.randint(0x80000000)
|
||||||
|
|
||||||
SPTF_FACETYPE_TO_FACETYPE = { SPTF.FACE_TYPE_HALF : FaceType.HALF,
|
SPTF_FACETYPE_TO_FACETYPE = { SPTF.FACE_TYPE_HALF : FaceType.HALF,
|
||||||
SPTF.FACE_TYPE_FULL : FaceType.FULL,
|
SPTF.FACE_TYPE_FULL : FaceType.FULL,
|
||||||
SPTF.FACE_TYPE_HEAD : FaceType.HEAD,
|
SPTF.FACE_TYPE_HEAD : FaceType.HEAD,
|
||||||
SPTF.FACE_TYPE_AVATAR : FaceType.AVATAR }
|
SPTF.FACE_TYPE_AVATAR : FaceType.AVATAR }
|
||||||
|
|
||||||
outputs = []
|
outputs = []
|
||||||
for opts in output_sample_types:
|
for opts in output_sample_types:
|
||||||
|
|
||||||
resolution = opts.get('resolution', 0)
|
resolution = opts.get('resolution', 0)
|
||||||
types = opts.get('types', [] )
|
types = opts.get('types', [] )
|
||||||
|
|
||||||
random_sub_res = opts.get('random_sub_res', 0)
|
random_sub_res = opts.get('random_sub_res', 0)
|
||||||
normalize_std_dev = opts.get('normalize_std_dev', False)
|
normalize_std_dev = opts.get('normalize_std_dev', False)
|
||||||
normalize_vgg = opts.get('normalize_vgg', False)
|
normalize_vgg = opts.get('normalize_vgg', False)
|
||||||
motion_blur = opts.get('motion_blur', None)
|
motion_blur = opts.get('motion_blur', None)
|
||||||
|
apply_ct = opts.get('apply_ct', False)
|
||||||
|
|
||||||
img_type = SPTF.NONE
|
img_type = SPTF.NONE
|
||||||
target_face_type = SPTF.NONE
|
target_face_type = SPTF.NONE
|
||||||
face_mask_type = SPTF.NONE
|
face_mask_type = SPTF.NONE
|
||||||
|
@ -131,11 +128,9 @@ class SampleProcessor(object):
|
||||||
img_type = t
|
img_type = t
|
||||||
elif t >= SPTF.FACE_TYPE_BEGIN and t < SPTF.FACE_TYPE_END:
|
elif t >= SPTF.FACE_TYPE_BEGIN and t < SPTF.FACE_TYPE_END:
|
||||||
target_face_type = t
|
target_face_type = t
|
||||||
elif t >= SPTF.FACE_MASK_BEGIN and t < SPTF.FACE_MASK_END:
|
|
||||||
face_mask_type = t
|
|
||||||
elif t >= SPTF.MODE_BEGIN and t < SPTF.MODE_END:
|
elif t >= SPTF.MODE_BEGIN and t < SPTF.MODE_END:
|
||||||
mode_type = t
|
mode_type = t
|
||||||
|
|
||||||
if img_type == SPTF.NONE:
|
if img_type == SPTF.NONE:
|
||||||
raise ValueError ('expected IMG_ type')
|
raise ValueError ('expected IMG_ type')
|
||||||
|
|
||||||
|
@ -152,55 +147,52 @@ class SampleProcessor(object):
|
||||||
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll (sample.landmarks)
|
pitch, yaw, roll = LandmarksProcessor.estimate_pitch_yaw_roll (sample.landmarks)
|
||||||
if params['flip']:
|
if params['flip']:
|
||||||
yaw = -yaw
|
yaw = -yaw
|
||||||
|
|
||||||
if img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
if img_type == SPTF.IMG_PITCH_YAW_ROLL_SIGMOID:
|
||||||
pitch = (pitch+1.0) / 2.0
|
pitch = (pitch+1.0) / 2.0
|
||||||
yaw = (yaw+1.0) / 2.0
|
yaw = (yaw+1.0) / 2.0
|
||||||
roll = (roll+1.0) / 2.0
|
roll = (roll+1.0) / 2.0
|
||||||
|
|
||||||
img = (pitch, yaw, roll)
|
img = (pitch, yaw, roll)
|
||||||
else:
|
else:
|
||||||
if mode_type == SPTF.NONE:
|
if mode_type == SPTF.NONE:
|
||||||
raise ValueError ('expected MODE_ type')
|
raise ValueError ('expected MODE_ type')
|
||||||
|
|
||||||
img = cached_images.get(img_type, {}).get(face_mask_type, None)
|
img = cached_images.get(img_type, None)
|
||||||
if img is None:
|
if img is None:
|
||||||
|
|
||||||
img = sample_bgr
|
img = sample_bgr
|
||||||
|
mask = None
|
||||||
cur_sample = sample
|
cur_sample = sample
|
||||||
|
|
||||||
if is_face_sample:
|
if is_face_sample:
|
||||||
if motion_blur is not None:
|
if motion_blur is not None:
|
||||||
chance, mb_range = motion_blur
|
chance, mb_range = motion_blur
|
||||||
chance = np.clip(chance, 0, 100)
|
chance = np.clip(chance, 0, 100)
|
||||||
|
|
||||||
if np.random.randint(100) < chance:
|
if np.random.randint(100) < chance:
|
||||||
mb_range = [3,5,7,9][ : np.clip(mb_range, 0, 3)+1 ]
|
mb_range = [3,5,7,9][ : np.clip(mb_range, 0, 3)+1 ]
|
||||||
dim = mb_range[ np.random.randint(len(mb_range) ) ]
|
dim = mb_range[ np.random.randint(len(mb_range) ) ]
|
||||||
img = imagelib.LinearMotionBlur (img, dim, np.random.randint(180) )
|
img = imagelib.LinearMotionBlur (img, dim, np.random.randint(180) )
|
||||||
|
|
||||||
if face_mask_type == SPTF.FACE_MASK_FULL:
|
mask = cur_sample.load_fanseg_mask() #using fanseg_mask if exist
|
||||||
mask = cur_sample.load_fanseg_mask() #using fanseg_mask if exist
|
|
||||||
|
|
||||||
if mask is None:
|
if mask is None:
|
||||||
mask = LandmarksProcessor.get_image_hull_mask (img.shape, cur_sample.landmarks)
|
mask = LandmarksProcessor.get_image_hull_mask (img.shape, cur_sample.landmarks)
|
||||||
|
|
||||||
if cur_sample.ie_polys is not None:
|
if cur_sample.ie_polys is not None:
|
||||||
cur_sample.ie_polys.overlay_mask(mask)
|
cur_sample.ie_polys.overlay_mask(mask)
|
||||||
|
|
||||||
img = np.concatenate( (img, mask ), -1 )
|
|
||||||
elif face_mask_type == SPTF.FACE_MASK_EYES:
|
|
||||||
mask = LandmarksProcessor.get_image_eye_mask (img.shape, cur_sample.landmarks)
|
|
||||||
mask = np.expand_dims (cv2.blur (mask, ( w // 32, w // 32 ) ), -1)
|
|
||||||
mask[mask > 0.0] = 1.0
|
|
||||||
img = np.concatenate( (img, mask ), -1 )
|
|
||||||
|
|
||||||
warp = (img_type==SPTF.IMG_WARPED or img_type==SPTF.IMG_WARPED_TRANSFORMED)
|
warp = (img_type==SPTF.IMG_WARPED or img_type==SPTF.IMG_WARPED_TRANSFORMED)
|
||||||
transform = (img_type==SPTF.IMG_WARPED_TRANSFORMED or img_type==SPTF.IMG_TRANSFORMED)
|
transform = (img_type==SPTF.IMG_WARPED_TRANSFORMED or img_type==SPTF.IMG_TRANSFORMED)
|
||||||
flip = img_type != SPTF.IMG_WARPED
|
flip = img_type != SPTF.IMG_WARPED
|
||||||
is_border_replicate = face_mask_type == SPTF.NONE
|
|
||||||
|
img = imagelib.warp_by_params (params, img, warp, transform, flip, True)
|
||||||
img = cached_images[img_type][face_mask_type] = imagelib.warp_by_params (params, img, warp, transform, flip, is_border_replicate)
|
if mask is not None:
|
||||||
|
mask = imagelib.warp_by_params (params, mask, warp, transform, flip, False)[...,np.newaxis]
|
||||||
|
img = np.concatenate( (img, mask ), -1 )
|
||||||
|
|
||||||
|
cached_images[img_type] = img
|
||||||
|
|
||||||
if is_face_sample and target_face_type != SPTF.NONE:
|
if is_face_sample and target_face_type != SPTF.NONE:
|
||||||
ft = SPTF_FACETYPE_TO_FACETYPE[target_face_type]
|
ft = SPTF_FACETYPE_TO_FACETYPE[target_face_type]
|
||||||
|
@ -217,9 +209,24 @@ class SampleProcessor(object):
|
||||||
start_y = rnd_state.randint(sub_size+1)
|
start_y = rnd_state.randint(sub_size+1)
|
||||||
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
img = img[start_y:start_y+sub_size,start_x:start_x+sub_size,:]
|
||||||
|
|
||||||
|
img = np.clip(img, 0, 1)
|
||||||
img_bgr = img[...,0:3]
|
img_bgr = img[...,0:3]
|
||||||
img_mask = img[...,3:4]
|
img_mask = img[...,3:4]
|
||||||
|
|
||||||
|
if apply_ct:
|
||||||
|
if ct_sample_bgr is None:
|
||||||
|
ct_sample_bgr = ct_sample.load_bgr()
|
||||||
|
ct_sample_mask = LandmarksProcessor.get_image_hull_mask (ct_sample_bgr.shape, ct_sample.landmarks)
|
||||||
|
|
||||||
|
ct_sample_bgr_resized = cv2.resize( ct_sample_bgr, (resolution,resolution), cv2.INTER_LINEAR )
|
||||||
|
ct_sample_mask_resized = cv2.resize( ct_sample_mask, (resolution,resolution), cv2.INTER_LINEAR )[...,np.newaxis]
|
||||||
|
|
||||||
|
img_bgr = imagelib.reinhard_color_transfer ( np.clip( (img_bgr*255) .astype(np.uint8), 0, 255),
|
||||||
|
np.clip( (ct_sample_bgr_resized*255).astype(np.uint8), 0, 255),
|
||||||
|
source_mask=img_mask, target_mask=ct_sample_mask_resized)
|
||||||
|
|
||||||
|
img_bgr = np.clip( img_bgr.astype(np.float32) / 255.0, 0.0, 1.0)
|
||||||
|
|
||||||
if normalize_std_dev:
|
if normalize_std_dev:
|
||||||
img_bgr = (img_bgr - img_bgr.mean( (0,1)) ) / img_bgr.std( (0,1) )
|
img_bgr = (img_bgr - img_bgr.mean( (0,1)) ) / img_bgr.std( (0,1) )
|
||||||
elif normalize_vgg:
|
elif normalize_vgg:
|
||||||
|
@ -227,7 +234,7 @@ class SampleProcessor(object):
|
||||||
img_bgr[:,:,0] -= 103.939
|
img_bgr[:,:,0] -= 103.939
|
||||||
img_bgr[:,:,1] -= 116.779
|
img_bgr[:,:,1] -= 116.779
|
||||||
img_bgr[:,:,2] -= 123.68
|
img_bgr[:,:,2] -= 123.68
|
||||||
|
|
||||||
if mode_type == SPTF.MODE_BGR:
|
if mode_type == SPTF.MODE_BGR:
|
||||||
img = img_bgr
|
img = img_bgr
|
||||||
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
|
||||||
|
@ -239,8 +246,6 @@ class SampleProcessor(object):
|
||||||
elif mode_type == SPTF.MODE_GGG:
|
elif mode_type == SPTF.MODE_GGG:
|
||||||
img = np.concatenate ( ( np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1), img_mask), -1)
|
img = np.concatenate ( ( np.repeat ( np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1), (3,), -1), img_mask), -1)
|
||||||
elif mode_type == SPTF.MODE_M and is_face_sample:
|
elif mode_type == SPTF.MODE_M and is_face_sample:
|
||||||
if face_mask_type == SPTF.NONE:
|
|
||||||
raise ValueError ('no face_mask_type defined')
|
|
||||||
img = img_mask
|
img = img_mask
|
||||||
|
|
||||||
if not debug:
|
if not debug:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue