leras ModelBase.summary()

This commit is contained in:
Colombo 2020-01-28 21:38:15 +04:00
parent 1fd4a25882
commit 80036f950f

View file

@ -270,6 +270,81 @@ def initialize_layers(nn):
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
def summary(self):
layers = self.get_layers()
layers_names = []
layers_params = []
max_len_str = 0
max_len_param_str = 0
delim_str = "-"
total_params = 0
#Get layers names and str lenght for delim
for l in layers:
if len(str(l))>max_len_str:
max_len_str = len(str(l))
layers_names+=[str(l).capitalize()]
#Get params for each layer
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
total_params = np.sum(layers_params)
#Get str lenght for delim
for p in layers_params:
if len(str(p))>max_len_param_str:
max_len_param_str=len(str(p))
#Set delim
for i in range(max_len_str+max_len_param_str+3):
delim_str += "-"
output = "\n"+delim_str+"\n"
#Format model name str
model_name_str = "| "+self.name.capitalize()
len_model_name_str = len(model_name_str)
for i in range(len(delim_str)-len_model_name_str):
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
output += model_name_str +"\n"
output += delim_str +"\n"
#Format layers table
for i in range(len(layers_names)):
output += delim_str +"\n"
l_name = layers_names[i]
l_param = str(layers_params[i])
l_param_str = ""
if len(l_name)<=max_len_str:
for i in range(max_len_str - len(l_name)):
l_name+= " "
if len(l_param)<=max_len_param_str:
for i in range(max_len_param_str - len(l_param)):
l_param_str+= " "
l_param_str += l_param
output +="| "+l_name+"|"+l_param_str+"| \n"
output += delim_str +"\n"
#Format sum of params
total_params_str = "| Total params count: "+str(total_params)
len_total_params_str = len(total_params_str)
for i in range(len(delim_str)-len_total_params_str):
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
output += total_params_str +"\n"
output += delim_str +"\n"
io.log_info(output)
nn.ModelBase = ModelBase
class Conv2D(LayerBase):
@ -284,7 +359,7 @@ def initialize_layers(nn):
if not isinstance(dilations, int):
raise ValueError ("dilations must be an int type")
kernel_size = int(kernel_size)
if dtype is None:
dtype = nn.tf_floatx
@ -388,7 +463,7 @@ def initialize_layers(nn):
if not isinstance(strides, int):
raise ValueError ("strides must be an int type")
kernel_size = int(kernel_size)
if dtype is None:
dtype = nn.tf_floatx