fix for plaidML

This commit is contained in:
iperov 2019-05-24 09:30:39 +04:00
parent 3114ae9d7b
commit c9da4cd89b
2 changed files with 6 additions and 4 deletions

View file

@ -161,7 +161,6 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
return nnlib.code_import_keras
nnlib.backend = device_config.backend
if "tensorflow" in nnlib.backend:
nnlib._import_tf(device_config)
elif nnlib.backend == "plaidML":
@ -174,6 +173,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
import keras as keras_
nnlib.keras = keras_
if 'KERAS_BACKEND' in os.environ:
os.environ.pop('KERAS_BACKEND')
if nnlib.backend == "plaidML":
import plaidml
import plaidml.tile
@ -591,7 +593,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
nnlib.Adam = Adam
def CAInitializerMP( conv_weights_list ):
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
for idx, weights in result:
K.set_value ( conv_weights_list[idx], weights )
@ -696,7 +698,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
x = ReflectionPadding2D( self.pad ) (x)
return self.func(x)
nnlib.Conv2D = Conv2D
class Conv2DTranspose():
def __init__ (self, *args, **kwargs):
self.reflect_pad = False