mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 13:09:56 -07:00
fix error in model saving
This commit is contained in:
parent
26c83f6e35
commit
56e70edc46
1 changed files with 5 additions and 3 deletions
|
@ -46,7 +46,9 @@ class Saveable():
|
||||||
raise Exception("name must be defined.")
|
raise Exception("name must be defined.")
|
||||||
|
|
||||||
name = self.name
|
name = self.name
|
||||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
|
||||||
|
for w in weights:
|
||||||
|
w_val = nn.tf_sess.run (w).copy()
|
||||||
w_name_split = w.name.split('/', 1)
|
w_name_split = w.name.split('/', 1)
|
||||||
if name != w_name_split[0]:
|
if name != w_name_split[0]:
|
||||||
raise Exception("weight first name != Saveable.name")
|
raise Exception("weight first name != Saveable.name")
|
||||||
|
@ -97,10 +99,10 @@ class Saveable():
|
||||||
nn.batch_set_value(tuples)
|
nn.batch_set_value(tuples)
|
||||||
except:
|
except:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def init_weights(self):
|
def init_weights(self):
|
||||||
nn.init_weights(self.get_weights())
|
nn.init_weights(self.get_weights())
|
||||||
|
|
||||||
nn.Saveable = Saveable
|
nn.Saveable = Saveable
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue