mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
upd
This commit is contained in:
parent
c516454566
commit
e9e7344424
4 changed files with 38 additions and 12 deletions
|
@ -281,6 +281,13 @@ class DFLJPG(object):
|
||||||
def has_xseg_mask(self):
|
def has_xseg_mask(self):
|
||||||
return self.dfl_dict.get('xseg_mask',None) is not None
|
return self.dfl_dict.get('xseg_mask',None) is not None
|
||||||
|
|
||||||
|
def get_xseg_mask_compressed(self):
|
||||||
|
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||||
|
if mask_buf is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return mask_buf
|
||||||
|
|
||||||
def get_xseg_mask(self):
|
def get_xseg_mask(self):
|
||||||
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
mask_buf = self.dfl_dict.get('xseg_mask',None)
|
||||||
if mask_buf is None:
|
if mask_buf is None:
|
||||||
|
|
|
@ -79,6 +79,9 @@ class Sample(object):
|
||||||
|
|
||||||
self._filename_offset_size = None
|
self._filename_offset_size = None
|
||||||
|
|
||||||
|
def has_xseg_mask(self):
|
||||||
|
return self.xseg_mask is not None or self.xseg_mask_compressed is not None
|
||||||
|
|
||||||
def get_xseg_mask(self):
|
def get_xseg_mask(self):
|
||||||
if self.xseg_mask_compressed is not None:
|
if self.xseg_mask_compressed is not None:
|
||||||
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
|
xseg_mask = cv2.imdecode(self.xseg_mask_compressed, cv2.IMREAD_UNCHANGED)
|
||||||
|
|
|
@ -26,11 +26,14 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
|
samples = sum([ SampleLoader.load (SampleType.FACE, path) for path in paths ] )
|
||||||
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
|
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples).run()
|
||||||
|
|
||||||
seg_samples_len = len(seg_sample_idxs)
|
if len(seg_sample_idxs) == 0:
|
||||||
if seg_samples_len == 0:
|
seg_sample_idxs = SegmentedSampleFilterSubprocessor(samples, count_xseg_mask=True).run()
|
||||||
|
if len(seg_sample_idxs) == 0:
|
||||||
raise Exception(f"No segmented faces found.")
|
raise Exception(f"No segmented faces found.")
|
||||||
else:
|
else:
|
||||||
io.log_info(f"Using {seg_samples_len} segmented samples.")
|
io.log_info(f"Using {len(seg_sample_idxs)} xseg labeled samples.")
|
||||||
|
else:
|
||||||
|
io.log_info(f"Using {len(seg_sample_idxs)} segmented samples.")
|
||||||
|
|
||||||
if self.debug:
|
if self.debug:
|
||||||
self.generators_count = 1
|
self.generators_count = 1
|
||||||
|
@ -80,8 +83,16 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
def gen_img_mask(sample):
|
def gen_img_mask(sample):
|
||||||
img = sample.load_bgr()
|
img = sample.load_bgr()
|
||||||
h,w,c = img.shape
|
h,w,c = img.shape
|
||||||
|
|
||||||
|
if sample.seg_ie_polys.has_polys():
|
||||||
mask = np.zeros ((h,w,1), dtype=np.float32)
|
mask = np.zeros ((h,w,1), dtype=np.float32)
|
||||||
sample.seg_ie_polys.overlay_mask(mask)
|
sample.seg_ie_polys.overlay_mask(mask)
|
||||||
|
elif sample.has_xseg_mask():
|
||||||
|
mask = sample.get_xseg_mask()
|
||||||
|
mask[mask < 0.5] = 0.0
|
||||||
|
mask[mask >= 0.5] = 1.0
|
||||||
|
else:
|
||||||
|
raise Exception(f'no mask in sample {sample.filename}')
|
||||||
|
|
||||||
if face_type == sample.face_type:
|
if face_type == sample.face_type:
|
||||||
if w != resolution:
|
if w != resolution:
|
||||||
|
@ -158,9 +169,10 @@ class SampleGeneratorFaceXSeg(SampleGeneratorBase):
|
||||||
|
|
||||||
class SegmentedSampleFilterSubprocessor(Subprocessor):
|
class SegmentedSampleFilterSubprocessor(Subprocessor):
|
||||||
#override
|
#override
|
||||||
def __init__(self, samples ):
|
def __init__(self, samples, count_xseg_mask=False ):
|
||||||
self.samples = samples
|
self.samples = samples
|
||||||
self.samples_len = len(self.samples)
|
self.samples_len = len(self.samples)
|
||||||
|
self.count_xseg_mask = count_xseg_mask
|
||||||
|
|
||||||
self.idxs = [*range(self.samples_len)]
|
self.idxs = [*range(self.samples_len)]
|
||||||
self.result = []
|
self.result = []
|
||||||
|
@ -169,7 +181,7 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
|
||||||
#override
|
#override
|
||||||
def process_info_generator(self):
|
def process_info_generator(self):
|
||||||
for i in range(multiprocessing.cpu_count()):
|
for i in range(multiprocessing.cpu_count()):
|
||||||
yield 'CPU%d' % (i), {}, {'samples':self.samples}
|
yield 'CPU%d' % (i), {}, {'samples':self.samples, 'count_xseg_mask':self.count_xseg_mask}
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def on_clients_initialized(self):
|
def on_clients_initialized(self):
|
||||||
|
@ -203,6 +215,10 @@ class SegmentedSampleFilterSubprocessor(Subprocessor):
|
||||||
#overridable optional
|
#overridable optional
|
||||||
def on_initialize(self, client_dict):
|
def on_initialize(self, client_dict):
|
||||||
self.samples = client_dict['samples']
|
self.samples = client_dict['samples']
|
||||||
|
self.count_xseg_mask = client_dict['count_xseg_mask']
|
||||||
|
|
||||||
def process_data(self, idx):
|
def process_data(self, idx):
|
||||||
|
if self.count_xseg_mask:
|
||||||
|
return idx, self.samples[idx].has_xseg_mask()
|
||||||
|
else:
|
||||||
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0
|
return idx, self.samples[idx].seg_ie_polys.get_pts_count() != 0
|
|
@ -81,7 +81,7 @@ class SampleLoader:
|
||||||
shape,
|
shape,
|
||||||
landmarks,
|
landmarks,
|
||||||
seg_ie_polys,
|
seg_ie_polys,
|
||||||
xseg_mask,
|
xseg_mask_compressed,
|
||||||
eyebrows_expand_mod,
|
eyebrows_expand_mod,
|
||||||
source_filename ) = data
|
source_filename ) = data
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ class SampleLoader:
|
||||||
shape=shape,
|
shape=shape,
|
||||||
landmarks=landmarks,
|
landmarks=landmarks,
|
||||||
seg_ie_polys=seg_ie_polys,
|
seg_ie_polys=seg_ie_polys,
|
||||||
xseg_mask=xseg_mask,
|
xseg_mask_compressed=xseg_mask_compressed,
|
||||||
eyebrows_expand_mod=eyebrows_expand_mod,
|
eyebrows_expand_mod=eyebrows_expand_mod,
|
||||||
source_filename=source_filename,
|
source_filename=source_filename,
|
||||||
))
|
))
|
||||||
|
@ -163,7 +163,7 @@ class FaceSamplesLoaderSubprocessor(Subprocessor):
|
||||||
dflimg.get_shape(),
|
dflimg.get_shape(),
|
||||||
dflimg.get_landmarks(),
|
dflimg.get_landmarks(),
|
||||||
dflimg.get_seg_ie_polys(),
|
dflimg.get_seg_ie_polys(),
|
||||||
dflimg.get_xseg_mask(),
|
dflimg.get_xseg_mask_compressed(),
|
||||||
dflimg.get_eyebrows_expand_mod(),
|
dflimg.get_eyebrows_expand_mod(),
|
||||||
dflimg.get_source_filename() )
|
dflimg.get_source_filename() )
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue