Машинное обучение для аналитиков данных

Регрессия

Модель прогнозирования смертности от рака на основе набора данных регрессии OLS

загрузить библиотеку

library(tidyverse)
## -- Attaching packages --------------------------------------- tidyverse 1.3.1 --
## v ggplot2 3.3.5     v purrr   0.3.4
## v tibble  3.1.6     v dplyr   1.0.8
## v tidyr   1.2.0     v stringr 1.4.0
## v readr   2.1.2     v forcats 0.5.1
## -- Conflicts ------------------------------------------ tidyverse_conflicts() --
## x dplyr::filter() masks stats::filter()
## x dplyr::lag()    masks stats::lag()
library(caret)
## Loading required package: lattice
## 
## Attaching package: 'caret'
## The following object is masked from 'package:purrr':
## 
##     lift
library(mlbench)

загрузить набор данных

cancer_mortal <- read_csv("cancer_reg.csv")
## Rows: 3047 Columns: 34
## -- Column specification --------------------------------------------------------
## Delimiter: ","
## chr  (2): binnedInc, Geography
## dbl (32): avgAnnCount, avgDeathsPerYear, TARGET_deathRate, incidenceRate, me...
## 
## i Use `spec()` to retrieve the full column specification for this data.
## i Specify the column types or set `show_col_types = FALSE` to quiet this message.

пропущенные значения (полные?)

mean(complete.cases(cancer_mortal))
## [1] 0.1939613

дата имеет много пропущенных значений

удалить пропущенные значения

cancer_mortal <- cancer_mortal %>%
  drop_na()

изучить данные, найти выброс

## Rows: 591
## Columns: 34
## $ avgAnnCount             <dbl> 173, 427, 57, 146, 2265, 1390, 32, 94, 25, 58,~
## $ avgDeathsPerYear        <dbl> 70, 202, 26, 71, 901, 483, 12, 41, 19, 22, 103~
## $ TARGET_deathRate        <dbl> 161.3, 194.8, 144.4, 183.6, 171.0, 169.9, 153.~
## $ incidenceRate           <dbl> 411.6, 430.4, 350.1, 404.0, 440.7, 495.9, 463.~
## $ medIncome               <dbl> 48127, 44243, 49955, 40189, 50083, 61653, 5102~
## $ popEst2015              <dbl> 43269, 75882, 10321, 20848, 490945, 269536, 40~
## $ povertyPercent          <dbl> 18.6, 17.1, 12.5, 17.8, 16.3, 11.9, 13.9, 21.5~
## $ studyPerCap             <dbl> 23.11123, 342.63725, 0.00000, 0.00000, 462.373~
## $ binnedInc               <chr> "(48021.6, 51046.4]", "(42724.4, 45201]", "(48~
## $ MedianAge               <dbl> 33.0, 42.8, 48.3, 51.7, 37.2, 38.5, 52.1, 41.5~
## $ MedianAgeMale           <dbl> 32.2, 42.2, 47.8, 50.8, 35.7, 37.1, 51.5, 40.9~
## $ MedianAgeFemale         <dbl> 33.7, 43.4, 48.9, 52.5, 38.7, 39.9, 53.1, 42.1~
## $ Geography               <chr> "Kittitas County, Washington", "Lewis County, ~
## $ AvgHouseholdSize        <dbl> 2.340, 2.520, 2.340, 2.240, 2.450, 2.520, 2.31~
## $ PercentMarried          <dbl> 44.5, 52.7, 57.8, 52.7, 49.4, 52.4, 53.5, 52.0~
## $ PctNoHS18_24            <dbl> 6.1, 20.2, 14.9, 27.3, 10.9, 14.8, 37.2, 9.8, ~
## $ PctHS18_24              <dbl> 22.4, 41.2, 43.0, 33.9, 29.3, 31.7, 20.8, 36.1~
## $ PctSomeCol18_24         <dbl> 64.0, 36.1, 40.0, 36.5, 51.2, 46.2, 37.8, 45.8~
## $ PctBachDeg18_24         <dbl> 7.5, 2.5, 2.0, 2.2, 8.6, 7.4, 4.2, 8.3, 1.9, 8~
## $ PctHS25_Over            <dbl> 26.0, 31.6, 33.4, 31.6, 25.7, 22.7, 33.1, 47.1~
## $ PctBachDeg25_Over       <dbl> 22.7, 9.3, 15.0, 11.3, 18.1, 20.3, 10.1, 7.9, ~
## $ PctEmployed16_Over      <dbl> 55.9, 48.3, 48.2, 40.9, 55.1, 56.5, 35.7, 46.5~
## $ PctUnemployed16_Over    <dbl> 7.8, 12.1, 4.8, 8.9, 8.4, 8.5, 10.6, 9.0, 6.5,~
## $ PctPrivateCoverage      <dbl> 70.2, 58.4, 61.6, 55.8, 65.2, 74.5, 64.7, 55.6~
## $ PctPrivateCoverageAlone <dbl> 53.8, 40.3, 43.9, 33.1, 50.6, 55.4, 38.6, 40.1~
## $ PctEmpPrivCoverage      <dbl> 43.6, 35.0, 35.1, 25.9, 42.5, 43.5, 35.2, 36.5~
## $ PctPublicCoverage       <dbl> 31.1, 45.3, 44.0, 50.9, 36.5, 30.7, 49.7, 44.8~
## $ PctPublicCoverageAlone  <dbl> 15.3, 25.0, 22.7, 24.1, 21.4, 13.7, 20.4, 26.4~
## $ PctWhite                <dbl> 89.22851, 91.74469, 94.10402, 89.40664, 89.038~
## $ PctBlack                <dbl> 0.9691025, 0.7826260, 0.2701920, 0.3051586, 1.~
## $ PctAsian                <dbl> 2.24623259, 1.16135867, 0.66583036, 1.88907726~
## $ PctOtherRace            <dbl> 3.74135153, 1.36264318, 0.49213548, 2.28626786~
## $ PctMarriedHouseholds    <dbl> 45.37250, 51.02151, 54.02746, 48.96703, 48.188~
## $ BirthRate               <dbl> 4.3330956, 4.6038408, 6.7966574, 5.8891790, 5.~

1.разделить данные

set.seed(42)
id <- createDataPartition(y = cancer_mortal$TARGET_deathRate,
                          p = 0.8,
                          list = F)
train_cancer_df <- cancer_mortal[id, ]
test_cancer_df <- cancer_mortal[-id, ]

2. модель обучения с линейной регрессией

set.seed(42)
ctrl <- trainControl(method = "cv",
                     number = 5,
                     verboseIter = T)
lm_model <- train(TARGET_deathRate ~ avgAnnCount +
                    avgDeathsPerYear +
                    incidenceRate +
                    MedianAge +
                    binnedInc +
                    PctBachDeg18_24 +
                    PctBachDeg25_Over +
                    PctEmployed16_Over +
                    PctPrivateCoverage +
                    PctPublicCoverage +
                    PctWhite +
                    PctBlack +
                    PctMarriedHouseholds,
                  data = train_cancer_df,
                  method = "lm",
                  trControl = ctrl)
## + Fold1: intercept=TRUE 
## - Fold1: intercept=TRUE 
## + Fold2: intercept=TRUE 
## - Fold2: intercept=TRUE 
## + Fold3: intercept=TRUE 
## - Fold3: intercept=TRUE 
## + Fold4: intercept=TRUE 
## - Fold4: intercept=TRUE 
## + Fold5: intercept=TRUE 
## - Fold5: intercept=TRUE 
## Aggregating results
## Fitting final model on full training set
lm_model
## Linear Regression 
## 
## 475 samples
##  13 predictor
## 
## No pre-processing
## Resampling: Cross-Validated (5 fold) 
## Summary of sample sizes: 379, 381, 380, 380, 380 
## Resampling results:
## 
##   RMSE      Rsquared   MAE     
##   21.65098  0.3573773  15.76747
## 
## Tuning parameter 'intercept' was held constant at a value of TRUE

3. Набор тестов для прогнозирования (оценка)

p_lm <- predict(lm_model, newdata = test_cancer_df)

4. оценить модель

test_rmse <- sqrt(mean((test_cancer_df$TARGET_deathRate - p_lm)**2))
call('RMSE =', test_rmse)
## `RMSE =`(21.1085977415944)