This commit is contained in:
Colombo 2020-01-22 10:29:17 +04:00
parent 9797a70fd3
commit beed145d29
2 changed files with 17 additions and 43 deletions

View file

@ -32,8 +32,11 @@ def initialize_tensor_ops(nn):
def tf_gradients ( loss, vars ):
grads = gradients.gradients(loss, vars, colocate_gradients_with_ops=True )
#todo none gradient for var
return [*zip(grads,vars)]
gv = [*zip(grads,vars)]
for g,v in gv:
if g is None:
raise Exception("No gradient for variable {v.name}")
return gv
nn.tf_gradients = tf_gradients
def tf_average_gv_list(grad_var_list, tf_device_string=None):