Added new features

- New cli argument --auto-gen-config
	It allows to create and/or write a yaml conf file inside model
	folder. The conf file will be called with the name of the model.
- Bug fixes
- Code refactoring
This commit is contained in:
Cioscos 2021-12-04 15:31:15 +01:00
commit 8976ae3863
3 changed files with 56 additions and 19 deletions

View file

@ -131,7 +131,8 @@ if __name__ == "__main__":
'start_tensorboard' : arguments.start_tensorboard, 'start_tensorboard' : arguments.start_tensorboard,
'dump_ckpt' : arguments.dump_ckpt, 'dump_ckpt' : arguments.dump_ckpt,
'flask_preview' : arguments.flask_preview, 'flask_preview' : arguments.flask_preview,
'config_training_file' : arguments.config_training_file 'config_training_file' : arguments.config_training_file,
'auto_gen_config' : arguments.auto_gen_config
} }
from mainscripts import Trainer from mainscripts import Trainer
Trainer.main(**kwargs) Trainer.main(**kwargs)
@ -152,6 +153,7 @@ if __name__ == "__main__":
p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files") p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files")
p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir") p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir")
p.add_argument('--config-training-file', action=fixPathAction, dest="config_training_file", help="Path to custom yaml configuration file") p.add_argument('--config-training-file', action=fixPathAction, dest="config_training_file", help="Path to custom yaml configuration file")
p.add_argument('--auto-gen-config', action="store_true", dest="auto_gen_config", default=False, help="Saves a configuration file for each model used in the trainer. It'll have the same model name")
p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.") p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.")

View file

@ -103,6 +103,7 @@ def trainerThread (s2c, c2s, e,
cpu_only=cpu_only, cpu_only=cpu_only,
silent_start=silent_start, silent_start=silent_start,
config_training_file=config_training_file, config_training_file=config_training_file,
auto_gen_config=kwargs.get("auto_gen_config", False),
debug=debug) debug=debug)
is_reached_goal = model.is_reached_iter_goal() is_reached_goal = model.is_reached_iter_goal()

View file

@ -1,5 +1,6 @@
import colorsys import colorsys
import inspect import inspect
from io import FileIO
import json import json
import multiprocessing import multiprocessing
import operator import operator
@ -37,6 +38,7 @@ class ModelBase(object):
debug=False, debug=False,
force_model_class_name=None, force_model_class_name=None,
config_training_file=None, config_training_file=None,
auto_gen_config=False,
silent_start=False, silent_start=False,
**kwargs): **kwargs):
self.is_training = is_training self.is_training = is_training
@ -46,6 +48,8 @@ class ModelBase(object):
self.training_data_dst_path = training_data_dst_path self.training_data_dst_path = training_data_dst_path
self.pretraining_data_path = pretraining_data_path self.pretraining_data_path = pretraining_data_path
self.pretrained_model_path = pretrained_model_path self.pretrained_model_path = pretrained_model_path
self.config_training_file = config_training_file
self.auto_gen_config = auto_gen_config
self.no_preview = no_preview self.no_preview = no_preview
self.debug = debug self.debug = debug
@ -145,27 +149,36 @@ class ModelBase(object):
model_data = {} model_data = {}
# True if yaml conf file exists # True if yaml conf file exists
self.config_file_exists = False self.config_file_exists = False
# True if user chooses to read options from conf file # True if user chooses to read options external or internal conf file
self.read_from_conf = False self.read_from_conf = False
#check if config_training_file mode is enabled #check if config_training_file mode is enabled
if config_training_file is not None: if config_training_file is not None:
self.config_file_path = Path(config_training_file) self.config_file_path = Path(config_training_file)
# Creates folder if folder doesn't exist
if not self.config_file_path.exists(): if not self.config_file_path.exists():
os.mkdir(self.config_file_path) os.makedirs(self.config_file_path, exist_ok=True)
if Path(self.get_strpath_configuration_path()).exists(): # Ask if user wants to read options from external or internal conf file only if external conf file exists
# or auto_gen_config is true
if Path(self.get_strpath_configuration_path()).exists() or self.auto_gen_config:
self.read_from_conf = io.input_bool( self.read_from_conf = io.input_bool(
f'Do you want to read training options from {self.config_file_path.stem} file?', f'Do you want to read training options from {"external" if self.auto_gen_config else "internal"} file?',
True, True,
'Read options from configuration file instead of asking one by one each option' 'Read options from configuration file instead of asking one by one each option'
) )
# If user decides to read from external or internal conf file
if self.read_from_conf: if self.read_from_conf:
self.options = self.read_from_config_file() # Try to read dictionary from external of internal yaml file according
# to the value of auto_gen_config
self.options = self.read_from_config_file(auto_gen=self.auto_gen_config)
# If options dict is empty options will be loaded from dat file
if not self.options.keys():
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
else:
self.config_file_exists = True self.config_file_exists = True
else: else:
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.") io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') )
if self.model_data_path.exists(): if self.model_data_path.exists():
io.log_info (f"Loading {self.model_name} model...") io.log_info (f"Loading {self.model_name} model...")
@ -210,9 +223,6 @@ class ModelBase(object):
if self.is_first_run(): if self.is_first_run():
# save as default options only for first run model initialize # save as default options only for first run model initialize
self.default_options_path.write_bytes( pickle.dumps (self.options) ) self.default_options_path.write_bytes( pickle.dumps (self.options) )
# save config file
if config_training_file is not None:
self.save_config_file()
self.session_name = self.options.get('session_name', "") self.session_name = self.options.get('session_name', "")
self.autobackup_hour = self.options.get('autobackup_hour', 0) self.autobackup_hour = self.options.get('autobackup_hour', 0)
@ -453,6 +463,10 @@ class ModelBase(object):
} }
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) ) pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
# save config file
if self.config_training_file is not None:
self.save_config_file(self.auto_gen_config)
if self.autobackup_hour != 0: if self.autobackup_hour != 0:
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 ) diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
@ -460,14 +474,23 @@ class ModelBase(object):
self.autobackup_start_time += self.autobackup_hour*3600 self.autobackup_start_time += self.autobackup_hour*3600
self.create_backup() self.create_backup()
def read_from_config_file(self): def read_from_config_file(self, auto_gen=False):
""" """
Read yaml config file and saves it into a dictionary Read yaml config file and saves it into a dictionary
Args:
auto_gen (bool, optional): True if you want that a yaml file is readed from model folder. Defaults to False.
Returns: Returns:
[type]: [description] [dict]: Returns the options dictionary if everything is alright otherwise an empty dictionary.
""" """
with open(self.get_strpath_configuration_path(), 'r') as file: fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
try:
with open(fun(), 'r') as file:
data = yaml.safe_load(file) data = yaml.safe_load(file)
except FileNotFoundError:
return {}
for key, value in data.items(): for key, value in data.items():
if isinstance(value, bool): if isinstance(value, bool):
@ -479,9 +502,12 @@ class ModelBase(object):
return data return data
def save_config_file(self): def save_config_file(self, auto_gen=False):
""" """
Saves options dictionary in a yaml file Saves options dictionary in a yaml file.
Args:
auto_gen ([bool], optional): True if you want that a yaml file is generated inside model folder for each model. Defaults to None.
""" """
saving_dict = {} saving_dict = {}
for key, value in self.options.items(): for key, value in self.options.items():
@ -490,8 +516,13 @@ class ModelBase(object):
else: else:
saving_dict[key] = value saving_dict[key] = value
with open(self.get_strpath_configuration_path(), 'w') as file: fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
try:
with open(fun(), 'w') as file:
yaml.dump(saving_dict, file, sort_keys=False) yaml.dump(saving_dict, file, sort_keys=False)
except OSError as exception:
print('Impossible to write YAML configuration file -> ', exception)
def create_backup(self): def create_backup(self):
io.log_info ("Creating backup...", end='\r') io.log_info ("Creating backup...", end='\r')
@ -631,6 +662,9 @@ class ModelBase(object):
def get_summary_path(self): def get_summary_path(self):
return self.get_strpath_storage_for_file('summary.txt') return self.get_strpath_storage_for_file('summary.txt')
def get_model_conf_path(self):
return self.get_strpath_storage_for_file('configuration_file.yaml')
def get_summary_text(self): def get_summary_text(self):
visible_options = self.options.copy() visible_options = self.options.copy()
visible_options.update(self.options_show_override) visible_options.update(self.options_show_override)