increased H64 decoder size for 4GB+

This commit is contained in:
iperov 2018-06-07 10:06:24 +04:00
parent 4aaac5e42e
commit adc1e701de

View file

@ -18,7 +18,7 @@ class Model(ModelBase):
tf = self.tf tf = self.tf
keras = self.keras keras = self.keras
K = keras.backend K = keras.backend
self.set_vram_batch_requirements( {1.5:2,2:2,3:4,4:8,5:16,6:32,7:32,8:32,9:48} ) self.set_vram_batch_requirements( {1.5:2,2:2,3:8,4:16,5:24,6:32,7:40,8:48} )
bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb) bgr_shape, mask_shape, self.encoder, self.decoder_src, self.decoder_dst = self.Build(self.created_vram_gb)
if not self.is_first_run(): if not self.is_first_run():
@ -157,14 +157,19 @@ class Model(ModelBase):
def Decoder(): def Decoder():
if created_vram_gb >= 4: if created_vram_gb >= 4:
input_ = self.keras.layers.Input(shape=(8, 8, 512)) input_ = self.keras.layers.Input(shape=(8, 8, 512))
x = input_
x = upscale(self.keras, x, 512)
x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128)
else: else:
input_ = self.keras.layers.Input(shape=(8, 8, 256)) input_ = self.keras.layers.Input(shape=(8, 8, 256))
x = input_ x = input_
x = upscale(self.keras, x, 256) x = upscale(self.keras, x, 256)
x = upscale(self.keras, x, 128) x = upscale(self.keras, x, 128)
x = upscale(self.keras, x, 64) x = upscale(self.keras, x, 64)
y = input_ #mask decoder y = input_ #mask decoder
y = upscale(self.keras, y, 256) y = upscale(self.keras, y, 256)
y = upscale(self.keras, y, 128) y = upscale(self.keras, y, 128)