added validation

This commit is contained in:
Jan 2021-12-05 11:15:11 +01:00
commit 13fb700403
2 changed files with 24 additions and 6 deletions

View file

@ -12,6 +12,8 @@ import time
import datetime import datetime
from pathlib import Path from pathlib import Path
import yaml import yaml
from jsonschema import validate, ValidationError
import models
import cv2 import cv2
import numpy as np import numpy as np
@ -151,6 +153,7 @@ class ModelBase(object):
self.config_file_exists = False self.config_file_exists = False
# True if user chooses to read options external or internal conf file # True if user chooses to read options external or internal conf file
self.read_from_conf = False self.read_from_conf = False
config_error = False
#check if config_training_file mode is enabled #check if config_training_file mode is enabled
if config_training_file is not None: if config_training_file is not None:
self.config_file_path = Path(config_training_file) self.config_file_path = Path(config_training_file)
@ -172,7 +175,10 @@ class ModelBase(object):
# to the value of auto_gen_config # to the value of auto_gen_config
self.options = self.read_from_config_file(auto_gen=self.auto_gen_config) self.options = self.read_from_config_file(auto_gen=self.auto_gen_config)
# If options dict is empty options will be loaded from dat file # If options dict is empty options will be loaded from dat file
if not self.options.keys(): if self.options is None:
io.log_info(f"Config file validation error, check your config")
config_error = True
elif not self.options.keys():
io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.") io.log_info(f"Configuration file doesn't exist. A standard configuration file will be created.")
else: else:
self.config_file_exists = True self.config_file_exists = True
@ -224,6 +230,10 @@ class ModelBase(object):
# save as default options only for first run model initialize # save as default options only for first run model initialize
self.default_options_path.write_bytes( pickle.dumps (self.options) ) self.default_options_path.write_bytes( pickle.dumps (self.options) )
# save config file
if self.config_training_file is not None and not self.config_file_exists and not config_error:
self.save_config_file(self.auto_gen_config)
self.session_name = self.options.get('session_name', "") self.session_name = self.options.get('session_name', "")
self.autobackup_hour = self.options.get('autobackup_hour', 0) self.autobackup_hour = self.options.get('autobackup_hour', 0)
self.maximum_n_backups = self.options.get('maximum_n_backups', 24) self.maximum_n_backups = self.options.get('maximum_n_backups', 24)
@ -463,10 +473,6 @@ class ModelBase(object):
} }
pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) ) pathex.write_bytes_safe (self.model_data_path, pickle.dumps(model_data) )
# save config file
if self.config_training_file is not None:
self.save_config_file(self.auto_gen_config)
if self.autobackup_hour != 0: if self.autobackup_hour != 0:
diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 ) diff_hour = int ( (time.time() - self.autobackup_start_time) // 3600 )
@ -487,10 +493,16 @@ class ModelBase(object):
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
try: try:
with open(fun(), 'r') as file: with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema:
data = yaml.safe_load(file) data = yaml.safe_load(file)
validate(data, yaml.safe_load(schema))
except FileNotFoundError: except FileNotFoundError:
return {} return {}
except ValidationError as ve:
io.log_err("%s"%ve)
return None
for key, value in data.items(): for key, value in data.items():
if isinstance(value, bool): if isinstance(value, bool):

View file

@ -1,5 +1,11 @@
from .ModelBase import ModelBase from .ModelBase import ModelBase
from pathlib import Path
def import_model(model_class_name): def import_model(model_class_name):
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1) module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
return getattr(module, 'Model') return getattr(module, 'Model')
def get_config_schema_path():
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path