This commit is contained in:
Colombo 2020-01-22 13:41:05 +04:00
parent 0c93b89e87
commit 60804ca3ba
3 changed files with 20 additions and 15 deletions

View file

@ -615,9 +615,9 @@ class SAEHDModel(ModelBase):
with tf.device( f'/GPU:0' if len(devices) != 0 else f'/CPU:0'):
if 'df' in archi:
gpu_dst_code = self.inter(self.encoder(self.warped_dst))
gpu_pred_src_dst = self.decoder_src(gpu_dst_code)
gpu_pred_dst_dstm = self.decoder_dstm(gpu_dst_code)
gpu_pred_src_dstm = self.decoder_srcm(gpu_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder_src(gpu_dst_code)
_, gpu_pred_dst_dstm = self.decoder_dst(gpu_dst_code)
elif 'liae' in archi:
gpu_dst_code = self.encoder (self.warped_dst)
gpu_dst_inter_B_code = self.inter_B (gpu_dst_code)
@ -625,9 +625,8 @@ class SAEHDModel(ModelBase):
gpu_dst_code = tf.concat([gpu_dst_inter_B_code,gpu_dst_inter_AB_code],-1)
gpu_src_dst_code = tf.concat([gpu_dst_inter_AB_code,gpu_dst_inter_AB_code],-1)
gpu_pred_src_dst = self.decoder(gpu_src_dst_code)
gpu_pred_dst_dstm = self.decoderm(gpu_dst_code)
gpu_pred_src_dstm = self.decoderm(gpu_src_dst_code)
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
if learn_mask:
def AE_merge( warped_dst):