added different schemas for the model types (currently, amp and saehd)

This commit is contained in:
Jan 2021-12-05 15:33:01 +01:00
commit 46e7307eca
5 changed files with 33 additions and 10 deletions

View file

@ -433,6 +433,10 @@ class ModelBase(object):
#return predictor_func, predictor_input_shape, MergerConfig() for the model #return predictor_func, predictor_input_shape, MergerConfig() for the model
raise NotImplementedError raise NotImplementedError
#overridable
def get_config_schema_path(self):
raise NotImplementedError
def get_pretraining_data_path(self): def get_pretraining_data_path(self):
return self.pretraining_data_path return self.pretraining_data_path
@ -493,7 +497,7 @@ class ModelBase(object):
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
try: try:
with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema: with open(fun(), 'r') as file, open(self.get_config_schema_path(), 'r') as schema:
data = yaml.safe_load(file) data = yaml.safe_load(file)
validate(data, yaml.safe_load(schema)) validate(data, yaml.safe_load(schema))
except FileNotFoundError: except FileNotFoundError:

View file

@ -11,6 +11,8 @@ from models import ModelBase
from samplelib import * from samplelib import *
from core.cv2ex import * from core.cv2ex import *
from pathlib import Path
class AMPModel(ModelBase): class AMPModel(ModelBase):
#override #override
@ -812,4 +814,10 @@ class AMPModel(ModelBase):
import merger import merger
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
#override
def get_config_schema_path(self):
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path
Model = AMPModel Model = AMPModel

View file

@ -4,14 +4,15 @@
"definitions": { "definitions": {
"dfl_config": { "dfl_config": {
"type": "object", "type": "object",
"additionalProperties": false, "additionalProperties": true,
"properties": { "properties": {
"use_fp16": { "use_fp16": {
"type": "boolean" "type": "boolean"
}, },
"archi": { "morph_factor": {
"type": "string", "type": "number",
"pattern": "^(df|liae)-(\\b(?!\\w*(\\w)\\w*\\1)[udtc]+\\b)+|^(df|liae)$" "minimum":0.0,
"maximum":1.0
}, },
"resolution": { "resolution": {
"type": "integer", "type": "integer",
@ -44,6 +45,12 @@
"maximum": 256, "maximum": 256,
"multipleOf": 2 "multipleOf": 2
}, },
"inter_dims": {
"type": "integer",
"minimum": 32,
"maximum": 2048,
"multipleOf": 2
},
"d_dims": { "d_dims": {
"type": "integer", "type": "integer",
"minimum": 16, "minimum": 16,
@ -214,7 +221,6 @@
"required": [ "required": [
"adabelief", "adabelief",
"ae_dims", "ae_dims",
"archi",
"autobackup_hour", "autobackup_hour",
"background_power", "background_power",
"batch_size", "batch_size",
@ -225,6 +231,8 @@
"d_dims", "d_dims",
"d_mask_dims", "d_mask_dims",
"e_dims", "e_dims",
"inter_dims",
"morph_factor",
"eyes_prio", "eyes_prio",
"face_style_power", "face_style_power",
"face_type", "face_type",

View file

@ -10,6 +10,8 @@ from facelib import FaceType
from models import ModelBase from models import ModelBase
from samplelib import * from samplelib import *
from pathlib import Path
class SAEHDModel(ModelBase): class SAEHDModel(ModelBase):
#override #override
@ -1058,4 +1060,9 @@ class SAEHDModel(ModelBase):
import merger import merger
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
#override
def get_config_schema_path(self):
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path
Model = SAEHDModel Model = SAEHDModel

View file

@ -1,11 +1,7 @@
from .ModelBase import ModelBase from .ModelBase import ModelBase
from pathlib import Path
def import_model(model_class_name): def import_model(model_class_name):
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1) module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
return getattr(module, 'Model') return getattr(module, 'Model')
def get_config_schema_path():
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
return config_path