diff --git a/models/ModelBase.py b/models/ModelBase.py index e4136b2..5403a48 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -92,7 +92,7 @@ class ModelBase(object): if not isinstance(generator, SampleGeneratorBase): raise Exception('training data generator is not subclass of SampleGeneratorBase') - if self.sample_for_preview is None: + if (self.sample_for_preview is None) or (self.epoch == 0): self.sample_for_preview = self.generate_next_sample() print ("===== Model summary =====") diff --git a/models/Model_RecycleGAN/Model.py b/models/Model_RecycleGAN/Model.py index 2073531..8fcb9ef 100644 --- a/models/Model_RecycleGAN/Model.py +++ b/models/Model_RecycleGAN/Model.py @@ -20,10 +20,7 @@ class Model(ModelBase): #override def onInitialize(self, batch_size=-1, **in_options): exec(nnlib.code_import_all, locals(), globals()) - self.set_vram_batch_requirements( {6:6} ) - - - + created_batch_size = self.get_batch_size() if self.epoch == 0: #first run @@ -31,13 +28,22 @@ class Model(ModelBase): print ("\nModel first run. Enter options.") try: - input_created_batch_size = int ( input ("Batch_size (default - based on VRAM) : ") ) + created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") ) except: - input_created_batch_size = 0 + created_resolution = 64 - if input_created_batch_size != 0: - created_batch_size = input_created_batch_size + if created_resolution not in [64,128,256]: + created_resolution = 64 + + try: + created_batch_size = int ( input ("Batch_size (minimum/default - 16) : ") ) + except: + created_batch_size = 16 + created_batch_size = max(created_batch_size,1) + print ("Done. If training won't start, decrease resolution") + + self.options['created_resolution'] = created_resolution self.options['created_batch_size'] = created_batch_size self.created_vram_gb = self.device_config.gpu_total_vram_gb else: @@ -45,9 +51,14 @@ class Model(ModelBase): if 'created_batch_size' in self.options.keys(): created_batch_size = self.options['created_batch_size'] else: - raise Exception("Continue traning, but created_batch_size not found.") + raise Exception("Continue training, but created_batch_size not found.") + + if 'created_resolution' in self.options.keys(): + created_resolution = self.options['created_resolution'] + else: + raise Exception("Continue training, but created_resolution not found.") - resolution = 128 + resolution = created_resolution bgr_shape = (resolution, resolution, 3) ngf = 64 npf = 64 @@ -181,10 +192,6 @@ class Model(ModelBase): output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ), ]) - #import code - #code.interact(local=dict(globals(), **locals())) - self.supress_std_once = False - #override def onSave(self): self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)], @@ -237,7 +244,5 @@ class Model(ModelBase): def get_converter(self, **in_options): from models import ConverterImage - return ConverterImage(self.predictor_func, predictor_input_size=128, output_size=128, **in_options) + return ConverterImage(self.predictor_func, predictor_input_size=self.options['created_resolution'], output_size=self.options['created_resolution'], **in_options) - - \ No newline at end of file diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index 6843457..90a9038 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -551,8 +551,8 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator for i in range(n_blocks): x = ResnetBlock(ngf*4)(x) - x = ReLU()(XNormalization(Conv2DTranspose(ngf*2, 3, 2, 'same')(x))) - x = ReLU()(XNormalization(Conv2DTranspose(ngf , 3, 2, 'same')(x))) + x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf*2 *4, 3, 1, 'same')(x)))) + x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf *4, 3, 1, 'same')(x)))) x = ReflectionPadding2D((3,3))(x) x = Conv2D(output_nc, 7, 1, 'valid')(x)