mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
added saving model_summary.txt
This commit is contained in:
parent
f0a20b46d3
commit
97685ce0ae
2 changed files with 52 additions and 37 deletions
|
@ -150,34 +150,39 @@ class ModelBase(object):
|
|||
if (self.sample_for_preview is None) or (self.epoch == 0):
|
||||
self.sample_for_preview = self.generate_next_sample()
|
||||
|
||||
print ("===== Model summary =====")
|
||||
print ("== Model name: " + self.get_model_name())
|
||||
print ("==")
|
||||
print ("== Current epoch: " + str(self.epoch) )
|
||||
print ("==")
|
||||
print ("== Model options:")
|
||||
model_summary_text = []
|
||||
|
||||
model_summary_text += ["===== Model summary ====="]
|
||||
model_summary_text += ["== Model name: " + self.get_model_name()]
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== Current epoch: " + str(self.epoch)]
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== Model options:"]
|
||||
for key in self.options.keys():
|
||||
print ("== |== %s : %s" % (key, self.options[key]) )
|
||||
model_summary_text += ["== |== %s : %s" % (key, self.options[key])]
|
||||
|
||||
if self.device_config.multi_gpu:
|
||||
print ("== |== multi_gpu : True ")
|
||||
model_summary_text += ["== |== multi_gpu : True "]
|
||||
|
||||
print ("== Running on:")
|
||||
model_summary_text += ["== Running on:"]
|
||||
if self.device_config.cpu_only:
|
||||
print ("== |== [CPU]")
|
||||
model_summary_text += ["== |== [CPU]"]
|
||||
else:
|
||||
for idx in self.device_config.gpu_idxs:
|
||||
print ("== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx)) )
|
||||
model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))]
|
||||
|
||||
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2:
|
||||
print ("==")
|
||||
print ("== WARNING: You are using 2GB GPU. Result quality may be significantly decreased.")
|
||||
print ("== If training does not start, close all programs and try again.")
|
||||
print ("== Also you can disable Windows Aero Desktop to get extra free VRAM.")
|
||||
print ("==")
|
||||
model_summary_text += ["=="]
|
||||
model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."]
|
||||
model_summary_text += ["== If training does not start, close all programs and try again."]
|
||||
model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."]
|
||||
model_summary_text += ["=="]
|
||||
|
||||
print ("=========================")
|
||||
|
||||
model_summary_text += ["========================="]
|
||||
model_summary_text = "\r\n".join (model_summary_text)
|
||||
self.model_summary_text = model_summary_text
|
||||
print(model_summary_text)
|
||||
|
||||
#overridable
|
||||
def onInitializeOptions(self, is_first_run, ask_override):
|
||||
pass
|
||||
|
@ -258,7 +263,8 @@ class ModelBase(object):
|
|||
if self.supress_std_once:
|
||||
supressor = std_utils.suppress_stdout_stderr()
|
||||
supressor.__enter__()
|
||||
|
||||
|
||||
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
|
||||
self.onSave()
|
||||
|
||||
if self.supress_std_once:
|
||||
|
@ -274,6 +280,7 @@ class ModelBase(object):
|
|||
|
||||
def load_weights_safe(self, model_filename_list):
|
||||
for model, filename in model_filename_list:
|
||||
filename = self.get_strpath_storage_for_file(filename)
|
||||
if Path(filename).exists():
|
||||
model.load_weights(filename)
|
||||
|
||||
|
|
|
@ -106,6 +106,7 @@ class SAEModel(ModelBase):
|
|||
target_dst_ar = [ Input ( ( bgr_shape[0] // (2**i) ,)*2 + (bgr_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
target_dstm_ar = [ Input ( ( mask_shape[0] // (2**i) ,)*2 + (mask_shape[-1],) ) for i in range(ms_count-1, -1, -1)]
|
||||
|
||||
weights_to_load = []
|
||||
if self.options['archi'] == 'liae':
|
||||
self.encoder = modelify(SAEModel.LIAEEncFlow(resolution, self.options['lighter_encoder'], ed_ch_dims=ed_ch_dims) ) (Input(bgr_shape))
|
||||
|
||||
|
@ -122,13 +123,14 @@ class SAEModel(ModelBase):
|
|||
self.decoderm = modelify(SAEModel.LIAEDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (inter_output_Inputs)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.inter_B.load_weights (self.get_strpath_storage_for_file(self.inter_BH5))
|
||||
self.inter_AB.load_weights (self.get_strpath_storage_for_file(self.inter_ABH5))
|
||||
self.decoder.load_weights (self.get_strpath_storage_for_file(self.decoderH5))
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.inter_B , 'inter_B.h5'],
|
||||
[self.inter_AB, 'inter_AB.h5'],
|
||||
[self.decoder , 'decoder.h5'],
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
self.decoderm.load_weights (self.get_strpath_storage_for_file(self.decodermH5))
|
||||
|
||||
weights_to_load += [ [self.decoderm, 'decoderm.h5'] ]
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_src_inter_AB_code = self.inter_AB (warped_src_code)
|
||||
warped_src_inter_code = Concatenate()([warped_src_inter_AB_code,warped_src_inter_AB_code])
|
||||
|
@ -162,13 +164,15 @@ class SAEModel(ModelBase):
|
|||
self.decoder_dstm = modelify(SAEModel.DFDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||
|
||||
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'],
|
||||
]
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
pred_src_src = self.decoder_src(warped_src_code)
|
||||
|
@ -193,12 +197,14 @@ class SAEModel(ModelBase):
|
|||
self.decoder_dstm = modelify(SAEModel.VGDecFlow (mask_shape[2],ed_ch_dims=int(ed_ch_dims/1.5) )) (dec_Inputs)
|
||||
|
||||
if not self.is_first_run():
|
||||
self.encoder.load_weights (self.get_strpath_storage_for_file(self.encoderH5))
|
||||
self.decoder_src.load_weights (self.get_strpath_storage_for_file(self.decoder_srcH5))
|
||||
self.decoder_dst.load_weights (self.get_strpath_storage_for_file(self.decoder_dstH5))
|
||||
weights_to_load += [ [self.encoder , 'encoder.h5'],
|
||||
[self.decoder_src, 'decoder_src.h5'],
|
||||
[self.decoder_dst, 'decoder_dst.h5']
|
||||
]
|
||||
if self.options['learn_mask']:
|
||||
self.decoder_srcm.load_weights (self.get_strpath_storage_for_file(self.decoder_srcmH5))
|
||||
self.decoder_dstm.load_weights (self.get_strpath_storage_for_file(self.decoder_dstmH5))
|
||||
weights_to_load += [ [self.decoder_srcm, 'decoder_srcm.h5'],
|
||||
[self.decoder_dstm, 'decoder_dstm.h5'],
|
||||
]
|
||||
|
||||
warped_src_code = self.encoder (warped_src)
|
||||
warped_dst_code = self.encoder (warped_dst)
|
||||
|
@ -211,7 +217,9 @@ class SAEModel(ModelBase):
|
|||
pred_src_srcm = self.decoder_srcm(warped_src_code)
|
||||
pred_dst_dstm = self.decoder_dstm(warped_dst_code)
|
||||
pred_src_dstm = self.decoder_srcm(warped_dst_code)
|
||||
|
||||
|
||||
self.load_weights_safe(weights_to_load)
|
||||
|
||||
pred_src_src, pred_dst_dst, pred_src_dst, = [ [x] if type(x) != list else x for x in [pred_src_src, pred_dst_dst, pred_src_dst, ] ]
|
||||
|
||||
if self.options['learn_mask']:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue