upgrade FaceAnimator module. Now uses https://github.com/wyhsirius/LIA model

This commit is contained in:
iperov 2022-09-16 12:10:15 +04:00
parent 42e835de65
commit 02de563a00
14 changed files with 144 additions and 201 deletions

View file

@ -3,6 +3,7 @@ from enum import IntEnum
import numpy as np import numpy as np
from xlib import os as lib_os from xlib import os as lib_os
from xlib.face import FRect
from xlib.mp import csw as lib_csw from xlib.mp import csw as lib_csw
from xlib.python import all_is_not_None from xlib.python import all_is_not_None
@ -14,9 +15,11 @@ from .BackendBase import (BackendConnection, BackendDB, BackendHost,
class AlignMode(IntEnum): class AlignMode(IntEnum):
FROM_RECT = 0 FROM_RECT = 0
FROM_POINTS = 1 FROM_POINTS = 1
FROM_STATIC_RECT = 2
AlignModeNames = ['@FaceAligner.AlignMode.FROM_RECT', AlignModeNames = ['@FaceAligner.AlignMode.FROM_RECT',
'@FaceAligner.AlignMode.FROM_POINTS', '@FaceAligner.AlignMode.FROM_POINTS',
'@FaceAligner.AlignMode.FROM_STATIC_RECT',
] ]
class FaceAligner(BackendHost): class FaceAligner(BackendHost):
@ -57,7 +60,7 @@ class FaceAlignerWorker(BackendWorker):
cs.align_mode.select(state.align_mode if state.align_mode is not None else AlignMode.FROM_POINTS) cs.align_mode.select(state.align_mode if state.align_mode is not None else AlignMode.FROM_POINTS)
cs.face_coverage.enable() cs.face_coverage.enable()
cs.face_coverage.set_config(lib_csw.Number.Config(min=0.1, max=4.0, step=0.1, decimals=1, allow_instant_update=True)) cs.face_coverage.set_config(lib_csw.Number.Config(min=0.1, max=8.0, step=0.1, decimals=1, allow_instant_update=True))
cs.face_coverage.set_number(state.face_coverage if state.face_coverage is not None else 2.2) cs.face_coverage.set_number(state.face_coverage if state.face_coverage is not None else 2.2)
cs.resolution.enable() cs.resolution.enable()
@ -74,11 +77,11 @@ class FaceAlignerWorker(BackendWorker):
cs.freeze_z_rotation.set_flag(state.freeze_z_rotation if state.freeze_z_rotation is not None else False) cs.freeze_z_rotation.set_flag(state.freeze_z_rotation if state.freeze_z_rotation is not None else False)
cs.x_offset.enable() cs.x_offset.enable()
cs.x_offset.set_config(lib_csw.Number.Config(min=-1, max=1, step=0.01, decimals=2, allow_instant_update=True)) cs.x_offset.set_config(lib_csw.Number.Config(min=-10, max=10, step=0.01, decimals=2, allow_instant_update=True))
cs.x_offset.set_number(state.x_offset if state.x_offset is not None else 0) cs.x_offset.set_number(state.x_offset if state.x_offset is not None else 0)
cs.y_offset.enable() cs.y_offset.enable()
cs.y_offset.set_config(lib_csw.Number.Config(min=-1, max=1, step=0.01, decimals=2, allow_instant_update=True)) cs.y_offset.set_config(lib_csw.Number.Config(min=-10, max=10, step=0.01, decimals=2, allow_instant_update=True))
cs.y_offset.set_number(state.y_offset if state.y_offset is not None else 0) cs.y_offset.set_number(state.y_offset if state.y_offset is not None else 0)
def on_cs_align_mode(self, idx, align_mode): def on_cs_align_mode(self, idx, align_mode):
@ -164,18 +167,22 @@ class FaceAlignerWorker(BackendWorker):
if face_ulmrks is not None: if face_ulmrks is not None:
fsi.face_resolution = state.resolution fsi.face_resolution = state.resolution
H, W = frame_image.shape[:2]
if state.align_mode == AlignMode.FROM_RECT: if state.align_mode == AlignMode.FROM_RECT:
face_align_img, uni_mat = fsi.face_urect.cut(frame_image, coverage= state.face_coverage, output_size=state.resolution, face_align_img, uni_mat = fsi.face_urect.cut(frame_image, coverage= state.face_coverage, output_size=state.resolution,
x_offset=state.x_offset, y_offset=state.y_offset) x_offset=state.x_offset, y_offset=state.y_offset)
elif state.align_mode == AlignMode.FROM_POINTS: elif state.align_mode == AlignMode.FROM_POINTS:
face_align_img, uni_mat = face_ulmrks.cut(frame_image, state.face_coverage, state.resolution, face_align_img, uni_mat = face_ulmrks.cut(frame_image, state.face_coverage, state.resolution,
exclude_moving_parts=state.exclude_moving_parts, exclude_moving_parts=state.exclude_moving_parts,
head_yaw=head_yaw, head_yaw=head_yaw,
x_offset=state.x_offset, x_offset=state.x_offset,
y_offset=state.y_offset-0.08, y_offset=state.y_offset-0.08,
freeze_z_rotation=state.freeze_z_rotation) freeze_z_rotation=state.freeze_z_rotation)
elif state.align_mode == AlignMode.FROM_STATIC_RECT:
rect = FRect.from_ltrb([ 0.5 - (fsi.face_resolution/W)/2, 0.5 - (fsi.face_resolution/H)/2, 0.5 + (fsi.face_resolution/W)/2, 0.5 + (fsi.face_resolution/H)/2,])
face_align_img, uni_mat = rect.cut(frame_image, coverage= state.face_coverage, output_size=state.resolution,
x_offset=state.x_offset, y_offset=state.y_offset)
fsi.face_align_image_name = f'{frame_image_name}_{face_id}_aligned' fsi.face_align_image_name = f'{frame_image_name}_{face_id}_aligned'
fsi.image_to_align_uni_mat = uni_mat fsi.image_to_align_uni_mat = uni_mat

View file

@ -1,10 +1,8 @@
import re
import time import time
from pathlib import Path from pathlib import Path
import cv2
import numpy as np import numpy as np
from modelhub.onnx import TPSMM from modelhub.onnx import LIA
from xlib import cv as lib_cv2 from xlib import cv as lib_cv2
from xlib import os as lib_os from xlib import os as lib_os
from xlib import path as lib_path from xlib import path as lib_path
@ -29,7 +27,7 @@ class FaceAnimator(BackendHost):
def get_control_sheet(self) -> 'Sheet.Host': return super().get_control_sheet() def get_control_sheet(self) -> 'Sheet.Host': return super().get_control_sheet()
def _get_name(self): def _get_name(self):
return super()._get_name()# + f'{self._id}' return super()._get_name()
class FaceAnimatorWorker(BackendWorker): class FaceAnimatorWorker(BackendWorker):
def get_state(self) -> 'WorkerState': return super().get_state() def get_state(self) -> 'WorkerState': return super().get_state()
@ -44,11 +42,10 @@ class FaceAnimatorWorker(BackendWorker):
self.pending_bcd = None self.pending_bcd = None
self.tpsmm_model = None self.lia_model : LIA = None
self.animatable_img = None self.animatable_img = None
self.driving_ref_kp = None self.driving_ref_motion = None
self.last_driving_kp = None
lib_os.set_timer_resolution(1) lib_os.set_timer_resolution(1)
@ -58,14 +55,12 @@ class FaceAnimatorWorker(BackendWorker):
cs.animatable.call_on_selected(self.on_cs_animatable) cs.animatable.call_on_selected(self.on_cs_animatable)
cs.animator_face_id.call_on_number(self.on_cs_animator_face_id) cs.animator_face_id.call_on_number(self.on_cs_animator_face_id)
cs.relative_mode.call_on_flag(self.on_cs_relative_mode)
cs.relative_power.call_on_number(self.on_cs_relative_power) cs.relative_power.call_on_number(self.on_cs_relative_power)
cs.update_animatables.call_on_signal(self.update_animatables) cs.update_animatables.call_on_signal(self.update_animatables)
cs.reset_reference_pose.call_on_signal(self.on_cs_reset_reference_pose) cs.reset_reference_pose.call_on_signal(self.on_cs_reset_reference_pose)
cs.device.enable() cs.device.enable()
cs.device.set_choices( TPSMM.get_available_devices(), none_choice_name='@misc.menu_select') cs.device.set_choices( LIA.get_available_devices(), none_choice_name='@misc.menu_select')
cs.device.select(state.device) cs.device.select(state.device)
def update_animatables(self): def update_animatables(self):
@ -76,7 +71,7 @@ class FaceAnimatorWorker(BackendWorker):
def on_cs_device(self, idx, device): def on_cs_device(self, idx, device):
state, cs = self.get_state(), self.get_control_sheet() state, cs = self.get_state(), self.get_control_sheet()
if device is not None and state.device == device: if device is not None and state.device == device:
self.tpsmm_model = TPSMM(device) self.lia_model = LIA(device)
cs.animatable.enable() cs.animatable.enable()
self.update_animatables() self.update_animatables()
@ -85,12 +80,9 @@ class FaceAnimatorWorker(BackendWorker):
cs.animator_face_id.enable() cs.animator_face_id.enable()
cs.animator_face_id.set_config(lib_csw.Number.Config(min=0, max=16, step=1, decimals=0, allow_instant_update=True)) cs.animator_face_id.set_config(lib_csw.Number.Config(min=0, max=16, step=1, decimals=0, allow_instant_update=True))
cs.animator_face_id.set_number(state.animator_face_id if state.animator_face_id is not None else 0) cs.animator_face_id.set_number(state.animator_face_id if state.animator_face_id is not None else 0)
cs.relative_mode.enable()
cs.relative_mode.set_flag(state.relative_mode if state.relative_mode is not None else True)
cs.relative_power.enable() cs.relative_power.enable()
cs.relative_power.set_config(lib_csw.Number.Config(min=0.0, max=1.0, step=0.01, decimals=2, allow_instant_update=True)) cs.relative_power.set_config(lib_csw.Number.Config(min=0.0, max=2.0, step=0.01, decimals=2, allow_instant_update=True))
cs.relative_power.set_number(state.relative_power if state.relative_power is not None else 1.0) cs.relative_power.set_number(state.relative_power if state.relative_power is not None else 1.0)
cs.update_animatables.enable() cs.update_animatables.enable()
@ -105,20 +97,15 @@ class FaceAnimatorWorker(BackendWorker):
state.animatable = animatable state.animatable = animatable
self.animatable_img = None self.animatable_img = None
self.animatable_kp = None self.driving_ref_motion = None
self.driving_ref_kp = None
if animatable is not None: if animatable is not None:
try: try:
W,H = self.tpsmm_model.get_input_size() W,H = self.lia_model.get_input_size()
ip = ImageProcessor(lib_cv2.imread(self.animatables_path / animatable)) ip = ImageProcessor(lib_cv2.imread(self.animatables_path / animatable))
ip.fit_in(TW=W, TH=H, pad_to_target=True, allow_upscale=True) ip.fit_in(TW=W, TH=H, pad_to_target=True, allow_upscale=True)
animatable_img = ip.get_image('HWC') self.animatable_img = ip.get_image('HWC')
animatable_kp = self.tpsmm_model.extract_kp(animatable_img)
self.animatable_img = animatable_img
self.animatable_kp = animatable_kp
except Exception as e: except Exception as e:
cs.animatable.unselect() cs.animatable.unselect()
@ -133,13 +120,6 @@ class FaceAnimatorWorker(BackendWorker):
cs.animator_face_id.set_number(animator_face_id) cs.animator_face_id.set_number(animator_face_id)
self.save_state() self.save_state()
self.reemit_frame_signal.send() self.reemit_frame_signal.send()
def on_cs_relative_mode(self, relative_mode):
state, cs = self.get_state(), self.get_control_sheet()
state.relative_mode = relative_mode
self.save_state()
self.reemit_frame_signal.send()
def on_cs_relative_power(self, relative_power): def on_cs_relative_power(self, relative_power):
state, cs = self.get_state(), self.get_control_sheet() state, cs = self.get_state(), self.get_control_sheet()
cfg = cs.relative_power.get_config() cfg = cs.relative_power.get_config()
@ -149,7 +129,7 @@ class FaceAnimatorWorker(BackendWorker):
self.reemit_frame_signal.send() self.reemit_frame_signal.send()
def on_cs_reset_reference_pose(self): def on_cs_reset_reference_pose(self):
self.driving_ref_kp = self.last_driving_kp self.driving_ref_motion = None
self.reemit_frame_signal.send() self.reemit_frame_signal.send()
def on_tick(self): def on_tick(self):
@ -162,8 +142,8 @@ class FaceAnimatorWorker(BackendWorker):
if bcd is not None: if bcd is not None:
bcd.assign_weak_heap(self.weak_heap) bcd.assign_weak_heap(self.weak_heap)
tpsmm_model = self.tpsmm_model lia_model = self.lia_model
if tpsmm_model is not None and self.animatable_img is not None: if lia_model is not None and self.animatable_img is not None:
for i, fsi in enumerate(bcd.get_face_swap_info_list()): for i, fsi in enumerate(bcd.get_face_swap_info_list()):
if state.animator_face_id == i: if state.animator_face_id == i:
@ -172,14 +152,10 @@ class FaceAnimatorWorker(BackendWorker):
_,H,W,_ = ImageProcessor(face_align_image).get_dims() _,H,W,_ = ImageProcessor(face_align_image).get_dims()
driving_kp = self.last_driving_kp = tpsmm_model.extract_kp(face_align_image) if self.driving_ref_motion is None:
self.driving_ref_motion = lia_model.extract_motion(face_align_image)
if self.driving_ref_kp is None: anim_image = lia_model.generate(self.animatable_img, face_align_image, self.driving_ref_motion, power=state.relative_power)
self.driving_ref_kp = driving_kp
anim_image = tpsmm_model.generate(self.animatable_img, self.animatable_kp, driving_kp,
self.driving_ref_kp if state.relative_mode else None,
relative_power=state.relative_power)
anim_image = ImageProcessor(anim_image).resize((W,H)).get_image('HWC') anim_image = ImageProcessor(anim_image).resize((W,H)).get_image('HWC')
fsi.face_swap_image_name = f'{fsi.face_align_image_name}_swapped' fsi.face_swap_image_name = f'{fsi.face_align_image_name}_swapped'
@ -203,7 +179,6 @@ class Sheet:
self.device = lib_csw.DynamicSingleSwitch.Client() self.device = lib_csw.DynamicSingleSwitch.Client()
self.animatable = lib_csw.DynamicSingleSwitch.Client() self.animatable = lib_csw.DynamicSingleSwitch.Client()
self.animator_face_id = lib_csw.Number.Client() self.animator_face_id = lib_csw.Number.Client()
self.relative_mode = lib_csw.Flag.Client()
self.update_animatables = lib_csw.Signal.Client() self.update_animatables = lib_csw.Signal.Client()
self.reset_reference_pose = lib_csw.Signal.Client() self.reset_reference_pose = lib_csw.Signal.Client()
self.relative_power = lib_csw.Number.Client() self.relative_power = lib_csw.Number.Client()
@ -214,7 +189,6 @@ class Sheet:
self.device = lib_csw.DynamicSingleSwitch.Host() self.device = lib_csw.DynamicSingleSwitch.Host()
self.animatable = lib_csw.DynamicSingleSwitch.Host() self.animatable = lib_csw.DynamicSingleSwitch.Host()
self.animator_face_id = lib_csw.Number.Host() self.animator_face_id = lib_csw.Number.Host()
self.relative_mode = lib_csw.Flag.Host()
self.update_animatables = lib_csw.Signal.Host() self.update_animatables = lib_csw.Signal.Host()
self.reset_reference_pose = lib_csw.Signal.Host() self.reset_reference_pose = lib_csw.Signal.Host()
self.relative_power = lib_csw.Number.Host() self.relative_power = lib_csw.Number.Host()
@ -223,5 +197,4 @@ class WorkerState(BackendWorkerState):
device = None device = None
animatable : str = None animatable : str = None
animator_face_id : int = None animator_face_id : int = None
relative_mode : bool = None
relative_power : float = None relative_power : float = None

View file

@ -31,10 +31,8 @@ class QFaceAnimator(QBackendPanel):
q_animator_face_id_label = QLabelPopupInfo(label=L('@QFaceAnimator.animator_face_id') ) q_animator_face_id_label = QLabelPopupInfo(label=L('@QFaceAnimator.animator_face_id') )
q_animator_face_id = QSpinBoxCSWNumber(cs.animator_face_id, reflect_state_widgets=[q_animator_face_id_label]) q_animator_face_id = QSpinBoxCSWNumber(cs.animator_face_id, reflect_state_widgets=[q_animator_face_id_label])
q_relative_mode_label = QLabelPopupInfo(label=L('@QFaceAnimator.relative_mode') ) q_relative_power_label = QLabelPopupInfo(label=L('@QFaceAnimator.relative_power') )
q_relative_mode = QCheckBoxCSWFlag(cs.relative_mode, reflect_state_widgets=[q_relative_mode_label])
q_relative_power = QSliderCSWNumber(cs.relative_power) q_relative_power = QSliderCSWNumber(cs.relative_power)
q_update_animatables = QXPushButtonCSWSignal(cs.update_animatables, image=QXImageDB.reload_outline('light gray'), button_size=(24,22) ) q_update_animatables = QXPushButtonCSWSignal(cs.update_animatables, image=QXImageDB.reload_outline('light gray'), button_size=(24,22) )
@ -52,9 +50,8 @@ class QFaceAnimator(QBackendPanel):
grid_l.addWidget(q_animator_face_id_label, row, 0, alignment=qtx.AlignRight | qtx.AlignVCenter ) grid_l.addWidget(q_animator_face_id_label, row, 0, alignment=qtx.AlignRight | qtx.AlignVCenter )
grid_l.addWidget(q_animator_face_id, row, 1, alignment=qtx.AlignLeft ) grid_l.addWidget(q_animator_face_id, row, 1, alignment=qtx.AlignLeft )
row += 1 row += 1
grid_l.addWidget(q_relative_mode_label, row, 0, alignment=qtx.AlignRight | qtx.AlignVCenter ) grid_l.addWidget(q_relative_power_label, row, 0, alignment=qtx.AlignRight | qtx.AlignVCenter )
grid_l.addLayout(qtx.QXHBoxLayout([q_relative_mode,2,q_relative_power]), row, 1, alignment=qtx.AlignLeft ) grid_l.addWidget(q_relative_power, row, 1 )
row += 1 row += 1
grid_l.addWidget(q_reset_reference_pose, row, 0, 1, 2 ) grid_l.addWidget(q_reset_reference_pose, row, 0, 1, 2 )

View file

@ -656,13 +656,13 @@ class Localization:
'it-IT' : 'Animatore Face ID', 'it-IT' : 'Animatore Face ID',
'ja-JP' : '動かす顔のID番号'}, 'ja-JP' : '動かす顔のID番号'},
'QFaceAnimator.relative_mode':{ 'QFaceAnimator.relative_power':{
'en-US' : 'Relative mode', 'en-US' : 'Relative power',
'ru-RU' : 'Относительный режим', 'ru-RU' : 'Относительная сила',
'zh-CN' : '相对模式', 'zh-CN' : 'Relative power',
'es-ES' : 'Modo relativo', 'es-ES' : 'Relative power',
'it-IT' : 'Modalità relativa', 'it-IT' : 'Relative power',
'ja-JP' : '相対モード'}, 'ja-JP' : 'Relative power'},
'QFaceAnimator.reset_reference_pose':{ 'QFaceAnimator.reset_reference_pose':{
'en-US' : 'Reset reference pose', 'en-US' : 'Reset reference pose',
@ -1143,7 +1143,15 @@ class Localization:
'es-ES' : 'De los puntos', 'es-ES' : 'De los puntos',
'it-IT' : 'Da punti', 'it-IT' : 'Da punti',
'ja-JP' : '点から'}, 'ja-JP' : '点から'},
'FaceAligner.AlignMode.FROM_STATIC_RECT':{
'en-US' : 'From static rect',
'ru-RU' : 'Из статичного прямоугольника',
'zh-CN' : '从一个静态的矩形',
'es-ES' : 'From static rect',
'it-IT' : 'From static rect',
'ja-JP' : 'From static rect'},
'FaceSwapper.model_information':{ 'FaceSwapper.model_information':{
'en-US' : 'Model information', 'en-US' : 'Model information',
'ru-RU' : 'Информация о модели', 'ru-RU' : 'Информация о модели',

89
modelhub/onnx/LIA/LIA.py Normal file
View file

@ -0,0 +1,89 @@
from pathlib import Path
from typing import List
import numpy as np
from xlib.file import SplittedFile
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class LIA:
"""
Latent Image Animator: Learning to Animate Images via Latent Space Navigation
https://github.com/wyhsirius/LIA
arguments
device_info ORTDeviceInfo
use LIA.get_available_devices()
to determine a list of avaliable devices accepted by model
raises
Exception
"""
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
return get_available_devices_info()
def __init__(self, device_info : ORTDeviceInfo):
if device_info not in LIA.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for LIA')
generator_path = Path(__file__).parent / 'generator.onnx'
SplittedFile.merge(generator_path, delete_parts=False)
if not generator_path.exists():
raise FileNotFoundError(f'{generator_path} not found')
self._generator = InferenceSession_with_device(str(generator_path), device_info)
def get_input_size(self):
"""
returns optimal (Width,Height) for input images, thus you can resize source image to avoid extra load
"""
return (256,256)
def extract_motion(self, img : np.ndarray):
"""
Extract motion from image
arguments
img np.ndarray HW HWC 1HWC uint8/float32
"""
feed_img = ImageProcessor(img).resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW')
return self._generator.run(['out_drv_motion'], {'in_src': np.zeros((1,3,256,256), np.float32),
'in_drv': feed_img,
'in_drv_start_motion': np.zeros((1,20), np.float32),
'in_power' : np.zeros((1,), np.float32)
})[0]
def generate(self, img_source : np.ndarray, img_driver : np.ndarray, driver_start_motion : np.ndarray, power):
"""
arguments
img_source np.ndarray HW HWC 1HWC uint8/float32
img_driver np.ndarray HW HWC 1HWC uint8/float32
driver_start_motion reference motion for driver
"""
ip = ImageProcessor(img_source)
dtype = ip.get_dtype()
_,H,W,_ = ip.get_dims()
out = self._generator.run(['out'], {'in_src': ip.resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW'),
'in_drv' : ImageProcessor(img_driver).resize(self.get_input_size()).ch(3).swap_ch().to_ufloat32(as_tanh=True).get_image('NCHW'),
'in_drv_start_motion' : driver_start_motion,
'in_power' : np.array([power], np.float32)
})[0].transpose(0,2,3,1)[0]
out = ImageProcessor(out).to_dtype(dtype, from_tanh=True).resize((W,H)).swap_ch().get_image('HWC')
return out

View file

@ -1,131 +0,0 @@
from pathlib import Path
from typing import List
import cv2
import numpy as np
from xlib.file import SplittedFile
from xlib.image import ImageProcessor
from xlib.onnxruntime import (InferenceSession_with_device, ORTDeviceInfo,
get_available_devices_info)
class TPSMM:
"""
[CVPR2022] Thin-Plate Spline Motion Model for Image Animation
https://github.com/yoyo-nb/Thin-Plate-Spline-Motion-Model
arguments
device_info ORTDeviceInfo
use TPSMM.get_available_devices()
to determine a list of avaliable devices accepted by model
raises
Exception
"""
@staticmethod
def get_available_devices() -> List[ORTDeviceInfo]:
return get_available_devices_info()
def __init__(self, device_info : ORTDeviceInfo):
if device_info not in TPSMM.get_available_devices():
raise Exception(f'device_info {device_info} is not in available devices for TPSMM')
generator_path = Path(__file__).parent / 'generator.onnx'
SplittedFile.merge(generator_path, delete_parts=False)
if not generator_path.exists():
raise FileNotFoundError(f'{generator_path} not found')
kp_detector_path = Path(__file__).parent / 'kp_detector.onnx'
if not kp_detector_path.exists():
raise FileNotFoundError(f'{kp_detector_path} not found')
self._generator = InferenceSession_with_device(str(generator_path), device_info)
self._kp_detector = InferenceSession_with_device(str(kp_detector_path), device_info)
def get_input_size(self):
"""
returns optimal (Width,Height) for input images, thus you can resize source image to avoid extra load
"""
return (256,256)
def extract_kp(self, img : np.ndarray):
"""
Extract keypoints from image
arguments
img np.ndarray HW HWC 1HWC uint8/float32
"""
feed_img = ImageProcessor(img).resize(self.get_input_size()).swap_ch().to_ufloat32().ch(3).get_image('NCHW')
return self._kp_detector.run(None, {'in': feed_img})[0]
def generate(self, img_source : np.ndarray, kp_source : np.ndarray, kp_driver : np.ndarray, kp_driver_ref : np.ndarray = None, relative_power : float = 1.0):
"""
arguments
img_source np.ndarray HW HWC 1HWC uint8/float32
kp_driver_ref specify to work in kp relative mode
"""
if kp_driver_ref is not None:
kp_driver = self.calc_relative_kp(kp_source=kp_source, kp_driver=kp_driver, kp_driver_ref=kp_driver_ref, power=relative_power)
theta, control_points, control_params = self.create_transformations_params(kp_source, kp_driver)
ip = ImageProcessor(img_source)
dtype = ip.get_dtype()
_,H,W,_ = ip.get_dims()
feed_img = ip.resize(self.get_input_size()).to_ufloat32().ch(3).get_image('NCHW')
out = self._generator.run(None, {'in': feed_img,
'theta' : theta,
'control_points' : control_points,
'control_params' : control_params,
'kp_driver' : kp_driver,
'kp_source' : kp_source,
})[0].transpose(0,2,3,1)[0]
out = ImageProcessor(out).resize( (W,H)).to_dtype(dtype).get_image('HWC')
return out
def calc_relative_kp(self, kp_source, kp_driver, kp_driver_ref, power = 1.0):
source_area = np.array([ cv2.contourArea(cv2.convexHull(pts)) for pts in kp_source ], dtype=kp_source.dtype)
driving_area = np.array([ cv2.contourArea(cv2.convexHull(pts)) for pts in kp_driver_ref ], dtype=kp_driver_ref.dtype)
movement_scale = np.sqrt(source_area) / np.sqrt(driving_area)
return kp_source + (kp_driver - kp_driver_ref) * movement_scale[:,None,None] * power
def create_transformations_params(self, kp_source, kp_driver):
kp_num=10
kp_sub_num=5
kp_d = kp_driver.reshape(-1, kp_num, kp_sub_num, 2)
kp_s = kp_source.reshape(-1, kp_num, kp_sub_num, 2)
K = np.linalg.norm(kp_d[:,:,:,None]-kp_d[:,:,None,:], ord=2, axis=4) ** 2
K = K * np.log(K+1e-9)
kp_1d = np.concatenate([kp_d, np.ones(kp_d.shape[:-1], dtype=kp_d.dtype)[...,None] ], -1)
P = np.concatenate([kp_1d, np.zeros(kp_d.shape[:2] + (3, 3), dtype=kp_d.dtype)], 2)
L = np.concatenate([K,kp_1d.transpose(0,1,3,2)],2)
L = np.concatenate([L,P],3)
Y = np.concatenate([kp_s, np.zeros(kp_d.shape[:2] + (3, 2), dtype=kp_d.dtype)], 2)
one = np.broadcast_to( np.eye(Y.shape[2], dtype=kp_d.dtype), L.shape)*0.01
L = L + one
param = np.matmul(np.linalg.inv(L),Y)
theta = param[:,:,kp_sub_num:,:].transpose(0,1,3,2)
control_points = kp_d
control_params = param[:,:,:kp_sub_num,:]
return theta, control_points, control_params

View file

@ -3,4 +3,4 @@ from .FaceMesh.FaceMesh import FaceMesh
from .S3FD.S3FD import S3FD from .S3FD.S3FD import S3FD
from .YoloV5Face.YoloV5Face import YoloV5Face from .YoloV5Face.YoloV5Face import YoloV5Face
from .InsightFace2d106.InsightFace2D106 import InsightFace2D106 from .InsightFace2d106.InsightFace2D106 import InsightFace2D106
from .TPSMM.TPSMM import TPSMM from .LIA.LIA import LIA

View file

@ -9,7 +9,7 @@ from xlib import cv as lib_cv
repo_root = Path(__file__).parent.parent repo_root = Path(__file__).parent.parent
large_files_list = [ (repo_root / 'modelhub' / 'onnx' / 'S3FD' / 'S3FD.onnx', 48*1024*1024), large_files_list = [ (repo_root / 'modelhub' / 'onnx' / 'S3FD' / 'S3FD.onnx', 48*1024*1024),
(repo_root / 'modelhub' / 'onnx' / 'TPSMM' / 'generator.onnx', 50*1024*1024), (repo_root / 'modelhub' / 'onnx' / 'LIA' / 'generator.onnx', 48*1024*1024),
(repo_root / 'modelhub' / 'torch' / 'S3FD' / 'S3FD.pth', 48*1024*1024), (repo_root / 'modelhub' / 'torch' / 'S3FD' / 'S3FD.pth', 48*1024*1024),
(repo_root / 'modelhub' / 'cv' / 'FaceMarkerLBF' / 'lbfmodel.yaml', 34*1024*1024), (repo_root / 'modelhub' / 'cv' / 'FaceMarkerLBF' / 'lbfmodel.yaml', 34*1024*1024),
] ]