From cbd78fbc8a785e682dc3bf9af23c6ac7981d0c8a Mon Sep 17 00:00:00 2001 From: seranus Date: Mon, 6 Dec 2021 22:33:52 +0100 Subject: [PATCH 1/2] mask labels fixed --- models/Model_Quick96/Model.py | 7 ++++++- models/Model_SAEHD/Model.py | 8 ++++++-- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index 8babe0f..9d259d6 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -305,7 +305,12 @@ class QModel(ModelBase): st_m = [] for i in range(n_samples): - ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + SM = S[i]*target_srcm[i] + DM = D[i]*target_dstm[i] + if filenames is not None and len(filenames) > 0: + SM = label_face_filename(SM, filenames[0][i]) + DM = label_face_filename(DM, filenames[0][i]) + ar = SM, SS[i], DM, DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index a363949..4dd5bfd 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -982,8 +982,12 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... st_m = [] for i in range(n_samples): SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] - - ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i]*SSM[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*SD_mask + SM = S[i]*target_srcm[i] + DM = D[i]*target_dstm[i] + if filenames is not None and len(filenames) > 0: + SM = label_face_filename(SM, filenames[0][i]) + DM = label_face_filename(DM, filenames[0][i]) + ar = SM, SS[i]*SSM[i], DM, DD[i]*DDM[i], SD[i]*SD_mask st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] From 6aee9030d2f0472ac8ee166febac6cfd4f059edf Mon Sep 17 00:00:00 2001 From: seranus Date: Mon, 6 Dec 2021 22:38:06 +0100 Subject: [PATCH 2/2] fix --- models/Model_XSeg/Model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py index b0addfd..1313e6f 100644 --- a/models/Model_XSeg/Model.py +++ b/models/Model_XSeg/Model.py @@ -196,7 +196,7 @@ class XSegModel(ModelBase): return ( ('loss', np.mean(loss) ), ) #override - def onGetPreview(self, samples, for_history=False): + def onGetPreview(self, samples, for_history=False, filenames=None): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) if self.pretrain: