fix DFLJPG,

SAE: added "rare sample booster"
SAE: pixel loss replaced to smooth transition from DSSIM to PixelLoss in 15k epochs by default
This commit is contained in:
iperov 2019-02-09 18:53:37 +04:00
parent f93b4713a9
commit 4d37fd62cd
11 changed files with 174 additions and 101 deletions

View file

@ -199,7 +199,7 @@ class ModelBase(object):
pass
#overridable
def onTrainOneEpoch(self, sample):
def onTrainOneEpoch(self, sample, generator_list):
#train your keras models here
#return array of losses
@ -293,7 +293,8 @@ class ModelBase(object):
images = []
for generator in self.generator_list:
for i,batch in enumerate(next(generator)):
images.append( batch[0] )
if len(batch.shape) == 4:
images.append( batch[0] )
return image_utils.equalize_and_stack_square (images)
@ -305,14 +306,12 @@ class ModelBase(object):
supressor = std_utils.suppress_stdout_stderr()
supressor.__enter__()
self.last_sample = self.generate_next_sample()
epoch_time = time.time()
losses = self.onTrainOneEpoch(self.last_sample)
sample = self.generate_next_sample()
epoch_time = time.time()
losses = self.onTrainOneEpoch(sample, self.generator_list)
epoch_time = time.time() - epoch_time
self.last_sample = sample
self.loss_history.append ( [float(loss[1]) for loss in losses] )
if self.supress_std_once: