From 46e7307ecad0e63ce5795ae389ec36f6e5958daf Mon Sep 17 00:00:00 2001 From: Jan Date: Sun, 5 Dec 2021 15:33:01 +0100 Subject: [PATCH] added different schemas for the model types (currently, amp and saehd) --- models/ModelBase.py | 6 +++++- models/Model_AMP/Model.py | 8 ++++++++ models/{ => Model_AMP}/config_schema.json | 18 +++++++++++++----- models/Model_SAEHD/Model.py | 7 +++++++ models/__init__.py | 4 ---- 5 files changed, 33 insertions(+), 10 deletions(-) rename models/{ => Model_AMP}/config_schema.json (94%) diff --git a/models/ModelBase.py b/models/ModelBase.py index 75cd4fe..33b7922 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -433,6 +433,10 @@ class ModelBase(object): #return predictor_func, predictor_input_shape, MergerConfig() for the model raise NotImplementedError + #overridable + def get_config_schema_path(self): + raise NotImplementedError + def get_pretraining_data_path(self): return self.pretraining_data_path @@ -493,7 +497,7 @@ class ModelBase(object): fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path try: - with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema: + with open(fun(), 'r') as file, open(self.get_config_schema_path(), 'r') as schema: data = yaml.safe_load(file) validate(data, yaml.safe_load(schema)) except FileNotFoundError: diff --git a/models/Model_AMP/Model.py b/models/Model_AMP/Model.py index 4a026c2..6aa0093 100644 --- a/models/Model_AMP/Model.py +++ b/models/Model_AMP/Model.py @@ -11,6 +11,8 @@ from models import ModelBase from samplelib import * from core.cv2ex import * +from pathlib import Path + class AMPModel(ModelBase): #override @@ -812,4 +814,10 @@ class AMPModel(ModelBase): import merger return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + #override + def get_config_schema_path(self): + config_path = Path(__file__).parent.absolute() / Path("config_schema.json") + return config_path + + Model = AMPModel diff --git a/models/config_schema.json b/models/Model_AMP/config_schema.json similarity index 94% rename from models/config_schema.json rename to models/Model_AMP/config_schema.json index ae9b930..56c492e 100644 --- a/models/config_schema.json +++ b/models/Model_AMP/config_schema.json @@ -4,14 +4,15 @@ "definitions": { "dfl_config": { "type": "object", - "additionalProperties": false, + "additionalProperties": true, "properties": { "use_fp16": { "type": "boolean" }, - "archi": { - "type": "string", - "pattern": "^(df|liae)-(\\b(?!\\w*(\\w)\\w*\\1)[udtc]+\\b)+|^(df|liae)$" + "morph_factor": { + "type": "number", + "minimum":0.0, + "maximum":1.0 }, "resolution": { "type": "integer", @@ -44,6 +45,12 @@ "maximum": 256, "multipleOf": 2 }, + "inter_dims": { + "type": "integer", + "minimum": 32, + "maximum": 2048, + "multipleOf": 2 + }, "d_dims": { "type": "integer", "minimum": 16, @@ -214,7 +221,6 @@ "required": [ "adabelief", "ae_dims", - "archi", "autobackup_hour", "background_power", "batch_size", @@ -225,6 +231,8 @@ "d_dims", "d_mask_dims", "e_dims", + "inter_dims", + "morph_factor", "eyes_prio", "face_style_power", "face_type", diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 08420ba..5b87442 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -10,6 +10,8 @@ from facelib import FaceType from models import ModelBase from samplelib import * +from pathlib import Path + class SAEHDModel(ModelBase): #override @@ -1058,4 +1060,9 @@ class SAEHDModel(ModelBase): import merger return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay') + #override + def get_config_schema_path(self): + config_path = Path(__file__).parent.absolute() / Path("config_schema.json") + return config_path + Model = SAEHDModel diff --git a/models/__init__.py b/models/__init__.py index 905e2cf..7c0782d 100644 --- a/models/__init__.py +++ b/models/__init__.py @@ -1,11 +1,7 @@ from .ModelBase import ModelBase -from pathlib import Path def import_model(model_class_name): module = __import__('Model_'+model_class_name, globals(), locals(), [], 1) return getattr(module, 'Model') -def get_config_schema_path(): - config_path = Path(__file__).parent.absolute() / Path("config_schema.json") - return config_path \ No newline at end of file