test: try as layer

This commit is contained in:
jh 2021-03-12 17:52:57 -08:00
commit 9b98607b33
3 changed files with 64 additions and 39 deletions

View file

@ -0,0 +1,25 @@
from core.leras import nn
tf = nn.tf
class MsSsim(nn.LayerBase):
default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
def __init__(self, resolution, kernel_size=11, **kwargs):
# restrict mssim factors to those greater/equal to kernel size
power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size]
# normalize power factors if reduced because of size
if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors]
self.power_factors = power_factors
self.kernel_size = kernel_size
super().__init__(**kwargs)
def __call__(self, y_true, y_pred, max_val):
# Transpose images from NCHW to NHWC
y_true_t = tf.transpose(y_true, [0, 2, 3, 1])
y_pred_t = tf.transpose(y_pred, [0, 2, 3, 1])
loss = tf.image.ssim_multiscale(y_true, y_pred, max_val, power_factors=self.power_factors)
return (1.0 - loss) / 2.0
nn.MsSsim = MsSsim

View file

@ -308,44 +308,44 @@ def dssim(img1,img2, max_val, filter_size=11, filter_sigma=1.5, k1=0.01, k2=0.03
nn.dssim = dssim nn.dssim = dssim
def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0, # def ms_ssim(img1, img2, resolution, kernel_size=11, k1=0.01, k2=0.03, max_value=1.0,
power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)): # power_factors=(0.0448, 0.2856, 0.3001, 0.2363, 0.1333)):
#
# restrict mssim factors to those greater/equal to kernel size # # restrict mssim factors to those greater/equal to kernel size
power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size] # power_factors = [power_factors[i] for i in range(len(power_factors)) if resolution//(2**i) >= kernel_size]
#
# normalize power factors if reduced because of size # # normalize power factors if reduced because of size
if sum(power_factors) < 1.0: # if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors] # power_factors = [x/sum(power_factors) for x in power_factors]
#
img_dtype = img1.dtype # img_dtype = img1.dtype
if img_dtype != img2.dtype: # if img_dtype != img2.dtype:
raise ValueError("img1.dtype != img2.dtype") # raise ValueError("img1.dtype != img2.dtype")
#
if img_dtype != tf.float32: # if img_dtype != tf.float32:
img1 = tf.cast(img1, tf.float32) # img1 = tf.cast(img1, tf.float32)
img2 = tf.cast(img2, tf.float32) # img2 = tf.cast(img2, tf.float32)
#
# Transpose images from NCHW to NHWC # # Transpose images from NCHW to NHWC
img1_t = tf.transpose(img1, [0, 2, 3, 1]) # img1_t = tf.transpose(img1, [0, 2, 3, 1])
img2_t = tf.transpose(img2, [0, 2, 3, 1]) # img2_t = tf.transpose(img2, [0, 2, 3, 1])
#
def assign_device(op): # def assign_device(op):
if op.type != 'ListDiff': # if op.type != 'ListDiff':
return '/gpu:0' # return '/gpu:0'
else: # else:
return '/cpu:0' # return '/cpu:0'
#
with tf.device(assign_device): # with tf.device(assign_device):
ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors, # ms_ssim_val = tf.image.ssim_multiscale(img1_t, img2_t, max_val=max_value, power_factors=power_factors,
filter_size=kernel_size, k1=k1, k2=k2) # filter_size=kernel_size, k1=k1, k2=k2)
ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0 # ms_ssim_loss = (1.0 - ms_ssim_val) / 2.0
#
if img_dtype != tf.float32: # if img_dtype != tf.float32:
ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype) # ms_ssim_loss = tf.cast(ms_ssim_loss, img_dtype)
return ms_ssim_loss # return ms_ssim_loss
#
nn.ms_ssim = ms_ssim # nn.ms_ssim = ms_ssim
def space_to_depth(x, size): def space_to_depth(x, size):

View file

@ -426,7 +426,7 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)
if self.options['ms_ssim_loss']: if self.options['ms_ssim_loss']:
gpu_src_loss = tf.reduce_mean ( 10*nn.ms_ssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, resolution)) gpu_src_loss = tf.reduce_mean ( 10*nn.MsSsim(resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt))
else: else:
if resolution < 256: if resolution < 256:
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])