Когда я пытаюсь сгенерировать список переставленных целочисленных индексов с randperm
, используя C ++ PyTorch API, полученный тензор имеет тип элемента CPUFloatType{10}
вместо целочисленного типа:
int N_SAMPLES = 10;
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES);
cout << shuffled_indices << endl;
возвращается
9
3
8
6
2
5
4
7
1
0
[ CPUFloatType{10} ]
Что нельзя использовать для индексации тензоров, потому что тип элемента - float, а не целочисленный тип. При попытке использовать my_tensor.index(shuffled_indices)
получаю
terminate called after throwing an instance of 'c10::IndexError'
what(): tensors used as indices must be long, byte or bool tensors
Среда:
- python-pytorch, версия 1.6.0-2 в Arch Linux
- g ++ (GCC) 10.1.0
Почему так происходит?