added different schemas for the model types (currently, amp and saehd)

This commit is contained in:
Jan 2021-12-05 15:33:01 +01:00
commit 46e7307eca
5 changed files with 33 additions and 10 deletions

View file

@ -433,6 +433,10 @@ class ModelBase(object):
#return predictor_func, predictor_input_shape, MergerConfig() for the model
raise NotImplementedError
#overridable
def get_config_schema_path(self):
raise NotImplementedError
def get_pretraining_data_path(self):
return self.pretraining_data_path
@ -493,7 +497,7 @@ class ModelBase(object):
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
try:
with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema:
with open(fun(), 'r') as file, open(self.get_config_schema_path(), 'r') as schema:
data = yaml.safe_load(file)
validate(data, yaml.safe_load(schema))
except FileNotFoundError:

View file

@ -11,6 +11,8 @@ from models import ModelBase
from samplelib import *
from core.cv2ex import *
from pathlib import Path
class AMPModel(ModelBase):
#override
@ -812,4 +814,10 @@ class AMPModel(ModelBase):
import merger
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
#override
def get_config_schema_path(self):
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path
Model = AMPModel

View file

@ -4,14 +4,15 @@
"definitions": {
"dfl_config": {
"type": "object",
"additionalProperties": false,
"additionalProperties": true,
"properties": {
"use_fp16": {
"type": "boolean"
},
"archi": {
"type": "string",
"pattern": "^(df|liae)-(\\b(?!\\w*(\\w)\\w*\\1)[udtc]+\\b)+|^(df|liae)$"
"morph_factor": {
"type": "number",
"minimum":0.0,
"maximum":1.0
},
"resolution": {
"type": "integer",
@ -44,6 +45,12 @@
"maximum": 256,
"multipleOf": 2
},
"inter_dims": {
"type": "integer",
"minimum": 32,
"maximum": 2048,
"multipleOf": 2
},
"d_dims": {
"type": "integer",
"minimum": 16,
@ -214,7 +221,6 @@
"required": [
"adabelief",
"ae_dims",
"archi",
"autobackup_hour",
"background_power",
"batch_size",
@ -225,6 +231,8 @@
"d_dims",
"d_mask_dims",
"e_dims",
"inter_dims",
"morph_factor",
"eyes_prio",
"face_style_power",
"face_type",

View file

@ -10,6 +10,8 @@ from facelib import FaceType
from models import ModelBase
from samplelib import *
from pathlib import Path
class SAEHDModel(ModelBase):
#override
@ -1058,4 +1060,9 @@ class SAEHDModel(ModelBase):
import merger
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
#override
def get_config_schema_path(self):
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path
Model = SAEHDModel

View file

@ -1,11 +1,7 @@
from .ModelBase import ModelBase
from pathlib import Path
def import_model(model_class_name):
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
return getattr(module, 'Model')
def get_config_schema_path():
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path