diff --git a/models/Model_Quick96/Model.py b/models/Model_Quick96/Model.py index 7ac3f0d..dba5738 100644 --- a/models/Model_Quick96/Model.py +++ b/models/Model_Quick96/Model.py @@ -13,6 +13,13 @@ from samplelib import * from pathlib import Path class QModel(ModelBase): + #override + def on_initialize_options(self): + ask_override = False if self.read_from_conf else self.ask_override() + if self.is_first_run() or ask_override: + if (self.read_from_conf and not self.config_file_exists) or not self.read_from_conf: + self.ask_batch_size() + #override def on_initialize(self): device_config = nn.getCurrentDeviceConfig() @@ -82,7 +89,7 @@ class QModel(ModelBase): if self.is_training: # Adjust batch size for multiple GPU gpu_count = max(1, len(devices) ) - bs_per_gpu = max(1, 4 // gpu_count) + bs_per_gpu = max(1, self.get_batch_size() // gpu_count) self.set_batch_size( gpu_count*bs_per_gpu) # Compute losses per GPU diff --git a/models/Model_XSeg/Model.py b/models/Model_XSeg/Model.py index b6e875f..4f536cf 100644 --- a/models/Model_XSeg/Model.py +++ b/models/Model_XSeg/Model.py @@ -20,7 +20,7 @@ class XSegModel(ModelBase): #override def on_initialize_options(self): - ask_override = self.ask_override() + ask_override = False if self.read_from_conf else self.ask_override() if not self.is_first_run() and ask_override: if io.input_bool(f"Restart training?", False, help_message="Reset model weights and start training from scratch."): @@ -30,11 +30,13 @@ class XSegModel(ModelBase): default_pretrain = self.options['pretrain'] = self.load_or_def_option('pretrain', False) if self.is_first_run(): - self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() + if (self.read_from_conf and not self.config_file_exists) or not self.read_from_conf: + self.options['face_type'] = io.input_str ("Face type", default_face_type, ['h','mf','f','wf','head'], help_message="Half / mid face / full face / whole face / head. Choose the same as your deepfake model.").lower() if self.is_first_run() or ask_override: - self.ask_batch_size(4, range=[2,16]) - self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) + if (self.read_from_conf and not self.config_file_exists) or not self.read_from_conf: + self.ask_batch_size(4, range=[2,16]) + self.options['pretrain'] = io.input_bool ("Enable pretraining mode", default_pretrain) if not self.is_exporting and (self.options['pretrain'] and self.get_pretraining_data_path() is None): raise Exception("pretraining_data_path is not defined") @@ -53,13 +55,11 @@ class XSegModel(ModelBase): self.resolution = resolution = 256 - self.face_type = {'h' : FaceType.HALF, 'mf' : FaceType.MID_FULL, 'f' : FaceType.FULL, 'wf' : FaceType.WHOLE_FACE, 'head' : FaceType.HEAD}[ self.options['face_type'] ] - place_model_on_cpu = len(devices) == 0 models_opt_device = '/CPU:0' if place_model_on_cpu else nn.tf_default_device_name @@ -287,4 +287,4 @@ class XSegModel(ModelBase): config_path = Path(__file__).parent.absolute() / Path("config_schema.json") return config_path -Model = XSegModel \ No newline at end of file +Model = XSegModel