Что означает вывод функции model.predict из Keras?


14

Я построил модель LSTM для прогнозирования повторяющихся вопросов в официальном наборе данных Quora. Метки теста - 0 или 1. 1 означает, что пара вопросов дублируется. После построения модели с использованием model.fit, я тестирую модель, используя model.predictданные теста. Вывод представляет собой массив значений примерно так:

 [ 0.00514298]
 [ 0.15161049]
 [ 0.27588326]
 [ 0.00236167]
 [ 1.80067325]
 [ 0.01048524]
 [ 1.43425131]
 [ 1.99202418]
 [ 0.54853892]
 [ 0.02514757]

Я показываю только первые 10 значений в массиве. Я не понимаю, что означают эти значения и каков прогнозируемый ярлык для каждой пары вопросов?


1
Я думаю, что у вас есть проблемы в вашей сети .. вероятности должны быть в масштабе 0-1 .. но у вас есть 1,99!, Я думаю, что у вас что-то не так ..
Ганем,

Ответы:


8

Выход нейронной сети по умолчанию никогда не будет двоичным, т. Е. Нулями или единицами. Сеть работает с непрерывными значениями (не дискретными), чтобы более свободно оптимизировать потери в рамках градиентного спуска.

Посмотрите здесь на похожий вопрос, который также показывает некоторый код.

Без какой-либо настройки и масштабирования выход вашей сети, скорее всего, попадет где-то в диапазон вашего ввода, с точки зрения его номинального значения. В вашем случае это примерно от 0 до 2.

Теперь вы можете написать функцию, которая превращает ваши значения выше в 0 или 1, основываясь на некотором пороге. Например, масштабируйте значения, чтобы они находились в диапазоне [0, 1], затем, если значение ниже 0,5, верните 0, если выше 0,5, верните 1.


Спасибо, я тоже думал об использовании порогового значения для классификации меток. Но на чем должна основываться пороговая величина?
Dookoto_Sea

@Dookoto_Sea ты должен решить это сам
Джереми Блен

@Dookoto_Sea Пожалуйста, обратите внимание, что если ваша метка равна 0 или 1, ваше значение должно быть в этом диапазоне, интригующим является масштаб значений прогнозируемых значений [0, 2], вам нужно изменить вывод модели
Jérémy Blain

6

Если это проблема классификации, вы должны изменить свою сеть, чтобы иметь 2 выходных нейрона.

Вы можете конвертировать метки в однозначно закодированные векторы, используя

y_train_binary = keras.utils.to_categorical(y_train, num_classes)
y_test_binary = keras.utils.to_categorical(y_test, num_classes)

Затем убедитесь, что ваш выходной слой имеет два нейрона с функцией активации softmax.

model.add(Dense(num_classes, activation='softmax'))

Это приведет к тому, что вы model.predict(x_test_reshaped)будете массивом списков. Где внутренний список - это вероятность того, что экземпляр принадлежит каждому классу. Это добавит до 1, и очевидно, что выбранная метка должна быть выходным нейроном с наибольшей вероятностью.

Keras включил это в свою библиотеку, поэтому вам не нужно делать это сравнение самостоятельно. Вы можете получить метку класса напрямую с помощью model.predict_classes(x_test_reshaped).


3
«Если это проблема классификации, вы должны изменить свою сеть так, чтобы в ней было 2 выходных нейрона.» ... извините, Джа, но он не должен, он может сделать это с одним нейроном и сигмоидом вместо функции softmax.
Ганем

@Minion, оба метода по существу эквивалентны, пороговое значение, которое в противном случае вам нужно было бы сделать с одним выходным нейроном, неявно встроено в сеть. Таким образом, обеспечивая двоичный вывод.
JahKnows

1
Да, я знаю. Я прокомментировал только потому, что он упомянул: «следует изменить вашу сеть, чтобы иметь 2 выходных нейрона». .. спасибо
Ганем

1

Прогнозы основаны на том, что вы вводите в качестве результатов обучения и функции активации.

Например, с входом 0-1 и функцией активации сигмоида для выхода с двоичной потерей кроссентропии вы получите вероятность 1. В зависимости от стоимости ошибочного решения в любом направлении вы можете затем решить, как вы разобраться с этими вероятностями (например, прогнозировать категорию «1», если вероятность составляет> 0,5 или, возможно, уже, когда она> 0,1).

(-,

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.