mirror of
https://github.com/iperov/DeepFaceLive
synced 2025-07-06 13:02:16 -07:00
update xlib
This commit is contained in:
parent
8b385f6d80
commit
8fc4447dab
6 changed files with 245 additions and 2 deletions
|
@ -1 +1,2 @@
|
||||||
from .rct import rct
|
from .rct import rct
|
||||||
|
from .sot import sot
|
59
xlib/image/color_transfer/sot.py
Normal file
59
xlib/image/color_transfer/sot.py
Normal file
|
@ -0,0 +1,59 @@
|
||||||
|
import cv2
|
||||||
|
import numpy as np
|
||||||
|
import numpy.linalg as npla
|
||||||
|
|
||||||
|
def sot(src,trg, steps=10, batch_size=30, reg_sigmaXY=16.0, reg_sigmaV=5.0):
|
||||||
|
"""
|
||||||
|
Color Transform via Sliced Optimal Transfer, ported from https://github.com/dcoeurjo/OTColorTransfer
|
||||||
|
|
||||||
|
src - any float range any channel image
|
||||||
|
dst - any float range any channel image, same shape as src
|
||||||
|
steps - number of solver steps
|
||||||
|
batch_size - solver batch size
|
||||||
|
reg_sigmaXY - apply regularization and sigmaXY of filter, otherwise set to 0.0
|
||||||
|
reg_sigmaV - sigmaV of filter
|
||||||
|
|
||||||
|
return value - clip it manually
|
||||||
|
|
||||||
|
TODO check why diff result on float and uint8 images
|
||||||
|
"""
|
||||||
|
if not np.issubdtype(src.dtype, np.floating):
|
||||||
|
raise ValueError("src value must be float")
|
||||||
|
if not np.issubdtype(trg.dtype, np.floating):
|
||||||
|
raise ValueError("trg value must be float")
|
||||||
|
|
||||||
|
if len(src.shape) != 3:
|
||||||
|
raise ValueError("src shape must have rank 3 (h,w,c)")
|
||||||
|
|
||||||
|
if src.shape != trg.shape:
|
||||||
|
raise ValueError("src and trg shapes must be equal")
|
||||||
|
|
||||||
|
src_dtype = src.dtype
|
||||||
|
h,w,c = src.shape
|
||||||
|
new_src = src.copy()
|
||||||
|
|
||||||
|
advect = np.empty ( (h*w,c), dtype=src_dtype )
|
||||||
|
for step in range (steps):
|
||||||
|
advect.fill(0)
|
||||||
|
for batch in range (batch_size):
|
||||||
|
dir = np.random.normal(size=c).astype(src_dtype)
|
||||||
|
dir /= npla.norm(dir)
|
||||||
|
|
||||||
|
projsource = np.sum( new_src*dir, axis=-1).reshape ((h*w))
|
||||||
|
projtarget = np.sum( trg*dir, axis=-1).reshape ((h*w))
|
||||||
|
|
||||||
|
idSource = np.argsort (projsource)
|
||||||
|
idTarget = np.argsort (projtarget)
|
||||||
|
|
||||||
|
a = projtarget[idTarget]-projsource[idSource]
|
||||||
|
for i_c in range(c):
|
||||||
|
advect[idSource,i_c] += a * dir[i_c]
|
||||||
|
new_src += advect.reshape( (h,w,c) ) / batch_size
|
||||||
|
|
||||||
|
if reg_sigmaXY != 0.0:
|
||||||
|
src_diff = new_src-src
|
||||||
|
src_diff_filt = cv2.bilateralFilter (src_diff, 0, reg_sigmaV, reg_sigmaXY )
|
||||||
|
if len(src_diff_filt.shape) == 2:
|
||||||
|
src_diff_filt = src_diff_filt[...,None]
|
||||||
|
new_src = src + src_diff_filt
|
||||||
|
return new_src
|
158
xlib/mp/SPMTWorker.py
Normal file
158
xlib/mp/SPMTWorker.py
Normal file
|
@ -0,0 +1,158 @@
|
||||||
|
import multiprocessing
|
||||||
|
import threading
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import weakref
|
||||||
|
|
||||||
|
def _host_thread_proc(wref):
|
||||||
|
while True:
|
||||||
|
ref = wref()
|
||||||
|
if ref is None:
|
||||||
|
break
|
||||||
|
ref._host_process_messages(0.005)
|
||||||
|
del ref
|
||||||
|
print('_host_thread_proc exit')
|
||||||
|
|
||||||
|
class SPMTWorker:
|
||||||
|
def __init__(self, *sub_args, **sub_kwargs):
|
||||||
|
"""
|
||||||
|
base class for single subprocess multi thread worker
|
||||||
|
|
||||||
|
provides messaging interface between host and subprocess
|
||||||
|
"""
|
||||||
|
host_pipe, sub_pipe = multiprocessing.Pipe()
|
||||||
|
p = multiprocessing.Process(target=self._subprocess_proc, args=(sub_pipe, sub_args, sub_kwargs), daemon=True)
|
||||||
|
p.start()
|
||||||
|
self._p = p
|
||||||
|
self._pipe = host_pipe
|
||||||
|
|
||||||
|
threading.Thread(target=_host_thread_proc, args=(weakref.ref(self),), daemon=True).start()
|
||||||
|
|
||||||
|
def kill(self):
|
||||||
|
"""
|
||||||
|
kill subprocess
|
||||||
|
"""
|
||||||
|
self._p.terminate()
|
||||||
|
self._p.join()
|
||||||
|
|
||||||
|
def stop(self):
|
||||||
|
"""
|
||||||
|
graceful stop subprocess, will wait all thread finalization
|
||||||
|
"""
|
||||||
|
self._send_msg('_stop')
|
||||||
|
self._p.join()
|
||||||
|
|
||||||
|
# overridable
|
||||||
|
def _on_host_sub_message(self, name, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
a message from subprocess
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _host_process_messages(self, timeout : float = 0) -> bool:
|
||||||
|
"""
|
||||||
|
process messages on host side
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
pipe = self._pipe
|
||||||
|
if pipe.poll(timeout):
|
||||||
|
while True:
|
||||||
|
name, args, kwargs = pipe.recv()
|
||||||
|
self._on_host_sub_message(name, *args, **kwargs)
|
||||||
|
if not pipe.poll():
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
...
|
||||||
|
|
||||||
|
# overridable
|
||||||
|
def _on_sub_host_message(self, name, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
a message from host
|
||||||
|
"""
|
||||||
|
|
||||||
|
# overridable
|
||||||
|
def _on_sub_initialize(self):
|
||||||
|
"""
|
||||||
|
on subprocess initialization
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _on_sub_finalize(self):
|
||||||
|
"""
|
||||||
|
on graceful subprocess finalization
|
||||||
|
"""
|
||||||
|
print('_on_sub_finalize')
|
||||||
|
|
||||||
|
# overridable
|
||||||
|
def _on_sub_thread_initialize(self, thread_id):
|
||||||
|
"""
|
||||||
|
called on subprocess thread initialization
|
||||||
|
"""
|
||||||
|
# overridable
|
||||||
|
def _on_sub_thread_finalize(self, thread_id):
|
||||||
|
"""
|
||||||
|
called on subprocess thread finalization
|
||||||
|
"""
|
||||||
|
# overridable
|
||||||
|
def _on_sub_thread_tick(self, thread_id):
|
||||||
|
"""
|
||||||
|
called on subprocess thread tick
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _send_msg(self, name, *args, **kwargs):
|
||||||
|
"""
|
||||||
|
send message to other side (to host or to sub)
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self._pipe.send( (name, args, kwargs) )
|
||||||
|
except:
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
|
def _sub_thread_proc(self, thread_id):
|
||||||
|
self._on_sub_thread_initialize(thread_id)
|
||||||
|
while self._threads_running:
|
||||||
|
self._on_sub_thread_tick(thread_id)
|
||||||
|
time.sleep(0.005)
|
||||||
|
self._on_sub_thread_finalize(thread_id)
|
||||||
|
|
||||||
|
self._threads_exit_barrier.wait()
|
||||||
|
|
||||||
|
def _sub_get_thread_count(self) -> int:
|
||||||
|
return self._thread_count
|
||||||
|
|
||||||
|
def _subprocess_proc(self, pipe, sub_args, sub_kwargs):
|
||||||
|
self._pipe = pipe
|
||||||
|
self._thread_count = multiprocessing.cpu_count()
|
||||||
|
|
||||||
|
self._on_sub_initialize(*sub_args, **sub_kwargs)
|
||||||
|
|
||||||
|
self._threads = []
|
||||||
|
self._threads_running = True
|
||||||
|
self._threads_exit_barrier = threading.Barrier(self._thread_count+1)
|
||||||
|
|
||||||
|
for thread_id in range(self._thread_count):
|
||||||
|
t = threading.Thread(target=self._sub_thread_proc, args=(thread_id,), daemon=True)
|
||||||
|
t.start()
|
||||||
|
self._threads.append(t)
|
||||||
|
|
||||||
|
working = True
|
||||||
|
while working:
|
||||||
|
if pipe.poll(0.005):
|
||||||
|
while True:
|
||||||
|
name, args, kwargs = pipe.recv()
|
||||||
|
if name == '_stop':
|
||||||
|
working = False
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self._on_sub_host_message(name, *args, **kwargs)
|
||||||
|
except:
|
||||||
|
print(f'Error during handling host message {name} : {traceback.format_exc()}')
|
||||||
|
|
||||||
|
if not pipe.poll():
|
||||||
|
break
|
||||||
|
|
||||||
|
self._threads_running = False
|
||||||
|
|
||||||
|
self._threads_exit_barrier.wait()
|
||||||
|
|
||||||
|
self._on_sub_finalize()
|
|
@ -6,3 +6,4 @@ from .PMPI import PMPI
|
||||||
from .MPAtomicInt32 import MPAtomicInt32
|
from .MPAtomicInt32 import MPAtomicInt32
|
||||||
from .MPSPSCMRRingData import MPSPSCMRRingData
|
from .MPSPSCMRRingData import MPSPSCMRRingData
|
||||||
from .MPWeakHeap import MPWeakHeap
|
from .MPWeakHeap import MPWeakHeap
|
||||||
|
from .SPMTWorker import SPMTWorker
|
23
xlib/mt/AtomicInteger.py
Normal file
23
xlib/mt/AtomicInteger.py
Normal file
|
@ -0,0 +1,23 @@
|
||||||
|
import threading
|
||||||
|
|
||||||
|
class AtomicInteger:
|
||||||
|
def __init__(self, value=0):
|
||||||
|
self._value = int(value)
|
||||||
|
self._lock = threading.Lock()
|
||||||
|
|
||||||
|
def inc(self, d = 1):
|
||||||
|
with self._lock:
|
||||||
|
self._value += int(d)
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def dec(self, d = 1):
|
||||||
|
return self.inc(-d)
|
||||||
|
|
||||||
|
def get_value(self) -> int:
|
||||||
|
with self._lock:
|
||||||
|
return self._value
|
||||||
|
|
||||||
|
def set_value(self, v : int):
|
||||||
|
with self._lock:
|
||||||
|
self._value = int(v)
|
||||||
|
return self._value
|
|
@ -2,4 +2,5 @@
|
||||||
various threading extensions
|
various threading extensions
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
from .AtomicInteger import AtomicInteger
|
||||||
from .MTOrderedData import MTOrderedData
|
from .MTOrderedData import MTOrderedData
|
Loading…
Add table
Add a link
Reference in a new issue