Как использовать GridSearchCV для настройки параметров со стратегией train_test_split?

Я пытаюсь настроить свои модели sklearn, используя стратегию train_test_split. Мне известно о способности GridSearchCV выполнять настройка параметров, однако это было завязано на использование стратегии Cross Validation, я бы хотел использовать стратегию train_test_split для поиска параметров, так как для моего случая важна скорость обучения, я предпочитаю простые train_test_split по перекрестной проверке.

Я мог бы попытаться написать свой собственный цикл for, но это было бы неэффективно, если бы я не использовал преимущества встроенного распараллеливания, используемого в GridSearchCV.

Кто-нибудь знает, как воспользоваться GridSearchCV для этого? Или предоставьте альтернативу, которая не была слишком медленной.


person Alex Ramses    schedule 27.08.2019    source источник


Ответы (2)


Да, вы можете использовать для этого ShuffleSplit.

ShuffleSplit — это стратегия перекрестной проверки, как и KFold, но в отличие от KFold, где вам нужно обучать K моделей, здесь вы можете контролировать, сколько раз выполнять разделение обучения/тестирования, даже один раз, если хотите.

shuffle_split = ShuffleSplit(n_splits=1,test_size=.25)

n_splits определяет, сколько раз повторять эту процедуру разделения и тренировки. Теперь вы можете использовать его так:

GridSearchCV(clf,param_grid={},cv=shuffle_split)
person Shihab Shahriar Khan    schedule 27.08.2019

Я хотел бы добавить к ответу Шихаба Шахриара, предоставив пример кода.

import pandas as pd
from sklearn import datasets
from sklearn.model_selection import GridSearchCV, ShuffleSplit
from sklearn.ensemble import RandomForestClassifier

# Load iris dataset
iris = datasets.load_iris()

# Prepare X and y as dataframe
X = pd.DataFrame(data=iris.data, columns=iris.feature_names)
y = pd.DataFrame(data=iris.target, columns=['Species'])

# Train test split
shuffle_split = ShuffleSplit(n_splits=1, test_size=0.3)
# This is equivalent to: 
#   X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
# But, it is usable for GridSearchCV

# GridSearch without CV
params = { 'n_estimators': [16, 32] }
clf = RandomForestClassifier()
grid_search = GridSearchCV(clf, param_grid=params, cv=shuffle_split)
grid_search.fit(X, y)

Это должно помочь всем, кто сталкивается с подобной проблемой.

person Alex Ramses    schedule 28.08.2019