Merge branch 'fix/missing-pred-src-mask' into test/missing-pred-src-mask

# Conflicts:
#	models/Model_SAEHD/Model.py
This commit is contained in:
jh 2021-03-12 10:07:15 -08:00
commit 9d343af29f

View file

@ -583,7 +583,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
def AE_view(warped_src, warped_dst): def AE_view(warped_src, warped_dst):
return nn.tf_sess.run ( [pred_src_src, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm], return nn.tf_sess.run ( [pred_src_src, pred_src_srcm, pred_dst_dst, pred_dst_dstm, pred_src_dst, pred_src_dstm],
feed_dict={self.warped_src:warped_src, feed_dict={self.warped_src:warped_src,
self.warped_dst:warped_dst}) self.warped_dst:warped_dst})
self.AE_view = AE_view self.AE_view = AE_view
@ -737,8 +737,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
( (warped_src, target_src, target_srcm, target_srcm_em), ( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ] S, D, SS, SSM, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ] SSM, DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [SSM, DDM, SDM] ]
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
@ -758,7 +758,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for i in range(n_samples): for i in range(n_samples):
SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i]
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask ar = S[i]*target_srcm[i], SS[i]*SSM[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ]
@ -786,7 +786,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
st_m = [] st_m = []
for i in range(n_samples): for i in range(n_samples):
ar = S[i]*target_srcm[i], SS[i] ar = S[i]*target_srcm[i], SS[i]*SSM[i]
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ] result += [ ('SAEHD masked src-src', np.concatenate (st_m, axis=0 )), ]