diff --git a/main.py b/main.py index c480a9e..8093c77 100644 --- a/main.py +++ b/main.py @@ -136,7 +136,7 @@ if __name__ == "__main__": p = subparsers.add_parser( "train", help="Trainer") p.add_argument('--training-data-src-dir', required=True, action=fixPathAction, dest="training_data_src_dir", help="Dir of extracted SRC faceset.") p.add_argument('--training-data-dst-dir', required=True, action=fixPathAction, dest="training_data_dst_dir", help="Dir of extracted DST faceset.") - p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", help="Optional dir of extracted faceset that will be used in pretraining mode.") + p.add_argument('--pretraining-data-dir', action=fixPathAction, dest="pretraining_data_dir", default=None, help="Optional dir of extracted faceset that will be used in pretraining mode.") p.add_argument('--model-dir', required=True, action=fixPathAction, dest="model_dir", help="Model dir.") p.add_argument('--model', required=True, dest="model_name", choices=Path_utils.get_all_dir_names_startswith ( Path(__file__).parent / 'models' , 'Model_'), help="Type of model") p.add_argument('--no-preview', action="store_true", dest="no_preview", default=False, help="Disable preview window.") diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index bde87e3..2d6505a 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -19,7 +19,10 @@ def trainerThread (s2c, c2s, args, device_args): training_data_src_path = Path( args.get('training_data_src_dir', '') ) training_data_dst_path = Path( args.get('training_data_dst_dir', '') ) - pretraining_data_path = Path( args.get('pretraining_data_dir', '') ) + + pretraining_data_path = args.get('pretraining_data_dir', '') + pretraining_data_path = Path(pretraining_data_path) if pretraining_data_path is not None else None + model_path = Path( args.get('model_path', '') ) model_name = args.get('model_name', '') save_interval_min = 15 diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index afc5c6b..687a5f9 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -353,7 +353,7 @@ class SAEModel(ModelBase): training_data_dst_path = self.training_data_dst_path sort_by_yaw = self.sort_by_yaw - if self.pretrain: + if self.pretrain and self.pretraining_data_path is not None: training_data_src_path = self.pretraining_data_path training_data_dst_path = self.pretraining_data_path sort_by_yaw = False