mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-22 06:23:20 -07:00
finished ping pong
This commit is contained in:
parent
d433acc4a0
commit
ca33aaa019
1 changed files with 24 additions and 8 deletions
|
@ -139,8 +139,8 @@ class ModelBase(object):
|
||||||
|
|
||||||
if ask_batch_size and (self.iter == 0 or ask_override):
|
if ask_batch_size and (self.iter == 0 or ask_override):
|
||||||
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0)
|
default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0)
|
||||||
self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % default_batch_size,
|
self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % 0,
|
||||||
default_batch_size,
|
0,
|
||||||
help_message="Larger batch size is better for NN's"
|
help_message="Larger batch size is better for NN's"
|
||||||
" generalization, but it can cause Out of"
|
" generalization, but it can cause Out of"
|
||||||
" Memory error. Tune this value for your"
|
" Memory error. Tune this value for your"
|
||||||
|
@ -149,9 +149,10 @@ class ModelBase(object):
|
||||||
"Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]),
|
"Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]),
|
||||||
True,
|
True,
|
||||||
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence")
|
||||||
|
self.options['paddle'] = self.options.get('paddle','ping')
|
||||||
|
|
||||||
else:
|
else:
|
||||||
self.options['batch_size'] = self.options.get('batch_size', 0)
|
self.options['batch_cap'] = self.options.get('batch_cap', 0)
|
||||||
self.options['ping_pong'] = self.options.get('ping_pong', True)
|
self.options['ping_pong'] = self.options.get('ping_pong', True)
|
||||||
|
|
||||||
if ask_sort_by_yaw:
|
if ask_sort_by_yaw:
|
||||||
|
@ -197,6 +198,7 @@ class ModelBase(object):
|
||||||
self.options.pop('target_iter')
|
self.options.pop('target_iter')
|
||||||
|
|
||||||
self.batch_size = self.options.get('batch_size', 0)
|
self.batch_size = self.options.get('batch_size', 0)
|
||||||
|
self.batch_cap = self.options.get('batch_cap',16)
|
||||||
self.sort_by_yaw = self.options.get('sort_by_yaw', False)
|
self.sort_by_yaw = self.options.get('sort_by_yaw', False)
|
||||||
self.random_flip = self.options.get('random_flip', True)
|
self.random_flip = self.options.get('random_flip', True)
|
||||||
|
|
||||||
|
@ -213,7 +215,7 @@ class ModelBase(object):
|
||||||
self.onInitialize()
|
self.onInitialize()
|
||||||
|
|
||||||
self.options['batch_size'] = self.batch_size
|
self.options['batch_size'] = self.batch_size
|
||||||
self.options['paddle'] = 'ping'
|
self.paddle = self.options.get('paddle', 'ping')
|
||||||
|
|
||||||
if self.debug or self.batch_size == 0:
|
if self.debug or self.batch_size == 0:
|
||||||
self.batch_size = 1
|
self.batch_size = 1
|
||||||
|
@ -390,6 +392,8 @@ class ModelBase(object):
|
||||||
return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr
|
return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr
|
||||||
|
|
||||||
def save(self):
|
def save(self):
|
||||||
|
self.options['batch_size'] = self.batch_size
|
||||||
|
self.options['paddle'] = self.paddle
|
||||||
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
summary_path = self.get_strpath_storage_for_file('summary.txt')
|
||||||
Path(summary_path).write_text(self.model_summary_text)
|
Path(summary_path).write_text(self.model_summary_text)
|
||||||
self.onSave()
|
self.onSave()
|
||||||
|
@ -518,6 +522,12 @@ class ModelBase(object):
|
||||||
return [next(generator) for generator in self.generator_list]
|
return [next(generator) for generator in self.generator_list]
|
||||||
|
|
||||||
def train_one_iter(self):
|
def train_one_iter(self):
|
||||||
|
|
||||||
|
if self.iter == 1 and self.options.get('ping_pong', True):
|
||||||
|
self.set_batch_size(1)
|
||||||
|
self.paddle = 'ping'
|
||||||
|
elif not self.options.get('ping_pong', True):
|
||||||
|
self.set_batch_size(self.batch_cap)
|
||||||
sample = self.generate_next_sample()
|
sample = self.generate_next_sample()
|
||||||
iter_time = time.time()
|
iter_time = time.time()
|
||||||
losses = self.onTrainOneIter(sample, self.generator_list)
|
losses = self.onTrainOneIter(sample, self.generator_list)
|
||||||
|
@ -544,9 +554,15 @@ class ModelBase(object):
|
||||||
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8)
|
||||||
cv2_imwrite(filepath, img)
|
cv2_imwrite(filepath, img)
|
||||||
|
|
||||||
if self.iter % 50 == 0 and self.iter != 0:
|
if self.iter % 1000 == 0 and self.iter != 0 and self.options.get('ping_pong', True):
|
||||||
|
if self.batch_size == self.batch_cap:
|
||||||
self.set_batch_size(self.batch_size + 1)
|
self.paddle = 'pong'
|
||||||
|
if self.batch_size == 1:
|
||||||
|
self.paddle = 'ping'
|
||||||
|
if self.paddle == 'ping':
|
||||||
|
self.set_batch_size(self.batch_size + 1)
|
||||||
|
else:
|
||||||
|
self.set_batch_size(self.batch_size - 1)
|
||||||
|
|
||||||
self.iter += 1
|
self.iter += 1
|
||||||
|
|
||||||
|
@ -668,6 +684,6 @@ class ModelBase(object):
|
||||||
lh_text = 'Iter: %d' % iter if iter != 0 else ''
|
lh_text = 'Iter: %d' % iter if iter != 0 else ''
|
||||||
bs_text = 'BS: %d' % batch_size if batch_size is not None else '1'
|
bs_text = 'BS: %d' % batch_size if batch_size is not None else '1'
|
||||||
|
|
||||||
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), bs_text,
|
lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), lh_text,
|
||||||
color=[0.8] * c)
|
color=[0.8] * c)
|
||||||
return lh_img
|
return lh_img
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue