Merge pull request #8 from MachineEditor/preview_filenames

Preview filenames
This commit is contained in:
Ognjen 2021-12-06 22:42:12 +01:00 committed by GitHub
commit 01ff54fe7d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 13 additions and 4 deletions

View file

@ -307,7 +307,12 @@ class QModel(ModelBase):
st_m = [] st_m = []
for i in range(n_samples): for i in range(n_samples):
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) SM = S[i]*target_srcm[i]
DM = D[i]*target_dstm[i]
if filenames is not None and len(filenames) > 0:
SM = label_face_filename(SM, filenames[0][i])
DM = label_face_filename(DM, filenames[0][i])
ar = SM, SS[i], DM, DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ]

View file

@ -983,8 +983,12 @@ class SAEHDModel(ModelBase):
st_m = [] st_m = []
for i in range(n_samples): for i in range(n_samples):
SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i]
SM = S[i]*target_srcm[i]
ar = S[i]*target_srcm[i], SS[i]*SSM[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask DM = D[i]*target_dstm[i]
if filenames is not None and len(filenames) > 0:
SM = label_face_filename(SM, filenames[0][i])
DM = label_face_filename(DM, filenames[0][i])
ar = SM, SS[i]*SSM[i], DM, DD[i]*DDM[i], SD[i]*SD_mask
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ]

View file

@ -198,7 +198,7 @@ class XSegModel(ModelBase):
return ( ('loss', np.mean(loss) ), ) return ( ('loss', np.mean(loss) ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples, for_history=False, filenames=None):
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
if self.pretrain: if self.pretrain: