This commit is contained in:
Colombo 2020-01-21 21:05:29 +04:00
commit 9797a70fd3
5 changed files with 112 additions and 47 deletions

View file

@ -77,18 +77,25 @@ def initialize_layers(nn):
def init_weights(self):
ops = []
tuples = []
ca_tuples_w = []
ca_tuples = []
for w in self.get_weights():
initializer = w.initializer
for input in initializer.inputs:
if "_cai_" in input.name:
tuples.append ( (w, nn.initializers.ca.generate(w.shape.as_list(), dtype= w.dtype.as_numpy_dtype) ) )
ca_tuples_w.append (w)
ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) )
break
else:
ops.append (initializer)
nn.tf_sess.run (ops)
nn.tf_batch_set_value(tuples)
if len(ops) != 0:
nn.tf_sess.run (ops)
if len(ca_tuples) != 0:
nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
nn.Saveable = Saveable
class LayerBase():