SAEHD: added option Enable random warp of samples, default is on

Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.
This commit is contained in:
Colombo 2019-10-12 10:31:50 +04:00
parent e15f846d08
commit 92f14dee70
4 changed files with 33 additions and 29 deletions

View file

@ -226,6 +226,12 @@ class ModelBase(object):
io.destroy_window(wnd_name)
else:
self.sample_for_preview = self.generate_next_sample()
try:
self.get_static_preview()
except:
self.sample_for_preview = self.generate_next_sample()
self.last_sample = self.sample_for_preview
###Generate text summary of model hyperparameters

View file

@ -63,6 +63,9 @@ class SAEv2Model(ModelBase):
default_face_style_power = 0.0
default_bg_style_power = 0.0
if is_first_run or ask_override:
default_random_warp = self.options.get('random_warp', True)
self.options['random_warp'] = io.input_str (f"Enable random warp of samples? ( y/n, ?:help skip:{yn_str[default_random_warp]}) : ", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness for less amount of iterations.")
default_face_style_power = default_face_style_power if is_first_run else self.options.get('face_style_power', default_face_style_power)
self.options['face_style_power'] = np.clip ( io.input_number("Face style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_face_style_power), default_face_style_power,
help_message="Learn to transfer face style details such as light and color conditions. Warning: Enable it only after 10k iters, when predicted face is clear enough to start learn style. Start from 0.1 value and check history changes. Enabling this option increases the chance of model collapse."), 0.0, 100.0 )
@ -84,6 +87,7 @@ class SAEv2Model(ModelBase):
self.options['clipgrad'] = False
else:
self.options['random_warp'] = self.options.get('random_warp', True)
self.options['face_style_power'] = self.options.get('face_style_power', default_face_style_power)
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
self.options['apply_random_ct'] = self.options.get('apply_random_ct', False)
@ -529,19 +533,21 @@ class SAEv2Model(ModelBase):
training_data_dst_path = self.pretraining_data_path
sort_by_yaw = False
t_img_warped = t.IMG_WARPED_TRANSFORMED if self.options['random_warp'] else t.IMG_TRANSFORMED
self.set_training_data_generators ([
SampleGeneratorFace(training_data_src_path, sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
random_ct_samples_path=training_data_dst_path if apply_random_ct else None,
debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, scale_range=np.array([-0.05, 0.05])+self.src_scale_mod / 100.0 ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct},
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution, 'apply_ct': apply_random_ct},
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution, 'apply_ct': apply_random_ct },
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution } ]
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
output_sample_types = [ {'types' : (t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr), 'resolution':resolution},
output_sample_types = [ {'types' : (t_img_warped, face_type, t_mode_bgr), 'resolution':resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t_mode_bgr), 'resolution': resolution},
{'types' : (t.IMG_TRANSFORMED, face_type, t.MODE_M), 'resolution': resolution} ])
])

View file

@ -631,26 +631,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
reduction_axes = list(range(len(input_shape)))
del reduction_axes[self.axis]
#broadcast_shape = [1] * len(input_shape)
#broadcast_shape[self.axis] = input_shape[self.axis]
#normed = x# (x - K.reshape(self.moving_mean,broadcast_shape) ) / ( K.sqrt( K.reshape(self.moving_variance,broadcast_shape)) +self.epsilon)
#normed *= K.reshape(gamma,[-1]+broadcast_shape[1:] )
#normed += K.reshape(beta, [-1]+broadcast_shape[1:] )
#mean = K.mean(x, axis=reduction_axes)
#self.moving_mean = self.add_weight(shape=(units,), name='moving_mean', initializer='zeros',trainable=False)
#self.moving_variance = self.add_weight(shape=(units,), name='moving_variance',initializer='ones', trainable=False)
#variance = K.var(x, axis=reduction_axes)
#sample_size = K.prod([ K.shape(x)[axis] for axis in reduction_axes ])
#sample_size = K.cast(sample_size, dtype=K.dtype(x))
#variance *= sample_size / (sample_size - (1.0 + self.epsilon))
#self.add_update([K.moving_average_update(self.moving_mean, mean, self.momentum),
# K.moving_average_update(self.moving_variance, variance, self.momentum)], None)
#return normed
del reduction_axes[0]
broadcast_shape = [1] * len(input_shape)
broadcast_shape[self.axis] = input_shape[self.axis]
mean = K.mean(x, reduction_axes, keepdims=True)

View file

@ -72,6 +72,7 @@ class SampleProcessor(object):
MODE_GGG = 42 #3xGrayscale
MODE_M = 43 #mask only
MODE_BGR_SHUFFLE = 44 #BGR shuffle
MODE_BGR_RANDOM_HUE_SHIFT = 45
MODE_END = 50
class Options(object):
@ -257,6 +258,15 @@ class SampleProcessor(object):
elif mode_type == SPTF.MODE_BGR_SHUFFLE:
rnd_state = np.random.RandomState (sample_rnd_seed)
img = np.take (img_bgr, rnd_state.permutation(img_bgr.shape[-1]), axis=-1)
elif mode_type == SPTF.MODE_BGR_RANDOM_HUE_SHIFT:
rnd_state = np.random.RandomState (sample_rnd_seed)
hsv = cv2.cvtColor(img_bgr, cv2.COLOR_BGR2HSV)
h, s, v = cv2.split(hsv)
h = (h + rnd_state.randint(360) ) % 360
hsv = cv2.merge([h, s, v])
img = np.clip( cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR) , 0, 1 )
elif mode_type == SPTF.MODE_G:
img = np.concatenate ( (np.expand_dims(cv2.cvtColor(img_bgr, cv2.COLOR_BGR2GRAY),-1),img_mask) , -1 )
elif mode_type == SPTF.MODE_GGG: