refactoring

This commit is contained in:
Colombo 2020-03-09 11:05:26 +04:00
parent 9eb15065f9
commit 45abcff3d1

View file

@ -19,7 +19,7 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
class TernausNet(object):
VERSION = 1
def __init__ (self, name, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
nn.initialize(data_format="NHWC")
tf = nn.tf
@ -125,12 +125,16 @@ class TernausNet(object):
self.net = Ternaus(3, 64, name='Ternaus')
self.net_weights = self.net.get_weights()
self.model_filename_list = [ [self.net, '%s_%d_%s.npy' % (name, resolution, face_type_str) ] ]
model_name = f'{name}_{resolution}'
if face_type_str is not None:
model_name += f'_{face_type_str}'
self.model_filename_list = [ [self.net, f'{model_name}.npy'] ]
if training:
self.opt = nn.TFRMSpropOptimizer(lr=0.0001, name='opt')
self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu)
self.model_filename_list += [ [self.opt, '%s_%d_%s_opt.npy' % (name, resolution, face_type_str) ] ]
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
else:
_, pred = self.net([self.input_t])
def net_run(input_np):