RecycleGAN fixes

This commit is contained in:
iperov 2018-12-28 23:13:29 +04:00
parent 1f86d7f1dd
commit 474fff248f
3 changed files with 25 additions and 20 deletions

View file

@ -92,7 +92,7 @@ class ModelBase(object):
if not isinstance(generator, SampleGeneratorBase): if not isinstance(generator, SampleGeneratorBase):
raise Exception('training data generator is not subclass of SampleGeneratorBase') raise Exception('training data generator is not subclass of SampleGeneratorBase')
if self.sample_for_preview is None: if (self.sample_for_preview is None) or (self.epoch == 0):
self.sample_for_preview = self.generate_next_sample() self.sample_for_preview = self.generate_next_sample()
print ("===== Model summary =====") print ("===== Model summary =====")

View file

@ -20,10 +20,7 @@ class Model(ModelBase):
#override #override
def onInitialize(self, batch_size=-1, **in_options): def onInitialize(self, batch_size=-1, **in_options):
exec(nnlib.code_import_all, locals(), globals()) exec(nnlib.code_import_all, locals(), globals())
self.set_vram_batch_requirements( {6:6} )
created_batch_size = self.get_batch_size() created_batch_size = self.get_batch_size()
if self.epoch == 0: if self.epoch == 0:
#first run #first run
@ -31,13 +28,22 @@ class Model(ModelBase):
print ("\nModel first run. Enter options.") print ("\nModel first run. Enter options.")
try: try:
input_created_batch_size = int ( input ("Batch_size (default - based on VRAM) : ") ) created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") )
except: except:
input_created_batch_size = 0 created_resolution = 64
if input_created_batch_size != 0: if created_resolution not in [64,128,256]:
created_batch_size = input_created_batch_size created_resolution = 64
try:
created_batch_size = int ( input ("Batch_size (minimum/default - 16) : ") )
except:
created_batch_size = 16
created_batch_size = max(created_batch_size,1)
print ("Done. If training won't start, decrease resolution")
self.options['created_resolution'] = created_resolution
self.options['created_batch_size'] = created_batch_size self.options['created_batch_size'] = created_batch_size
self.created_vram_gb = self.device_config.gpu_total_vram_gb self.created_vram_gb = self.device_config.gpu_total_vram_gb
else: else:
@ -45,9 +51,14 @@ class Model(ModelBase):
if 'created_batch_size' in self.options.keys(): if 'created_batch_size' in self.options.keys():
created_batch_size = self.options['created_batch_size'] created_batch_size = self.options['created_batch_size']
else: else:
raise Exception("Continue traning, but created_batch_size not found.") raise Exception("Continue training, but created_batch_size not found.")
if 'created_resolution' in self.options.keys():
created_resolution = self.options['created_resolution']
else:
raise Exception("Continue training, but created_resolution not found.")
resolution = 128 resolution = created_resolution
bgr_shape = (resolution, resolution, 3) bgr_shape = (resolution, resolution, 3)
ngf = 64 ngf = 64
npf = 64 npf = 64
@ -181,10 +192,6 @@ class Model(ModelBase):
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ), output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
]) ])
#import code
#code.interact(local=dict(globals(), **locals()))
self.supress_std_once = False
#override #override
def onSave(self): def onSave(self):
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)], self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
@ -237,7 +244,5 @@ class Model(ModelBase):
def get_converter(self, **in_options): def get_converter(self, **in_options):
from models import ConverterImage from models import ConverterImage
return ConverterImage(self.predictor_func, predictor_input_size=128, output_size=128, **in_options) return ConverterImage(self.predictor_func, predictor_input_size=self.options['created_resolution'], output_size=self.options['created_resolution'], **in_options)

View file

@ -551,8 +551,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
for i in range(n_blocks): for i in range(n_blocks):
x = ResnetBlock(ngf*4)(x) x = ResnetBlock(ngf*4)(x)
x = ReLU()(XNormalization(Conv2DTranspose(ngf*2, 3, 2, 'same')(x))) x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf*2 *4, 3, 1, 'same')(x))))
x = ReLU()(XNormalization(Conv2DTranspose(ngf , 3, 2, 'same')(x))) x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf *4, 3, 1, 'same')(x))))
x = ReflectionPadding2D((3,3))(x) x = ReflectionPadding2D((3,3))(x)
x = Conv2D(output_nc, 7, 1, 'valid')(x) x = Conv2D(output_nc, 7, 1, 'valid')(x)