This commit is contained in:
Colombo 2020-11-22 18:26:54 +04:00
parent c516454566
commit e9e7344424
4 changed files with 38 additions and 12 deletions

View file

@ -26,11 +26,14 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
seg_samples_len = len(seg_sample_idxs)
if seg_samples_len == 0:
raise Exception(f"No segmented faces found.")
if len(seg_sample_idxs) == 0:
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run()
if len(seg_sample_idxs) == 0:
raise Exception(f"No segmented faces found.")
else:
io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.")
else:
io.log_info(f"Using {seg_samples_len} segmented samples.")
io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.")
if self.debug:
self.generators_count = 1
@ -80,8 +83,16 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
def gen_img_mask(sample):
img = sample.load_bgr()
h,w,c = img.shape
mask = np.zeros ((h,w,1), dtype=np.float32)
sample.seg_ie_polys.overlay_mask(mask)
if sample.seg_ie_polys.has_polys():
mask = np.zeros ((h,w,1), dtype=np.float32)
sample.seg_ie_polys.overlay_mask(mask)
elif sample.has_xseg_mask():
mask = sample.get_xseg_mask()
mask[mask < 0.5] = 0.0
mask[mask >= 0.5] = 1.0
else:
raise Exception(f'no mask in sample {sample.filename}')
if face_type == sample.face_type:
if w != resolution:
@ -158,9 +169,10 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
class SegmentedSampleFilterSubprocessor(Subprocessor):
#override
def __init__(self, samples ):
def __init__(self, samples, count_xseg_mask=False ):
self.samples = samples
self.samples_len = len(self.samples)
self.count_xseg_mask = count_xseg_mask
self.idxs = [*range(self.samples_len)]
self.result = []
@ -169,7 +181,7 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
#override
def process_info_generator(self):
for i in range(multiprocessing.cpu_count()):
yield 'CPU%d' % (i), {}, {'samples':self.samples}
yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask}
#override
def on_clients_initialized(self):
@ -203,6 +215,10 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
#overridable optional
def on_initialize(self, client_dict):
self.samples = client_dict['samples']
self.count_xseg_mask = client_dict['count_xseg_mask']
def process_data(self, idx):
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0
if self.count_xseg_mask:
return idx, self.samples[idx].has_xseg_mask()
else:
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0