mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 13:02:15 -07:00
fixes
This commit is contained in:
parent
76ca79216e
commit
c485e1718a
4 changed files with 10 additions and 22 deletions
|
@ -182,23 +182,17 @@ class nn():
|
||||||
nn.conv2d_spatial_axes = [2,3]
|
nn.conv2d_spatial_axes = [2,3]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get4Dshape ( w, h, c, data_format=None ):
|
def get4Dshape ( w, h, c ):
|
||||||
"""
|
"""
|
||||||
returns 4D shape based on current data_format
|
returns 4D shape based on current data_format
|
||||||
"""
|
"""
|
||||||
if data_format is None:
|
if nn.data_format == "NHWC":
|
||||||
data_format = nn.data_format
|
|
||||||
|
|
||||||
if data_format == "NHWC":
|
|
||||||
return (None,h,w,c)
|
return (None,h,w,c)
|
||||||
else:
|
else:
|
||||||
return (None,c,h,w)
|
return (None,c,h,w)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_data_format( x, to_data_format, from_data_format=None):
|
def to_data_format( x, to_data_format, from_data_format):
|
||||||
if from_data_format is None:
|
|
||||||
from_data_format = nn.data_format
|
|
||||||
|
|
||||||
if to_data_format == from_data_format:
|
if to_data_format == from_data_format:
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
|
@ -35,7 +35,7 @@ def initialize_tensor_ops(nn):
|
||||||
gv = [*zip(grads,vars)]
|
gv = [*zip(grads,vars)]
|
||||||
for g,v in gv:
|
for g,v in gv:
|
||||||
if g is None:
|
if g is None:
|
||||||
raise Exception("No gradient for variable {v.name}")
|
raise Exception(f"No gradient for variable {v.name}")
|
||||||
return gv
|
return gv
|
||||||
nn.tf_gradients = tf_gradients
|
nn.tf_gradients = tf_gradients
|
||||||
|
|
||||||
|
|
|
@ -413,18 +413,15 @@ class QModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
n_samples = min(4, self.get_batch_size() )
|
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm),
|
( (warped_src, target_src, target_srcm),
|
||||||
(warped_dst, target_dst, target_dstm) ) = \
|
(warped_dst, target_dst, target_dstm) ) = samples
|
||||||
[ [sample[0:n_samples] for sample in sample_list ]
|
|
||||||
for sample_list in samples ]
|
|
||||||
|
|
||||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
||||||
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
DDM, SDM, = [ np.repeat (x, (3,), -1) for x in [DDM, SDM] ]
|
||||||
|
|
||||||
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
||||||
|
|
||||||
|
n_samples = min(4, self.get_batch_size() )
|
||||||
result = []
|
result = []
|
||||||
st = []
|
st = []
|
||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
|
|
|
@ -450,9 +450,10 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
for gpu_id in range(gpu_count):
|
for gpu_id in range(gpu_count):
|
||||||
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
with tf.device( f'/GPU:{gpu_id}' if len(devices) != 0 else f'/CPU:0' ):
|
||||||
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
|
||||||
with tf.device(f'/CPU:0'):
|
with tf.device(f'/CPU:0'):
|
||||||
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
# slice on CPU, otherwise all batch data will be transfered to GPU first
|
||||||
|
batch_slice = slice( gpu_id*bs_per_gpu, (gpu_id+1)*bs_per_gpu )
|
||||||
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
|
gpu_warped_src = self.warped_src [batch_slice,:,:,:]
|
||||||
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
|
gpu_warped_dst = self.warped_dst [batch_slice,:,:,:]
|
||||||
gpu_target_src = self.target_src [batch_slice,:,:,:]
|
gpu_target_src = self.target_src [batch_slice,:,:,:]
|
||||||
|
@ -646,7 +647,6 @@ class SAEHDModel(ModelBase):
|
||||||
model.init_weights()
|
model.init_weights()
|
||||||
|
|
||||||
# initializing sample generators
|
# initializing sample generators
|
||||||
|
|
||||||
if self.is_training:
|
if self.is_training:
|
||||||
t = SampleProcessor.Types
|
t = SampleProcessor.Types
|
||||||
if self.options['face_type'] == 'h':
|
if self.options['face_type'] == 'h':
|
||||||
|
@ -710,12 +710,8 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
#override
|
#override
|
||||||
def onGetPreview(self, samples):
|
def onGetPreview(self, samples):
|
||||||
n_samples = min(4, self.get_batch_size() )
|
|
||||||
|
|
||||||
( (warped_src, target_src, target_srcm),
|
( (warped_src, target_src, target_srcm),
|
||||||
(warped_dst, target_dst, target_dstm) ) = \
|
(warped_dst, target_dst, target_dstm) ) = samples
|
||||||
[ [sample[0:n_samples] for sample in sample_list ]
|
|
||||||
for sample_list in samples ]
|
|
||||||
|
|
||||||
if self.options['learn_mask']:
|
if self.options['learn_mask']:
|
||||||
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
S, D, SS, DD, DDM, SD, SDM = [ np.clip( nn.to_data_format(x,"NHWC", self.model_data_format), 0.0, 1.0) for x in ([target_src,target_dst] + self.AE_view (target_src, target_dst) ) ]
|
||||||
|
@ -725,6 +721,7 @@ class SAEHDModel(ModelBase):
|
||||||
|
|
||||||
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
target_srcm, target_dstm = [ nn.to_data_format(x,"NHWC", self.model_data_format) for x in ([target_srcm, target_dstm] )]
|
||||||
|
|
||||||
|
n_samples = min(4, self.get_batch_size() )
|
||||||
result = []
|
result = []
|
||||||
st = []
|
st = []
|
||||||
for i in range(n_samples):
|
for i in range(n_samples):
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue