Merge pull request #129 from faceshiftlabs/feature/ms-ssim-loss-2

Feature/ms ssim loss 2
This commit is contained in:
Jeremy Hummel 2021-03-24 13:10:13 -07:00 committed by GitHub
commit 06f30b0e0d
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
6 changed files with 119 additions and 27 deletions

View file

@ -0,0 +1,30 @@
from core.leras import nn
tf = nn.tf
class MsSsim(nn.LayerBase):
default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333)
def __init__(self, resolution, kernel_size=11, **kwargs):
# restrict mssim factors to those greater/equal to kernel size
power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size]
# normalize power factors if reduced because of size
if sum(power_factors) < 1.0:
power_factors = [x/sum(power_factors) for x in power_factors]
self.power_factors = power_factors
self.kernel_size = kernel_size
super().__init__(**kwargs)
def __call__(self, y_true, y_pred, max_val):
# Transpose images from NCHW to NHWC
y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1])
y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1])
ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size)
# ssim_multiscale returns values in range [0, 1] (where 1 is completely identical)
# subtract from 1 to get loss
return 1.0 - ms_ssim_val
nn.MsSsim = MsSsim

View file

@ -13,4 +13,5 @@ from .FRNorm2D import *
from .TLU import * from .TLU import *
from .ScaleAdd import * from .ScaleAdd import *
from .DenseNorm import * from .DenseNorm import *
from .AdaIN import * from .AdaIN import *
from .MsSsim import *

View file

@ -40,7 +40,7 @@ class nn():
conv2d_spatial_axes = None conv2d_spatial_axes = None
floatx = None floatx = None
@staticmethod @staticmethod
def initialize(device_config=None, floatx="float32", data_format="NHWC"): def initialize(device_config=None, floatx="float32", data_format="NHWC"):
@ -50,7 +50,7 @@ class nn():
nn.setCurrentDeviceConfig(device_config) nn.setCurrentDeviceConfig(device_config)
# Manipulate environment variables before import tensorflow # Manipulate environment variables before import tensorflow
if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): if 'CUDA_VISIBLE_DEVICES' in os.environ.keys():
os.environ.pop('CUDA_VISIBLE_DEVICES') os.environ.pop('CUDA_VISIBLE_DEVICES')
@ -77,13 +77,13 @@ class nn():
io.log_info("Caching GPU kernels...") io.log_info("Caching GPU kernels...")
import tensorflow import tensorflow
tf_version = getattr(tensorflow,'VERSION', None) tf_version = getattr(tensorflow,'VERSION', None)
if tf_version is None: if tf_version is None:
tf_version = tensorflow.version.GIT_VERSION tf_version = tensorflow.version.GIT_VERSION
if tf_version[0] == 'v': if tf_version[0] == 'v':
tf_version = tf_version[1:] tf_version = tf_version[1:]
if tf_version[0] == '2': if tf_version[0] == '2':
tf = tensorflow.compat.v1 tf = tensorflow.compat.v1
else: else:
@ -93,7 +93,7 @@ class nn():
# Disable tensorflow warnings # Disable tensorflow warnings
tf_logger = logging.getLogger('tensorflow') tf_logger = logging.getLogger('tensorflow')
tf_logger.setLevel(logging.ERROR) tf_logger.setLevel(logging.ERROR)
if tf_version[0] == '2': if tf_version[0] == '2':
tf.disable_v2_behavior() tf.disable_v2_behavior()
nn.tf = tf nn.tf = tf
@ -105,20 +105,20 @@ class nn():
import core.leras.optimizers import core.leras.optimizers
import core.leras.models import core.leras.models
import core.leras.archis import core.leras.archis
# Configure tensorflow session-config # Configure tensorflow session-config
if len(device_config.devices) == 0: if len(device_config.devices) == 0:
nn.tf_default_device = "/CPU:0" nn.tf_default_device = "/CPU:0"
config = tf.ConfigProto(device_count={'GPU': 0}) config = tf.ConfigProto(device_count={'GPU': 0})
else: else:
nn.tf_default_device = "/GPU:0" nn.tf_default_device = "/GPU:0"
config = tf.ConfigProto() config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices])
config.gpu_options.force_gpu_compatible = True config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
nn.tf_sess_config = config nn.tf_sess_config = config
if nn.tf_sess is None: if nn.tf_sess is None:
nn.tf_sess = tf.Session(config=nn.tf_sess_config) nn.tf_sess = tf.Session(config=nn.tf_sess_config)
@ -273,7 +273,7 @@ class nn():
@staticmethod @staticmethod
def ask_choose_device(*args, **kwargs): def ask_choose_device(*args, **kwargs):
return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) ) return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) )
def __init__ (self, devices=None): def __init__ (self, devices=None):
devices = devices or [] devices = devices or []

View file

@ -0,0 +1,42 @@
# Multiscale SSIM (MS-SSIM)
Allows you to train using the MS-SSIM (multiscale structural similarity index measure) as the main loss metric,
a perceptually more accurate measure of image quality than MSE (mean squared error).
- [DESCRIPTION](#description)
- [USAGE](#usage)
![](example.png)
## DESCRIPTION
[SSIM](https://en.wikipedia.org/wiki/Structural_similarity) is metric for comparing the perceptial quality of an image:
> SSIM is a perception-based model that considers image degradation as perceived change in structural information,
> while also incorporating important perceptual phenomena, including both luminance masking and contrast masking terms.
> [...]
> Structural information is the idea that the pixels have strong inter-dependencies especially when they are spatially
> close. These dependencies carry important information about the structure of the objects in the visual scene.
> Luminance masking is a phenomenon whereby image distortions (in this context) tend to be less visible in bright
> regions, while contrast masking is a phenomenon whereby distortions become less visible where there is significant
> activity or "texture" in the image.
The current loss metric is a combination of SSIM (structural similarity index measure) and
[MSE](https://en.wikipedia.org/wiki/Mean_squared_error) (mean squared error).
[Multiscale SSIM](https://en.wikipedia.org/wiki/Structural_similarity#Multi-Scale_SSIM) is a variant of SSIM that
improves upon SSIM by comparing the similarity at multiple scales (e.g.: full-size, half-size, 1/4 size, etc.)
By using MS-SSIM as our main loss metric, we should expect the image similarity to improve across each scale, improving
both the large scale and small scale detail of the predicted images.
Original paper: [Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik.
"Multiscale structural similarity for image quality assessment."
Signals, Systems and Computers, 2004.](https://www.cns.nyu.edu/pub/eero/wang03b.pdf)
## USAGE
```
[n] Use multiscale loss? ( y/n ?:help ) : y
```

View file

@ -19,4 +19,7 @@ maintaining the same `C*` (chroma, relative saturation).
## USAGE ## USAGE
`[n] Random color ( y/n ?:help ) : y` ```
[n] Random color ( y/n ?:help ) : y
```

View file

@ -53,6 +53,8 @@ class SAEHDModel(ModelBase):
lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp
default_lr_dropout = self.options['lr_dropout'] = lr_dropout default_lr_dropout = self.options['lr_dropout'] = lr_dropout
default_ms_ssim_loss = self.options['ms_ssim_loss'] = self.load_or_def_option('ms_ssim_loss', False)
default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True)
default_background_power = self.options['background_power'] = self.load_or_def_option('background_power', 0.0) default_background_power = self.options['background_power'] = self.load_or_def_option('background_power', 0.0)
default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0)
@ -152,6 +154,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.")
self.options['ms_ssim_loss'] = io.input_bool("Use multiscale loss?", default_ms_ssim_loss, help_message="Use Multiscale structural similarity for image quality assessment.")
self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.")
self.options['gan_version'] = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3) self.options['gan_version'] = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3)
@ -450,11 +454,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur
gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur)
if resolution < 256: if self.options['ms_ssim_loss']:
gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0)
else: else:
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) if resolution < 256:
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
else:
gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3])
if eyes_prio or mouth_prio: if eyes_prio or mouth_prio:
@ -471,11 +478,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.options['background_power'] > 0: if self.options['background_power'] > 0:
bg_factor = self.options['background_power'] bg_factor = self.options['background_power']
if resolution < 256: if self.options['ms_ssim_loss']:
gpu_src_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0)
else: else:
gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) if resolution < 256:
gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
else:
gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_src_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_src_anti_masked - gpu_pred_src_src_anti_masked ), axis=[1,2,3]) gpu_src_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_src_anti_masked - gpu_pred_src_src_anti_masked ), axis=[1,2,3])
face_style_power = self.options['face_style_power'] / 100.0 face_style_power = self.options['face_style_power'] / 100.0
@ -487,11 +497,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] )
if resolution < 256: if self.options['ms_ssim_loss']:
gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) gpu_dst_loss = 10 * nn.MsSsim(resolution)(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0)
else: else:
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) if resolution < 256:
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
else:
gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3])
@ -507,11 +520,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ...
if self.options['background_power'] > 0: if self.options['background_power'] > 0:
bg_factor = self.options['background_power'] bg_factor = self.options['background_power']
if resolution < 256: if self.options['ms_ssim_loss']:
gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0)
else: else:
gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) if resolution < 256:
gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
else:
gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1])
gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1])
gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_dst_anti_masked - gpu_pred_dst_dst_anti_masked ), axis=[1,2,3]) gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_dst_anti_masked - gpu_pred_dst_dst_anti_masked ), axis=[1,2,3])
gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] ) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )