This commit is contained in:
Colombo 2020-02-18 19:16:44 +04:00
parent 6a8426dcdb
commit d4335b5fa5

View file

@ -644,14 +644,14 @@ class SAEHDModel(ModelBase):
self.D_train = D_train
if gan_power != 0:
def D_src_dst_train(warped_src, target_src, target_srcm, \
warped_dst, target_dst, target_dstm):
def D_src_dst_train(warped_src, target_src, target_srcm_all, \
warped_dst, target_dst, target_dstm_all:
nn.tf_sess.run ([src_D_src_dst_loss_gv_op], feed_dict={self.warped_src :warped_src,
self.target_src :target_src,
self.target_srcm:target_srcm,
self.target_srcm_all:target_srcm_all,
self.warped_dst :warped_dst,
self.target_dst :target_dst,
self.target_dstm:target_dstm})
self.target_dstm_all:target_dstm_all})
self.D_src_dst_train = D_src_dst_train
if learn_mask: