mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
XSegUtil apply xseg now checks model face type
This commit is contained in:
parent
b15cdd96a1
commit
8b90ca0dac
1 changed files with 22 additions and 6 deletions
|
@ -11,7 +11,7 @@ from core.interact import interact as io
|
||||||
from core.leras import nn
|
from core.leras import nn
|
||||||
from DFLIMG import *
|
from DFLIMG import *
|
||||||
from facelib import XSegNet, LandmarksProcessor, FaceType
|
from facelib import XSegNet, LandmarksProcessor, FaceType
|
||||||
|
import pickle
|
||||||
|
|
||||||
def apply_xseg(input_path, model_path):
|
def apply_xseg(input_path, model_path):
|
||||||
if not input_path.exists():
|
if not input_path.exists():
|
||||||
|
@ -20,20 +20,36 @@ def apply_xseg(input_path, model_path):
|
||||||
if not model_path.exists():
|
if not model_path.exists():
|
||||||
raise ValueError(f'{model_path} not found. Please ensure it exists.')
|
raise ValueError(f'{model_path} not found. Please ensure it exists.')
|
||||||
|
|
||||||
|
face_type = None
|
||||||
|
|
||||||
|
model_dat = model_path / 'XSeg_data.dat'
|
||||||
|
if model_dat.exists():
|
||||||
|
dat = pickle.loads( model_dat.read_bytes() )
|
||||||
|
dat_options = dat.get('options', None)
|
||||||
|
if dat_options is not None:
|
||||||
|
face_type = dat_options.get('face_type', None)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
if face_type is None:
|
||||||
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
|
face_type = io.input_str ("XSeg model face type", 'same', ['h','mf','f','wf','head','same'], help_message="Specify face type of trained XSeg model. For example if XSeg model trained as WF, but faceset is HEAD, specify WF to apply xseg only on WF part of HEAD. Default is 'same'").lower()
|
||||||
if face_type == 'same':
|
if face_type == 'same':
|
||||||
face_type = None
|
face_type = None
|
||||||
else:
|
|
||||||
|
if face_type is not None:
|
||||||
face_type = {'h' : FaceType.HALF,
|
face_type = {'h' : FaceType.HALF,
|
||||||
'mf' : FaceType.MID_FULL,
|
'mf' : FaceType.MID_FULL,
|
||||||
'f' : FaceType.FULL,
|
'f' : FaceType.FULL,
|
||||||
'wf' : FaceType.WHOLE_FACE,
|
'wf' : FaceType.WHOLE_FACE,
|
||||||
'head' : FaceType.HEAD}[face_type]
|
'head' : FaceType.HEAD}[face_type]
|
||||||
|
|
||||||
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
|
io.log_info(f'Applying trained XSeg model to {input_path.name}/ folder.')
|
||||||
|
|
||||||
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
|
device_config = nn.DeviceConfig.ask_choose_device(choose_only_one=True)
|
||||||
nn.initialize(device_config)
|
nn.initialize(device_config)
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
xseg = XSegNet(name='XSeg',
|
xseg = XSegNet(name='XSeg',
|
||||||
load_weights=True,
|
load_weights=True,
|
||||||
weights_file_root=model_path,
|
weights_file_root=model_path,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue