This commit is contained in:
iperov 2018-06-04 17:12:43 +04:00
parent 73de93b4f1
commit 6bd5a44264
71 changed files with 8448 additions and 0 deletions

View file

@ -0,0 +1,245 @@
import traceback
import random
from pathlib import Path
from tqdm import tqdm
import numpy as np
import cv2
from utils.AlignedPNG import AlignedPNG
from utils import iter_utils
from utils import Path_utils
from .BaseTypes import TrainingDataType
from .BaseTypes import TrainingDataSample
from facelib import FaceType
from facelib import LandmarksProcessor
'''
You can implement your own TrainingDataGenerator
'''
class TrainingDataGeneratorBase(object):
cache = dict()
#DONT OVERRIDE
#use YourOwnTrainingDataGenerator (..., your_opt=1)
#and then this opt will be passed in YourOwnTrainingDataGenerator.onInitialize ( your_opt )
def __init__ (self, trainingdatatype, training_data_path, target_training_data_path=None, debug=False, batch_size=1, **kwargs):
if not isinstance(trainingdatatype, TrainingDataType):
raise Exception('TrainingDataGeneratorBase() trainingdatatype is not TrainingDataType')
if training_data_path is None:
raise Exception('training_data_path is None')
self.training_data_path = Path(training_data_path)
self.target_training_data_path = Path(target_training_data_path) if target_training_data_path is not None else None
self.debug = debug
self.batch_size = 1 if self.debug else batch_size
self.trainingdatatype = trainingdatatype
self.data = TrainingDataGeneratorBase.load (trainingdatatype, self.training_data_path, self.target_training_data_path)
if self.debug:
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, self.data)]
else:
if len(self.data) > 1:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, self.data[0::2] ),
iter_utils.SubprocessGenerator ( self.batch_func, self.data[1::2] )]
else:
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, self.data )]
self.generator_counter = -1
self.onInitialize(**kwargs)
#overridable
def onInitialize(self, **kwargs):
#your TrainingDataGenerator initialization here
pass
#overridable
def onProcessSample(self, sample, debug):
#process sample and return tuple of images for your model in onTrainOneEpoch
return ( np.zeros( (64,64,4), dtype=np.float32 ), )
def __iter__(self):
return self
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators) ]
x = next(generator)
return x
def batch_func(self, data):
data_len = len(data)
if data_len == 0:
raise ValueError('No training data provided.')
if self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
if all ( [ x == None for x in data] ):
raise ValueError('Not enough training data. Gather more faces!')
if self.trainingdatatype == TrainingDataType.IMAGE or self.trainingdatatype == TrainingDataType.FACE:
shuffle_idxs = []
elif self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
shuffle_idxs = []
shuffle_idxs_2D = [[]]*data_len
while True:
batches = None
for n_batch in range(0, self.batch_size):
while True:
sample = None
if self.trainingdatatype == TrainingDataType.IMAGE or self.trainingdatatype == TrainingDataType.FACE:
if len(shuffle_idxs) == 0:
shuffle_idxs = [ i for i in range(0, data_len) ]
random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
sample = data[ idx ]
elif self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
if len(shuffle_idxs) == 0:
shuffle_idxs = [ i for i in range(0, data_len) ]
random.shuffle(shuffle_idxs)
idx = shuffle_idxs.pop()
if data[idx] != None:
if len(shuffle_idxs_2D[idx]) == 0:
shuffle_idxs_2D[idx] = [ i for i in range(0, len(data[idx])) ]
random.shuffle(shuffle_idxs_2D[idx])
idx2 = shuffle_idxs_2D[idx].pop()
sample = data[idx][idx2]
if sample is not None:
try:
x = self.onProcessSample (sample, self.debug)
except:
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
if type(x) != tuple and type(x) != list:
raise Exception('TrainingDataGenerator.onProcessSample() returns NOT tuple/list')
x_len = len(x)
if batches is None:
batches = [ [] for _ in range(0,x_len) ]
for i in range(0,x_len):
batches[i].append ( x[i] )
break
yield [ np.array(batch) for batch in batches]
def get_dict_state(self):
return {}
def set_dict_state(self, state):
pass
@staticmethod
def load(trainingdatatype, training_data_path, target_training_data_path=None):
cache = TrainingDataGeneratorBase.cache
if str(training_data_path) not in cache.keys():
cache[str(training_data_path)] = [None]*TrainingDataType.QTY
if target_training_data_path is not None and str(target_training_data_path) not in cache.keys():
cache[str(target_training_data_path)] = [None]*TrainingDataType.QTY
datas = cache[str(training_data_path)]
if trainingdatatype == TrainingDataType.IMAGE:
if datas[trainingdatatype] is None:
datas[trainingdatatype] = [ TrainingDataSample(filename=filename) for filename in tqdm( Path_utils.get_image_paths(training_data_path), desc="Loading" ) ]
elif trainingdatatype == TrainingDataType.FACE:
if datas[trainingdatatype] is None:
datas[trainingdatatype] = X_LOAD( [ TrainingDataSample(filename=filename) for filename in Path_utils.get_image_paths(training_data_path) ] )
elif trainingdatatype == TrainingDataType.FACE_YAW_SORTED:
if datas[trainingdatatype] is None:
datas[trainingdatatype] = X_YAW_SORTED( TrainingDataGeneratorBase.load(TrainingDataType.FACE, training_data_path) )
elif trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
if datas[trainingdatatype] is None:
if target_training_data_path is None:
raise Exception('target_training_data_path is None for FACE_YAW_SORTED_AS_TARGET')
datas[trainingdatatype] = X_YAW_AS_Y_SORTED( TrainingDataGeneratorBase.load(TrainingDataType.FACE_YAW_SORTED, training_data_path), TrainingDataGeneratorBase.load(TrainingDataType.FACE_YAW_SORTED, target_training_data_path) )
return datas[trainingdatatype]
def X_LOAD ( RAWS ):
sample_list = []
for s in tqdm( RAWS, desc="Loading" ):
s_filename_path = Path(s.filename)
if s_filename_path.suffix != '.png':
print ("%s is not a png file required for training" % (s_filename_path.name) )
continue
a_png = AlignedPNG.load ( str(s_filename_path) )
if a_png is None:
print ("%s failed to load" % (s_filename_path.name) )
continue
d = a_png.getFaceswapDictData()
if d is None or d['landmarks'] is None or d['yaw_value'] is None:
print ("%s - no embedded faceswap info found required for training" % (s_filename_path.name) )
continue
face_type = d['face_type'] if 'face_type' in d.keys() else 'full_face'
face_type = FaceType.fromString (face_type)
sample_list.append( s.copy_and_set(face_type=face_type, shape=a_png.get_shape(), landmarks=d['landmarks'], yaw=d['yaw_value']) )
return sample_list
def X_YAW_SORTED( YAW_RAWS ):
lowest_yaw, highest_yaw = -32, +32
gradations = 64
diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations
yaws_sample_list = [None]*gradations
for i in tqdm( range(0, gradations), desc="Sorting" ):
yaw = lowest_yaw + i*diff_rot_per_grad
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
yaw_samples = []
for s in YAW_RAWS:
s_yaw = s.yaw
if (i == 0 and s_yaw < next_yaw) or \
(i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
(i == gradations-1 and s_yaw >= yaw):
yaw_samples.append ( s )
if len(yaw_samples) > 0:
yaws_sample_list[i] = yaw_samples
return yaws_sample_list
def X_YAW_AS_Y_SORTED (s, t):
l = len(s)
if l != len(t):
raise Exception('X_YAW_AS_Y_SORTED() s_len != t_len')
b = l // 2
s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0]
t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0]
new_s = [None]*l
for t_idx in t_idxs:
search_idxs = []
for i in range(0,l):
search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i]
for search_idx in search_idxs:
if search_idx in s_idxs:
mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) )
new_s[t_idx] = [ sample.copy_and_set(mirror=True, yaw=-sample.yaw, landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] ))
for sample in s[search_idx]
] if mirrored else s[search_idx]
break
return new_s