mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-31 04:00:11 -07:00
leras ModelBase.summary()
This commit is contained in:
parent
1fd4a25882
commit
80036f950f
1 changed files with 77 additions and 2 deletions
|
@ -270,6 +270,81 @@ def initialize_layers(nn):
|
||||||
|
|
||||||
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
|
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
|
||||||
|
|
||||||
|
def summary(self):
|
||||||
|
layers = self.get_layers()
|
||||||
|
layers_names = []
|
||||||
|
layers_params = []
|
||||||
|
|
||||||
|
max_len_str = 0
|
||||||
|
max_len_param_str = 0
|
||||||
|
delim_str = "-"
|
||||||
|
|
||||||
|
total_params = 0
|
||||||
|
|
||||||
|
#Get layers names and str lenght for delim
|
||||||
|
for l in layers:
|
||||||
|
if len(str(l))>max_len_str:
|
||||||
|
max_len_str = len(str(l))
|
||||||
|
layers_names+=[str(l).capitalize()]
|
||||||
|
|
||||||
|
#Get params for each layer
|
||||||
|
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
|
||||||
|
total_params = np.sum(layers_params)
|
||||||
|
|
||||||
|
#Get str lenght for delim
|
||||||
|
for p in layers_params:
|
||||||
|
if len(str(p))>max_len_param_str:
|
||||||
|
max_len_param_str=len(str(p))
|
||||||
|
|
||||||
|
#Set delim
|
||||||
|
for i in range(max_len_str+max_len_param_str+3):
|
||||||
|
delim_str += "-"
|
||||||
|
|
||||||
|
output = "\n"+delim_str+"\n"
|
||||||
|
|
||||||
|
#Format model name str
|
||||||
|
model_name_str = "| "+self.name.capitalize()
|
||||||
|
len_model_name_str = len(model_name_str)
|
||||||
|
for i in range(len(delim_str)-len_model_name_str):
|
||||||
|
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
|
||||||
|
|
||||||
|
output += model_name_str +"\n"
|
||||||
|
output += delim_str +"\n"
|
||||||
|
|
||||||
|
|
||||||
|
#Format layers table
|
||||||
|
for i in range(len(layers_names)):
|
||||||
|
output += delim_str +"\n"
|
||||||
|
|
||||||
|
l_name = layers_names[i]
|
||||||
|
l_param = str(layers_params[i])
|
||||||
|
l_param_str = ""
|
||||||
|
if len(l_name)<=max_len_str:
|
||||||
|
for i in range(max_len_str - len(l_name)):
|
||||||
|
l_name+= " "
|
||||||
|
|
||||||
|
if len(l_param)<=max_len_param_str:
|
||||||
|
for i in range(max_len_param_str - len(l_param)):
|
||||||
|
l_param_str+= " "
|
||||||
|
|
||||||
|
l_param_str += l_param
|
||||||
|
|
||||||
|
|
||||||
|
output +="| "+l_name+"|"+l_param_str+"| \n"
|
||||||
|
|
||||||
|
output += delim_str +"\n"
|
||||||
|
|
||||||
|
#Format sum of params
|
||||||
|
total_params_str = "| Total params count: "+str(total_params)
|
||||||
|
len_total_params_str = len(total_params_str)
|
||||||
|
for i in range(len(delim_str)-len_total_params_str):
|
||||||
|
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
|
||||||
|
|
||||||
|
output += total_params_str +"\n"
|
||||||
|
output += delim_str +"\n"
|
||||||
|
|
||||||
|
io.log_info(output)
|
||||||
|
|
||||||
nn.ModelBase = ModelBase
|
nn.ModelBase = ModelBase
|
||||||
|
|
||||||
class Conv2D(LayerBase):
|
class Conv2D(LayerBase):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue