From ca33aaa0193aac18f2d6c2d7ad2c64fcff980bbd Mon Sep 17 00:00:00 2001 From: Brigham Lysenko Date: Tue, 13 Aug 2019 15:23:50 -0600 Subject: [PATCH] finished ping pong --- models/ModelBase.py | 32 ++++++++++++++++++++++++-------- 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/models/ModelBase.py b/models/ModelBase.py index 80e490f..c6d16d3 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -139,8 +139,8 @@ class ModelBase(object): if ask_batch_size and (self.iter == 0 or ask_override): default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size', 0) - self.options['batch_size'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % default_batch_size, - default_batch_size, + self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % 0, + 0, help_message="Larger batch size is better for NN's" " generalization, but it can cause Out of" " Memory error. Tune this value for your" @@ -149,9 +149,10 @@ class ModelBase(object): "Enable ping-pong? (y/n ?:help skip:%s) : " % (yn_str[True]), True, help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") + self.options['paddle'] = self.options.get('paddle','ping') else: - self.options['batch_size'] = self.options.get('batch_size', 0) + self.options['batch_cap'] = self.options.get('batch_cap', 0) self.options['ping_pong'] = self.options.get('ping_pong', True) if ask_sort_by_yaw: @@ -197,6 +198,7 @@ class ModelBase(object): self.options.pop('target_iter') self.batch_size = self.options.get('batch_size', 0) + self.batch_cap = self.options.get('batch_cap',16) self.sort_by_yaw = self.options.get('sort_by_yaw', False) self.random_flip = self.options.get('random_flip', True) @@ -213,7 +215,7 @@ class ModelBase(object): self.onInitialize() self.options['batch_size'] = self.batch_size - self.options['paddle'] = 'ping' + self.paddle = self.options.get('paddle', 'ping') if self.debug or self.batch_size == 0: self.batch_size = 1 @@ -390,6 +392,8 @@ class ModelBase(object): return self.onGetPreview(self.sample_for_preview)[0][1] # first preview, and bgr def save(self): + self.options['batch_size'] = self.batch_size + self.options['paddle'] = self.paddle summary_path = self.get_strpath_storage_for_file('summary.txt') Path(summary_path).write_text(self.model_summary_text) self.onSave() @@ -518,6 +522,12 @@ class ModelBase(object): return [next(generator) for generator in self.generator_list] def train_one_iter(self): + + if self.iter == 1 and self.options.get('ping_pong', True): + self.set_batch_size(1) + self.paddle = 'ping' + elif not self.options.get('ping_pong', True): + self.set_batch_size(self.batch_cap) sample = self.generate_next_sample() iter_time = time.time() losses = self.onTrainOneIter(sample, self.generator_list) @@ -544,9 +554,15 @@ class ModelBase(object): img = (np.concatenate([preview_lh, preview], axis=0) * 255).astype(np.uint8) cv2_imwrite(filepath, img) - if self.iter % 50 == 0 and self.iter != 0: - - self.set_batch_size(self.batch_size + 1) + if self.iter % 1000 == 0 and self.iter != 0 and self.options.get('ping_pong', True): + if self.batch_size == self.batch_cap: + self.paddle = 'pong' + if self.batch_size == 1: + self.paddle = 'ping' + if self.paddle == 'ping': + self.set_batch_size(self.batch_size + 1) + else: + self.set_batch_size(self.batch_size - 1) self.iter += 1 @@ -668,6 +684,6 @@ class ModelBase(object): lh_text = 'Iter: %d' % iter if iter != 0 else '' bs_text = 'BS: %d' % batch_size if batch_size is not None else '1' - lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), bs_text, + lh_img[last_line_t:last_line_b, 0:w] += imagelib.get_text_image((last_line_b - last_line_t, w, c), lh_text, color=[0.8] * c) return lh_img