mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
initial
This commit is contained in:
parent
73de93b4f1
commit
6bd5a44264
71 changed files with 8448 additions and 0 deletions
245
models/TrainingDataGeneratorBase.py
Normal file
245
models/TrainingDataGeneratorBase.py
Normal file
|
@ -0,0 +1,245 @@
|
|||
import traceback
|
||||
import random
|
||||
from pathlib import Path
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
import cv2
|
||||
from utils.AlignedPNG import AlignedPNG
|
||||
from utils import iter_utils
|
||||
from utils import Path_utils
|
||||
from .BaseTypes import TrainingDataType
|
||||
from .BaseTypes import TrainingDataSample
|
||||
from facelib import FaceType
|
||||
from facelib import LandmarksProcessor
|
||||
|
||||
'''
|
||||
You can implement your own TrainingDataGenerator
|
||||
'''
|
||||
class TrainingDataGeneratorBase(object):
|
||||
cache = dict()
|
||||
|
||||
#DONT OVERRIDE
|
||||
#use YourOwnTrainingDataGenerator (..., your_opt=1)
|
||||
#and then this opt will be passed in YourOwnTrainingDataGenerator.onInitialize ( your_opt )
|
||||
def __init__ (self, trainingdatatype, training_data_path, target_training_data_path=None, debug=False, batch_size=1, **kwargs):
|
||||
if not isinstance(trainingdatatype, TrainingDataType):
|
||||
raise Exception('TrainingDataGeneratorBase() trainingdatatype is not TrainingDataType')
|
||||
|
||||
if training_data_path is None:
|
||||
raise Exception('training_data_path is None')
|
||||
|
||||
self.training_data_path = Path(training_data_path)
|
||||
self.target_training_data_path = Path(target_training_data_path) if target_training_data_path is not None else None
|
||||
|
||||
self.debug = debug
|
||||
self.batch_size = 1 if self.debug else batch_size
|
||||
self.trainingdatatype = trainingdatatype
|
||||
self.data = TrainingDataGeneratorBase.load (trainingdatatype, self.training_data_path, self.target_training_data_path)
|
||||
|
||||
if self.debug:
|
||||
self.generators = [iter_utils.ThisThreadGenerator ( self.batch_func, self.data)]
|
||||
else:
|
||||
if len(self.data) > 1:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, self.data[0::2] ),
|
||||
iter_utils.SubprocessGenerator ( self.batch_func, self.data[1::2] )]
|
||||
else:
|
||||
self.generators = [iter_utils.SubprocessGenerator ( self.batch_func, self.data )]
|
||||
|
||||
self.generator_counter = -1
|
||||
self.onInitialize(**kwargs)
|
||||
|
||||
#overridable
|
||||
def onInitialize(self, **kwargs):
|
||||
#your TrainingDataGenerator initialization here
|
||||
pass
|
||||
|
||||
#overridable
|
||||
def onProcessSample(self, sample, debug):
|
||||
#process sample and return tuple of images for your model in onTrainOneEpoch
|
||||
return ( np.zeros( (64,64,4), dtype=np.float32 ), )
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
self.generator_counter += 1
|
||||
generator = self.generators[self.generator_counter % len(self.generators) ]
|
||||
x = next(generator)
|
||||
return x
|
||||
|
||||
def batch_func(self, data):
|
||||
data_len = len(data)
|
||||
if data_len == 0:
|
||||
raise ValueError('No training data provided.')
|
||||
|
||||
if self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if all ( [ x == None for x in data] ):
|
||||
raise ValueError('Not enough training data. Gather more faces!')
|
||||
|
||||
if self.trainingdatatype == TrainingDataType.IMAGE or self.trainingdatatype == TrainingDataType.FACE:
|
||||
shuffle_idxs = []
|
||||
elif self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
|
||||
shuffle_idxs = []
|
||||
shuffle_idxs_2D = [[]]*data_len
|
||||
|
||||
while True:
|
||||
|
||||
batches = None
|
||||
for n_batch in range(0, self.batch_size):
|
||||
while True:
|
||||
sample = None
|
||||
|
||||
if self.trainingdatatype == TrainingDataType.IMAGE or self.trainingdatatype == TrainingDataType.FACE:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = [ i for i in range(0, data_len) ]
|
||||
random.shuffle(shuffle_idxs)
|
||||
idx = shuffle_idxs.pop()
|
||||
sample = data[ idx ]
|
||||
elif self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED or self.trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if len(shuffle_idxs) == 0:
|
||||
shuffle_idxs = [ i for i in range(0, data_len) ]
|
||||
random.shuffle(shuffle_idxs)
|
||||
|
||||
idx = shuffle_idxs.pop()
|
||||
if data[idx] != None:
|
||||
if len(shuffle_idxs_2D[idx]) == 0:
|
||||
shuffle_idxs_2D[idx] = [ i for i in range(0, len(data[idx])) ]
|
||||
random.shuffle(shuffle_idxs_2D[idx])
|
||||
|
||||
idx2 = shuffle_idxs_2D[idx].pop()
|
||||
sample = data[idx][idx2]
|
||||
|
||||
if sample is not None:
|
||||
try:
|
||||
x = self.onProcessSample (sample, self.debug)
|
||||
except:
|
||||
raise Exception ("Exception occured in sample %s. Error: %s" % (sample.filename, traceback.format_exc() ) )
|
||||
|
||||
if type(x) != tuple and type(x) != list:
|
||||
raise Exception('TrainingDataGenerator.onProcessSample() returns NOT tuple/list')
|
||||
|
||||
x_len = len(x)
|
||||
if batches is None:
|
||||
batches = [ [] for _ in range(0,x_len) ]
|
||||
|
||||
for i in range(0,x_len):
|
||||
batches[i].append ( x[i] )
|
||||
|
||||
break
|
||||
|
||||
yield [ np.array(batch) for batch in batches]
|
||||
|
||||
def get_dict_state(self):
|
||||
return {}
|
||||
|
||||
def set_dict_state(self, state):
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def load(trainingdatatype, training_data_path, target_training_data_path=None):
|
||||
cache = TrainingDataGeneratorBase.cache
|
||||
|
||||
if str(training_data_path) not in cache.keys():
|
||||
cache[str(training_data_path)] = [None]*TrainingDataType.QTY
|
||||
|
||||
if target_training_data_path is not None and str(target_training_data_path) not in cache.keys():
|
||||
cache[str(target_training_data_path)] = [None]*TrainingDataType.QTY
|
||||
|
||||
datas = cache[str(training_data_path)]
|
||||
|
||||
if trainingdatatype == TrainingDataType.IMAGE:
|
||||
if datas[trainingdatatype] is None:
|
||||
datas[trainingdatatype] = [ TrainingDataSample(filename=filename) for filename in tqdm( Path_utils.get_image_paths(training_data_path), desc="Loading" ) ]
|
||||
|
||||
elif trainingdatatype == TrainingDataType.FACE:
|
||||
if datas[trainingdatatype] is None:
|
||||
datas[trainingdatatype] = X_LOAD( [ TrainingDataSample(filename=filename) for filename in Path_utils.get_image_paths(training_data_path) ] )
|
||||
|
||||
elif trainingdatatype == TrainingDataType.FACE_YAW_SORTED:
|
||||
if datas[trainingdatatype] is None:
|
||||
datas[trainingdatatype] = X_YAW_SORTED( TrainingDataGeneratorBase.load(TrainingDataType.FACE, training_data_path) )
|
||||
|
||||
elif trainingdatatype == TrainingDataType.FACE_YAW_SORTED_AS_TARGET:
|
||||
if datas[trainingdatatype] is None:
|
||||
if target_training_data_path is None:
|
||||
raise Exception('target_training_data_path is None for FACE_YAW_SORTED_AS_TARGET')
|
||||
datas[trainingdatatype] = X_YAW_AS_Y_SORTED( TrainingDataGeneratorBase.load(TrainingDataType.FACE_YAW_SORTED, training_data_path), TrainingDataGeneratorBase.load(TrainingDataType.FACE_YAW_SORTED, target_training_data_path) )
|
||||
|
||||
return datas[trainingdatatype]
|
||||
|
||||
def X_LOAD ( RAWS ):
|
||||
sample_list = []
|
||||
|
||||
for s in tqdm( RAWS, desc="Loading" ):
|
||||
|
||||
s_filename_path = Path(s.filename)
|
||||
if s_filename_path.suffix != '.png':
|
||||
print ("%s is not a png file required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
a_png = AlignedPNG.load ( str(s_filename_path) )
|
||||
if a_png is None:
|
||||
print ("%s failed to load" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
d = a_png.getFaceswapDictData()
|
||||
if d is None or d['landmarks'] is None or d['yaw_value'] is None:
|
||||
print ("%s - no embedded faceswap info found required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
face_type = d['face_type'] if 'face_type' in d.keys() else 'full_face'
|
||||
face_type = FaceType.fromString (face_type)
|
||||
sample_list.append( s.copy_and_set(face_type=face_type, shape=a_png.get_shape(), landmarks=d['landmarks'], yaw=d['yaw_value']) )
|
||||
|
||||
return sample_list
|
||||
|
||||
def X_YAW_SORTED( YAW_RAWS ):
|
||||
|
||||
lowest_yaw, highest_yaw = -32, +32
|
||||
gradations = 64
|
||||
diff_rot_per_grad = abs(highest_yaw-lowest_yaw) / gradations
|
||||
|
||||
yaws_sample_list = [None]*gradations
|
||||
|
||||
for i in tqdm( range(0, gradations), desc="Sorting" ):
|
||||
yaw = lowest_yaw + i*diff_rot_per_grad
|
||||
next_yaw = lowest_yaw + (i+1)*diff_rot_per_grad
|
||||
|
||||
yaw_samples = []
|
||||
for s in YAW_RAWS:
|
||||
s_yaw = s.yaw
|
||||
if (i == 0 and s_yaw < next_yaw) or \
|
||||
(i < gradations-1 and s_yaw >= yaw and s_yaw < next_yaw) or \
|
||||
(i == gradations-1 and s_yaw >= yaw):
|
||||
yaw_samples.append ( s )
|
||||
|
||||
if len(yaw_samples) > 0:
|
||||
yaws_sample_list[i] = yaw_samples
|
||||
|
||||
return yaws_sample_list
|
||||
|
||||
def X_YAW_AS_Y_SORTED (s, t):
|
||||
l = len(s)
|
||||
if l != len(t):
|
||||
raise Exception('X_YAW_AS_Y_SORTED() s_len != t_len')
|
||||
b = l // 2
|
||||
|
||||
s_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in s] ) == 1 )[:,0]
|
||||
t_idxs = np.argwhere ( np.array ( [ 1 if x != None else 0 for x in t] ) == 1 )[:,0]
|
||||
|
||||
new_s = [None]*l
|
||||
|
||||
for t_idx in t_idxs:
|
||||
search_idxs = []
|
||||
for i in range(0,l):
|
||||
search_idxs += [t_idx - i, (l-t_idx-1) - i, t_idx + i, (l-t_idx-1) + i]
|
||||
|
||||
for search_idx in search_idxs:
|
||||
if search_idx in s_idxs:
|
||||
mirrored = ( t_idx != search_idx and ((t_idx < b and search_idx >= b) or (search_idx < b and t_idx >= b)) )
|
||||
new_s[t_idx] = [ sample.copy_and_set(mirror=True, yaw=-sample.yaw, landmarks=LandmarksProcessor.mirror_landmarks (sample.landmarks, sample.shape[1] ))
|
||||
for sample in s[search_idx]
|
||||
] if mirrored else s[search_idx]
|
||||
break
|
||||
|
||||
return new_s
|
Loading…
Add table
Add a link
Reference in a new issue