This commit is contained in:
Colombo 2020-02-18 19:16:44 +04:00
parent 6a8426dcdb
commit d4335b5fa5

View file

@ -644,14 +644,14 @@ class SAEHDModel(ModelBase):
self.D_train = D_train self.D_train = D_train
if gan_power != 0: if gan_power != 0:
def D_src_dst_train(warped_src, target_src, target_srcm, \ def D_src_dst_train(warped_src, target_src, target_srcm_all, \
warped_dst, target_dst, target_dstm): warped_dst, target_dst, target_dstm_all:
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src, nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src, self.target_src :target_src,
self.target_srcm:target_srcm, self.target_srcm_all:target_srcm_all,
self.warped_dst :warped_dst, self.warped_dst :warped_dst,
self.target_dst :target_dst, self.target_dst :target_dst,
self.target_dstm:target_dstm}) self.target_dstm_all:target_dstm_all})
self.D_src_dst_train = D_src_dst_train self.D_src_dst_train = D_src_dst_train
if learn_mask: if learn_mask: