mirror of
https://github.com/iperov/DeepFaceLab.git
synced 2025-07-06 04:52:13 -07:00
fix ModelBase, nnlib
This commit is contained in:
parent
90a7d4b1e7
commit
a9026ccb67
2 changed files with 27 additions and 67 deletions
|
@ -23,8 +23,9 @@ class ModelBase(object):
|
|||
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None):
|
||||
|
||||
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
|
||||
device_args['cpu_only'] = device_args.get('cpu_only',False)
|
||||
|
||||
if device_args['force_gpu_idx'] == -1:
|
||||
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
|
||||
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
|
||||
if len(idxs_names_list) > 1:
|
||||
io.log_info ("You have multi GPUs in a system: ")
|
||||
|
@ -34,10 +35,7 @@ class ModelBase(object):
|
|||
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
|
||||
self.device_args = device_args
|
||||
|
||||
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) )
|
||||
self.device_config = nnlib.active_DeviceConfig
|
||||
self.keras = nnlib.keras
|
||||
self.K = nnlib.keras.backend
|
||||
self.device_config = nnlib.DeviceConfig(allow_growth=False, **self.device_args)
|
||||
|
||||
io.log_info ("Loading model...")
|
||||
|
||||
|
@ -127,8 +125,12 @@ class ModelBase(object):
|
|||
if self.src_scale_mod == 0:
|
||||
self.options.pop('src_scale_mod')
|
||||
|
||||
|
||||
self.onInitializeOptions(self.iter == 0, ask_override)
|
||||
|
||||
nnlib.import_all(self.device_config)
|
||||
self.keras = nnlib.keras
|
||||
self.K = nnlib.keras.backend
|
||||
|
||||
self.onInitialize()
|
||||
|
||||
self.options['batch_size'] = self.batch_size
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue