mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-16 10:03:41 -07:00
leras ModelBase.summary()
This commit is contained in:
parent
1fd4a25882
commit
80036f950f
1 changed files with 77 additions and 2 deletions
|
@ -270,6 +270,81 @@ def initialize_layers(nn):
|
|||
|
||||
return nn.tf_sess.run ( self.run_output, feed_dict=feed_dict)
|
||||
|
||||
def summary(self):
|
||||
layers = self.get_layers()
|
||||
layers_names = []
|
||||
layers_params = []
|
||||
|
||||
max_len_str = 0
|
||||
max_len_param_str = 0
|
||||
delim_str = "-"
|
||||
|
||||
total_params = 0
|
||||
|
||||
#Get layers names and str lenght for delim
|
||||
for l in layers:
|
||||
if len(str(l))>max_len_str:
|
||||
max_len_str = len(str(l))
|
||||
layers_names+=[str(l).capitalize()]
|
||||
|
||||
#Get params for each layer
|
||||
layers_params = [ int(np.sum(np.prod(w.shape) for w in l.get_weights())) for l in layers ]
|
||||
total_params = np.sum(layers_params)
|
||||
|
||||
#Get str lenght for delim
|
||||
for p in layers_params:
|
||||
if len(str(p))>max_len_param_str:
|
||||
max_len_param_str=len(str(p))
|
||||
|
||||
#Set delim
|
||||
for i in range(max_len_str+max_len_param_str+3):
|
||||
delim_str += "-"
|
||||
|
||||
output = "\n"+delim_str+"\n"
|
||||
|
||||
#Format model name str
|
||||
model_name_str = "| "+self.name.capitalize()
|
||||
len_model_name_str = len(model_name_str)
|
||||
for i in range(len(delim_str)-len_model_name_str):
|
||||
model_name_str+= " " if i!=(len(delim_str)-len_model_name_str-2) else " |"
|
||||
|
||||
output += model_name_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
|
||||
#Format layers table
|
||||
for i in range(len(layers_names)):
|
||||
output += delim_str +"\n"
|
||||
|
||||
l_name = layers_names[i]
|
||||
l_param = str(layers_params[i])
|
||||
l_param_str = ""
|
||||
if len(l_name)<=max_len_str:
|
||||
for i in range(max_len_str - len(l_name)):
|
||||
l_name+= " "
|
||||
|
||||
if len(l_param)<=max_len_param_str:
|
||||
for i in range(max_len_param_str - len(l_param)):
|
||||
l_param_str+= " "
|
||||
|
||||
l_param_str += l_param
|
||||
|
||||
|
||||
output +="| "+l_name+"|"+l_param_str+"| \n"
|
||||
|
||||
output += delim_str +"\n"
|
||||
|
||||
#Format sum of params
|
||||
total_params_str = "| Total params count: "+str(total_params)
|
||||
len_total_params_str = len(total_params_str)
|
||||
for i in range(len(delim_str)-len_total_params_str):
|
||||
total_params_str+= " " if i!=(len(delim_str)-len_total_params_str-2) else " |"
|
||||
|
||||
output += total_params_str +"\n"
|
||||
output += delim_str +"\n"
|
||||
|
||||
io.log_info(output)
|
||||
|
||||
nn.ModelBase = ModelBase
|
||||
|
||||
class Conv2D(LayerBase):
|
||||
|
@ -284,7 +359,7 @@ def initialize_layers(nn):
|
|||
if not isinstance(dilations, int):
|
||||
raise ValueError ("dilations must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
|
||||
|
@ -388,7 +463,7 @@ def initialize_layers(nn):
|
|||
if not isinstance(strides, int):
|
||||
raise ValueError ("strides must be an int type")
|
||||
kernel_size = int(kernel_size)
|
||||
|
||||
|
||||
if dtype is None:
|
||||
dtype = nn.tf_floatx
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue