mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 05:22:06 -07:00
RecycleGAN fixes
This commit is contained in:
parent
1f86d7f1dd
commit
474fff248f
3 changed files with 25 additions and 20 deletions
|
@ -20,10 +20,7 @@ class Model(ModelBase):
|
|||
#override
|
||||
def onInitialize(self, batch_size=-1, **in_options):
|
||||
exec(nnlib.code_import_all, locals(), globals())
|
||||
self.set_vram_batch_requirements( {6:6} )
|
||||
|
||||
|
||||
|
||||
|
||||
created_batch_size = self.get_batch_size()
|
||||
if self.epoch == 0:
|
||||
#first run
|
||||
|
@ -31,13 +28,22 @@ class Model(ModelBase):
|
|||
print ("\nModel first run. Enter options.")
|
||||
|
||||
try:
|
||||
input_created_batch_size = int ( input ("Batch_size (default - based on VRAM) : ") )
|
||||
created_resolution = int ( input ("Resolution (default:64, valid: 64,128,256) : ") )
|
||||
except:
|
||||
input_created_batch_size = 0
|
||||
created_resolution = 64
|
||||
|
||||
if input_created_batch_size != 0:
|
||||
created_batch_size = input_created_batch_size
|
||||
if created_resolution not in [64,128,256]:
|
||||
created_resolution = 64
|
||||
|
||||
try:
|
||||
created_batch_size = int ( input ("Batch_size (minimum/default - 16) : ") )
|
||||
except:
|
||||
created_batch_size = 16
|
||||
created_batch_size = max(created_batch_size,1)
|
||||
|
||||
print ("Done. If training won't start, decrease resolution")
|
||||
|
||||
self.options['created_resolution'] = created_resolution
|
||||
self.options['created_batch_size'] = created_batch_size
|
||||
self.created_vram_gb = self.device_config.gpu_total_vram_gb
|
||||
else:
|
||||
|
@ -45,9 +51,14 @@ class Model(ModelBase):
|
|||
if 'created_batch_size' in self.options.keys():
|
||||
created_batch_size = self.options['created_batch_size']
|
||||
else:
|
||||
raise Exception("Continue traning, but created_batch_size not found.")
|
||||
raise Exception("Continue training, but created_batch_size not found.")
|
||||
|
||||
if 'created_resolution' in self.options.keys():
|
||||
created_resolution = self.options['created_resolution']
|
||||
else:
|
||||
raise Exception("Continue training, but created_resolution not found.")
|
||||
|
||||
resolution = 128
|
||||
resolution = created_resolution
|
||||
bgr_shape = (resolution, resolution, 3)
|
||||
ngf = 64
|
||||
npf = 64
|
||||
|
@ -181,10 +192,6 @@ class Model(ModelBase):
|
|||
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
|
||||
])
|
||||
|
||||
#import code
|
||||
#code.interact(local=dict(globals(), **locals()))
|
||||
self.supress_std_once = False
|
||||
|
||||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
|
||||
|
@ -237,7 +244,5 @@ class Model(ModelBase):
|
|||
def get_converter(self, **in_options):
|
||||
from models import ConverterImage
|
||||
|
||||
return ConverterImage(self.predictor_func, predictor_input_size=128, output_size=128, **in_options)
|
||||
return ConverterImage(self.predictor_func, predictor_input_size=self.options['created_resolution'], output_size=self.options['created_resolution'], **in_options)
|
||||
|
||||
|
||||
|
Loading…
Add table
Add a link
Reference in a new issue