diff --git a/models/ModelBase.py b/models/ModelBase.py index d1ea117..33b7922 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -415,7 +415,7 @@ class ModelBase(object): return ( ('loss_src', 0), ('loss_dst', 0) ) #overridable - def onGetPreview(self, sample, for_history=False, filenames=None): + def onGetPreview(self, sample, for_history=False): #you can return multiple previews #return [ ('preview_name',preview_rgb), ... ] return [] @@ -447,7 +447,7 @@ class ModelBase(object): return self.target_iter != 0 and self.iter >= self.target_iter def get_previews(self): - return self.onGetPreview ( self.last_sample, filenames=self.last_sample_filenames) + return self.onGetPreview ( self.last_sample ) def get_static_previews(self): return self.onGetPreview (self.sample_for_preview) @@ -585,19 +585,12 @@ class ModelBase(object): def generate_next_samples(self): sample = [] - sample_filenames = [] for generator in self.generator_list: if generator.is_initialized(): - batch = generator.generate_next() - if type(batch) is tuple: - sample.append ( batch[0] ) - sample_filenames.append( batch[1] ) - else: - sample.append ( batch ) + sample.append ( generator.generate_next() ) else: sample.append ( [] ) self.last_sample = sample - self.last_sample_filenames = sample_filenames return sample #overridable diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index e8feaa3..44af92a 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -10,7 +10,6 @@ from facelib import FaceType from models import ModelBase from samplelib import * from core.cv2ex import * -from utils.label_face import label_face_filename from pathlib import Path @@ -889,7 +888,7 @@ class AMPModel(ModelBase): return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override - def onGetPreview(self, samples, for_history=False, filenames=None): + def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm, target_srcm_em), (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples @@ -921,10 +920,6 @@ class AMPModel(ModelBase): i = np.random.randint(n_samples) if not for_history else 0 - if filenames is not None and len(filenames) > 0: - S[i] = label_face_filename(S[i], filenames[0][i]) - D[i] = label_face_filename(D[i], filenames[1][i]) - st = [ np.concatenate ((S[i], D[i], DD[i]*DDM_000[i]), axis=1) ] st += [ np.concatenate ((SS[i], DD[i], SD_100[i] ), axis=1) ] diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index ce7e10d..dba5738 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -9,7 +9,6 @@ from core.leras import nn from facelib import FaceType from models import ModelBase from samplelib import * -from utils.label_face import label_face_filename from pathlib import Path @@ -288,7 +287,7 @@ class QModel(ModelBase): return ( ('src_loss', src_loss), ('dst_loss', dst_loss), ) #override - def onGetPreview(self, samples, for_history=False, filenames=None): + def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm), (warped_dst, target_dst, target_dstm) ) = samples @@ -298,12 +297,6 @@ class QModel(ModelBase): target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )] n_samples = min(4, self.get_batch_size() ) - - if filenames is not None and len(filenames) > 0: - for i in range(n_samples): - S[i] = label_face_filename(S[i], filenames[0][i]) - D[i] = label_face_filename(D[i], filenames[1][i]) - result = [] st = [] for i in range(n_samples): @@ -314,7 +307,7 @@ class QModel(ModelBase): st_m = [] for i in range(n_samples): - ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) + ar = S[i]*target_srcm[i], SS[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*(DDM[i]*SDM[i]) st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('Quick96 masked', np.concatenate (st_m, axis=0 )), ] diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 33f9a50..5b87442 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -9,7 +9,6 @@ from core.leras import nn from facelib import FaceType from models import ModelBase from samplelib import * -from utils.label_face import label_face_filename from pathlib import Path @@ -794,7 +793,7 @@ class SAEHDModel(ModelBase): random_ct_samples_path=training_data_dst_path if ct_mode is not None and not self.pretrain else None - cpu_count = min(multiprocessing.cpu_count(), 4) + cpu_count = multiprocessing.cpu_count() src_generators_count = cpu_count // 2 dst_generators_count = cpu_count // 2 if ct_mode is not None: @@ -954,7 +953,7 @@ class SAEHDModel(ModelBase): return ( ('src_loss', np.mean(src_loss) ), ('dst_loss', np.mean(dst_loss) ), ) #override - def onGetPreview(self, samples, for_history=False, filenames=None): + def onGetPreview(self, samples, for_history=False): ( (warped_src, target_src, target_srcm, target_srcm_em), (warped_dst, target_dst, target_dstm, target_dstm_em) ) = samples @@ -966,11 +965,6 @@ class SAEHDModel(ModelBase): n_samples = min(4, self.get_batch_size(), 800 // self.resolution ) - if filenames is not None and len(filenames) > 0: - for i in range(n_samples): - S[i] = label_face_filename(S[i], filenames[0][i]) - D[i] = label_face_filename(D[i], filenames[1][i]) - if self.resolution <= 256: result = [] @@ -990,7 +984,7 @@ class SAEHDModel(ModelBase): for i in range(n_samples): SD_mask = DDM[i]*SDM[i] if self.face_type < FaceType.HEAD else SDM[i] - ar = label_face_filename(S[i]*target_srcm[i], filenames[0][i]), SS[i]*SSM[i], label_face_filename(D[i]*target_dstm[i], filenames[1][i]), DD[i]*DDM[i], SD[i]*SD_mask + ar = S[i]*target_srcm[i], SS[i]*SSM[i], D[i]*target_dstm[i], DD[i]*DDM[i], SD[i]*SD_mask st_m.append ( np.concatenate ( ar, axis=1) ) result += [ ('SAEHD masked', np.concatenate (st_m, axis=0 )), ] diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 68229c2..605d327 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -115,7 +115,6 @@ class SampleGeneratorFace(SampleGeneratorBase): samples, index_host, ct_samples, ct_index_host = param bs = self.batch_size - filenames = [] while True: batches = None @@ -142,6 +141,4 @@ class SampleGeneratorFace(SampleGeneratorBase): for i in range(len(x)): batches[i].append ( x[i] ) - filenames.append(sample.filename) - - yield ([ np.array(batch) for batch in batches], filenames) + yield [ np.array(batch) for batch in batches]