update xlib

This commit is contained in:
iperov 2021-10-28 09:30:24 +04:00
parent 8b385f6d80
commit 8fc4447dab
6 changed files with 245 additions and 2 deletions

View file

@ -1 +1,2 @@
from .rct import rct from .rct import rct
from .sot import sot

View file

@ -0,0 +1,59 @@
import cv2
import numpy as np
import numpy.linalg as npla
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
"""
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
src - any float range any channel image
dst - any float range any channel image, same shape as src
steps - number of solver steps
batch_size - solver batch size
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
reg_sigmaV - sigmaV of filter
return value - clip it manually
TODO check why diff result on float and uint8 images
"""
if not np.issubdtype(src.dtype, np.floating):
raise ValueError("src value must be float")
if not np.issubdtype(trg.dtype, np.floating):
raise ValueError("trg value must be float")
if len(src.shape) != 3:
raise ValueError("src shape must have rank 3 (h,w,c)")
if src.shape != trg.shape:
raise ValueError("src and trg shapes must be equal")
src_dtype = src.dtype
h,w,c = src.shape
new_src = src.copy()
advect = np.empty ( (h*w,c), dtype=src_dtype )
for step in range (steps):
advect.fill(0)
for batch in range (batch_size):
dir = np.random.normal(size=c).astype(src_dtype)
dir /= npla.norm(dir)
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
idSource = np.argsort (projsource)
idTarget = np.argsort (projtarget)
a = projtarget[idTarget]-projsource[idSource]
for i_c in range(c):
advect[idSource,i_c] += a * dir[i_c]
new_src += advect.reshape( (h,w,c) ) / batch_size
if reg_sigmaXY != 0.0:
src_diff = new_src-src
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
if len(src_diff_filt.shape) == 2:
src_diff_filt = src_diff_filt[...,None]
new_src = src + src_diff_filt
return new_src

158
xlib/mp/SPMTWorker.py Normal file
View file

@ -0,0 +1,158 @@
import multiprocessing
import threading
import time
import traceback
import weakref
def _host_thread_proc(wref):
while True:
ref = wref()
if ref is None:
break
ref._host_process_messages(0.005)
del ref
print('_host_thread_proc exit')
class SPMTWorker:
def __init__(self, *sub_args, **sub_kwargs):
"""
base class for single subprocess multi thread worker
provides messaging interface between host and subprocess
"""
host_pipe, sub_pipe = multiprocessing.Pipe()
p = multiprocessing.Process(target=self._subprocess_proc, args=(sub_pipe, sub_args, sub_kwargs), daemon=True)
p.start()
self._p = p
self._pipe = host_pipe
threading.Thread(target=_host_thread_proc, args=(weakref.ref(self),), daemon=True).start()
def kill(self):
"""
kill subprocess
"""
self._p.terminate()
self._p.join()
def stop(self):
"""
graceful stop subprocess, will wait all thread finalization
"""
self._send_msg('_stop')
self._p.join()
# overridable
def _on_host_sub_message(self, name, *args, **kwargs):
"""
a message from subprocess
"""
def _host_process_messages(self, timeout : float = 0) -> bool:
"""
process messages on host side
"""
try:
pipe = self._pipe
if pipe.poll(timeout):
while True:
name, args, kwargs = pipe.recv()
self._on_host_sub_message(name, *args, **kwargs)
if not pipe.poll():
break
except:
...
# overridable
def _on_sub_host_message(self, name, *args, **kwargs):
"""
a message from host
"""
# overridable
def _on_sub_initialize(self):
"""
on subprocess initialization
"""
def _on_sub_finalize(self):
"""
on graceful subprocess finalization
"""
print('_on_sub_finalize')
# overridable
def _on_sub_thread_initialize(self, thread_id):
"""
called on subprocess thread initialization
"""
# overridable
def _on_sub_thread_finalize(self, thread_id):
"""
called on subprocess thread finalization
"""
# overridable
def _on_sub_thread_tick(self, thread_id):
"""
called on subprocess thread tick
"""
def _send_msg(self, name, *args, **kwargs):
"""
send message to other side (to host or to sub)
"""
try:
self._pipe.send( (name, args, kwargs) )
except:
...
def _sub_thread_proc(self, thread_id):
self._on_sub_thread_initialize(thread_id)
while self._threads_running:
self._on_sub_thread_tick(thread_id)
time.sleep(0.005)
self._on_sub_thread_finalize(thread_id)
self._threads_exit_barrier.wait()
def _sub_get_thread_count(self) -> int:
return self._thread_count
def _subprocess_proc(self, pipe, sub_args, sub_kwargs):
self._pipe = pipe
self._thread_count = multiprocessing.cpu_count()
self._on_sub_initialize(*sub_args, **sub_kwargs)
self._threads = []
self._threads_running = True
self._threads_exit_barrier = threading.Barrier(self._thread_count+1)
for thread_id in range(self._thread_count):
t = threading.Thread(target=self._sub_thread_proc, args=(thread_id,), daemon=True)
t.start()
self._threads.append(t)
working = True
while working:
if pipe.poll(0.005):
while True:
name, args, kwargs = pipe.recv()
if name == '_stop':
working = False
else:
try:
self._on_sub_host_message(name, *args, **kwargs)
except:
print(f'Error during handling host message {name} : {traceback.format_exc()}')
if not pipe.poll():
break
self._threads_running = False
self._threads_exit_barrier.wait()
self._on_sub_finalize()

View file

@ -6,3 +6,4 @@ from .PMPI import PMPI
from .MPAtomicInt32 import MPAtomicInt32 from .MPAtomicInt32 import MPAtomicInt32
from .MPSPSCMRRingData import MPSPSCMRRingData from .MPSPSCMRRingData import MPSPSCMRRingData
from .MPWeakHeap import MPWeakHeap from .MPWeakHeap import MPWeakHeap
from .SPMTWorker import SPMTWorker

23
xlib/mt/AtomicInteger.py Normal file
View file

@ -0,0 +1,23 @@
import threading
class AtomicInteger:
def __init__(self, value=0):
self._value = int(value)
self._lock = threading.Lock()
def inc(self, d = 1):
with self._lock:
self._value += int(d)
return self._value
def dec(self, d = 1):
return self.inc(-d)
def get_value(self) -> int:
with self._lock:
return self._value
def set_value(self, v : int):
with self._lock:
self._value = int(v)
return self._value

View file

@ -2,4 +2,5 @@
various threading extensions various threading extensions
""" """
from .MTOrderedData import MTOrderedData from .AtomicInteger import AtomicInteger
from .MTOrderedData import MTOrderedData