mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
added validation
This commit is contained in:
parent
e754bf5bd6
commit
13fb700403
2 changed files with 24 additions and 6 deletions
|
@ -12,6 +12,8 @@ import time
|
|||
import datetime
|
||||
from pathlib import Path
|
||||
import yaml
|
||||
from jsonschema import validate, ValidationError
|
||||
import models
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
|
@ -151,6 +153,7 @@ class ModelBase(object):
|
|||
self.config_file_exists = False
|
||||
# True if user chooses to read options external or internal conf file
|
||||
self.read_from_conf = False
|
||||
config_error = False
|
||||
#check if config_training_file mode is enabled
|
||||
if config_training_file is not None:
|
||||
self.config_file_path = Path(config_training_file)
|
||||
|
@ -172,7 +175,10 @@ class ModelBase(object):
|
|||
# to the value of auto_gen_config
|
||||
self.options = self.read_from_config_file(auto_gen=self.auto_gen_config)
|
||||
# If options dict is empty options will be loaded from dat file
|
||||
if not self.options.keys():
|
||||
if self.options is None:
|
||||
io.log_info(f"Config file validation error, check your config")
|
||||
config_error = True
|
||||
elif not self.options.keys():
|
||||
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
|
||||
else:
|
||||
self.config_file_exists = True
|
||||
|
@ -224,6 +230,10 @@ class ModelBase(object):
|
|||
# save as default options only for first run model initialize
|
||||
self.default_options_path.write_bytes( pickle.dumps (self.options) )
|
||||
|
||||
# save config file
|
||||
if self.config_training_file is not None and not self.config_file_exists and not config_error:
|
||||
self.save_config_file(self.auto_gen_config)
|
||||
|
||||
self.session_name = self.options.get('session_name', "")
|
||||
self.autobackup_hour = self.options.get('autobackup_hour', 0)
|
||||
self.maximum_n_backups = self.options.get('maximum_n_backups', 24)
|
||||
|
@ -463,10 +473,6 @@ class ModelBase(object):
|
|||
}
|
||||
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
|
||||
|
||||
# save config file
|
||||
if self.config_training_file is not None:
|
||||
self.save_config_file(self.auto_gen_config)
|
||||
|
||||
if self.autobackup_hour != 0:
|
||||
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
|
||||
|
||||
|
@ -487,10 +493,16 @@ class ModelBase(object):
|
|||
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
|
||||
|
||||
try:
|
||||
with open(fun(), 'r') as file:
|
||||
with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema:
|
||||
|
||||
|
||||
data = yaml.safe_load(file)
|
||||
validate(data, yaml.safe_load(schema))
|
||||
except FileNotFoundError:
|
||||
return {}
|
||||
except ValidationError as ve:
|
||||
io.log_err("%s"%ve)
|
||||
return None
|
||||
|
||||
for key, value in data.items():
|
||||
if isinstance(value, bool):
|
||||
|
|
|
@ -1,5 +1,11 @@
|
|||
from .ModelBase import ModelBase
|
||||
from pathlib import Path
|
||||
|
||||
def import_model(model_class_name):
|
||||
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
|
||||
return getattr(module, 'Model')
|
||||
|
||||
|
||||
def get_config_schema_path():
|
||||
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
|
||||
return config_path
|
Loading…
Add table
Add a link
Reference in a new issue