mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 21:12:07 -07:00
fix for plaidML
This commit is contained in:
parent
3114ae9d7b
commit
c9da4cd89b
2 changed files with 6 additions and 4 deletions
|
@ -299,7 +299,7 @@ def get_plaidML_devices():
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
return plaidML_devices
|
return plaidML_devices
|
||||||
|
|
||||||
if not has_nvidia_device:
|
if not has_nvidia_device:
|
||||||
get_plaidML_devices()
|
get_plaidML_devices()
|
||||||
|
|
||||||
|
|
|
@ -161,7 +161,6 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
return nnlib.code_import_keras
|
return nnlib.code_import_keras
|
||||||
|
|
||||||
nnlib.backend = device_config.backend
|
nnlib.backend = device_config.backend
|
||||||
|
|
||||||
if "tensorflow" in nnlib.backend:
|
if "tensorflow" in nnlib.backend:
|
||||||
nnlib._import_tf(device_config)
|
nnlib._import_tf(device_config)
|
||||||
elif nnlib.backend == "plaidML":
|
elif nnlib.backend == "plaidML":
|
||||||
|
@ -174,6 +173,9 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
import keras as keras_
|
import keras as keras_
|
||||||
nnlib.keras = keras_
|
nnlib.keras = keras_
|
||||||
|
|
||||||
|
if 'KERAS_BACKEND' in os.environ:
|
||||||
|
os.environ.pop('KERAS_BACKEND')
|
||||||
|
|
||||||
if nnlib.backend == "plaidML":
|
if nnlib.backend == "plaidML":
|
||||||
import plaidml
|
import plaidml
|
||||||
import plaidml.tile
|
import plaidml.tile
|
||||||
|
@ -591,7 +593,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
nnlib.Adam = Adam
|
nnlib.Adam = Adam
|
||||||
|
|
||||||
def CAInitializerMP( conv_weights_list ):
|
def CAInitializerMP( conv_weights_list ):
|
||||||
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
#Convolution Aware Initialization https://arxiv.org/abs/1702.06295
|
||||||
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
|
result = CAInitializerMPSubprocessor ( [ (i, K.int_shape(conv_weights)) for i, conv_weights in enumerate(conv_weights_list) ], K.floatx(), K.image_data_format() ).run()
|
||||||
for idx, weights in result:
|
for idx, weights in result:
|
||||||
K.set_value ( conv_weights_list[idx], weights )
|
K.set_value ( conv_weights_list[idx], weights )
|
||||||
|
@ -696,7 +698,7 @@ NLayerDiscriminator = nnlib.NLayerDiscriminator
|
||||||
x = ReflectionPadding2D( self.pad ) (x)
|
x = ReflectionPadding2D( self.pad ) (x)
|
||||||
return self.func(x)
|
return self.func(x)
|
||||||
nnlib.Conv2D = Conv2D
|
nnlib.Conv2D = Conv2D
|
||||||
|
|
||||||
class Conv2DTranspose():
|
class Conv2DTranspose():
|
||||||
def __init__ (self, *args, **kwargs):
|
def __init__ (self, *args, **kwargs):
|
||||||
self.reflect_pad = False
|
self.reflect_pad = False
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue