def load_2d_multi_input_dataset_generator_v2(batch_size):
# loding dataset
logger.info("batch size:%s",batch_size)
train_shuffled_b0, val_shuffled_b0, train_shuffled_b1000, val_shuffled_b1000, train_mask_shuffled, val_mask_shuffled = load_2d_multi_input_dataset_v2()
train_length = train_shuffled_b0.shape[0]
logger.info("train datasets shape:%s",train_length)
ranges = train_length/batch_size + 1
logger.info("ranges:%s",ranges)
while True:
for i in range(ranges):
begin = i * batch_size
end = (i+1)*batch_size
if end > train_length:
end = train_length
logger.info("slice datasets:%s-%s-%s", begin, end,
train_shuffled_b0[begin:end].shape)
yield ({'b-value_0':train_shuffled_b0[begin:end],'b-value_1000':train_shuffled_b1000[begin:end]},{'output':train_mask_shuffled[begin:end]})
model.fit_generator(load_2d_multi_input_dataset_generator_v2(20),steps_per_epoch=200,epochs=10,validation_data=(X_val,y_val)
参考
解决fit_generator的batch_size=1的问题
线程安全