登陆

使用生成器从成批的.npy文件中训练Keras模型吗?

admin 2022-11-25 5人围观 ,发现0个评论

目前,在使用Keras训练图像数据时,我正在处理一个大数据问题。我的目录中有一批.npy文件。每批包含512张图像。每个批次的对应标签文件均为.npy。看起来像:{image_file_1.npy,label_file_1.npy,...,image_file_37.npy,label_file_37}。每个图像文件都具有尺寸(512, 199, 199, 3),每个标签文件都具有尺寸(512, 1)(即1或0)。如果我将所有图像加载到一个ndarray中,则将超过35 GB。到目前为止,已阅读所有Keras Doc。我仍然找不到如何使用自定义生成器进行训练的方法。我已经阅读过flow_from_dictImageDataGenerator(...).flow()但是在那种情况下它们并不理想,或者我不知道如何定制它们。

import numpy as np import keras from keras.models import Sequential from keras.layers import Dense, Dropout, Flatten from keras.layers import Conv2D, MaxPooling2D from keras.optimizers import SGD from keras.preprocessing.image import ImageDataGenerator  val_gen = ImageDataGenerator(rescale=1./255) x_test = np.load("../data/val_file.npy") y_test = np.load("../data/val_label.npy") val_gen.fit(x_test)  model = Sequential() ... model_1.add(layers.Dense(512, activation='relu')) model_1.add(layers.Dense(1, activation='sigmoid'))  model.compile(loss='categorical_crossentropy',                optimizer=sgd,                 metrics=['acc'])  model.fit_generator(generate_batch_from_directory() # should give 1 image file and 1 label file                     validation_data=val_gen.flow(x_test,                                                   y_test,                                                   batch_size=64),                     validation_steps=32) 

So here generate_batch_from_directory() should take image_file_i.npy and label_file_i.npy every time and optimise the weight until there is no batch left. Each image array in the .npy files has already been processed with augmentation, rotation and scaling. Each .npy file is properly mixed with data from class 1 and 0 (50/50).

If I append all the batch and create a big file such as:

X_train = np.append([image_file_1, ..., image_file_37]) y_train = np.append([label_file_1, ..., label_file_37]) 

It does not fit in the memory. Otherwise I could use .flow() to generate image sets to train the model.

Thanks for any advise.



1> DataPsycho..:

终于,我能够解决这个问题。但是我必须遍历源代码和文档keras.utils.Sequence来构建自己的生成器类。该文档对理解Generator在Kears中的工作原理有很大帮助。您可以在我的kaggle笔记本中阅读更多详细信息:

all_files_loc = "datapsycho/imglake/population/train/image_files/" all_files = os.listdir(all_files_loc)  image_label_map = {         "image_file_{}.npy".format(i+1): "label_file_{}.npy".format(i+1)         for i in range(int(len(all_files)/2))} partition = [item for item in all_files if "image_file" in item]  class DataGenerator(keras.utils.Sequence):      def __init__(self, file_list):         """Constructor can be expanded,            with batch size, dimentation etc.         """         self.file_list = file_list         self.on_epoch_end()      def __len__(self):       'Take all batches in each iteration'       return int(len(self.file_list))      def __getitem__(self, index):       'Get next batch'       # Generate indexes of the batch       indexes = self.indexes[index:(index+1)]        # single file       file_list_temp = [self.file_list[k] for k in indexes]        # Set of X_train and y_train       X, y = self.__data_generation(file_list_temp)        return X, y      def on_epoch_end(self):       'Updates indexes after each epoch'       self.indexes = np.arange(len(self.file_list))      def __data_generation(self, file_list_temp):       'Generates data containing batch_size samples'       data_loc = "datapsycho/imglake/population/train/image_files/"       # Generate data       for ID in file_list_temp:           x_file_path = os.path.join(data_loc, ID)           y_file_path = os.path.join(data_loc, image_label_map.get(ID))            # Store sample           X = np.load(x_file_path)            # Store class           y = np.load(y_file_path)        return X, y  # ==================== # train set # ==================== all_files_loc = "datapsycho/imglake/population/train/image_files/" all_files = os.listdir(all_files_loc)  training_generator = DataGenerator(partition) validation_generator = ValDataGenerator(val_partition) # work same as training generator  hst = model.fit_generator(generator=training_generator,                             epochs=200,                             validation_data=validation_generator,                            use_multiprocessing=True,                            max_queue_size=32) 

请发表您的评论
不容错过
Powered By Z-BlogPHP