Относительная важность набора предикторов в классификации случайных лесов в R


31

Я хотел бы определить относительную важность наборов переменных для randomForestмодели классификации в R. importanceФункция предоставляет MeanDecreaseGiniметрику для каждого отдельного предиктора - это так же просто, как суммировать это для каждого предиктора в наборе?

Например:

# Assumes df has variables a1, a2, b1, b2, and outcome
rf <- randomForest(outcome ~ ., data=df)
importance(rf)
# To determine whether the "a" predictors are more important than the "b"s,
# can I sum the MeanDecreaseGini for a1 and a2 and compare to that of b1+b2?

Ответы:


46

Сначала я хотел бы уточнить, что на самом деле измеряет показатель важности.

MeanDecreaseGini - это показатель переменной важности, основанный на индексе примесей Джини, который используется для расчета расщеплений во время тренировки. Распространенным заблуждением является то, что метрика переменной важности относится к Джини, используемому для подтверждения производительности модели, которая тесно связана с AUC, но это неверно. Вот объяснение из пакета randomForest, написанное Брейманом и Катлером:

Важность Джини
Каждый раз, когда деление узла выполняется по переменной m, критерий примеси Джини для двух дочерних узлов меньше, чем родительский узел. Сложение уменьшений джини для каждой отдельной переменной по всем деревьям в лесу дает важность быстрой переменной, которая часто очень согласуется с мерой важности перестановки.

Индекс примеси Джини определяется как , где п с является число классов в целевой переменной и р я есть отношение этого класса.

G=i=1ncpi(1pi)
ncpi

Для задачи двух классов это приводит к следующей кривой, которая максимизируется для образца 50-50 и минимизируется для однородных множеств: Примесь Джини для 2 класса

Значение вычисляется как , усредненное по всем расколов в лесу , связанных с предсказателем в вопросе. Поскольку это среднее значение, оно может быть легко расширено для усреднения по всем расщеплениям по переменным, содержащимся в группе.

I=GparentGsplit1Gsplit2

Если присмотреться, мы узнаем, что важность каждой переменной - это среднее условие для используемой переменной, а meanDecreaseGini группы будет просто средним значением этих значений, взвешенных для доли, которую эта переменная использует в лесу, по сравнению с другими переменными в той же группе. Это верно, потому что свойство башни

E[E[X|Y]]=E[X]

Теперь, чтобы ответить на ваш вопрос напрямую, это не так просто, как просто суммировать все значения в каждой группе, чтобы получить комбинированный MeanDecreaseGini, но вычисление средневзвешенного значения даст вам ответ, который вы ищете. Нам просто нужно найти переменные частоты в каждой группе.

Вот простой скрипт для получения их из объекта случайного леса в R:

var.share <- function(rf.obj, members) {
  count <- table(rf.obj$forest$bestvar)[-1]
  names(count) <- names(rf.obj$forest$ncat)
  share <- count[members] / sum(count[members])
  return(share)
}

Просто передайте имена переменных в группе в качестве параметра members.

Надеюсь, это ответит на ваш вопрос. Я могу написать функцию, чтобы получить значения группы напрямую, если это представляет интерес.

РЕДАКТИРОВАТЬ:
Вот функция, которая дает важность группы для данного randomForestобъекта и список векторов с именами переменных. Он использует var.shareкак определено ранее. Я не делал никакой проверки ввода, поэтому вам нужно убедиться, что вы используете правильные имена переменных.

group.importance <- function(rf.obj, groups) {
  var.imp <- as.matrix(sapply(groups, function(g) {
    sum(importance(rf.obj, 2)[g, ]*var.share(rf.obj, g))
  }))
  colnames(var.imp) <- "MeanDecreaseGini"
  return(var.imp)
}

Пример использования:

library(randomForest)                                                          
data(iris)

rf.obj <- randomForest(Species ~ ., data=iris)

groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
               Petal=c("Petal.Width", "Petal.Length"))

group.importance(rf.obj, groups)

>

      MeanDecreaseGini
Sepal         6.187198
Petal        43.913020

Это также работает для перекрывающихся групп:

overlapping.groups <- list(Sepal=c("Sepal.Width", "Sepal.Length"), 
                           Petal=c("Petal.Width", "Petal.Length"),
                           Width=c("Sepal.Width", "Petal.Width"), 
                           Length=c("Sepal.Length", "Petal.Length"))

group.importance(rf.obj, overlapping.groups)

>

       MeanDecreaseGini
Sepal          6.187198
Petal         43.913020
Width          30.513776
Length        30.386706

Спасибо за четкий и строгий ответ! Если вы не против добавить функцию для групповых значений, это было бы здорово.
Макс Генис

Спасибо за этот ответ! Два вопроса, если у вас есть минутка: (1) Значимость тогда вычисляется как ... : в отношении определения Бреймана, я - это «уменьшение Джини», и важность будет суммой уменьшений, верно ? (2) усреднено по всем расщеплениям в лесу с участием рассматриваемого предиктора : могу ли я заменить это всеми узлами, включающими расщепление по этой конкретной функции ? Конечно, я полностью понимаю;)
Реми Мелиссон,

1
Ваш комментарий заставил меня задуматься над определениями, поэтому я перебрал код randomForest, используемый в R, чтобы правильно ответить на него. Я был немного не честен, если честно. Среднее значение производится по всем деревьям, а не по всем узлам. Я обновлю ответ, как только у меня будет время. Вот ответы на ваш вопрос на данный момент: (1) да. Вот как это определяется на уровне дерева. Сумма убываний затем усредняется по всем деревьям. (2) Да, это то, что я хотел сказать, но на самом деле это не так.
то время как

4

Функция, определенная выше как G = сумма по классам [pi (1-pi)], на самом деле является энтропией, которая является еще одним способом оценки разбиения. Разница между энтропией в дочерних узлах и родительском узле заключается в получении информации. Примесная функция GINI является G = 1-суммой по классам [pi ^ 2].

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.