added file names to model previews - except xseg

This commit is contained in:
seranus 2021-12-06 22:19:05 +01:00
commit 546b72ff12
5 changed files with 38 additions and 10 deletions

View file

@ -364,7 +364,7 @@ class ModelBase(object):
return ( ('loss_src', 0), ('loss_dst', 0) ) return ( ('loss_src', 0), ('loss_dst', 0) )
#overridable #overridable
def onGetPreview(self, sample, for_history=False): def onGetPreview(self, sample, for_history=False, filenames=None):
#you can return multiple previews #you can return multiple previews
#return [ ('preview_name',preview_rgb), ... ] #return [ ('preview_name',preview_rgb), ... ]
return [] return []
@ -392,7 +392,7 @@ class ModelBase(object):
return self.target_iter != 0 and self.iter >= self.target_iter return self.target_iter != 0 and self.iter >= self.target_iter
def get_previews(self): def get_previews(self):
return self.onGetPreview ( self.last_sample ) return self.onGetPreview ( self.last_sample, filenames=self.last_sample_filenames)
def get_static_previews(self): def get_static_previews(self):
return self.onGetPreview (self.sample_for_preview) return self.onGetPreview (self.sample_for_preview)
@ -476,12 +476,19 @@ class ModelBase(object):
def generate_next_samples(self): def generate_next_samples(self):
sample = [] sample = []
sample_filenames = []
for generator in self.generator_list: for generator in self.generator_list:
if generator.is_initialized(): if generator.is_initialized():
sample.append ( generator.generate_next() ) batch = generator.generate_next()
if type(batch) is tuple:
sample.append ( batch[0] )
sample_filenames.append( batch[1] )
else:
sample.append ( batch )
else: else:
sample.append ( [] ) sample.append ( [] )
self.last_sample = sample self.last_sample = sample
self.last_sample_filenames = sample_filenames
return sample return sample
#overridable #overridable

View file

@ -10,6 +10,7 @@ from facelib import FaceType
from models import ModelBase from models import ModelBase
from samplelib import * from samplelib import *
from core.cv2ex import * from core.cv2ex import *
from utils.label_face import label_face_filename
class AMPModel(ModelBase): class AMPModel(ModelBase):
@ -742,7 +743,7 @@ class AMPModel(ModelBase):
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples, for_history=False, filenames=None):
( (warped_src, target_src, target_srcm, target_srcm_em), ( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
@ -774,6 +775,10 @@ class AMPModel(ModelBase):
i = np.random.randint(n_samples) if not for_history else 0 i = np.random.randint(n_samples) if not for_history else 0
if filenames is not None and len(filenames) > 0:
S[i] = label_face_filename(S[i], filenames[0][i])
D[i] = label_face_filename(D[i], filenames[1][i])
st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ]
st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ]

View file

@ -9,6 +9,7 @@ from core.leras import nn
from facelib import FaceType from facelib import FaceType
from models import ModelBase from models import ModelBase
from samplelib import * from samplelib import *
from utils.label_face import label_face_filename
class QModel(ModelBase): class QModel(ModelBase):
#override #override
@ -278,7 +279,7 @@ class QModel(ModelBase):
return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) return ( ('src_loss', src_loss), ('dst_loss', dst_loss), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples, for_history=False, filenames=None):
( (warped_src, target_src, target_srcm), ( (warped_src, target_src, target_srcm),
(warped_dst, target_dst, target_dstm) ) = samples (warped_dst, target_dst, target_dstm) ) = samples
@ -288,6 +289,12 @@ class QModel(ModelBase):
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
n_samples = min(4, self.get_batch_size() ) n_samples = min(4, self.get_batch_size() )
if filenames is not None and len(filenames) > 0:
for i in range(n_samples):
S[i] = label_face_filename(S[i], filenames[0][i])
D[i] = label_face_filename(D[i], filenames[1][i])
result = [] result = []
st = [] st = []
for i in range(n_samples): for i in range(n_samples):
@ -298,7 +305,7 @@ class QModel(ModelBase):
st_m = [] st_m = []
for i in range(n_samples): for i in range(n_samples):
ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i])
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ]

View file

@ -9,6 +9,7 @@ from core.leras import nn
from facelib import FaceType from facelib import FaceType
from models import ModelBase from models import ModelBase
from samplelib import * from samplelib import *
from utils.label_face import label_face_filename
class SAEHDModel(ModelBase): class SAEHDModel(ModelBase):
@ -786,7 +787,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None
cpu_count = multiprocessing.cpu_count() cpu_count = min(multiprocessing.cpu_count(), 4)
src_generators_count = cpu_count // 2 src_generators_count = cpu_count // 2
dst_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2
if ct_mode is not None: if ct_mode is not None:
@ -946,7 +947,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), )
#override #override
def onGetPreview(self, samples, for_history=False): def onGetPreview(self, samples, for_history=False, filenames=None):
( (warped_src, target_src, target_srcm, target_srcm_em), ( (warped_src, target_src, target_srcm, target_srcm_em),
(warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples
@ -958,6 +959,11 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
if filenames is not None and len(filenames) > 0:
for i in range(n_samples):
S[i] = label_face_filename(S[i], filenames[0][i])
D[i] = label_face_filename(D[i], filenames[1][i])
if self.resolution <= 256: if self.resolution <= 256:
result = [] result = []
@ -977,7 +983,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
for i in range(n_samples): for i in range(n_samples):
SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i]
ar = S[i]*target_srcm[i], SS[i]*SSM[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i]*SSM[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*SD_mask
st_m.append ( np.concatenate ( ar, axis=1) ) st_m.append ( np.concatenate ( ar, axis=1) )
result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ]

View file

@ -115,6 +115,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
samples, index_host, ct_samples, ct_index_host = param samples, index_host, ct_samples, ct_index_host = param
bs = self.batch_size bs = self.batch_size
filenames = []
while True: while True:
batches = None batches = None
@ -141,4 +142,6 @@ class SampleGeneratorFace(SampleGeneratorBase):
for i in range(len(x)): for i in range(len(x)):
batches[i].append ( x[i] ) batches[i].append ( x[i] )
yield [ np.array(batch) for batch in batches] filenames.append(sample.filename)
yield ([ np.array(batch) for batch in batches], filenames)