Случайный лес и прогноз


13

Я пытаюсь понять, как работает Random Forest. У меня есть представление о том, как строятся деревья, но я не могу понять, как Random Forest делает прогнозы на выборке из сумки. Может ли кто-нибудь дать мне простое объяснение, пожалуйста? :)

Ответы:


16

Каждое дерево в лесу построено из начальной загрузки наблюдений в ваших тренировочных данных. Эти наблюдения в образце начальной загрузки создают дерево, в то время как те, которые не находятся в образце начальной загрузки, образуют образцы вне пакета (или OOB).

Должно быть понятно, что для наблюдений в данных, используемых для построения дерева, доступны те же переменные, что и для наблюдений в примере OOB. Чтобы получить прогнозы для образца OOB, каждый из них передается по текущему дереву, а правила для дерева следуют до тех пор, пока он не прибудет в конечный узел. Это дает предсказания OOB для этого конкретного дерева.

Этот процесс повторяется большое количество раз, каждое дерево обучается на новой выборке начальной загрузки на основе данных обучения и прогнозов для новых выборок OOB.

По мере роста числа деревьев любая выборка будет присутствовать в выборках OOB более одного раза, поэтому «среднее» предсказаний по N деревьям, где выборка находится в OOB, используется в качестве прогнозирования OOB для каждой обучающей выборки для деревья 1, ..., N. По «среднему» мы используем среднее значение прогнозов для непрерывного ответа, или большинство голосов может быть использовано для категориального ответа (большинство голосов - это класс с большинством голосов над набором деревья 1, ..., N).

Например, предположим, что у нас были следующие прогнозы OOB для 10 образцов в обучающем наборе на 10 деревьях

set.seed(123)
oob.p <- matrix(rpois(100, lambda = 4), ncol = 10)
colnames(oob.p) <- paste0("tree", seq_len(ncol(oob.p)))
rownames(oob.p) <- paste0("samp", seq_len(nrow(oob.p)))
oob.p[sample(length(oob.p), 50)] <- NA
oob.p

> oob.p
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA     7     8     2     1    NA     5     3      2
samp2      6    NA     5     7     3    NA    NA    NA    NA     NA
samp3      3    NA     5    NA    NA    NA     3     5    NA     NA
samp4      6    NA    10     6    NA    NA     3    NA     6     NA
samp5     NA     2    NA    NA     2    NA     6     4    NA     NA
samp6     NA     7    NA     4    NA     2     4     2    NA     NA
samp7     NA    NA    NA     5    NA    NA    NA     3     9      5
samp8      7     1     4    NA    NA     5     6    NA     7     NA
samp9      4    NA    NA     3    NA     7     6     3    NA     NA
samp10     4     8     2     2    NA    NA     4    NA    NA      4

Где NAозначает, что образец был в обучающих данных для этого дерева (другими словами, он не был в выборке OOB).

Среднее значение ненулевых NAзначений для каждой строки дает прогноз OOB для каждой выборки для всего леса

> rowMeans(oob.p, na.rm = TRUE)
 samp1  samp2  samp3  samp4  samp5  samp6  samp7  samp8  samp9 samp10 
  4.00   5.25   4.00   6.20   3.50   3.80   5.50   5.00   4.60   4.00

Поскольку каждое дерево добавляется в лес, мы можем вычислить ошибку OOB, включая это дерево. Например, ниже приведены кумулятивные средства для каждого образца:

FUN <- function(x) {
  na <- is.na(x)
  cs <- cumsum(x[!na]) / seq_len(sum(!na))
  x[!na] <- cs
  x
}
t(apply(oob.p, 1, FUN))

> print(t(apply(oob.p, 1, FUN)), digits = 3)
       tree1 tree2 tree3 tree4 tree5 tree6 tree7 tree8 tree9 tree10
samp1     NA    NA  7.00  7.50  5.67  4.50    NA   4.6  4.33    4.0
samp2      6    NA  5.50  6.00  5.25    NA    NA    NA    NA     NA
samp3      3    NA  4.00    NA    NA    NA  3.67   4.0    NA     NA
samp4      6    NA  8.00  7.33    NA    NA  6.25    NA  6.20     NA
samp5     NA     2    NA    NA  2.00    NA  3.33   3.5    NA     NA
samp6     NA     7    NA  5.50    NA  4.33  4.25   3.8    NA     NA
samp7     NA    NA    NA  5.00    NA    NA    NA   4.0  5.67    5.5
samp8      7     4  4.00    NA    NA  4.25  4.60    NA  5.00     NA
samp9      4    NA    NA  3.50    NA  4.67  5.00   4.6    NA     NA
samp10     4     6  4.67  4.00    NA    NA  4.00    NA    NA    4.0

Таким образом, мы видим, как прогноз накапливается по N деревьям в лесу до заданной итерации. Если вы читаете по строкам, крайнее правое NAзначение, которое я показываю выше для прогноза OOB. Вот как можно отслеживать характеристики OOB - RMSEP может быть вычислен для выборок OOB на основе прогнозов OOB, накопленных кумулятивно по N деревьям.

Обратите внимание, что показанный R-код не берется из внутренних частей кода randomForest в пакете randomForest для R - я просто выбил некоторый простой код, чтобы вы могли следить за тем, что происходит, когда определены прогнозы для каждого дерева.

Поскольку каждое дерево построено из выборки начальной загрузки и в случайном лесу имеется большое количество деревьев, так что каждое наблюдение обучающего набора находится в выборке OOB для одного или нескольких деревьев, прогнозы OOB могут быть предоставлены для всех образцы в тренировочных данных.

Я замаскировал такие проблемы, как отсутствие данных для некоторых случаев OOB и т. Д., Но эти проблемы также относятся к одному дереву регрессии или классификации. Также обратите внимание, что каждое дерево в лесу использует только mtryслучайно выбранные переменные.


Отличный ответ Гэвин! Когда вы пишете "To get predictions for the OOB sample, each one is passed down the current tree and the rules for the tree followed until it arrives in a terminal node", у вас есть простое объяснение того, что rules for the treeесть? И sampleправильно ли я понимаю строку, если я понимаю, что образцы представляют собой groupsнаблюдения, на которые деревья делят данные?
user1665355

@ user1665355 Я полагаю, вы поняли, как строились деревья регрессии или классификации? Деревья в РФ ничем не отличаются (разве что в правилах остановки). Каждое дерево разбивает обучающие данные на группы выборок с одинаковыми «значениями» для ответа. Переменная и местоположение разделения (например, pH> 4,5), которое лучше всего предсказывает (то есть минимизирует «ошибку»), формирует первое разделение или правило в дереве. Каждая ветвь этого разбиения затем рассматривается по очереди, и идентифицируются новые разбиения / правила, которые минимизируют «ошибку» дерева. Это двоичный алгоритм рекурсивного разбиения. Расколы - это правила.
Восстановить Монику - Дж. Симпсон

@ user1665355 Да, извините, я пришел из поля, где выборка является наблюдением, строкой в ​​наборе данных. Но когда вы начинаете говорить о начальной загрузке, это набор из N наблюдений, составленный с заменой обучающих данных и, следовательно, имеющий N строк или наблюдений. Я постараюсь очистить мою терминологию позже.
Восстановить Монику - Дж. Симпсон

Благодарность! Я очень новичок в РФ, извините за, возможно, глупые вопросы :) Мне кажется, я понимаю почти все, что вы написали, очень хорошее объяснение! Мне просто интересно, где переменная и место разделения (например, pH> 4,5), которое лучше всего предсказывает (то есть минимизирует «ошибку»), формирует первое разделение или правило в дереве ... Я не могу понять, что это за ошибка. Я читаю и пытаюсь понять http://www.ime.unicamp.br/~ra109078/PED/Data%20Minig%20with%20R/Data%20Mining%20with%20R.pdf. На странице 115-116 авторы используют RF для выбора variable importanceтехнических индикаторов.
user1665355

«Ошибка» зависит от того, какой тип дерева устанавливается. Отклонение является обычной мерой для непрерывных (гауссовых) ответов. В пакете rpart коэффициент Джини является значением по умолчанию для категориальных ответов, но есть другие для различных моделей и т. Д. Если вы хотите успешно развернуть его, вам следует воспользоваться хорошей книгой по деревьям и RF. Показатели важности переменных являются чем-то другим - они измеряют «важность» каждой переменной в наборе данных, наблюдая, насколько что-то меняется, когда эта переменная используется для подгонки к дереву и когда эта переменная не используется.
Восстановить Монику - Дж. Симпсон
Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.