From 8fc4447dab62d5e1dc9a95f4ae543740b0dc1d11 Mon Sep 17 00:00:00 2001 From: iperov Date: Thu, 28 Oct 2021 09:30:24 +0400 Subject: [PATCH] update xlib --- xlib/image/color_transfer/__init__.py | 3 +- xlib/image/color_transfer/sot.py | 59 ++++++++++ xlib/mp/SPMTWorker.py | 158 ++++++++++++++++++++++++++ xlib/mp/__init__.py | 1 + xlib/mt/AtomicInteger.py | 23 ++++ xlib/mt/__init__.py | 3 +- 6 files changed, 245 insertions(+), 2 deletions(-) create mode 100644 xlib/image/color_transfer/sot.py create mode 100644 xlib/mp/SPMTWorker.py create mode 100644 xlib/mt/AtomicInteger.py diff --git a/xlib/image/color_transfer/__init__.py b/xlib/image/color_transfer/__init__.py index eb4596e..c5851f5 100644 --- a/xlib/image/color_transfer/__init__.py +++ b/xlib/image/color_transfer/__init__.py @@ -1 +1,2 @@ -from .rct import rct \ No newline at end of file +from .rct import rct +from .sot import sot \ No newline at end of file diff --git a/xlib/image/color_transfer/sot.py b/xlib/image/color_transfer/sot.py new file mode 100644 index 0000000..4e3d0e1 --- /dev/null +++ b/xlib/image/color_transfer/sot.py @@ -0,0 +1,59 @@ +import cv2 +import numpy as np +import numpy.linalg as npla + +def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0): + """ + Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer + + src - any float range any channel image + dst - any float range any channel image, same shape as src + steps - number of solver steps + batch_size - solver batch size + reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0 + reg_sigmaV - sigmaV of filter + + return value - clip it manually + + TODO check why diff result on float and uint8 images + """ + if not np.issubdtype(src.dtype, np.floating): + raise ValueError("src value must be float") + if not np.issubdtype(trg.dtype, np.floating): + raise ValueError("trg value must be float") + + if len(src.shape) != 3: + raise ValueError("src shape must have rank 3 (h,w,c)") + + if src.shape != trg.shape: + raise ValueError("src and trg shapes must be equal") + + src_dtype = src.dtype + h,w,c = src.shape + new_src = src.copy() + + advect = np.empty ( (h*w,c), dtype=src_dtype ) + for step in range (steps): + advect.fill(0) + for batch in range (batch_size): + dir = np.random.normal(size=c).astype(src_dtype) + dir /= npla.norm(dir) + + projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w)) + projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w)) + + idSource = np.argsort (projsource) + idTarget = np.argsort (projtarget) + + a = projtarget[idTarget]-projsource[idSource] + for i_c in range(c): + advect[idSource,i_c] += a * dir[i_c] + new_src += advect.reshape( (h,w,c) ) / batch_size + + if reg_sigmaXY != 0.0: + src_diff = new_src-src + src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY ) + if len(src_diff_filt.shape) == 2: + src_diff_filt = src_diff_filt[...,None] + new_src = src + src_diff_filt + return new_src diff --git a/xlib/mp/SPMTWorker.py b/xlib/mp/SPMTWorker.py new file mode 100644 index 0000000..a59f055 --- /dev/null +++ b/xlib/mp/SPMTWorker.py @@ -0,0 +1,158 @@ +import multiprocessing +import threading +import time +import traceback +import weakref + +def _host_thread_proc(wref): + while True: + ref = wref() + if ref is None: + break + ref._host_process_messages(0.005) + del ref + print('_host_thread_proc exit') + +class SPMTWorker: + def __init__(self, *sub_args, **sub_kwargs): + """ + base class for single subprocess multi thread worker + + provides messaging interface between host and subprocess + """ + host_pipe, sub_pipe = multiprocessing.Pipe() + p = multiprocessing.Process(target=self._subprocess_proc, args=(sub_pipe, sub_args, sub_kwargs), daemon=True) + p.start() + self._p = p + self._pipe = host_pipe + + threading.Thread(target=_host_thread_proc, args=(weakref.ref(self),), daemon=True).start() + + def kill(self): + """ + kill subprocess + """ + self._p.terminate() + self._p.join() + + def stop(self): + """ + graceful stop subprocess, will wait all thread finalization + """ + self._send_msg('_stop') + self._p.join() + + # overridable + def _on_host_sub_message(self, name, *args, **kwargs): + """ + a message from subprocess + """ + + def _host_process_messages(self, timeout : float = 0) -> bool: + """ + process messages on host side + """ + try: + pipe = self._pipe + if pipe.poll(timeout): + while True: + name, args, kwargs = pipe.recv() + self._on_host_sub_message(name, *args, **kwargs) + if not pipe.poll(): + break + except: + ... + + # overridable + def _on_sub_host_message(self, name, *args, **kwargs): + """ + a message from host + """ + + # overridable + def _on_sub_initialize(self): + """ + on subprocess initialization + """ + + def _on_sub_finalize(self): + """ + on graceful subprocess finalization + """ + print('_on_sub_finalize') + + # overridable + def _on_sub_thread_initialize(self, thread_id): + """ + called on subprocess thread initialization + """ + # overridable + def _on_sub_thread_finalize(self, thread_id): + """ + called on subprocess thread finalization + """ + # overridable + def _on_sub_thread_tick(self, thread_id): + """ + called on subprocess thread tick + """ + + + def _send_msg(self, name, *args, **kwargs): + """ + send message to other side (to host or to sub) + """ + try: + self._pipe.send( (name, args, kwargs) ) + except: + ... + + + def _sub_thread_proc(self, thread_id): + self._on_sub_thread_initialize(thread_id) + while self._threads_running: + self._on_sub_thread_tick(thread_id) + time.sleep(0.005) + self._on_sub_thread_finalize(thread_id) + + self._threads_exit_barrier.wait() + + def _sub_get_thread_count(self) -> int: + return self._thread_count + + def _subprocess_proc(self, pipe, sub_args, sub_kwargs): + self._pipe = pipe + self._thread_count = multiprocessing.cpu_count() + + self._on_sub_initialize(*sub_args, **sub_kwargs) + + self._threads = [] + self._threads_running = True + self._threads_exit_barrier = threading.Barrier(self._thread_count+1) + + for thread_id in range(self._thread_count): + t = threading.Thread(target=self._sub_thread_proc, args=(thread_id,), daemon=True) + t.start() + self._threads.append(t) + + working = True + while working: + if pipe.poll(0.005): + while True: + name, args, kwargs = pipe.recv() + if name == '_stop': + working = False + else: + try: + self._on_sub_host_message(name, *args, **kwargs) + except: + print(f'Error during handling host message {name} : {traceback.format_exc()}') + + if not pipe.poll(): + break + + self._threads_running = False + + self._threads_exit_barrier.wait() + + self._on_sub_finalize() \ No newline at end of file diff --git a/xlib/mp/__init__.py b/xlib/mp/__init__.py index baa9235..ec49a8c 100644 --- a/xlib/mp/__init__.py +++ b/xlib/mp/__init__.py @@ -6,3 +6,4 @@ from .PMPI import PMPI from .MPAtomicInt32 import MPAtomicInt32 from .MPSPSCMRRingData import MPSPSCMRRingData from .MPWeakHeap import MPWeakHeap +from .SPMTWorker import SPMTWorker \ No newline at end of file diff --git a/xlib/mt/AtomicInteger.py b/xlib/mt/AtomicInteger.py new file mode 100644 index 0000000..082b61f --- /dev/null +++ b/xlib/mt/AtomicInteger.py @@ -0,0 +1,23 @@ +import threading + +class AtomicInteger: + def __init__(self, value=0): + self._value = int(value) + self._lock = threading.Lock() + + def inc(self, d = 1): + with self._lock: + self._value += int(d) + return self._value + + def dec(self, d = 1): + return self.inc(-d) + + def get_value(self) -> int: + with self._lock: + return self._value + + def set_value(self, v : int): + with self._lock: + self._value = int(v) + return self._value \ No newline at end of file diff --git a/xlib/mt/__init__.py b/xlib/mt/__init__.py index 8288013..a3fab78 100644 --- a/xlib/mt/__init__.py +++ b/xlib/mt/__init__.py @@ -2,4 +2,5 @@ various threading extensions """ -from .MTOrderedData import MTOrderedData \ No newline at end of file +from .AtomicInteger import AtomicInteger +from .MTOrderedData import MTOrderedData