This commit is contained in:
iperov 2021-06-28 14:27:09 +04:00
parent c43b3b161b
commit 9d6b6feb1f

View file

@ -413,7 +413,7 @@ class AMPModel(ModelBase):
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]), gpu_src_dst_code = tf.concat( (tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,ae_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 )
gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code) gpu_pred_src_src, gpu_pred_src_srcm = self.decoder(gpu_src_code)
gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code) gpu_pred_dst_dst, gpu_pred_dst_dstm = self.decoder(gpu_dst_code)
@ -569,7 +569,7 @@ class AMPModel(ModelBase):
inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32) inter_dims_slice = tf.cast(inter_dims*self.morph_value_t[0], tf.int32)
gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]), gpu_src_dst_code = tf.concat( ( tf.slice(gpu_dst_inter_src_code, [0,0,0,0], [-1, inter_dims_slice , lowest_dense_res, lowest_dense_res]),
tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,ae_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 ) tf.slice(gpu_dst_inter_dst_code, [0,inter_dims_slice,0,0], [-1,inter_dims-inter_dims_slice, lowest_dense_res,lowest_dense_res]) ), 1 )
gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code) gpu_pred_src_dst, gpu_pred_src_dstm = self.decoder(gpu_src_dst_code)
_, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code) _, gpu_pred_dst_dstm = self.decoder(gpu_dst_inter_dst_code)