mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
Added new face type : head
Now you can replace the head. Example: https://www.youtube.com/watch?v=xr5FHd0AdlQ Requirements: Post processing skill in Adobe After Effects or Davinci Resolve. Usage: 1) Find suitable dst footage with the monotonous background behind head 2) Use “extract head” script 3) Gather rich src headset from only one scene (same color and haircut) 4) Mask whole head for src and dst using XSeg editor 5) Train XSeg 6) Apply trained XSeg mask for src and dst headsets 7) Train SAEHD using ‘head’ face_type as regular deepfake model with DF archi. You can use pretrained model for head. Minimum recommended resolution for head is 224. 8) Extract multiple tracks, using Merger: a. Raw-rgb b. XSeg-prd mask c. XSeg-dst mask 9) Using AAE or DavinciResolve, do: a. Hide source head using XSeg-prd mask: content-aware-fill, clone-stamp, background retraction, or other technique b. Overlay new head using XSeg-dst mask Warning: Head faceset can be used for whole_face or less types of training only with XSeg masking. XSegEditor: added button ‘view trained XSeg mask’, so you can see which frames should be masked to improve mask quality.
This commit is contained in:
parent
383d4d3736
commit
2b7364005d
21 changed files with 506 additions and 413 deletions
|
@ -18,6 +18,7 @@ from PyQt5.QtWidgets import *
|
|||
|
||||
from core import pathex
|
||||
from core.cv2ex import *
|
||||
from core import imagelib
|
||||
from core.imagelib import SegIEPoly, SegIEPolys, SegIEPolyType, sd
|
||||
from core.qtex import *
|
||||
from DFLIMG import *
|
||||
|
@ -33,6 +34,7 @@ class OpMode(IntEnum):
|
|||
DRAW_PTS = 1
|
||||
EDIT_PTS = 2
|
||||
VIEW_BAKED = 3
|
||||
VIEW_XSEG_MASK = 4
|
||||
|
||||
class PTEditMode(IntEnum):
|
||||
MOVE = 0
|
||||
|
@ -244,11 +246,17 @@ class QCanvasControlsRightBar(QFrame):
|
|||
btn_view_baked_mask.setDefaultAction(self.btn_view_baked_mask_act)
|
||||
btn_view_baked_mask.setIconSize(QUIConfig.icon_q_size)
|
||||
|
||||
btn_view_xseg_mask = QToolButton()
|
||||
self.btn_view_xseg_mask_act = QActionEx( QIconDB.view_xseg, QStringDB.btn_view_xseg_mask_tip, shortcut='5', shortcut_in_tooltip=True, is_checkable=True)
|
||||
btn_view_xseg_mask.setDefaultAction(self.btn_view_xseg_mask_act)
|
||||
btn_view_xseg_mask.setIconSize(QUIConfig.icon_q_size)
|
||||
|
||||
self.btn_poly_color_act_grp = QActionGroup (self)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_red_act)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_green_act)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_poly_color_blue_act)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_view_baked_mask_act)
|
||||
self.btn_poly_color_act_grp.addAction(self.btn_view_xseg_mask_act)
|
||||
self.btn_poly_color_act_grp.setExclusive(True)
|
||||
#==============================================
|
||||
|
||||
|
@ -257,6 +265,7 @@ class QCanvasControlsRightBar(QFrame):
|
|||
controls_bar_frame1_l.addWidget ( btn_poly_color_green )
|
||||
controls_bar_frame1_l.addWidget ( btn_poly_color_blue )
|
||||
controls_bar_frame1_l.addWidget ( btn_view_baked_mask )
|
||||
controls_bar_frame1_l.addWidget ( btn_view_xseg_mask )
|
||||
controls_bar_frame1 = QFrame()
|
||||
controls_bar_frame1.setFrameShape(QFrame.StyledPanel)
|
||||
controls_bar_frame1.setSizePolicy (QSizePolicy.Fixed, QSizePolicy.Fixed)
|
||||
|
@ -274,12 +283,13 @@ class QCanvasOperator(QWidget):
|
|||
super().__init__()
|
||||
self.cbar = cbar
|
||||
|
||||
self.set_cbar_disabled(initialize=False)
|
||||
self.set_cbar_disabled()
|
||||
|
||||
self.cbar.btn_poly_color_red_act.triggered.connect ( lambda : self.set_color_scheme_id(0) )
|
||||
self.cbar.btn_poly_color_green_act.triggered.connect ( lambda : self.set_color_scheme_id(1) )
|
||||
self.cbar.btn_poly_color_blue_act.triggered.connect ( lambda : self.set_color_scheme_id(2) )
|
||||
self.cbar.btn_view_baked_mask_act.toggled.connect ( self.set_view_baked_mask )
|
||||
self.cbar.btn_view_baked_mask_act.toggled.connect ( lambda : self.set_op_mode(OpMode.VIEW_BAKED) )
|
||||
self.cbar.btn_view_xseg_mask_act.toggled.connect ( self.set_view_xseg_mask )
|
||||
|
||||
self.cbar.btn_poly_type_include_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.INCLUDE) )
|
||||
self.cbar.btn_poly_type_exclude_act.triggered.connect ( lambda : self.set_poly_include_type(SegIEPolyType.EXCLUDE) )
|
||||
|
@ -298,10 +308,19 @@ class QCanvasOperator(QWidget):
|
|||
|
||||
self.qp = QPainter()
|
||||
self.initialized = False
|
||||
self.last_state = None
|
||||
|
||||
def initialize(self, q_img, img_look_pt=None, view_scale=None, ie_polys=None, canvas_config=None ):
|
||||
def initialize(self, q_img, img_look_pt=None, view_scale=None, ie_polys=None, xseg_mask=None, canvas_config=None ):
|
||||
self.q_img = q_img
|
||||
self.img_pixmap = QPixmap.fromImage(q_img)
|
||||
|
||||
self.xseg_mask_pixmap = None
|
||||
if xseg_mask is not None:
|
||||
w,h = QSize_to_np ( q_img.size() )
|
||||
xseg_mask = cv2.resize(xseg_mask, (w,h), cv2.INTER_CUBIC)
|
||||
xseg_mask = (imagelib.normalize_channels(xseg_mask, 1) * 255).astype(np.uint8)
|
||||
self.xseg_mask_pixmap = QPixmap.fromImage(QImage_from_np(xseg_mask))
|
||||
|
||||
self.img_size = QSize_to_np (self.img_pixmap.size())
|
||||
|
||||
self.img_look_pt = img_look_pt
|
||||
|
@ -314,45 +333,49 @@ class QCanvasOperator(QWidget):
|
|||
if canvas_config is None:
|
||||
canvas_config = CanvasConfig()
|
||||
self.canvas_config = canvas_config
|
||||
|
||||
|
||||
# UI init
|
||||
self.set_cbar_disabled()
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(False)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(False)
|
||||
|
||||
# Initial vars
|
||||
self.current_cursor = None
|
||||
|
||||
|
||||
self.mouse_hull_poly = None
|
||||
self.mouse_wire_poly = None
|
||||
|
||||
self.drag_type = DragType.NONE
|
||||
self.op_mode = None
|
||||
self.pt_edit_mode = None
|
||||
|
||||
if not hasattr(self, 'color_scheme_id' ):
|
||||
self.color_scheme_id = 1
|
||||
self.set_color_scheme_id(self.color_scheme_id)
|
||||
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
|
||||
# Initial state
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
self.set_color_scheme_id(1)
|
||||
self.set_poly_include_type(SegIEPolyType.INCLUDE)
|
||||
self.set_pt_edit_mode(PTEditMode.MOVE)
|
||||
self.set_view_baked_mask(False)
|
||||
|
||||
self.set_cbar_disabled(initialize=True)
|
||||
|
||||
if not hasattr(self, 'poly_include_type' ):
|
||||
self.poly_include_type = SegIEPolyType.INCLUDE
|
||||
self.set_poly_include_type(self.poly_include_type)
|
||||
|
||||
|
||||
# Apply last state
|
||||
if self.last_state is not None:
|
||||
self.set_color_scheme_id(self.last_state.color_scheme_id)
|
||||
if self.last_state.op_mode is not None:
|
||||
self.set_op_mode(self.last_state.op_mode)
|
||||
|
||||
self.initialized = True
|
||||
|
||||
self.setMouseTracking(True)
|
||||
self.update_cursor()
|
||||
self.update()
|
||||
self.initialized = True
|
||||
|
||||
|
||||
def finalize(self):
|
||||
if self.initialized:
|
||||
|
||||
self.last_state = sn(op_mode = self.op_mode if self.op_mode in [OpMode.VIEW_BAKED, OpMode.VIEW_XSEG_MASK] else None,
|
||||
color_scheme_id = self.color_scheme_id,
|
||||
)
|
||||
|
||||
self.img_pixmap = None
|
||||
self.update_cursor(is_finalize=True)
|
||||
self.setMouseTracking(False)
|
||||
self.setFocusPolicy(Qt.NoFocus)
|
||||
self.set_cbar_disabled(initialize=False)
|
||||
self.set_cbar_disabled()
|
||||
self.initialized = False
|
||||
self.update()
|
||||
|
||||
|
@ -445,16 +468,18 @@ class QCanvasOperator(QWidget):
|
|||
# ====================================== SETTERS =====================================
|
||||
# ====================================================================================
|
||||
# ====================================================================================
|
||||
|
||||
def set_op_mode(self, op_mode, op_poly=None):
|
||||
if op_mode != self.op_mode:
|
||||
|
||||
if not hasattr(self,'op_mode'):
|
||||
self.op_mode = None
|
||||
self.op_poly = None
|
||||
|
||||
if self.op_mode != op_mode:
|
||||
# Finalize prev mode
|
||||
if self.op_mode == OpMode.NONE:
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(True)
|
||||
elif self.op_mode == OpMode.DRAW_PTS:
|
||||
self.cbar.btn_undo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(True)
|
||||
|
||||
if self.op_poly.get_pts_count() < 3:
|
||||
# Remove unfinished poly
|
||||
self.ie_polys.remove_poly(self.op_poly)
|
||||
|
@ -463,59 +488,69 @@ class QCanvasOperator(QWidget):
|
|||
self.cbar.btn_delete_poly_act.setDisabled(True)
|
||||
# Reset pt_edit_move when exit from EDIT_PTS
|
||||
self.set_pt_edit_mode(PTEditMode.MOVE)
|
||||
elif self.op_mode == OpMode.VIEW_BAKED:
|
||||
self.cbar.btn_view_baked_mask_act.setChecked(False)
|
||||
elif self.op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
self.cbar.btn_view_xseg_mask_act.setChecked(False)
|
||||
|
||||
self.op_mode = op_mode
|
||||
|
||||
if self.op_mode == OpMode.NONE:
|
||||
|
||||
# Initialize new mode
|
||||
if op_mode == OpMode.NONE:
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(False)
|
||||
elif self.op_mode == OpMode.DRAW_PTS:
|
||||
elif op_mode == OpMode.DRAW_PTS:
|
||||
self.cbar.btn_undo_pt_act.setDisabled(False)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(False)
|
||||
elif self.op_mode == OpMode.EDIT_PTS:
|
||||
elif op_mode == OpMode.EDIT_PTS:
|
||||
self.cbar.btn_pt_edit_mode_act.setDisabled(False)
|
||||
self.cbar.btn_delete_poly_act.setDisabled(False)
|
||||
|
||||
if self.op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]:
|
||||
elif op_mode == OpMode.VIEW_BAKED:
|
||||
self.cbar.btn_view_baked_mask_act.setChecked(True )
|
||||
n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0
|
||||
h,w,c = n.shape
|
||||
mask = np.zeros( (h,w,1), dtype=np.float32 )
|
||||
self.ie_polys.overlay_mask(mask)
|
||||
n = (mask*255).astype(np.uint8)
|
||||
self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
|
||||
elif op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
self.cbar.btn_view_xseg_mask_act.setChecked(True)
|
||||
if op_mode in [OpMode.DRAW_PTS, OpMode.EDIT_PTS]:
|
||||
self.mouse_op_poly_pt_id = None
|
||||
self.mouse_op_poly_edge_id = None
|
||||
self.mouse_op_poly_edge_id_pt = None
|
||||
#
|
||||
self.op_poly = op_poly
|
||||
if op_poly is not None:
|
||||
self.update_mouse_info()
|
||||
|
||||
self.set_op_poly(op_poly)
|
||||
self.update_cursor()
|
||||
self.update()
|
||||
|
||||
def set_op_poly(self, op_poly):
|
||||
self.op_poly = op_poly
|
||||
if op_poly is not None:
|
||||
self.update_mouse_info()
|
||||
self.update()
|
||||
|
||||
def set_pt_edit_mode(self, pt_edit_mode):
|
||||
if self.pt_edit_mode != pt_edit_mode:
|
||||
if not hasattr(self, 'pt_edit_mode') or self.pt_edit_mode != pt_edit_mode:
|
||||
self.pt_edit_mode = pt_edit_mode
|
||||
self.update_cursor()
|
||||
self.update()
|
||||
|
||||
self.cbar.btn_pt_edit_mode_act.setChecked( self.pt_edit_mode == PTEditMode.ADD_DEL )
|
||||
|
||||
def set_cbar_disabled(self, initialize):
|
||||
def set_cbar_disabled(self):
|
||||
self.cbar.btn_delete_poly_act.setDisabled(True)
|
||||
self.cbar.btn_undo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_redo_pt_act.setDisabled(True)
|
||||
self.cbar.btn_pt_edit_mode_act.setDisabled(True)
|
||||
|
||||
if initialize:
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(False)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(False)
|
||||
else:
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(True)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(True)
|
||||
self.cbar.btn_poly_color_act_grp.setDisabled(True)
|
||||
self.cbar.btn_poly_type_act_grp.setDisabled(True)
|
||||
|
||||
def set_color_scheme_id(self, id):
|
||||
if self.color_scheme_id != id:
|
||||
if self.op_mode == OpMode.VIEW_BAKED:
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
if not hasattr(self, 'color_scheme_id') or self.color_scheme_id != id:
|
||||
self.color_scheme_id = id
|
||||
self.update_cursor()
|
||||
self.update()
|
||||
|
||||
if self.color_scheme_id == 0:
|
||||
self.cbar.btn_poly_color_red_act.setChecked( True )
|
||||
elif self.color_scheme_id == 1:
|
||||
|
@ -524,33 +559,33 @@ class QCanvasOperator(QWidget):
|
|||
self.cbar.btn_poly_color_blue_act.setChecked( True )
|
||||
|
||||
def set_poly_include_type(self, poly_include_type):
|
||||
if self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS]:
|
||||
if self.poly_include_type != poly_include_type:
|
||||
self.poly_include_type = poly_include_type
|
||||
self.update()
|
||||
if not hasattr(self, 'poly_include_type' ) or \
|
||||
( self.poly_include_type != poly_include_type and \
|
||||
self.op_mode in [OpMode.NONE, OpMode.EDIT_PTS] ):
|
||||
self.poly_include_type = poly_include_type
|
||||
self.update()
|
||||
|
||||
self.cbar.btn_poly_type_include_act.setChecked(self.poly_include_type == SegIEPolyType.INCLUDE)
|
||||
self.cbar.btn_poly_type_exclude_act.setChecked(self.poly_include_type == SegIEPolyType.EXCLUDE)
|
||||
|
||||
|
||||
|
||||
def set_view_baked_mask(self, is_checked):
|
||||
def set_view_xseg_mask(self, is_checked):
|
||||
if is_checked:
|
||||
self.set_op_mode(OpMode.VIEW_BAKED)
|
||||
self.set_op_mode(OpMode.VIEW_XSEG_MASK)
|
||||
|
||||
n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0
|
||||
h,w,c = n.shape
|
||||
#n = QImage_to_np ( self.q_img ).astype(np.float32) / 255.0
|
||||
#h,w,c = n.shape
|
||||
|
||||
mask = np.zeros( (h,w,1), dtype=np.float32 )
|
||||
self.ie_polys.overlay_mask(mask)
|
||||
#mask = np.zeros( (h,w,1), dtype=np.float32 )
|
||||
#self.ie_polys.overlay_mask(mask)
|
||||
|
||||
n = (mask*255).astype(np.uint8)
|
||||
#n = (mask*255).astype(np.uint8)
|
||||
|
||||
self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
|
||||
#self.img_baked_pixmap = QPixmap.fromImage(QImage_from_np(n))
|
||||
else:
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
self.cbar.btn_view_baked_mask_act.setChecked(is_checked )
|
||||
self.cbar.btn_view_xseg_mask_act.setChecked(is_checked )
|
||||
|
||||
|
||||
# ====================================================================================
|
||||
# ====================================================================================
|
||||
|
@ -764,7 +799,6 @@ class QCanvasOperator(QWidget):
|
|||
# other cases -> unselect poly
|
||||
self.set_op_mode(OpMode.NONE)
|
||||
|
||||
|
||||
elif btn == Qt.MiddleButton:
|
||||
if self.drag_type == DragType.NONE:
|
||||
# Start image drag
|
||||
|
@ -773,6 +807,7 @@ class QCanvasOperator(QWidget):
|
|||
self.drag_img_look_pt = self.get_img_look_pt()
|
||||
self.update_cursor()
|
||||
|
||||
|
||||
def mouseReleaseEvent(self, ev):
|
||||
super().mouseReleaseEvent(ev)
|
||||
if not self.initialized:
|
||||
|
@ -855,6 +890,11 @@ class QCanvasOperator(QWidget):
|
|||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
qp.drawPixmap(dst_rect, self.img_baked_pixmap, src_rect)
|
||||
elif self.op_mode == OpMode.VIEW_XSEG_MASK:
|
||||
if self.xseg_mask_pixmap is not None:
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
dst_rect = self.img_to_cli_rect( src_rect )
|
||||
qp.drawPixmap(dst_rect, self.xseg_mask_pixmap, src_rect)
|
||||
else:
|
||||
if self.img_pixmap is not None:
|
||||
src_rect = QRect(0, 0, *self.img_size)
|
||||
|
@ -980,6 +1020,7 @@ class QCanvas(QFrame):
|
|||
btn_poly_color_green_act = self.canvas_control_right_bar.btn_poly_color_green_act,
|
||||
btn_poly_color_blue_act = self.canvas_control_right_bar.btn_poly_color_blue_act,
|
||||
btn_view_baked_mask_act = self.canvas_control_right_bar.btn_view_baked_mask_act,
|
||||
btn_view_xseg_mask_act = self.canvas_control_right_bar.btn_view_xseg_mask_act,
|
||||
btn_poly_color_act_grp = self.canvas_control_right_bar.btn_poly_color_act_grp,
|
||||
|
||||
btn_poly_type_include_act = self.canvas_control_left_bar.btn_poly_type_include_act,
|
||||
|
@ -1124,9 +1165,9 @@ class MainWindow(QXMainWindow):
|
|||
if img is None:
|
||||
img = QImage_from_np(cv2_imread(image_path))
|
||||
if img is None:
|
||||
raise Exception(f'Unable to load {image_path}')
|
||||
io.log_err(f'Unable to load {image_path}')
|
||||
except:
|
||||
io.log_err(f"{traceback.format_exc()}")
|
||||
img = None
|
||||
|
||||
return img
|
||||
|
||||
|
@ -1143,25 +1184,32 @@ class MainWindow(QXMainWindow):
|
|||
return False
|
||||
|
||||
dflimg = DFLIMG.load(image_path)
|
||||
if not dflimg or not dflimg.has_data():
|
||||
return False
|
||||
|
||||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
xseg_mask = dflimg.get_xseg_mask()
|
||||
q_img = self.load_QImage(image_path)
|
||||
|
||||
self.canvas.op.initialize ( q_img, ie_polys=ie_polys )
|
||||
if q_img is None:
|
||||
return False
|
||||
|
||||
self.canvas.op.initialize ( q_img, ie_polys=ie_polys, xseg_mask=xseg_mask )
|
||||
|
||||
self.filename_label.setText(str(image_path.name))
|
||||
|
||||
return True
|
||||
|
||||
def canvas_finalize(self, image_path):
|
||||
dflimg = DFLIMG.load(image_path)
|
||||
|
||||
if image_path.exists():
|
||||
dflimg = DFLIMG.load(image_path)
|
||||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
new_ie_polys = self.canvas.op.get_ie_polys()
|
||||
|
||||
ie_polys = dflimg.get_seg_ie_polys()
|
||||
new_ie_polys = self.canvas.op.get_ie_polys()
|
||||
|
||||
if not new_ie_polys.identical(ie_polys):
|
||||
self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys()
|
||||
dflimg.set_seg_ie_polys( new_ie_polys )
|
||||
dflimg.save()
|
||||
if not new_ie_polys.identical(ie_polys):
|
||||
self.image_paths_has_ie_polys[image_path] = new_ie_polys.has_polys()
|
||||
dflimg.set_seg_ie_polys( new_ie_polys )
|
||||
dflimg.save()
|
||||
|
||||
self.canvas.op.finalize()
|
||||
self.filename_label.setText("")
|
||||
|
@ -1182,9 +1230,10 @@ class MainWindow(QXMainWindow):
|
|||
break
|
||||
if len(self.image_paths) == 0:
|
||||
break
|
||||
|
||||
ret = self.canvas_initialize(self.image_paths[0], len(self.image_paths_done) != 0 and only_has_polys)
|
||||
|
||||
|
||||
if self.canvas_initialize(self.image_paths[0], len(self.image_paths_done) != 0 and only_has_polys):
|
||||
if ret or len(self.image_paths_done) == 0:
|
||||
break
|
||||
|
||||
self.update_cached_images()
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue