""" File: train_gender_classifier.py Author: Octavio Arriaga Email: arriaga.camargo@gmail.com Github: https://github.com/oarriaga Description: Train gender classification model """ from keras.callbacks import CSVLogger, ModelCheckpoint, EarlyStopping from keras.callbacks import ReduceLROnPlateau from utils.datasets import DataManager from models.cnn import mini_XCEPTION from utils.data_augmentation import ImageGenerator from utils.datasets import split_imdb_data # parameters batch_size = 32 num_epochs = 1000 validation_split = .2 do_random_crop = False patience = 100 num_classes = 2 dataset_name = 'imdb' input_shape = (64, 64, 1) if input_shape[2] == 1: grayscale = True images_path = '../datasets/imdb_crop/' log_file_path = '../trained_models/gender_models/gender_training.log' trained_models_path = '../trained_models/gender_models/gender_mini_XCEPTION' # data loader data_loader = DataManager(dataset_name) ground_truth_data = data_loader.get_data() train_keys, val_keys = split_imdb_data(ground_truth_data, validation_split) print('Number of training samples:', len(train_keys)) print('Number of validation samples:', len(val_keys)) image_generator = ImageGenerator(ground_truth_data, batch_size, input_shape[:2], train_keys, val_keys, None, path_prefix=images_path, vertical_flip_probability=0, grayscale=grayscale, do_random_crop=do_random_crop) # model parameters/compilation model = mini_XCEPTION(input_shape, num_classes) model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy']) model.summary() # model callbacks early_stop = EarlyStopping('val_loss', patience=patience) reduce_lr = ReduceLROnPlateau('val_loss', factor=0.1, patience=int(patience/2), verbose=1) csv_logger = CSVLogger(log_file_path, append=False) model_names = trained_models_path + '.{epoch:02d}-{val_acc:.2f}.hdf5' model_checkpoint = ModelCheckpoint(model_names, monitor='val_loss', verbose=1, save_best_only=True, save_weights_only=False) callbacks = [model_checkpoint, csv_logger, early_stop, reduce_lr] # training model model.fit_generator(image_generator.flow(mode='train'), steps_per_epoch=int(len(train_keys) / batch_size), epochs=num_epochs, verbose=1, callbacks=callbacks, validation_data=image_generator.flow('val'), validation_steps=int(len(val_keys) / batch_size))