mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
Added new features
- New cli argument --auto-gen-config It allows to create and/or write a yaml conf file inside model folder. The conf file will be called with the name of the model. - Bug fixes - Code refactoring
This commit is contained in:
parent
57cdf0e794
commit
8976ae3863
3 changed files with 56 additions and 19 deletions
4
main.py
4
main.py
|
@ -131,7 +131,8 @@ if __name__ == "__main__":
|
||||||
'start_tensorboard' : arguments.start_tensorboard,
|
'start_tensorboard' : arguments.start_tensorboard,
|
||||||
'dump_ckpt' : arguments.dump_ckpt,
|
'dump_ckpt' : arguments.dump_ckpt,
|
||||||
'flask_preview' : arguments.flask_preview,
|
'flask_preview' : arguments.flask_preview,
|
||||||
'config_training_file' : arguments.config_training_file
|
'config_training_file' : arguments.config_training_file,
|
||||||
|
'auto_gen_config' : arguments.auto_gen_config
|
||||||
}
|
}
|
||||||
from mainscripts import Trainer
|
from mainscripts import Trainer
|
||||||
Trainer.main(**kwargs)
|
Trainer.main(**kwargs)
|
||||||
|
@ -152,6 +153,7 @@ if __name__ == "__main__":
|
||||||
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
|
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
|
||||||
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
|
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
|
||||||
p.add_argument('--config-training-file', action=fixPathAction, dest="config_training_file", help="Path to custom yaml configuration file")
|
p.add_argument('--config-training-file', action=fixPathAction, dest="config_training_file", help="Path to custom yaml configuration file")
|
||||||
|
p.add_argument('--auto-gen-config', action="store_true", dest="auto_gen_config", default=False, help="Saves a configuration file for each model used in the trainer. It'll have the same model name")
|
||||||
|
|
||||||
|
|
||||||
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
|
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")
|
||||||
|
|
|
@ -103,6 +103,7 @@ def trainerThread (s2c, c2s, e,
|
||||||
cpu_only=cpu_only,
|
cpu_only=cpu_only,
|
||||||
silent_start=silent_start,
|
silent_start=silent_start,
|
||||||
config_training_file=config_training_file,
|
config_training_file=config_training_file,
|
||||||
|
auto_gen_config=kwargs.get("auto_gen_config", False),
|
||||||
debug=debug)
|
debug=debug)
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_iter_goal()
|
is_reached_goal = model.is_reached_iter_goal()
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
import colorsys
|
import colorsys
|
||||||
import inspect
|
import inspect
|
||||||
|
from io import FileIO
|
||||||
import json
|
import json
|
||||||
import multiprocessing
|
import multiprocessing
|
||||||
import operator
|
import operator
|
||||||
|
@ -37,6 +38,7 @@ class ModelBase(object):
|
||||||
debug=False,
|
debug=False,
|
||||||
force_model_class_name=None,
|
force_model_class_name=None,
|
||||||
config_training_file=None,
|
config_training_file=None,
|
||||||
|
auto_gen_config=False,
|
||||||
silent_start=False,
|
silent_start=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
self.is_training = is_training
|
self.is_training = is_training
|
||||||
|
@ -46,6 +48,8 @@ class ModelBase(object):
|
||||||
self.training_data_dst_path = training_data_dst_path
|
self.training_data_dst_path = training_data_dst_path
|
||||||
self.pretraining_data_path = pretraining_data_path
|
self.pretraining_data_path = pretraining_data_path
|
||||||
self.pretrained_model_path = pretrained_model_path
|
self.pretrained_model_path = pretrained_model_path
|
||||||
|
self.config_training_file = config_training_file
|
||||||
|
self.auto_gen_config = auto_gen_config
|
||||||
self.no_preview = no_preview
|
self.no_preview = no_preview
|
||||||
self.debug = debug
|
self.debug = debug
|
||||||
|
|
||||||
|
@ -145,27 +149,36 @@ class ModelBase(object):
|
||||||
model_data = {}
|
model_data = {}
|
||||||
# True if yaml conf file exists
|
# True if yaml conf file exists
|
||||||
self.config_file_exists = False
|
self.config_file_exists = False
|
||||||
# True if user chooses to read options from conf file
|
# True if user chooses to read options external or internal conf file
|
||||||
self.read_from_conf = False
|
self.read_from_conf = False
|
||||||
#check if config_training_file mode is enabled
|
#check if config_training_file mode is enabled
|
||||||
if config_training_file is not None:
|
if config_training_file is not None:
|
||||||
self.config_file_path = Path(config_training_file)
|
self.config_file_path = Path(config_training_file)
|
||||||
|
# Creates folder if folder doesn't exist
|
||||||
if not self.config_file_path.exists():
|
if not self.config_file_path.exists():
|
||||||
os.mkdir(self.config_file_path)
|
os.makedirs(self.config_file_path, exist_ok=True)
|
||||||
if Path(self.get_strpath_configuration_path()).exists():
|
# Ask if user wants to read options from external or internal conf file only if external conf file exists
|
||||||
|
# or auto_gen_config is true
|
||||||
|
if Path(self.get_strpath_configuration_path()).exists() or self.auto_gen_config:
|
||||||
self.read_from_conf = io.input_bool(
|
self.read_from_conf = io.input_bool(
|
||||||
f'Do you want to read training options from {self.config_file_path.stem} file?',
|
f'Do you want to read training options from {"external" if self.auto_gen_config else "internal"} file?',
|
||||||
True,
|
True,
|
||||||
'Read options from configuration file instead of asking one by one each option'
|
'Read options from configuration file instead of asking one by one each option'
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# If user decides to read from external or internal conf file
|
||||||
if self.read_from_conf:
|
if self.read_from_conf:
|
||||||
self.options = self.read_from_config_file()
|
# Try to read dictionary from external of internal yaml file according
|
||||||
|
# to the value of auto_gen_config
|
||||||
|
self.options = self.read_from_config_file(auto_gen=self.auto_gen_config)
|
||||||
|
# If options dict is empty options will be loaded from dat file
|
||||||
|
if not self.options.keys():
|
||||||
|
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
|
||||||
|
else:
|
||||||
self.config_file_exists = True
|
self.config_file_exists = True
|
||||||
else:
|
else:
|
||||||
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
|
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
|
||||||
|
|
||||||
|
|
||||||
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
|
||||||
if self.model_data_path.exists():
|
if self.model_data_path.exists():
|
||||||
io.log_info (f"Loading {self.model_name} model...")
|
io.log_info (f"Loading {self.model_name} model...")
|
||||||
|
@ -210,9 +223,6 @@ class ModelBase(object):
|
||||||
if self.is_first_run():
|
if self.is_first_run():
|
||||||
# save as default options only for first run model initialize
|
# save as default options only for first run model initialize
|
||||||
self.default_options_path.write_bytes( pickle.dumps (self.options) )
|
self.default_options_path.write_bytes( pickle.dumps (self.options) )
|
||||||
# save config file
|
|
||||||
if config_training_file is not None:
|
|
||||||
self.save_config_file()
|
|
||||||
|
|
||||||
self.session_name = self.options.get('session_name', "")
|
self.session_name = self.options.get('session_name', "")
|
||||||
self.autobackup_hour = self.options.get('autobackup_hour', 0)
|
self.autobackup_hour = self.options.get('autobackup_hour', 0)
|
||||||
|
@ -453,6 +463,10 @@ class ModelBase(object):
|
||||||
}
|
}
|
||||||
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
|
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
|
||||||
|
|
||||||
|
# save config file
|
||||||
|
if self.config_training_file is not None:
|
||||||
|
self.save_config_file(self.auto_gen_config)
|
||||||
|
|
||||||
if self.autobackup_hour != 0:
|
if self.autobackup_hour != 0:
|
||||||
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
|
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
|
||||||
|
|
||||||
|
@ -460,14 +474,23 @@ class ModelBase(object):
|
||||||
self.autobackup_start_time += self.autobackup_hour*3600
|
self.autobackup_start_time += self.autobackup_hour*3600
|
||||||
self.create_backup()
|
self.create_backup()
|
||||||
|
|
||||||
def read_from_config_file(self):
|
def read_from_config_file(self, auto_gen=False):
|
||||||
"""
|
"""
|
||||||
Read yaml config file and saves it into a dictionary
|
Read yaml config file and saves it into a dictionary
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_gen (bool, optional): True if you want that a yaml file is readed from model folder. Defaults to False.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
[type]: [description]
|
[dict]: Returns the options dictionary if everything is alright otherwise an empty dictionary.
|
||||||
"""
|
"""
|
||||||
with open(self.get_strpath_configuration_path(), 'r') as file:
|
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(fun(), 'r') as file:
|
||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
|
except FileNotFoundError:
|
||||||
|
return {}
|
||||||
|
|
||||||
for key, value in data.items():
|
for key, value in data.items():
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
|
@ -479,9 +502,12 @@ class ModelBase(object):
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def save_config_file(self):
|
def save_config_file(self, auto_gen=False):
|
||||||
"""
|
"""
|
||||||
Saves options dictionary in a yaml file
|
Saves options dictionary in a yaml file.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
auto_gen ([bool], optional): True if you want that a yaml file is generated inside model folder for each model. Defaults to None.
|
||||||
"""
|
"""
|
||||||
saving_dict = {}
|
saving_dict = {}
|
||||||
for key, value in self.options.items():
|
for key, value in self.options.items():
|
||||||
|
@ -490,8 +516,13 @@ class ModelBase(object):
|
||||||
else:
|
else:
|
||||||
saving_dict[key] = value
|
saving_dict[key] = value
|
||||||
|
|
||||||
with open(self.get_strpath_configuration_path(), 'w') as file:
|
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(fun(), 'w') as file:
|
||||||
yaml.dump(saving_dict, file, sort_keys=False)
|
yaml.dump(saving_dict, file, sort_keys=False)
|
||||||
|
except OSError as exception:
|
||||||
|
print('Impossible to write YAML configuration file -> ', exception)
|
||||||
|
|
||||||
def create_backup(self):
|
def create_backup(self):
|
||||||
io.log_info ("Creating backup...", end='\r')
|
io.log_info ("Creating backup...", end='\r')
|
||||||
|
@ -631,6 +662,9 @@ class ModelBase(object):
|
||||||
def get_summary_path(self):
|
def get_summary_path(self):
|
||||||
return self.get_strpath_storage_for_file('summary.txt')
|
return self.get_strpath_storage_for_file('summary.txt')
|
||||||
|
|
||||||
|
def get_model_conf_path(self):
|
||||||
|
return self.get_strpath_storage_for_file('configuration_file.yaml')
|
||||||
|
|
||||||
def get_summary_text(self):
|
def get_summary_text(self):
|
||||||
visible_options = self.options.copy()
|
visible_options = self.options.copy()
|
||||||
visible_options.update(self.options_show_override)
|
visible_options.update(self.options_show_override)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue