mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
fix funit
This commit is contained in:
parent
9e80bbc917
commit
deeb98474b
2 changed files with 11 additions and 12 deletions
|
@ -33,7 +33,7 @@ class TrueFaceModel(ModelBase):
|
|||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||
|
||||
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
|
||||
def_optimizer_mode = self.options.get('optimizer_mode', 1)
|
||||
def_optimizer_mode = self.options.get('optimizer_mode', 3)
|
||||
self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
|
||||
else:
|
||||
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)
|
||||
|
|
|
@ -112,7 +112,7 @@ class FUNIT(object):
|
|||
if func is not None:
|
||||
tensors = [func(t) for t in tensors]
|
||||
|
||||
return K.sum(tensors) / (b*h*w), acc
|
||||
return K.sum(tensors, axis=[1,2,3] ) / (h*w), acc
|
||||
|
||||
d_xr_la, d_xr_la_acc = dis_gather_mean(d_xr, la, acc_func=lambda x: x >= 0)
|
||||
d_xt_lb, d_xt_lb_acc = dis_gather_mean(d_xt, lb, acc_func=lambda x: x >= 0)
|
||||
|
@ -122,12 +122,12 @@ class FUNIT(object):
|
|||
d_xb_lb_real, d_xb_lb_real_acc = dis_gather_mean(d_xb, lb, lambda x: K.relu(1.0-x), acc_func=lambda x: x >= 0)
|
||||
d_xt_lb_fake, d_xt_lb_fake_acc = dis_gather_mean(d_xt, lb, lambda x: K.relu(1.0+x), acc_func=lambda x: x < 0)
|
||||
|
||||
G_c_rec = K.mean(K.abs(K.mean(d_xr_feat, axis=[1,2]) - K.mean(d_xa_feat, axis=[1,2]))) #* 1.0
|
||||
G_m_rec = K.mean(K.abs(K.mean(d_xt_feat, axis=[1,2]) - K.mean(d_xb_feat, axis=[1,2]))) #* 1.0
|
||||
G_x_rec = 0.1 * K.mean(K.abs(xr-xa))
|
||||
|
||||
G_c_rec = K.mean(K.abs(K.mean(d_xr_feat, axis=[1,2]) - K.mean(d_xa_feat, axis=[1,2])), axis=1 ) #* 1.0
|
||||
G_m_rec = K.mean(K.abs(K.mean(d_xt_feat, axis=[1,2]) - K.mean(d_xb_feat, axis=[1,2])), axis=1 ) #* 1.0
|
||||
G_x_rec = 0.1 * K.mean(K.abs(xr-xa), axis=[1,2,3])
|
||||
|
||||
G_loss = (-d_xr_la-d_xt_lb)*0.5 + G_x_rec + G_c_rec + G_m_rec
|
||||
G_acc = (d_xr_la_acc+d_xt_lb_acc)*0.5
|
||||
|
||||
G_weights = self.enc_class_model.trainable_weights + self.enc_content.trainable_weights + self.decoder.trainable_weights
|
||||
######
|
||||
|
@ -135,16 +135,15 @@ class FUNIT(object):
|
|||
D_real = d_xb_lb_real #1.0 *
|
||||
D_fake = d_xt_lb_fake #1.0 *
|
||||
|
||||
l_reg = 10 * K.sum( K.gradients( d_xb_lb, xb )[0] ** 2 ) # , axis=[1,2,3] / self.batch_size )
|
||||
l_reg = 10 * K.sum( K.gradients( d_xb_lb, xb )[0] ** 2 , axis=[1,2,3] ) #/ self.batch_size )
|
||||
|
||||
D_loss = D_real + D_fake + l_reg
|
||||
D_acc = (d_xb_lb_real_acc+d_xt_lb_fake_acc)*0.5
|
||||
|
||||
D_weights = self.dis.trainable_weights
|
||||
|
||||
self.G_train = K.function ([xa, la, xb, lb],[G_loss], self.G_opt.get_updates(G_loss, G_weights) )
|
||||
self.G_train = K.function ([xa, la, xb, lb],[K.mean(G_loss)], self.G_opt.get_updates(G_loss, G_weights) )
|
||||
|
||||
self.D_train = K.function ([xa, la, xb, lb],[D_loss], self.D_opt.get_updates(D_loss, D_weights) )
|
||||
self.D_train = K.function ([xa, la, xb, lb],[K.mean(D_loss)], self.D_opt.get_updates(D_loss, D_weights) )
|
||||
self.get_average_class_code = K.function ([xa],[s_xa_mean])
|
||||
|
||||
self.G_convert = K.function ([xa,s_xa_one],[xr_one])
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue