basic hack for osx

This commit is contained in:
jkennedyvz 2022-05-22 18:02:03 -07:00
parent 0d19d8ec8e
commit 8aadbfef4a
3 changed files with 6 additions and 5 deletions

View file

@ -143,7 +143,7 @@ def gen_warp_params (w, flip=False, rotation_range=[-10,10], scale_range=[-0.5,
################ ################
#random transform #random transform
random_transform_mat = cv2.getRotationMatrix2D((w // 2, w // 2), rotation, scale) random_transform_mat = cv2.getRotationMatrix2D((int(w // 2), int(w // 2)), rotation, scale)
random_transform_mat[:, 2] += (tx*w, ty*w) random_transform_mat[:, 2] += (tx*w, ty*w)
params = dict() params = dict()
@ -178,4 +178,4 @@ def warp_by_params (params, img, can_warp, can_transform, can_flip, border_repli
img = img[...,None] img = img[...,None]
if can_flip and params['flip']: if can_flip and params['flip']:
img = img[:,::-1,...] img = img[:,::-1,...]
return img return img

View file

@ -46,7 +46,7 @@ class Devices(object):
idx_mem = 0 idx_mem = 0
for device in self.devices: for device in self.devices:
mem = device.total_mem mem = device.total_mem
if mem > idx_mem: if mem >= idx_mem:
result = device result = device
idx_mem = mem idx_mem = mem
return result return result
@ -56,7 +56,7 @@ class Devices(object):
idx_mem = sys.maxsize idx_mem = sys.maxsize
for device in self.devices: for device in self.devices:
mem = device.total_mem mem = device.total_mem
if mem < idx_mem: if mem <= idx_mem:
result = device result = device
idx_mem = mem idx_mem = mem
return result return result
@ -270,4 +270,4 @@ class Devices(object):
os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem']) os.environ[f'NN_DEVICE_{i}_TOTAL_MEM'] = str(device['total_mem'])
os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem']) os.environ[f'NN_DEVICE_{i}_FREE_MEM'] = str(device['free_mem'])
os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc']) os.environ[f'NN_DEVICE_{i}_CC'] = str(device['cc'])
""" """

View file

@ -112,6 +112,7 @@ class nn():
config.gpu_options.force_gpu_compatible = True config.gpu_options.force_gpu_compatible = True
config.gpu_options.allow_growth = True config.gpu_options.allow_growth = True
config.allow_soft_placement = True
nn.tf_sess_config = config nn.tf_sess_config = config
if nn.tf_sess is None: if nn.tf_sess is None: