From ebb085e31e46bd1fb88d6c809c230172e295cf45 Mon Sep 17 00:00:00 2001 From: seranus Date: Mon, 6 Dec 2021 23:28:19 +0100 Subject: [PATCH] added filename previews --- models/ModelBase.py | 13 ++++++++++--- models/Model_AMP/Model.py | 8 +++++++- models/Model_Quick96/Model.py | 10 +++++++++- models/Model_SAEHD/Model.py | 9 ++++++++- samplelib/SampleGeneratorFace.py | 3 ++- 5 files changed, 36 insertions(+), 7 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index 33b7922..5b338da 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -415,7 +415,7 @@ class ModelBase(object): return ( ('loss_src', 0), ('loss_dst', 0) ) #overridable - def onGetPreview(self, sample, for_history=False): + def onGetPreview(self, sample, for_history=False, filenames=None): #you can return multiple previews #return [ ('preview_name',preview_rgb), ... ] return [] @@ -447,7 +447,7 @@ class ModelBase(object): return self.target_iter != 0 and self.iter >= self.target_iter def get_previews(self): - return self.onGetPreview ( self.last_sample ) + return self.onGetPreview ( self.last_sample, filenames=self.last_sample_filenames ) def get_static_previews(self): return self.onGetPreview (self.sample_for_preview) @@ -585,12 +585,19 @@ class ModelBase(object): def generate_next_samples(self): sample = [] + sample_filenames = [] for generator in self.generator_list: if generator.is_initialized(): - sample.append ( generator.generate_next() ) + batch = generator.generate_next() + if type(batch) is tuple: + sample.append ( batch[0] ) + sample_filenames.append( batch[1] ) + else: + sample.append ( batch ) else: sample.append ( [] ) self.last_sample = sample + self.last_sample_filenames = sample_filenames return sample #overridable diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 44af92a..4ab859f 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -13,6 +13,8 @@ from core.cv2ex import * from pathlib import Path +from utils.label_face import label_face_filename + class AMPModel(ModelBase): #override @@ -888,7 +890,7 @@ class AMPModel(ModelBase): return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override - def onGetPreview(self, samples, for_history=False): + def onGetPreview(self, samples, for_history=False, filenames=None): ( (warped_src, target_src, target_srcm, target_srcm_em), (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples @@ -920,6 +922,10 @@ class AMPModel(ModelBase): i = np.random.randint(n_samples) if not for_history else 0 + if filenames is not None and len(filenames) > 0: + S[i] = label_face_filename(S[i], filenames[0][i]) + D[i] = label_face_filename(D[i], filenames[1][i]) + st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index 8b999af..a03d673 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -12,6 +12,8 @@ from samplelib import * from pathlib import Path +from utils.label_face import label_face_filename + class QModel(ModelBase): #override def on_initialize_options(self): @@ -287,7 +289,7 @@ class QModel(ModelBase): return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) #override - def onGetPreview(self, samples, for_history=False): + def onGetPreview(self, samples, for_history=False, filenames=None): ( (warped_src, target_src, target_srcm), (warped_dst, target_dst, target_dstm) ) = samples @@ -297,6 +299,12 @@ class QModel(ModelBase): target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] n_samples = min(4, self.get_batch_size() ) + + if filenames is not None and len(filenames) > 0: + for i in range(n_samples): + S[i] = label_face_filename(S[i], filenames[0][i]) + D[i] = label_face_filename(D[i], filenames[1][i]) + result = [] st = [] for i in range(n_samples): diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 2fb3a75..817ce9d 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -12,6 +12,8 @@ from samplelib import * from pathlib import Path +from utils.label_face import label_face_filename + class SAEHDModel(ModelBase): #override @@ -953,7 +955,7 @@ class SAEHDModel(ModelBase): return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override - def onGetPreview(self, samples, for_history=False): + def onGetPreview(self, samples, for_history=False, filenames=None): ( (warped_src, target_src, target_srcm, target_srcm_em), (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples @@ -965,6 +967,11 @@ class SAEHDModel(ModelBase): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) + if filenames is not None and len(filenames) > 0: + for i in range(n_samples): + S[i] = label_face_filename(S[i], filenames[0][i]) + D[i] = label_face_filename(D[i], filenames[1][i]) + if self.resolution <= 256: result = [] diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 605d327..70f8d71 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -115,6 +115,7 @@ class SampleGeneratorFace(SampleGeneratorBase): samples, index_host, ct_samples, ct_index_host = param bs = self.batch_size + filenames = [] while True: batches = None @@ -141,4 +142,4 @@ class SampleGeneratorFace(SampleGeneratorBase): for i in range(len(x)): batches[i].append ( x[i] ) - yield [ np.array(batch) for batch in batches] + yield ([ np.array(batch) for batch in batches], filenames)