mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
_
This commit is contained in:
parent
7140ee8684
commit
2cd028fb67
2 changed files with 33 additions and 60 deletions
|
@ -48,10 +48,10 @@ class PoseEstimator(object):
|
|||
if training:
|
||||
latent_t = self.encoder(inp_t)
|
||||
bgr_t = self.decoder (latent_t)
|
||||
bins_t = self.model_l(latent_t)
|
||||
pyrs_t = self.model_l(latent_t)
|
||||
else:
|
||||
self.model = Model(inp_t, self.model_l(self.encoder(inp_t)) )
|
||||
bins_t = self.model(inp_t)
|
||||
pyrs_t = self.model(inp_t)
|
||||
|
||||
|
||||
if load_weights:
|
||||
|
@ -76,50 +76,31 @@ class PoseEstimator(object):
|
|||
|
||||
CAInitializerMP ( gather_Conv2D_layers( [self.encoder, self.decoder] ) )
|
||||
|
||||
idx_tensor = self.idx_tensor = K.constant( np.array([idx for idx in range(self.class_nums[0])], dtype=K.floatx() ) )
|
||||
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * idx_tensor, 1), K.sum (bins_t[1] * idx_tensor, 1), K.sum ( bins_t[2] * idx_tensor, 1)
|
||||
|
||||
if training:
|
||||
inp_bins_t = []
|
||||
inp_pyrs_t = []
|
||||
for class_num in self.class_nums:
|
||||
inp_bins_t += [ Input ((class_num,)), Input ((class_num,)), Input ((class_num,)) ]
|
||||
inp_pyrs_t += [ Input ((3,)) ]
|
||||
|
||||
loss_pitch = []
|
||||
loss_yaw = []
|
||||
loss_roll = []
|
||||
pyr_loss = []
|
||||
|
||||
for i,class_num in enumerate(self.class_nums):
|
||||
a = self.alpha_cat_losses[i]
|
||||
loss_pitch += [ a*K.categorical_crossentropy( inp_bins_t[i*3+0], bins_t[i*3+0] ) ]
|
||||
loss_yaw += [ a*K.categorical_crossentropy( inp_bins_t[i*3+1], bins_t[i*3+1] ) ]
|
||||
loss_roll += [ a*K.categorical_crossentropy( inp_bins_t[i*3+2], bins_t[i*3+2] ) ]
|
||||
pyr_loss += [ a*K.mean( K.square ( inp_pyrs_t[i] - pyrs_t[i]) ) ]
|
||||
|
||||
bgr_loss = K.mean( 10*dssim(kernel_size=int(resolution/11.6),max_value=1.0)( inp_real_t*inp_mask_t, bgr_t*inp_mask_t) )
|
||||
|
||||
reg_alpha = 0.01
|
||||
reg_pitch_loss = reg_alpha * K.mean(K.square( inp_pitch_t - pitch_t), -1)
|
||||
reg_yaw_loss = reg_alpha * K.mean(K.square( inp_yaw_t - yaw_t), -1)
|
||||
reg_roll_loss = reg_alpha * K.mean(K.square( inp_roll_t - roll_t), -1)
|
||||
|
||||
pitch_loss = reg_pitch_loss + sum(loss_pitch)
|
||||
yaw_loss = reg_yaw_loss + sum(loss_yaw)
|
||||
roll_loss = reg_roll_loss + sum(loss_roll)
|
||||
pyr_loss = sum(pyr_loss)
|
||||
|
||||
|
||||
self.train = K.function ([inp_t, inp_real_t, inp_mask_t],
|
||||
[bgr_loss], Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates( bgr_loss, self.encoder.trainable_weights+self.decoder.trainable_weights ) )
|
||||
|
||||
self.train_l = K.function ([inp_t, inp_pitch_t, inp_yaw_t, inp_roll_t] + inp_bins_t,
|
||||
[K.mean(pitch_loss),K.mean(yaw_loss),K.mean(roll_loss)], Adam(lr=0.0001).get_updates( [pitch_loss,yaw_loss,roll_loss], self.model_l.trainable_weights) )
|
||||
self.train_l = K.function ([inp_t] + inp_pyrs_t,
|
||||
[pyr_loss], Adam(lr=0.0001).get_updates( pyr_loss, self.model_l.trainable_weights) )
|
||||
|
||||
|
||||
self.view = K.function ([inp_t], [pitch_t, yaw_t, roll_t] )
|
||||
|
||||
def flow(self, x):
|
||||
bins_t = self.model(x)
|
||||
return bins_t[0], bins_t[1], bins_t[2]
|
||||
pitch_t, yaw_t, roll_t = K.sum ( bins_t[0] * self.idx_tensor, 1), K.sum (bins_t[1] * self.idx_tensor, 1), K.sum ( bins_t[2] * self.idx_tensor, 1)
|
||||
return pitch_t, yaw_t, roll_t
|
||||
self.view = K.function ([inp_t], [ pyrs_t[0] ] )
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
@ -136,7 +117,6 @@ class PoseEstimator(object):
|
|||
Model(inp_t, self.model_l(self.encoder(inp_t)) ).save_weights (str(self.model_weights_path))
|
||||
|
||||
def train_on_batch(self, warps, imgs, masks, pitch_yaw_roll, skip_bgr_train=False):
|
||||
pyr = pitch_yaw_roll+1
|
||||
|
||||
if not skip_bgr_train:
|
||||
bgr_loss, = self.train( [warps, imgs, masks] )
|
||||
|
@ -145,20 +125,11 @@ class PoseEstimator(object):
|
|||
|
||||
feed = [imgs]
|
||||
for i, (angle, class_num) in enumerate(zip(self.angles, self.class_nums)):
|
||||
c = np.round(pyr * (angle / 2) ).astype(K.floatx())
|
||||
inp_pitch = c[:,0:1]
|
||||
inp_yaw = c[:,1:2]
|
||||
inp_roll = c[:,2:3]
|
||||
if i == 0:
|
||||
feed += [inp_pitch, inp_yaw, inp_roll]
|
||||
c = np.round( np.round(pitch_yaw_roll * angle) / angle ) #.astype(K.floatx())
|
||||
feed += [c]
|
||||
|
||||
inp_pitch_bins = keras.utils.to_categorical(inp_pitch, class_num )
|
||||
inp_yaw_bins = keras.utils.to_categorical(inp_yaw, class_num )
|
||||
inp_roll_bins = keras.utils.to_categorical(inp_roll, class_num )
|
||||
feed += [inp_pitch_bins, inp_yaw_bins, inp_roll_bins]
|
||||
|
||||
pitch_loss,yaw_loss,roll_loss = self.train_l(feed)
|
||||
return bgr_loss, pitch_loss, yaw_loss, roll_loss
|
||||
pyr_loss, = self.train_l(feed)
|
||||
return bgr_loss, pyr_loss
|
||||
|
||||
def extract (self, input_image, is_input_tanh=False):
|
||||
if is_input_tanh:
|
||||
|
@ -168,9 +139,10 @@ class PoseEstimator(object):
|
|||
if input_shape_len == 3:
|
||||
input_image = input_image[np.newaxis,...]
|
||||
|
||||
pitch, yaw, roll = self.view( [input_image] )
|
||||
result = np.concatenate( (pitch[...,np.newaxis], yaw[...,np.newaxis], roll[...,np.newaxis]), -1 )
|
||||
result = np.clip ( result / (self.angles[0] / 2) - 1, -1.0, 1.0 )
|
||||
result, = self.view( [input_image] )
|
||||
|
||||
|
||||
#result = np.clip ( result / (self.angles[0] / 2) - 1, 0.0, 1.0 )
|
||||
|
||||
if input_shape_len == 3:
|
||||
result = result[0]
|
||||
|
@ -206,7 +178,7 @@ class PoseEstimator(object):
|
|||
|
||||
def downscale (dim, **kwargs):
|
||||
def func(x):
|
||||
return MaxPooling2D()( Act() ( XConv2D(dim, kernel_size=5, strides=1)(x)) )
|
||||
return Act() ( XConv2D(dim, kernel_size=5, strides=2)(x))
|
||||
return func
|
||||
|
||||
def upscale (dim, **kwargs):
|
||||
|
@ -314,10 +286,8 @@ class PoseEstimator(object):
|
|||
|
||||
output = []
|
||||
for class_num in class_nums:
|
||||
pitch = Dense(class_num, activation='softmax')(x)
|
||||
yaw = Dense(class_num, activation='softmax')(x)
|
||||
roll = Dense(class_num, activation='softmax')(x)
|
||||
output += [pitch,yaw,roll]
|
||||
pyr = Dense(3, activation='sigmoid')(x)
|
||||
output += [pyr]
|
||||
|
||||
return output
|
||||
|
||||
|
|
|
@ -59,13 +59,13 @@ class Model(ModelBase):
|
|||
output_sample_types=[ {'types': (t.IMG_WARPED_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution, 'motion_blur':(25, 1) },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M, t.FACE_MASK_FULL), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
]),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, generators_count=4,
|
||||
sample_process_options=SampleProcessor.Options( rotation_range=[0,0] ), #random_flip=True,
|
||||
output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR_SHUFFLE), 'resolution':self.resolution },
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL,)}
|
||||
{'types': (t.IMG_PITCH_YAW_ROLL_SIGMOID,)}
|
||||
])
|
||||
])
|
||||
|
||||
|
@ -77,9 +77,9 @@ class Model(ModelBase):
|
|||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
target_srcw, target_src, target_srcm, pitch_yaw_roll = generators_samples[0]
|
||||
|
||||
bgr_loss, pitch_loss,yaw_loss,roll_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
bgr_loss, pyr_loss = self.pose_est.train_on_batch( target_srcw, target_src, target_srcm, pitch_yaw_roll, skip_bgr_train=not self.options['train_bgr'] )
|
||||
|
||||
return ( ('bgr_loss', bgr_loss), ('pitch_loss', pitch_loss), ('yaw_loss', yaw_loss), ('roll_loss', roll_loss) )
|
||||
return ( ('bgr_loss', bgr_loss), ('pyr_loss', pyr_loss), )
|
||||
|
||||
#override
|
||||
def onGetPreview(self, generators_samples):
|
||||
|
@ -99,8 +99,11 @@ class Model(ModelBase):
|
|||
hor_imgs = []
|
||||
for i in range(len(img)):
|
||||
img_info = np.ones ( (h,w,c) ) * 0.1
|
||||
lines = ["%s" % ( str(pyr[i]) ),
|
||||
"%s" % ( str(pyr_pred[i]) ) ]
|
||||
|
||||
i_pyr = pyr[i]
|
||||
i_pyr_pred = pyr_pred[i]
|
||||
lines = ["%.4f %.4f %.4f" % (i_pyr[0],i_pyr[1],i_pyr[2]),
|
||||
"%.4f %.4f %.4f" % (i_pyr_pred[0],i_pyr_pred[1],i_pyr_pred[2]) ]
|
||||
|
||||
lines_count = len(lines)
|
||||
for ln in range(lines_count):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue