mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
added different schemas for the model types (currently, amp and saehd)
This commit is contained in:
parent
d7c5e7e9f1
commit
46e7307eca
5 changed files with 33 additions and 10 deletions
|
@ -433,6 +433,10 @@ class ModelBase(object):
|
||||||
#return predictor_func, predictor_input_shape, MergerConfig() for the model
|
#return predictor_func, predictor_input_shape, MergerConfig() for the model
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
#overridable
|
||||||
|
def get_config_schema_path(self):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
def get_pretraining_data_path(self):
|
def get_pretraining_data_path(self):
|
||||||
return self.pretraining_data_path
|
return self.pretraining_data_path
|
||||||
|
|
||||||
|
@ -493,7 +497,7 @@ class ModelBase(object):
|
||||||
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
|
fun = self.get_strpath_configuration_path if not auto_gen else self.get_model_conf_path
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(fun(), 'r') as file, open(models.get_config_schema_path(), 'r') as schema:
|
with open(fun(), 'r') as file, open(self.get_config_schema_path(), 'r') as schema:
|
||||||
data = yaml.safe_load(file)
|
data = yaml.safe_load(file)
|
||||||
validate(data, yaml.safe_load(schema))
|
validate(data, yaml.safe_load(schema))
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
|
|
|
@ -11,6 +11,8 @@ from models import ModelBase
|
||||||
from samplelib import *
|
from samplelib import *
|
||||||
from core.cv2ex import *
|
from core.cv2ex import *
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class AMPModel(ModelBase):
|
class AMPModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -812,4 +814,10 @@ class AMPModel(ModelBase):
|
||||||
import merger
|
import merger
|
||||||
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
return predictor_morph, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_config_schema_path(self):
|
||||||
|
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
|
||||||
|
return config_path
|
||||||
|
|
||||||
|
|
||||||
Model = AMPModel
|
Model = AMPModel
|
||||||
|
|
|
@ -4,14 +4,15 @@
|
||||||
"definitions": {
|
"definitions": {
|
||||||
"dfl_config": {
|
"dfl_config": {
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"additionalProperties": false,
|
"additionalProperties": true,
|
||||||
"properties": {
|
"properties": {
|
||||||
"use_fp16": {
|
"use_fp16": {
|
||||||
"type": "boolean"
|
"type": "boolean"
|
||||||
},
|
},
|
||||||
"archi": {
|
"morph_factor": {
|
||||||
"type": "string",
|
"type": "number",
|
||||||
"pattern": "^(df|liae)-(\\b(?!\\w*(\\w)\\w*\\1)[udtc]+\\b)+|^(df|liae)$"
|
"minimum":0.0,
|
||||||
|
"maximum":1.0
|
||||||
},
|
},
|
||||||
"resolution": {
|
"resolution": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
|
@ -44,6 +45,12 @@
|
||||||
"maximum": 256,
|
"maximum": 256,
|
||||||
"multipleOf": 2
|
"multipleOf": 2
|
||||||
},
|
},
|
||||||
|
"inter_dims": {
|
||||||
|
"type": "integer",
|
||||||
|
"minimum": 32,
|
||||||
|
"maximum": 2048,
|
||||||
|
"multipleOf": 2
|
||||||
|
},
|
||||||
"d_dims": {
|
"d_dims": {
|
||||||
"type": "integer",
|
"type": "integer",
|
||||||
"minimum": 16,
|
"minimum": 16,
|
||||||
|
@ -214,7 +221,6 @@
|
||||||
"required": [
|
"required": [
|
||||||
"adabelief",
|
"adabelief",
|
||||||
"ae_dims",
|
"ae_dims",
|
||||||
"archi",
|
|
||||||
"autobackup_hour",
|
"autobackup_hour",
|
||||||
"background_power",
|
"background_power",
|
||||||
"batch_size",
|
"batch_size",
|
||||||
|
@ -225,6 +231,8 @@
|
||||||
"d_dims",
|
"d_dims",
|
||||||
"d_mask_dims",
|
"d_mask_dims",
|
||||||
"e_dims",
|
"e_dims",
|
||||||
|
"inter_dims",
|
||||||
|
"morph_factor",
|
||||||
"eyes_prio",
|
"eyes_prio",
|
||||||
"face_style_power",
|
"face_style_power",
|
||||||
"face_type",
|
"face_type",
|
|
@ -10,6 +10,8 @@ from facelib import FaceType
|
||||||
from models import ModelBase
|
from models import ModelBase
|
||||||
from samplelib import *
|
from samplelib import *
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
class SAEHDModel(ModelBase):
|
class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
|
@ -1058,4 +1060,9 @@ class SAEHDModel(ModelBase):
|
||||||
import merger
|
import merger
|
||||||
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
return self.predictor_func, (self.options['resolution'], self.options['resolution'], 3), merger.MergerConfigMasked(face_type=self.face_type, default_mode = 'overlay')
|
||||||
|
|
||||||
|
#override
|
||||||
|
def get_config_schema_path(self):
|
||||||
|
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
|
||||||
|
return config_path
|
||||||
|
|
||||||
Model = SAEHDModel
|
Model = SAEHDModel
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
from .ModelBase import ModelBase
|
from .ModelBase import ModelBase
|
||||||
from pathlib import Path
|
|
||||||
|
|
||||||
def import_model(model_class_name):
|
def import_model(model_class_name):
|
||||||
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
|
module = __import__('Model_'+model_class_name, globals(), locals(), [], 1)
|
||||||
return getattr(module, 'Model')
|
return getattr(module, 'Model')
|
||||||
|
|
||||||
|
|
||||||
def get_config_schema_path():
|
|
||||||
config_path = Path(__file__).parent.absolute() / Path("config_schema.json")
|
|
||||||
return config_path
|
|
Loading…
Add table
Add a link
Reference in a new issue