mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-20 05:23:22 -07:00
Bug fixes
This commit is contained in:
parent
d0b5090879
commit
f1ddbb05de
2 changed files with 24 additions and 19 deletions
|
@ -150,7 +150,9 @@ class ModelBase(object):
|
||||||
#check if config_training_file mode is enabled
|
#check if config_training_file mode is enabled
|
||||||
if config_training_file is not None:
|
if config_training_file is not None:
|
||||||
self.config_file_path = Path(config_training_file)
|
self.config_file_path = Path(config_training_file)
|
||||||
if self.config_file_path.exists():
|
if not self.config_file_path.exists():
|
||||||
|
os.mkdir(self.config_file_path)
|
||||||
|
if Path(self.get_strpath_configuration_path()).exists():
|
||||||
self.read_from_conf = io.input_bool(
|
self.read_from_conf = io.input_bool(
|
||||||
f'Do you want to read training options from {self.config_file_path.stem} file?',
|
f'Do you want to read training options from {self.config_file_path.stem} file?',
|
||||||
False,
|
False,
|
||||||
|
@ -463,7 +465,7 @@ class ModelBase(object):
|
||||||
Returns:
|
Returns:
|
||||||
[type]: [description]
|
[type]: [description]
|
||||||
"""
|
"""
|
||||||
with open(self.config_file_path, 'r') as file:
|
with open(self.get_strpath_configuration_path(), 'r') as file:
|
||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
|
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
|
@ -487,7 +489,7 @@ class ModelBase(object):
|
||||||
else:
|
else:
|
||||||
saving_dict[key] = value
|
saving_dict[key] = value
|
||||||
|
|
||||||
with open(self.config_file_path, 'w') as file:
|
with open(self.get_strpath_configuration_path(), 'w') as file:
|
||||||
yaml.dump(saving_dict, file, sort_keys=False)
|
yaml.dump(saving_dict, file, sort_keys=False)
|
||||||
|
|
||||||
def create_backup(self):
|
def create_backup(self):
|
||||||
|
@ -622,6 +624,9 @@ class ModelBase(object):
|
||||||
def get_strpath_storage_for_file(self, filename):
|
def get_strpath_storage_for_file(self, filename):
|
||||||
return str( self.saved_models_path / ( self.get_model_name() + '_' + filename) )
|
return str( self.saved_models_path / ( self.get_model_name() + '_' + filename) )
|
||||||
|
|
||||||
|
def get_strpath_configuration_path(self):
|
||||||
|
return str(self.config_file_path / 'configuration_file.yaml')
|
||||||
|
|
||||||
def get_summary_path(self):
|
def get_summary_path(self):
|
||||||
return self.get_strpath_storage_for_file('summary.txt')
|
return self.get_strpath_storage_for_file('summary.txt')
|
||||||
|
|
||||||
|
|
|
@ -28,7 +28,7 @@ class SAEHDModel(ModelBase):
|
||||||
min_res = 64
|
min_res = 64
|
||||||
max_res = 640
|
max_res = 640
|
||||||
|
|
||||||
default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
#default_usefp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||||
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
|
default_resolution = self.options['resolution'] = self.load_or_def_option('resolution', 128)
|
||||||
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
default_face_type = self.options['face_type'] = self.load_or_def_option('face_type', 'f')
|
||||||
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
default_models_opt_on_gpu = self.options['models_opt_on_gpu'] = self.load_or_def_option('models_opt_on_gpu', True)
|
||||||
|
@ -68,7 +68,7 @@ class SAEHDModel(ModelBase):
|
||||||
default_random_color = self.options['random_color'] = self.load_or_def_option('random_color', False)
|
default_random_color = self.options['random_color'] = self.load_or_def_option('random_color', False)
|
||||||
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
default_clipgrad = self.options['clipgrad'] = self.load_or_def_option('clipgrad', False)
|
||||||
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
|
default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False)
|
||||||
default_use_fp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
#default_use_fp16 = self.options['use_fp16'] = self.load_or_def_option('use_fp16', False)
|
||||||
|
|
||||||
ask_override = False if self.read_from_conf else self.ask_override()
|
ask_override = False if self.read_from_conf else self.ask_override()
|
||||||
if self.is_first_run() or ask_override:
|
if self.is_first_run() or ask_override:
|
||||||
|
@ -82,7 +82,7 @@ class SAEHDModel(ModelBase):
|
||||||
self.ask_random_src_flip()
|
self.ask_random_src_flip()
|
||||||
self.ask_random_dst_flip()
|
self.ask_random_dst_flip()
|
||||||
self.ask_batch_size(suggest_batch_size)
|
self.ask_batch_size(suggest_batch_size)
|
||||||
self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
|
#self.options['use_fp16'] = io.input_bool ("Use fp16", default_usefp16, help_message='Increases training/inference speed, reduces model size. Model may crash. Enable it after 1-5k iters.')
|
||||||
|
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
if (self.read_from_conf and not self.config_file_exists) or not self.read_from_conf:
|
if (self.read_from_conf and not self.config_file_exists) or not self.read_from_conf:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue