mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-08-14 10:46:59 -07:00
fixes
This commit is contained in:
parent
76ca79216e
commit
c485e1718a
4 changed files with 10 additions and 22 deletions
|
@ -182,23 +182,17 @@ class nn():
|
|||
nn.conv2d_spatial_axes = [2,3]
|
||||
|
||||
@staticmethod
|
||||
def get4Dshape ( w, h, c, data_format=None ):
|
||||
def get4Dshape ( w, h, c ):
|
||||
"""
|
||||
returns 4D shape based on current data_format
|
||||
"""
|
||||
if data_format is None:
|
||||
data_format = nn.data_format
|
||||
|
||||
if data_format == "NHWC":
|
||||
if nn.data_format == "NHWC":
|
||||
return (None,h,w,c)
|
||||
else:
|
||||
return (None,c,h,w)
|
||||
|
||||
@staticmethod
|
||||
def to_data_format( x, to_data_format, from_data_format=None):
|
||||
if from_data_format is None:
|
||||
from_data_format = nn.data_format
|
||||
|
||||
def to_data_format( x, to_data_format, from_data_format):
|
||||
if to_data_format == from_data_format:
|
||||
return x
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ def initialize_tensor_ops(nn):
|
|||
gv = [*zip(grads,vars)]
|
||||
for g,v in gv:
|
||||
if g is None:
|
||||
raise Exception("No gradient for variable {v.name}")
|
||||
raise Exception(f"No gradient for variable {v.name}")
|
||||
return gv
|
||||
nn.tf_gradients = tf_gradients
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue