Move ping pong logic into generator class

This commit is contained in:
jh 2019-08-23 17:32:09 -07:00
commit 1ecc6f1b62
5 changed files with 122 additions and 78 deletions

View file

@ -14,6 +14,7 @@ import imagelib
from interact import interact as io
from nnlib import nnlib
from samplelib import SampleGeneratorBase
from samplelib.SampleGeneratorPingPong import PingPongOptions, Paddle
from utils import Path_utils, std_utils
from utils.cv2_utils import *
@ -67,7 +68,7 @@ class ModelBase(object):
self.debug = debug
self.is_training_mode = (training_data_src_path is not None and training_data_dst_path is not None)
self.paddle = 'pong'
self.paddle = Paddle.PONG
self.iter = 0
self.options = {}
@ -210,15 +211,18 @@ class ModelBase(object):
self.onInitializeOptions(self.iter == 0, ask_override)
self.ping_pong_options = PingPongOptions(enabled=self.options['ping_pong'],
iterations=self.ping_pong_iter,
model_iter=self.iter,
paddle=self.paddle,
batch_cap=self.batch_size)
nnlib.import_all(self.device_config)
self.keras = nnlib.keras
self.K = nnlib.keras.backend
self.onInitialize()
self.options['batch_size'] = self.batch_size
self.paddle = self.options.get('paddle', 'ping')
if self.debug or self.batch_size == 0:
self.batch_size = 1
@ -413,7 +417,7 @@ class ModelBase(object):
def save(self):
self.options['batch_size'] = self.batch_size
self.options['paddle'] = self.paddle
self.options['paddle'] = self.ping_pong_options.paddle
summary_path = self.get_strpath_storage_for_file('summary.txt')
Path( summary_path ).write_text(self.model_summary_text)
self.onSave()
@ -542,12 +546,11 @@ class ModelBase(object):
return [ generator.generate_next() for generator in self.generator_list]
def train_one_iter(self):
if self.iter == 1 and self.options.get('ping_pong', False):
self.set_batch_size(1)
self.paddle = 'ping'
elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
self.set_batch_size(self.batch_cap)
# if self.iter == 1 and self.options.get('ping_pong', False):
# self.set_batch_size(1)
# self.paddle = 'ping'
# elif not self.options.get('ping_pong', False) and self.batch_cap != self.batch_size:
# self.set_batch_size(self.batch_cap)
sample = self.generate_next_sample()
iter_time = time.time()
losses = self.onTrainOneIter(sample, self.generator_list)
@ -574,21 +577,6 @@ class ModelBase(object):
img = (np.concatenate ( [preview_lh, preview], axis=0 ) * 255).astype(np.uint8)
cv2_imwrite (filepath, img )
if self.iter % self.ping_pong_iter == 0 and self.iter != 0 and self.options.get('ping_pong', False):
if self.batch_size == self.batch_cap:
self.paddle = 'pong'
if self.batch_size > self.batch_cap:
self.set_batch_size(self.batch_cap)
self.paddle = 'pong'
if self.batch_size == 1:
self.paddle = 'ping'
if self.paddle == 'ping':
self.save()
self.set_batch_size(self.batch_size + 1)
else:
self.save()
self.set_batch_size(self.batch_size - 1)
self.iter += 1
return self.iter, iter_time, self.batch_size

View file

@ -487,7 +487,8 @@ class SAEModel(ModelBase):
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
range(ms_count)],
ping_pong=self.ping_pong_options,
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
@ -500,7 +501,8 @@ class SAEModel(ModelBase):
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
range(ms_count)],
ping_pong=self.ping_pong_options,)
])
# override
@ -557,7 +559,8 @@ class SAEModel(ModelBase):
'apply_ct': apply_random_ct} for i in range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)]
range(ms_count)],
ping_pong=self.ping_pong_options,
),
SampleGeneratorFace(training_data_dst_path, debug=self.is_debug(), batch_size=self.batch_size,
@ -570,7 +573,8 @@ class SAEModel(ModelBase):
range(ms_count)] + \
[{'types': (t.IMG_TRANSFORMED, face_type, t.MODE_M),
'resolution': resolution // (2 ** i)} for i in
range(ms_count)])
range(ms_count)],
ping_pong=self.ping_pong_options,)
])
# override

View file

@ -5,8 +5,9 @@ import cv2
import numpy as np
from facelib import LandmarksProcessor
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor,
from samplelib import (SampleGeneratorBase, SampleLoader, SampleProcessor, SampleGeneratorPingPong,
SampleType)
from samplelib.SampleGeneratorPingPong import PingPongOptions
from utils import iter_utils
@ -19,12 +20,12 @@ output_sample_types = [
'''
class SampleGeneratorFace(SampleGeneratorBase):
class SampleGeneratorFace(SampleGeneratorPingPong):
def __init__(self, samples_path, debug, batch_size, sort_by_yaw=False, sort_by_yaw_target_samples_path=None,
random_ct_samples_path=None, sample_process_options=SampleProcessor.Options(),
output_sample_types=[], add_sample_idx=False, generators_count=2, generators_random_seed=None,
**kwargs):
super().__init__(samples_path, debug, batch_size)
ping_pong=PingPongOptions(), **kwargs):
super().__init__(samples_path, debug, batch_size, ping_pong)
self.sample_process_options = sample_process_options
self.output_sample_types = output_sample_types
self.add_sample_idx = add_sample_idx
@ -64,6 +65,7 @@ class SampleGeneratorFace(SampleGeneratorBase):
def __next__(self):
self.generator_counter += 1
generator = self.generators[self.generator_counter % len(self.generators)]
super().__next__()
return next(generator)
def batch_func(self, param):

View file

@ -0,0 +1,50 @@
from enum import Enum
from samplelib import SampleGeneratorBase
class Paddle(Enum):
PING = 'ping' # Ascending
PONG = 'pong' # Descending
class PingPongOptions:
def __init__(self, enabled=False, iterations=1000, model_iter=1, paddle=Paddle.PING, batch_cap=1):
self.enabled = enabled
self.iterations = iterations
self.model_iter = model_iter
self.paddle = paddle
self.batch_cap = batch_cap
class SampleGeneratorPingPong(SampleGeneratorBase):
def __init__(self, *args, batch_size, ping_pong=PingPongOptions()):
self.ping_pong = ping_pong
super().__init__(*args, batch_size)
def __next__(self):
if self.ping_pong.enabled and self.ping_pong.model_iter % self.ping_pong.iterations == 0 \
and self.ping_pong.model_iter != 0:
# If batch size is greater then batch cap, set it to batch cap
if self.batch_size > self.ping_pong.batch_cap:
self.batch_size = self.ping_pong.batch_cap
# If we are at the batch cap, switch to PONG (descend)
if self.batch_size == self.ping_pong.batch_cap:
self.paddle = Paddle.PONG
# Else if we are at 1, switch to PING (ascend)
elif self.batch_size == 1:
self.paddle = Paddle.PING
# If PING (ascending) increase the batch size
if self.paddle is Paddle.PING:
self.batch_size += 1
# Else decrease the batch size
else:
self.batch_size -= 1
self.ping_pong.model_iter += 1
super().__next__()