diff --git a/models/ModelBase.py b/models/ModelBase.py index 25fa9fe..70ad484 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -197,6 +197,8 @@ class ModelBase(object): self.batch_cap = self.options.get('batch_cap',16) self.sort_by_yaw = self.options.get('sort_by_yaw', False) self.random_flip = self.options.get('random_flip', True) + if self.batch_cap == 0: + self.batch_cap = self.options['batch_cap'] = self.batch_size self.src_scale_mod = self.options.get('src_scale_mod', 0) if self.src_scale_mod == 0 and 'src_scale_mod' in self.options: @@ -553,6 +555,9 @@ class ModelBase(object): if self.iter % 1000 == 0 and self.iter != 0 and self.options.get('ping_pong', True): if self.batch_size == self.batch_cap: self.paddle = 'pong' + if self.batch_size > self.batch_cap: + self.set_batch_size(self.batch_cap) + self.paddle = 'pong' if self.batch_size == 1: self.paddle = 'ping' if self.paddle == 'ping':