diff --git a/core/leras/layers/MsSsim.py b/core/leras/layers/MsSsim.py new file mode 100644 index 0000000..d45a599 --- /dev/null +++ b/core/leras/layers/MsSsim.py @@ -0,0 +1,30 @@ +from core.leras import nn +tf = nn.tf + + +class MsSsim(nn.LayerBase): + default_power_factors = (0.0448, 0.2856, 0.3001, 0.2363, 0.1333) + + def __init__(self, resolution, kernel_size=11, **kwargs): + # restrict mssim factors to those greater/equal to kernel size + power_factors = [p for i, p in enumerate(self.default_power_factors) if resolution//(2**i) >= kernel_size] + # normalize power factors if reduced because of size + if sum(power_factors) < 1.0: + power_factors = [x/sum(power_factors) for x in power_factors] + self.power_factors = power_factors + self.kernel_size = kernel_size + + super().__init__(**kwargs) + + def __call__(self, y_true, y_pred, max_val): + # Transpose images from NCHW to NHWC + y_true_t = tf.transpose(tf.cast(y_true, tf.float32), [0, 2, 3, 1]) + y_pred_t = tf.transpose(tf.cast(y_pred, tf.float32), [0, 2, 3, 1]) + + ms_ssim_val = tf.image.ssim_multiscale(y_true_t, y_pred_t, max_val, power_factors=self.power_factors, filter_size=self.kernel_size) + # ssim_multiscale returns values in range [0, 1] (where 1 is completely identical) + # subtract from 1 to get loss + return 1.0 - ms_ssim_val + + +nn.MsSsim = MsSsim diff --git a/core/leras/layers/__init__.py b/core/leras/layers/__init__.py index 1c81963..d8f1c9d 100644 --- a/core/leras/layers/__init__.py +++ b/core/leras/layers/__init__.py @@ -13,4 +13,5 @@ from .FRNorm2D import * from .TLU import * from .ScaleAdd import * from .DenseNorm import * -from .AdaIN import * \ No newline at end of file +from .AdaIN import * +from .MsSsim import * diff --git a/core/leras/nn.py b/core/leras/nn.py index ef5c2c9..8ac437b 100644 --- a/core/leras/nn.py +++ b/core/leras/nn.py @@ -40,7 +40,7 @@ class nn(): conv2d_spatial_axes = None floatx = None - + @staticmethod def initialize(device_config=None, floatx="float32", data_format="NHWC"): @@ -50,7 +50,7 @@ class nn(): nn.setCurrentDeviceConfig(device_config) # Manipulate environment variables before import tensorflow - + if 'CUDA_VISIBLE_DEVICES' in os.environ.keys(): os.environ.pop('CUDA_VISIBLE_DEVICES') @@ -77,13 +77,13 @@ class nn(): io.log_info("Caching GPU kernels...") import tensorflow - + tf_version = getattr(tensorflow,'VERSION', None) if tf_version is None: tf_version = tensorflow.version.GIT_VERSION if tf_version[0] == 'v': tf_version = tf_version[1:] - + if tf_version[0] == '2': tf = tensorflow.compat.v1 else: @@ -93,7 +93,7 @@ class nn(): # Disable tensorflow warnings tf_logger = logging.getLogger('tensorflow') tf_logger.setLevel(logging.ERROR) - + if tf_version[0] == '2': tf.disable_v2_behavior() nn.tf = tf @@ -105,20 +105,20 @@ class nn(): import core.leras.optimizers import core.leras.models import core.leras.archis - + # Configure tensorflow session-config if len(device_config.devices) == 0: nn.tf_default_device = "/CPU:0" config = tf.ConfigProto(device_count={'GPU': 0}) else: nn.tf_default_device = "/GPU:0" - config = tf.ConfigProto() + config = tf.ConfigProto(allow_soft_placement=True) config.gpu_options.visible_device_list = ','.join([str(device.index) for device in device_config.devices]) config.gpu_options.force_gpu_compatible = True config.gpu_options.allow_growth = True nn.tf_sess_config = config - + if nn.tf_sess is None: nn.tf_sess = tf.Session(config=nn.tf_sess_config) @@ -273,7 +273,7 @@ class nn(): @staticmethod def ask_choose_device(*args, **kwargs): return nn.DeviceConfig.GPUIndexes( nn.ask_choose_device_idxs(*args,**kwargs) ) - + def __init__ (self, devices=None): devices = devices or [] diff --git a/doc/features/ms-ssim/README.md b/doc/features/ms-ssim/README.md new file mode 100644 index 0000000..c7a0476 --- /dev/null +++ b/doc/features/ms-ssim/README.md @@ -0,0 +1,42 @@ +# Multiscale SSIM (MS-SSIM) + +Allows you to train using the MS-SSIM (multiscale structural similarity index measure) as the main loss metric, +a perceptually more accurate measure of image quality than MSE (mean squared error). + +- [DESCRIPTION](#description) +- [USAGE](#usage) + +![](example.png) + +## DESCRIPTION + +[SSIM](https://en.wikipedia.org/wiki/Structural_similarity) is metric for comparing the perceptial quality of an image: +> SSIM is a perception-based model that considers image degradation as perceived change in structural information, +> while also incorporating important perceptual phenomena, including both luminance masking and contrast masking terms. +> [...] +> Structural information is the idea that the pixels have strong inter-dependencies especially when they are spatially +> close. These dependencies carry important information about the structure of the objects in the visual scene. +> Luminance masking is a phenomenon whereby image distortions (in this context) tend to be less visible in bright +> regions, while contrast masking is a phenomenon whereby distortions become less visible where there is significant +> activity or "texture" in the image. + +The current loss metric is a combination of SSIM (structural similarity index measure) and +[MSE](https://en.wikipedia.org/wiki/Mean_squared_error) (mean squared error). + +[Multiscale SSIM](https://en.wikipedia.org/wiki/Structural_similarity#Multi-Scale_SSIM) is a variant of SSIM that +improves upon SSIM by comparing the similarity at multiple scales (e.g.: full-size, half-size, 1/4 size, etc.) +By using MS-SSIM as our main loss metric, we should expect the image similarity to improve across each scale, improving +both the large scale and small scale detail of the predicted images. + +Original paper: [Wang, Zhou, Eero P. Simoncelli, and Alan C. Bovik. +"Multiscale structural similarity for image quality assessment." +Signals, Systems and Computers, 2004.](https://www.cns.nyu.edu/pub/eero/wang03b.pdf) + +## USAGE + +``` +[n] Use multiscale loss? ( y/n ?:help ) : y +``` + + + diff --git a/doc/features/random-color/README.md b/doc/features/random-color/README.md index 8d8298c..d1aeac1 100644 --- a/doc/features/random-color/README.md +++ b/doc/features/random-color/README.md @@ -19,4 +19,7 @@ maintaining the same `C*` (chroma, relative saturation). ## USAGE -`[n] Random color ( y/n ?:help ) : y` +``` +[n] Random color ( y/n ?:help ) : y +``` + diff --git a/models/Model_SAEHD/Model.py b/models/Model_SAEHD/Model.py index 5522891..cfd67e2 100644 --- a/models/Model_SAEHD/Model.py +++ b/models/Model_SAEHD/Model.py @@ -53,6 +53,8 @@ class SAEHDModel(ModelBase): lr_dropout = {True:'y', False:'n'}.get(lr_dropout, lr_dropout) #backward comp default_lr_dropout = self.options['lr_dropout'] = lr_dropout + default_ms_ssim_loss = self.options['ms_ssim_loss'] = self.load_or_def_option('ms_ssim_loss', False) + default_random_warp = self.options['random_warp'] = self.load_or_def_option('random_warp', True) default_background_power = self.options['background_power'] = self.load_or_def_option('background_power', 0.0) default_true_face_power = self.options['true_face_power'] = self.load_or_def_option('true_face_power', 0.0) @@ -152,6 +154,8 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... self.options['lr_dropout'] = io.input_str (f"Use learning rate dropout", default_lr_dropout, ['n','y','cpu'], help_message="When the face is trained enough, you can enable this option to get extra sharpness and reduce subpixel shake for less amount of iterations. Enabled it before `disable random warp` and before GAN. \nn - disabled.\ny - enabled\ncpu - enabled on CPU. This allows not to use extra VRAM, sacrificing 20% time of iteration.") + self.options['ms_ssim_loss'] = io.input_bool("Use multiscale loss?", default_ms_ssim_loss, help_message="Use Multiscale structural similarity for image quality assessment.") + self.options['random_warp'] = io.input_bool ("Enable random warp of samples", default_random_warp, help_message="Random warp is required to generalize facial expressions of both faces. When the face is trained enough, you can disable it to get extra sharpness and reduce subpixel shake for less amount of iterations.") self.options['gan_version'] = np.clip (io.input_int("GAN version", default_gan_version, add_info="2 or 3", help_message="Choose GAN version (v2: 7/16/2020, v3: 1/3/2021):"), 2, 3) @@ -450,11 +454,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_psd_target_dst_style_masked = gpu_pred_src_dst*gpu_target_dstm_style_blur gpu_psd_target_dst_style_anti_masked = gpu_pred_src_dst*(1.0 - gpu_target_dstm_style_blur) - if resolution < 256: - gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + if self.options['ms_ssim_loss']: + gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0) else: - gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + if resolution < 256: + gpu_src_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_src_masked_opt, gpu_pred_src_src_masked_opt, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += tf.reduce_mean ( 10*tf.square ( gpu_target_src_masked_opt - gpu_pred_src_src_masked_opt ), axis=[1,2,3]) if eyes_prio or mouth_prio: @@ -471,11 +478,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.options['background_power'] > 0: bg_factor = self.options['background_power'] - if resolution < 256: - gpu_src_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + if self.options['ms_ssim_loss']: + gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0) else: - gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + if resolution < 256: + gpu_src_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_src_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_src_anti_masked, gpu_pred_src_src_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_src_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_src_anti_masked - gpu_pred_src_src_anti_masked ), axis=[1,2,3]) face_style_power = self.options['face_style_power'] / 100.0 @@ -487,11 +497,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*nn.dssim( gpu_psd_target_dst_style_anti_masked, gpu_target_dst_style_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) gpu_src_loss += tf.reduce_mean( (10*bg_style_power)*tf.square(gpu_psd_target_dst_style_anti_masked - gpu_target_dst_style_anti_masked), axis=[1,2,3] ) - if resolution < 256: - gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + if self.options['ms_ssim_loss']: + gpu_dst_loss = 10 * nn.MsSsim(resolution)(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0) else: - gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) - gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) + if resolution < 256: + gpu_dst_loss = tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + else: + gpu_dst_loss = tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/11.6) ), axis=[1]) + gpu_dst_loss += tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_masked_opt, gpu_pred_dst_dst_masked_opt, max_val=1.0, filter_size=int(resolution/23.2) ), axis=[1]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dst_masked_opt- gpu_pred_dst_dst_masked_opt ), axis=[1,2,3]) @@ -507,11 +520,14 @@ Examples: df, liae, df-d, df-ud, liae-ud, ... if self.options['background_power'] > 0: bg_factor = self.options['background_power'] - if resolution < 256: - gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + if self.options['ms_ssim_loss']: + gpu_src_loss = 10 * nn.MsSsim(resolution)(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0) else: - gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) - gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) + if resolution < 256: + gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + else: + gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/11.6)), axis=[1]) + gpu_dst_loss += bg_factor * tf.reduce_mean ( 5*nn.dssim(gpu_target_dst_anti_masked, gpu_pred_dst_dst_anti_masked, max_val=1.0, filter_size=int(resolution/23.2)), axis=[1]) gpu_dst_loss += bg_factor * tf.reduce_mean ( 10*tf.square ( gpu_target_dst_anti_masked - gpu_pred_dst_dst_anti_masked ), axis=[1,2,3]) gpu_dst_loss += tf.reduce_mean ( 10*tf.square( gpu_target_dstm - gpu_pred_dst_dstm ),axis=[1,2,3] )