Обучение дерева решений, сокращение и настройка гиперпараметров.

Описание статьи

  • Что такое дерево решений?
  • Зачем их использовать?
  • Фон данных
  • Описательная статистика
  • Обучение и оценка дерева принятия решений
  • Обрезка дерева решений
  • Настройка гиперпараметров

Что такое дерево решений?

Дерево решений - это представление блок-схемы. Алгоритм дерева классификации и регрессии (он же дерево решений) был разработан Брейманом и др. 1984 (обычно сообщается), но это определенно не самое раннее. Вей-Инь Ло из Университета Висконсина написал об истории деревьев решений. Вы можете прочитать это здесь, Пятьдесят лет классификации и деревьев регрессии ».

В дереве решений верхний узел называется «корневым узлом», а нижний узел - «конечным узлом». Другие узлы называются «внутренними узлами», что включает условие двоичного разделения, в то время как каждый листовой узел содержит связанные метки классов.

Дерево классификации использует условие разделения для прогнозирования метки класса на основе предоставленных входных переменных. Процесс разделения начинается с верхнего узла (корневого узла), и на каждом узле он проверяет, рекурсивно ли продолжаются предоставленные входные значения влево или вправо в соответствии с заданным условием разделения (коэффициент Джини или информационное усиление). Этот процесс завершается при достижении конечного или конечного узла.

Зачем их использовать?

Единую модель на основе дерева решений легко построить, построить и интерпретировать, что делает этот алгоритм настолько популярным. Вы можете использовать этот алгоритм для выполнения классификации, а также для задачи регрессии.

Фон данных

В этом примере мы собираемся использовать набор данных Индийский диабет 2, полученный из репозитория UCI баз данных машинного обучения (Newman et al. др. 1998).

Этот набор данных взят из Национального института диабета, болезней органов пищеварения и почек. Цель набора данных - диагностически предсказать, есть ли у пациента диабет, на основе определенных диагностических измерений, включенных в набор данных. На выбор этих экземпляров из более крупной базы данных было наложено несколько ограничений. В частности, все пациенты здесь - женщины старше 21 года, принадлежащие к индейцам пима.

Набор данных Pima Indian Diabetes 2 представляет собой уточненную версию (все пропущенные значения были присвоены как NA) данных по диабету Pima Indian. Набор данных содержит следующие независимые и зависимые переменные.

Независимые переменные (символ: I)

  • I1: беременна: количество беременных.
  • I2: глюкоза: концентрация глюкозы в плазме (тест на толерантность к глюкозе).
  • I3: давление: диастолическое артериальное давление (мм рт. Ст.).
  • I4: трицепс: толщина кожной складки трицепса (мм)
  • I5: инсулин: 2-часовой сывороточный инсулин (мкЕ / мл)
  • I6: масса: индекс массы тела (вес в кг / (рост в м) \ ²)
  • I7: родословная: функция родословной диабета.
  • I8: возраст: возраст (лет)

Зависимая переменная (символ: D)

  • D1: диабет: случай диабета (положительный / отрицательный)

Цель моделирования

  • подгонка модели машинного обучения классификации дерева решений, которая точно предсказывает, есть ли у пациентов в наборе данных диабет
  • Обрезка дерева решений для уменьшения переобучения
  • Настройка гиперпараметров дерева решений

Загрузка соответствующих библиотек

Первый шаг анализа данных начинается с загрузки соответствующих библиотек.

library(mlbench) # Diabetes dataset
library(rpart) # Decision tree
library(rpart.plot) # Plotting decision tree
library(caret) # Accuracy estimation
library(Metrics) # For diferent model evaluation metrics

Загрузка набора данных

Следующим шагом будет загрузка данных в среду R. Поскольку он поставляется с пакетом mlbench, можно загружать данные, вызывающие data ().

# load the diabetes dataset
data(PimaIndiansDiabetes2)

Предварительная обработка данных

Следующим шагом будет поисковый анализ. Во-первых, нам нужно удалить недостающие значения с помощью функции na.omit (). Распечатайте типы данных, используя метод glimpse () из библиотеки dplyr. Вы можете видеть, что все переменные, кроме зависимой переменной (диабет: категориальный / факторный), относятся к двойному типу.

Diabetes <- na.omit(PimaIndiansDiabetes2) # Data for modeling
dplyr::glimpse(Diabetes)

Тренировка и тестовый сплит

Следующим шагом является разделение набора данных на 80% поездов и 20% тестов. Здесь мы используем метод sample () для случайного выбора индекса наблюдения для разделения на обучение и тест с заменой. Далее на основе индексации разделяем данные поезда и теста.

set.seed(123)
index <- sample(2, nrow(Diabetes), prob = c(0.8, 0.2), replace = TRUE)
Diabetes_train <- Diabetes[index==1, ] # Train data
Diabetes_test <- Diabetes[index == 2, ] # Test data

Данные поезда включают 318 наблюдений, а тестовые данные включают 74 наблюдения. Оба содержат 9 переменных.

print(dim(Diabetes_train))
print(dim(Diabetes_test))

Модельное обучение

Следующий шаг - обучение модели и оценка ее работоспособности.

Обучение дерева решений

Для обучения дерева решений мы будем использовать функцию rpart () из библиотеки rpart. Аргументы включают; формула для модели, данных и метода.

формула = диабет ~. т.е. диабет предсказывается всеми независимыми переменными (за исключением диабета)

Здесь метод должен быть указан как класс для задачи классификации.

# Train a decision tree model
Diabetes_model <- rpart(formula = diabetes ~., 
                        data = Diabetes_train, 
                        method = "class")

Построение модели

Основное преимущество древовидной модели состоит в том, что вы можете построить древовидную структуру и определить механизм принятия решения.

# type: 0; Draw a split label at each split and a node label at each leaf.
# yesno = 2; provides spli yes or no
# Extra = 0; no extra information
rpart.plot(x = Diabetes_model, yesno = 2, type = 0, extra = 0)

Оценка производительности модели

Затем необходимо посмотреть, как наша обученная модель работает с тестовым / невидимым набором данных. Для прогнозирования класса тестовых данных нам необходимо предоставить объект модели, тестовый набор данных и type = «class» внутри функции pred ().

# class prediction
class_predicted <- predict(object = Diabetes_model,  
                            newdata = Diabetes_test,   
                            type = "class")

(а) Матрица неточностей

Чтобы оценить производительность теста, мы собираемся использовать confusionMatrix () из библиотеки caret. Мы можем заметить, что из 74 наблюдений он ошибочно предсказывает 17 наблюдений. Модель достигла точности 77,03% при использовании одного дерева решений.

# Generate a confusion matrix for the test data
confusionMatrix(data = class_predicted,       
                reference = Diabetes_test$diabetes)

(б) Точность теста

Мы также можем предоставить предсказанные метки классов и метки исходных тестовых наборов данных для функции precision () для оценки точности модели.

accuracy(actual = class_predicted,       
         predicted = Diabetes_test$diabetes)

Сравнение моделей на основе критериев разделения

При построении модели алгоритм дерева решений использует критерии разделения. В деревьях решений используются два популярных критерия разделения; один называется «джини», а другой - «получение информации». Здесь мы пытаемся сравнить производительность модели на тестовом наборе после обучения с различными критериями разделения. Критерии разделения предоставляются с использованием аргумента parms в виде списка.

# Model training based on gini-based splitting criteria
Diabetes_model1 <- rpart(formula = diabetes ~ ., 
                         data = Diabetes_train, 
                         method = "class",
                         parms = list(split = "gini"))
# Model training based on information gain-based splitting criteria
Diabetes_model2 <- rpart(formula = diabetes ~ ., 
                         data = Diabetes_train, 
                         method = "class",
                         parms = list(split = "information"))

Оценка модели на тестовых данных

После обучения модели следующим шагом является прогнозирование меток классов тестового набора данных.

# Generate class predictions on the test data using gini-based splitting criteria
pred1 <- predict(object = Diabetes_model1, 
                 newdata = Diabetes_test,
                 type = "class")
# Generate class predictions on test data using information gain based splitting criteria
pred2 <- predict(object = Diabetes_model2, 
                 newdata = Diabetes_test,
                 type = "class")

Сравнение точности прогнозов

Далее мы сравниваем точность моделей. Здесь мы можем заметить, что критерий разделения на основе «gini» обеспечивает более точную модель, чем «информация» основанное расщепление.

# Compare classification accuracy on test data
accuracy(actual = Diabetes_test$diabetes, 
   predicted = pred1)
accuracy(actual = Diabetes_test$diabetes, 
   predicted = pred2)

Исходная модель (Diabetes_model) и модель на основе «gini» (Diabetes_model1 ) с такой же точностью, поскольку модель rpart использует «gini» в качестве критерии разделения по умолчанию.

Обрезка дерева решений

График исходной модели (Diabetes_model) показывает, что древовидная структура является глубокой и хрупкой, что может затруднить легкую интерпретацию в процессе принятия решений. Таким образом, здесь мы попытаемся изучить другие способы сделать дерево более интерпретируемым без потери производительности. Один из способов сделать это - обрезать хрупкую часть дерева (часть способствует переобучению модели).

(а) График зависимости ошибки от параметра сложности

В дереве решений есть один параметр, называемый параметр сложности (cp), который управляет размером дерева решений. Если стоимость добавления другой переменной к дереву решений из текущего узла превышает значение cp, построение дерева не продолжается. Мы можем построить график зависимости cp от ошибки, используя библиотеку plotcp ().

# Plotting Cost Parameter (CP) Table
plotcp(Diabetes_model1)

(б) Создание таблицы параметров сложности

Мы также можем создать таблицу cp, вызвав model $ cptable. Здесь вы можете заметить, что xerror минимален со значением CP 0,025.

# Plotting the Cost Parameter (CP) Table
print(Diabetes_model1$cptable)

(c) Получение оптимальной обрезанной модели

Мы можем отфильтровать оптимальное значение CP, определив индекс минимального xerror и указав его в таблице CP.

# Retrieve of optimal cp value based on cross-validated error
index <- which.min(Diabetes_model1$cptable[, "xerror"])
cp_optimal <- Diabetes_model1$cptable[index, "CP"]

Следующим шагом является сокращение дерева с помощью функции prune (), задав оптимальное значение CP. Если мы построим оптимальное обрезанное дерево, мы увидим, что дерево очень просто и легко интерпретируется.

Если у человека уровень глюкозы выше 128 и возраст старше 25, будет считаться диабетом положительным, иначе отрицательным.

# Pruning tree based on optimal CP value
Diabetes_model1_opt <- prune(tree = Diabetes_model1, cp = cp_optimal)
rpart.plot(x = Diabetes_model1_opt, yesno = 2, type = 0, extra = 0)

(г) Производительность сокращенного дерева

Следующим шагом является проверка того, имеет ли дерево обрезки аналогичную производительность или производительность была скомпрометирована. После проверки производительности мы видим, что обрезанное дерево так же способно, как и более раннее хрупкое дерево, но теперь его просто и легко интерпретировать.

pred3 <- predict(object = Diabetes_model1_opt, 
                 newdata = Diabetes_test,
                 type = "class")
accuracy(actual = Diabetes_test$diabetes, 
         predicted = pred3)

Настройка гиперпараметров дерева решений

Затем мы попытаемся повысить производительность модели дерева решений, настроив ее гиперпараметры. Rpart () предлагает разные гиперпараметры, но здесь мы попытаемся настроить два важных параметра: minsplit и maxdepth.

  • minsplit: минимальное количество наблюдений, которое должно существовать в узле для попытки разделения.
  • maxdepth: максимальная глубина любого узла окончательного дерева.

(а) Создание сетки гиперпараметров

Сначала мы генерируем последовательность от 1 до 20 для minsplit и maxdepth. Затем мы строим сетку комбинации параметров с помощью функции expand.grid ().

#############################
## Hyper parameter Grid Search
#############################
# Setting values for minsplit and maxdepth
## the minimum number of observations that must exist in a node in order for a split to be attempted.
## Set the maximum depth of any node of the final tree
minsplit <- seq(1, 20, 1)
maxdepth <- seq(1, 20, 1)
# Generate a search grid 
hyperparam_grid <- expand.grid(minsplit = minsplit, maxdepth = maxdepth)

(б) Обучение сеточным моделям

Следующим шагом является обучение различных моделей на основе каждой комбинации гиперпараметров сетки. Это можно сделать, выполнив следующие действия:

  • использование цикла for для обхода каждого гиперпараметра в сетке, а затем передача его в функцию rpart () для обучения модели
  • сохранение каждой модели в пустом списке (Diade_models)
# Number of potential models in the grid
num_models <- nrow(hyperparam_grid)
# Create an empty list 
diabetes_models <- list()
# Write a loop over the rows of hyper_grid to train the grid of models
for (i in 1:num_models) {
  
  minsplit <- hyperparam_grid$minsplit[i]
  maxdepth <- hyperparam_grid$maxdepth[i]
  
  # Train a model and store in the list
  diabetes_models[[i]] <- rpart(formula = diabetes ~ ., 
                             data = Diabetes_train, 
                             method = "class",
                             minsplit = minsplit,
                             maxdepth = maxdepth)
}

(c) Вычисление точности теста

Следующим шагом является проверка производительности каждой модели на тестовых данных и получение лучшей модели. Это можно сделать, выполнив следующие действия:

  • использование цикла for для просмотра каждой модели в списке, а затем прогнозирование тестовых данных и точности вычислений
  • сохранение точности каждой модели в пустой вектор (precision_values)
# Number of models inside the grid
num_models <- length(diabetes_models)
# Create an empty vector to store accuracy values
accuracy_values <- c()
# Use for loop for models accuracy estimation
for (i in 1:num_models) {
  
  # Retrieve the model i from the list
  model <- diabetes_models[[i]]
  
  # Generate predictions on test data 
  pred <- predict(object = model,
                  newdata = Diabetes_test,
                  type = "class")
  
  # Compute test accuracy and add to the empty vector accuracy_values 
  accuracy_values[i] <- accuracy(actual = Diabetes_test$diabetes, 
                         predicted = pred)
}

(г) Определение лучшей модели

Следующим шагом является получение наиболее эффективной модели (максимальная точность) и печать ее гиперпараметров с помощью model $ control. Мы можем заметить, что при минимальном разбиении 17 и максимальной глубине 6 модель обеспечивает наиболее точные результаты при оценке на невидимом / тестовом наборе данных.

# Identify the model with maximum accuracy
best_model <- diabetes_models[[which.max(accuracy_values)]]
# Print the model hyper-parameters of the best model
best_model$control

(e) Лучшая оценка модели на основе тестовых данных

После определения наиболее эффективной модели следующим шагом будет проверка ее точности. Теперь, с лучшими гиперпараметрами, модель достигла точности 81,08%, что действительно здорово.

# Best_model accuracy on test data
pred <- predict(object = best_model,
                newdata = Diabetes_test,
                type = "class")
accuracy(actual = Diabetes_test$diabetes, 
     predicted = pred)

(f) Лучшая модель графика

Пришло время построить лучшую модель.

rpart.plot(x = best_model, yesno = 2, type = 0, extra = 0)

Даже приведенный выше график предназначен для модели с лучшими характеристиками, тем не менее, он выглядит немного хрупким. Итак, ваша следующая задача - сократить его и посмотреть, получите ли вы лучше интерпретируемое дерево решений или нет.

Надеюсь, вы узнали что-то новое. Увидимся в следующий раз!

Ссылки

[1] Брейман, Л., Фридман, Дж., Стоун, С.Дж., Олшен, Р.А., 1984. Деревья классификации и регрессии. CRC Press.

[2] Ло, В. (2014). Пятьдесят лет деревьев классификации и регрессии 1.

[3] Ньюман К. Б. и Мерц К. (1998). Репозиторий баз данных машинного обучения UCI, Технический отчет, Калифорнийский университет, Ирвин, Департамент информации и компьютерных наук.

Рахул Раоньяр

  • Если вам понравилось, подпишитесь на меня на medium, чтобы узнать больше
  • Свяжитесь со мной в Twitter, LinkedIn, YouTube и Github