mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
SAE: added new archi 'vg'
This commit is contained in:
parent
d66829aae4
commit
f0a20b46d3
5 changed files with 378 additions and 119 deletions
120
__dev/test.py
120
__dev/test.py
|
@ -305,20 +305,98 @@ def get_transform_mat (image_landmarks, output_size, scale=1.0):
|
||||||
# alignments.append (dflimg.get_source_landmarks())
|
# alignments.append (dflimg.get_source_landmarks())
|
||||||
import mathlib
|
import mathlib
|
||||||
def main():
|
def main():
|
||||||
|
|
||||||
|
def f ( *args, asd=True, **kwargs ):
|
||||||
|
import code
|
||||||
|
code.interact(local=dict(globals(), **locals()))
|
||||||
|
|
||||||
|
f( 1, asd=True, bg=0)
|
||||||
|
|
||||||
from nnlib import nnlib
|
from nnlib import nnlib
|
||||||
exec( nnlib.import_all(), locals(), globals() )
|
exec( nnlib.import_all( device_config=nnlib.device.Config() ), locals(), globals() )
|
||||||
PMLTile = nnlib.PMLTile
|
PMLTile = nnlib.PMLTile
|
||||||
PMLK = nnlib.PMLK
|
PMLK = nnlib.PMLK
|
||||||
|
|
||||||
|
class DSSIMObjective:
|
||||||
|
"""Computes DSSIM index between img1 and img2.
|
||||||
|
This function is based on the standard SSIM implementation from:
|
||||||
|
Wang, Z., Bovik, A. C., Sheikh, H. R., & Simoncelli, E. P. (2004).
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, k1=0.01, k2=0.03, max_value=1.0):
|
||||||
|
self.__name__ = 'DSSIMObjective'
|
||||||
|
self.k1 = k1
|
||||||
|
self.k2 = k2
|
||||||
|
self.max_value = max_value
|
||||||
|
self.c1 = (self.k1 * self.max_value) ** 2
|
||||||
|
self.c2 = (self.k2 * self.max_value) ** 2
|
||||||
|
self.dim_ordering = K.image_data_format()
|
||||||
|
self.backend = K.backend()
|
||||||
|
|
||||||
|
def __int_shape(self, x):
|
||||||
|
return K.int_shape(x) if self.backend == 'tensorflow' else K.shape(x)
|
||||||
|
|
||||||
|
def __call__(self, y_true, y_pred):
|
||||||
|
ch = K.shape(y_pred)[-1]
|
||||||
|
|
||||||
|
def softmax(x, axis=-1):
|
||||||
|
y = np.exp(x - np.max(x, axis, keepdims=True))
|
||||||
|
return y / np.sum(y, axis, keepdims=True)
|
||||||
|
|
||||||
|
def _fspecial_gauss(size, sigma):
|
||||||
|
#Function to mimic the 'fspecial' gaussian MATLAB function.
|
||||||
|
coords = np.arange(0, size, dtype=K.floatx())
|
||||||
|
coords -= (size - 1 ) / 2.0
|
||||||
|
g = coords**2
|
||||||
|
g *= ( -0.5 / (sigma**2) )
|
||||||
|
g = np.reshape (g, (1,-1)) + np.reshape(g, (-1,1) )
|
||||||
|
g = np.reshape (g, (1,-1))
|
||||||
|
g = softmax(g)
|
||||||
|
g = K.constant ( np.reshape (g, (size, size, 1, 1)) )
|
||||||
|
g = K.tile (g, (1,1,ch,1))
|
||||||
|
return g
|
||||||
|
|
||||||
|
kernel = _fspecial_gauss(11,1.5)
|
||||||
|
|
||||||
|
def reducer(x):
|
||||||
|
return K.depthwise_conv2d(x, kernel, strides=(1, 1), padding='valid')
|
||||||
|
|
||||||
|
c1 = (self.k1 * self.max_value) ** 2
|
||||||
|
c2 = (self.k2 * self.max_value) ** 2
|
||||||
|
|
||||||
|
mean0 = reducer(y_true)
|
||||||
|
mean1 = reducer(y_pred)
|
||||||
|
num0 = mean0 * mean1 * 2.0
|
||||||
|
den0 = K.square(mean0) + K.square(mean1)
|
||||||
|
luminance = (num0 + c1) / (den0 + c1)
|
||||||
|
|
||||||
|
num1 = reducer(y_true * y_pred) * 2.0
|
||||||
|
den1 = reducer(K.square(y_true) + K.square(y_pred))
|
||||||
|
c2 *= 1.0 #compensation factor
|
||||||
|
cs = (num1 - num0 + c2) / (den1 - den0 + c2)
|
||||||
|
|
||||||
|
ssim_val = K.mean(luminance * cs, axis=(-3, -2) )
|
||||||
|
return K.mean( (1.0 - ssim_val ) / 2.0 )
|
||||||
|
|
||||||
image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0
|
image = cv2.imread('D:\\DeepFaceLab\\test\\00000.png').astype(np.float32) / 255.0
|
||||||
image = cv2.resize ( image, (128,128) )
|
|
||||||
|
|
||||||
image = cv2.cvtColor (image, cv2.COLOR_BGR2GRAY)
|
|
||||||
image = np.expand_dims (image, -1)
|
|
||||||
image = np.expand_dims (image, 0)
|
image = np.expand_dims (image, 0)
|
||||||
image_shape = image.shape
|
image_shape = image.shape
|
||||||
|
|
||||||
t = K.placeholder ( image_shape ) #K.constant ( np.ones ( (10,) ) )
|
image2 = cv2.imread('D:\\DeepFaceLab\\test\\00001.png').astype(np.float32) / 255.0
|
||||||
|
image2 = np.expand_dims (image2, 0)
|
||||||
|
image2_shape = image2.shape
|
||||||
|
|
||||||
|
#image = np.random.uniform ( size=(1,256,256,3) )
|
||||||
|
#image2 = np.random.uniform ( size=(1,256,256,3) )
|
||||||
|
|
||||||
|
t1 = K.placeholder ( (None,) + image_shape[1:], name="t1" )
|
||||||
|
t2 = K.placeholder ( (None,) + image_shape[1:], name="t2" )
|
||||||
|
|
||||||
|
l1_t = DSSIMObjective() (t1,t2 )
|
||||||
|
l1, = K.function([t1, t2],[l1_t]) ([image, image2])
|
||||||
|
|
||||||
|
print (l1)
|
||||||
|
|
||||||
import code
|
import code
|
||||||
code.interact(local=dict(globals(), **locals()))
|
code.interact(local=dict(globals(), **locals()))
|
||||||
|
|
||||||
|
@ -1279,4 +1357,34 @@ O[i0, i1, i2, i3: (1 + 1 - 1)/1, (64 + 1 - 1)/1, (64 + 2 - 1)/2, (1 + 1 - 1)/1]
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
#os.environ["KERAS_BACKEND"] = "plaidml.keras.backend"
|
||||||
|
#os.environ["PLAIDML_DEVICE_IDS"] = "opencl_nvidia_geforce_gtx_1060_6gb.0"
|
||||||
|
#import keras
|
||||||
|
#K = keras.backend
|
||||||
|
#
|
||||||
|
#image = np.random.uniform ( size=(1,256,256,3) )
|
||||||
|
#image2 = np.random.uniform ( size=(1,256,256,3) )
|
||||||
|
#
|
||||||
|
#y_true = K.placeholder ( (None,) + image.shape[1:] )
|
||||||
|
#y_pred = K.placeholder ( (None,) + image2.shape[1:] )
|
||||||
|
#
|
||||||
|
#def reducer(x):
|
||||||
|
# shape = K.shape(x)
|
||||||
|
# x = K.reshape(x, (-1, shape[-3] , shape[-2], shape[-1]) )
|
||||||
|
# y = K.depthwise_conv2d(x, K.constant(np.ones( (11,11,3,1) )), strides=(1, 1), padding='valid' )
|
||||||
|
# y_shape = K.shape(y)
|
||||||
|
# return K.reshape(y, (shape[0], y_shape[1], y_shape[2], y_shape[3] ) )
|
||||||
|
#
|
||||||
|
#mean0 = reducer(y_true)
|
||||||
|
#mean1 = reducer(y_pred)
|
||||||
|
#luminance = mean0 * mean1
|
||||||
|
#cs = y_true * y_pred
|
||||||
|
#
|
||||||
|
#result = K.function([y_true, y_pred],[luminance, cs]) ([image, image2])
|
||||||
|
#
|
||||||
|
#print (result)
|
||||||
|
#import code
|
||||||
|
#code.interact(local=dict(globals(), **locals()))
|
||||||
|
|
||||||
|
|
||||||
main()
|
main()
|
|
@ -12,7 +12,7 @@ from utils import image_utils
|
||||||
import cv2
|
import cv2
|
||||||
import models
|
import models
|
||||||
|
|
||||||
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=10, debug=False, **in_options):
|
def trainerThread (input_queue, output_queue, training_data_src_dir, training_data_dst_dir, model_path, model_name, save_interval_min=15, debug=False, **in_options):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
|
@ -39,10 +39,11 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
||||||
**in_options)
|
**in_options)
|
||||||
|
|
||||||
is_reached_goal = model.is_reached_epoch_goal()
|
is_reached_goal = model.is_reached_epoch_goal()
|
||||||
|
is_upd_save_time_after_train = False
|
||||||
def model_save():
|
def model_save():
|
||||||
if not debug and not is_reached_goal:
|
if not debug and not is_reached_goal:
|
||||||
model.save()
|
model.save()
|
||||||
|
is_upd_save_time_after_train = True
|
||||||
|
|
||||||
def send_preview():
|
def send_preview():
|
||||||
if not debug:
|
if not debug:
|
||||||
|
@ -65,11 +66,15 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
||||||
print('Starting. Press "Enter" to stop training and save model.')
|
print('Starting. Press "Enter" to stop training and save model.')
|
||||||
|
|
||||||
last_save_time = time.time()
|
last_save_time = time.time()
|
||||||
|
|
||||||
for i in itertools.count(0,1):
|
for i in itertools.count(0,1):
|
||||||
if not debug:
|
if not debug:
|
||||||
if not is_reached_goal:
|
if not is_reached_goal:
|
||||||
loss_string = model.train_one_epoch()
|
loss_string = model.train_one_epoch()
|
||||||
|
if is_upd_save_time_after_train:
|
||||||
|
#save resets plaidML programs, so upd last_save_time only after plaidML rebuild them
|
||||||
|
last_save_time = time.time()
|
||||||
|
|
||||||
print (loss_string, end='\r')
|
print (loss_string, end='\r')
|
||||||
if model.get_target_epoch() != 0 and model.is_reached_epoch_goal():
|
if model.get_target_epoch() != 0 and model.is_reached_epoch_goal():
|
||||||
print ('Reached target epoch.')
|
print ('Reached target epoch.')
|
||||||
|
@ -78,7 +83,7 @@ def trainerThread (input_queue, output_queue, training_data_src_dir, training_da
|
||||||
print ('You can use preview now.')
|
print ('You can use preview now.')
|
||||||
|
|
||||||
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
if not is_reached_goal and (time.time() - last_save_time) >= save_interval_min*60:
|
||||||
last_save_time = time.time()
|
last_save_time = time.time()
|
||||||
model_save()
|
model_save()
|
||||||
send_preview()
|
send_preview()
|
||||||
|
|
||||||
|
|
|
@ -327,9 +327,9 @@ class ModelBase(object):
|
||||||
|
|
||||||
self.epoch += 1
|
self.epoch += 1
|
||||||
|
|
||||||
if epoch_time >= 10000:
|
if epoch_time >= 10:
|
||||||
#............."Saving...
|
#............."Saving...
|
||||||
loss_string = "Training [#{0:06d}][{1:.5s}s]".format ( self.epoch, '{:0.4f}'.format(epoch_time / 1000) )
|
loss_string = "Training [#{0:06d}][{1:.5s}s]".format ( self.epoch, '{:0.4f}'.format(epoch_time) )
|
||||||
else:
|
else:
|
||||||
loss_string = "Training [#{0:06d}][{1:04d}ms]".format ( self.epoch, int(epoch_time*1000) )
|
loss_string = "Training [#{0:06d}][{1:04d}ms]".format ( self.epoch, int(epoch_time*1000) )
|
||||||
for (loss_name, loss_value) in losses:
|
for (loss_name, loss_value) in losses:
|
||||||
|
|
|
@ -30,7 +30,7 @@ class SAEModel(ModelBase):
|
||||||
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
self.options['resolution'] = input_int("Resolution (64,128 ?:help skip:128) : ", default_resolution, [64,128], help_message="More resolution requires more VRAM.")
|
||||||
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
self.options['face_type'] = input_str ("Half or Full face? (h/f, ?:help skip:f) : ", default_face_type, ['h','f'], help_message="Half face has better resolution, but covers less area of cheeks.").lower()
|
||||||
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
self.options['learn_mask'] = input_bool ("Learn mask? (y/n, ?:help skip:y) : ", True, help_message="Learning mask can help model to recognize face directions. Learn without mask can reduce model size, in this case converter forced to use 'not predicted mask' that is not smooth as predicted. Model with style values can be learned without mask and produce same quality result.")
|
||||||
self.options['archi'] = input_str ("AE architecture (df, liae, ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae'], help_message="DF keeps faces more natural, while LIAE can fix overly different face shapes.").lower()
|
self.options['archi'] = input_str ("AE architecture (df, liae, vg ?:help skip:%s) : " % (default_archi) , default_archi, ['df','liae','vg'], help_message="'df' keeps faces more natural. 'liae' can fix overly different face shapes. 'vg' - currently testing.").lower()
|
||||||
else:
|
else:
|
||||||
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
self.options['resolution'] = self.options.get('resolution', default_resolution)
|
||||||
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
self.options['face_type'] = self.options.get('face_type', default_face_type)
|
||||||
|
@ -48,10 +48,14 @@ class SAEModel(ModelBase):
|
||||||
|
|
||||||
if is_first_run:
|
if is_first_run:
|
||||||
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
self.options['lighter_encoder'] = input_bool ("Use lightweight encoder? (y/n, ?:help skip:n) : ", False, help_message="Lightweight encoder is 35% faster, requires less VRAM, but sacrificing overall quality.")
|
||||||
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:y) : ", True, help_message="Multiscale decoder helps to get better details.")
|
|
||||||
|
if self.options['archi'] != 'vg':
|
||||||
|
self.options['multiscale_decoder'] = input_bool ("Use multiscale decoder? (y/n, ?:help skip:y) : ", True, help_message="Multiscale decoder helps to get better details.")
|
||||||
else:
|
else:
|
||||||
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
self.options['lighter_encoder'] = self.options.get('lighter_encoder', False)
|
||||||
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', True)
|
|
||||||
|
if self.options['archi'] != 'vg':
|
||||||
|
self.options['multiscale_decoder'] = self.options.get('multiscale_decoder', True)
|
||||||
|
|
||||||
default_face_style_power = 0.0
|
default_face_style_power = 0.0
|
||||||
default_bg_style_power = 0.0
|
default_bg_style_power = 0.0
|
||||||
|
@ -74,17 +78,19 @@ class SAEModel(ModelBase):
|
||||||
#override
|
#override
|
||||||
def onInitialize(self, **in_options):
|
def onInitialize(self, **in_options):
|
||||||
exec(nnlib.import_all(), locals(), globals())
|
exec(nnlib.import_all(), locals(), globals())
|
||||||
|
SAEModel.initialize_nn_functions()
|
||||||
|
|
||||||
self.set_vram_batch_requirements({1.5:4})
|
self.set_vram_batch_requirements({1.5:4})
|
||||||
|
|
||||||
resolution = self.options['resolution']
|
resolution = self.options['resolution']
|
||||||
ae_dims = self.options['ae_dims']
|
ae_dims = self.options['ae_dims']
|
||||||
ed_ch_dims = self.options['ed_ch_dims']
|
ed_ch_dims = self.options['ed_ch_dims']
|
||||||
adapt_k_size = False
|
|
||||||
bgr_shape = (resolution, resolution, 3)
|
bgr_shape = (resolution, resolution, 3)
|
||||||
mask_shape = (resolution, resolution, 1)
|
mask_shape = (resolution, resolution, 1)
|
||||||
|
|
||||||
|
self.ms_count = ms_count = 3 if (self.options['archi'] != 'vg' and self.options['multiscale_decoder']) else 1
|
||||||
|
|
||||||
self.ms_count = ms_count = 3 if self.options['multiscale_decoder'] else 1
|
masked_training = True
|
||||||
|
|
||||||
epoch_alpha = Input( (1,) )
|
epoch_alpha = Input( (1,) )
|
||||||
warped_src = Input(bgr_shape)
|
warped_src = Input(bgr_shape)
|
||||||
|
@ -101,7 +107,7 @@ class SAEModel(ModelBase):
|
||||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||||
|
|
||||||
if self.options['archi'] == 'liae':
|
if self.options['archi'] == 'liae':
|
||||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||||
|
|
||||||
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
enc_output_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
@ -143,8 +149,8 @@ class SAEModel(ModelBase):
|
||||||
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
pred_dst_dstm = self.decoderm(warped_dst_inter_code)
|
||||||
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
pred_src_dstm = self.decoderm(warped_src_dst_inter_code)
|
||||||
|
|
||||||
else:
|
elif self.options['archi'] == 'df':
|
||||||
self.encoder = modelify(SAEModel.DFEncFlow(resolution, adapt_k_size, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
self.encoder = modelify(SAEModel.DFEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||||
|
|
||||||
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
@ -173,7 +179,39 @@ class SAEModel(ModelBase):
|
||||||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
|
elif self.options['archi'] == 'vg':
|
||||||
|
self.encoder = modelify(SAEModel.VGEncFlow(resolution, self.options['lighter_encoder'], ae_dims=ae_dims, ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||||
|
|
||||||
|
dec_Inputs = [ Input(K.int_shape(x)[1:]) for x in self.encoder.outputs ]
|
||||||
|
|
||||||
|
self.decoder_src = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||||
|
self.decoder_dst = modelify(SAEModel.VGDecFlow (bgr_shape[2],ed_ch_dims=ed_ch_dims//2 )) (dec_Inputs)
|
||||||
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
self.decoder_srcm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||||
|
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||||
|
|
||||||
|
if not self.is_first_run():
|
||||||
|
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||||
|
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||||
|
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||||
|
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||||
|
|
||||||
|
warped_src_code = self.encoder (warped_src)
|
||||||
|
warped_dst_code = self.encoder (warped_dst)
|
||||||
|
pred_src_src = self.decoder_src(warped_src_code)
|
||||||
|
pred_dst_dst = self.decoder_dst(warped_dst_code)
|
||||||
|
pred_src_dst = self.decoder_src(warped_dst_code)
|
||||||
|
|
||||||
|
|
||||||
|
if self.options['learn_mask']:
|
||||||
|
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||||
|
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||||
|
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||||
|
|
||||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
|
@ -193,11 +231,20 @@ class SAEModel(ModelBase):
|
||||||
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
pred_src_src_sigm_ar = [ x + 1 for x in pred_src_src]
|
||||||
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
pred_dst_dst_sigm_ar = [ x + 1 for x in pred_dst_dst]
|
||||||
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
pred_src_dst_sigm_ar = [ x + 1 for x in pred_src_dst]
|
||||||
|
|
||||||
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
target_src_masked_ar = [ target_src_sigm_ar[i]*target_srcm_sigm_ar[i] for i in range(len(target_src_sigm_ar))]
|
||||||
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
target_dst_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
target_dst_anti_masked_ar = [ target_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(target_dst_sigm_ar))]
|
||||||
|
|
||||||
|
pred_src_src_masked_ar = [ pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] for i in range(len(pred_src_src_sigm_ar))]
|
||||||
|
pred_dst_dst_masked_ar = [ pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] for i in range(len(pred_dst_dst_sigm_ar))]
|
||||||
|
|
||||||
|
target_src_masked_ar_opt = target_src_masked_ar if masked_training else target_src_sigm_ar
|
||||||
|
target_dst_masked_ar_opt = target_dst_masked_ar if masked_training else target_dst_sigm_ar
|
||||||
|
|
||||||
|
pred_src_src_masked_ar_opt = pred_src_src_masked_ar if masked_training else pred_src_src_sigm_ar
|
||||||
|
pred_dst_dst_masked_ar_opt = pred_dst_dst_masked_ar if masked_training else pred_dst_dst_sigm_ar
|
||||||
|
|
||||||
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))]
|
||||||
|
|
||||||
|
@ -215,9 +262,9 @@ class SAEModel(ModelBase):
|
||||||
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
src_dst_mask_loss_train_weights = self.encoder.trainable_weights + self.decoder_srcm.trainable_weights + self.decoder_dstm.trainable_weights
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
src_loss_batch = sum([ ( 100*K.square( dssim(max_value=2.0)( target_src_masked_ar[i], pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ) )) for i in range(len(target_src_masked_ar)) ])
|
src_loss_batch = sum([ ( 100*K.square( dssim(max_value=2.0)( target_src_masked_ar_opt[i], pred_src_src_masked_ar_opt[i] ) )) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar[i] - pred_src_src_sigm_ar[i] * target_srcm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar)) ])
|
src_loss_batch = sum([ K.mean ( 100*K.square( target_src_masked_ar_opt[i] - pred_src_src_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_src_masked_ar_opt)) ])
|
||||||
|
|
||||||
src_loss = K.mean(src_loss_batch)
|
src_loss = K.mean(src_loss_batch)
|
||||||
|
|
||||||
|
@ -235,9 +282,9 @@ class SAEModel(ModelBase):
|
||||||
src_loss += bg_loss
|
src_loss += bg_loss
|
||||||
|
|
||||||
if not self.options['pixel_loss']:
|
if not self.options['pixel_loss']:
|
||||||
dst_loss_batch = sum([ ( 100*K.square(dssim(max_value=2.0)( target_dst_masked_ar[i], pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ) )) for i in range(len(target_dst_masked_ar)) ])
|
dst_loss_batch = sum([ ( 100*K.square(dssim(max_value=2.0)( target_dst_masked_ar_opt[i], pred_dst_dst_masked_ar_opt[i] ) )) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
else:
|
else:
|
||||||
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar[i] - pred_dst_dst_sigm_ar[i] * target_dstm_sigm_ar[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar)) ])
|
dst_loss_batch = sum([ K.mean ( 100*K.square( target_dst_masked_ar_opt[i] - pred_dst_dst_masked_ar_opt[i] ), axis=[1,2,3]) for i in range(len(target_dst_masked_ar_opt)) ])
|
||||||
|
|
||||||
dst_loss = K.mean(dst_loss_batch)
|
dst_loss = K.mean(dst_loss_batch)
|
||||||
|
|
||||||
|
@ -390,28 +437,68 @@ class SAEModel(ModelBase):
|
||||||
default_blur_mask_modifier=default_blur_mask_modifier,
|
default_blur_mask_modifier=default_blur_mask_modifier,
|
||||||
clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0,
|
clip_hborder_mask_per=0.0625 if self.options['face_type'] == 'f' else 0,
|
||||||
**in_options)
|
**in_options)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEEncFlow(resolution, adapt_k_size, light_enc, ed_ch_dims=42):
|
def initialize_nn_functions():
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
||||||
k_size = resolution // 16 + 1 if adapt_k_size else 5
|
class ResidualBlock(object):
|
||||||
strides = resolution // 32 if adapt_k_size else 2
|
def __init__(self, filters, kernel_size=3, padding='same', use_reflection_padding=False):
|
||||||
|
self.filters = filters
|
||||||
|
self.kernel_size = kernel_size
|
||||||
|
self.padding = padding #if not use_reflection_padding else 'valid'
|
||||||
|
self.use_reflection_padding = use_reflection_padding
|
||||||
|
|
||||||
|
def __call__(self, inp):
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(inp)
|
||||||
|
|
||||||
|
#if self.use_reflection_padding:
|
||||||
|
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||||
|
|
||||||
|
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
|
||||||
|
#if self.use_reflection_padding:
|
||||||
|
# #var_x = ReflectionPadding2D(stride=1, kernel_size=kernel_size)(var_x)
|
||||||
|
|
||||||
|
var_x = Conv2D(self.filters, kernel_size=self.kernel_size, padding=self.padding, kernel_initializer=RandomNormal(0, 0.02) )(var_x)
|
||||||
|
var_x = Scale(gamma_init=keras.initializers.Constant(value=0.1))(var_x)
|
||||||
|
var_x = Add()([var_x, inp])
|
||||||
|
var_x = LeakyReLU(alpha=0.2)(var_x)
|
||||||
|
return var_x
|
||||||
|
SAEModel.ResidualBlock = ResidualBlock
|
||||||
|
|
||||||
def downscale (dim):
|
def downscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
|
return LeakyReLU(0.1)(Conv2D(dim, kernel_size=5, strides=2, padding='same', kernel_initializer=RandomNormal(0, 0.02))(x))
|
||||||
return func
|
return func
|
||||||
|
SAEModel.downscale = downscale
|
||||||
|
|
||||||
def downscale_sep (dim):
|
def downscale_sep (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return LeakyReLU(0.1)(SeparableConv2D(dim, k_size, strides=strides, padding='same')(x))
|
return LeakyReLU(0.1)(SeparableConv2D(dim, kernel_size=5, strides=2, padding='same', depthwise_initializer=RandomNormal(0, 0.02), pointwise_initializer=RandomNormal(0, 0.02) )(x))
|
||||||
return func
|
return func
|
||||||
|
SAEModel.downscale_sep = downscale_sep
|
||||||
|
|
||||||
def upscale (dim):
|
def upscale (dim):
|
||||||
def func(x):
|
def func(x):
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, kernel_size=3, strides=1, padding='same', kernel_initializer=RandomNormal(0, 0.02) )(x)))
|
||||||
return func
|
return func
|
||||||
|
SAEModel.upscale = upscale
|
||||||
|
|
||||||
|
def to_bgr (output_nc):
|
||||||
|
def func(x):
|
||||||
|
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh', kernel_initializer=RandomNormal(0, 0.02))(x)
|
||||||
|
return func
|
||||||
|
SAEModel.to_bgr = to_bgr
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def LIAEEncFlow(resolution, light_enc, ed_ch_dims=42):
|
||||||
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
|
downscale = SAEModel.downscale
|
||||||
|
downscale_sep = SAEModel.downscale_sep
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
ed_dims = K.int_shape(input)[-1]*ed_ch_dims
|
ed_dims = K.int_shape(input)[-1]*ed_ch_dims
|
||||||
|
@ -434,13 +521,9 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEInterFlow(resolution, ae_dims=256):
|
def LIAEInterFlow(resolution, ae_dims=256):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
lowest_dense_res=resolution // 16
|
lowest_dense_res=resolution // 16
|
||||||
|
|
||||||
def upscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
x = Dense(ae_dims)(x)
|
x = Dense(ae_dims)(x)
|
||||||
|
@ -453,17 +536,10 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_count=1):
|
def LIAEDecFlow(output_nc,ed_ch_dims=21, multiscale_count=1):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
|
to_bgr = SAEModel.to_bgr
|
||||||
ed_dims = output_nc * ed_ch_dims
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
def upscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def to_bgr ():
|
|
||||||
def func(x):
|
|
||||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
|
||||||
return func
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
|
||||||
|
@ -471,45 +547,28 @@ class SAEModel(ModelBase):
|
||||||
x1 = upscale(ed_dims*8)( x )
|
x1 = upscale(ed_dims*8)( x )
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
if multiscale_count >= 3:
|
||||||
outputs += [ to_bgr() ( x1 ) ]
|
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
||||||
|
|
||||||
x2 = upscale(ed_dims*4)( x1 )
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
if multiscale_count >= 2:
|
||||||
outputs += [ to_bgr() ( x2 ) ]
|
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
||||||
|
|
||||||
x3 = upscale(ed_dims*2)( x2 )
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
|
|
||||||
outputs += [ to_bgr() ( x3 ) ]
|
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
return func
|
return func
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFEncFlow(resolution, adapt_k_size, light_enc, ae_dims=512, ed_ch_dims=42):
|
def DFEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
k_size = resolution // 16 + 1 if adapt_k_size else 5
|
upscale = SAEModel.upscale
|
||||||
strides = resolution // 32 if adapt_k_size else 2
|
downscale = SAEModel.downscale
|
||||||
|
downscale_sep = SAEModel.downscale_sep
|
||||||
lowest_dense_res = resolution // 16
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
|
||||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
|
||||||
|
|
||||||
def downscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return LeakyReLU(0.1)(Conv2D(dim, k_size, strides=strides, padding='same')(x))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def downscale_sep (dim):
|
|
||||||
def func(x):
|
|
||||||
return LeakyReLU(0.1)(SeparableConv2D(dim, k_size, strides=strides, padding='same')(x))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def upscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input
|
x = input
|
||||||
|
|
||||||
|
@ -536,20 +595,10 @@ class SAEModel(ModelBase):
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
|
def DFDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
|
to_bgr = SAEModel.to_bgr
|
||||||
ed_dims = output_nc * ed_ch_dims
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
def Conv2D (filters, kernel_size, strides=(1, 1), padding='valid', data_format=None, dilation_rate=(1, 1), activation=None, use_bias=True, kernel_initializer=RandomNormal(0, 0.02), bias_initializer='zeros', kernel_regularizer=None, bias_regularizer=None, activity_regularizer=None, kernel_constraint=None, bias_constraint=None):
|
|
||||||
return keras.layers.Conv2D( filters=filters, kernel_size=kernel_size, strides=strides, padding=padding, data_format=data_format, dilation_rate=dilation_rate, activation=activation, use_bias=use_bias, kernel_initializer=kernel_initializer, bias_initializer=bias_initializer, kernel_regularizer=kernel_regularizer, bias_regularizer=bias_regularizer, activity_regularizer=activity_regularizer, kernel_constraint=kernel_constraint, bias_constraint=bias_constraint )
|
|
||||||
|
|
||||||
def upscale (dim):
|
|
||||||
def func(x):
|
|
||||||
return SubpixelUpscaler()(LeakyReLU(0.1)(Conv2D(dim * 4, 3, strides=1, padding='same')(x)))
|
|
||||||
return func
|
|
||||||
|
|
||||||
def to_bgr ():
|
|
||||||
def func(x):
|
|
||||||
return Conv2D(output_nc, kernel_size=5, padding='same', activation='tanh')(x)
|
|
||||||
return func
|
|
||||||
def func(input):
|
def func(input):
|
||||||
x = input[0]
|
x = input[0]
|
||||||
|
|
||||||
|
@ -557,18 +606,95 @@ class SAEModel(ModelBase):
|
||||||
x1 = upscale(ed_dims*8)( x )
|
x1 = upscale(ed_dims*8)( x )
|
||||||
|
|
||||||
if multiscale_count >= 3:
|
if multiscale_count >= 3:
|
||||||
outputs += [ to_bgr() ( x1 ) ]
|
outputs += [ to_bgr(output_nc) ( x1 ) ]
|
||||||
|
|
||||||
x2 = upscale(ed_dims*4)( x1 )
|
x2 = upscale(ed_dims*4)( x1 )
|
||||||
|
|
||||||
if multiscale_count >= 2:
|
if multiscale_count >= 2:
|
||||||
outputs += [ to_bgr() ( x2 ) ]
|
outputs += [ to_bgr(output_nc) ( x2 ) ]
|
||||||
|
|
||||||
x3 = upscale(ed_dims*2)( x2 )
|
x3 = upscale(ed_dims*2)( x2 )
|
||||||
|
|
||||||
outputs += [ to_bgr() ( x3 ) ]
|
outputs += [ to_bgr(output_nc) ( x3 ) ]
|
||||||
|
|
||||||
return outputs
|
return outputs
|
||||||
|
return func
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def VGEncFlow(resolution, light_enc, ae_dims=512, ed_ch_dims=42):
|
||||||
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
|
downscale = SAEModel.downscale
|
||||||
|
downscale_sep = SAEModel.downscale_sep
|
||||||
|
ResidualBlock = SAEModel.ResidualBlock
|
||||||
|
lowest_dense_res = resolution // 16
|
||||||
|
|
||||||
|
def func(input):
|
||||||
|
x = input
|
||||||
|
ed_dims = K.int_shape(input)[-1]*ed_ch_dims
|
||||||
|
while np.modf(ed_dims / 4)[0] != 0.0:
|
||||||
|
ed_dims -= 1
|
||||||
|
|
||||||
|
in_conv_filters = ed_dims if resolution <= 128 else ed_dims + (resolution//128)*ed_ch_dims
|
||||||
|
|
||||||
|
x = tmp_x = Conv2D (in_conv_filters, kernel_size=5, strides=2, padding='same') (x)
|
||||||
|
|
||||||
|
for _ in range ( 8 if light_enc else 16 ):
|
||||||
|
x = ResidualBlock(ed_dims)(x)
|
||||||
|
|
||||||
|
x = Add()([x, tmp_x])
|
||||||
|
|
||||||
|
x = downscale(ed_dims)(x)
|
||||||
|
x = SubpixelUpscaler()(x)
|
||||||
|
|
||||||
|
x = downscale(ed_dims)(x)
|
||||||
|
x = SubpixelUpscaler()(x)
|
||||||
|
|
||||||
|
x = downscale(ed_dims)(x)
|
||||||
|
if light_enc:
|
||||||
|
x = downscale_sep (ed_dims*2)(x)
|
||||||
|
else:
|
||||||
|
x = downscale (ed_dims*2)(x)
|
||||||
|
|
||||||
|
x = downscale(ed_dims*4)(x)
|
||||||
|
|
||||||
|
if light_enc:
|
||||||
|
x = downscale_sep (ed_dims*8)(x)
|
||||||
|
else:
|
||||||
|
x = downscale (ed_dims*8)(x)
|
||||||
|
|
||||||
|
x = Dense(ae_dims)(Flatten()(x))
|
||||||
|
x = Dense(lowest_dense_res * lowest_dense_res * ae_dims)(x)
|
||||||
|
x = Reshape((lowest_dense_res, lowest_dense_res, ae_dims))(x)
|
||||||
|
x = upscale(ae_dims)(x)
|
||||||
|
return x
|
||||||
|
|
||||||
|
return func
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def VGDecFlow(output_nc, ed_ch_dims=21, multiscale_count=1):
|
||||||
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
upscale = SAEModel.upscale
|
||||||
|
to_bgr = SAEModel.to_bgr
|
||||||
|
ResidualBlock = SAEModel.ResidualBlock
|
||||||
|
ed_dims = output_nc * ed_ch_dims
|
||||||
|
|
||||||
|
def func(input):
|
||||||
|
x = input[0]
|
||||||
|
|
||||||
|
x = upscale( ed_dims*8 )(x)
|
||||||
|
x = ResidualBlock( ed_dims*8 )(x)
|
||||||
|
|
||||||
|
x = upscale( ed_dims*4 )(x)
|
||||||
|
x = ResidualBlock( ed_dims*4 )(x)
|
||||||
|
|
||||||
|
x = upscale( ed_dims*2 )(x)
|
||||||
|
x = ResidualBlock( ed_dims*2 )(x)
|
||||||
|
|
||||||
|
x = to_bgr(output_nc) (x)
|
||||||
|
return x
|
||||||
|
|
||||||
return func
|
return func
|
||||||
|
|
||||||
|
|
|
@ -62,6 +62,7 @@ Lambda = keras.layers.Lambda
|
||||||
Add = keras.layers.Add
|
Add = keras.layers.Add
|
||||||
Concatenate = keras.layers.Concatenate
|
Concatenate = keras.layers.Concatenate
|
||||||
|
|
||||||
|
|
||||||
Flatten = keras.layers.Flatten
|
Flatten = keras.layers.Flatten
|
||||||
Reshape = keras.layers.Reshape
|
Reshape = keras.layers.Reshape
|
||||||
|
|
||||||
|
@ -77,9 +78,11 @@ gaussian_blur = nnlib.gaussian_blur
|
||||||
style_loss = nnlib.style_loss
|
style_loss = nnlib.style_loss
|
||||||
dssim = nnlib.dssim
|
dssim = nnlib.dssim
|
||||||
|
|
||||||
#ReflectionPadding2D = nnlib.ReflectionPadding2D
|
|
||||||
PixelShuffler = nnlib.PixelShuffler
|
PixelShuffler = nnlib.PixelShuffler
|
||||||
SubpixelUpscaler = nnlib.SubpixelUpscaler
|
SubpixelUpscaler = nnlib.SubpixelUpscaler
|
||||||
|
Scale = nnlib.Scale
|
||||||
|
#ReflectionPadding2D = nnlib.ReflectionPadding2D
|
||||||
#AddUniformNoise = nnlib.AddUniformNoise
|
#AddUniformNoise = nnlib.AddUniformNoise
|
||||||
"""
|
"""
|
||||||
code_import_keras_contrib_string = \
|
code_import_keras_contrib_string = \
|
||||||
|
@ -183,9 +186,10 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
|
|
||||||
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
if 'TF_SUPPRESS_STD' in os.environ.keys() and os.environ['TF_SUPPRESS_STD'] == '1':
|
||||||
suppressor.__exit__()
|
suppressor.__exit__()
|
||||||
|
|
||||||
nnlib.__initialize_keras_functions()
|
|
||||||
nnlib.code_import_keras = compile (nnlib.code_import_keras_string,'','exec')
|
nnlib.code_import_keras = compile (nnlib.code_import_keras_string,'','exec')
|
||||||
|
nnlib.__initialize_keras_functions()
|
||||||
|
|
||||||
return nnlib.code_import_keras
|
return nnlib.code_import_keras
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
@ -394,9 +398,41 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
|
||||||
nnlib.PixelShuffler = PixelShuffler
|
nnlib.PixelShuffler = PixelShuffler
|
||||||
nnlib.SubpixelUpscaler = PixelShuffler
|
nnlib.SubpixelUpscaler = PixelShuffler
|
||||||
'''
|
|
||||||
|
class Scale(keras.layers.Layer):
|
||||||
|
"""
|
||||||
|
GAN Custom Scal Layer
|
||||||
|
Code borrows from https://github.com/flyyufelix/cnn_finetune
|
||||||
|
"""
|
||||||
|
def __init__(self, weights=None, axis=-1, gamma_init='zero', **kwargs):
|
||||||
|
self.axis = axis
|
||||||
|
self.gamma_init = keras.initializers.get(gamma_init)
|
||||||
|
self.initial_weights = weights
|
||||||
|
super(Scale, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
def build(self, input_shape):
|
||||||
|
self.input_spec = [keras.engine.InputSpec(shape=input_shape)]
|
||||||
|
|
||||||
|
# Compatibility with TensorFlow >= 1.0.0
|
||||||
|
self.gamma = K.variable(self.gamma_init((1,)), name='{}_gamma'.format(self.name))
|
||||||
|
self.trainable_weights = [self.gamma]
|
||||||
|
|
||||||
|
if self.initial_weights is not None:
|
||||||
|
self.set_weights(self.initial_weights)
|
||||||
|
del self.initial_weights
|
||||||
|
|
||||||
|
def call(self, x, mask=None):
|
||||||
|
return self.gamma * x
|
||||||
|
|
||||||
|
def get_config(self):
|
||||||
|
config = {"axis": self.axis}
|
||||||
|
base_config = super(Scale, self).get_config()
|
||||||
|
return dict(list(base_config.items()) + list(config.items()))
|
||||||
|
nnlib.Scale = Scale
|
||||||
|
|
||||||
|
'''
|
||||||
|
not implemented in plaidML
|
||||||
class ReflectionPadding2D(keras.layers.Layer):
|
class ReflectionPadding2D(keras.layers.Layer):
|
||||||
def __init__(self, padding=(1, 1), **kwargs):
|
def __init__(self, padding=(1, 1), **kwargs):
|
||||||
self.padding = tuple(padding)
|
self.padding = tuple(padding)
|
||||||
|
@ -410,28 +446,11 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
def call(self, x, mask=None):
|
def call(self, x, mask=None):
|
||||||
w_pad,h_pad = self.padding
|
w_pad,h_pad = self.padding
|
||||||
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
|
return tf.pad(x, [[0,0], [h_pad,h_pad], [w_pad,w_pad], [0,0] ], 'REFLECT')
|
||||||
nnlib.ReflectionPadding2D = ReflectionPadding2D
|
nnlib.ReflectionPadding2D = ReflectionPadding2D
|
||||||
|
'''
|
||||||
|
|
||||||
class AddUniformNoise(keras.layers.Layer):
|
|
||||||
def __init__(self, power=1.0, minval=-1.0, maxval=1.0, **kwargs):
|
|
||||||
super(AddUniformNoise, self).__init__(**kwargs)
|
|
||||||
self.power = power
|
|
||||||
self.supports_masking = True
|
|
||||||
self.minval = minval
|
|
||||||
self.maxval = maxval
|
|
||||||
|
|
||||||
def call(self, inputs, training=None):
|
|
||||||
def noised():
|
|
||||||
return inputs + self.power*K.random_uniform(shape=K.shape(inputs), minval=self.minval, maxval=self.maxval)
|
|
||||||
return K.in_train_phase(noised, inputs, training=training)
|
|
||||||
|
|
||||||
def get_config(self):
|
|
||||||
config = {'power': self.power, 'minval': self.minval, 'maxval': self.maxval}
|
|
||||||
base_config = super(AddUniformNoise, self).get_config()
|
|
||||||
return dict(list(base_config.items()) + list(config.items()))
|
|
||||||
nnlib.AddUniformNoise = AddUniformNoise
|
|
||||||
'''
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def import_keras_contrib(device_config = None):
|
def import_keras_contrib(device_config = None):
|
||||||
if nnlib.keras_contrib is not None:
|
if nnlib.keras_contrib is not None:
|
||||||
|
@ -489,6 +508,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
return 10*dssim() (y_true*mask, y_pred*mask)
|
return 10*dssim() (y_true*mask, y_pred*mask)
|
||||||
nnlib.DSSIMMSEMaskLoss = DSSIMMSEMaskLoss
|
nnlib.DSSIMMSEMaskLoss = DSSIMMSEMaskLoss
|
||||||
|
|
||||||
|
|
||||||
'''
|
'''
|
||||||
def ResNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False):
|
def ResNet(output_nc, use_batch_norm, ngf=64, n_blocks=6, use_dropout=False):
|
||||||
exec (nnlib.import_all(), locals(), globals())
|
exec (nnlib.import_all(), locals(), globals())
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue