diff --git a/models/ModelBase.py b/models/ModelBase.py index fe13698..4cb44b9 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -140,21 +140,23 @@ class ModelBase(object): if ask_batch_size and (self.iter == 0 or ask_override): default_batch_size = 0 if self.iter == 0 else self.options.get('batch_size',0) - self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % self.options.get('batch_cap', 16),self.options.get('batch_cap', 16), + self.options['batch_cap'] = max(0, io.input_int("Batch_size (?:help skip:%d) : " % self.options.get('batch_cap', 1),self.options.get('batch_cap', 1), help_message="Larger batch size is better for NN's" " generalization, but it can cause Out of" " Memory error. Tune this value for your" " videocard manually.")) self.options['ping_pong'] = io.input_bool( - "Enable ping-pong? (y/n ?:help skip:%s) : " % self.options.get('ping_pong', False), + "Enable ping-pong? (y/n ?:help skip:%s) : " % yn_str[self.options.get('ping_pong', False)], self.options.get('ping_pong', False), help_message="Cycles batch size between 1 and chosen batch size, simulating super convergence") self.options['paddle'] = self.options.get('paddle','ping') if self.options.get('ping_pong',False): self.options['ping_pong_iter'] = max(0, io.input_int("Ping-pong iteration (skip:1000/default) : ", 1000)) + else: + self.options['batch_size'] = self.options.get('batch_cap', 1) else: - self.options['batch_cap'] = self.options.get('batch_cap', 16) + self.options['batch_cap'] = self.options.get('batch_cap', 1) self.options['ping_pong'] = self.options.get('ping_pong', False) self.options['ping_pong_iter'] = self.options.get('ping_pong_iter',1000) @@ -197,13 +199,13 @@ class ModelBase(object): self.options.pop('target_iter') self.batch_size = self.options.get('batch_size', 8) - self.batch_cap = self.options.get('batch_cap',16) + self.batch_cap = self.options.get('batch_cap',1) self.ping_pong_iter = self.options.get('ping_pong_iter',1000) self.sort_by_yaw = self.options.get('sort_by_yaw',False) self.random_flip = self.options.get('random_flip',True) if self.batch_cap == 0: self.options['batch_cap'] = self.batch_size - self.batch_cap = self.options.get('batch_cap',16) + self.batch_cap = self.options.get('batch_cap',1) self.src_scale_mod = self.options.get('src_scale_mod',0) if self.src_scale_mod == 0 and 'src_scale_mod' in self.options: