已经学习了data,dataset和dataloader,不如就先实战根据自己的数据集,写好自定义的dataset吧。
1、首先将每个图数据预处理成Data需要的形式:
x是所有节点的特征,【num_nodes, embed_dim】,要注意这里所有的节点特征维度需要一致;
edge_index是邻接表,有向图:【【0,1】,【1,2】】;无向图:【【0,1,1,2】,【1,2,0,1】】;
y类别标签;
其他自定义的数据,需要是int或者float类型。
最后分别转换成numpy.array类型,使用numpy.savez()保存成npz文件,分别存放在train/eval/test路径下的graph文件夹里,后面要用。
np.savez(os.path.join(path, data_name, 'graph', file_id+'.npz'), x=x, edge_index=edge_idx, y=y, dtype=object)
2、自定义dataset,主要是__getitem__函数,逻辑是传入上面处理好的文件list,然后getitem函数按照列表下标读取,返回Data类型就好。
class GraphDataset(Dataset):
def __init__(self, root, file_list, treeLenDic, lower = 2, upper = 100000):
super(GraphDataset, self).__init__()
self.root = root
self.file_list = list(filter(lambda id: id.split('.')[0] in treeLenDic.keys() and treeLenDic[id.split('.')[0]] >= lower and treeLenDic[id.split('.')[0]] <= upper, file_list))
def __len__(self):
return len(self.file_list)
def __getitem__(self, idx):
id = self.file_list[idx]
data = np.load(os.path.join(self.root, id), allow_pickle=True)
return Data(x=torch.tensor(data['x'], dtype=torch.float32),
edge_index=torch.LongTensor(data['edge_index']),
y=torch.LongTensor([int(data['y'])]))
这里对每个图文件的长度做了筛选,要至少有两个节点,那种只有一个点的就不考虑了,TreeLenDic是个字典,{graph_id: len}.
3. 将Dataset实例化的对象传入DataLoader就可以批量读取数据了
好啦,到这里我数据预处理以及自定义Dataset就搞定了,可以开始学习torch.geometric.nn里面的网络模型啦~