diff --git a/facelib/LandmarksProcessor.py b/facelib/LandmarksProcessor.py index 608e44c..f8bcd5c 100644 --- a/facelib/LandmarksProcessor.py +++ b/facelib/LandmarksProcessor.py @@ -683,6 +683,9 @@ def draw_rect_landmarks(image, rect, image_landmarks, face_size, face_type, tran [(0, 0), (0, face_size - 1), (face_size - 1, face_size - 1), (face_size - 1, 0)], image_to_face_mat, True) imagelib.draw_polygon (image, points, (0,0,255), 2) + + points = transform_points ( [ ( int(face_size*0.05), 0), ( int(face_size*0.1), int(face_size*0.1) ), ( 0, int(face_size*0.1) ) ], image_to_face_mat, True) + imagelib.draw_polygon (image, points, (0,0,255), 2) def calc_face_pitch(landmarks): diff --git a/facelib/S3FDExtractor.py b/facelib/S3FDExtractor.py index 2d2d937..b56e159 100644 --- a/facelib/S3FDExtractor.py +++ b/facelib/S3FDExtractor.py @@ -12,8 +12,8 @@ class S3FDExtractor(object): S3FD: Single Shot Scale-invariant Face Detector https://arxiv.org/pdf/1708.05237.pdf """ - def __init__(self): - exec(nnlib.import_all(), locals(), globals()) + def __init__(self, do_dummy_predict=False): + exec( nnlib.import_all(), locals(), globals() ) model_path = Path(__file__).parent / "S3FD.h5" if not model_path.exists(): @@ -21,13 +21,14 @@ class S3FDExtractor(object): self.model = nnlib.keras.models.load_model ( str(model_path) ) - self.extract ( np.zeros( (1080,1920,3), dtype=np.uint8) ) + if do_dummy_predict: + self.extract ( np.zeros( (640,640,3), dtype=np.uint8) ) def __enter__(self): return self def __exit__(self, exc_type=None, exc_value=None, traceback=None): - return False # pass exception between __enter__ and __exit__ to outter level + return False #pass exception between __enter__ and __exit__ to outter level def extract(self, input_image, is_bgr=True, is_remove_intersects=False, nms_thresh=0.3): """ @@ -40,7 +41,7 @@ class S3FDExtractor(object): """ if is_bgr: - input_image = input_image[:, :, ::-1] + input_image = input_image[:,:,::-1] is_bgr = False (h, w, ch) = input_image.shape @@ -53,16 +54,16 @@ class S3FDExtractor(object): input_image = cv2.resize(input_image, (int(w / input_scale), int(h / input_scale)), interpolation=cv2.INTER_LINEAR) - olist = self.model.predict(np.expand_dims(input_image, 0)) + olist = self.model.predict( np.expand_dims(input_image,0) ) detected_faces = [] for ltrb in self._refine(olist, nms_thresh): - l, t, r, b = [x * input_scale for x in ltrb] - bt = b - t - if min(r - l, bt) < 40: # filtering faces < 40pix by any side + l,t,r,b = [ x*input_scale for x in ltrb] + bt = b-t + if min(r-l,bt) < 40: #filtering faces < 40pix by any side continue - b += bt * 0.1 # enlarging bottom line a bit for 2DFAN-4, because default is not enough covering a chin - detected_faces.append([int(x) for x in (l, t, r, b)]) + b += bt*0.1 #enlarging bottom line a bit for 2DFAN-4, because default is not enough covering a chin + detected_faces.append ( [int(x) for x in (l,t,r,b) ] ) #sort by largest area first detected_faces = [ [(l,t,r,b), (r-l)*(b-t) ] for (l,t,r,b) in detected_faces ] @@ -83,18 +84,18 @@ class S3FDExtractor(object): def _refine(self, olist, thresh): bboxlist = [] - for i, ((ocls,), (oreg,)) in enumerate(zip(olist[::2], olist[1::2])): - stride = 2 ** (i + 2) # 4,8,16,32,64,128 + for i, ((ocls,), (oreg,)) in enumerate ( zip ( olist[::2], olist[1::2] ) ): + stride = 2**(i + 2) # 4,8,16,32,64,128 s_d2 = stride / 2 s_m4 = stride * 4 for hindex, windex in zip(*np.where(ocls > 0.05)): score = ocls[hindex, windex] - loc = oreg[hindex, windex, :] + loc = oreg[hindex, windex, :] priors = np.array([windex * stride + s_d2, hindex * stride + s_d2, s_m4, s_m4]) priors_2p = priors[2:] box = np.concatenate((priors[:2] + loc[:2] * 0.1 * priors_2p, - priors_2p * np.exp(loc[2:] * 0.2))) + priors_2p * np.exp(loc[2:] * 0.2)) ) box[:2] -= box[2:] / 2 box[2:] += box[:2] @@ -104,7 +105,7 @@ class S3FDExtractor(object): if len(bboxlist) == 0: bboxlist = np.zeros((1, 5)) bboxlist = bboxlist[self._refine_nms(bboxlist, thresh), :] - bboxlist = [x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5] + bboxlist = [ x[:-1].astype(np.int) for x in bboxlist if x[-1] >= 0.5] return bboxlist def _refine_nms(self, dets, nms_thresh): diff --git a/main.py b/main.py index 1896fff..c9b9d1f 100644 --- a/main.py +++ b/main.py @@ -112,7 +112,7 @@ if __name__ == "__main__": p = subparsers.add_parser( "sort", help="Sort faces in a directory.") p.add_argument('--input-dir', required=True, action=fixPathAction, dest="input_dir", help="Input directory. A directory containing the files you wish to process.") - p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) + p.add_argument('--by', required=True, dest="sort_by_method", choices=("blur", "face", "face-dissim", "face-yaw", "face-pitch", "hist", "hist-dissim", "brightness", "hue", "black", "origname", "oneface", "final", "final-no-blur", "vggface", "test"), help="Method of sorting. 'origname' sort by original filename to recover original sequence." ) p.set_defaults (func=process_sort) def process_util(arguments): diff --git a/mainscripts/Extractor.py b/mainscripts/Extractor.py index 8a036db..23cee97 100644 --- a/mainscripts/Extractor.py +++ b/mainscripts/Extractor.py @@ -77,7 +77,7 @@ class ExtractSubprocessor(Subprocessor): self.e = facelib.DLIBExtractor(nnlib.dlib) elif self.type == 'rects-s3fd': nnlib.import_all (device_config) - self.e = facelib.S3FDExtractor() + self.e = facelib.S3FDExtractor(do_dummy_predict=True) else: raise ValueError ("Wrong type.") @@ -89,7 +89,7 @@ class ExtractSubprocessor(Subprocessor): self.e = facelib.FANExtractor() self.e.__enter__() if self.device_vram >= 2: - self.second_pass_e = facelib.S3FDExtractor() + self.second_pass_e = facelib.S3FDExtractor(do_dummy_predict=False) self.second_pass_e.__enter__() else: self.second_pass_e = None diff --git a/mainscripts/Sorter.py b/mainscripts/Sorter.py index e9332da..7c94025 100644 --- a/mainscripts/Sorter.py +++ b/mainscripts/Sorter.py @@ -1,19 +1,26 @@ import os -import sys -import operator -import numpy as np -import cv2 -from shutil import copyfile -from pathlib import Path -from utils import Path_utils -from utils.DFLPNG import DFLPNG -from utils.DFLJPG import DFLJPG -from utils.cv2_utils import * -from facelib import LandmarksProcessor -from joblib import Subprocessor import multiprocessing -from interact import interact as io +import operator +import sys +from pathlib import Path +from shutil import copyfile + +import cv2 +import numpy as np +from numpy import linalg as npla + +import imagelib +from facelib import LandmarksProcessor +from functools import cmp_to_key from imagelib import estimate_sharpness +from interact import interact as io +from joblib import Subprocessor +from nnlib import VGGFace +from utils import Path_utils +from utils.cv2_utils import * +from utils.DFLJPG import DFLJPG +from utils.DFLPNG import DFLPNG + class BlurEstimatorSubprocessor(Subprocessor): class Cli(Subprocessor.Cli): @@ -772,24 +779,97 @@ def sort_final(input_path, include_by_blur=True): for pg in range(pitch_grads): img_list = pitch_sample_list[pg] if img_list is None: - continue + continue final_img_list += [ img_list.pop(0) ] if len(img_list) == 0: - pitch_sample_list[pg] = None + pitch_sample_list[pg] = None n -= 1 if n == 0: break - if n_prev == n: - break + if n_prev == n: + break for pg in range(pitch_grads): img_list = pitch_sample_list[pg] if img_list is None: - continue + continue trash_img_list += img_list return final_img_list, trash_img_list + +def sort_by_vggface(input_path): + io.log_info ("Sorting by face similarity using VGGFace model...") + + model = VGGFace() + + final_img_list = [] + trash_img_list = [] + + image_paths = Path_utils.get_image_paths(input_path) + img_list = [ (x,) for x in image_paths ] + img_list_len = len(img_list) + img_list_range = [*range(img_list_len)] + + feats = [None]*img_list_len + for i in io.progress_bar_generator(img_list_range, "Loading"): + img = cv2_imread( img_list[i][0] ).astype(np.float32) + img = imagelib.normalize_channels (img, 3) + img = cv2.resize (img, (224,224) ) + img = img[..., ::-1] + img[..., 0] -= 93.5940 + img[..., 1] -= 104.7624 + img[..., 2] -= 129.1863 + feats[i] = model.predict( img[None,...] )[0] + + tmp = np.zeros( (img_list_len,) ) + float_inf = float("inf") + for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ): + i_feat = feats[i] + + for j in img_list_range: + tmp[j] = npla.norm(i_feat-feats[j]) if j >= i+1 else float_inf + + idx = np.argmin(tmp) + + img_list[i+1], img_list[idx] = img_list[idx], img_list[i+1] + feats[i+1], feats[idx] = feats[idx], feats[i+1] + + return img_list, trash_img_list + +""" + img_list_len = len(img_list) + + for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ): + a = [] + i_1 = img_list[i][1] + + + for j in range(i+1, img_list_len): + a.append ( [ j, np.linalg.norm(i_1-img_list[j][1]) ] ) + + x = sorted(a, key=operator.itemgetter(1) )[0][0] + saved = img_list[i+1] + img_list[i+1] = img_list[x] + img_list[x] = saved + + + q = np.array ( [ x[1] for x in img_list ] ) + + for i in io.progress_bar_generator ( range(img_list_len-1), "Sorting" ): + + a = np.linalg.norm( q[i] - q[i+1:], axis=1 ) + a = i+1+np.argmin(a) + + saved = img_list[i+1] + img_list[i+1] = img_list[a] + img_list[a] = saved + + saved = q[i+1] + q[i+1] = q[a] + q[a] = saved +""" + def final_process(input_path, img_list, trash_img_list): if len(trash_img_list) != 0: parent_input_path = input_path.parent @@ -851,6 +931,7 @@ def main (input_path, sort_by_method): elif sort_by_method == 'black': img_list = sort_by_black (input_path) elif sort_by_method == 'origname': img_list, trash_img_list = sort_by_origname (input_path) elif sort_by_method == 'oneface': img_list, trash_img_list = sort_by_oneface_in_image (input_path) + elif sort_by_method == 'vggface': img_list, trash_img_list = sort_by_vggface (input_path) elif sort_by_method == 'final': img_list, trash_img_list = sort_final (input_path) elif sort_by_method == 'final-no-blur': img_list, trash_img_list = sort_final (input_path, include_by_blur=False) diff --git a/mainscripts/dev_misc.py b/mainscripts/dev_misc.py index 72be53f..b6e1d4d 100644 --- a/mainscripts/dev_misc.py +++ b/mainscripts/dev_misc.py @@ -37,6 +37,10 @@ def extract_vggface2_dataset(input_dir, device_args={} ): cur_input_path = input_path / dir_name cur_output_path = output_path / dir_name + + l = len(Path_utils.get_image_paths(cur_input_path)) + if l < 250 or l > 350: + continue io.log_info (f"Processing: {str(cur_input_path)} ") diff --git a/mathlib/umeyama.py b/mathlib/umeyama.py index 7c6b2d0..826a88f 100644 --- a/mathlib/umeyama.py +++ b/mathlib/umeyama.py @@ -57,7 +57,7 @@ def umeyama(src, dst, estimate_scale): T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) d[dim - 1] = s else: - T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V.T)) + T[:dim, :dim] = np.dot(U, np.dot(np.diag(d), V)) if estimate_scale: # Eq. (41) and (42). diff --git a/models/Model_DEV_FUNIT/Model.py b/models/Model_DEV_FUNIT/Model.py index 395c56b..16387f2 100644 --- a/models/Model_DEV_FUNIT/Model.py +++ b/models/Model_DEV_FUNIT/Model.py @@ -42,7 +42,7 @@ class FUNITModel(ModelBase): #override def onInitialize(self, batch_size=-1, **in_options): exec(nnlib.code_import_all, locals(), globals()) - self.set_vram_batch_requirements({4:16}) + self.set_vram_batch_requirements({4:16,11:24}) resolution = self.options['resolution'] face_type = FaceType.FULL if self.options['face_type'] == 'f' else FaceType.HALF @@ -75,7 +75,8 @@ class FUNITModel(ModelBase): face_type = t.FACE_TYPE_FULL if self.options['face_type'] == 'f' else t.FACE_TYPE_HALF output_sample_types=[ {'types': (t.IMG_TRANSFORMED, face_type, t.MODE_BGR), 'resolution':128, 'normalize_tanh':True} ] - + output_sample_types1=[ {'types': (t.IMG_SOURCE, face_type, t.MODE_BGR), 'resolution':128, 'normalize_tanh':True} ] + self.set_training_data_generators ([ SampleGeneratorFace(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=True), @@ -87,11 +88,11 @@ class FUNITModel(ModelBase): SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=True), - output_sample_types=output_sample_types, person_id_mode=True ), + output_sample_types=output_sample_types1, person_id_mode=True ), SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size, sample_process_options=SampleProcessor.Options(random_flip=True), - output_sample_types=output_sample_types, person_id_mode=True ), + output_sample_types=output_sample_types1, person_id_mode=True ), ]) #override diff --git a/nnlib/FUNIT.py b/nnlib/FUNIT.py index 9df50c6..0bd5006 100644 --- a/nnlib/FUNIT.py +++ b/nnlib/FUNIT.py @@ -162,10 +162,6 @@ class FUNIT(object): for w in weights_list: K.set_value( w, K.get_value(initer(K.int_shape(w))) ) - #if not self.is_first_run(): - # self.load_weights_safe(self.get_model_filename_list()) - - if load_weights_locally: pass @@ -188,9 +184,6 @@ class FUNIT(object): [self.D_opt, 'D_opt.h5'], ] - #def save_weights(self): - # self.model.save_weights (str(self.weights_path)) - def train(self, xa,la,xb,lb): D_loss, = self.D_train ([xa,la,xb,lb]) G_loss, = self.G_train ([xa,la,xb,lb]) @@ -209,17 +202,17 @@ class FUNIT(object): def ResBlock(dim): def func(input): x = input - x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x)) + x = Conv2D(dim, 3, strides=1, padding='same')(x) x = InstanceNormalization()(x) x = ReLU()(x) - x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x)) + x = Conv2D(dim, 3, strides=1, padding='same')(x) x = InstanceNormalization()(x) return Add()([x,input]) return func def func(x): - x = Conv2D (nf, kernel_size=7, strides=1, padding='valid')(ZeroPadding2D(3)(x)) + x = Conv2D (nf, kernel_size=7, strides=1, padding='same')(x) x = InstanceNormalization()(x) x = ReLU()(x) for i in range(downs): @@ -237,11 +230,11 @@ class FUNIT(object): exec (nnlib.import_all(), locals(), globals()) def func(x): - x = Conv2D (nf, kernel_size=7, strides=1, padding='valid', activation='relu')(ZeroPadding2D(3)(x)) + x = Conv2D (nf, kernel_size=7, strides=1, padding='same', activation='relu')(x) for i in range(downs): x = Conv2D (nf * min ( 4, 2**(i+1) ), kernel_size=4, strides=2, padding='valid', activation='relu')(ZeroPadding2D(1)(x)) x = GlobalAveragePooling2D()(x) - x = Dense(nf)(x) + x = Dense(latent_dim)(x) return x return func @@ -250,16 +243,14 @@ class FUNIT(object): def DecoderFlow(ups, n_res_blks=2, mlp_blks=2, subpixel_decoder=False ): exec (nnlib.import_all(), locals(), globals()) - - def ResBlock(dim): def func(input): inp, mlp = input x = inp - x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x)) + x = Conv2D(dim, 3, strides=1, padding='same')(x) x = FUNITAdain(kernel_initializer='he_normal')([x,mlp]) x = ReLU()(x) - x = Conv2D(dim, 3, strides=1, padding='valid')(ZeroPadding2D(1)(x)) + x = Conv2D(dim, 3, strides=1, padding='same')(x) x = FUNITAdain(kernel_initializer='he_normal')([x,mlp]) return Add()([x,inp]) return func @@ -280,16 +271,16 @@ class FUNIT(object): for i in range(ups): if subpixel_decoder: - x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='valid')(ZeroPadding2D(1)(x)) + x = Conv2D (4* (nf // 2**(i+1)), kernel_size=3, strides=1, padding='same')(x) x = SubpixelUpscaler()(x) else: x = UpSampling2D()(x) - x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='valid')(ZeroPadding2D(2)(x)) + x = Conv2D (nf // 2**(i+1), kernel_size=5, strides=1, padding='same')(x) x = InstanceNormalization()(x) x = ReLU()(x) - rgb = Conv2D (3, kernel_size=7, strides=1, padding='valid', activation='tanh')(ZeroPadding2D(3)(x)) + rgb = Conv2D (3, kernel_size=7, strides=1, padding='same', activation='tanh')(x) return rgb return func diff --git a/nnlib/VGGFace.py b/nnlib/VGGFace.py new file mode 100644 index 0000000..06babf0 --- /dev/null +++ b/nnlib/VGGFace.py @@ -0,0 +1,64 @@ +from nnlib import nnlib + +def VGGFace(): + exec(nnlib.import_all(), locals(), globals()) + + img_input = Input(shape=(224,224,3) ) + + # Block 1 + x = Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_1')( + img_input) + x = Conv2D(64, (3, 3), activation='relu', padding='same', name='conv1_2')(x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='pool1')(x) + + # Block 2 + x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_1')( + x) + x = Conv2D(128, (3, 3), activation='relu', padding='same', name='conv2_2')( + x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='pool2')(x) + + # Block 3 + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_1')( + x) + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_2')( + x) + x = Conv2D(256, (3, 3), activation='relu', padding='same', name='conv3_3')( + x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='pool3')(x) + + # Block 4 + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_1')( + x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_2')( + x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv4_3')( + x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='pool4')(x) + + # Block 5 + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv5_1')( + x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv5_2')( + x) + x = Conv2D(512, (3, 3), activation='relu', padding='same', name='conv5_3')( + x) + x = MaxPooling2D((2, 2), strides=(2, 2), name='pool5')(x) + + + # Classification block + x = Flatten(name='flatten')(x) + x = Dense(4096, name='fc6')(x) + x = Activation('relu', name='fc6/relu')(x) + x = Dense(4096, name='fc7')(x) + x = Activation('relu', name='fc7/relu')(x) + x = Dense(2622, name='fc8')(x) + x = Activation('softmax', name='fc8/softmax')(x) + + model = Model(img_input, x, name='vggface_vgg16') + weights_path = keras.utils.data_utils.get_file('rcmalli_vggface_tf_vgg16.h5', + 'https://github.com/rcmalli/keras-vggface/releases/download/v2.0/rcmalli_vggface_tf_vgg16.h5') + + model.load_weights(weights_path, by_name=True) + + return model \ No newline at end of file diff --git a/nnlib/__init__.py b/nnlib/__init__.py index d0d6a59..e30936a 100644 --- a/nnlib/__init__.py +++ b/nnlib/__init__.py @@ -1,3 +1,4 @@ from .nnlib import nnlib from .FUNIT import FUNIT -from .TernausNet import TernausNet \ No newline at end of file +from .TernausNet import TernausNet +from .VGGFace import VGGFace \ No newline at end of file diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 0443d7b..d9143c1 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -63,6 +63,7 @@ UpSampling2D = KL.UpSampling2D BatchNormalization = KL.BatchNormalization PixelNormalization = nnlib.PixelNormalization +Activation = KL.Activation LeakyReLU = KL.LeakyReLU ELU = KL.ELU ReLU = KL.ReLU diff --git a/samplelib/SampleGeneratorFace.py b/samplelib/SampleGeneratorFace.py index 96bd2cc..46f6824 100644 --- a/samplelib/SampleGeneratorFace.py +++ b/samplelib/SampleGeneratorFace.py @@ -30,7 +30,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong): person_id_mode=False, add_sample_idx=False, generators_count=2, - generators_random_seed=None, + generators_random_seed=None, ping_pong=PingPongOptions(), **kwargs): @@ -53,58 +53,56 @@ class SampleGeneratorFace(SampleGeneratorPingPong): self.generators_random_seed = generators_random_seed samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path, person_id_mode=person_id_mode) - self.total_samples_count = len(samples) + self.samples_len = len(samples) + + if self.samples_len == 0: + raise ValueError('No training data provided.') ct_samples = SampleLoader.load (SampleType.FACE, random_ct_samples_path) if random_ct_samples_path is not None else None self.random_ct_sample_chance = 100 if self.debug: self.generators_count = 1 - self.generators = [iter_utils.ThisThreadGenerator(self.batch_func, (0, samples, ct_samples))] + self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, (0, samples, ct_samples) )] else: - self.generators_count = min(generators_count, len(samples)) - self.generators = [ - iter_utils.SubprocessGenerator(self.batch_func, (i, samples[i::self.generators_count], ct_samples)) for - i in range(self.generators_count)] + self.generators_count = min ( generators_count, self.samples_len ) + self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, (i, samples[i::self.generators_count], ct_samples ) ) for i in range(self.generators_count) ] self.generator_counter = -1 #overridable def get_total_sample_count(self): - return self.total_samples_count + return self.samples_len def __iter__(self): return self def __next__(self): self.generator_counter += 1 - generator = self.generators[self.generator_counter % len(self.generators)] + generator = self.generators[self.generator_counter % len(self.generators) ] super().__next__() return next(generator) - def batch_func(self, param): + def batch_func(self, param ): generator_id, samples, ct_samples = param if self.generators_random_seed is not None: - np.random.seed(self.generators_random_seed[generator_id]) + np.random.seed ( self.generators_random_seed[generator_id] ) samples_len = len(samples) samples_idxs = [*range(samples_len)] ct_samples_len = len(ct_samples) if ct_samples is not None else 0 - if len(samples_idxs) == 0: - raise ValueError('No training data provided.') - if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: - if all([samples[idx] == None for idx in samples_idxs]): + if all ( [ samples[idx] == None for idx in samples_idxs] ): raise ValueError('Not enough training data. Gather more faces!') if self.sample_type == SampleType.FACE: shuffle_idxs = [] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: shuffle_idxs = [] - shuffle_idxs_2D = [[]] * samples_len + shuffle_idxs_2D = [[]]*samples_len while True: batches = None @@ -118,7 +116,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong): np.random.shuffle(shuffle_idxs) idx = shuffle_idxs.pop() - sample = samples[idx] + sample = samples[ idx ] elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET: if len(shuffle_idxs) == 0: @@ -128,8 +126,8 @@ class SampleGeneratorFace(SampleGeneratorPingPong): idx = shuffle_idxs.pop() if samples[idx] != None: if len(shuffle_idxs_2D[idx]) == 0: - a = shuffle_idxs_2D[idx] = [*range(len(samples[idx]))] - np.random.shuffle(a) + a = shuffle_idxs_2D[idx] = [ *range(len(samples[idx])) ] + np.random.shuffle (a) idx2 = shuffle_idxs_2D[idx].pop() sample = samples[idx][idx2] @@ -138,11 +136,11 @@ class SampleGeneratorFace(SampleGeneratorPingPong): if sample is not None: try: - ct_sample = None - if ct_samples is not None: + ct_sample=None + if ct_samples is not None: if np.random.randint(100) < self.random_ct_sample_chance: - ct_sample = ct_samples[np.random.randint(ct_samples_len)] - + ct_sample=ct_samples[np.random.randint(ct_samples_len)] + x = SampleProcessor.process(sample, self.sample_process_options, self.output_sample_types, self.debug, ct_sample=ct_sample) except: @@ -153,7 +151,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong): raise Exception('SampleProcessor.process returns NOT tuple/list') if batches is None: - batches = [[] for _ in range(len(x))] + batches = [ [] for _ in range(len(x)) ] if self.add_sample_idx: batches += [ [] ] i_sample_idx = len(batches)-1 @@ -163,7 +161,7 @@ class SampleGeneratorFace(SampleGeneratorPingPong): i_person_id = len(batches)-1 for i in range(len(x)): - batches[i].append(x[i]) + batches[i].append ( x[i] ) if self.add_sample_idx: batches[i_sample_idx].append (idx) @@ -172,8 +170,8 @@ class SampleGeneratorFace(SampleGeneratorPingPong): batches[i_person_id].append ( np.array([sample.person_id]) ) break - yield [np.array(batch) for batch in batches] - + yield [ np.array(batch) for batch in batches] + def update_batch(self, batch_size): self.batch_size = batch_size