diff --git a/DFLIMG/DFLJPG.py b/DFLIMG/DFLJPG.py index ae8f765..291db71 100644 --- a/DFLIMG/DFLJPG.py +++ b/DFLIMG/DFLJPG.py @@ -281,6 +281,13 @@ class DFLJPG(object): def has_xseg_mask(self): return self.dfl_dict.get('xseg_mask',None) is not None + def get_xseg_mask_compressed(self): + mask_buf = self.dfl_dict.get('xseg_mask',None) + if mask_buf is None: + return None + + return mask_buf + def get_xseg_mask(self): mask_buf = self.dfl_dict.get('xseg_mask',None) if mask_buf is None: diff --git a/samplelib/Sample.py b/samplelib/Sample.py index 7ced210..a379275 100644 --- a/samplelib/Sample.py +++ b/samplelib/Sample.py @@ -79,6 +79,9 @@ class Sample(object): self._filename_offset_size = None + def has_xseg_mask(self): + return self.xseg_mask is not None or self.xseg_mask_compressed is not None + def get_xseg_mask(self): if self.xseg_mask_compressed is not None: xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED) diff --git a/samplelib/SampleGeneratorFaceXSeg.py b/samplelib/SampleGeneratorFaceXSeg.py index 7999d16..9c7cf6e 100644 --- a/samplelib/SampleGeneratorFaceXSeg.py +++ b/samplelib/SampleGeneratorFaceXSeg.py @@ -26,11 +26,14 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] ) seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run() - seg_samples_len = len(seg_sample_idxs) - if seg_samples_len == 0: - raise Exception(f"No segmented faces found.") + if len(seg_sample_idxs) == 0: + seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run() + if len(seg_sample_idxs) == 0: + raise Exception(f"No segmented faces found.") + else: + io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.") else: - io.log_info(f"Using {seg_samples_len} segmented samples.") + io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.") if self.debug: self.generators_count = 1 @@ -80,8 +83,16 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): def gen_img_mask(sample): img = sample.load_bgr() h,w,c = img.shape - mask = np.zeros ((h,w,1), dtype=np.float32) - sample.seg_ie_polys.overlay_mask(mask) + + if sample.seg_ie_polys.has_polys(): + mask = np.zeros ((h,w,1), dtype=np.float32) + sample.seg_ie_polys.overlay_mask(mask) + elif sample.has_xseg_mask(): + mask = sample.get_xseg_mask() + mask[mask < 0.5] = 0.0 + mask[mask >= 0.5] = 1.0 + else: + raise Exception(f'no mask in sample {sample.filename}') if face_type == sample.face_type: if w != resolution: @@ -158,9 +169,10 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase): class SegmentedSampleFilterSubprocessor(Subprocessor): #override - def __init__(self, samples ): + def __init__(self, samples, count_xseg_mask=False ): self.samples = samples self.samples_len = len(self.samples) + self.count_xseg_mask = count_xseg_mask self.idxs = [*range(self.samples_len)] self.result = [] @@ -169,7 +181,7 @@ class SegmentedSampleFilterSubprocessor(Subprocessor): #override def process_info_generator(self): for i in range(multiprocessing.cpu_count()): - yield 'CPU%d' % (i), {}, {'samples':self.samples} + yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask} #override def on_clients_initialized(self): @@ -203,6 +215,10 @@ class SegmentedSampleFilterSubprocessor(Subprocessor): #overridable optional def on_initialize(self, client_dict): self.samples = client_dict['samples'] + self.count_xseg_mask = client_dict['count_xseg_mask'] def process_data(self, idx): - return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0 \ No newline at end of file + if self.count_xseg_mask: + return idx, self.samples[idx].has_xseg_mask() + else: + return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0 \ No newline at end of file diff --git a/samplelib/SampleLoader.py b/samplelib/SampleLoader.py index 46d0962..edb3775 100644 --- a/samplelib/SampleLoader.py +++ b/samplelib/SampleLoader.py @@ -81,7 +81,7 @@ class SampleLoader: shape, landmarks, seg_ie_polys, - xseg_mask, + xseg_mask_compressed, eyebrows_expand_mod, source_filename ) = data @@ -91,7 +91,7 @@ class SampleLoader: shape=shape, landmarks=landmarks, seg_ie_polys=seg_ie_polys, - xseg_mask=xseg_mask, + xseg_mask_compressed=xseg_mask_compressed, eyebrows_expand_mod=eyebrows_expand_mod, source_filename=source_filename, )) @@ -163,7 +163,7 @@ class FaceSamplesLoaderSubprocessor(Subprocessor): dflimg.get_shape(), dflimg.get_landmarks(), dflimg.get_seg_ie_polys(), - dflimg.get_xseg_mask(), + dflimg.get_xseg_mask_compressed(), dflimg.get_eyebrows_expand_mod(), dflimg.get_source_filename() )