mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
1
This commit is contained in:
parent
56de1d9fa5
commit
9797a70fd3
5 changed files with 112 additions and 47 deletions
|
@ -77,18 +77,25 @@ def initialize_layers(nn):
|
|||
|
||||
def init_weights(self):
|
||||
ops = []
|
||||
tuples = []
|
||||
|
||||
ca_tuples_w = []
|
||||
ca_tuples = []
|
||||
for w in self.get_weights():
|
||||
initializer = w.initializer
|
||||
for input in initializer.inputs:
|
||||
if "_cai_" in input.name:
|
||||
tuples.append ( (w, nn.initializers.ca.generate(w.shape.as_list(), dtype= w.dtype.as_numpy_dtype) ) )
|
||||
ca_tuples_w.append (w)
|
||||
ca_tuples.append ( (w.shape.as_list(), w.dtype.as_numpy_dtype) )
|
||||
break
|
||||
else:
|
||||
ops.append (initializer)
|
||||
|
||||
nn.tf_sess.run (ops)
|
||||
nn.tf_batch_set_value(tuples)
|
||||
if len(ops) != 0:
|
||||
nn.tf_sess.run (ops)
|
||||
|
||||
if len(ca_tuples) != 0:
|
||||
nn.tf_batch_set_value( [*zip(ca_tuples_w, nn.initializers.ca.generate_batch (ca_tuples))] )
|
||||
|
||||
nn.Saveable = Saveable
|
||||
|
||||
class LayerBase():
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue