Разве randperm в API PyTorch C ++ не должно возвращать тензор с типом по умолчанию int?

Когда я пытаюсь сгенерировать список переставленных целочисленных индексов с randperm, используя C ++ PyTorch API, полученный тензор имеет тип элемента CPUFloatType{10} вместо целочисленного типа:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES); 
cout << shuffled_indices << endl;

возвращается

 9
 3
 8
 6
 2
 5
 4
 7
 1
 0
[ CPUFloatType{10} ]

Что нельзя использовать для индексации тензоров, потому что тип элемента - float, а не целочисленный тип. При попытке использовать my_tensor.index(shuffled_indices) получаю

terminate called after throwing an instance of 'c10::IndexError'
  what():  tensors used as indices must be long, byte or bool tensors

Среда:

  • python-pytorch, версия 1.6.0-2 в Arch Linux
  • g ++ (GCC) 10.1.0

Почему так происходит?


person tmaric    schedule 20.08.2020    source источник


Ответы (1)


Это потому, что типом по умолчанию любого тензора, который вы создаете с помощью torch, всегда является float. Если вы хотите иначе, вы должны указать это с помощью структуры параметра TensorOptions:

int N_SAMPLES = 10;               
torch::Tensor shuffled_indices = torch::randperm(N_SAMPLES, torch::TensorOptions().dtype(at::kLong));
cout << shuffled_indices.dtype() << endl;
>>> long
person trialNerror    schedule 20.08.2020