Функция перекрестной проверки и передискретизации (SMOTE)

Я написал приведенный ниже код. X - это фрейм данных с формой (1000,5), а y - это фрейм данных с формой (1000,1). y - это целевые данные для прогнозирования, и они несбалансированы. Я хочу применить перекрестную проверку и SMOTE.

def Learning(n, est, X, y):
    s_k_fold = StratifiedKFold(n_splits = n)
    acc_scores = []
    rec_scores = []
    f1_scores = []

    for train_index, test_index in s_k_fold.split(X, y): 
        X_train = X[train_index]
        y_train = y[train_index]    

        sm = SMOTE(random_state=42)
        X_resampled, y_resampled = sm.fit_resample(X_train, y_train)

        X_test = X[test_index]
        y_test = y[test_index]

        est.fit(X_resampled, y_resampled)
        y_pred = est.predict(X_test)
        acc_scores.append(accuracy_score(y_test, y_pred))
        rec_scores.append(recall_score(y_test, y_pred))
        f1_scores.append(f1_score(y_test, y_pred)) 

    print('Accuracy:',np.mean(acc_scores))
    print('Recall:',np.mean(rec_scores))
    print('F1:',np.mean(f1_scores)) 

Learning(3, SGDClassifier(), X_train_s_pca, y_train)

Когда я запускаю код, я получаю следующую ошибку:

Ни один из [Int64Index ([4231, 4235, 4246, 4250, 4255, 4295, 4317, 4344, 4381, \ n 4387, \ n ... \ n 13122, 13123, 13124, 13125, 13126, 13127, 13128, 13129 , 13130, \ n
13131], \ n dtype = 'int64', length = 8754)] находятся в [столбцах] "

Помощь в его запуске приветствуется.


person BTurkeli    schedule 15.05.2019    source источник


Ответы (1)


Если вы внимательно наблюдаете за трассировкой стека ошибок (что важно, но вы не включаете), вы должны увидеть, что ошибка исходит из этой строки (и будет исходить из других похожих строк):

X_train = X[train_index]

Этот способ выбора строк применим только для массива Numpy. Поскольку вы используете Pandas DataFrame, вам следует использовать loc:

X_train = X.loc[train_index]

В качестве альтернативы вы можете преобразовать DataFrame в массив Numpy (чтобы минимизировать изменение кода) с помощью значения:

Learning(3, SGDClassifier(), X_train_s_pca.values, y_train.values)
person Yohanes Gultom    schedule 15.05.2019