mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
fixed bug fix
This commit is contained in:
parent
ea226b2cba
commit
4adeb0aa79
2 changed files with 43 additions and 12 deletions
|
@ -145,16 +145,16 @@ class ModelBase(object):
|
||||||
" Memory error. Tune this value for your"
|
" Memory error. Tune this value for your"
|
||||||
" videocard manually."))
|
" videocard manually."))
|
||||||
self.options['ping_pong'] = io.input_bool(
|
self.options['ping_pong'] = io.input_bool(
|
||||||
"Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]),
|
"Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('batch_cap', False),
|
||||||
True,
|
self.options.get('batch_cap', False),
|
||||||
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
||||||
self.options['paddle'] = self.options.get('paddle','ping')
|
self.options['paddle'] = self.options.get('paddle','ping')
|
||||||
if self.options.get('ping_pong',True):
|
if self.options.get('ping_pong',False):
|
||||||
self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000))
|
self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.options['batch_cap'] = self.options.get('batch_cap', 16)
|
self.options['batch_cap'] = self.options.get('batch_cap', 16)
|
||||||
self.options['ping_pong'] = self.options.get('ping_pong', True)
|
self.options['ping_pong'] = self.options.get('ping_pong', False)
|
||||||
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000)
|
||||||
|
|
||||||
if ask_sort_by_yaw:
|
if ask_sort_by_yaw:
|
||||||
|
@ -525,10 +525,10 @@ class ModelBase(object):
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
|
||||||
if self.iter == 1 and self.options.get('ping_pong', True):
|
if self.iter == 1 and self.options.get('ping_pong', False):
|
||||||
self.set_batch_size(1)
|
self.set_batch_size(1)
|
||||||
self.paddle = 'ping'
|
self.paddle = 'ping'
|
||||||
elif not self.options.get('ping_pong', True) and self.batch_cap != self.batch_size:
|
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
|
||||||
self.set_batch_size(self.batch_cap)
|
self.set_batch_size(self.batch_cap)
|
||||||
sample = self.generate_next_sample()
|
sample = self.generate_next_sample()
|
||||||
iter_time = time.time()
|
iter_time = time.time()
|
||||||
|
@ -556,7 +556,7 @@ class ModelBase(object):
|
||||||
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||||
cv2_imwrite(filepath, img)
|
cv2_imwrite(filepath, img)
|
||||||
|
|
||||||
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', True):
|
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
|
||||||
if self.batch_size == self.batch_cap:
|
if self.batch_size == self.batch_cap:
|
||||||
self.paddle = 'pong'
|
self.paddle = 'pong'
|
||||||
if self.batch_size > self.batch_cap:
|
if self.batch_size > self.batch_cap:
|
||||||
|
|
|
@ -179,6 +179,7 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
global apply_random_ct
|
||||||
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
apply_random_ct = self.options.get('apply_random_ct', ColorTransferMode.NONE)
|
||||||
masked_training = True
|
masked_training = True
|
||||||
|
|
||||||
|
@ -457,11 +458,12 @@ class SAEModel(ModelBase):
|
||||||
global t_mode_bgr
|
global t_mode_bgr
|
||||||
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
t_mode_bgr = t.MODE_BGR if not self.pretrain else t.MODE_BGR_SHUFFLE
|
||||||
|
|
||||||
|
global training_data_src_path
|
||||||
training_data_src_path = self.training_data_src_path
|
training_data_src_path = self.training_data_src_path
|
||||||
|
global training_data_dst_path
|
||||||
training_data_dst_path= self.training_data_dst_path
|
training_data_dst_path= self.training_data_dst_path
|
||||||
|
|
||||||
|
global sort_by_yaw
|
||||||
sort_by_yaw = self.sort_by_yaw
|
sort_by_yaw = self.sort_by_yaw
|
||||||
|
|
||||||
if self.pretrain and self.pretraining_data_path is not None:
|
if self.pretrain and self.pretraining_data_path is not None:
|
||||||
|
@ -537,9 +539,38 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def set_batch_size(self, batch_size):
|
def set_batch_size(self, batch_size):
|
||||||
self.batch_size = batch_size
|
self.set_training_data_generators(None)
|
||||||
for i, generator in enumerate(self.generator_list):
|
self.set_training_data_generators([
|
||||||
generator.update_batch(batch_size)
|
SampleGeneratorFace(training_data_src_path,
|
||||||
|
sort_by_yaw_target_samples_path=training_data_dst_path if sort_by_yaw else None,
|
||||||
|
random_ct_samples_path=training_data_dst_path if apply_random_ct != ColorTransferMode.NONE else None,
|
||||||
|
debug=self.is_debug(), batch_size=batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip,
|
||||||
|
scale_range=np.array([-0.05,
|
||||||
|
0.05]) + self.src_scale_mod / 100.0),
|
||||||
|
output_sample_types=[{'types': (
|
||||||
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution, 'apply_ct': apply_random_ct}] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution // (2 ** i),
|
||||||
|
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)]
|
||||||
|
),
|
||||||
|
|
||||||
|
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=batch_size,
|
||||||
|
sample_process_options=SampleProcessor.Options(random_flip=self.random_flip, ),
|
||||||
|
output_sample_types=[{'types': (
|
||||||
|
t.IMG_WARPED_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution}] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t_mode_bgr),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)] + \
|
||||||
|
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
|
||||||
|
'resolution': resolution // (2 ** i)} for i in
|
||||||
|
range(ms_count)])
|
||||||
|
])
|
||||||
|
|
||||||
# override
|
# override
|
||||||
def onTrainOneIter(self, generators_samples, generators_list):
|
def onTrainOneIter(self, generators_samples, generators_list):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue