diff --git a/main.py b/main.py index f96157b..054dff5 100644 --- a/main.py +++ b/main.py @@ -131,7 +131,8 @@ if __name__ == "__main__": 'start_tensorboard' : arguments.start_tensorboard, 'dump_ckpt' : arguments.dump_ckpt, 'flask_preview' : arguments.flask_preview, - 'config_training_file' : arguments.config_training_file + 'config_training_file' : arguments.config_training_file, + 'auto_gen_config' : arguments.auto_gen_config } from mainscripts import Trainer Trainer.main(**kwargs) @@ -152,6 +153,7 @@ if __name__ == "__main__": p.add_argument('--tensorboard-logdir', action=fixPathAction, dest="tensorboard_dir", help="Directory of the tensorboard output files") p.add_argument('--start-tensorboard', action="store_true", dest="start_tensorboard", default=False, help="Automatically start the tensorboard server preconfigured to the tensorboard-logdir") p.add_argument('--config-training-file', action=fixPathAction, dest="config_training_file", help="Path to custom yaml configuration file") + p.add_argument('--auto-gen-config', action="store_true", dest="auto_gen_config", default=False, help="Saves a configuration file for each model used in the trainer. It'll have the same model name") p.add_argument('--dump-ckpt', action="store_true", dest="dump_ckpt", default=False, help="Dump the model to ckpt format.") diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 9814d30..b894cc3 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -103,6 +103,7 @@ def trainerThread (s2c, c2s, e, cpu_only=cpu_only, silent_start=silent_start, config_training_file=config_training_file, + auto_gen_config=kwargs.get("auto_gen_config", False), debug=debug) is_reached_goal = model.is_reached_iter_goal() diff --git a/models/ModelBase.py b/models/ModelBase.py index 527b23b..0da7246 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -1,5 +1,6 @@ import colorsys import inspect +from io import FileIO import json import multiprocessing import operator @@ -37,6 +38,7 @@ class ModelBase(object): debug=False, force_model_class_name=None, config_training_file=None, + auto_gen_config=False, silent_start=False, **kwargs): self.is_training = is_training @@ -46,6 +48,8 @@ class ModelBase(object): self.training_data_dst_path = training_data_dst_path self.pretraining_data_path = pretraining_data_path self.pretrained_model_path = pretrained_model_path + self.config_training_file = config_training_file + self.auto_gen_config = auto_gen_config self.no_preview = no_preview self.debug = debug @@ -145,27 +149,36 @@ class ModelBase(object): model_data = {} # True if yaml conf file exists self.config_file_exists = False - # True if user chooses to read options from conf file + # True if user chooses to read options external or internal conf file self.read_from_conf = False #check if config_training_file mode is enabled if config_training_file is not None: self.config_file_path = Path(config_training_file) + # Creates folder if folder doesn't exist if not self.config_file_path.exists(): - os.mkdir(self.config_file_path) - if Path(self.get_strpath_configuration_path()).exists(): + os.makedirs(self.config_file_path, exist_ok=True) + # Ask if user wants to read options from external or internal conf file only if external conf file exists + # or auto_gen_config is true + if Path(self.get_strpath_configuration_path()).exists() or self.auto_gen_config: self.read_from_conf = io.input_bool( - f'Do you want to read training options from {self.config_file_path.stem} file?', + f'Do you want to read training options from {"external" if self.auto_gen_config else "internal"} file?', True, 'Read options from configuration file instead of asking one by one each option' ) + # If user decides to read from external or internal conf file if self.read_from_conf: - self.options = self.read_from_config_file() - self.config_file_exists = True + # Try to read dictionary from external of internal yaml file according + # to the value of auto_gen_config + self.options = self.read_from_config_file(auto_gen=self.auto_gen_config) + # If options dict is empty options will be loaded from dat file + if not self.options.keys(): + io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.") + else: + self.config_file_exists = True else: io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.") - self.model_data_path = Path( self.get_strpath_storage_for_file('data.dat') ) if self.model_data_path.exists(): io.log_info (f"Loading {self.model_name} model...") @@ -210,9 +223,6 @@ class ModelBase(object): if self.is_first_run(): # save as default options only for first run model initialize self.default_options_path.write_bytes( pickle.dumps (self.options) ) - # save config file - if config_training_file is not None: - self.save_config_file() self.session_name = self.options.get('session_name', "") self.autobackup_hour = self.options.get('autobackup_hour', 0) @@ -453,6 +463,10 @@ class ModelBase(object): } pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) ) + # save config file + if self.config_training_file is not None: + self.save_config_file(self.auto_gen_config) + if self.autobackup_hour != 0: diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 ) @@ -460,14 +474,23 @@ class ModelBase(object): self.autobackup_start_time += self.autobackup_hour*3600 self.create_backup() - def read_from_config_file(self): + def read_from_config_file(self, auto_gen=False): """ Read yaml config file and saves it into a dictionary + + Args: + auto_gen (bool, optional): True if you want that a yaml file is readed from model folder. Defaults to False. + Returns: - [type]: [description] + [dict]: Returns the options dictionary if everything is alright otherwise an empty dictionary. """ - with open(self.get_strpath_configuration_path(), 'r') as file: - data = yaml.safe_load(file) + fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path + + try: + with open(fun(), 'r') as file: + data = yaml.safe_load(file) + except FileNotFoundError: + return {} for key, value in data.items(): if isinstance(value, bool): @@ -479,9 +502,12 @@ class ModelBase(object): return data - def save_config_file(self): + def save_config_file(self, auto_gen=False): """ - Saves options dictionary in a yaml file + Saves options dictionary in a yaml file. + + Args: + auto_gen ([bool], optional): True if you want that a yaml file is generated inside model folder for each model. Defaults to None. """ saving_dict = {} for key, value in self.options.items(): @@ -490,8 +516,13 @@ class ModelBase(object): else: saving_dict[key] = value - with open(self.get_strpath_configuration_path(), 'w') as file: - yaml.dump(saving_dict, file, sort_keys=False) + fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path + + try: + with open(fun(), 'w') as file: + yaml.dump(saving_dict, file, sort_keys=False) + except OSError as exception: + print('Impossible to write YAML configuration file -> ', exception) def create_backup(self): io.log_info ("Creating backup...", end='\r') @@ -631,6 +662,9 @@ class ModelBase(object): def get_summary_path(self): return self.get_strpath_storage_for_file('summary.txt') + def get_model_conf_path(self): + return self.get_strpath_storage_for_file('configuration_file.yaml') + def get_summary_text(self): visible_options = self.options.copy() visible_options.update(self.options_show_override)