diff --git a/models/ModelBase.py b/models/ModelBase.py index 3529de8..173e33e 100644 --- a/models/ModelBase.py +++ b/models/ModelBase.py @@ -1,4 +1,5 @@ import os +import json import time import inspect import pickle @@ -119,6 +120,7 @@ class ModelBase(object): nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) ) self.device_config = nnlib.active_DeviceConfig + self.keras = nnlib.keras self.onInitialize() @@ -271,25 +273,52 @@ class ModelBase(object): } self.model_data_path.write_bytes( pickle.dumps(model_data) ) - def load_weights_safe(self, model_filename_list): + def load_weights_safe(self, model_filename_list, optimizer_filename_list=[]): for model, filename in model_filename_list: filename = self.get_strpath_storage_for_file(filename) - if Path(filename).exists(): + if Path(filename).exists(): model.load_weights(filename) - def save_weights_safe(self, model_filename_list): + if len(optimizer_filename_list) != 0: + opt_filename = self.get_strpath_storage_for_file('opt.h5') + if Path(opt_filename).exists(): + h5dict = self.keras.utils.io_utils.H5Dict(opt_filename, mode='r') + try: + for x in optimizer_filename_list: + opt, filename = x + if filename in h5dict: + opt = opt.__class__.from_config( json.loads(h5dict[filename]) ) + x[0] = opt + finally: + h5dict.close() + + return [x[0] for x in optimizer_filename_list] + + def save_weights_safe(self, model_filename_list, optimizer_filename_list=[]): for model, filename in model_filename_list: filename = self.get_strpath_storage_for_file(filename) model.save_weights( filename + '.tmp' ) - - for model, filename in model_filename_list: + + rename_list = model_filename_list + if len(optimizer_filename_list) != 0: + opt_filename = self.get_strpath_storage_for_file('opt.h5') + h5dict = self.keras.utils.io_utils.H5Dict(opt_filename + '.tmp', mode='w') + try: + for opt, filename in optimizer_filename_list: + h5dict[filename] = json.dumps(opt.get_config()) + finally: + h5dict.close() + rename_list += [('', 'opt.h5')] + + for _, filename in rename_list: filename = self.get_strpath_storage_for_file(filename) source_filename = Path(filename+'.tmp') - target_filename = Path(filename) - if target_filename.exists(): - target_filename.unlink() - - source_filename.rename ( str(target_filename) ) + if source_filename.exists(): + target_filename = Path(filename) + if target_filename.exists(): + target_filename.unlink() + source_filename.rename ( str(target_filename) ) + def debug_one_epoch(self): images = [] diff --git a/models/Model_SAE/Model.py b/models/Model_SAE/Model.py index ea62f8e..965fd02 100644 --- a/models/Model_SAE/Model.py +++ b/models/Model_SAE/Model.py @@ -84,7 +84,6 @@ class SAEModel(ModelBase): def onInitialize(self): exec(nnlib.import_all(), locals(), globals()) SAEModel.initialize_nn_functions() - self.set_vram_batch_requirements({1.5:4}) resolution = self.options['resolution'] @@ -111,6 +110,7 @@ class SAEModel(ModelBase): target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)] target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)] + weights_to_load = [] if self.options['archi'] == 'liae': self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape)) @@ -223,8 +223,13 @@ class SAEModel(ModelBase): pred_dst_dstm = self.decoder_dstm(warped_dst_code) pred_src_dstm = self.decoder_srcm(warped_dst_code) - self.load_weights_safe(weights_to_load) - + self.src_dst_opt, \ + self.src_dst_mask_opt = self.load_weights_safe( + weights_to_load, + [ [Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_opt'], + [Adam(lr=5e-5, beta_1=0.5, beta_2=0.999), 'src_dst_mask_opt'] + ]) + pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ] if self.options['learn_mask']: @@ -262,9 +267,6 @@ class SAEModel(ModelBase): psd_target_dst_anti_masked_ar = [ pred_src_dst_sigm_ar[i]*target_dstm_anti_sigm_ar[i] for i in range(len(pred_src_dst_sigm_ar))] if self.is_training_mode: - def optimizer(): - return Adam(lr=5e-5, beta_1=0.5, beta_2=0.999) - if self.options['archi'] == 'liae': src_dst_loss_train_weights = self.encoder.trainable_weights + self.inter_B.trainable_weights + self.inter_AB.trainable_weights + self.decoder.trainable_weights if self.options['learn_mask']: @@ -307,7 +309,7 @@ class SAEModel(ModelBase): feed += target_dst_ar[::-1] feed += target_dstm_ar[::-1] - self.src_dst_train = K.function (feed,[src_loss,dst_loss], optimizer().get_updates(src_loss+dst_loss, src_dst_loss_train_weights) ) + self.src_dst_train = K.function (feed,[src_loss,dst_loss], self.src_dst_opt.get_updates(src_loss+dst_loss, src_dst_loss_train_weights) ) if self.options['learn_mask']: src_mask_loss = sum([ K.mean(K.square(target_srcm_ar[-1]-pred_src_srcm[-1])) for i in range(len(target_srcm_ar)) ]) @@ -317,7 +319,7 @@ class SAEModel(ModelBase): feed += target_srcm_ar[::-1] feed += target_dstm_ar[::-1] - self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], optimizer().get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) ) + self.src_dst_mask_train = K.function (feed,[src_mask_loss, dst_mask_loss], self.src_dst_mask_opt.get_updates(src_mask_loss+dst_mask_loss, src_dst_mask_loss_train_weights) ) if self.options['learn_mask']: self.AE_view = K.function ([warped_src, warped_dst], [pred_src_src[-1], pred_dst_dst[-1], pred_src_dst[-1], pred_src_dstm[-1]]) @@ -353,8 +355,12 @@ class SAEModel(ModelBase): ]) #override def onSave(self): + opt_ar = [ [self.src_dst_opt, 'src_dst_opt'], + [self.src_dst_mask_opt, 'src_dst_mask_opt'] + ] + ar = [] if self.options['archi'] == 'liae': - ar = [[self.encoder, 'encoder.h5'], + ar += [[self.encoder, 'encoder.h5'], [self.inter_B, 'inter_B.h5'], [self.inter_AB, 'inter_AB.h5'], [self.decoder, 'decoder.h5'] @@ -362,15 +368,15 @@ class SAEModel(ModelBase): if self.options['learn_mask']: ar += [ [self.decoderm, 'decoderm.h5'] ] elif self.options['archi'] == 'df' or self.options['archi'] == 'vg': - ar = [[self.encoder, 'encoder.h5'], + ar += [[self.encoder, 'encoder.h5'], [self.decoder_src, 'decoder_src.h5'], [self.decoder_dst, 'decoder_dst.h5'] ] if self.options['learn_mask']: ar += [ [self.decoder_srcm, 'decoder_srcm.h5'], [self.decoder_dstm, 'decoder_dstm.h5'] ] - - self.save_weights_safe(ar) + + self.save_weights_safe(ar, opt_ar) #override diff --git a/nnlib/nnlib.py b/nnlib/nnlib.py index c864e4d..db0cb7c 100644 --- a/nnlib/nnlib.py +++ b/nnlib/nnlib.py @@ -71,17 +71,18 @@ ZeroPadding2D = keras.layers.ZeroPadding2D RandomNormal = keras.initializers.RandomNormal Model = keras.models.Model -Adam = keras.optimizers.Adam +#Adam = keras.optimizers.Adam +Adam = nnlib.Adam modelify = nnlib.modelify gaussian_blur = nnlib.gaussian_blur style_loss = nnlib.style_loss dssim = nnlib.dssim - PixelShuffler = nnlib.PixelShuffler SubpixelUpscaler = nnlib.SubpixelUpscaler Scale = nnlib.Scale + #ReflectionPadding2D = nnlib.ReflectionPadding2D #AddUniformNoise = nnlib.AddUniformNoise """ @@ -181,7 +182,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator if device_config.use_fp16: nnlib.keras.backend.set_floatx('float16') - + if "tensorflow" in device_config.backend: nnlib.keras.backend.set_session(nnlib.tf_sess) @@ -432,7 +433,78 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator base_config = super(Scale, self).get_config() return dict(list(base_config.items()) + list(config.items())) nnlib.Scale = Scale - + + class Adam(keras.optimizers.Optimizer): + def __init__(self, lr=0.001, beta_1=0.9, beta_2=0.999, + epsilon=None, decay=0., amsgrad=False, iterations=0, **kwargs): + super(Adam, self).__init__(**kwargs) + with K.name_scope(self.__class__.__name__): + self.iterations = K.variable(iterations, dtype='int64', name='iterations') + self.lr = K.variable(lr, name='lr') + self.beta_1 = K.variable(beta_1, name='beta_1') + self.beta_2 = K.variable(beta_2, name='beta_2') + self.decay = K.variable(decay, name='decay') + if epsilon is None: + epsilon = K.epsilon() + self.epsilon = epsilon + self.initial_decay = decay + self.amsgrad = amsgrad + + @keras.legacy.interfaces.legacy_get_updates_support + def get_updates(self, loss, params): + grads = self.get_gradients(loss, params) + self.updates = [K.update_add(self.iterations, 1)] + + lr = self.lr + if self.initial_decay > 0: + lr = lr * (1. / (1. + self.decay * K.cast(self.iterations, + K.dtype(self.decay)))) + + t = K.cast(self.iterations, K.floatx()) + 1 + lr_t = lr * (K.sqrt(1. - K.pow(self.beta_2, t)) / + (1. - K.pow(self.beta_1, t))) + + ms = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] + vs = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] + if self.amsgrad: + vhats = [K.zeros(K.int_shape(p), dtype=K.dtype(p)) for p in params] + else: + vhats = [K.zeros(1) for _ in params] + self.weights = [self.iterations] + ms + vs + vhats + + for p, g, m, v, vhat in zip(params, grads, ms, vs, vhats): + m_t = (self.beta_1 * m) + (1. - self.beta_1) * g + v_t = (self.beta_2 * v) + (1. - self.beta_2) * K.square(g) + if self.amsgrad: + vhat_t = K.maximum(vhat, v_t) + p_t = p - lr_t * m_t / (K.sqrt(vhat_t) + self.epsilon) + self.updates.append(K.update(vhat, vhat_t)) + else: + p_t = p - lr_t * m_t / (K.sqrt(v_t) + self.epsilon) + + self.updates.append(K.update(m, m_t)) + self.updates.append(K.update(v, v_t)) + new_p = p_t + + # Apply constraints. + if getattr(p, 'constraint', None) is not None: + new_p = p.constraint(new_p) + + self.updates.append(K.update(p, new_p)) + return self.updates + + def get_config(self): + config = {'iterations': int(K.get_value(self.iterations)), + 'lr': float(K.get_value(self.lr)), + 'beta_1': float(K.get_value(self.beta_1)), + 'beta_2': float(K.get_value(self.beta_2)), + 'decay': float(K.get_value(self.decay)), + 'epsilon': self.epsilon, + 'amsgrad': self.amsgrad} + base_config = super(Adam, self).get_config() + return dict(list(base_config.items()) + list(config.items())) + nnlib.Adam = Adam + ''' not implemented in plaidML class ReflectionPadding2D(keras.layers.Layer):