fix funit

This commit is contained in:
Colombo 2019-09-23 18:54:37 +04:00
parent 9e80bbc917
commit deeb98474b
2 changed files with 11 additions and 12 deletions

View file

@ -33,7 +33,7 @@ class TrueFaceModel(ModelBase):
self.options['face_type'] = self.options.get('face_type', default_face_type) self.options['face_type'] = self.options.get('face_type', default_face_type)
if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend: if (is_first_run or ask_override) and 'tensorflow' in self.device_config.backend:
def_optimizer_mode = self.options.get('optimizer_mode', 1) def_optimizer_mode = self.options.get('optimizer_mode', 3)
self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.") self.options['optimizer_mode'] = io.input_int ("Optimizer mode? ( 1,2,3 ?:help skip:%d) : " % (def_optimizer_mode), def_optimizer_mode, help_message="1 - no changes. 2 - allows you to train x2 bigger network consuming RAM. 3 - allows you to train x3 bigger network consuming huge amount of RAM and slower, depends on CPU power.")
else: else:
self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1) self.options['optimizer_mode'] = self.options.get('optimizer_mode', 1)

View file

@ -111,8 +111,8 @@ class FUNIT(object):
if func is not None: if func is not None:
tensors = [func(t) for t in tensors] tensors = [func(t) for t in tensors]
return K.sum(tensors) / (b*h*w), acc return K.sum(tensors, axis=[1,2,3] ) / (h*w), acc
d_xr_la, d_xr_la_acc = dis_gather_mean(d_xr, la, acc_func=lambda x: x >= 0) d_xr_la, d_xr_la_acc = dis_gather_mean(d_xr, la, acc_func=lambda x: x >= 0)
d_xt_lb, d_xt_lb_acc = dis_gather_mean(d_xt, lb, acc_func=lambda x: x >= 0) d_xt_lb, d_xt_lb_acc = dis_gather_mean(d_xt, lb, acc_func=lambda x: x >= 0)
@ -121,30 +121,29 @@ class FUNIT(object):
d_xb_lb_real, d_xb_lb_real_acc = dis_gather_mean(d_xb, lb, lambda x: K.relu(1.0-x), acc_func=lambda x: x >= 0) d_xb_lb_real, d_xb_lb_real_acc = dis_gather_mean(d_xb, lb, lambda x: K.relu(1.0-x), acc_func=lambda x: x >= 0)
d_xt_lb_fake, d_xt_lb_fake_acc = dis_gather_mean(d_xt, lb, lambda x: K.relu(1.0+x), acc_func=lambda x: x < 0) d_xt_lb_fake, d_xt_lb_fake_acc = dis_gather_mean(d_xt, lb, lambda x: K.relu(1.0+x), acc_func=lambda x: x < 0)
G_c_rec = K.mean(K.abs(K.mean(d_xr_feat, axis=[1,2]) - K.mean(d_xa_feat, axis=[1,2]))) #* 1.0 G_c_rec = K.mean(K.abs(K.mean(d_xr_feat, axis=[1,2]) - K.mean(d_xa_feat, axis=[1,2])), axis=1 ) #* 1.0
G_m_rec = K.mean(K.abs(K.mean(d_xt_feat, axis=[1,2]) - K.mean(d_xb_feat, axis=[1,2]))) #* 1.0 G_m_rec = K.mean(K.abs(K.mean(d_xt_feat, axis=[1,2]) - K.mean(d_xb_feat, axis=[1,2])), axis=1 ) #* 1.0
G_x_rec = 0.1 * K.mean(K.abs(xr-xa)) G_x_rec = 0.1 * K.mean(K.abs(xr-xa), axis=[1,2,3])
G_loss = (-d_xr_la-d_xt_lb)*0.5 + G_x_rec + G_c_rec + G_m_rec G_loss = (-d_xr_la-d_xt_lb)*0.5 + G_x_rec + G_c_rec + G_m_rec
G_acc = (d_xr_la_acc+d_xt_lb_acc)*0.5
G_weights = self.enc_class_model.trainable_weights + self.enc_content.trainable_weights + self.decoder.trainable_weights G_weights = self.enc_class_model.trainable_weights + self.enc_content.trainable_weights + self.decoder.trainable_weights
###### ######
D_real = d_xb_lb_real #1.0 * D_real = d_xb_lb_real #1.0 *
D_fake = d_xt_lb_fake #1.0 * D_fake = d_xt_lb_fake #1.0 *
l_reg = 10 * K.sum( K.gradients( d_xb_lb, xb )[0] ** 2 ) # , axis=[1,2,3] / self.batch_size ) l_reg = 10 * K.sum( K.gradients( d_xb_lb, xb )[0] ** 2 , axis=[1,2,3] ) #/ self.batch_size )
D_loss = D_real + D_fake + l_reg D_loss = D_real + D_fake + l_reg
D_acc = (d_xb_lb_real_acc+d_xt_lb_fake_acc)*0.5
D_weights = self.dis.trainable_weights D_weights = self.dis.trainable_weights
self.G_train = K.function ([xa, la, xb, lb],[G_loss], self.G_opt.get_updates(G_loss, G_weights) ) self.G_train = K.function ([xa, la, xb, lb],[K.mean(G_loss)], self.G_opt.get_updates(G_loss, G_weights) )
self.D_train = K.function ([xa, la, xb, lb],[D_loss], self.D_opt.get_updates(D_loss, D_weights) ) self.D_train = K.function ([xa, la, xb, lb],[K.mean(D_loss)], self.D_opt.get_updates(D_loss, D_weights) )
self.get_average_class_code = K.function ([xa],[s_xa_mean]) self.get_average_class_code = K.function ([xa],[s_xa_mean])
self.G_convert = K.function ([xa,s_xa_one],[xr_one]) self.G_convert = K.function ([xa,s_xa_one],[xr_one])