initial code to extract umdfaces.io dataset and train pose estimator

This commit is contained in:
iperov 2019-04-23 08:14:09 +04:00
parent 51a917facc
commit e58197ca22
18 changed files with 437 additions and 57 deletions

View file

@ -35,9 +35,9 @@ class SampleLoader:
if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceSamples( [ Sample(filename=filename) for filename in Path_utils.get_image_paths(samples_path) ] )
# elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
# if datas[sample_type] is None:
# datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_TEMPORAL_SORTED:
if datas[sample_type] is None:
datas[sample_type] = SampleLoader.upgradeToFaceTemporalSortedSamples( SampleLoader.load(SampleType.FACE, samples_path) )
elif sample_type == SampleType.FACE_YAW_SORTED:
if datas[sample_type] is None:
@ -69,15 +69,12 @@ class SampleLoader:
print ("%s is not a dfl image file required for training" % (s_filename_path.name) )
continue
pitch, yaw = LandmarksProcessor.estimate_pitch_yaw ( dflimg.get_landmarks() )
sample_list.append( s.copy_and_set(sample_type=SampleType.FACE,
face_type=FaceType.fromString (dflimg.get_face_type()),
shape=dflimg.get_shape(),
landmarks=dflimg.get_landmarks(),
ie_polys=dflimg.get_ie_polys(),
pitch=pitch,
yaw=yaw,
pitch_yaw_roll=dflimg.get_pitch_yaw_roll(),
source_filename=dflimg.get_source_filename(),
fanseg_mask_exist=dflimg.get_fanseg_mask() is not None, ) )
except:
@ -85,12 +82,12 @@ class SampleLoader:
return sample_list
# @staticmethod
# def upgradeToFaceTemporalSortedSamples( samples ):
# new_s = [ (s, s.source_filename) for s in samples]
# new_s = sorted(new_s, key=operator.itemgetter(1))
@staticmethod
def upgradeToFaceTemporalSortedSamples( samples ):
new_s = [ (s, s.source_filename) for s in samples]
new_s = sorted(new_s, key=operator.itemgetter(1))
# return [ s[0] for s in new_s]
return [ s[0] for s in new_s]
@staticmethod
def upgradeToFaceYawSortedSamples( samples ):