fixes and optimizations

This commit is contained in:
Colombo 2020-01-07 13:45:54 +04:00
parent 842a48964f
commit d3e6b435aa
6 changed files with 72 additions and 58 deletions

View file

@ -58,13 +58,12 @@ class AVATARModel(ModelBase):
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
if self.is_first_run():
conv_weights_list = []
self.CA_conv_weights_list = []
if self.is_first_run():
for model, _ in self.get_model_filename_list():
for layer in model.layers:
if type(layer) == keras.layers.Conv2D:
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
CAInitializerMP ( conv_weights_list )
self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
if not self.is_first_run():
self.load_weights_safe( self.get_model_filename_list() )
@ -247,7 +246,14 @@ class AVATARModel(ModelBase):
#override
def onSave(self):
self.save_weights_safe( self.get_model_filename_list() )
#override
def on_success_train_one_iter(self):
if len(self.CA_conv_weights_list) != 0:
exec(nnlib.import_all(), locals(), globals())
CAInitializerMP ( self.CA_conv_weights_list )
self.CA_conv_weights_list = []
#override
def onTrainOneIter(self, generators_samples, generators_list):
warped_src64, src64, src64m = generators_samples[0]