mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-07 13:32:09 -07:00
parent
b03b147bae
commit
96de328221
3 changed files with 158 additions and 41 deletions
|
@ -6,7 +6,16 @@ import multiprocessing
|
||||||
import cv2
|
import cv2
|
||||||
from tqdm import tqdm
|
from tqdm import tqdm
|
||||||
|
|
||||||
class Interact(object):
|
try:
|
||||||
|
import IPython #if success we are in colab
|
||||||
|
from IPython.display import display, clear_output
|
||||||
|
import PIL
|
||||||
|
import matplotlib.pyplot as plt
|
||||||
|
is_colab = True
|
||||||
|
except:
|
||||||
|
is_colab = False
|
||||||
|
|
||||||
|
class InteractBase(object):
|
||||||
EVENT_LBUTTONDOWN = 1
|
EVENT_LBUTTONDOWN = 1
|
||||||
EVENT_LBUTTONUP = 2
|
EVENT_LBUTTONUP = 2
|
||||||
EVENT_RBUTTONDOWN = 5
|
EVENT_RBUTTONDOWN = 5
|
||||||
|
@ -21,6 +30,30 @@ class Interact(object):
|
||||||
self.key_events = {}
|
self.key_events = {}
|
||||||
self.pg_bar = None
|
self.pg_bar = None
|
||||||
|
|
||||||
|
def is_colab(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def on_destroy_all_windows(self):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_create_window (self, wnd_name):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_show_image (self, wnd_name, img):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_capture_mouse (self, wnd_name):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_capture_keys (self, wnd_name):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_process_messages(self, sleep_time=0):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
|
def on_wait_any_key(self):
|
||||||
|
raise NotImplemented
|
||||||
|
|
||||||
def log_info(self, msg, end='\n'):
|
def log_info(self, msg, end='\n'):
|
||||||
print (msg, end=end)
|
print (msg, end=end)
|
||||||
|
|
||||||
|
@ -31,12 +64,11 @@ class Interact(object):
|
||||||
if wnd_name not in self.named_windows:
|
if wnd_name not in self.named_windows:
|
||||||
#we will show window only on first show_image
|
#we will show window only on first show_image
|
||||||
self.named_windows[wnd_name] = 0
|
self.named_windows[wnd_name] = 0
|
||||||
|
|
||||||
else: print("named_window: ", wnd_name, " already created.")
|
else: print("named_window: ", wnd_name, " already created.")
|
||||||
|
|
||||||
def destroy_all_windows(self):
|
def destroy_all_windows(self):
|
||||||
if len( self.named_windows ) != 0:
|
if len( self.named_windows ) != 0:
|
||||||
cv2.destroyAllWindows()
|
self.on_destroy_all_windows()
|
||||||
self.named_windows = {}
|
self.named_windows = {}
|
||||||
self.capture_mouse_windows = {}
|
self.capture_mouse_windows = {}
|
||||||
self.capture_keys_windows = {}
|
self.capture_keys_windows = {}
|
||||||
|
@ -47,35 +79,24 @@ class Interact(object):
|
||||||
if wnd_name in self.named_windows:
|
if wnd_name in self.named_windows:
|
||||||
if self.named_windows[wnd_name] == 0:
|
if self.named_windows[wnd_name] == 0:
|
||||||
self.named_windows[wnd_name] = 1
|
self.named_windows[wnd_name] = 1
|
||||||
cv2.namedWindow(wnd_name)
|
self.on_create_window(wnd_name)
|
||||||
if wnd_name in self.capture_mouse_windows:
|
if wnd_name in self.capture_mouse_windows:
|
||||||
self.capture_mouse(wnd_name)
|
self.capture_mouse(wnd_name)
|
||||||
|
self.on_show_image(wnd_name,img)
|
||||||
cv2.imshow (wnd_name, img)
|
|
||||||
else: print("show_image: named_window ", wnd_name, " not found.")
|
else: print("show_image: named_window ", wnd_name, " not found.")
|
||||||
|
|
||||||
def capture_mouse(self, wnd_name):
|
def capture_mouse(self, wnd_name):
|
||||||
def onMouse(event, x, y, flags, param):
|
|
||||||
(inst, wnd_name) = param
|
|
||||||
if event == cv2.EVENT_LBUTTONDOWN: ev = Interact.EVENT_LBUTTONDOWN
|
|
||||||
elif event == cv2.EVENT_LBUTTONUP: ev = Interact.EVENT_LBUTTONUP
|
|
||||||
elif event == cv2.EVENT_RBUTTONDOWN: ev = Interact.EVENT_RBUTTONDOWN
|
|
||||||
elif event == cv2.EVENT_RBUTTONUP: ev = Interact.EVENT_RBUTTONUP
|
|
||||||
elif event == cv2.EVENT_MOUSEWHEEL: ev = Interact.EVENT_MOUSEWHEEL
|
|
||||||
|
|
||||||
else: ev = 0
|
|
||||||
inst.add_mouse_event (wnd_name, x, y, ev, flags)
|
|
||||||
|
|
||||||
if wnd_name in self.named_windows:
|
if wnd_name in self.named_windows:
|
||||||
self.capture_mouse_windows[wnd_name] = True
|
self.capture_mouse_windows[wnd_name] = True
|
||||||
if self.named_windows[wnd_name] == 1:
|
if self.named_windows[wnd_name] == 1:
|
||||||
cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) )
|
self.on_capture_mouse(wnd_name)
|
||||||
else: print("capture_mouse: named_window ", wnd_name, " not found.")
|
else: print("capture_mouse: named_window ", wnd_name, " not found.")
|
||||||
|
|
||||||
def capture_keys(self, wnd_name):
|
def capture_keys(self, wnd_name):
|
||||||
if wnd_name in self.named_windows:
|
if wnd_name in self.named_windows:
|
||||||
if wnd_name not in self.capture_keys_windows:
|
if wnd_name not in self.capture_keys_windows:
|
||||||
self.capture_keys_windows[wnd_name] = True
|
self.capture_keys_windows[wnd_name] = True
|
||||||
|
self.on_capture_keys(wnd_name)
|
||||||
else: print("capture_keys: already set for window ", wnd_name)
|
else: print("capture_keys: already set for window ", wnd_name)
|
||||||
else: print("capture_keys: named_window ", wnd_name, " not found.")
|
else: print("capture_keys: named_window ", wnd_name, " not found.")
|
||||||
|
|
||||||
|
@ -101,28 +122,10 @@ class Interact(object):
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
def process_messages(self, sleep_time=0):
|
def process_messages(self, sleep_time=0):
|
||||||
has_windows = False
|
self.on_process_messages(sleep_time)
|
||||||
has_capture_keys = False
|
|
||||||
|
|
||||||
if len(self.named_windows) != 0:
|
|
||||||
has_windows = True
|
|
||||||
|
|
||||||
if len(self.capture_keys_windows) != 0:
|
|
||||||
has_capture_keys = True
|
|
||||||
|
|
||||||
if has_windows or has_capture_keys:
|
|
||||||
wait_key_time = max(1, int(sleep_time*1000) )
|
|
||||||
key = cv2.waitKey(wait_key_time) & 0xFF
|
|
||||||
else:
|
|
||||||
if sleep_time != 0:
|
|
||||||
time.sleep(sleep_time)
|
|
||||||
|
|
||||||
if has_capture_keys and key != 255:
|
|
||||||
for wnd_name in self.capture_keys_windows:
|
|
||||||
self.add_key_event (wnd_name, key)
|
|
||||||
|
|
||||||
def wait_any_key(self):
|
def wait_any_key(self):
|
||||||
cv2.waitKey(0)
|
self.on_wait_any_key()
|
||||||
|
|
||||||
def add_mouse_event(self, wnd_name, x, y, ev, flags):
|
def add_mouse_event(self, wnd_name, x, y, ev, flags):
|
||||||
if wnd_name not in self.mouse_events:
|
if wnd_name not in self.mouse_events:
|
||||||
|
@ -240,4 +243,101 @@ class Interact(object):
|
||||||
sys.stdin = os.fdopen( sys.stdin.fileno() )
|
sys.stdin = os.fdopen( sys.stdin.fileno() )
|
||||||
return inp
|
return inp
|
||||||
|
|
||||||
interact = Interact()
|
|
||||||
|
|
||||||
|
class InteractDesktop(InteractBase):
|
||||||
|
|
||||||
|
def on_destroy_all_windows(self):
|
||||||
|
cv2.destroyAllWindows()
|
||||||
|
|
||||||
|
def on_create_window (self, wnd_name):
|
||||||
|
cv2.namedWindow(wnd_name)
|
||||||
|
|
||||||
|
def on_show_image (self, wnd_name, img):
|
||||||
|
cv2.imshow (wnd_name, img)
|
||||||
|
|
||||||
|
def on_capture_mouse (self, wnd_name):
|
||||||
|
def onMouse(event, x, y, flags, param):
|
||||||
|
(inst, wnd_name) = param
|
||||||
|
if event == cv2.EVENT_LBUTTONDOWN: ev = InteractBase.EVENT_LBUTTONDOWN
|
||||||
|
elif event == cv2.EVENT_LBUTTONUP: ev = InteractBase.EVENT_LBUTTONUP
|
||||||
|
elif event == cv2.EVENT_RBUTTONDOWN: ev = InteractBase.EVENT_RBUTTONDOWN
|
||||||
|
elif event == cv2.EVENT_RBUTTONUP: ev = InteractBase.EVENT_RBUTTONUP
|
||||||
|
elif event == cv2.EVENT_MOUSEWHEEL: ev = InteractBase.EVENT_MOUSEWHEEL
|
||||||
|
|
||||||
|
else: ev = 0
|
||||||
|
inst.add_mouse_event (wnd_name, x, y, ev, flags)
|
||||||
|
cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) )
|
||||||
|
|
||||||
|
def on_capture_keys (self, wnd_name):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_process_messages(self, sleep_time=0):
|
||||||
|
|
||||||
|
has_windows = False
|
||||||
|
has_capture_keys = False
|
||||||
|
|
||||||
|
if len(self.named_windows) != 0:
|
||||||
|
has_windows = True
|
||||||
|
|
||||||
|
if len(self.capture_keys_windows) != 0:
|
||||||
|
has_capture_keys = True
|
||||||
|
|
||||||
|
if has_windows or has_capture_keys:
|
||||||
|
wait_key_time = max(1, int(sleep_time*1000) )
|
||||||
|
key = cv2.waitKey(wait_key_time) & 0xFF
|
||||||
|
else:
|
||||||
|
if sleep_time != 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
if has_capture_keys and key != 255:
|
||||||
|
for wnd_name in self.capture_keys_windows:
|
||||||
|
self.add_key_event (wnd_name, key)
|
||||||
|
|
||||||
|
def on_wait_any_key(self):
|
||||||
|
cv2.waitKey(0)
|
||||||
|
|
||||||
|
class InteractColab(InteractBase):
|
||||||
|
|
||||||
|
def is_colab(self):
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_destroy_all_windows(self):
|
||||||
|
pass
|
||||||
|
#clear_output()
|
||||||
|
|
||||||
|
def on_create_window (self, wnd_name):
|
||||||
|
pass
|
||||||
|
#clear_output()
|
||||||
|
|
||||||
|
def on_show_image (self, wnd_name, img):
|
||||||
|
pass
|
||||||
|
# # cv2 stores colors as BGR; convert to RGB
|
||||||
|
# if img.ndim == 3:
|
||||||
|
# if img.shape[2] == 4:
|
||||||
|
# img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA)
|
||||||
|
# else:
|
||||||
|
# img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||||
|
# img = PIL.Image.fromarray(img)
|
||||||
|
# plt.imshow(img)
|
||||||
|
# plt.show()
|
||||||
|
|
||||||
|
def on_capture_mouse (self, wnd_name):
|
||||||
|
pass
|
||||||
|
#print("on_capture_mouse(): Colab does not support")
|
||||||
|
|
||||||
|
def on_capture_keys (self, wnd_name):
|
||||||
|
pass
|
||||||
|
#print("on_capture_keys(): Colab does not support")
|
||||||
|
|
||||||
|
def on_process_messages(self, sleep_time=0):
|
||||||
|
time.sleep(sleep_time)
|
||||||
|
|
||||||
|
def on_wait_any_key(self):
|
||||||
|
pass
|
||||||
|
#print("on_wait_any_key(): Colab does not support")
|
||||||
|
|
||||||
|
if is_colab:
|
||||||
|
interact = InteractColab()
|
||||||
|
else:
|
||||||
|
interact = InteractDesktop()
|
||||||
|
|
|
@ -100,7 +100,11 @@ def trainerThread (s2c, c2s, args, device_args):
|
||||||
else:
|
else:
|
||||||
for loss_value in loss_history[-1]:
|
for loss_value in loss_history[-1]:
|
||||||
loss_string += "[%.4f]" % (loss_value)
|
loss_string += "[%.4f]" % (loss_value)
|
||||||
io.log_info (loss_string, end='\r')
|
|
||||||
|
if io.is_colab():
|
||||||
|
io.log_info ('\r' + loss_string, end='')
|
||||||
|
else:
|
||||||
|
io.log_info (loss_string, end='\r')
|
||||||
|
|
||||||
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
if model.get_target_iter() != 0 and model.is_reached_iter_goal():
|
||||||
io.log_info ('Reached target iteration.')
|
io.log_info ('Reached target iteration.')
|
||||||
|
@ -281,6 +285,9 @@ def main(args, device_args):
|
||||||
selected_preview = (selected_preview + 1) % len(previews)
|
selected_preview = (selected_preview + 1) % len(previews)
|
||||||
update_preview = True
|
update_preview = True
|
||||||
|
|
||||||
io.process_messages(0.1)
|
try:
|
||||||
|
io.process_messages(0.1)
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
s2c.put ( {'op': 'close'} )
|
||||||
|
|
||||||
io.destroy_all_windows()
|
io.destroy_all_windows()
|
||||||
|
|
10
requirements-colab.txt
Normal file
10
requirements-colab.txt
Normal file
|
@ -0,0 +1,10 @@
|
||||||
|
numpy==1.16.1
|
||||||
|
h5py==2.9.0
|
||||||
|
Keras==2.2.4
|
||||||
|
opencv-python==4.0.0.21
|
||||||
|
tensorflow-gpu==1.13.1
|
||||||
|
plaidml-keras==0.5.0
|
||||||
|
scikit-image
|
||||||
|
tqdm
|
||||||
|
ffmpeg-python==0.1.17
|
||||||
|
git+https://www.github.com/keras-team/keras-contrib.git
|
Loading…
Add table
Add a link
Reference in a new issue