mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
initial code to extract umdfaces.io dataset and train pose estimator
This commit is contained in:
parent
51a917facc
commit
e58197ca22
18 changed files with 437 additions and 57 deletions
|
@ -35,9 +35,9 @@ class SampleLoader:
|
|||
if datas[sample_type] is None:
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
|
||||
|
||||
# elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
# if datas[sample_type] is None:
|
||||
# datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
||||
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
|
||||
if datas[sample_type] is None:
|
||||
datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
|
||||
|
||||
elif sample_type == SampleType.FACE_YAW_SORTED:
|
||||
if datas[sample_type] is None:
|
||||
|
@ -69,15 +69,12 @@ class SampleLoader:
|
|||
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
|
||||
continue
|
||||
|
||||
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
|
||||
|
||||
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
|
||||
face_type=FaceType.fromString (dflimg.get_face_type()),
|
||||
shape=dflimg.get_shape(),
|
||||
landmarks=dflimg.get_landmarks(),
|
||||
ie_polys=dflimg.get_ie_polys(),
|
||||
pitch=pitch,
|
||||
yaw=yaw,
|
||||
pitch_yaw_roll=dflimg.get_pitch_yaw_roll(),
|
||||
source_filename=dflimg.get_source_filename(),
|
||||
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
|
||||
except:
|
||||
|
@ -85,12 +82,12 @@ class SampleLoader:
|
|||
|
||||
return sample_list
|
||||
|
||||
# @staticmethod
|
||||
# def upgradeToFaceTemporalSortedSamples( samples ):
|
||||
# new_s = [ (s, s.source_filename) for s in samples]
|
||||
# new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||
@staticmethod
|
||||
def upgradeToFaceTemporalSortedSamples( samples ):
|
||||
new_s = [ (s, s.source_filename) for s in samples]
|
||||
new_s = sorted(new_s, key=operator.itemgetter(1))
|
||||
|
||||
# return [ s[0] for s in new_s]
|
||||
return [ s[0] for s in new_s]
|
||||
|
||||
@staticmethod
|
||||
def upgradeToFaceYawSortedSamples( samples ):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue