refactoring. Added RecycleGAN for testing.

This commit is contained in:
iperov 2018-12-28 19:38:52 +04:00
parent 8686309417
commit f8824f9601
24 changed files with 1661 additions and 1505 deletions

View file

@ -18,14 +18,6 @@ class MTCExtractor(object):
self.thresh2 = 0.85 self.thresh2 = 0.85
self.thresh3 = 0.6 self.thresh3 = 0.6
self.scale_factor = 0.95 self.scale_factor = 0.95
'''
self.min_face_size = self.scale_to * 0.042
self.thresh1 = 7
self.thresh2 = 85
self.thresh3 = 6
self.scale_factor = 0.95
'''
def __enter__(self): def __enter__(self):
with self.tf.variable_scope('pnet2'): with self.tf.variable_scope('pnet2'):

View file

@ -1 +0,0 @@
from .gpufmkmgr import *

View file

@ -1,317 +0,0 @@
import os
import sys
import contextlib
from utils import std_utils
from .pynvml import *
dlib_module = None
def import_dlib(device_idx, cpu_only=False):
global dlib_module
if dlib_module is not None:
raise Exception ('Multiple import of dlib is not allowed, reorganize your program.')
import dlib
dlib_module = dlib
if not cpu_only:
dlib_module.cuda.set_device(device_idx)
return dlib_module
tf_module = None
tf_session = None
keras_module = None
keras_contrib_module = None
keras_vggface_module = None
def set_prefer_GPUConfig(gpu_config):
global prefer_GPUConfig
prefer_GPUConfig = gpu_config
def get_tf_session():
global tf_session
return tf_session
def import_tf( gpu_config = None ):
global prefer_GPUConfig
global tf_module
global tf_session
if gpu_config is None:
gpu_config = prefer_GPUConfig
else:
prefer_GPUConfig = gpu_config
if tf_module is not None:
return tf_module
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__()
else:
suppressor = None
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
import tensorflow as tf
tf_module = tf
if gpu_config.cpu_only:
config = tf_module.ConfigProto( device_count = {'GPU': 0} )
else:
config = tf_module.ConfigProto()
visible_device_list = ''
for idx in gpu_config.gpu_idxs: visible_device_list += str(idx) + ','
visible_device_list = visible_device_list[:-1]
config.gpu_options.visible_device_list=visible_device_list
config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = gpu_config.allow_growth
tf_session = tf_module.Session(config=config)
if suppressor is not None:
suppressor.__exit__()
return tf_module
def finalize_tf():
global tf_module
global tf_session
tf_session.close()
tf_session = None
tf_module = None
def get_keras():
global keras_module
return keras_module
def import_keras():
global keras_module
if keras_module is not None:
return keras_module
sess = get_tf_session()
if sess is None:
raise Exception ('No TF session found. Import TF first.')
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__()
import keras
keras.backend.tensorflow_backend.set_session(sess)
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor.__exit__()
keras_module = keras
return keras_module
def finalize_keras():
global keras_module
keras_module.backend.clear_session()
keras_module = None
def import_keras_contrib():
global keras_contrib_module
if keras_contrib_module is not None:
raise Exception ('Multiple import of keras_contrib is not allowed, reorganize your program.')
import keras_contrib
keras_contrib_module = keras_contrib
return keras_contrib_module
def finalize_keras_contrib():
global keras_contrib_module
keras_contrib_module = None
def import_keras_vggface(optional=False):
global keras_vggface_module
if keras_vggface_module is not None:
raise Exception ('Multiple import of keras_vggface_module is not allowed, reorganize your program.')
try:
import keras_vggface
except:
if optional:
print ("Unable to import keras_vggface. It will not be used.")
else:
raise Exception ("Unable to import keras_vggface.")
keras_vggface = None
keras_vggface_module = keras_vggface
return keras_vggface_module
def finalize_keras_vggface():
global keras_vggface_module
keras_vggface_module = None
def hasNVML():
try:
nvmlInit()
nvmlShutdown()
except:
return False
return True
#returns [ (device_idx, device_name), ... ]
def getDevicesWithAtLeastFreeMemory(freememsize):
result = []
nvmlInit()
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total - memInfo.used) >= freememsize:
result.append (i)
nvmlShutdown()
return result
def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
result = []
nvmlInit()
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total) >= totalmemsize_gb*1024*1024*1024:
result.append (i)
nvmlShutdown()
return result
def getAllDevicesIdxsList ():
nvmlInit()
result = [ i for i in range(0, nvmlDeviceGetCount() ) ]
nvmlShutdown()
return result
def getDeviceVRAMFree (idx):
result = 0
nvmlInit()
if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle )
result = (memInfo.total - memInfo.used)
nvmlShutdown()
return result
def getDeviceVRAMTotalGb (idx):
result = 0
nvmlInit()
if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle )
result = memInfo.total / (1024*1024*1024)
nvmlShutdown()
return round(result)
def getBestDeviceIdx():
nvmlInit()
idx = -1
idx_mem = 0
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if memInfo.total > idx_mem:
idx = i
idx_mem = memInfo.total
nvmlShutdown()
return idx
def getWorstDeviceIdx():
nvmlInit()
idx = -1
idx_mem = sys.maxsize
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if memInfo.total < idx_mem:
idx = i
idx_mem = memInfo.total
nvmlShutdown()
return idx
def isValidDeviceIdx(idx):
nvmlInit()
result = (idx < nvmlDeviceGetCount())
nvmlShutdown()
return result
def getDeviceIdxsEqualModel(idx):
result = []
nvmlInit()
idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
for i in range(0, nvmlDeviceGetCount() ):
if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name:
result.append (i)
nvmlShutdown()
return result
def getDeviceName (idx):
result = ''
nvmlInit()
if idx < nvmlDeviceGetCount():
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
nvmlShutdown()
return result
class GPUConfig():
force_best_gpu_idx = -1
multi_gpu = False
force_gpu_idxs = None
choose_worst_gpu = False
gpu_idxs = []
gpu_total_vram_gb = 0
allow_growth = True
cpu_only = False
def __init__ (self, force_best_gpu_idx = -1,
multi_gpu = False,
force_gpu_idxs = None,
choose_worst_gpu = False,
allow_growth = True,
cpu_only = False,
**in_options):
if not hasNVML():
cpu_only = True
if cpu_only:
self.cpu_only = cpu_only
else:
self.force_best_gpu_idx = force_best_gpu_idx
self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and isValidDeviceIdx(force_best_gpu_idx)) else getBestDeviceIdx() if not choose_worst_gpu else getWorstDeviceIdx()
if force_gpu_idxs is not None:
self.gpu_idxs = [ int(x) for x in force_gpu_idxs.split(',') ]
else:
if self.multi_gpu:
self.gpu_idxs = getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
self.gpu_total_vram_gb = getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
prefer_GPUConfig = GPUConfig()

View file

@ -240,9 +240,10 @@ if __name__ == "__main__":
arguments = parser.parse_args() arguments = parser.parse_args()
if arguments.tf_suppress_std: if arguments.tf_suppress_std:
os.environ['TF_SUPPRESS_STD'] = '1' os.environ['TF_SUPPRESS_STD'] = '1'
arguments.func(arguments) arguments.func(arguments)
print ("Done.")
''' '''
import code import code
code.interact(local=dict(globals(), **locals())) code.interact(local=dict(globals(), **locals()))

View file

@ -41,25 +41,20 @@ def model_process(model_name, model_dir, in_options, sq, cq):
cq.put ( {'op':'init', 'converter' : converter.copy_and_set_predictor( None ) } ) cq.put ( {'op':'init', 'converter' : converter.copy_and_set_predictor( None ) } )
closing = False while True:
while not closing:
while not sq.empty(): while not sq.empty():
obj = sq.get() obj = sq.get()
obj_op = obj['op'] obj_op = obj['op']
if obj_op == 'predict': if obj_op == 'predict':
result = converter.predictor ( obj['face'] ) result = converter.predictor ( obj['face'] )
cq.put ( {'op':'predict_result', 'result':result} ) cq.put ( {'op':'predict_result', 'result':result} )
elif obj_op == 'close':
closing = True
break
time.sleep(0.005) time.sleep(0.005)
model.finalize()
except Exception as e: except Exception as e:
print ( 'Error: %s' % (str(e))) print ( 'Error: %s' % (str(e)))
traceback.print_exc() traceback.print_exc()
from utils.SubprocessorBase import SubprocessorBase from utils.SubprocessorBase import SubprocessorBase
class ConvertSubprocessor(SubprocessorBase): class ConvertSubprocessor(SubprocessorBase):
@ -129,10 +124,11 @@ class ConvertSubprocessor(SubprocessorBase):
self.alignments = client_dict['alignments'] self.alignments = client_dict['alignments']
self.debug = client_dict['debug'] self.debug = client_dict['debug']
import gpufmkmgr from nnlib import nnlib
#model process ate all GPU mem, #model process ate all GPU mem,
#so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter) #so we cannot use GPU for any TF operations in converter processes (for example image_utils.TFLabConverter)
gpufmkmgr.set_prefer_GPUConfig ( gpufmkmgr.GPUConfig (cpu_only=True) ) #therefore forcing prefer_DeviceConfig to CPU only
nnlib.prefer_DeviceConfig = nnlib.DeviceConfig (cpu_only=True)
return None return None
@ -156,6 +152,13 @@ class ConvertSubprocessor(SubprocessorBase):
image = (cv2.imread(str(filename_path)) / 255.0).astype(np.float32) image = (cv2.imread(str(filename_path)) / 255.0).astype(np.float32)
if self.converter.get_mode() == ConverterBase.MODE_IMAGE: if self.converter.get_mode() == ConverterBase.MODE_IMAGE:
image = self.converter.convert_image(image, self.debug)
if self.debug:
for img in image:
cv2.imshow ('Debug convert', img )
cv2.waitKey(0)
faces_processed = 1
elif self.converter.get_mode() == ConverterBase.MODE_IMAGE_WITH_LANDMARKS:
image_landmarks = DFLPNG.load( str(filename_path), throw_on_no_embedded_data=True ).get_landmarks() image_landmarks = DFLPNG.load( str(filename_path), throw_on_no_embedded_data=True ).get_landmarks()
image = self.converter.convert_image(image, image_landmarks, self.debug) image = self.converter.convert_image(image, image_landmarks, self.debug)
@ -270,9 +273,8 @@ def main (input_dir, output_dir, model_dir, model_name, aligned_dir=None, **in_o
output_path = output_path, output_path = output_path,
alignments = alignments, alignments = alignments,
**in_options ).process() **in_options ).process()
model_sq.put ( {'op':'close'} ) model_p.terminate()
model_p.join()
''' '''
if model_name == 'AVATAR': if model_name == 'AVATAR':

View file

@ -12,7 +12,7 @@ from utils.DFLPNG import DFLPNG
from utils import image_utils from utils import image_utils
from facelib import FaceType from facelib import FaceType
import facelib import facelib
import gpufmkmgr from nnlib import nnlib
from utils.SubprocessorBase import SubprocessorBase from utils.SubprocessorBase import SubprocessorBase
class ExtractSubprocessor(SubprocessorBase): class ExtractSubprocessor(SubprocessorBase):
@ -63,10 +63,10 @@ class ExtractSubprocessor(SubprocessorBase):
def get_devices_for_type (self, type, multi_gpu): def get_devices_for_type (self, type, multi_gpu):
if (type == 'rects' or type == 'landmarks'): if (type == 'rects' or type == 'landmarks'):
if not multi_gpu: if not multi_gpu:
devices = [gpufmkmgr.getBestDeviceIdx()] devices = [nnlib.device.getBestDeviceIdx()]
else: else:
devices = gpufmkmgr.getDevicesWithAtLeastTotalMemoryGB(2) devices = nnlib.device.getDevicesWithAtLeastTotalMemoryGB(2)
devices = [ (idx, gpufmkmgr.getDeviceName(idx), gpufmkmgr.getDeviceVRAMTotalGb(idx) ) for idx in devices] devices = [ (idx, nnlib.device.getDeviceName(idx), nnlib.device.getDeviceVRAMTotalGb(idx) ) for idx in devices]
elif type == 'final': elif type == 'final':
devices = [ (i, 'CPU%d' % (i), 0 ) for i in range(0, multiprocessing.cpu_count()) ] devices = [ (i, 'CPU%d' % (i), 0 ) for i in range(0, multiprocessing.cpu_count()) ]
@ -253,31 +253,22 @@ class ExtractSubprocessor(SubprocessorBase):
self.debug = client_dict['debug'] self.debug = client_dict['debug']
self.detector = client_dict['detector'] self.detector = client_dict['detector']
self.keras = None
self.tf = None
self.tf_session = None
self.e = None self.e = None
device_config = nnlib.DeviceConfig ( cpu_only=self.cpu_only, force_best_gpu_idx=self.device_idx, allow_growth=True)
if self.type == 'rects': if self.type == 'rects':
if self.detector is not None: if self.detector is not None:
if self.detector == 'mt': if self.detector == 'mt':
nnlib.import_all (device_config)
self.gpu_config = gpufmkmgr.GPUConfig ( cpu_only=self.cpu_only, force_best_gpu_idx=self.device_idx, allow_growth=True) self.e = facelib.MTCExtractor(nnlib.keras, nnlib.tf, nnlib.tf_sess)
self.tf = gpufmkmgr.import_tf ( self.gpu_config )
self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.e = facelib.MTCExtractor(self.keras, self.tf, self.tf_session)
elif self.detector == 'dlib': elif self.detector == 'dlib':
self.dlib = gpufmkmgr.import_dlib( self.device_idx, cpu_only=self.cpu_only ) nnlib.import_dlib (device_config)
self.e = facelib.DLIBExtractor(self.dlib) self.e = facelib.DLIBExtractor(nnlib.dlib)
self.e.__enter__() self.e.__enter__()
elif self.type == 'landmarks': elif self.type == 'landmarks':
self.gpu_config = gpufmkmgr.GPUConfig ( cpu_only=self.cpu_only, force_best_gpu_idx=self.device_idx, allow_growth=True) nnlib.import_all (device_config)
self.tf = gpufmkmgr.import_tf ( self.gpu_config ) self.e = facelib.LandmarksExtractor(nnlib.keras)
self.tf_session = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.e = facelib.LandmarksExtractor(self.keras)
self.e.__enter__() self.e.__enter__()
elif self.type == 'final': elif self.type == 'final':

7
mathlib/__init__.py Normal file
View file

@ -0,0 +1,7 @@
from .umeyama import umeyama
def get_power_of_two(x):
i = 0
while (1 << i) < x:
i += 1
return i

View file

@ -6,6 +6,7 @@ You can implement your own Converter, check example ConverterMasked.py
class ConverterBase(object): class ConverterBase(object):
MODE_FACE = 0 MODE_FACE = 0
MODE_IMAGE = 1 MODE_IMAGE = 1
MODE_IMAGE_WITH_LANDMARKS = 2
#overridable #overridable
def __init__(self, predictor): def __init__(self, predictor):

View file

@ -34,7 +34,7 @@ class ConverterImage(ConverterBase):
self.predictor ( np.zeros ( (self.predictor_input_size, self.predictor_input_size,3), dtype=np.float32) ) self.predictor ( np.zeros ( (self.predictor_input_size, self.predictor_input_size,3), dtype=np.float32) )
#override #override
def convert_image (self, img_bgr, img_landmarks, debug): def convert_image (self, img_bgr, debug):
img_size = img_bgr.shape[1], img_bgr.shape[0] img_size = img_bgr.shape[1], img_bgr.shape[0]
predictor_input_bgr = cv2.resize ( img_bgr, (self.predictor_input_size, self.predictor_input_size), cv2.INTER_LANCZOS4 ) predictor_input_bgr = cv2.resize ( img_bgr, (self.predictor_input_size, self.predictor_input_size), cv2.INTER_LANCZOS4 )
@ -42,5 +42,5 @@ class ConverterImage(ConverterBase):
output = cv2.resize ( predicted_bgr, (self.output_size, self.output_size), cv2.INTER_LANCZOS4 ) output = cv2.resize ( predicted_bgr, (self.output_size, self.output_size), cv2.INTER_LANCZOS4 )
if debug: if debug:
return (img_bgr,output,) return (predictor_input_bgr,output,)
return output return output

View file

@ -9,9 +9,8 @@ from utils import std_utils
from utils import image_utils from utils import image_utils
import numpy as np import numpy as np
import cv2 import cv2
import gpufmkmgr
from samples import SampleGeneratorBase from samples import SampleGeneratorBase
from nnlib import nnlib
''' '''
You can implement your own model. Check examples. You can implement your own model. Check examples.
''' '''
@ -63,27 +62,22 @@ class ModelBase(object):
if self.epoch == 0: if self.epoch == 0:
for filename in Path_utils.get_image_paths(self.preview_history_path): for filename in Path_utils.get_image_paths(self.preview_history_path):
Path(filename).unlink() Path(filename).unlink()
self.device_config = nnlib.DeviceConfig(allow_growth=False, **in_options)
self.gpu_config = gpufmkmgr.GPUConfig(allow_growth=False, **in_options)
self.gpu_total_vram_gb = self.gpu_config.gpu_total_vram_gb
if self.epoch == 0: if self.epoch == 0:
#first run #first run
self.options['created_vram_gb'] = self.gpu_total_vram_gb self.options['created_vram_gb'] = self.device_config.gpu_total_vram_gb
self.created_vram_gb = self.gpu_total_vram_gb self.created_vram_gb = self.device_config.gpu_total_vram_gb
else: else:
#not first run #not first run
if 'created_vram_gb' in self.options.keys(): if 'created_vram_gb' in self.options.keys():
self.created_vram_gb = self.options['created_vram_gb'] self.created_vram_gb = self.options['created_vram_gb']
else: else:
self.options['created_vram_gb'] = self.gpu_total_vram_gb self.options['created_vram_gb'] = self.device_config.gpu_total_vram_gb
self.created_vram_gb = self.gpu_total_vram_gb self.created_vram_gb = self.device_config.gpu_total_vram_gb
self.tf = gpufmkmgr.import_tf( self.gpu_config ) nnlib.import_all (self.device_config)
self.tf_sess = gpufmkmgr.get_tf_session()
self.keras = gpufmkmgr.import_keras()
self.keras_contrib = gpufmkmgr.import_keras_contrib()
self.onInitialize(**in_options) self.onInitialize(**in_options)
@ -108,18 +102,18 @@ class ModelBase(object):
print ("==") print ("==")
print ("== Options:") print ("== Options:")
print ("== |== batch_size : %s " % (self.batch_size) ) print ("== |== batch_size : %s " % (self.batch_size) )
print ("== |== multi_gpu : %s " % (self.gpu_config.multi_gpu) ) print ("== |== multi_gpu : %s " % (self.device_config.multi_gpu) )
for key in self.options.keys(): for key in self.options.keys():
print ("== |== %s : %s" % (key, self.options[key]) ) print ("== |== %s : %s" % (key, self.options[key]) )
print ("== Running on:") print ("== Running on:")
if self.gpu_config.cpu_only: if self.device_config.cpu_only:
print ("== |== [CPU]") print ("== |== [CPU]")
else: else:
for idx in self.gpu_config.gpu_idxs: for idx in self.device_config.gpu_idxs:
print ("== |== [%d : %s]" % (idx, gpufmkmgr.getDeviceName(idx)) ) print ("== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx)) )
if not self.gpu_config.cpu_only and self.gpu_total_vram_gb == 2: if not self.device_config.cpu_only and self.device_config.gpu_total_vram_gb == 2:
print ("==") print ("==")
print ("== WARNING: You are using 2GB GPU. Result quality may be significantly decreased.") print ("== WARNING: You are using 2GB GPU. Result quality may be significantly decreased.")
print ("== If training does not start, close all programs and try again.") print ("== If training does not start, close all programs and try again.")
@ -168,18 +162,18 @@ class ModelBase(object):
return ConverterBase(self, **in_options) return ConverterBase(self, **in_options)
def to_multi_gpu_model_if_possible (self, models_list): def to_multi_gpu_model_if_possible (self, models_list):
if len(self.gpu_config.gpu_idxs) > 1: if len(self.device_config.gpu_idxs) > 1:
#make batch_size to divide on GPU count without remainder #make batch_size to divide on GPU count without remainder
self.batch_size = int( self.batch_size / len(self.gpu_config.gpu_idxs) ) self.batch_size = int( self.batch_size / len(self.device_config.gpu_idxs) )
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = 1 self.batch_size = 1
self.batch_size *= len(self.gpu_config.gpu_idxs) self.batch_size *= len(self.device_config.gpu_idxs)
result = [] result = []
for model in models_list: for model in models_list:
for i in range( len(model.output_names) ): for i in range( len(model.output_names) ):
model.output_names = 'output_%d' % (i) model.output_names = 'output_%d' % (i)
result += [ self.keras.utils.multi_gpu_model( model, self.gpu_config.gpu_idxs ) ] result += [ nnlib.keras.utils.multi_gpu_model( model, self.device_config.gpu_idxs ) ]
return result return result
else: else:
@ -259,12 +253,12 @@ class ModelBase(object):
cv2.imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.epoch) )), img ) cv2.imwrite ( str (self.preview_history_path / ('%.6d.jpg' %( self.epoch) )), img )
self.epoch += 1 self.epoch += 1
#............."Saving...
if epoch_time >= 10000: if epoch_time >= 10000:
loss_string = "Training [#{0:06d}][{1:03d}s]".format ( self.epoch, epoch_time / 1000 ) #............."Saving...
loss_string = "Training [#{0:06d}][{1:.5s}s]".format ( self.epoch, '{:0.4f}'.format(epoch_time / 1000) )
else: else:
loss_string = "Training [#{0:06d}][{1:04d}ms]".format ( self.epoch, int(epoch_time*1000) % 10000 ) loss_string = "Training [#{0:06d}][{1:04d}ms]".format ( self.epoch, int(epoch_time*1000) )
for (loss_name, loss_value) in losses: for (loss_name, loss_value) in losses:
loss_string += " %s:%.3f" % (loss_name, loss_value) loss_string += " %s:%.3f" % (loss_name, loss_value)
@ -274,7 +268,7 @@ class ModelBase(object):
self.last_sample = self.generate_next_sample() self.last_sample = self.generate_next_sample()
def finalize(self): def finalize(self):
gpufmkmgr.finalize_keras() nnlib.finalize_all()
def is_first_run(self): def is_first_run(self):
return self.epoch == 0 return self.epoch == 0
@ -282,6 +276,12 @@ class ModelBase(object):
def is_debug(self): def is_debug(self):
return self.debug return self.debug
def set_batch_size(self, batch_size):
self.batch_size = batch_size
def get_batch_size(self):
return self.batch_size
def get_epoch(self): def get_epoch(self):
return self.epoch return self.epoch
@ -301,16 +301,16 @@ class ModelBase(object):
#example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} #example d = {2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48}
keys = [x for x in d.keys()] keys = [x for x in d.keys()]
if self.gpu_config.cpu_only: if self.device_config.cpu_only:
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = 2 self.batch_size = 2
else: else:
if self.gpu_total_vram_gb < keys[0]: if self.device_config.gpu_total_vram_gb < keys[0]:
raise Exception ('Sorry, this model works only on %dGB+ GPU' % ( keys[0] ) ) raise Exception ('Sorry, this model works only on %dGB+ GPU' % ( keys[0] ) )
if self.batch_size == 0: if self.batch_size == 0:
for x in keys: for x in keys:
if self.gpu_total_vram_gb <= x: if self.device_config.gpu_total_vram_gb <= x:
self.batch_size = d[x] self.batch_size = d[x]
break break

View file

@ -2,10 +2,7 @@ import numpy as np
import cv2 import cv2
from models import ModelBase from models import ModelBase
from samples import * from samples import *
from nnlib import tf_dssim from nnlib import nnlib
from nnlib import DSSIMLossClass
from nnlib import conv
from nnlib import upscale
class Model(ModelBase): class Model(ModelBase):
@ -17,9 +14,7 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
tf = self.tf exec(nnlib.import_all(), locals(), globals())
keras = self.keras
K = keras.backend
self.set_vram_batch_requirements( {3.5:8,4:8,5:12,6:16,7:24,8:32,9:48} ) self.set_vram_batch_requirements( {3.5:8,4:8,5:12,6:16,7:24,8:32,9:48} )
if self.batch_size < 4: if self.batch_size < 4:
@ -34,39 +29,39 @@ class Model(ModelBase):
self.encoder256.load_weights (self.get_strpath_storage_for_file(self.encoder256H5)) self.encoder256.load_weights (self.get_strpath_storage_for_file(self.encoder256H5))
self.decoder256.load_weights (self.get_strpath_storage_for_file(self.decoder256H5)) self.decoder256.load_weights (self.get_strpath_storage_for_file(self.decoder256H5))
if self.is_training_mode: #if self.is_training_mode:
self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256 = self.to_multi_gpu_model_if_possible ( [self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256] ) # self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256 = self.to_multi_gpu_model_if_possible ( [self.encoder64, self.decoder64_src, self.decoder64_dst, self.encoder256, self.decoder256] )
input_A_warped64 = keras.layers.Input(img_shape64) input_A_warped64 = Input(img_shape64)
input_B_warped64 = keras.layers.Input(img_shape64) input_B_warped64 = Input(img_shape64)
A_rec64 = self.decoder64_src(self.encoder64(input_A_warped64)) A_rec64 = self.decoder64_src(self.encoder64(input_A_warped64))
B_rec64 = self.decoder64_dst(self.encoder64(input_B_warped64)) B_rec64 = self.decoder64_dst(self.encoder64(input_B_warped64))
self.ae64 = self.keras.models.Model([input_A_warped64, input_B_warped64], [A_rec64, B_rec64] ) self.ae64 = Model([input_A_warped64, input_B_warped64], [A_rec64, B_rec64] )
if self.is_training_mode: if self.is_training_mode:
self.ae64, = self.to_multi_gpu_model_if_possible ( [self.ae64,] ) self.ae64, = self.to_multi_gpu_model_if_possible ( [self.ae64,] )
self.ae64.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), self.ae64.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[DSSIMLossClass(self.tf)(), DSSIMLossClass(self.tf)()] ) loss=[DSSIMLoss(), DSSIMLoss()] )
self.A64_view = K.function ([input_A_warped64], [A_rec64]) self.A64_view = K.function ([input_A_warped64], [A_rec64])
self.B64_view = K.function ([input_B_warped64], [B_rec64]) self.B64_view = K.function ([input_B_warped64], [B_rec64])
input_A_warped64 = keras.layers.Input(img_shape64) input_A_warped64 = Input(img_shape64)
input_A_target256 = keras.layers.Input(img_shape256) input_A_target256 = Input(img_shape256)
A_rec256 = self.decoder256( self.encoder256(input_A_warped64) ) A_rec256 = self.decoder256( self.encoder256(input_A_warped64) )
input_B_warped64 = keras.layers.Input(img_shape64) input_B_warped64 = Input(img_shape64)
BA_rec64 = self.decoder64_src( self.encoder64(input_B_warped64) ) BA_rec64 = self.decoder64_src( self.encoder64(input_B_warped64) )
BA_rec256 = self.decoder256( self.encoder256(BA_rec64) ) BA_rec256 = self.decoder256( self.encoder256(BA_rec64) )
self.ae256 = self.keras.models.Model([input_A_warped64], [A_rec256] ) self.ae256 = Model([input_A_warped64], [A_rec256] )
if self.is_training_mode: if self.is_training_mode:
self.ae256, = self.to_multi_gpu_model_if_possible ( [self.ae256,] ) self.ae256, = self.to_multi_gpu_model_if_possible ( [self.ae256,] )
self.ae256.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), self.ae256.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[DSSIMLossClass(self.tf)()]) loss=[DSSIMLoss()])
self.A256_view = K.function ([input_A_warped64], [A_rec256]) self.A256_view = K.function ([input_A_warped64], [A_rec256])
self.BA256_view = K.function ([input_B_warped64], [BA_rec256]) self.BA256_view = K.function ([input_B_warped64], [BA_rec256])
@ -153,62 +148,67 @@ class Model(ModelBase):
return ConverterAvatar(self.predictor_func, predictor_input_size=64, output_size=256, **in_options) return ConverterAvatar(self.predictor_func, predictor_input_size=64, output_size=256, **in_options)
def Build(self): def Build(self):
keras, K = self.keras, self.keras.backend exec(nnlib.code_import_all, locals(), globals())
img_shape64 = (64,64,3) img_shape64 = (64,64,3)
img_shape256 = (256,256,3) img_shape256 = (256,256,3)
def upscale (dim):
def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder(_input): def Encoder(_input):
x = _input x = _input
x = self.keras.layers.convolutional.Conv2D(90, kernel_size=5, strides=1, padding='same')(x) x = Conv2D(90, kernel_size=5, strides=1, padding='same')(x)
x = self.keras.layers.convolutional.Conv2D(90, kernel_size=5, strides=1, padding='same')(x) x = Conv2D(90, kernel_size=5, strides=1, padding='same')(x)
x = self.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x) x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
x = self.keras.layers.convolutional.Conv2D(180, kernel_size=3, strides=1, padding='same')(x) x = Conv2D(180, kernel_size=3, strides=1, padding='same')(x)
x = self.keras.layers.convolutional.Conv2D(180, kernel_size=3, strides=1, padding='same')(x) x = Conv2D(180, kernel_size=3, strides=1, padding='same')(x)
x = self.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x) x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
x = self.keras.layers.convolutional.Conv2D(360, kernel_size=3, strides=1, padding='same')(x) x = Conv2D(360, kernel_size=3, strides=1, padding='same')(x)
x = self.keras.layers.convolutional.Conv2D(360, kernel_size=3, strides=1, padding='same')(x) x = Conv2D(360, kernel_size=3, strides=1, padding='same')(x)
x = self.keras.layers.MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x) x = MaxPooling2D(pool_size=(3, 3), strides=2, padding='same')(x)
x = self.keras.layers.Dense (1024)(x) x = Dense (1024)(x)
x = self.keras.layers.advanced_activations.LeakyReLU(0.1)(x) x = LeakyReLU(0.1)(x)
x = self.keras.layers.Dropout(0.5)(x) x = Dropout(0.5)(x)
x = self.keras.layers.Dense (1024)(x) x = Dense (1024)(x)
x = self.keras.layers.advanced_activations.LeakyReLU(0.1)(x) x = LeakyReLU(0.1)(x)
x = self.keras.layers.Dropout(0.5)(x) x = Dropout(0.5)(x)
x = self.keras.layers.Flatten()(x) x = Flatten()(x)
x = self.keras.layers.Dense (64)(x) x = Dense (64)(x)
return keras.models.Model (_input, x) return keras.models.Model (_input, x)
encoder256 = Encoder( keras.layers.Input (img_shape64) ) encoder256 = Encoder( Input (img_shape64) )
encoder64 = Encoder( keras.layers.Input (img_shape64) ) encoder64 = Encoder( Input (img_shape64) )
def decoder256(encoder): def decoder256(encoder):
decoder_input = keras.layers.Input ( K.int_shape(encoder.outputs[0])[1:] ) decoder_input = Input ( K.int_shape(encoder.outputs[0])[1:] )
x = decoder_input x = decoder_input
x = self.keras.layers.Dense(16 * 16 * 720)(x) x = Dense(16 * 16 * 720)(x)
x = keras.layers.Reshape ( (16, 16, 720) )(x) x = Reshape ( (16, 16, 720) )(x)
x = upscale(keras, x, 720) x = upscale(720)(x)
x = upscale(keras, x, 360) x = upscale(360)(x)
x = upscale(keras, x, 180) x = upscale(180)(x)
x = upscale(keras, x, 90) x = upscale(90)(x)
x = keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return keras.models.Model(decoder_input, x) return keras.models.Model(decoder_input, x)
def decoder64(encoder): def decoder64(encoder):
decoder_input = keras.layers.Input ( K.int_shape(encoder.outputs[0])[1:] ) decoder_input = Input ( K.int_shape(encoder.outputs[0])[1:] )
x = decoder_input x = decoder_input
x = self.keras.layers.Dense(8 * 8 * 720)(x) x = Dense(8 * 8 * 720)(x)
x = keras.layers.Reshape ( (8, 8, 720) )(x) x = Reshape ( (8, 8, 720) )(x)
x = upscale(keras, x, 360) x = upscale(360)(x)
x = upscale(keras, x, 180) x = upscale(180)(x)
x = upscale(keras, x, 90) x = upscale(90)(x)
x = keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return keras.models.Model(decoder_input, x) return Model(decoder_input, x)
return img_shape64, img_shape256, encoder64, decoder64(encoder64), decoder64(encoder64), encoder256, decoder256(encoder256) return img_shape64, img_shape256, encoder64, decoder64(encoder64), decoder64(encoder64), encoder256, decoder256(encoder256)
@ -230,7 +230,7 @@ class ConverterAvatar(ConverterBase):
#override #override
def get_mode(self): def get_mode(self):
return ConverterBase.MODE_IMAGE return ConverterBase.MODE_IMAGE_WITH_LANDMARKS
#override #override
def dummy_predict(self): def dummy_predict(self):

View file

@ -1,12 +1,10 @@
from models import ModelBase
import numpy as np import numpy as np
import cv2
from nnlib import DSSIMMaskLossClass from nnlib import nnlib
from nnlib import conv from models import ModelBase
from nnlib import upscale
from facelib import FaceType from facelib import FaceType
from samples import * from samples import *
class Model(ModelBase): class Model(ModelBase):
encoderH5 = 'encoder.h5' encoderH5 = 'encoder.h5'
@ -15,30 +13,27 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} ) self.set_vram_batch_requirements( {4.5:16,5:16,6:16,7:16,8:24,9:24,10:32,11:32,12:32,13:48} )
ae_input_layer = self.keras.layers.Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = self.keras.layers.Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder = self.Encoder(ae_input_layer)
self.decoder_src = self.Decoder()
self.decoder_dst = self.Decoder()
self.encoder, self.decoder_src, self.decoder_dst = self.Build(ae_input_layer)
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
self.autoencoder_src = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer))) self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder_src(self.encoder(ae_input_layer)))
self.autoencoder_dst = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer))) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder_dst(self.encoder(ae_input_layer)))
if self.is_training_mode: if self.is_training_mode:
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
optimizer = self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
dssimloss = DSSIMMaskLossClass(self.tf)([mask_layer]) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
self.autoencoder_src.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
self.autoencoder_dst.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
@ -123,33 +118,48 @@ class Model(ModelBase):
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options)
def Encoder(self, input_layer): def Build(self, input_layer):
x = input_layer exec(nnlib.code_import_all, locals(), globals())
x = conv(self.keras, x, 128)
x = conv(self.keras, x, 256) def downscale (dim):
x = conv(self.keras, x, 512) def func(x):
x = conv(self.keras, x, 1024) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func
x = self.keras.layers.Dense(512)(self.keras.layers.Flatten()(x))
x = self.keras.layers.Dense(8 * 8 * 512)(x)
x = self.keras.layers.Reshape((8, 8, 512))(x)
x = upscale(self.keras, x, 512)
return self.keras.models.Model(input_layer, x) def upscale (dim):
def func(x):
def Decoder(self): return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
input_ = self.keras.layers.Input(shape=(16, 16, 512)) return func
x = input_
x = upscale(self.keras, x, 512)
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
y = input_ #mask decoder
y = upscale(self.keras, y, 512)
y = upscale(self.keras, y, 256)
y = upscale(self.keras, y, 128)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) def Encoder(input_layer):
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) x = input_layer
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Dense(512)(Flatten()(x))
x = Dense(8 * 8 * 512)(x)
x = Reshape((8, 8, 512))(x)
x = upscale(512)(x)
return Model(input_layer, x)
def Decoder():
input_ = Input(shape=(16, 16, 512))
x = input_
x = upscale(512)(x)
x = upscale(256)(x)
x = upscale(128)(x)
y = input_ #mask decoder
y = upscale(512)(y)
y = upscale(256)(y)
y = upscale(128)(y)
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [x,y])
return self.keras.models.Model(input_, [x,y]) return Encoder(input_layer), Decoder(), Decoder()

View file

@ -1,9 +1,6 @@
import numpy as np import numpy as np
import cv2
from nnlib import DSSIMMaskLossClass from nnlib import nnlib
from nnlib import conv
from nnlib import upscale
from models import ModelBase from models import ModelBase
from facelib import FaceType from facelib import FaceType
from samples import * from samples import *
@ -16,10 +13,8 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
tf = self.tf exec(nnlib.import_all(), locals(), globals())
keras = self.keras self.set_vram_batch_requirements( {2.5:2,3:2,4:2,4:4,5:8,6:12,7:16,8:16,9:24,10:24,11:32,12:32,13:48} )
K = keras.backend
self.set_vram_batch_requirements( {2.5:2,3:2,4:2,4:4,5:8,6:8,7:16,8:16,9:24,10:24,11:32,12:32,13:48} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb) bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb)
if not self.is_first_run(): if not self.is_first_run():
@ -27,21 +22,21 @@ class Model(ModelBase):
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
input_src_bgr = self.keras.layers.Input(bgr_shape) input_src_bgr = Input(bgr_shape)
input_src_mask = self.keras.layers.Input(mask_shape) input_src_mask = Input(mask_shape)
input_dst_bgr = self.keras.layers.Input(bgr_shape) input_dst_bgr = Input(bgr_shape)
input_dst_mask = self.keras.layers.Input(mask_shape) input_dst_mask = Input(mask_shape)
rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) ) rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) )
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) ) rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = self.keras.models.Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
if self.is_training_mode: if self.is_training_mode:
self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] ) self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] )
self.ae.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMaskLossClass(self.tf)([input_src_mask]), 'mae', DSSIMMaskLossClass(self.tf)([input_dst_mask]), 'mae' ] ) loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask]) self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask]) self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])
@ -129,61 +124,73 @@ class Model(ModelBase):
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.HALF, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.HALF, **in_options)
def Build(self, created_vram_gb): def Build(self, created_vram_gb):
exec(nnlib.code_import_all, locals(), globals())
bgr_shape = (128, 128, 3) bgr_shape = (128, 128, 3)
mask_shape = (128, 128, 1) mask_shape = (128, 128, 1)
def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func
def upscale (dim):
def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder(input_shape): def Encoder(input_shape):
input_layer = self.keras.layers.Input(input_shape) input_layer = Input(input_shape)
x = input_layer x = input_layer
if created_vram_gb >= 5: if created_vram_gb >= 5:
x = conv(self.keras, x, 128) x = downscale(128)(x)
x = conv(self.keras, x, 256) x = downscale(256)(x)
x = conv(self.keras, x, 512) x = downscale(512)(x)
x = conv(self.keras, x, 1024) x = downscale(1024)(x)
x = self.keras.layers.Dense(512)(self.keras.layers.Flatten()(x)) x = Dense(512)(Flatten()(x))
x = self.keras.layers.Dense(8 * 8 * 512)(x) x = Dense(8 * 8 * 512)(x)
x = self.keras.layers.Reshape((8, 8, 512))(x) x = Reshape((8, 8, 512))(x)
x = upscale(self.keras, x, 512) x = upscale(512)(x)
else: else:
x = conv(self.keras, x, 128) x = downscale(128)(x)
x = conv(self.keras, x, 256) x = downscale(256)(x)
x = conv(self.keras, x, 512) x = downscale(512)(x)
x = conv(self.keras, x, 1024) x = downscale(1024)(x)
x = self.keras.layers.Dense(256)(self.keras.layers.Flatten()(x)) x = Dense(256)(Flatten()(x))
x = self.keras.layers.Dense(8 * 8 * 256)(x) x = Dense(8 * 8 * 256)(x)
x = self.keras.layers.Reshape((8, 8, 256))(x) x = Reshape((8, 8, 256))(x)
x = upscale(self.keras, x, 256) x = upscale(256)(x)
return self.keras.models.Model(input_layer, x) return Model(input_layer, x)
def Decoder(): def Decoder():
if created_vram_gb >= 5: if created_vram_gb >= 5:
input_ = self.keras.layers.Input(shape=(16, 16, 512)) input_ = Input(shape=(16, 16, 512))
x = input_ x = input_
x = upscale(self.keras, x, 512) x = upscale(512)(x)
x = upscale(self.keras, x, 256) x = upscale(256)(x)
x = upscale(self.keras, x, 128) x = upscale(128)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 512) y = upscale(512)(y)
y = upscale(self.keras, y, 256) y = upscale(256)(y)
y = upscale(self.keras, y, 128) y = upscale(128)(y)
else: else:
input_ = self.keras.layers.Input(shape=(16, 16, 256)) input_ = self.keras.layers.Input(shape=(16, 16, 256))
x = input_ x = input_
x = upscale(self.keras, x, 256) x = upscale(256)(x)
x = upscale(self.keras, x, 128) x = upscale(128)(x)
x = upscale(self.keras, x, 64) x = upscale(64)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 256) y = upscale(256)(y)
y = upscale(self.keras, y, 128) y = upscale(128)(y)
y = upscale(self.keras, y, 64) y = upscale(64)(y)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return self.keras.models.Model(input_, [x,y]) return Model(input_, [x,y])
return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder() return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder()

View file

@ -1,11 +1,9 @@
from models import ModelBase
import numpy as np import numpy as np
from samples import *
from nnlib import DSSIMMaskLossClass from nnlib import nnlib
from nnlib import conv from models import ModelBase
from nnlib import upscale
from facelib import FaceType from facelib import FaceType
from samples import *
class Model(ModelBase): class Model(ModelBase):
@ -15,9 +13,7 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
tf = self.tf exec(nnlib.import_all(), locals(), globals())
keras = self.keras
K = keras.backend
self.set_vram_batch_requirements( {1.5:2,2:2,3:8,4:16,5:24,6:32,7:40,8:48} ) self.set_vram_batch_requirements( {1.5:2,2:2,3:8,4:16,5:24,6:32,7:40,8:48} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb) bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb)
@ -26,21 +22,21 @@ class Model(ModelBase):
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5)) self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5)) self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
input_src_bgr = self.keras.layers.Input(bgr_shape) input_src_bgr = Input(bgr_shape)
input_src_mask = self.keras.layers.Input(mask_shape) input_src_mask = Input(mask_shape)
input_dst_bgr = self.keras.layers.Input(bgr_shape) input_dst_bgr = Input(bgr_shape)
input_dst_mask = self.keras.layers.Input(mask_shape) input_dst_mask = Input(mask_shape)
rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) ) rec_src_bgr, rec_src_mask = self.decoder_src( self.encoder(input_src_bgr) )
rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) ) rec_dst_bgr, rec_dst_mask = self.decoder_dst( self.encoder(input_dst_bgr) )
self.ae = self.keras.models.Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] ) self.ae = Model([input_src_bgr,input_src_mask,input_dst_bgr,input_dst_mask], [rec_src_bgr, rec_src_mask, rec_dst_bgr, rec_dst_mask] )
if self.is_training_mode: if self.is_training_mode:
self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] ) self.ae, = self.to_multi_gpu_model_if_possible ( [self.ae,] )
self.ae.compile(optimizer=self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), self.ae.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999),
loss=[ DSSIMMaskLossClass(self.tf)([input_src_mask]), 'mae', DSSIMMaskLossClass(self.tf)([input_dst_mask]), 'mae' ] ) loss=[ DSSIMMaskLoss([input_src_mask]), 'mae', DSSIMMaskLoss([input_dst_mask]), 'mae' ] )
self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask]) self.src_view = K.function([input_src_bgr],[rec_src_bgr, rec_src_mask])
self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask]) self.dst_view = K.function([input_dst_bgr],[rec_dst_bgr, rec_dst_mask])
@ -130,58 +126,69 @@ class Model(ModelBase):
return ConverterMasked(self.predictor_func, predictor_input_size=64, output_size=64, face_type=FaceType.HALF, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=64, output_size=64, face_type=FaceType.HALF, **in_options)
def Build(self, created_vram_gb): def Build(self, created_vram_gb):
exec(nnlib.code_import_all, locals(), globals())
bgr_shape = (64, 64, 3) bgr_shape = (64, 64, 3)
mask_shape = (64, 64, 1) mask_shape = (64, 64, 1)
def downscale (dim):
def func(x):
return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
return func
def upscale (dim):
def func(x):
return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder(input_shape): def Encoder(input_shape):
input_layer = self.keras.layers.Input(input_shape) input_layer = Input(input_shape)
x = input_layer x = input_layer
if created_vram_gb >= 4: if created_vram_gb >= 4:
x = conv(self.keras, x, 128) x = downscale(128)(x)
x = conv(self.keras, x, 256) x = downscale(256)(x)
x = conv(self.keras, x, 512) x = downscale(512)(x)
x = conv(self.keras, x, 1024) x = downscale(1024)(x)
x = self.keras.layers.Dense(1024)(self.keras.layers.Flatten()(x)) x = Dense(1024)(Flatten()(x))
x = self.keras.layers.Dense(4 * 4 * 1024)(x) x = Dense(4 * 4 * 1024)(x)
x = self.keras.layers.Reshape((4, 4, 1024))(x) x = Reshape((4, 4, 1024))(x)
x = upscale(self.keras, x, 512) x = upscale(512)(x)
else: else:
x = conv(self.keras, x, 128 ) x = downscale(128)(x)
x = conv(self.keras, x, 256 ) x = downscale(256)(x)
x = conv(self.keras, x, 512 ) x = downscale(512)(x)
x = conv(self.keras, x, 768 ) x = downscale(768)(x)
x = self.keras.layers.Dense(512)(self.keras.layers.Flatten()(x)) x = Dense(512)(Flatten()(x))
x = self.keras.layers.Dense(4 * 4 * 512)(x) x = Dense(4 * 4 * 512)(x)
x = self.keras.layers.Reshape((4, 4, 512))(x) x = Reshape((4, 4, 512))(x)
x = upscale(self.keras, x, 256) x = upscale(256)(x)
return Model(input_layer, x)
return self.keras.models.Model(input_layer, x)
def Decoder(): def Decoder():
if created_vram_gb >= 4: if created_vram_gb >= 4:
input_ = self.keras.layers.Input(shape=(8, 8, 512)) input_ = Input(shape=(8, 8, 512))
x = input_ x = input_
x = upscale(self.keras, x, 512)
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
else: x = upscale(512)(x)
input_ = self.keras.layers.Input(shape=(8, 8, 256)) x = upscale(256)(x)
x = upscale(128)(x)
x = input_ else:
x = upscale(self.keras, x, 256) input_ = Input(shape=(8, 8, 256))
x = upscale(self.keras, x, 128)
x = upscale(self.keras, x, 64) x = input_
x = upscale(256)(x)
x = upscale(128)(x)
x = upscale(64)(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 256) y = upscale(256)(y)
y = upscale(self.keras, y, 128) y = upscale(128)(y)
y = upscale(self.keras, y, 64) y = upscale(64)(y)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [x,y])
return self.keras.models.Model(input_, [x,y])
return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder() return bgr_shape, mask_shape, Encoder(bgr_shape), Decoder(), Decoder()

View file

@ -1,10 +1,7 @@
from models import ModelBase
import numpy as np import numpy as np
import cv2
from nnlib import DSSIMMaskLossClass from nnlib import nnlib
from nnlib import conv from models import ModelBase
from nnlib import upscale
from facelib import FaceType from facelib import FaceType
from samples import * from samples import *
@ -17,16 +14,14 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} ) self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} )
ae_input_layer = self.keras.layers.Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = self.keras.layers.Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder = self.Encoder(ae_input_layer) self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
self.decoder = self.Decoder()
self.inter_B = self.Intermediate ()
self.inter_AB = self.Intermediate ()
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
@ -36,16 +31,14 @@ class Model(ModelBase):
code = self.encoder(ae_input_layer) code = self.encoder(ae_input_layer)
AB = self.inter_AB(code) AB = self.inter_AB(code)
B = self.inter_B(code) B = self.inter_B(code)
self.autoencoder_src = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([AB, AB])) ) self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
self.autoencoder_dst = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([B, AB])) ) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
if self.is_training_mode: if self.is_training_mode:
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
optimizer = self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
dssimloss = DSSIMMaskLossClass(self.tf)([mask_layer]) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
self.autoencoder_src.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
self.autoencoder_dst.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
@ -131,37 +124,52 @@ class Model(ModelBase):
in_options['blur_mask_modifier'] = 0 in_options['blur_mask_modifier'] = 0
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options)
def Build(self, input_layer):
exec(nnlib.code_import_all, locals(), globals())
def Encoder(self, input_layer,): def downscale (dim):
x = input_layer def func(x):
x = conv(self.keras, x, 128) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
x = conv(self.keras, x, 256) return func
x = conv(self.keras, x, 512)
x = conv(self.keras, x, 1024) def upscale (dim):
x = self.keras.layers.Flatten()(x) def func(x):
return self.keras.models.Model(input_layer, x) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder():
x = input_layer
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Flatten()(x)
return Model(input_layer, x)
def Intermediate(self): def Intermediate():
input_layer = self.keras.layers.Input(shape=(None, 8 * 8 * 1024)) input_layer = Input(shape=(None, 8 * 8 * 1024))
x = input_layer x = input_layer
x = self.keras.layers.Dense(256)(x) x = Dense(256)(x)
x = self.keras.layers.Dense(8 * 8 * 512)(x) x = Dense(8 * 8 * 512)(x)
x = self.keras.layers.Reshape((8, 8, 512))(x) x = Reshape((8, 8, 512))(x)
x = upscale(self.keras, x, 512) x = upscale(512)(x)
return self.keras.models.Model(input_layer, x) return Model(input_layer, x)
def Decoder(self): def Decoder():
input_ = self.keras.layers.Input(shape=(16, 16, 1024)) input_ = Input(shape=(16, 16, 1024))
x = input_ x = input_
x = upscale(self.keras, x, 512) x = upscale(512)(x)
x = upscale(self.keras, x, 256) x = upscale(256)(x)
x = upscale(self.keras, x, 128) x = upscale(128)(x)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 512) y = upscale(512)(y)
y = upscale(self.keras, y, 256) y = upscale(256)(y)
y = upscale(self.keras, y, 128) y = upscale(128)(y)
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y)
return self.keras.models.Model(input_, [x,y]) return Model(input_, [x,y])
return Encoder(), Decoder(), Intermediate(), Intermediate()

View file

@ -1,13 +1,10 @@
from models import ModelBase
import numpy as np import numpy as np
import cv2
from nnlib import DSSIMMaskLossClass from nnlib import nnlib
from nnlib import conv from models import ModelBase
from nnlib import upscale
from facelib import FaceType from facelib import FaceType
from samples import * from samples import *
class Model(ModelBase): class Model(ModelBase):
encoderH5 = 'encoder.h5' encoderH5 = 'encoder.h5'
@ -17,16 +14,14 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} ) self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} )
ae_input_layer = self.keras.layers.Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = self.keras.layers.Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder = self.Encoder(ae_input_layer) self.encoder, self.decoder, self.inter_B, self.inter_AB = self.Build(ae_input_layer)
self.decoder = self.Decoder()
self.inter_B = self.Intermediate ()
self.inter_AB = self.Intermediate ()
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5)) self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
@ -36,16 +31,14 @@ class Model(ModelBase):
code = self.encoder(ae_input_layer) code = self.encoder(ae_input_layer)
AB = self.inter_AB(code) AB = self.inter_AB(code)
B = self.inter_B(code) B = self.inter_B(code)
self.autoencoder_src = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([AB, AB])) ) self.autoencoder_src = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([AB, AB])) )
self.autoencoder_dst = self.keras.models.Model([ae_input_layer,mask_layer], self.decoder(self.keras.layers.Concatenate()([B, AB])) ) self.autoencoder_dst = Model([ae_input_layer,mask_layer], self.decoder(Concatenate()([B, AB])) )
if self.is_training_mode: if self.is_training_mode:
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
optimizer = self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
dssimloss = DSSIMMaskLossClass(self.tf)([mask_layer]) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
self.autoencoder_src.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
self.autoencoder_dst.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
@ -131,37 +124,52 @@ class Model(ModelBase):
in_options['blur_mask_modifier'] = 0 in_options['blur_mask_modifier'] = 0
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options)
def Build(self, input_layer):
exec(nnlib.code_import_all, locals(), globals())
def Encoder(self, input_layer,): def downscale (dim):
x = input_layer def func(x):
x = conv(self.keras, x, 128) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
x = conv(self.keras, x, 256) return func
x = conv(self.keras, x, 512)
x = conv(self.keras, x, 1024) def upscale (dim):
x = self.keras.layers.Flatten()(x) def func(x):
return self.keras.models.Model(input_layer, x) return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
return func
def Encoder():
x = input_layer
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Flatten()(x)
return Model(input_layer, x)
def Intermediate(self): def Intermediate():
input_layer = self.keras.layers.Input(shape=(None, 8 * 8 * 1024)) input_layer = Input(shape=(None, 8 * 8 * 1024))
x = input_layer x = input_layer
x = self.keras.layers.Dense(256)(x) x = Dense(256)(x)
x = self.keras.layers.Dense(8 * 8 * 512)(x) x = Dense(8 * 8 * 512)(x)
x = self.keras.layers.Reshape((8, 8, 512))(x) x = Reshape((8, 8, 512))(x)
x = upscale(self.keras, x, 512) x = upscale(512)(x)
return self.keras.models.Model(input_layer, x) return Model(input_layer, x)
def Decoder(self): def Decoder():
input_ = self.keras.layers.Input(shape=(16, 16, 1024)) input_ = Input(shape=(16, 16, 1024))
x = input_ x = input_
x = upscale(self.keras, x, 512) x = upscale(512)(x)
x = upscale(self.keras, x, 256) x = upscale(256)(x)
x = upscale(self.keras, x, 128) x = upscale(128)(x)
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 512) y = upscale(512)(y)
y = upscale(self.keras, y, 256) y = upscale(256)(y)
y = upscale(self.keras, y, 128) y = upscale(128)(y)
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y) y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid' )(y)
return self.keras.models.Model(input_, [x,y]) return Model(input_, [x,y])
return Encoder(), Decoder(), Intermediate(), Intermediate()

View file

@ -1,10 +1,7 @@
import numpy as np import numpy as np
import cv2
from nnlib import nnlib
from models import ModelBase from models import ModelBase
from nnlib import DSSIMMaskLossClass
from nnlib import conv
from nnlib import upscale
from facelib import FaceType from facelib import FaceType
from samples import * from samples import *
@ -21,20 +18,15 @@ class Model(ModelBase):
#override #override
def onInitialize(self, **in_options): def onInitialize(self, **in_options):
exec(nnlib.import_all(), locals(), globals())
self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} ) self.set_vram_batch_requirements( {4.5:4,5:4,6:8,7:12,8:16,9:20,10:24,11:24,12:32,13:48} )
ae_input_layer = self.keras.layers.Input(shape=(128, 128, 3)) ae_input_layer = Input(shape=(128, 128, 3))
mask_layer = self.keras.layers.Input(shape=(128, 128, 1)) #same as output mask_layer = Input(shape=(128, 128, 1)) #same as output
self.encoder = self.Encoder(ae_input_layer) self.encoder, self.decoderMask, self.decoderCommonA, self.decoderCommonB, self.decoderRGB, \
self.decoderMask = self.DecoderMask() self.decoderBW, self.inter_A, self.inter_B = self.Build(ae_input_layer)
self.decoderCommonA = self.DecoderCommon()
self.decoderCommonB = self.DecoderCommon()
self.decoderRGB = self.DecoderRGB()
self.decoderBW = self.DecoderBW()
self.inter_A = self.Intermediate ()
self.inter_B = self.Intermediate ()
if not self.is_first_run(): if not self.is_first_run():
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5)) self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
self.decoderMask.load_weights (self.get_strpath_storage_for_file(self.decoderMaskH5)) self.decoderMask.load_weights (self.get_strpath_storage_for_file(self.decoderMaskH5))
@ -49,37 +41,35 @@ class Model(ModelBase):
A = self.inter_A(code) A = self.inter_A(code)
B = self.inter_B(code) B = self.inter_B(code)
inter_A_A = self.keras.layers.Concatenate()([A, A]) inter_A_A = Concatenate()([A, A])
inter_B_A = self.keras.layers.Concatenate()([B, A]) inter_B_A = Concatenate()([B, A])
x1,m1 = self.decoderCommonA (inter_A_A) x1,m1 = self.decoderCommonA (inter_A_A)
x2,m2 = self.decoderCommonA (inter_A_A) x2,m2 = self.decoderCommonA (inter_A_A)
self.autoencoder_src = self.keras.models.Model([ae_input_layer,mask_layer], self.autoencoder_src = Model([ae_input_layer,mask_layer],
[ self.decoderBW (self.keras.layers.Concatenate()([x1,x2]) ), [ self.decoderBW (Concatenate()([x1,x2]) ),
self.decoderMask(self.keras.layers.Concatenate()([m1,m2]) ) self.decoderMask(Concatenate()([m1,m2]) )
]) ])
x1,m1 = self.decoderCommonA (inter_A_A) x1,m1 = self.decoderCommonA (inter_A_A)
x2,m2 = self.decoderCommonB (inter_A_A) x2,m2 = self.decoderCommonB (inter_A_A)
self.autoencoder_src_RGB = self.keras.models.Model([ae_input_layer,mask_layer], self.autoencoder_src_RGB = Model([ae_input_layer,mask_layer],
[ self.decoderRGB (self.keras.layers.Concatenate()([x1,x2]) ), [ self.decoderRGB (Concatenate()([x1,x2]) ),
self.decoderMask (self.keras.layers.Concatenate()([m1,m2]) ) self.decoderMask (Concatenate()([m1,m2]) )
]) ])
x1,m1 = self.decoderCommonA (inter_B_A) x1,m1 = self.decoderCommonA (inter_B_A)
x2,m2 = self.decoderCommonB (inter_B_A) x2,m2 = self.decoderCommonB (inter_B_A)
self.autoencoder_dst = self.keras.models.Model([ae_input_layer,mask_layer], self.autoencoder_dst = Model([ae_input_layer,mask_layer],
[ self.decoderRGB (self.keras.layers.Concatenate()([x1,x2]) ), [ self.decoderRGB (Concatenate()([x1,x2]) ),
self.decoderMask (self.keras.layers.Concatenate()([m1,m2]) ) self.decoderMask (Concatenate()([m1,m2]) )
]) ])
if self.is_training_mode: if self.is_training_mode:
self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] ) self.autoencoder_src, self.autoencoder_dst = self.to_multi_gpu_model_if_possible ( [self.autoencoder_src, self.autoencoder_dst] )
optimizer = self.keras.optimizers.Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) self.autoencoder_src.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
dssimloss = DSSIMMaskLossClass(self.tf)([mask_layer]) self.autoencoder_dst.compile(optimizer=Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), loss=[DSSIMMaskLoss([mask_layer]), 'mse'] )
self.autoencoder_src.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
self.autoencoder_dst.compile(optimizer=optimizer, loss=[dssimloss, 'mse'] )
if self.is_training_mode: if self.is_training_mode:
f = SampleProcessor.TypeFlags f = SampleProcessor.TypeFlags
@ -169,53 +159,67 @@ class Model(ModelBase):
return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options) return ConverterMasked(self.predictor_func, predictor_input_size=128, output_size=128, face_type=FaceType.FULL, clip_border_mask_per=0.046875, **in_options)
def Build(self, input_layer):
def Encoder(self, input_layer,): exec(nnlib.code_import_all, locals(), globals())
x = input_layer
x = conv(self.keras, x, 128)
x = conv(self.keras, x, 256)
x = conv(self.keras, x, 512)
x = conv(self.keras, x, 1024)
x = self.keras.layers.Flatten()(x)
return self.keras.models.Model(input_layer, x)
def Intermediate(self):
input_layer = self.keras.layers.Input(shape=(None, 8 * 8 * 1024))
x = input_layer
x = self.keras.layers.Dense(256)(x)
x = self.keras.layers.Dense(8 * 8 * 512)(x)
x = self.keras.layers.Reshape((8, 8, 512))(x)
x = upscale(self.keras, x, 512)
return self.keras.models.Model(input_layer, x)
def DecoderCommon(self):
input_ = self.keras.layers.Input(shape=(16, 16, 1024))
x = input_
x = upscale(self.keras, x, 512)
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
y = input_ def downscale (dim):
y = upscale(self.keras, y, 256) def func(x):
y = upscale(self.keras, y, 128) return LeakyReLU(0.1)(Conv2D(dim, 5, strides=2, padding='same')(x))
y = upscale(self.keras, y, 64) return func
return self.keras.models.Model(input_, [x,y]) def upscale (dim):
def func(x):
def DecoderRGB(self): return PixelShuffler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
input_ = self.keras.layers.Input(shape=(128, 128, 256)) return func
x = input_
x = self.keras.layers.convolutional.Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x) def Encoder():
return self.keras.models.Model(input_, [x]) x = input_layer
x = downscale(128)(x)
x = downscale(256)(x)
x = downscale(512)(x)
x = downscale(1024)(x)
x = Flatten()(x)
return Model(input_layer, x)
def DecoderBW(self): def Intermediate():
input_ = self.keras.layers.Input(shape=(128, 128, 256)) input_layer = Input(shape=(None, 8 * 8 * 1024))
x = input_ x = input_layer
x = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(x) x = Dense(256)(x)
return self.keras.models.Model(input_, [x]) x = Dense(8 * 8 * 512)(x)
x = Reshape((8, 8, 512))(x)
def DecoderMask(self): x = upscale(512)(x)
input_ = self.keras.layers.Input(shape=(128, 128, 128)) return Model(input_layer, x)
y = input_
y = self.keras.layers.convolutional.Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y) def DecoderCommon():
return self.keras.models.Model(input_, [y]) input_ = Input(shape=(16, 16, 1024))
x = input_
x = upscale(512)(x)
x = upscale(256)(x)
x = upscale(128)(x)
y = input_
y = upscale(256)(y)
y = upscale(128)(y)
y = upscale(64)(y)
return Model(input_, [x,y])
def DecoderRGB():
input_ = Input(shape=(128, 128, 256))
x = input_
x = Conv2D(3, kernel_size=5, padding='same', activation='sigmoid')(x)
return Model(input_, [x])
def DecoderBW():
input_ = Input(shape=(128, 128, 256))
x = input_
x = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(x)
return Model(input_, [x])
def DecoderMask():
input_ = Input(shape=(128, 128, 128))
y = input_
y = Conv2D(1, kernel_size=5, padding='same', activation='sigmoid')(y)
return Model(input_, [y])
return Encoder(), DecoderMask(), DecoderCommon(), DecoderCommon(), DecoderRGB(), DecoderBW(), Intermediate(), Intermediate()

View file

@ -0,0 +1,243 @@
from models import ModelBase
import numpy as np
import cv2
from mathlib import get_power_of_two
from nnlib import nnlib
from facelib import FaceType
from samples import *
class Model(ModelBase):
GAH5 = 'GA.h5'
PAH5 = 'PA.h5'
DAH5 = 'DA.h5'
GBH5 = 'GB.h5'
DBH5 = 'DB.h5'
PBH5 = 'PB.h5'
#override
def onInitialize(self, batch_size=-1, **in_options):
exec(nnlib.code_import_all, locals(), globals())
self.set_vram_batch_requirements( {6:6} )
created_batch_size = self.get_batch_size()
if self.epoch == 0:
#first run
print ("\nModel first run. Enter options.")
try:
input_created_batch_size = int ( input ("Batch_size (default - based on VRAM) : ") )
except:
input_created_batch_size = 0
if input_created_batch_size != 0:
created_batch_size = input_created_batch_size
self.options['created_batch_size'] = created_batch_size
self.created_vram_gb = self.device_config.gpu_total_vram_gb
else:
#not first run
if 'created_batch_size' in self.options.keys():
created_batch_size = self.options['created_batch_size']
else:
raise Exception("Continue traning, but created_batch_size not found.")
resolution = 128
bgr_shape = (resolution, resolution, 3)
ngf = 64
npf = 64
ndf = 64
lambda_A = 10
lambda_B = 10
self.set_batch_size(created_batch_size)
use_batch_norm = created_batch_size > 1
self.GA = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=False))(Input(bgr_shape))
self.GB = modelify(ResNet (bgr_shape[2], use_batch_norm, n_blocks=6, ngf=ngf, use_dropout=False))(Input(bgr_shape))
#self.GA = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
#self.GB = modelify(UNet (bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=ngf, use_dropout=True))(Input(bgr_shape))
self.PA = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
self.PB = modelify(UNetTemporalPredictor(bgr_shape[2], use_batch_norm, num_downs=get_power_of_two(resolution)-1, ngf=npf, use_dropout=True))([Input(bgr_shape), Input(bgr_shape)])
self.DA = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
self.DB = modelify(NLayerDiscriminator(use_batch_norm, ndf=ndf, n_layers=3) ) (Input(bgr_shape))
if not self.is_first_run():
self.GA.load_weights (self.get_strpath_storage_for_file(self.GAH5))
self.DA.load_weights (self.get_strpath_storage_for_file(self.DAH5))
self.PA.load_weights (self.get_strpath_storage_for_file(self.PAH5))
self.GB.load_weights (self.get_strpath_storage_for_file(self.GBH5))
self.DB.load_weights (self.get_strpath_storage_for_file(self.DBH5))
self.PB.load_weights (self.get_strpath_storage_for_file(self.PBH5))
real_A0 = Input(bgr_shape, name="real_A0")
real_A1 = Input(bgr_shape, name="real_A1")
real_A2 = Input(bgr_shape, name="real_A2")
real_B0 = Input(bgr_shape, name="real_B0")
real_B1 = Input(bgr_shape, name="real_B1")
real_B2 = Input(bgr_shape, name="real_B2")
DA_ones = K.ones ( K.int_shape(self.DA.outputs[0])[1:] )
DA_zeros = K.zeros ( K.int_shape(self.DA.outputs[0])[1:] )
DB_ones = K.ones ( K.int_shape(self.DB.outputs[0])[1:] )
DB_zeros = K.zeros ( K.int_shape(self.DB.outputs[0])[1:] )
def CycleLoss (t1,t2):
return K.mean(K.square(t1 - t2))
def RecurrentLOSS(t1,t2):
return K.mean(K.square(t1 - t2))
def RecycleLOSS(t1,t2):
return K.mean(K.square(t1 - t2))
fake_B0 = self.GA(real_A0)
fake_B1 = self.GA(real_A1)
fake_A0 = self.GB(real_B0)
fake_A1 = self.GB(real_B1)
#rec_FB0 = self.GA(fake_A0)
#rec_FB1 = self.GA(fake_A1)
#rec_FA0 = self.GB(fake_B0)
#rec_FA1 = self.GB(fake_B1)
pred_A2 = self.PA ( [real_A0, real_A1])
pred_B2 = self.PB ( [real_B0, real_B1])
rec_A2 = self.GB ( self.PB ( [fake_B0, fake_B1]) )
rec_B2 = self.GA ( self.PA ( [fake_A0, fake_A1]))
loss_G = K.mean(K.square(self.DB(fake_B0) - DB_ones)) + \
K.mean(K.square(self.DB(fake_B1) - DB_ones)) + \
K.mean(K.square(self.DA(fake_A0) - DA_ones)) + \
K.mean(K.square(self.DA(fake_A1) - DA_ones)) + \
lambda_A * ( #CycleLoss(rec_FA0, real_A0) + \
#CycleLoss(rec_FA1, real_A1) + \
RecurrentLOSS(pred_A2, real_A2) + \
RecycleLOSS(rec_A2, real_A2) ) + \
lambda_B * ( #CycleLoss(rec_FB0, real_B0) + \
#CycleLoss(rec_FB1, real_B1) + \
RecurrentLOSS(pred_B2, real_B2) + \
RecycleLOSS(rec_B2, real_B2) )
weights_G = self.GA.trainable_weights + self.GB.trainable_weights + self.PA.trainable_weights + self.PB.trainable_weights
self.G_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_G],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_G, weights_G) )
###########
loss_D_A0 = ( K.mean(K.square( self.DA(real_A0) - DA_ones)) + \
K.mean(K.square( self.DA(fake_A0) - DA_zeros)) ) * 0.5
loss_D_A1 = ( K.mean(K.square( self.DA(real_A1) - DA_ones)) + \
K.mean(K.square( self.DA(fake_A1) - DA_zeros)) ) * 0.5
loss_D_A = loss_D_A0 + loss_D_A1
self.DA_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_A],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_A, self.DA.trainable_weights) )
############
loss_D_B0 = ( K.mean(K.square( self.DB(real_B0) - DB_ones)) + \
K.mean(K.square( self.DB(fake_B0) - DB_zeros)) ) * 0.5
loss_D_B1 = ( K.mean(K.square( self.DB(real_B1) - DB_ones)) + \
K.mean(K.square( self.DB(fake_B1) - DB_zeros)) ) * 0.5
loss_D_B = loss_D_B0 + loss_D_B1
self.DB_train = K.function ([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[loss_D_B],
Adam(lr=2e-4, beta_1=0.5, beta_2=0.999).get_updates(loss_D_B, self.DB.trainable_weights) )
############
self.G_view = K.function([real_A0, real_A1, real_A2, real_B0, real_B1, real_B2],[fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 ])
self.G_convert = K.function([real_B0],[fake_A0])
if self.is_training_mode:
f = SampleProcessor.TypeFlags
self.set_training_data_generators ([
SampleGeneratorImageTemporal(self.training_data_src_path, debug=self.is_debug(), batch_size=self.batch_size,
temporal_image_count=3,
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
SampleGeneratorImageTemporal(self.training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
temporal_image_count=3,
sample_process_options=SampleProcessor.Options(random_flip = False, normalize_tanh = True),
output_sample_types=[ [f.SOURCE | f.MODE_BGR, resolution] ] ),
])
#import code
#code.interact(local=dict(globals(), **locals()))
self.supress_std_once = False
#override
def onSave(self):
self.save_weights_safe( [[self.GA, self.get_strpath_storage_for_file(self.GAH5)],
[self.GB, self.get_strpath_storage_for_file(self.GBH5)],
[self.DA, self.get_strpath_storage_for_file(self.DAH5)],
[self.DB, self.get_strpath_storage_for_file(self.DBH5)],
[self.PA, self.get_strpath_storage_for_file(self.PAH5)],
[self.PB, self.get_strpath_storage_for_file(self.PBH5)] ])
#override
def onTrainOneEpoch(self, sample):
source_src_0, source_src_1, source_src_2, = sample[0]
source_dst_0, source_dst_1, source_dst_2, = sample[1]
feed = [source_src_0, source_src_1, source_src_2, source_dst_0, source_dst_1, source_dst_2]
loss_G, = self.G_train ( feed )
loss_DA, = self.DA_train( feed )
loss_DB, = self.DB_train( feed )
return ( ('G', loss_G), ('DA', loss_DA), ('DB', loss_DB) )
#override
def onGetPreview(self, sample):
test_A0 = sample[0][0]
test_A1 = sample[0][1]
test_A2 = sample[0][2]
test_B0 = sample[1][0]
test_B1 = sample[1][1]
test_B2 = sample[1][2]
G_view_result = self.G_view([test_A0, test_A1, test_A2, test_B0, test_B1, test_B2])
fake_A0, fake_A1, pred_A2, rec_A2, fake_B0, fake_B1, pred_B2, rec_B2 = [ x[0] / 2 + 0.5 for x in G_view_result]
test_A0, test_A1, test_A2, test_B0, test_B1, test_B2 = [ x[0] / 2 + 0.5 for x in [test_A0, test_A1, test_A2, test_B0, test_B1, test_B2] ]
r = np.concatenate ((np.concatenate ( (test_A0, test_A1, test_A2, pred_A2, fake_B0, fake_B1, rec_A2), axis=1),
np.concatenate ( (test_B0, test_B1, test_B2, pred_B2, fake_A0, fake_A1, rec_B2), axis=1)
), axis=0)
return [ ('RecycleGAN, A0-A1-A2-PA2-FB0-FB1-RA2, B0-B1-B2-PB2-FA0-FA1-RB2, ', r ) ]
def predictor_func (self, face):
x = self.G_convert ( [ np.expand_dims(face *2 - 1,0)] )[0]
return x[0] / 2 + 0.5
#override
def get_converter(self, **in_options):
from models import ConverterImage
return ConverterImage(self.predictor_func, predictor_input_size=128, output_size=128, **in_options)

View file

@ -0,0 +1 @@
from .Model import Model

View file

@ -1,635 +1 @@
def tf_image_histogram (tf, input): from .nnlib import nnlib
x = input
x += 1 / 255.0
output = []
for i in range(256, 0, -1):
v = i / 255.0
y = (x - v) * 1000
y = tf.clip_by_value (y, -1.0, 0.0) + 1
output.append ( tf.reduce_sum (y) )
x -= y*v
return tf.stack ( output[::-1] )
def tf_dssim(tf, t1, t2):
return (1.0 - tf.image.ssim (t1, t2, 1.0)) / 2.0
def tf_ssim(tf, t1, t2):
return tf.image.ssim (t1, t2, 1.0)
def DSSIMMaskLossClass(tf):
class DSSIMMaskLoss(object):
def __init__(self, mask_list, is_tanh=False):
self.mask_list = mask_list
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
total_loss = None
for mask in self.mask_list:
if not self.is_tanh:
loss = (1.0 - tf.image.ssim (y_true*mask, y_pred*mask, 1.0)) / 2.0
else:
loss = (1.0 - tf.image.ssim ( (y_true/2+0.5)*(mask/2+0.5), (y_pred/2+0.5)*(mask/2+0.5), 1.0)) / 2.0
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss
return DSSIMMaskLoss
def DSSIMPatchMaskLossClass(tf):
class DSSIMPatchMaskLoss(object):
def __init__(self, mask_list, is_tanh=False):
self.mask_list = mask_list
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
total_loss = None
for mask in self.mask_list:
#import code
#code.interact(local=dict(globals(), **locals()))
y_true = tf.extract_image_patches ( y_true, (1,9,9,1), (1,1,1,1), (1,8,8,1), 'VALID' )
y_pred = tf.extract_image_patches ( y_pred, (1,9,9,1), (1,1,1,1), (1,8,8,1), 'VALID' )
mask = tf.extract_image_patches ( tf.tile(mask,[1,1,1,3]) , (1,9,9,1), (1,1,1,1), (1,8,8,1), 'VALID' )
if not self.is_tanh:
loss = (1.0 - tf.image.ssim (y_true*mask, y_pred*mask, 1.0)) / 2.0
else:
loss = (1.0 - tf.image.ssim ( (y_true/2+0.5)*(mask/2+0.5), (y_pred/2+0.5)*(mask/2+0.5), 1.0)) / 2.0
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss
return DSSIMPatchMaskLoss
def DSSIMLossClass(tf):
class DSSIMLoss(object):
def __init__(self, is_tanh=False):
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
if not self.is_tanh:
return (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
else:
return (1.0 - tf.image.ssim ((y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0
return DSSIMLoss
def rgb_to_lab(tf, rgb_input):
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
#output [0, 100] , ~[-110, 110], ~[-110, 110]
lab_pixels = lab_pixels / tf.constant([100.0, 220.0, 220.0 ]) + tf.constant([0.0, 0.5, 0.5])
#output [0-1, 0-1, 0-1]
return tf.reshape(lab_pixels, tf.shape(rgb_input))
def lab_to_rgb(tf, lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))
def DSSIMPatchLossClass(tf, keras):
class DSSIMPatchLoss(object):
def __init__(self, is_tanh=False):
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
y_pred_lab = rgb_to_lab(tf, y_pred)
y_true_lab = rgb_to_lab(tf, y_true)
#import code
#code.interact(local=dict(globals(), **locals()))
return keras.backend.mean ( keras.backend.square(y_true_lab - y_pred_lab) ) # + (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
if not self.is_tanh:
return (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
else:
return (1.0 - tf.image.ssim ((y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0
#y_true_72 = tf.extract_image_patches ( y_true, (1,8,8,1), (1,1,1,1), (1,8,8,1), 'VALID' )
#y_pred_72 = tf.extract_image_patches ( y_pred, (1,8,8,1), (1,1,1,1), (1,8,8,1), 'VALID' )
#y_true_36 = tf.extract_image_patches ( y_true, (1,8,8,1), (1,2,2,1), (1,8,8,1), 'VALID' )
#y_pred_36 = tf.extract_image_patches ( y_pred, (1,8,8,1), (1,2,2,1), (1,8,8,1), 'VALID' )
#if not self.is_tanh:
# return (1.0 - tf.image.ssim (y_true_72, y_pred_72, 1.0)) / 2.0 + \
# (1.0 - tf.image.ssim (y_true_36, y_pred_36, 1.0)) / 2.0
#
#else:
# return (1.0 - tf.image.ssim ((y_true_72/2+0.5), (y_pred_72/2+0.5), 1.0)) / 2.0 + \
# (1.0 - tf.image.ssim ((y_true_36/2+0.5), (y_pred_36/2+0.5), 1.0)) / 2.0
return DSSIMPatchLoss
def MSEMaskLossClass(keras):
class MSEMaskLoss(object):
def __init__(self, mask_list, is_tanh=False):
self.mask_list = mask_list
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
K = keras.backend
total_loss = None
for mask in self.mask_list:
if not self.is_tanh:
loss = K.mean(K.square(y_true*mask - y_pred*mask))
else:
loss = K.mean(K.square( (y_true/2+0.5)*(mask/2+0.5) - (y_pred/2+0.5)*(mask/2+0.5) ))
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss
return MSEMaskLoss
def PixelShufflerClass(keras):
class PixelShuffler(keras.layers.Layer):
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = keras.backend.common.normalize_data_format(data_format)
self.size = keras.utils.conv_utils.normalize_tuple(size, 2, 'size')
def call(self, inputs):
input_shape = keras.backend.int_shape(inputs)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = keras.backend.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = keras.backend.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = keras.backend.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = keras.backend.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = keras.backend.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = keras.backend.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
height,
width,
channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return PixelShuffler
def conv(keras, input_tensor, filters):
x = input_tensor
x = keras.layers.convolutional.Conv2D(filters, kernel_size=5, strides=2, padding='same')(x)
x = keras.layers.advanced_activations.LeakyReLU(0.1)(x)
return x
def upscale(keras, input_tensor, filters, k_size=3):
x = input_tensor
x = keras.layers.convolutional.Conv2D(filters * 4, kernel_size=k_size, padding='same')(x)
x = keras.layers.advanced_activations.LeakyReLU(0.1)(x)
x = PixelShufflerClass(keras)()(x)
return x
def upscale4(keras, input_tensor, filters):
x = input_tensor
x = keras.layers.convolutional.Conv2D(filters * 16, kernel_size=3, padding='same')(x)
x = keras.layers.advanced_activations.LeakyReLU(0.1)(x)
x = PixelShufflerClass(keras)(size=(4, 4))(x)
return x
def res(keras, input_tensor, filters):
x = input_tensor
x = keras.layers.convolutional.Conv2D(filters, kernel_size=3, kernel_initializer=keras.initializers.RandomNormal(0, 0.02), use_bias=False, padding="same")(x)
x = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)(x)
x = keras.layers.convolutional.Conv2D(filters, kernel_size=3, kernel_initializer=keras.initializers.RandomNormal(0, 0.02), use_bias=False, padding="same")(x)
x = keras.layers.Add()([x, input_tensor])
x = keras.layers.advanced_activations.LeakyReLU(alpha=0.2)(x)
return x
def resize_like(tf, keras, ref_tensor, input_tensor):
def func(input_tensor, ref_tensor):
H, W = ref_tensor.get_shape()[1], ref_tensor.get_shape()[2]
return tf.image.resize_bilinear(input_tensor, [H.value, W.value])
return keras.layers.Lambda(func, arguments={'ref_tensor':ref_tensor})(input_tensor)
def total_variation_loss(keras, x):
K = keras.backend
assert K.ndim(x) == 4
B,H,W,C = K.int_shape(x)
a = K.square(x[:, :H - 1, :W - 1, :] - x[:, 1:, :W - 1, :])
b = K.square(x[:, :H - 1, :W - 1, :] - x[:, :H - 1, 1:, :])
return K.mean (a+b)
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
def UNet(keras, tf, input_shape, output_nc, num_downs, ngf=64, use_dropout=False):
Conv2D = keras.layers.convolutional.Conv2D
Conv2DTranspose = keras.layers.convolutional.Conv2DTranspose
LeakyReLU = keras.layers.advanced_activations.LeakyReLU
BatchNormalization = keras.layers.BatchNormalization
ReLU = keras.layers.ReLU
tanh = keras.layers.Activation('tanh')
Dropout = keras.layers.Dropout
Concatenate = keras.layers.Concatenate
ZeroPadding2D = keras.layers.ZeroPadding2D
conv_kernel_initializer = keras.initializers.RandomNormal(0, 0.02)
norm_gamma_initializer = keras.initializers.RandomNormal(1, 0.02)
input = keras.layers.Input (input_shape)
def UNetSkipConnection(outer_nc, inner_nc, sub_model=None, outermost=False, innermost=False, use_dropout=False):
def func(inp):
downconv_pad = ZeroPadding2D( (1,1) )
downconv = Conv2D(inner_nc, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=2, padding='valid', use_bias=False)
downrelu = LeakyReLU(0.2)
downnorm = BatchNormalization( gamma_initializer=norm_gamma_initializer )
upconv = Conv2DTranspose(outer_nc, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=2, padding='same', use_bias=False)
uprelu = ReLU()
upnorm = BatchNormalization( gamma_initializer=norm_gamma_initializer )
if outermost:
x = inp
x = downconv(downconv_pad(x))
x = sub_model(x)
x = uprelu(x)
x = upconv(x)
x = tanh(x)
elif innermost:
x = inp
x = downrelu(x)
x = downconv(downconv_pad(x))
x = uprelu(x)
x = upconv(x)
x = upnorm(x)
x = Concatenate(axis=3)([inp, x])
else:
x = inp
x = downrelu(x)
x = downconv(downconv_pad(x))
x = downnorm(x)
x = sub_model(x)
x = uprelu(x)
x = upconv(x)
x = upnorm(x)
if use_dropout:
x = Dropout(0.5)(x)
x = Concatenate(axis=3)([inp, x])
return x
return func
unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=None, innermost=True)
#for i in range(num_downs - 5):
# unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=unet_block, use_dropout=use_dropout)
unet_block = UNetSkipConnection(ngf * 4 , ngf * 8, sub_model=unet_block)
unet_block = UNetSkipConnection(ngf * 2 , ngf * 4, sub_model=unet_block)
unet_block = UNetSkipConnection(ngf , ngf * 2, sub_model=unet_block)
unet_block = UNetSkipConnection(output_nc, ngf , sub_model=unet_block, outermost=True)
x = input
x = unet_block(x)
return keras.models.Model (input,x)
#predicts based on two past_image_tensors
def UNetTemporalPredictor(keras, tf, input_shape, output_nc, num_downs, ngf=32, use_dropout=False):
K = keras.backend
Conv2D = keras.layers.convolutional.Conv2D
Conv2DTranspose = keras.layers.convolutional.Conv2DTranspose
LeakyReLU = keras.layers.advanced_activations.LeakyReLU
BatchNormalization = keras.layers.BatchNormalization
ReLU = keras.layers.ReLU
tanh = keras.layers.Activation('tanh')
ReflectionPadding2D = ReflectionPadding2DClass(keras, tf)
ZeroPadding2D = keras.layers.ZeroPadding2D
Dropout = keras.layers.Dropout
Concatenate = keras.layers.Concatenate
conv_kernel_initializer = keras.initializers.RandomNormal(0, 0.02)
norm_gamma_initializer = keras.initializers.RandomNormal(1, 0.02)
past_2_image_tensor = keras.layers.Input (input_shape)
past_1_image_tensor = keras.layers.Input (input_shape)
def model1(input_shape):
input = keras.layers.Input (input_shape)
x = input
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(ngf, kernel_size=7, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
x = ZeroPadding2D((1,1))(x)
x = Conv2D(ngf*2, kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
x = ZeroPadding2D((1,1))(x)
x = Conv2D(ngf*4, kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
return keras.models.Model(input, x)
def model3(input_shape):
input = keras.layers.Input (input_shape)
x = input
x = ZeroPadding2D((1,1))(x)
x = Conv2D(ngf*2, kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
x = ZeroPadding2D((1,1))(x)
x = Conv2D(ngf, kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(output_nc, kernel_size=7, kernel_initializer=conv_kernel_initializer, strides=1, padding='valid', use_bias=False)(x)
x = tanh(x)
return keras.models.Model(input, x)
x = Concatenate(axis=3)([ model1(input_shape)(past_2_image_tensor), model1(input_shape)(past_1_image_tensor) ])
unet = UNet(keras, tf, K.int_shape(x)[1:], ngf*4, num_downs=num_downs, ngf=ngf*4*2, #ngf=ngf*4*4,
use_dropout=use_dropout)
x = unet(x)
x = model3 ( K.int_shape(x)[1:] ) (x)
return keras.models.Model ( [past_2_image_tensor,past_1_image_tensor], x )
def Resnet(keras, tf, input_shape, output_nc, ngf=64, use_dropout=False, n_blocks=6):
Conv2D = keras.layers.convolutional.Conv2D
Conv2DTranspose = keras.layers.convolutional.Conv2DTranspose
LeakyReLU = keras.layers.advanced_activations.LeakyReLU
BatchNormalization = keras.layers.BatchNormalization
ReLU = keras.layers.ReLU
Add = keras.layers.Add
tanh = keras.layers.Activation('tanh')
ReflectionPadding2D = ReflectionPadding2DClass(keras, tf)
ZeroPadding2D = keras.layers.ZeroPadding2D
Dropout = keras.layers.Dropout
Concatenate = keras.layers.Concatenate
conv_kernel_initializer = keras.initializers.RandomNormal(0, 0.02)
norm_gamma_initializer = keras.initializers.RandomNormal(1, 0.02)
use_bias = False
input = keras.layers.Input (input_shape)
def ResnetBlock(dim, use_dropout, use_bias):
def func(inp):
x = inp
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(dim, kernel_size=3, kernel_initializer=conv_kernel_initializer, padding='valid', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
if use_dropout:
x = Dropout(0.5)(x)
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(dim, kernel_size=3, kernel_initializer=conv_kernel_initializer, padding='valid', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
return Add()([x,inp])
return func
x = input
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(ngf, kernel_size=7, kernel_initializer=conv_kernel_initializer, padding='valid', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
n_downsampling = 2
for i in range(n_downsampling):
x = ZeroPadding2D( (1,1) ) (x)
x = Conv2D(ngf * (2**i) * 2, kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=2, padding='valid', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
for i in range(n_blocks):
x = ResnetBlock(ngf*(2**n_downsampling), use_dropout=use_dropout, use_bias=use_bias)(x)
for i in range(n_downsampling):
x = Conv2DTranspose( int(ngf* (2**(n_downsampling - i)) /2), kernel_size=3, kernel_initializer=conv_kernel_initializer, strides=2, padding='same', output_padding=1, use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = ReLU()(x)
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(output_nc, kernel_size=7, kernel_initializer=conv_kernel_initializer, padding='valid')(x)
x = tanh(x)
return keras.models.Model(input, x)
def NLayerDiscriminator(keras, tf, input_shape, ndf=64, n_layers=3, use_sigmoid=False):
Conv2D = keras.layers.convolutional.Conv2D
LeakyReLU = keras.layers.advanced_activations.LeakyReLU
BatchNormalization = keras.layers.BatchNormalization
sigmoid = keras.layers.Activation('sigmoid')
ZeroPadding2D = keras.layers.ZeroPadding2D
conv_kernel_initializer = keras.initializers.RandomNormal(0, 0.02)
norm_gamma_initializer = keras.initializers.RandomNormal(1, 0.02)
use_bias = False
input = keras.layers.Input (input_shape, name="NLayerDiscriminatorInput") ###
x = input
x = ZeroPadding2D( (1,1) ) (x)
x = Conv2D(ndf, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=2, padding='valid', use_bias=use_bias)(x)
x = LeakyReLU(0.2)(x)
nf_mult = 1
nf_mult_prev = 1
for n in range(1, n_layers):
nf_mult = min(2**n, 8)
x = ZeroPadding2D( (1,1) ) (x)
x = Conv2D(ndf * nf_mult, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=2, padding='valid', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = LeakyReLU(0.2)(x)
nf_mult = min(2**n_layers, 8)
#x = ZeroPadding2D( (1,1) ) (x)
x = Conv2D(ndf * nf_mult, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=1, padding='same', use_bias=use_bias)(x)
x = BatchNormalization( gamma_initializer=norm_gamma_initializer )(x)
x = LeakyReLU(0.2)(x)
#x = ZeroPadding2D( (1,1) ) (x)
x = Conv2D(1, kernel_size=4, kernel_initializer=conv_kernel_initializer, strides=1, padding='same', use_bias=use_bias)(x)
if use_sigmoid:
x = sigmoid(x)
return keras.models.Model (input,x)
def ReflectionPadding2DClass(keras, tf):
class ReflectionPadding2D(keras.layers.Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [keras.layers.InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, s):
""" If you are using "channels_last" configuration"""
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad,h_pad = self.padding
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
return ReflectionPadding2D

209
nnlib/devicelib.py Normal file
View file

@ -0,0 +1,209 @@
from .pynvml import *
class devicelib:
class Config():
force_best_gpu_idx = -1
multi_gpu = False
force_gpu_idxs = None
choose_worst_gpu = False
gpu_idxs = []
gpu_total_vram_gb = 0
allow_growth = True
float16 = False
cpu_only = False
def __init__ (self, force_best_gpu_idx = -1,
multi_gpu = False,
force_gpu_idxs = None,
choose_worst_gpu = False,
allow_growth = True,
float16 = False,
cpu_only = False,
**in_options):
self.float16 = float16
if cpu_only or not devicelib.hasNVML():
self.cpu_only = True
else:
self.force_best_gpu_idx = force_best_gpu_idx
self.multi_gpu = multi_gpu
self.force_gpu_idxs = force_gpu_idxs
self.choose_worst_gpu = choose_worst_gpu
self.allow_growth = allow_growth
self.gpu_idxs = []
if force_gpu_idxs is not None:
for idx in force_gpu_idxs.split(','):
idx = int(idx)
if devicelib.isValidDeviceIdx(idx):
self.gpu_idxs.append(idx)
else:
gpu_idx = force_best_gpu_idx if (force_best_gpu_idx >= 0 and devicelib.isValidDeviceIdx(force_best_gpu_idx)) else devicelib.getBestDeviceIdx() if not choose_worst_gpu else devicelib.getWorstDeviceIdx()
if gpu_idx != -1:
if self.multi_gpu:
self.gpu_idxs = devicelib.getDeviceIdxsEqualModel( gpu_idx )
if len(self.gpu_idxs) <= 1:
self.multi_gpu = False
else:
self.gpu_idxs = [gpu_idx]
if len(self.gpu_idxs) == 0:
self.cpu_only = True
else:
self.cpu_only = False
self.gpu_total_vram_gb = devicelib.getDeviceVRAMTotalGb ( self.gpu_idxs[0] )
@staticmethod
def hasNVML():
try:
nvmlInit()
nvmlShutdown()
except e:
return False
return True
@staticmethod
def getDevicesWithAtLeastFreeMemory(freememsize):
result = []
try:
nvmlInit()
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total - memInfo.used) >= freememsize:
result.append (i)
nvmlShutdown()
except:
pass
return result
@staticmethod
def getDevicesWithAtLeastTotalMemoryGB(totalmemsize_gb):
result = []
try:
nvmlInit()
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if (memInfo.total) >= totalmemsize_gb*1024*1024*1024:
result.append (i)
nvmlShutdown()
except:
pass
return result
@staticmethod
def getAllDevicesIdxsList ():
result = []
try:
nvmlInit()
result = [ i for i in range(0, nvmlDeviceGetCount() ) ]
nvmlShutdown()
except:
pass
return result
@staticmethod
def getDeviceVRAMFree (idx):
result = 0
try:
nvmlInit()
if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle )
result = (memInfo.total - memInfo.used)
nvmlShutdown()
except:
pass
return result
@staticmethod
def getDeviceVRAMTotalGb (idx):
result = 0
try:
nvmlInit()
if idx < nvmlDeviceGetCount():
handle = nvmlDeviceGetHandleByIndex(idx)
memInfo = nvmlDeviceGetMemoryInfo( handle )
result = memInfo.total / (1024*1024*1024)
nvmlShutdown()
result = round(result)
except:
pass
return result
@staticmethod
def getBestDeviceIdx():
idx = -1
try:
nvmlInit()
idx_mem = 0
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if memInfo.total > idx_mem:
idx = i
idx_mem = memInfo.total
nvmlShutdown()
except:
pass
return idx
@staticmethod
def getWorstDeviceIdx():
idx = -1
try:
nvmlInit()
idx_mem = sys.maxsize
for i in range(0, nvmlDeviceGetCount() ):
handle = nvmlDeviceGetHandleByIndex(i)
memInfo = nvmlDeviceGetMemoryInfo( handle )
if memInfo.total < idx_mem:
idx = i
idx_mem = memInfo.total
nvmlShutdown()
except:
pass
return idx
@staticmethod
def isValidDeviceIdx(idx):
result = False
try:
nvmlInit()
result = (idx < nvmlDeviceGetCount())
nvmlShutdown()
except:
pass
return result
@staticmethod
def getDeviceIdxsEqualModel(idx):
result = []
try:
nvmlInit()
idx_name = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
for i in range(0, nvmlDeviceGetCount() ):
if nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(i)).decode() == idx_name:
result.append (i)
nvmlShutdown()
except:
pass
return result
@staticmethod
def getDeviceName (idx):
result = ''
try:
nvmlInit()
if idx < nvmlDeviceGetCount():
result = nvmlDeviceGetName(nvmlDeviceGetHandleByIndex(idx)).decode()
nvmlShutdown()
except:
pass
return result

698
nnlib/nnlib.py Normal file
View file

@ -0,0 +1,698 @@
import os
import sys
import contextlib
from utils import std_utils
from .devicelib import devicelib
class nnlib(object):
device = devicelib #forwards nnlib.devicelib to device in order to use nnlib as standalone lib
DeviceConfig = devicelib.Config
prefer_DeviceConfig = DeviceConfig() #default is one best GPU
dlib = None
keras = None
keras_contrib = None
tf = None
tf_sess = None
code_import_tf = None
code_import_keras = None
code_import_keras_contrib = None
code_import_all = None
code_import_dlib = None
tf_dssim = None
tf_ssim = None
tf_resize_like = None
tf_rgb_to_lab = None
tf_lab_to_rgb = None
tf_image_histogram = None
modelify = None
ReflectionPadding2D = None
DSSIMLoss = None
DSSIMMaskLoss = None
PixelShuffler = None
ResNet = None
UNet = None
UNetTemporalPredictor = None
NLayerDiscriminator = None
code_import_tf_string = \
"""
tf = nnlib.tf
tf_sess = nnlib.tf_sess
tf_total_variation = tf.image.total_variation
tf_dssim = nnlib.tf_dssim
tf_ssim = nnlib.tf_ssim
tf_resize_like = nnlib.tf_resize_like
tf_image_histogram = nnlib.tf_image_histogram
tf_rgb_to_lab = nnlib.tf_rgb_to_lab
tf_lab_to_rgb = nnlib.tf_lab_to_rgb
"""
code_import_keras_string = \
"""
keras = nnlib.keras
K = keras.backend
Input = keras.layers.Input
Dense = keras.layers.Dense
Conv2D = keras.layers.convolutional.Conv2D
Conv2DTranspose = keras.layers.convolutional.Conv2DTranspose
MaxPooling2D = keras.layers.MaxPooling2D
BatchNormalization = keras.layers.BatchNormalization
LeakyReLU = keras.layers.advanced_activations.LeakyReLU
ReLU = keras.layers.ReLU
tanh = keras.layers.Activation('tanh')
sigmoid = keras.layers.Activation('sigmoid')
Dropout = keras.layers.Dropout
Add = keras.layers.Add
Concatenate = keras.layers.Concatenate
Flatten = keras.layers.Flatten
Reshape = keras.layers.Reshape
ZeroPadding2D = keras.layers.ZeroPadding2D
RandomNormal = keras.initializers.RandomNormal
Model = keras.models.Model
Adam = keras.optimizers.Adam
modelify = nnlib.modelify
ReflectionPadding2D = nnlib.ReflectionPadding2D
DSSIMLoss = nnlib.DSSIMLoss
DSSIMMaskLoss = nnlib.DSSIMMaskLoss
PixelShuffler = nnlib.PixelShuffler
"""
code_import_keras_contrib_string = \
"""
keras_contrib = nnlib.keras_contrib
GroupNormalization = keras_contrib.layers.GroupNormalization
InstanceNormalization = keras_contrib.layers.InstanceNormalization
"""
code_import_dlib_string = \
"""
dlib = nnlib.dlib
"""
code_import_all_string = \
"""
ResNet = nnlib.ResNet
UNet = nnlib.UNet
UNetTemporalPredictor = nnlib.UNetTemporalPredictor
NLayerDiscriminator = nnlib.NLayerDiscriminator
"""
@staticmethod
def import_tf(device_config = None):
if nnlib.tf is not None:
return nnlib.code_import_tf
if device_config is None:
device_config = nnlib.prefer_DeviceConfig
else:
nnlib.prefer_DeviceConfig = device_config
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__()
else:
suppressor = None
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES')
os.environ['TF_MIN_GPU_MULTIPROCESSOR_COUNT'] = '2'
import tensorflow as tf
nnlib.tf = tf
if device_config.cpu_only:
config = tf.ConfigProto( device_count = {'GPU': 0} )
else:
config = tf.ConfigProto()
visible_device_list = ''
for idx in device_config.gpu_idxs:
visible_device_list += str(idx) + ','
config.gpu_options.visible_device_list=visible_device_list[:-1]
config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = device_config.allow_growth
nnlib.tf_sess = tf.Session(config=config)
if suppressor is not None:
suppressor.__exit__()
nnlib.__initialize_tf_functions()
nnlib.code_import_tf = compile (nnlib.code_import_tf_string,'','exec')
return nnlib.code_import_tf
@staticmethod
def __initialize_tf_functions():
tf = nnlib.tf
def tf_dssim_(max_value=1.0):
def func(t1,t2):
return (1.0 - tf.image.ssim (t1, t2, max_value)) / 2.0
return func
nnlib.tf_dssim = tf_dssim_
def tf_ssim_(max_value=1.0):
def func(t1,t2):
return tf.image.ssim (t1, t2, max_value)
return func
nnlib.tf_ssim = tf_ssim_
def tf_resize_like_(ref_tensor):
def func(input_tensor):
H, W = ref_tensor.get_shape()[1], ref_tensor.get_shape()[2]
return tf.image.resize_bilinear(input_tensor, [H.value, W.value])
return func
nnlib.tf_resize_like = tf_resize_like_
def tf_rgb_to_lab():
def func(rgb_input):
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(rgb_input))
return func
nnlib.tf_rgb_to_lab = tf_rgb_to_lab
def tf_lab_to_rgb():
def func(lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))
return func
nnlib.tf_lab_to_rgb = tf_lab_to_rgb
def tf_image_histogram():
def func(input):
x = input
x += 1 / 255.0
output = []
for i in range(256, 0, -1):
v = i / 255.0
y = (x - v) * 1000
y = tf.clip_by_value (y, -1.0, 0.0) + 1
output.append ( tf.reduce_sum (y) )
x -= y*v
return tf.stack ( output[::-1] )
return func
nnlib.tf_image_histogram = tf_image_histogram
@staticmethod
def import_keras(device_config = None):
if nnlib.keras is not None:
return nnlib.code_import_keras
nnlib.import_tf(device_config)
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor = std_utils.suppress_stdout_stderr().__enter__()
import keras as keras_
nnlib.keras = keras_
nnlib.keras.backend.tensorflow_backend.set_session(nnlib.tf_sess)
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
suppressor.__exit__()
nnlib.__initialize_keras_functions()
nnlib.code_import_keras = compile (nnlib.code_import_keras_string,'','exec')
@staticmethod
def __initialize_keras_functions():
tf = nnlib.tf
keras = nnlib.keras
def modelify(model_functor):
def func(tensor):
return keras.models.Model (tensor, model_functor(tensor))
return func
nnlib.modelify = modelify
class ReflectionPadding2D(keras.layers.Layer):
def __init__(self, padding=(1, 1), **kwargs):
self.padding = tuple(padding)
self.input_spec = [keras.layers.InputSpec(ndim=4)]
super(ReflectionPadding2D, self).__init__(**kwargs)
def compute_output_shape(self, s):
""" If you are using "channels_last" configuration"""
return (s[0], s[1] + 2 * self.padding[0], s[2] + 2 * self.padding[1], s[3])
def call(self, x, mask=None):
w_pad,h_pad = self.padding
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
nnlib.ReflectionPadding2D = ReflectionPadding2D
class DSSIMLoss(object):
def __init__(self, is_tanh=False):
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
if not self.is_tanh:
return (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
else:
return (1.0 - tf.image.ssim ((y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0
nnlib.DSSIMLoss = DSSIMLoss
class DSSIMLoss(object):
def __init__(self, is_tanh=False):
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
if not self.is_tanh:
loss = (1.0 - tf.image.ssim (y_true, y_pred, 1.0)) / 2.0
else:
loss = (1.0 - tf.image.ssim ( (y_true/2+0.5), (y_pred/2+0.5), 1.0)) / 2.0
return loss
nnlib.DSSIMLoss = DSSIMLoss
class DSSIMMaskLoss(object):
def __init__(self, mask_list, is_tanh=False):
self.mask_list = mask_list
self.is_tanh = is_tanh
def __call__(self,y_true, y_pred):
total_loss = None
for mask in self.mask_list:
if not self.is_tanh:
loss = (1.0 - tf.image.ssim (y_true*mask, y_pred*mask, 1.0)) / 2.0
else:
loss = (1.0 - tf.image.ssim ( (y_true/2+0.5)*(mask/2+0.5), (y_pred/2+0.5)*(mask/2+0.5), 1.0)) / 2.0
if total_loss is None:
total_loss = loss
else:
total_loss += loss
return total_loss
nnlib.DSSIMMaskLoss = DSSIMMaskLoss
class PixelShuffler(keras.layers.Layer):
def __init__(self, size=(2, 2), data_format=None, **kwargs):
super(PixelShuffler, self).__init__(**kwargs)
self.data_format = keras.backend.common.normalize_data_format(data_format)
self.size = keras.utils.conv_utils.normalize_tuple(size, 2, 'size')
def call(self, inputs):
input_shape = keras.backend.int_shape(inputs)
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
batch_size, c, h, w = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = keras.backend.reshape(inputs, (batch_size, rh, rw, oc, h, w))
out = keras.backend.permute_dimensions(out, (0, 3, 4, 1, 5, 2))
out = keras.backend.reshape(out, (batch_size, oc, oh, ow))
return out
elif self.data_format == 'channels_last':
batch_size, h, w, c = input_shape
if batch_size is None:
batch_size = -1
rh, rw = self.size
oh, ow = h * rh, w * rw
oc = c // (rh * rw)
out = keras.backend.reshape(inputs, (batch_size, h, w, rh, rw, oc))
out = keras.backend.permute_dimensions(out, (0, 1, 3, 2, 4, 5))
out = keras.backend.reshape(out, (batch_size, oh, ow, oc))
return out
def compute_output_shape(self, input_shape):
if len(input_shape) != 4:
raise ValueError('Inputs should have rank ' +
str(4) +
'; Received input shape:', str(input_shape))
if self.data_format == 'channels_first':
height = input_shape[2] * self.size[0] if input_shape[2] is not None else None
width = input_shape[3] * self.size[1] if input_shape[3] is not None else None
channels = input_shape[1] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[1]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
channels,
height,
width)
elif self.data_format == 'channels_last':
height = input_shape[1] * self.size[0] if input_shape[1] is not None else None
width = input_shape[2] * self.size[1] if input_shape[2] is not None else None
channels = input_shape[3] // self.size[0] // self.size[1]
if channels * self.size[0] * self.size[1] != input_shape[3]:
raise ValueError('channels of input and size are incompatible')
return (input_shape[0],
height,
width,
channels)
def get_config(self):
config = {'size': self.size,
'data_format': self.data_format}
base_config = super(PixelShuffler, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
nnlib.PixelShuffler = PixelShuffler
@staticmethod
def import_keras_contrib(device_config = None):
if nnlib.keras_contrib is not None:
return nnlib.code_import_keras_contrib
import keras_contrib as keras_contrib_
nnlib.keras_contrib = keras_contrib_
nnlib.__initialize_keras_contrib_functions()
nnlib.code_import_keras_contrib = compile (nnlib.code_import_keras_contrib_string,'','exec')
@staticmethod
def __initialize_keras_contrib_functions():
pass
@staticmethod
def import_dlib( device_config = None):
if nnlib.dlib is not None:
return nnlib.code_import_dlib
import dlib as dlib_
nnlib.dlib = dlib_
if not device_config.cpu_only and len(device_config.gpu_idxs) > 0:
nnlib.dlib.cuda.set_device(device_config.gpu_idxs[0])
nnlib.code_import_dlib = compile (nnlib.code_import_dlib_string,'','exec')
@staticmethod
def import_all(device_config = None):
if nnlib.code_import_all is None:
nnlib.import_tf(device_config)
nnlib.import_keras(device_config)
nnlib.import_keras_contrib(device_config)
nnlib.code_import_all = compile (nnlib.code_import_tf_string + '\n'
+ nnlib.code_import_keras_string + '\n'
+ nnlib.code_import_keras_contrib_string
+ nnlib.code_import_all_string,'','exec')
nnlib.__initialize_all_functions()
return nnlib.code_import_all
@staticmethod
def __initialize_all_functions():
def ResNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False):
exec (nnlib.import_all(), locals(), globals())
if not use_batch_norm:
use_bias = True
def XNormalization(x):
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)#GroupNormalization (axis=3, groups=K.int_shape (x)[3] // 4, gamma_initializer=RandomNormal(1., 0.02))(x)
else:
use_bias = False
def XNormalization(x):
return BatchNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.convolutional.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def func(input):
def ResnetBlock(dim):
def func(input):
x = input
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(dim, 3, 1, padding='valid')(x)
x = XNormalization(x)
x = ReLU()(x)
if use_dropout:
x = Dropout(0.5)(x)
x = ReflectionPadding2D((1,1))(x)
x = Conv2D(dim, 3, 1, padding='valid')(x)
x = XNormalization(x)
x = ReLU()(x)
return Add()([x,input])
return func
x = input
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(ngf, 7, 1, 'valid')(x)
x = ReLU()(XNormalization(Conv2D(ngf*2, 4, 2, 'same')(x)))
x = ReLU()(XNormalization(Conv2D(ngf*4, 4, 2, 'same')(x)))
for i in range(n_blocks):
x = ResnetBlock(ngf*4)(x)
x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf*2 *4, 3, 1, 'same')(x))))
x = ReLU()(XNormalization(PixelShuffler()(Conv2D(ngf *4, 3, 1, 'same')(x))))
x = ReflectionPadding2D((3,3))(x)
x = Conv2D(output_nc, 7, 1, 'valid')(x)
x = tanh(x)
return x
return func
nnlib.ResNet = ResNet
# Defines the Unet generator.
# |num_downs|: number of downsamplings in UNet. For example,
# if |num_downs| == 7, image of size 128x128 will become of size 1x1
# at the bottleneck
def UNet(output_nc, use_batch_norm, num_downs, ngf=64, use_dropout=False):
exec (nnlib.import_all(), locals(), globals())
if not use_batch_norm:
use_bias = True
def XNormalization(x):
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)#GroupNormalization (axis=3, groups=K.int_shape (x)[3] // 4, gamma_initializer=RandomNormal(1., 0.02))(x)
else:
use_bias = False
def XNormalization(x):
return BatchNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.convolutional.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def Conv2DTranspose(filters, kernel_size, strides=(1, 1), padding='valid', output_padding=None, data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer='glorot_uniform', bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.Conv2DTranspose(filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, output_padding=output_padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint)
def UNetSkipConnection(outer_nc, inner_nc, sub_model=None, outermost=False, innermost=False, use_dropout=False):
def func(inp):
x = inp
x = Conv2D(inner_nc, 4, 2, 'valid')(ReflectionPadding2D( (1,1) )(x))
x = XNormalization(x)
x = ReLU()(x)
if not innermost:
x = sub_model(x)
if not outermost:
x = Conv2DTranspose(outer_nc, 3, 2, 'same')(x)
x = XNormalization(x)
x = ReLU()(x)
if not innermost:
if use_dropout:
x = Dropout(0.5)(x)
x = Concatenate(axis=3)([inp, x])
else:
x = Conv2DTranspose(outer_nc, 3, 2, 'same')(x)
x = tanh(x)
return x
return func
def func(input):
unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=None, innermost=True)
#for i in range(num_downs - 5):
# unet_block = UNetSkipConnection(ngf * 8, ngf * 8, sub_model=unet_block, use_dropout=use_dropout)
unet_block = UNetSkipConnection(ngf * 4 , ngf * 8, sub_model=unet_block)
unet_block = UNetSkipConnection(ngf * 2 , ngf * 4, sub_model=unet_block)
unet_block = UNetSkipConnection(ngf , ngf * 2, sub_model=unet_block)
unet_block = UNetSkipConnection(output_nc, ngf , sub_model=unet_block, outermost=True)
return unet_block(input)
return func
nnlib.UNet = UNet
#predicts based on two past_image_tensors
def UNetTemporalPredictor(output_nc, use_batch_norm, num_downs, ngf=64, use_dropout=False):
exec (nnlib.import_all(), locals(), globals())
def func(inputs):
past_2_image_tensor, past_1_image_tensor = inputs
x = Concatenate(axis=3)([ past_2_image_tensor, past_1_image_tensor ])
x = UNet(3, use_batch_norm, num_downs=num_downs, ngf=ngf, use_dropout=use_dropout) (x)
return x
return func
nnlib.UNetTemporalPredictor = UNetTemporalPredictor
def NLayerDiscriminator(use_batch_norm, ndf=64, n_layers=3):
exec (nnlib.import_all(), locals(), globals())
if not use_batch_norm:
use_bias = True
def XNormalization(x):
return InstanceNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)#GroupNormalization (axis=3, groups=K.int_shape (x)[3] // 4, gamma_initializer=RandomNormal(1., 0.02))(x)
else:
use_bias = False
def XNormalization(x):
return BatchNormalization (axis=3, gamma_initializer=RandomNormal(1., 0.02))(x)
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
return keras.layers.convolutional.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
def func(input):
x = input
x = ZeroPadding2D((1,1))(x)
x = Conv2D( ndf, 4, 2, 'valid')(x)
x = LeakyReLU(0.2)(x)
for i in range(1, n_layers):
x = ZeroPadding2D((1,1))(x)
x = Conv2D( ndf * min(2 ** i, 8), 4, 2, 'valid')(x)
x = XNormalization(x)
x = LeakyReLU(0.2)(x)
x = ZeroPadding2D((1,1))(x)
x = Conv2D( ndf * min(2 ** n_layers, 8), 4, 1, 'valid')(x)
x = XNormalization(x)
x = LeakyReLU(0.2)(x)
x = ZeroPadding2D((1,1))(x)
return Conv2D( 1, 4, 1, 'valid')(x)
return func
nnlib.NLayerDiscriminator = NLayerDiscriminator
@staticmethod
def finalize_all():
if nnlib.keras_contrib is not None:
nnlib.keras_contrib = None
if nnlib.keras is not None:
nnlib.keras.backend.clear_session()
nnlib.keras = None
if nnlib.tf is not None:
nnlib.tf_sess.close()
nnlib.tf_sess = None
nnlib.tf = None

View file

@ -5,6 +5,7 @@ import cv2
import localization import localization
from scipy.spatial import Delaunay from scipy.spatial import Delaunay
from PIL import Image, ImageDraw, ImageFont from PIL import Image, ImageDraw, ImageFont
from nnlib import nnlib
def channel_hist_match(source, template, hist_match_threshold=255, mask=None): def channel_hist_match(source, template, hist_match_threshold=255, mask=None):
# Code borrowed from: # Code borrowed from:
@ -276,103 +277,21 @@ def reduce_colors (img_bgr, n_colors):
class TFLabConverter(): class TFLabConverter():
def __init__(self):
exec (nnlib.import_tf(), locals(), globals())
self.tf_sess = tf_sess
def __init__(self,):
import gpufmkmgr
self.tf_module = gpufmkmgr.import_tf() self.bgr_input_tensor = tf.placeholder("float", [None, None, 3])
self.tf_session = gpufmkmgr.get_tf_session() self.lab_input_tensor = tf.placeholder("float", [None, None, 3])
self.bgr_input_tensor = self.tf_module.placeholder("float", [None, None, 3]) self.lab_output_tensor = tf_rgb_to_lab()(self.bgr_input_tensor)
self.lab_input_tensor = self.tf_module.placeholder("float", [None, None, 3]) self.bgr_output_tensor = tf_lab_to_rgb()(self.lab_input_tensor)
self.lab_output_tensor = self.rgb_to_lab(self.tf_module, self.bgr_input_tensor)
self.bgr_output_tensor = self.lab_to_rgb(self.tf_module, self.lab_input_tensor)
def bgr2lab(self, bgr): def bgr2lab(self, bgr):
return self.tf_session.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr}) return self.tf_sess.run(self.lab_output_tensor, feed_dict={self.bgr_input_tensor: bgr})
def lab2bgr(self, lab): def lab2bgr(self, lab):
return self.tf_session.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab}) return self.tf_sess.run(self.bgr_output_tensor, feed_dict={self.lab_input_tensor: lab})
def rgb_to_lab(self, tf, rgb_input):
with tf.name_scope("rgb_to_lab"):
srgb_pixels = tf.reshape(rgb_input, [-1, 3])
with tf.name_scope("srgb_to_xyz"):
linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32)
exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32)
rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask
rgb_to_xyz = tf.constant([
# X Y Z
[0.412453, 0.212671, 0.019334], # R
[0.357580, 0.715160, 0.119193], # G
[0.180423, 0.072169, 0.950227], # B
])
xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz)
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("xyz_to_cielab"):
# convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn)
# normalize for D65 white point
xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754])
epsilon = 6/29
linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32)
exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32)
fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask
# convert to lab
fxfyfz_to_lab = tf.constant([
# l a b
[ 0.0, 500.0, 0.0], # fx
[116.0, -500.0, 200.0], # fy
[ 0.0, 0.0, -200.0], # fz
])
lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0])
return tf.reshape(lab_pixels, tf.shape(rgb_input))
def lab_to_rgb(self, tf, lab):
with tf.name_scope("lab_to_rgb"):
lab_pixels = tf.reshape(lab, [-1, 3])
# https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions
with tf.name_scope("cielab_to_xyz"):
# convert to fxfyfz
lab_to_fxfyfz = tf.constant([
# fx fy fz
[1/116.0, 1/116.0, 1/116.0], # l
[1/500.0, 0.0, 0.0], # a
[ 0.0, 0.0, -1/200.0], # b
])
fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz)
# convert to xyz
epsilon = 6/29
linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32)
exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32)
xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask
# denormalize for D65 white point
xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754])
with tf.name_scope("xyz_to_srgb"):
xyz_to_rgb = tf.constant([
# r g b
[ 3.2404542, -0.9692660, 0.0556434], # x
[-1.5371385, 1.8760108, -0.2040259], # y
[-0.4985314, 0.0415560, 1.0572252], # z
])
rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb)
# avoid a slightly negative number messing up the conversion
rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0)
linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32)
exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32)
srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask
return tf.reshape(srgb_pixels, tf.shape(lab))