mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-05 20:42:11 -07:00
refactoring
This commit is contained in:
parent
9eb15065f9
commit
45abcff3d1
1 changed files with 7 additions and 3 deletions
|
@ -19,7 +19,7 @@ TernausNet: U-Net with VGG11 Encoder Pre-Trained on ImageNet for Image Segmentat
|
|||
|
||||
class TernausNet(object):
|
||||
VERSION = 1
|
||||
def __init__ (self, name, resolution, face_type_str, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
|
||||
def __init__ (self, name, resolution, face_type_str=None, load_weights=True, weights_file_root=None, training=False, place_model_on_cpu=False):
|
||||
nn.initialize(data_format="NHWC")
|
||||
tf = nn.tf
|
||||
|
||||
|
@ -125,12 +125,16 @@ class TernausNet(object):
|
|||
self.net = Ternaus(3, 64, name='Ternaus')
|
||||
self.net_weights = self.net.get_weights()
|
||||
|
||||
self.model_filename_list = [ [self.net, '%s_%d_%s.npy' % (name, resolution, face_type_str) ] ]
|
||||
model_name = f'{name}_{resolution}'
|
||||
if face_type_str is not None:
|
||||
model_name += f'_{face_type_str}'
|
||||
|
||||
self.model_filename_list = [ [self.net, f'{model_name}.npy'] ]
|
||||
|
||||
if training:
|
||||
self.opt = nn.TFRMSpropOptimizer(lr=0.0001, name='opt')
|
||||
self.opt.initialize_variables (self.net_weights, vars_on_cpu=place_model_on_cpu)
|
||||
self.model_filename_list += [ [self.opt, '%s_%d_%s_opt.npy' % (name, resolution, face_type_str) ] ]
|
||||
self.model_filename_list += [ [self.opt, f'{model_name}_opt.npy' ] ]
|
||||
else:
|
||||
_, pred = self.net([self.input_t])
|
||||
def net_run(input_np):
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue