fix ModelBase, nnlib

This commit is contained in:
iperov 2019-03-13 19:50:16 +04:00
parent 90a7d4b1e7
commit a9026ccb67
2 changed files with 27 additions and 67 deletions

View file

@ -23,8 +23,9 @@ class ModelBase(object):
def __init__(self, model_path, training_data_src_path=None, training_data_dst_path=None, debug = False, device_args = None):
device_args['force_gpu_idx'] = device_args.get('force_gpu_idx',-1)
device_args['cpu_only'] = device_args.get('cpu_only',False)
if device_args['force_gpu_idx'] == -1:
if device_args['force_gpu_idx'] == -1 and not device_args['cpu_only']:
idxs_names_list = nnlib.device.getValidDevicesIdxsWithNamesList()
if len(idxs_names_list) > 1:
io.log_info ("You have multi GPUs in a system: ")
@ -34,10 +35,7 @@ class ModelBase(object):
device_args['force_gpu_idx'] = io.input_int("Which GPU idx to choose? ( skip: best GPU ) : ", -1, [ x[0] for x in idxs_names_list] )
self.device_args = device_args
nnlib.import_all ( nnlib.DeviceConfig(allow_growth=False, **self.device_args) )
self.device_config = nnlib.active_DeviceConfig
self.keras = nnlib.keras
self.K = nnlib.keras.backend
self.device_config = nnlib.DeviceConfig(allow_growth=False, **self.device_args)
io.log_info ("Loading model...")
@ -127,8 +125,12 @@ class ModelBase(object):
if self.src_scale_mod == 0:
self.options.pop('src_scale_mod')
self.onInitializeOptions(self.iter == 0, ask_override)
nnlib.import_all(self.device_config)
self.keras = nnlib.keras
self.K = nnlib.keras.backend
self.onInitialize()
self.options['batch_size'] = self.batch_size