Category: Topics
LOSS: Wasserstein Distance
The implementation of tf version Wasserstein Distance in reference of Scipy wasserstein_distance:
tf wasserstein_distance
Python
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 |
import tensorflow.compat.v1 as tf def wasserstein_distance(p_prob, q_prob): p = tf.reshape(p_prob, [-1]) # can be batch case q = tf.reshape(q_prob, [-1]) # can be batch case p_sorter = tf.argsort(p) q_sorter = tf.argsort(q) all_values = tf.concat((p, q),axis=0) all_values = tf.sort(all_values, axis=-1, direction='ASCENDING') deltas = all_values[1:] - all_values[:-1] p_cdf_indices = tf.searchsorted(tf.gather(p, p_sorter), all_values[:-1], side='right') q_cdf_indices = tf.searchsorted(tf.gather(q, q_sorter), all_values[:-1], side='right') p_cdf = p_cdf_indices / tf.shape(p)[0] q_cdf = q_cdf_indices / tf.shape(q)[0] loss_sum = tf.reduce_sum(tf.multiply(tf.cast(tf.abs(p_cdf - q_cdf), tf.float32), tf.cast(deltas, tf.float32))) return loss_sum |