scatter_add_(dim, index, src) → Tensor函数官方链接:https://pytorch.org/docs/stable/tensors.html#torch.Tensor.scatter_add_
- 用torch.tensor.scatter_add_() 函数报错:Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dimension 2 and to be smaller size than src [1, 67, 3]
idxes = torch.tensor([[[ 59, 33, 848],
[1818, 257, 3081],
[ 0, 3234, 320],
[ 59, 21, 16],
[ 756, 516, 1311],
[4990, 1286, 2835],
[ 702, 2446, 1662],
[ 270, 1576, 2220],
[ 963, 201, 775],
[ 0, 3234, 320],
[ 359, 1007, 3563],
[4983, 3339, 2446],
[1039, 4596, 1552],
[ 448, 3075, 2003],
[ 848, 1053, 407],
[2446, 4983, 3339],
[2236, 3056, 1059],
[ 25, 346, 940],
[ 4, 1782, 4376],
[ 433, 475, 91],
[ 223, 1135, 2728],
[ 290, 2235, 610],
[3073, 2693, 3248],
[ 568, 426, 226],
[2344, 2148, 2260],
[ 601, 394, 3207],
[ 0, 3234, 320],
[3828, 1800, 3261],
[ 0, 3234, 320],
[1351, 4438, 1767],
[1852, 2284, 4906],
[4773, 3558, 1311],
[2220, 3589, 1806],
[3073, 2693, 3248],
[1405, 678, 2247],
[ 0, 3234, 320],
[2655, 2558, 3618],
[ 20, 4594, 4574],
[ 20, 775, 822],
[ 189, 106, 102],
[1311, 2234, 2548],
[ 93, 37, 491],
[ 526, 1059, 2332],
[ 0, 3234, 320],
[1282, 3268, 4381],
[3204, 941, 4946],
[3433, 1737, 3983],
[2220, 1576, 3922],
[ 642, 4518, 3075],
[2102, 3225, 1594],
[3728, 838, 3844],
[1029, 2844, 2213],
[ 739, 1025, 411],
[3515, 4990, 4652],
[4983, 3339, 2446],
[ 223, 53, 3995],
[ 408, 228, 158],
[ 290, 33, 221],
[ 126, 2678, 1674],
[ 448, 2003, 253],
[ 33, 290, 221],
[ 223, 106, 189],
[4983, 2446, 3318],
[3305, 1835, 4762],
[ 0, 3234, 320],
[ 0, 3234, 320],
[ 0, 3234, 320]]]).long()
probs = torch.ones([1,67,3])
tmp_trans_scores = torch.zeros([9, 66, 5003])
tmp_trans_scores.scatter_add_(2, idxes, probs)
>>>---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-16-dae8c3e390fa> in <module>
----> 1 tmp_trans_scores.scatter_add_(2, idxes, probs)
RuntimeError: Expected index [1, 67, 3] to be smaller than self [9, 66, 5003] apart from dimension 2 and to be smaller size than src [1, 67, 3]
错误原因:index的每个维度都要小于src和self对应的相应的维度,所以idxes.size=[1,67,3]中,要dim0=1<9,dim1=67<66,dim2=3<5003.