SAE collapse fix (#245)

* test

* _

* _

* upd dev_poseest

* SAE: finally collapses are fixed

* fix batch size help
This commit is contained in:
iperov 2019-04-24 09:38:26 +04:00 committed by GitHub
parent d62785ca5a
commit e1da9c56b4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
4 changed files with 87 additions and 70 deletions

View file

@ -12,6 +12,7 @@ from nnlib import nnlib
""" """
PoseEstimator estimates pitch, yaw, roll, from FAN aligned face. PoseEstimator estimates pitch, yaw, roll, from FAN aligned face.
trained on https://www.umdfaces.io trained on https://www.umdfaces.io
based on https://arxiv.org/pdf/1901.06778.pdf HYBRID COARSE-FINE CLASSIFICATION FOR HEAD POSE ESTIMATION
""" """
class PoseEstimator(object): class PoseEstimator(object):
@ -19,9 +20,12 @@ class PoseEstimator(object):
def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False): def __init__ (self, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
self.class_num = 91 self.angles = [90, 45, 30, 10, 2]
self.alpha_cat_losses = [0.07,0.05,0.03,0.01,0.01]
self.class_nums = [ angle+1 for angle in self.angles ]
self.model = PoseEstimator.BuildModel(resolution, class_nums=self.class_nums)
self.model = PoseEstimator.BuildModel(resolution, class_num=self.class_num)
if weights_file_root is not None: if weights_file_root is not None:
weights_file_root = Path(weights_file_root) weights_file_root = Path(weights_file_root)
@ -33,42 +37,44 @@ class PoseEstimator(object):
if load_weights: if load_weights:
self.model.load_weights (str(self.weights_path)) self.model.load_weights (str(self.weights_path))
idx_tensor = np.array([idx for idx in range(self.class_num)], dtype=K.floatx() )
idx_tensor = K.constant(idx_tensor)
inp_t, = self.model.inputs inp_t, = self.model.inputs
pitch_bins_t, yaw_bins_t, roll_bins_t = self.model.outputs bins_t = self.model.outputs
pitch_t, yaw_t, roll_t = K.sum ( pitch_bins_t * idx_tensor, 1), K.sum ( yaw_bins_t * idx_tensor, 1), K.sum ( roll_bins_t * idx_tensor, 1)
inp_pitch_bins_t = Input ( (self.class_num,) )
inp_pitch_t = Input ( (1,) ) inp_pitch_t = Input ( (1,) )
inp_yaw_bins_t = Input ( (self.class_num,) )
inp_yaw_t = Input ( (1,) ) inp_yaw_t = Input ( (1,) )
inp_roll_bins_t = Input ( (self.class_num,) )
inp_roll_t = Input ( (1,) ) inp_roll_t = Input ( (1,) )
alpha = 0.001 inp_bins_t = []
for class_num in self.class_nums:
pitch_loss = K.categorical_crossentropy(inp_pitch_bins_t, pitch_bins_t) \ inp_bins_t += [ Input ((class_num,)), Input ((class_num,)), Input ((class_num,)) ]
+ alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
yaw_loss = K.categorical_crossentropy(inp_yaw_bins_t, yaw_bins_t) \ loss_pitch = []
+ alpha * K.mean(K.square( inp_yaw_t - yaw_t), -1) loss_yaw = []
loss_roll = []
roll_loss = K.categorical_crossentropy(inp_roll_bins_t, roll_bins_t) \
+ alpha * K.mean(K.square( inp_roll_t - roll_t), -1)
for i,class_num in enumerate(self.class_nums):
a = self.alpha_cat_losses[i]
loss_pitch += [ a*K.categorical_crossentropy( inp_bins_t[i*3+0], bins_t[i*3+0] ) ]
loss_yaw += [ a*K.categorical_crossentropy( inp_bins_t[i*3+1], bins_t[i*3+1] ) ]
loss_roll += [ a*K.categorical_crossentropy( inp_bins_t[i*3+2], bins_t[i*3+2] ) ]
idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum ( bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
loss = K.mean( pitch_loss + yaw_loss + roll_loss ) reg_alpha = 0.02
reg_pitch_loss = reg_alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
reg_yaw_loss = reg_alpha * K.mean(K.square( inp_yaw_t - yaw_t), -1)
reg_roll_loss = reg_alpha * K.mean(K.square( inp_roll_t - roll_t), -1)
pitch_loss = reg_pitch_loss + sum(loss_pitch)
yaw_loss = reg_yaw_loss + sum(loss_yaw)
roll_loss = reg_roll_loss + sum(loss_roll)
opt = Adam(lr=0.001, tf_cpu_mode=2) opt = Adam(lr=0.001, tf_cpu_mode=0)
if training: if training:
self.train = K.function ([inp_t, inp_pitch_bins_t, inp_pitch_t, inp_yaw_bins_t, inp_yaw_t, inp_roll_bins_t, inp_roll_t], self.train = K.function ([inp_t, inp_pitch_t, inp_yaw_t, inp_roll_t] + inp_bins_t,
[loss], opt.get_updates(loss, self.model.trainable_weights) ) [K.mean(pitch_loss),K.mean(yaw_loss),K.mean(roll_loss)], opt.get_updates( [pitch_loss,yaw_loss,roll_loss], self.model.trainable_weights) )
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] ) self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
@ -82,19 +88,27 @@ class PoseEstimator(object):
self.model.save_weights (str(self.weights_path)) self.model.save_weights (str(self.weights_path))
def train_on_batch(self, imgs, pitch_yaw_roll): def train_on_batch(self, imgs, pitch_yaw_roll):
c = ( (pitch_yaw_roll+1) * 45.0 ).astype(np.int).astype(K.floatx()) pyr = pitch_yaw_roll+1
feed = [imgs]
inp_pitch = c[:,0:1] for i, (angle, class_num) in enumerate(zip(self.angles, self.class_nums)):
inp_yaw = c[:,1:2] c = np.round(pyr * (angle / 2) ).astype(K.floatx())
inp_roll = c[:,2:3] inp_pitch = c[:,0:1]
inp_yaw = c[:,1:2]
inp_roll = c[:,2:3]
if i == 0:
feed += [inp_pitch, inp_yaw, inp_roll]
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, class_num )
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, class_num )
inp_roll_bins = keras.utils.to_categorical(inp_roll, class_num )
feed += [inp_pitch_bins, inp_yaw_bins, inp_roll_bins]
#import code
#code.interact(local=dict(globals(), **locals()))
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, self.class_num ) pitch_loss,yaw_loss,roll_loss = self.train(feed)
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, self.class_num ) return pitch_loss,yaw_loss,roll_loss
inp_roll_bins = keras.utils.to_categorical(inp_roll, self.class_num )
loss, = self.train( [imgs, inp_pitch_bins, inp_pitch, inp_yaw_bins, inp_yaw, inp_roll_bins, inp_roll] )
return loss
def extract (self, input_image, is_input_tanh=False): def extract (self, input_image, is_input_tanh=False):
if is_input_tanh: if is_input_tanh:
@ -106,7 +120,7 @@ class PoseEstimator(object):
pitch, yaw, roll = self.view( [input_image] ) pitch, yaw, roll = self.view( [input_image] )
result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 ) result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 )
result = np.clip ( result / 45.0 - 1, -1.0, 1.0 ) result = np.clip ( result / (self.angles[0] / 2) - 1, -1.0, 1.0 )
if input_shape_len == 3: if input_shape_len == 3:
result = result[0] result = result[0]
@ -114,28 +128,31 @@ class PoseEstimator(object):
return result return result
@staticmethod @staticmethod
def BuildModel ( resolution, class_num): def BuildModel ( resolution, class_nums):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
inp = Input ( (resolution,resolution,3) ) inp = Input ( (resolution,resolution,3) )
x = inp x = inp
x = PoseEstimator.Flow(class_num=class_num)(x) x = PoseEstimator.Flow(class_nums=class_nums)(x)
model = Model(inp,x) model = Model(inp,x)
return model return model
@staticmethod @staticmethod
def Flow(class_num): def Flow(class_nums):
exec( nnlib.import_all(), locals(), globals() ) exec( nnlib.import_all(), locals(), globals() )
def func(input): def func(input):
x = input x = input
# resnet50 = keras.applications.ResNet50(include_top=False, weights='imagenet', input_shape=K.int_shape(x)[1:], pooling='avg') # resnet50 = keras.applications.ResNet50(include_top=False, weights=None, input_shape=K.int_shape(x)[1:], pooling='avg')
# x = resnet50(x) # x = resnet50(x)
# pitch = Dense(class_num, activation='softmax', name='pitch')(x) # output = []
# yaw = Dense(class_num, activation='softmax', name='yaw')(x) # for class_num in class_nums:
# roll = Dense(class_num, activation='softmax', name='roll')(x) # pitch = Dense(class_num, activation='softmax')(x)
# yaw = Dense(class_num, activation='softmax')(x)
# return [pitch, yaw, roll] # roll = Dense(class_num, activation='softmax')(x)
# output += [pitch,yaw,roll]
# return output
x = Conv2D(64, kernel_size=11, strides=4, padding='same', activation='relu')(x) x = Conv2D(64, kernel_size=11, strides=4, padding='same', activation='relu')(x)
x = MaxPooling2D( (3,3), strides=2 )(x) x = MaxPooling2D( (3,3), strides=2 )(x)
@ -153,10 +170,13 @@ class PoseEstimator(object):
x = Dropout(0.5)(x) x = Dropout(0.5)(x)
x = Dense(1024, activation='relu')(x) x = Dense(1024, activation='relu')(x)
pitch = Dense(class_num, activation='softmax', name='pitch')(x) output = []
yaw = Dense(class_num, activation='softmax', name='yaw')(x) for class_num in class_nums:
roll = Dense(class_num, activation='softmax', name='roll')(x) pitch = Dense(class_num, activation='softmax')(x)
yaw = Dense(class_num, activation='softmax')(x)
return [pitch, yaw, roll] roll = Dense(class_num, activation='softmax')(x)
output += [pitch,yaw,roll]
return output
return func return func

View file

@ -95,7 +95,7 @@ class ModelBase(object):
if ask_batch_size and (self.iter == 0 or ask_override): if ask_batch_size and (self.iter == 0 or ask_override):
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0)
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error and increases risk of model collapse during training. Tune this value for your videocard manually.")) self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % (default_batch_size), default_batch_size, help_message="Larger batch size is better for NN's generalization, but it can cause Out of Memory error. Tune this value for your videocard manually."))
else: else:
self.options['batch_size'] = self.options.get('batch_size', 0) self.options['batch_size'] = self.options.get('batch_size', 0)

View file

@ -44,18 +44,17 @@ class Model(ModelBase):
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF face_type = f.FACE_TYPE_FULL if self.options['face_type'] == 'f' else f.FACE_TYPE_HALF
normalize_vgg = False
self.set_training_data_generators ([ self.set_training_data_generators ([
SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0], motion_blur = [25, 1] ), #random_flip=True, sample_process_options=SampleProcessor.Options( rotation_range=[0,0], motion_blur = [25, 1] ), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution, {'normalize_vgg':normalize_vgg} ], output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE | f.OPT_APPLY_MOTION_BLUR, self.resolution ],
[f.PITCH_YAW_ROLL], [f.PITCH_YAW_ROLL],
]), ]),
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True, sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution, {'normalize_vgg':normalize_vgg} ], output_sample_types=[ [f.TRANSFORMED | face_type | f.MODE_BGR_SHUFFLE, self.resolution ],
[f.PITCH_YAW_ROLL], [f.PITCH_YAW_ROLL],
]) ])
]) ])
@ -68,9 +67,9 @@ class Model(ModelBase):
def onTrainOneIter(self, generators_samples, generators_list): def onTrainOneIter(self, generators_samples, generators_list):
target_src, pitch_yaw_roll = generators_samples[0] target_src, pitch_yaw_roll = generators_samples[0]
loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll ) pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_src, pitch_yaw_roll )
return ( ('loss', loss), ) return ( ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
#override #override
def onGetPreview(self, generators_samples): def onGetPreview(self, generators_samples):

View file

@ -248,9 +248,7 @@ class SAEModel(ModelBase):
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
alpha_rec = 100
if self.is_training_mode: if self.is_training_mode:
self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1) self.src_dst_mask_opt = Adam(lr=5e-5, beta_1=0.5, beta_2=0.999, tf_cpu_mode=self.options['optimizer_mode']-1)
@ -265,9 +263,9 @@ class SAEModel(ModelBase):
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
src_loss_batch = sum([ ( alpha_rec*K.square( dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i]) for i in range(len(target_src_masked_ar_opt)) ])
else: else:
src_loss_batch = sum([ K.mean ( alpha_rec*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ]) src_loss_batch = sum([ K.mean ( 50*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
src_loss = K.mean(src_loss_batch) src_loss = K.mean(src_loss_batch)
@ -279,15 +277,15 @@ class SAEModel(ModelBase):
bg_style_power = self.options['bg_style_power'] / 100.0 bg_style_power = self.options['bg_style_power'] / 100.0
if bg_style_power != 0: if bg_style_power != 0:
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
bg_loss = K.mean( (alpha_rec*bg_style_power)*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))) bg_loss = K.mean( (10*bg_style_power)*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] ))
else: else:
bg_loss = K.mean( (alpha_rec*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] )) bg_loss = K.mean( (50*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
src_loss += bg_loss src_loss += bg_loss
if not self.options['pixel_loss']: if not self.options['pixel_loss']:
dst_loss_batch = sum([ ( alpha_rec*K.square(dssim(kernel_size=int(resolution/11.6),max_value=1.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)(target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i]) for i in range(len(target_dst_masked_ar_opt)) ])
else: else:
dst_loss_batch = sum([ K.mean ( alpha_rec*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ]) dst_loss_batch = sum([ K.mean ( 50*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
dst_loss = K.mean(dst_loss_batch) dst_loss = K.mean(dst_loss_batch)