This commit is contained in:
iperov 2019-04-18 21:21:43 +04:00
commit 0bdba117ef

View file

@ -80,7 +80,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
if all ( [ samples[idx] == None for idx in samples_idxs] ):
raise ValueError('Not enough training data. Gather more faces!')
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
if self.sample_type == SampleType.FACE:
shuffle_idxs = []
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
shuffle_idxs = []
@ -100,12 +100,12 @@ class SampleGeneratorFace(SampleGeneratorBase):
if len(repeat_samples_idxs) > 0:
idx = repeat_samples_idxs.pop()
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
if self.sample_type == SampleType.FACE:
sample = samples[idx]
elif self.sample_type == SampleType.FACE_YAW_SORTED or self.sample_type == SampleType.FACE_YAW_SORTED_AS_TARGET:
sample = samples[(idx >> 16) & 0xFFFF][idx & 0xFFFF]
else:
if self.sample_type == SampleType.FACE or self.sample_type == SampleType.FACE_WITH_CLOSE_TO_SELF:
if self.sample_type == SampleType.FACE:
if len(shuffle_idxs) == 0:
shuffle_idxs = samples_idxs.copy()
np.random.shuffle(shuffle_idxs)