多分类logistic回归
在之前文章介绍了,如何在R里面处理多分类的回归模型,得到的是各个因素的系数及相对OR,但是解释性,比二元logistic回归方程要冗杂的多。
那么今天继续前面的基础上,用机器学习的方法来解释多分类问题。
其实最终回归到这类分类问题的本质:有了一系列的影响因素x,那么根据这些影响因素来判断最终y属于哪一类别。
1.数据案例
这里主要用到DALEX包里面包含的HR数据,里面记录了职工在工作岗位的状态与年龄,性别,工作时长,评价及薪水有关。根据7847条记录来评估,如果一个职工属于男性,68岁,薪水及评价处于3等级,那么该职工可能会处于什么状态。
library(DALEX)
library(iBreakDown)
library(car)
library(questionr)
try(data(package="DALEX"))
data(HR)
# split
set.seed(543)
ind = sample(2,nrow(HR),replace=TRUE,prob=c(0.9,0.1))
trainData = HR[ind==1,]
testData = HR[ind==2,]
# randforest
m_rf = randomForest(status ~ . , data = trainData)
2.随机森林模型
我们根据上述数据,分成训练集与测试集(Train and Test)测试集用来估计随机森林模型的效果。
2.1模型评估
通过对Train数据构建rf模型后,我们对Train数据进行拟合,看一下模型的效果,Accuracy : 0.9357 显示很好,kappa一致性为90%。
那再用该fit去预测test数据, Accuracy : 0.7166 , Kappa : 56% ,显示效果不怎么理想。
# Prediction and Confusion Matrix - Training data
pred1 <- predict(m_rf, trainData)
head(pred1)
confusionMatrix(pred1, trainData$status) #
pred2 <- predict(m_rf, testData)
head(pred2)
confusionMatrix(pred2, testData$status) #
> confusionMatrix(pred1, trainData$status) #
Confusion Matrix and Statistics
Reference
Prediction fired ok promoted
fired 2478 194 49
ok 43 1738 80
promoted 25 64 2375
Overall Statistics
Accuracy : 0.9354
95% CI : (0.9294, 0.9411)
No Information Rate : 0.3613
P-Value [Acc > NIR] : < 2.2e-16
Kappa : 0.9024
Mcnemar's Test P-Value : < 2.2e-16
Statistics by Class:
Class: fired Class: ok Class: promoted
Sensitivity 0.9733 0.8707 0.9485
Specificity 0.9460 0.9756 0.9804
Pos Pred Value 0.9107 0.9339 0.9639
Neg Pred Value 0.9843 0.9502 0.9718
Prevalence 0.3613 0.2833 0.3554
Detection Rate 0.3517 0.2467 0.3371
Detection Prevalence 0.3862 0.2641 0.3497
Balanced Accuracy 0.9596 0.9232 0.9644
>
> pred2 <- predict(m_rf, testData)
> head(pred2)
1 20 36 42 49 56
fired fired fired fired fired ok
Levels: fired ok promoted
> confusionMatrix(pred2, testData$status) #
Confusion Matrix and Statistics
Reference
Prediction fired ok promoted
fired 246 62 19
ok 37 117 37
promoted 26 46 211
Overall Statistics
Accuracy : 0.7166
95% CI : (0.684, 0.7476)
No Information Rate : 0.3858
P-Value [Acc > NIR] : < 2e-16
Kappa : 0.5692
Mcnemar's Test P-Value : 0.03881
Statistics by Class:
Class: fired Class: ok Class: promoted
Sensitivity 0.7961 0.5200 0.7903
Specificity 0.8354 0.8715 0.8652
Pos Pred Value 0.7523 0.6126 0.7456
Neg Pred Value 0.8671 0.8230 0.8919
Prevalence 0.3858 0.2809 0.3333
Detection Rate 0.3071 0.1461 0.2634
Detection Prevalence 0.4082 0.2385 0.3533
Balanced Accuracy 0.8157 0.6958 0.8277
2.2变量重要性
我们看到,对影响因素进行重要性排序,等同于P值。在预测时候,哪些因素对y占影响比重较大。这里的variable_importance(),可以有好几种方式对变量进行衡量,这里采用默认的MeanDecreaseGini.
# vip
vip(m_rf)
var=randomForest::importance(m_rf)
var
2.2边际效应
我们知道了hours,age比较重要,那么是如何重要的,譬如年龄在什么阶段,会导致升职或者开除。
当工作小时在45以内,被开除/离职的概率较大,当工作时常超过60以后,很有可能会被提升。得到升职加薪的机会。
当然了,也可以绘制2D的边际效应,两个因素相互作用的Partial plot。
# partial plot
partialPlot(m_rf, HR, age)
head(partial(m_rf, pred.var = "age")) # returns a data frame
# for all varibles
nm=rownames(var)
# Get partial depedence values for top predictors
pd_df <- partial_dependence(fit = m_rf,
vars = nm,
data = df_rf,
n = c(100, 200))
# Plot partial dependence using edarf
plot_pd(pd_df)
2.3个体预测
现在假如有一个员工的信息如下,
gender age hours evaluation salary status
10000 female 57.96254 54.78624 4 4 promoted
去预测该职工最后的状态:
该预测结果显示,这个职工,有97%的可能性要升职加薪。而他的实际状态也是Promoted。
new_observation=tail(HR,1)
p_fun <- function(object, newdata){predict(object, newdata = newdata, type = "prob")}
bd_rf <- local_attributions(m_rf,
data = HR_test,
new_observation = new_observation,
predict_function = p_fun)
bd_rf
plot(bd_rf)
> sessionInfo()
R version 3.6.2 (2019-12-12)
Platform: x86_64-apple-darwin15.6.0 (64-bit)
Running under: macOS Mojave 10.14
Matrix products: default
BLAS: /System/Library/Frameworks/Accelerate.framework/Versions/A/Frameworks/vecLib.framework/Versions/A/libBLAS.dylib
LAPACK: /Library/Frameworks/R.framework/Versions/3.6/Resources/lib/libRlapack.dylib
locale:
[1] en_US.UTF-8/en_US.UTF-8/en_US.UTF-8/C/en_US.UTF-8/en_US.UTF-8
attached base packages:
[1] stats graphics utils datasets grDevices methods base
other attached packages:
[1] edarf_1.1.1 ranger_0.12.1 questionr_0.7.0 car_3.0-7
[5] carData_3.0-3 nnet_7.3-14 DALEX_1.2.1 vip_0.2.2
[9] ggpubr_0.3.0 rstatix_0.5.0 caret_6.0-86 lattice_0.20-41
[13] pdp_0.7.0 randomForest_4.6-14 iBreakDown_1.2.0 hrbrthemes_0.8.0
[17] reshape2_1.4.4 RColorBrewer_1.1-2 forcats_0.5.0 stringr_1.4.0
[21] dplyr_0.8.5 purrr_0.3.4 readr_1.3.1 tidyr_1.0.3
[25] tibble_3.0.1 ggplot2_3.3.0 tidyverse_1.3.0