mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-19 21:13:20 -07:00
global refactoring and fixes,
removed support of extracted(aligned) PNG faces. Use old builds to convert from PNG to JPG. fanseg model file in facelib/ is renamed
This commit is contained in:
parent
921b464d5b
commit
61472cdaf7
82 changed files with 3838 additions and 3812 deletions
103
core/leras/layers/Saveable.py
Normal file
103
core/leras/layers/Saveable.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
import pickle
|
||||
from pathlib import Path
|
||||
from core import pathex
|
||||
import numpy as np
|
||||
|
||||
from core.leras import nn
|
||||
|
||||
tf = nn.tf
|
||||
|
||||
class Saveable():
|
||||
def __init__(self, name=None):
|
||||
self.name = name
|
||||
|
||||
#override
|
||||
def get_weights(self):
|
||||
#return tf tensors that should be initialized/loaded/saved
|
||||
return []
|
||||
|
||||
#override
|
||||
def get_weights_np(self):
|
||||
weights = self.get_weights()
|
||||
if len(weights) == 0:
|
||||
return []
|
||||
return nn.tf_sess.run (weights)
|
||||
|
||||
def set_weights(self, new_weights):
|
||||
weights = self.get_weights()
|
||||
if len(weights) != len(new_weights):
|
||||
raise ValueError ('len of lists mismatch')
|
||||
|
||||
tuples = []
|
||||
for w, new_w in zip(weights, new_weights):
|
||||
|
||||
if len(w.shape) != new_w.shape:
|
||||
new_w = new_w.reshape(w.shape)
|
||||
|
||||
tuples.append ( (w, new_w) )
|
||||
|
||||
nn.batch_set_value (tuples)
|
||||
|
||||
def save_weights(self, filename, force_dtype=None):
|
||||
d = {}
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
name = self.name
|
||||
for w, w_val in zip(weights, nn.tf_sess.run (weights)):
|
||||
w_name_split = w.name.split('/', 1)
|
||||
if name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
if force_dtype is not None:
|
||||
w_val = w_val.astype(force_dtype)
|
||||
|
||||
d[ w_name_split[1] ] = w_val
|
||||
|
||||
d_dumped = pickle.dumps (d, 4)
|
||||
pathex.write_bytes_safe ( Path(filename), d_dumped )
|
||||
|
||||
def load_weights(self, filename):
|
||||
"""
|
||||
returns True if file exists
|
||||
"""
|
||||
filepath = Path(filename)
|
||||
if filepath.exists():
|
||||
result = True
|
||||
d_dumped = filepath.read_bytes()
|
||||
d = pickle.loads(d_dumped)
|
||||
else:
|
||||
return False
|
||||
|
||||
weights = self.get_weights()
|
||||
|
||||
if self.name is None:
|
||||
raise Exception("name must be defined.")
|
||||
|
||||
tuples = []
|
||||
for w in weights:
|
||||
w_name_split = w.name.split('/')
|
||||
if self.name != w_name_split[0]:
|
||||
raise Exception("weight first name != Saveable.name")
|
||||
|
||||
sub_w_name = "/".join(w_name_split[1:])
|
||||
|
||||
w_val = d.get(sub_w_name, None)
|
||||
|
||||
if w_val is None:
|
||||
#io.log_err(f"Weight {w.name} was not loaded from file {filename}")
|
||||
tuples.append ( (w, w.initializer) )
|
||||
else:
|
||||
w_val = np.reshape( w_val, w.shape.as_list() )
|
||||
tuples.append ( (w, w_val) )
|
||||
|
||||
nn.batch_set_value(tuples)
|
||||
|
||||
return True
|
||||
|
||||
def init_weights(self):
|
||||
nn.init_weights(self.get_weights())
|
||||
|
||||
nn.Saveable = Saveable
|
Loading…
Add table
Add a link
Reference in a new issue