Как избавиться от каждого столбца, заполненного нулями, из тензора Pytorch?

У меня есть тензор pytorch A, как показано ниже:

A = 
tensor([[  4,   3,   3,  ...,   0,   0,   0],
        [ 13,   4,  13,  ...,   0,   0,   0],
        [707, 707,   4,  ...,   0,   0,   0],
        ...,
        [  7,   7,   7,  ...,   0,   0,   0],
        [  0,   0,   0,  ...,   0,   0,   0],
        [195, 195, 195,  ...,   0,   0,   0]], dtype=torch.int32)

Я хотел бы:

  • определить все столбцы, все записи которых равны 0
  • удалить только те столбцы, в которых все записи равны 0

Я могу представить, что делаю:

zero_list = []
for j in range(A.size()[1]):
    if torch.sum(A[:,j]) == 0:
         zero_list = zero_list.append(j)

для определения столбцов, в которых для элементов только 0, но я не уверен, как удалить такие столбцы, заполненные 0, из исходного тензора.

Как я могу удалить столбцы с нулем из тензора pytorch на основе номера индекса?

Спасибо,


person chico0913    schedule 27.11.2019    source источник


Ответы (2)


Более разумно индексировать столбцы, которые вы хотите сохранить, а не то, что вы хотите удалить.

valid_cols = []
for col_idx in range(A.size(1)):
    if not torch.all(A[:, col_idx] == 0):
        valid_cols.append(col_idx)
A = A[:, valid_cols]

Или немного более загадочно

valid_cols = [col_idx for col_idx, col in enumerate(torch.split(A, 1, dim=1)) if not torch.all(col == 0)]
A = A[:, valid_cols]
person jodag    schedule 27.11.2019

Определите все столбцы, все записи которых равны 0

non_empty_mask = A.abs().sum(dim=0).bool()

Это суммирует абсолютные значения каждого столбца, а затем преобразует результат в логическое значение, то есть False, если сумма равна нулю, и True в противном случае.

Удалите только те столбцы, в которых все записи равны 0

A[:,non_empty_mask]

Это просто применяет маску к исходному тензору, то есть сохраняет строки, где non_empty_mask равно True.

person Aydo    schedule 27.10.2020