mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
fixes and optimizations
This commit is contained in:
parent
842a48964f
commit
d3e6b435aa
6 changed files with 72 additions and 58 deletions
|
@ -58,13 +58,12 @@ class AVATARModel(ModelBase):
|
|||
self.D = modelify(AVATARModel.Discriminator() ) (Input(df_bgr_shape))
|
||||
self.C = modelify(AVATARModel.ResNet (9, n_blocks=6, ngf=128, use_dropout=False))( Input(res_bgr_t_shape))
|
||||
|
||||
if self.is_first_run():
|
||||
conv_weights_list = []
|
||||
self.CA_conv_weights_list = []
|
||||
if self.is_first_run():
|
||||
for model, _ in self.get_model_filename_list():
|
||||
for layer in model.layers:
|
||||
if type(layer) == keras.layers.Conv2D:
|
||||
conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
CAInitializerMP ( conv_weights_list )
|
||||
self.CA_conv_weights_list += [layer.weights[0]] #Conv2D kernel_weights
|
||||
|
||||
if not self.is_first_run():
|
||||
self.load_weights_safe( self.get_model_filename_list() )
|
||||
|
@ -247,7 +246,14 @@ class AVATARModel(ModelBase):
|
|||
#override
|
||||
def onSave(self):
|
||||
self.save_weights_safe( self.get_model_filename_list() )
|
||||
|
||||
|
||||
#override
|
||||
def on_success_train_one_iter(self):
|
||||
if len(self.CA_conv_weights_list) != 0:
|
||||
exec(nnlib.import_all(), locals(), globals())
|
||||
CAInitializerMP ( self.CA_conv_weights_list )
|
||||
self.CA_conv_weights_list = []
|
||||
|
||||
#override
|
||||
def onTrainOneIter(self, generators_samples, generators_list):
|
||||
warped_src64, src64, src64m = generators_samples[0]
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue