added saving model_summary.txt

This commit is contained in:
iperov 2019-02-21 19:55:44 +04:00
parent f0a20b46d3
commit 97685ce0ae
2 changed files with 52 additions and 37 deletions

View file

@ -150,34 +150,39 @@ class ModelBase(object):
if (self.sample_for_preview is None) or (self.epoch == 0):
self.sample_for_preview = self.generate_next_sample()
print ("===== Model summary =====")
print ("== Model name: " + self.get_model_name())
print ("==")
print ("== Current epoch: " + str(self.epoch) )
print ("==")
print ("== Model options:")
model_summary_text = []
model_summary_text += ["===== Model summary ====="]
model_summary_text += ["== Model name: " + self.get_model_name()]
model_summary_text += ["=="]
model_summary_text += ["== Current epoch: " + str(self.epoch)]
model_summary_text += ["=="]
model_summary_text += ["== Model options:"]
for key in self.options.keys():
print ("== |== %s : %s" % (key, self.options[key]) )
model_summary_text += ["== |== %s : %s" % (key, self.options[key])]
if self.device_config.multi_gpu:
print ("== |== multi_gpu : True ")
model_summary_text += ["== |== multi_gpu : True "]
print ("== Running on:")
model_summary_text += ["== Running on:"]
if self.device_config.cpu_only:
print ("== |== [CPU]")
model_summary_text += ["== |== [CPU]"]
else:
for idx in self.device_config.gpu_idxs:
print ("== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx)) )
model_summary_text += ["== |== [%d : %s]" % (idx, nnlib.device.getDeviceName(idx))]
if not self.device_config.cpu_only and self.device_config.gpu_vram_gb[0] == 2:
print ("==")
print ("== WARNING: You are using 2GB GPU. Result quality may be significantly decreased.")
print ("== If training does not start, close all programs and try again.")
print ("== Also you can disable Windows Aero Desktop to get extra free VRAM.")
print ("==")
model_summary_text += ["=="]
model_summary_text += ["== WARNING: You are using 2GB GPU. Result quality may be significantly decreased."]
model_summary_text += ["== If training does not start, close all programs and try again."]
model_summary_text += ["== Also you can disable Windows Aero Desktop to get extra free VRAM."]
model_summary_text += ["=="]
print ("=========================")
model_summary_text += ["========================="]
model_summary_text = "\r\n".join (model_summary_text)
self.model_summary_text = model_summary_text
print(model_summary_text)
#overridable
def onInitializeOptions(self, is_first_run, ask_override):
pass
@ -258,7 +263,8 @@ class ModelBase(object):
if self.supress_std_once:
supressor = std_utils.suppress_stdout_stderr()
supressor.__enter__()
Path( self.get_strpath_storage_for_file('summary.txt') ).write_text(self.model_summary_text)
self.onSave()
if self.supress_std_once:
@ -274,6 +280,7 @@ class ModelBase(object):
def load_weights_safe(self, model_filename_list):
for model, filename in model_filename_list:
filename = self.get_strpath_storage_for_file(filename)
if Path(filename).exists():
model.load_weights(filename)