CuPy и Dirichlet дают мне TypeError: неподдерживаемые типы операндов для + =: 'int' и 'tuple'

Я просто хочу создать случайную матрицу A, векторы которой взяты из распределения Дирихле. Функция отлично работает с numpy:

import numpy as np
A = np.random.dirichlet(np.ones(n), n)

Когда я делаю то же самое с cupy

import cupy as cp
A = cp.random.dirichlet(cp.ones(n), n)

Я получаю сообщение об ошибке ниже:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-1-45a4f64a8b6e> in <module>
      6 n = 10000 #Size of the square matrix
      7 
----> 8 A = cp.random.dirichlet(cp.ones(n), n)
      9 
     10 print("--- %s seconds ---" % (time.time() - start_time))

~\anaconda3\envs\tensorflow\lib\site-packages\cupy\random\_distributions.py in dirichlet(alpha, size, dtype)
    112     """
    113     rs = _generator.get_random_state()
--> 114     return rs.dirichlet(alpha, size, dtype)
    115 
    116 

~\anaconda3\envs\tensorflow\lib\site-packages\cupy\random\_generator.py in dirichlet(self, alpha, size, dtype)
    144             size = alpha.shape
    145         else:
--> 146             size += alpha.shape
    147         y = cupy.empty(shape=size, dtype=dtype)
    148         _kernels.standard_gamma_kernel(alpha, self._rk_seed, y)

TypeError: unsupported operand type(s) for +=: 'int' and 'tuple'

Когда ввод представляет собой массив типа numpy, подобный этому

import cupy as cp
import numpy as np

A = cp.random.dirichlet(np.ones(n), n)

то я получаю ту же ошибку.

alpha.shape из строки 146 - (n,), когда я проверяю вручную. Это баги или я что-то упускаю?

Я использую cupy-cuda101 версии 8.5.0 для CUDA 10.1. Все остальное, что связано с cupy и tensorflow, отлично работает на моем графическом процессоре (2080ti).


person DimitrisMel    schedule 28.02.2021    source источник


Ответы (1)


Это ошибка в cupy, о которой вы должны сообщить на их GitHub.

Они неправильно обрабатывают случай целочисленного аргумента, несмотря на документацию. Они требуют, чтобы вы указали кортеж или None. Вот почему вы видите поведение, которое наблюдаете. (Если вы указали кортеж (a, b), тогда полученная форма будет правильной (a, b, n).

Обходной путь здесь - предоставить желаемую фигуру в виде кортежа длиной 1: (n,). Обратите внимание, что запятая необходима.

person Arya McCarthy    schedule 28.02.2021
comment
Спасибо. Я сообщу об этом на их GitHub. Этот обходной путь работает. - person DimitrisMel; 28.02.2021
comment
Обновление: я отправил запрос на перенос, чтобы исправить это. - person Arya McCarthy; 28.02.2021
comment
Отлично, я надеюсь, что скоро он будет объединен. Спасибо. - person DimitrisMel; 28.02.2021
comment
Спасибо за пул-реквест! Мы включим его в следующий выпуск. - person kmaehashi; 02.03.2021