mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fix DFLJPG,
SAE: added "rare sample booster" SAE: pixel loss replaced to smooth transition from DSSIM to PixelLoss in 15k epochs by default
This commit is contained in:
parent
f93b4713a9
commit
4d37fd62cd
11 changed files with 174 additions and 101 deletions
|
@ -199,7 +199,7 @@ class ModelBase(object):
|
|||
pass
|
||||
|
||||
#overridable
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generator_list):
|
||||
#train your keras models here
|
||||
|
||||
#return array of losses
|
||||
|
@ -293,7 +293,8 @@ class ModelBase(object):
|
|||
images = []
|
||||
for generator in self.generator_list:
|
||||
for i,batch in enumerate(next(generator)):
|
||||
images.append( batch[0] )
|
||||
if len(batch.shape) == 4:
|
||||
images.append( batch[0] )
|
||||
|
||||
return image_utils.equalize_and_stack_square (images)
|
||||
|
||||
|
@ -305,14 +306,12 @@ class ModelBase(object):
|
|||
supressor = std_utils.suppress_stdout_stderr()
|
||||
supressor.__enter__()
|
||||
|
||||
self.last_sample = self.generate_next_sample()
|
||||
|
||||
epoch_time = time.time()
|
||||
|
||||
losses = self.onTrainOneEpoch(self.last_sample)
|
||||
|
||||
sample = self.generate_next_sample()
|
||||
epoch_time = time.time()
|
||||
losses = self.onTrainOneEpoch(sample, self.generator_list)
|
||||
epoch_time = time.time() - epoch_time
|
||||
|
||||
self.last_sample = sample
|
||||
|
||||
self.loss_history.append ( [float(loss[1]) for loss in losses] )
|
||||
|
||||
if self.supress_std_once:
|
||||
|
|
|
@ -55,7 +55,7 @@ class Model(ModelBase):
|
|||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
|
||||
|
|
|
@ -73,7 +73,7 @@ class Model(ModelBase):
|
|||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]])
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
|
||||
|
|
|
@ -75,7 +75,7 @@ class Model(ModelBase):
|
|||
[self.decoder_dst, self.get_strpath_storage_for_file(self.decoder_dstH5)]] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
warped_src, target_src, target_src_full_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_full_mask = sample[1]
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ class Model(ModelBase):
|
|||
[self.inter_AB, self.get_strpath_storage_for_file(self.inter_ABH5)]] )
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
def onTrainOneEpoch(self, sample, generators_list):
|
||||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
|
||||
|
|
|
@ -52,13 +52,7 @@ class SAEModel(ModelBase):
|
|||
self.options['bg_style_power'] = np.clip ( input_number("Background style power ( 0.0 .. 100.0 ?:help skip:%.2f) : " % (default_bg_style_power), default_bg_style_power, help_message="How fast NN will learn dst background style during generalization of src and dst faces. If style is learned good enough, set this value to 0.1-0.3 to prevent artifacts appearing."), 0.0, 100.0 )
|
||||
else:
|
||||
self.options['bg_style_power'] = self.options.get('bg_style_power', default_bg_style_power)
|
||||
|
||||
if is_first_run or ask_override:
|
||||
default_pixel_loss = False if is_first_run else self.options.get('pixel_loss', False)
|
||||
self.options['pixel_loss'] = input_bool ("Use pixel loss? (y/n, ?:help skip: n/default ) : ", default_pixel_loss, help_message="Default DSSIM loss good for initial understanding structure of faces. Use pixel loss after 30-40k epochs to enhance fine details.")
|
||||
else:
|
||||
self.options['pixel_loss'] = self.options.get('pixel_loss', False)
|
||||
|
||||
|
||||
default_ae_dims = 256 if self.options['archi'] == 'liae' else 512
|
||||
default_ed_ch_dims = 42
|
||||
if is_first_run:
|
||||
|
@ -83,6 +77,7 @@ class SAEModel(ModelBase):
|
|||
bgr_shape = (resolution, resolution, 3)
|
||||
mask_shape = (resolution, resolution, 1)
|
||||
|
||||
dssim_pixel_alpha = Input( (1,) )
|
||||
warped_src = Input(bgr_shape)
|
||||
target_src = Input(bgr_shape)
|
||||
target_srcm = Input(mask_shape)
|
||||
|
@ -199,6 +194,7 @@ class SAEModel(ModelBase):
|
|||
def optimizer():
|
||||
return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999)
|
||||
|
||||
dssim_pixel_alpha_value = dssim_pixel_alpha[0][0]
|
||||
|
||||
if self.options['archi'] == 'liae':
|
||||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights
|
||||
|
@ -208,29 +204,29 @@ class SAEModel(ModelBase):
|
|||
src_dst_loss_train_weights = self.encoder.trainable_weights + self.decoder_src.trainable_weights + self.decoder_dst.trainable_weights
|
||||
if self.options['learn_mask']:
|
||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
src_loss = sum([ K.mean( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] )) for i in range(len(target_src_masked_ar)) ])
|
||||
else:
|
||||
src_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
src_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
||||
src_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar)) ])
|
||||
|
||||
src_loss_batch = src_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + src_pixel_loss_batch*dssim_pixel_alpha_value
|
||||
src_loss = K.mean(src_loss_batch)
|
||||
|
||||
if self.options['face_style_power'] != 0:
|
||||
face_style_power = self.options['face_style_power'] / 100.0
|
||||
src_loss += tf_style_loss(gaussian_blur_radius=resolution // 8, loss_weight=0.2*face_style_power)( psd_target_dst_masked_ar[-1], target_dst_masked_ar[-1] )
|
||||
|
||||
if self.options['bg_style_power'] != 0:
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
if self.options['pixel_loss']:
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||
else:
|
||||
src_loss += K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
|
||||
if self.options['pixel_loss']:
|
||||
dst_loss = sum([ K.mean( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] )) for i in range(len(target_dst_masked_ar)) ])
|
||||
else:
|
||||
dst_loss = sum([ K.mean( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
|
||||
self.src_dst_train = K.function ([warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
||||
bg_style_power = self.options['bg_style_power'] / 100.0
|
||||
bg_dssim_loss = K.mean( (100*bg_style_power)*K.square(tf_dssim(2.0)( psd_target_dst_anti_masked_ar[-1], target_dst_anti_masked_ar[-1] )))
|
||||
bg_pixel_loss = K.mean( (100*bg_style_power)*K.square( psd_target_dst_anti_masked_ar[-1] - target_dst_anti_masked_ar[-1] ))
|
||||
src_loss += bg_dssim_loss*(1.0-dssim_pixel_alpha_value) + bg_pixel_loss*dssim_pixel_alpha_value
|
||||
|
||||
dst_dssim_loss_batch = sum([ ( 100*K.square(tf_dssim(2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
||||
dst_pixel_loss_batch = sum([ tf_reduce_mean ( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar)) ])
|
||||
dst_loss_batch = dst_dssim_loss_batch*(1.0-dssim_pixel_alpha_value) + dst_pixel_loss_batch*dssim_pixel_alpha_value
|
||||
dst_loss = K.mean(dst_loss_batch)
|
||||
|
||||
self.src_dst_train = K.function ([dssim_pixel_alpha, warped_src, target_src, target_srcm, warped_dst, target_dst, target_dstm ],[src_loss,dst_loss,src_loss_batch,dst_loss_batch], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) )
|
||||
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
@ -250,6 +246,9 @@ class SAEModel(ModelBase):
|
|||
self.AE_convert = K.function ([warped_dst],[ pred_src_dst[-1] ])
|
||||
|
||||
if self.is_training_mode:
|
||||
self.src_sample_losses = []
|
||||
self.dst_sample_losses = []
|
||||
|
||||
f = SampleProcessor.TypeFlags
|
||||
face_type = f.FACE_ALIGN_FULL if self.options['face_type'] == 'f' else f.FACE_ALIGN_HALF
|
||||
self.set_training_data_generators ([
|
||||
|
@ -259,14 +258,14 @@ class SAEModel(ModelBase):
|
|||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
] ),
|
||||
], add_sample_idx=True ),
|
||||
|
||||
SampleGeneratorFace(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
|
||||
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, normalize_tanh = True),
|
||||
output_sample_types=[ [f.WARPED_TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_BGR, resolution],
|
||||
[f.TRANSFORMED | face_type | f.MODE_M | f.FACE_MASK_FULL, resolution]
|
||||
] )
|
||||
], add_sample_idx=True )
|
||||
])
|
||||
#override
|
||||
def onSave(self):
|
||||
|
@ -289,13 +288,39 @@ class SAEModel(ModelBase):
|
|||
|
||||
self.save_weights_safe(ar)
|
||||
|
||||
|
||||
#override
|
||||
def onTrainOneEpoch(self, sample):
|
||||
warped_src, target_src, target_src_mask = sample[0]
|
||||
warped_dst, target_dst, target_dst_mask = sample[1]
|
||||
def onTrainOneEpoch(self, generators_samples, generators_list):
|
||||
warped_src, target_src, target_src_mask, src_sample_idxs = generators_samples[0]
|
||||
warped_dst, target_dst, target_dst_mask, dst_sample_idxs = generators_samples[1]
|
||||
|
||||
src_loss, dst_loss = self.src_dst_train ([warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||
dssim_pixel_alpha = np.clip ( self.epoch / 15000.0, 0.0, 1.0 ) #smooth transition between DSSIM and MSE in 15k epochs
|
||||
dssim_pixel_alpha = np.repeat( dssim_pixel_alpha, (self.batch_size,) )
|
||||
dssim_pixel_alpha = np.expand_dims(dssim_pixel_alpha,-1)
|
||||
|
||||
src_loss, dst_loss, src_sample_losses, dst_sample_losses = self.src_dst_train ([dssim_pixel_alpha, warped_src, target_src, target_src_mask, warped_dst, target_dst, target_dst_mask])
|
||||
|
||||
#gathering array of sample_losses
|
||||
self.src_sample_losses += [[src_sample_idxs[i], src_sample_losses[i]] for i in range(self.batch_size) ]
|
||||
self.dst_sample_losses += [[dst_sample_idxs[i], dst_sample_losses[i]] for i in range(self.batch_size) ]
|
||||
|
||||
if len(self.src_sample_losses) >= 48: #array is big enough
|
||||
#fetching idxs which losses are bigger than average
|
||||
x = np.array (self.src_sample_losses)
|
||||
self.src_sample_losses = []
|
||||
b = x[:,1]
|
||||
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
|
||||
generators_list[0].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
||||
|
||||
|
||||
if len(self.dst_sample_losses) >= 48: #array is big enough
|
||||
#fetching idxs which losses are bigger than average
|
||||
x = np.array (self.dst_sample_losses)
|
||||
self.dst_sample_losses = []
|
||||
b = x[:,1]
|
||||
idxs = (x[:,0][ np.argwhere ( b [ b > np.mean(b) ] )[:,0] ]).astype(np.uint)
|
||||
generators_list[1].repeat_sample_idxs(idxs) #ask generator to repeat these sample idxs
|
||||
|
||||
if self.options['learn_mask']:
|
||||
src_mask_loss, dst_mask_loss, = self.src_dst_mask_train ([warped_src, target_src_mask, warped_dst, target_dst_mask])
|
||||
|
||||
|
@ -453,6 +478,9 @@ class SAEModel(ModelBase):
|
|||
strides = resolution // 32 if adapt_k_size else 2
|
||||
lowest_dense_res = resolution // 16
|
||||
|
||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||
|
||||
def downscale (dim):
|
||||
def func(x):
|
||||
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
|
||||
|
@ -496,6 +524,10 @@ class SAEModel(ModelBase):
|
|||
exec (nnlib.import_all(), locals(), globals())
|
||||
ed_dims = output_nc * ed_ch_dims
|
||||
|
||||
|
||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
||||
|
||||
def upscale (dim):
|
||||
def func(x):
|
||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
||||
|
|
|
@ -68,7 +68,7 @@ class devicelib:
|
|||
|
||||
@staticmethod
|
||||
def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
|
||||
if not hasNVML and totalmemsize_gb <= 2:
|
||||
if not hasNVML:
|
||||
return [0]
|
||||
|
||||
result = []
|
||||
|
|
|
@ -52,6 +52,7 @@ class nnlib(object):
|
|||
tf = nnlib.tf
|
||||
tf_sess = nnlib.tf_sess
|
||||
|
||||
tf_reduce_mean = tf.reduce_mean # todo tf 12+ = tf.math.reduce_mean
|
||||
tf_total_variation = tf.image.total_variation
|
||||
tf_dssim = nnlib.tf_dssim
|
||||
tf_ssim = nnlib.tf_ssim
|
||||
|
|
|
@ -2,7 +2,7 @@ import traceback
|
|||
import numpy as np
|
||||
import random
|
||||
import cv2
|
||||
|
||||
import multiprocessing
|
||||
from utils import iter_utils
|
||||
|
||||
from samples import SampleType
|
||||
|
@ -18,10 +18,11 @@ output_sample_types = [
|
|||
]
|
||||
'''
|
||||
class SampleGeneratorFace(SampleGeneratorBase):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], generators_count=2, **kwargs):
|
||||
def __init__ (self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None, with_close_to_self=False, sample_process_options=SampleProcessor.Options(), output_sample_types=[], add_sample_idx=False, generators_count=2, **kwargs):
|
||||
super().__init__(samples_path, debug, batch_size)
|
||||
self.sample_process_options = sample_process_options
|
||||
self.output_sample_types = output_sample_types
|
||||
self.add_sample_idx = add_sample_idx
|
||||
|
||||
if sort_by_yaw_target_samples_path is not None:
|
||||
self.sample_type = SampleType.FACE_YAW_SORTED_AS_TARGET
|
||||
|
@ -34,13 +35,15 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
self.samples = SampleLoader.load (self.sample_type, self.samples_path, sort_by_yaw_target_samples_path)
|
||||
|
||||
self.generators_count = min ( generators_count, len(self.samples) )
|
||||
|
||||
if self.debug:
|
||||
self.generators_count = 1
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, 0 )]
|
||||
else:
|
||||
self.generators_count = min ( generators_count, len(self.samples) )
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, i ) for i in range(self.generators_count) ]
|
||||
|
||||
|
||||
self.generators_sq = [ multiprocessing.Queue() for _ in range(self.generators_count) ]
|
||||
|
||||
self.generator_counter = -1
|
||||
|
||||
def __iter__(self):
|
||||
|
@ -50,47 +53,73 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
return next(generator)
|
||||
|
||||
def batch_func(self, generator_id):
|
||||
samples = self.samples[generator_id::self.generators_count]
|
||||
|
||||
data_len = len(samples)
|
||||
if data_len == 0:
|
||||
|
||||
def repeat_sample_idxs(self, idxs): # [ idx, ... ]
|
||||
#send idxs list to all sub generators.
|
||||
for gen_sq in self.generators_sq:
|
||||
gen_sq.put (idxs)
|
||||
|
||||
def batch_func(self, generator_id):
|
||||
gen_sq = self.generators_sq[generator_id]
|
||||
samples = self.samples
|
||||
samples_len = len(samples)
|
||||
samples_idxs = [ *range(samples_len) ] [generator_id::self.generators_count]
|
||||
repeat_samples_idxs = []
|
||||
|
||||
if len(samples_idxs) == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
|
||||
if self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if all ( [ x == None for x in samples] ):
|
||||
if all ( [ samples[idx] == None for idx in samples_idxs] ):
|
||||
raise ValueError('Not enough training data. Gather more faces!')
|
||||
|
||||
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs = []
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [[]]*data_len
|
||||
|
||||
while True:
|
||||
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [[]]*samples_len
|
||||
|
||||
while True:
|
||||
while not gen_sq.empty():
|
||||
idxs = gen_sq.get()
|
||||
for idx in idxs:
|
||||
if idx in samples_idxs:
|
||||
repeat_samples_idxs.append(idx)
|
||||
|
||||
batches = None
|
||||
for n_batch in range(self.batch_size):
|
||||
while True:
|
||||
sample = None
|
||||
|
||||
if len(repeat_samples_idxs) > 0:
|
||||
idx = repeat_samples_idxs.pop()
|
||||
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
|
||||
sample = samples[idx]
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF]
|
||||
else:
|
||||
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = random.sample( range(data_len), data_len )
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[ idx ]
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = random.sample( range(data_len), data_len )
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
if samples[idx] != None:
|
||||
if len(shuffle_idxs_2D[idx]) == 0:
|
||||
shuffle_idxs_2D[idx] = random.sample( range(len(samples[idx])), len(samples[idx]) )
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = samples[ idx ]
|
||||
|
||||
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = samples_idxs.copy()
|
||||
np.random.shuffle(shuffle_idxs)
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
if samples[idx] != None:
|
||||
if len(shuffle_idxs_2D[idx]) == 0:
|
||||
shuffle_idxs_2D[idx] = random.sample( range(len(samples[idx])), len(samples[idx]) )
|
||||
|
||||
idx2 = shuffle_idxs_2D[idx].pop()
|
||||
sample = samples[idx][idx2]
|
||||
|
||||
idx2 = shuffle_idxs_2D[idx].pop()
|
||||
sample = samples[idx][idx2]
|
||||
idx = (idx << 16) | (idx2 & 0xFFFF)
|
||||
|
||||
if sample is not None:
|
||||
try:
|
||||
|
@ -103,10 +132,14 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
|||
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(len(x)) ]
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches += [ [] ]
|
||||
|
||||
for i in range(len(x)):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
if self.add_sample_idx:
|
||||
batches[-1].append (idx)
|
||||
|
||||
break
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
|
|
@ -1,3 +1,4 @@
|
|||
import traceback
|
||||
from enum import IntEnum
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -56,22 +57,25 @@ class SampleLoader:
|
|||
|
||||
for s in tqdm( samples, desc="Loading", ascii=True ):
|
||||
s_filename_path = Path(s.filename)
|
||||
if s_filename_path.suffix == '.png':
|
||||
dflimg = DFLPNG.load ( str(s_filename_path), print_on_no_embedded_data=True )
|
||||
if dflimg is None: continue
|
||||
elif s_filename_path.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(s_filename_path), print_on_no_embedded_data=True )
|
||||
if dflimg is None: continue
|
||||
else:
|
||||
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
try:
|
||||
if s_filename_path.suffix == '.png':
|
||||
dflimg = DFLPNG.load ( str(s_filename_path), print_on_no_embedded_data=True )
|
||||
if dflimg is None: continue
|
||||
elif s_filename_path.suffix == '.jpg':
|
||||
dflimg = DFLJPG.load ( str(s_filename_path), print_on_no_embedded_data=True )
|
||||
if dflimg is None: continue
|
||||
else:
|
||||
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||
shape=dflimg.get_shape(),
|
||||
landmarks=dflimg.get_landmarks(),
|
||||
yaw=dflimg.get_yaw_value()) )
|
||||
|
||||
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||
shape=dflimg.get_shape(),
|
||||
landmarks=dflimg.get_landmarks(),
|
||||
yaw=dflimg.get_yaw_value()) )
|
||||
except:
|
||||
print ("Unable to load %s , error: %s" % (str(s_filename_path), traceback.format_exc() ) )
|
||||
|
||||
return sample_list
|
||||
|
||||
@staticmethod
|
||||
|
@ -120,7 +124,7 @@ class SampleLoader:
|
|||
yaw = lowest_yaw + i*diff_rot_per_grad
|
||||
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
|
||||
|
||||
yaw_samples = []
|
||||
yaw_samples = []
|
||||
for s in samples:
|
||||
s_yaw = s.yaw
|
||||
if (i == 0 and s_yaw < next_yaw) or \
|
||||
|
|
|
@ -10,6 +10,7 @@ class DFLJPG(object):
|
|||
self.length = 0
|
||||
self.chunks = []
|
||||
self.dfl_dict = None
|
||||
self.shape = (0,0,0)
|
||||
|
||||
@staticmethod
|
||||
def load_raw(filename):
|
||||
|
@ -123,13 +124,16 @@ class DFLJPG(object):
|
|||
|
||||
if id == b"JFIF":
|
||||
c, ver_major, ver_minor, units, Xdensity, Ydensity, Xthumbnail, Ythumbnail = struct_unpack (d, c, "=BBBHHBB")
|
||||
if units != 0:
|
||||
raise Exception("JPG must be in pixel units.")
|
||||
inst.shape = (Ydensity, Xdensity, 3)
|
||||
#if units == 0:
|
||||
# inst.shape = (Ydensity, Xdensity, 3)
|
||||
else:
|
||||
raise Exception("Unknown jpeg ID: %s" % (id) )
|
||||
|
||||
if chunk['name'] == 'APP15':
|
||||
elif chunk['name'] == 'SOF0' or chunk['name'] == 'SOF2':
|
||||
d, c = chunk['data'], 0
|
||||
c, precision, height, width = struct_unpack (d, c, ">BHH")
|
||||
inst.shape = (height, width, 3)
|
||||
|
||||
elif chunk['name'] == 'APP15':
|
||||
if type(chunk['data']) == bytes:
|
||||
inst.dfl_dict = pickle.loads(chunk['data'])
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue