diff --git a/interact/interact.py b/interact/interact.py index fc2e727..6255639 100644 --- a/interact/interact.py +++ b/interact/interact.py @@ -6,7 +6,16 @@ import multiprocessing import cv2 from tqdm import tqdm -class Interact(object): +try: + import IPython #if success we are in colab + from IPython.display import display, clear_output + import PIL + import matplotlib.pyplot as plt + is_colab = True +except: + is_colab = False + +class InteractBase(object): EVENT_LBUTTONDOWN = 1 EVENT_LBUTTONUP = 2 EVENT_RBUTTONDOWN = 5 @@ -21,6 +30,30 @@ class Interact(object): self.key_events = {} self.pg_bar = None + def is_colab(self): + return False + + def on_destroy_all_windows(self): + raise NotImplemented + + def on_create_window (self, wnd_name): + raise NotImplemented + + def on_show_image (self, wnd_name, img): + raise NotImplemented + + def on_capture_mouse (self, wnd_name): + raise NotImplemented + + def on_capture_keys (self, wnd_name): + raise NotImplemented + + def on_process_messages(self, sleep_time=0): + raise NotImplemented + + def on_wait_any_key(self): + raise NotImplemented + def log_info(self, msg, end='\n'): print (msg, end=end) @@ -31,12 +64,11 @@ class Interact(object): if wnd_name not in self.named_windows: #we will show window only on first show_image self.named_windows[wnd_name] = 0 - else: print("named_window: ", wnd_name, " already created.") def destroy_all_windows(self): if len( self.named_windows ) != 0: - cv2.destroyAllWindows() + self.on_destroy_all_windows() self.named_windows = {} self.capture_mouse_windows = {} self.capture_keys_windows = {} @@ -47,35 +79,24 @@ class Interact(object): if wnd_name in self.named_windows: if self.named_windows[wnd_name] == 0: self.named_windows[wnd_name] = 1 - cv2.namedWindow(wnd_name) + self.on_create_window(wnd_name) if wnd_name in self.capture_mouse_windows: self.capture_mouse(wnd_name) - - cv2.imshow (wnd_name, img) + self.on_show_image(wnd_name,img) else: print("show_image: named_window ", wnd_name, " not found.") def capture_mouse(self, wnd_name): - def onMouse(event, x, y, flags, param): - (inst, wnd_name) = param - if event == cv2.EVENT_LBUTTONDOWN: ev = Interact.EVENT_LBUTTONDOWN - elif event == cv2.EVENT_LBUTTONUP: ev = Interact.EVENT_LBUTTONUP - elif event == cv2.EVENT_RBUTTONDOWN: ev = Interact.EVENT_RBUTTONDOWN - elif event == cv2.EVENT_RBUTTONUP: ev = Interact.EVENT_RBUTTONUP - elif event == cv2.EVENT_MOUSEWHEEL: ev = Interact.EVENT_MOUSEWHEEL - - else: ev = 0 - inst.add_mouse_event (wnd_name, x, y, ev, flags) - if wnd_name in self.named_windows: self.capture_mouse_windows[wnd_name] = True if self.named_windows[wnd_name] == 1: - cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) ) + self.on_capture_mouse(wnd_name) else: print("capture_mouse: named_window ", wnd_name, " not found.") def capture_keys(self, wnd_name): if wnd_name in self.named_windows: if wnd_name not in self.capture_keys_windows: self.capture_keys_windows[wnd_name] = True + self.on_capture_keys(wnd_name) else: print("capture_keys: already set for window ", wnd_name) else: print("capture_keys: named_window ", wnd_name, " not found.") @@ -101,28 +122,10 @@ class Interact(object): yield x def process_messages(self, sleep_time=0): - has_windows = False - has_capture_keys = False - - if len(self.named_windows) != 0: - has_windows = True - - if len(self.capture_keys_windows) != 0: - has_capture_keys = True - - if has_windows or has_capture_keys: - wait_key_time = max(1, int(sleep_time*1000) ) - key = cv2.waitKey(wait_key_time) & 0xFF - else: - if sleep_time != 0: - time.sleep(sleep_time) - - if has_capture_keys and key != 255: - for wnd_name in self.capture_keys_windows: - self.add_key_event (wnd_name, key) + self.on_process_messages(sleep_time) def wait_any_key(self): - cv2.waitKey(0) + self.on_wait_any_key() def add_mouse_event(self, wnd_name, x, y, ev, flags): if wnd_name not in self.mouse_events: @@ -240,4 +243,101 @@ class Interact(object): sys.stdin = os.fdopen( sys.stdin.fileno() ) return inp -interact = Interact() + + +class InteractDesktop(InteractBase): + + def on_destroy_all_windows(self): + cv2.destroyAllWindows() + + def on_create_window (self, wnd_name): + cv2.namedWindow(wnd_name) + + def on_show_image (self, wnd_name, img): + cv2.imshow (wnd_name, img) + + def on_capture_mouse (self, wnd_name): + def onMouse(event, x, y, flags, param): + (inst, wnd_name) = param + if event == cv2.EVENT_LBUTTONDOWN: ev = InteractBase.EVENT_LBUTTONDOWN + elif event == cv2.EVENT_LBUTTONUP: ev = InteractBase.EVENT_LBUTTONUP + elif event == cv2.EVENT_RBUTTONDOWN: ev = InteractBase.EVENT_RBUTTONDOWN + elif event == cv2.EVENT_RBUTTONUP: ev = InteractBase.EVENT_RBUTTONUP + elif event == cv2.EVENT_MOUSEWHEEL: ev = InteractBase.EVENT_MOUSEWHEEL + + else: ev = 0 + inst.add_mouse_event (wnd_name, x, y, ev, flags) + cv2.setMouseCallback(wnd_name, onMouse, (self,wnd_name) ) + + def on_capture_keys (self, wnd_name): + pass + + def on_process_messages(self, sleep_time=0): + + has_windows = False + has_capture_keys = False + + if len(self.named_windows) != 0: + has_windows = True + + if len(self.capture_keys_windows) != 0: + has_capture_keys = True + + if has_windows or has_capture_keys: + wait_key_time = max(1, int(sleep_time*1000) ) + key = cv2.waitKey(wait_key_time) & 0xFF + else: + if sleep_time != 0: + time.sleep(sleep_time) + + if has_capture_keys and key != 255: + for wnd_name in self.capture_keys_windows: + self.add_key_event (wnd_name, key) + + def on_wait_any_key(self): + cv2.waitKey(0) + +class InteractColab(InteractBase): + + def is_colab(self): + return True + + def on_destroy_all_windows(self): + pass + #clear_output() + + def on_create_window (self, wnd_name): + pass + #clear_output() + + def on_show_image (self, wnd_name, img): + pass + # # cv2 stores colors as BGR; convert to RGB + # if img.ndim == 3: + # if img.shape[2] == 4: + # img = cv2.cvtColor(img, cv2.COLOR_BGRA2RGBA) + # else: + # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + # img = PIL.Image.fromarray(img) + # plt.imshow(img) + # plt.show() + + def on_capture_mouse (self, wnd_name): + pass + #print("on_capture_mouse(): Colab does not support") + + def on_capture_keys (self, wnd_name): + pass + #print("on_capture_keys(): Colab does not support") + + def on_process_messages(self, sleep_time=0): + time.sleep(sleep_time) + + def on_wait_any_key(self): + pass + #print("on_wait_any_key(): Colab does not support") + +if is_colab: + interact = InteractColab() +else: + interact = InteractDesktop() diff --git a/mainscripts/Trainer.py b/mainscripts/Trainer.py index 1cc642b..e023c26 100644 --- a/mainscripts/Trainer.py +++ b/mainscripts/Trainer.py @@ -100,7 +100,11 @@ def trainerThread (s2c, c2s, args, device_args): else: for loss_value in loss_history[-1]: loss_string += "[%.4f]" % (loss_value) - io.log_info (loss_string, end='\r') + + if io.is_colab(): + io.log_info ('\r' + loss_string, end='') + else: + io.log_info (loss_string, end='\r') if model.get_target_iter() != 0 and model.is_reached_iter_goal(): io.log_info ('Reached target iteration.') @@ -281,6 +285,9 @@ def main(args, device_args): selected_preview = (selected_preview + 1) % len(previews) update_preview = True - io.process_messages(0.1) + try: + io.process_messages(0.1) + except KeyboardInterrupt: + s2c.put ( {'op': 'close'} ) io.destroy_all_windows() diff --git a/requirements-colab.txt b/requirements-colab.txt new file mode 100644 index 0000000..5f2216f --- /dev/null +++ b/requirements-colab.txt @@ -0,0 +1,10 @@ +numpy==1.16.1 +h5py==2.9.0 +Keras==2.2.4 +opencv-python==4.0.0.21 +tensorflow-gpu==1.13.1 +plaidml-keras==0.5.0 +scikit-image +tqdm +ffmpeg-python==0.1.17 +git+https://www.github.com/keras-team/keras-contrib.git \ No newline at end of file