diff --git a/models/ModelBase.py b/models/ModelBase.py index c6d16d3..639d6a7 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -526,7 +526,7 @@ class ModelBase(object): if self.iter == 1 and self.options.get('ping_pong', True): self.set_batch_size(1) self.paddle = 'ping' - elif not self.options.get('ping_pong', True): + elif not self.options.get('ping_pong', True) and self.batch_cap != self.batch_size: self.set_batch_size(self.batch_cap) sample = self.generate_next_sample() iter_time = time.time()