mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
update FANSeg
This commit is contained in:
parent
6f4ea69d4d
commit
18d93376fc
4 changed files with 61 additions and 17 deletions
|
@ -412,7 +412,13 @@ class ModelBase(object):
|
||||||
return imagelib.equalize_and_stack_square (images)
|
return imagelib.equalize_and_stack_square (images)
|
||||||
|
|
||||||
def generate_next_samples(self):
|
def generate_next_samples(self):
|
||||||
self.last_sample = sample = [ generator.generate_next() for generator in self.generator_list]
|
sample = []
|
||||||
|
for generator in self.generator_list:
|
||||||
|
if generator.is_initialized():
|
||||||
|
sample.append ( generator.generate_next() )
|
||||||
|
else:
|
||||||
|
sample.append ( [] )
|
||||||
|
self.last_sample = sample
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
|
|
@ -24,7 +24,6 @@ class FANSegModel(ModelBase):
|
||||||
ask_override = self.ask_override()
|
ask_override = self.ask_override()
|
||||||
if self.is_first_run() or ask_override:
|
if self.is_first_run() or ask_override:
|
||||||
self.ask_autobackup_hour()
|
self.ask_autobackup_hour()
|
||||||
self.ask_write_preview_history()
|
|
||||||
self.ask_target_iter()
|
self.ask_target_iter()
|
||||||
self.ask_batch_size(4)
|
self.ask_batch_size(4)
|
||||||
|
|
||||||
|
@ -117,21 +116,30 @@ class FANSegModel(ModelBase):
|
||||||
|
|
||||||
# initializing sample generators
|
# initializing sample generators
|
||||||
training_data_src_path = self.training_data_src_path
|
training_data_src_path = self.training_data_src_path
|
||||||
#training_data_dst_path = self.training_data_dst_path
|
training_data_dst_path = self.training_data_dst_path
|
||||||
|
|
||||||
cpu_count = min(multiprocessing.cpu_count(), 8)
|
cpu_count = min(multiprocessing.cpu_count(), 8)
|
||||||
src_generators_count = cpu_count // 2
|
src_generators_count = cpu_count // 2
|
||||||
dst_generators_count = cpu_count // 2
|
dst_generators_count = cpu_count // 2
|
||||||
src_generators_count = int(src_generators_count * 1.5)
|
src_generators_count = int(src_generators_count * 1.5)
|
||||||
|
|
||||||
self.set_training_data_generators ([
|
src_generator = SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
SampleGeneratorFace(training_data_src_path, random_ct_samples_path=training_data_src_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||||
sample_process_options=SampleProcessor.Options(random_flip=True),
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'ct_mode':'idt', 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution},
|
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
{'sample_type': SampleProcessor.SampleType.FACE_MASK, 'warp':True, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.G, 'face_mask_type' : SampleProcessor.FaceMaskType.NONE, 'face_type':self.face_type, 'data_format':nn.data_format, 'resolution': resolution},
|
],
|
||||||
],
|
generators_count=src_generators_count )
|
||||||
generators_count=src_generators_count ),
|
|
||||||
])
|
dst_generator = SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.get_batch_size(),
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=True),
|
||||||
|
output_sample_types = [ {'sample_type': SampleProcessor.SampleType.FACE_IMAGE, 'warp':False, 'transform':True, 'channel_type' : SampleProcessor.ChannelType.BGR, 'face_type':self.face_type, 'motion_blur':(25, 5), 'gaussian_blur':(25,5), 'data_format':nn.data_format, 'resolution': resolution},
|
||||||
|
],
|
||||||
|
generators_count=dst_generators_count,
|
||||||
|
raise_on_no_data=False )
|
||||||
|
if not dst_generator.is_initialized():
|
||||||
|
io.log_info(f"\nTo view the model on unseen faces, place any aligned faces in {training_data_dst_path}.\n")
|
||||||
|
|
||||||
|
self.set_training_data_generators ([src_generator, dst_generator])
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def get_model_filename_list(self):
|
def get_model_filename_list(self):
|
||||||
|
@ -143,7 +151,7 @@ class FANSegModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onTrainOneIter(self):
|
def onTrainOneIter(self):
|
||||||
( (source_np, target_np), ) = self.generate_next_samples()
|
source_np, target_np = self.generate_next_samples()[0]
|
||||||
loss = self.train (source_np, target_np)
|
loss = self.train (source_np, target_np)
|
||||||
|
|
||||||
return ( ('loss', loss ), )
|
return ( ('loss', loss ), )
|
||||||
|
@ -152,7 +160,8 @@ class FANSegModel(ModelBase):
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
n_samples = min(4, self.get_batch_size(), 800 // self.resolution )
|
||||||
|
|
||||||
( (source_np, target_np), ) = samples
|
src_samples, dst_samples = samples
|
||||||
|
source_np, target_np = src_samples
|
||||||
|
|
||||||
S, T, SM, = [ np.clip(x, 0.0, 1.0) for x in ([source_np,target_np] + self.view (source_np) ) ]
|
S, T, SM, = [ np.clip(x, 0.0, 1.0) for x in ([source_np,target_np] + self.view (source_np) ) ]
|
||||||
T, SM, = [ np.repeat (x, (3,), -1) for x in [T, SM] ]
|
T, SM, = [ np.repeat (x, (3,), -1) for x in [T, SM] ]
|
||||||
|
@ -164,8 +173,22 @@ class FANSegModel(ModelBase):
|
||||||
ar = S[i], T[i], SM[i], S[i]*SM[i]
|
ar = S[i], T[i], SM[i], S[i]*SM[i]
|
||||||
#todo green bg
|
#todo green bg
|
||||||
st.append ( np.concatenate ( ar, axis=1) )
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
result += [ ('FANSeg', np.concatenate (st, axis=0 )), ]
|
result += [ ('FANSeg training faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
|
if len(dst_samples) != 0:
|
||||||
|
dst_np, = dst_samples
|
||||||
|
|
||||||
|
D, DM, = [ np.clip(x, 0.0, 1.0) for x in ([dst_np] + self.view (dst_np) ) ]
|
||||||
|
DM, = [ np.repeat (x, (3,), -1) for x in [DM] ]
|
||||||
|
|
||||||
|
st = []
|
||||||
|
for i in range(n_samples):
|
||||||
|
ar = D[i], DM[i], D[i]*DM[i]
|
||||||
|
#todo green bg
|
||||||
|
st.append ( np.concatenate ( ar, axis=1) )
|
||||||
|
|
||||||
|
result += [ ('FANSeg unseen faces', np.concatenate (st, axis=0 )), ]
|
||||||
|
|
||||||
return result
|
return result
|
||||||
|
|
||||||
Model = FANSegModel
|
Model = FANSegModel
|
||||||
|
|
|
@ -33,3 +33,7 @@ class SampleGeneratorBase(object):
|
||||||
def __next__(self):
|
def __next__(self):
|
||||||
#implement your own iterator
|
#implement your own iterator
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def is_initialized(self):
|
||||||
|
return True
|
|
@ -27,6 +27,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
output_sample_types=[],
|
output_sample_types=[],
|
||||||
add_sample_idx=False,
|
add_sample_idx=False,
|
||||||
generators_count=4,
|
generators_count=4,
|
||||||
|
raise_on_no_data=True,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
super().__init__(samples_path, debug, batch_size)
|
super().__init__(samples_path, debug, batch_size)
|
||||||
|
@ -42,8 +43,12 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
samples = SampleLoader.load (SampleType.FACE, self.samples_path)
|
samples = SampleLoader.load (SampleType.FACE, self.samples_path)
|
||||||
self.samples_len = len(samples)
|
self.samples_len = len(samples)
|
||||||
|
|
||||||
|
self.initialized = False
|
||||||
if self.samples_len == 0:
|
if self.samples_len == 0:
|
||||||
raise ValueError('No training data provided.')
|
if raise_on_no_data:
|
||||||
|
raise ValueError('No training data provided.')
|
||||||
|
else:
|
||||||
|
return
|
||||||
|
|
||||||
index_host = mplib.IndexHost(self.samples_len)
|
index_host = mplib.IndexHost(self.samples_len)
|
||||||
|
|
||||||
|
@ -66,7 +71,13 @@ class SampleGeneratorFace(SampleGeneratorBase):
|
||||||
SubprocessGenerator.start_in_parallel( self.generators )
|
SubprocessGenerator.start_in_parallel( self.generators )
|
||||||
|
|
||||||
self.generator_counter = -1
|
self.generator_counter = -1
|
||||||
|
|
||||||
|
self.initialized = True
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def is_initialized(self):
|
||||||
|
return self.initialized
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue