RecycleGAN fixes

This commit is contained in:
iperov 2018-12-28 23:13:29 +04:00
parent 1f86d7f1dd
commit 474fff248f
3 changed files with 25 additions and 20 deletions

View file

@ -20,10 +20,7 @@ class Model(ModelBase):
#override
def onInitialize(self, batch_size=-1, **in_options):
exec(nnlib.code_import_all, locals(), globals())
self.set_vram_batch_requirements( {6:6} )
created_batch_size = self.get_batch_size()
if self.epoch == 0:
#first run
@ -31,13 +28,22 @@ class Model(ModelBase):
print ("\nModel first run. Enter options.")
try:
input_created_batch_size = int ( input ("Batch_size (default - based on VRAM) : ") )
created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") )
except:
input_created_batch_size = 0
created_resolution = 64
if input_created_batch_size != 0:
created_batch_size = input_created_batch_size
if created_resolution not in [64,128,256]:
created_resolution = 64
try:
created_batch_size = int ( input ("Batch_size (minimum/default - 16) : ") )
except:
created_batch_size = 16
created_batch_size = max(created_batch_size,1)
print ("Done. If training won't start, decrease resolution")
self.options['created_resolution'] = created_resolution
self.options['created_batch_size'] = created_batch_size
self.created_vram_gb = self.device_config.gpu_total_vram_gb
else:
@ -45,9 +51,14 @@ class Model(ModelBase):
if 'created_batch_size' in self.options.keys():
created_batch_size = self.options['created_batch_size']
else:
raise Exception("Continue traning, but created_batch_size not found.")
raise Exception("Continue training, but created_batch_size not found.")
if 'created_resolution' in self.options.keys():
created_resolution = self.options['created_resolution']
else:
raise Exception("Continue training, but created_resolution not found.")
resolution = 128
resolution = created_resolution
bgr_shape = (resolution, resolution, 3)
ngf = 64
npf = 64
@ -181,10 +192,6 @@ class Model(ModelBase):
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
])
#import code
#code.interact(local=dict(globals(), **locals()))
self.supress_std_once = False
#override
def onSave(self):
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
@ -237,7 +244,5 @@ class Model(ModelBase):
def get_converter(self, **in_options):
from models import ConverterImage
return ConverterImage(self.predictor_func, predictor_input_size=128, output_size=128, **in_options)
return ConverterImage(self.predictor_func, predictor_input_size=self.options['created_resolution'], output_size=self.options['created_resolution'], **in_options)