Как оценить потери KLD и потери реконструкции в вариационном авто-кодировщике


26

почти во всех примерах кода, которые я видел в VAE, функции потерь определяются следующим образом (это код с тензорным потоком, но я видел похожее для theano, torch и т. д. Это также для коннета, но это также не слишком актуально) , только влияет на оси, суммы принимаются):

# latent space loss. KL divergence between latent space distribution and unit gaussian, for each batch.
# first half of eq 10. in https://arxiv.org/abs/1312.6114
kl_loss = -0.5 * tf.reduce_sum(1 + log_sigma_sq - tf.square(mu) - tf.exp(log_sigma_sq), axis=1)

# reconstruction error, using pixel-wise L2 loss, for each batch
rec_loss = tf.reduce_sum(tf.squared_difference(y, x), axis=[1,2,3])

# or binary cross entropy (assuming 0...1 values)
y = tf.clip_by_value(y, 1e-8, 1-1e-8) # prevent nan on log(0)
rec_loss = -tf.reduce_sum(x * tf.log(y) + (1-x) * tf.log(1-y), axis=[1,2,3])

# sum the two and average over batches
loss = tf.reduce_mean(kl_loss + rec_loss)

Однако числовой диапазон kl_loss и rec_loss очень зависит от затемнения скрытого пространства и размера входного объекта (например, разрешение в пикселях) соответственно. Было бы разумно заменить Reduce_sum на Reduce_mean, чтобы получить для каждого z-dim KLD и для каждого пикселя (или функции) LSE или BCE? Что еще более важно, как мы взвешиваем скрытые потери с потерями реконструкции при суммировании для окончательной потери? Это просто метод проб и ошибок? или есть какая-то теория (или хотя бы эмпирическое правило) для этого? Я не мог найти информацию об этом нигде (в том числе оригинал).


Проблема, с которой я столкнулся, заключается в том, что если баланс между размерами моего входного элемента (x) и размерами скрытого пространства (z) не является «оптимальным», то мои реконструкции очень хороши, но изученное скрытое пространство не структурировано (если x измерений очень высока и ошибка реконструкции преобладает над KLD), или наоборот (реконструкции не очень хорошие, но изученное скрытое пространство хорошо структурировано, если доминирует KLD).

Я обнаружил, что должен нормализовать потери при восстановлении (деление на размер входного объекта) и KLD (деление на измерения z), а затем вручную взвешивать термин KLD с произвольным весовым коэффициентом (нормализация такова, что я могу использовать то же самое или аналогичный вес не зависит от размеров х или г ). Опытным путем я нашел около 0,1, чтобы обеспечить хороший баланс между реконструкцией и структурированным скрытым пространством, которое мне кажется «сладким пятном». Я ищу предыдущую работу в этой области.


По запросу математическая запись выше (с упором на потери L2 для ошибки восстановления)

LLaTеNT(я)знак равно-12ΣJзнак равно1J(1+журнал(σJ(я))2-(μJ(я))2-(σJ(я))2)

LресоN(я)знак равно-ΣКзнак равно1К(YК(я)-ИксК(я))2

L(м)знак равно1MΣязнак равно1M(LLaTеNT(я)+LресоN(я))

JZμσ2КM(я)яL(м)м

Ответы:


17

7

Я хотел бы добавить еще один документ, касающийся этой проблемы (я не могу комментировать из-за моей низкой репутации на данный момент).

В подразделе 3.1 статьи авторы указали, что им не удалось обучить прямому внедрению VAE, которое в равной степени взвешивало вероятность и расхождение KL. В их случае потери KL были нежелательно уменьшены до нуля, хотя ожидалось, что они будут иметь небольшое значение. Чтобы преодолеть это, они предложили использовать «отжиг стоимости KL», который медленно увеличивал весовой коэффициент члена дивергенции KL (синяя кривая) с 0 до 1.

Рис. 2. Вес члена дивергенции KL вариационной нижней границы в соответствии с типичным графиком отжига сигмовидной кишки, нанесенный вместе с (невзвешенным) значением члена дивергенции KL для нашего VAE в Penn TreeBank.

Это обходное решение также применяется в Ladder VAE.

Бумага:

Боуман С.Р., Вилнис Л., Виньяльс О., Дай А.М., Йозефович Р. и Бенжио С., 2015. Генерация предложений из непрерывного пространства . Препринт arXiv arXiv: 1511.06349.

Используя наш сайт, вы подтверждаете, что прочитали и поняли нашу Политику в отношении файлов cookie и Политику конфиденциальности.
Licensed under cc by-sa 3.0 with attribution required.