最优传输 (Optimal transport) 和 sinkhorn 迭代

2019-01-21

引言

最优传输理论在机器学习领域受到了越来越多的关注,特别是 Wasserstein-GAN [2] 的出现,引起了大家对最优传输的兴趣。 有很多近期的工作将最优传输应用在许多不同的任务中,比如小样本学习 [3],自监督学习 [4],信息检索 [5],目标检测 [6],多标签分类 [7] 等等。 但是目前关于最优传输的中文资料较少,特别是关于最优传输的 Sinkhorn 迭代解法及其 subgradient 的求法,几乎找不到相关的中文材料。

本文首先简要介绍离散情况下的最优传输问题,然后介绍文献 [1] 所提出的最优传输问题的快速解法以及对应的对应梯度的计算。 为了理解的方便,本文只介绍离散情况下的最优传输问题。 值得注意的是,有很多关于最优传输的资料会用到不同的关键词, 例如 optimal transport / wasserstein distance / earth mover's distance / sinkhorn distance 等。 离散条件下 optimal transport / wasserstein distance / earth mover's distance 可以认为是等价的,Sinkhorn iteration 是一种最优传输的快速迭代解法,后文中会介绍到。 本文涉及到的数学符号均沿用文献 [1],并补充了详细的说明。

推土机距离:earth mover's distance

Earth mover's distance

假设地面上有 $d_1$ 个土堆,第 $i$ 个土堆有 $r_i$ 数量的土; 同时有 $d_2$ 个坑,第 $j$ 个坑可以容纳的 $c_j$ 数量的土。 我们用 $$ \begin{split} r\in\mathbb{R}^{d1} \\ c\in\mathbb{R}^{d2} \end{split} $$ 来分别表示这土堆和坑,并假设 $\sum_i r_i = \sum_j c_j$,也就是 所有的 “土” 刚好够填满所有的 “坑”。 定义 $\mathbf{M}\in\mathbb{R}^{d_1\times d_2}_+$ 为距离矩阵,其中 $m_{ij}$ 表示 “从第 $i$ 个土堆搬运一份土到第 $j$ 个坑的成本 (cost)”。 现在要将土从 $r$ 搬到 $c$,定义 $\mathbf{P}\in\mathbb{R}^{d_1\times d_2}_+$ 为从 $r$ 到 $c$ 的 "传输方案" (transportation paln), 其中 $p_{i,j}$ 表示从 $r_i$ 搬运到 $c_j$ 的土的数量。 最优传输目的是求得一个最优的传输矩阵 $P^*$,使得总体的传输消耗 (cost) 最小

很显然,传输矩阵 $P$ 的每一 行/列 加起来应该等于 $r$/$c$,这也是为什么用字母 r(ow)/c(olumn) 来表示的原因。 写成公式,$P$ 满足约束: $$ \begin{equation} \begin{cases} \sum_i p_{i,j} = c \\ \sum_j p_{i,j} = r \end{cases}\label{eq:t-constrain} \end{equation} $$ 即所有满足条件的 transportation plan $P$ 的集合为: $$ U(r, c) := \big\{ P\in \mathbb{R}_+^{d_1\times d_2} | P \cdot 1_{d2} = r, P^T \cdot 1_{d1} = c \big\} \label{eq:U} $$ 其中 $1_d$ 表示 $d$ 维的全1向量。 值得一提的是,$U(r, c)$ 构成高维空间的多面体 (polytope) , 因为 式$\ref{eq:t-constrain}$ 所有约束都是线性的,平面上多个线性约束条件围成的区域是多边形, 拓展到高维空间后线性约束围成的集合就是多面体(polytope)。 这里用集合 $\mathbf{U}$ 来表示所有满足条件的传输,在优化中集合 $\mathbf{U}$ 又被称为“可行域” (feasible set)。

最优传输 $P^*$ 定义为“使得总体运输成本 $\langle P, M\rangle_F$ 最小的运输方案”: $$ \begin{equation} P^* = \inf_{P\in \mathbf{U(r, c)}} \langle P, M\rangle_F \label{eq:ot-origin} \end{equation} $$ 其中 $\langle X, Y\rangle_F = \sum_i X_iY_i$ 为 Frobenius inner product 。 使得 式$\ref{eq:ot-origin}$ 最小的传输方案 $P^*$ 称 为最优传输方案

简而言之,最优传输就是要在可行域 $U(r, c)$ 中找到找到使得 式$\ref{eq:ot-origin}$ 最小的传输方案 $P^*$。 很显然这是一个线性规划问题,因为不论是式$\ref{eq:t-constrain}$中的约束条件还是式$\ref{eq:ot-origin}$中的目标函数 都是线性的

EMD 度量概率分布的距离

如果 $r$ 和 $c$ 满足 $\sum_i r_i = 1, \ \sum_j c_j = 1 $ 且 $\forall i, j \ \ \ \ r_i \ge 0, c_j \ge 0$, 那么 $r$ 和 $c$ 可以看作两个概率分布, $U(r, c)$ 也可以看作是 边际分布分别为 $r$ 和 $c$ 的 “联合概率分布”。 $r$ 和 $c$ 之间的最优传输距离可视作概率分布之间的差异。 在很多应用中都会将输入 $r$ 和 $c$ 归一化成概率(比如分类中的 softmax),最后计算两个概率的最优传输距离。 因此最优传输也可用作度量概率分布之间的距离。

Why optimal transport?

那么我们为什么要计算 optimal transport 呢?如果为了 measure 概率分布之间的距离,有很多现成的 measurements 可以用, 比如非常简单的 KL 散度: $$ \text{KL}(p, q) = \sum_i p_i \cdot\log\frac{p_i}{q_i}. $$ 除了 “无法处理两个分布的支撑集不相交的情况”以及“不满足对称性” 等原因之外,一个重要的原因就是这种 逐点计算的度量 没有考虑分布内的结构信息。 所谓的结构信息,就是分布内的联系。例如在 KL 散度中,$p_i, i=0,1,...$ 彼此之间都是独立计算最后加起来的, 而大部分情况下它们并不是独立的。

就以我们常见的分类任务为例,分类任务通常用交叉熵损失来度量模型预测和样本标签之间的距离, 交叉熵损失实际上就是在计算 onehot 化的标签和模型预测之间的 KL 散度。 这种逐点计算的损失函数(不论是交叉熵还是 L2)都无法考虑分布内不同事件的相关性。 例如将“汽车”误分类成“卡车”显然没有把“汽车”误分类成“斑马”严重。 但是用 KL 散度来度量的话,这两种错误的损失是一样的。

假设$r \in \mathbb{R}^{d1}_+, c \in \mathbb{R}^{d2}_+$是两个离散概率分布, 满足 $r^T\mathbf{1}_{d1} = 1, c^T\mathbf{1}_{d2} = 1$,其中$\mathbf{1}_{d1},\mathbf{1}_{d1}$ 表示全1列向量, 我们可以用最优传输来计算两者之间的距离。 而概率分布内的结构信息可以通过距离矩阵 $\mathbf{M}$ “嵌入” 到距离度量中。 还是以分类为例,我们可以让 $m_{汽车,卡车}$ 远小于 $m_{汽车,斑马}$。

熵正则的最优传输问题

熵正则

在文献[1]中,Cuturi 提出了一种快速的最优传输求解方法。该方法首先引入熵正则,使得原问题的可行域更加平滑, 然后将最优传输问题转为 matrix permutation 问题,最后用 sinkhorn 迭代算法求原问题的近似解。

联合分布 $P$ 的熵 (entropy) 为: $$ \begin{equation} H(P) = -\sum_i P_i\log(P_i) \end{equation} $$ 原问题 式$\ref{eq:ot-origin}$ 增加熵正则约束之后,优化目标变为: $$ \begin{equation} D_{M, \lambda}(r,c) = \inf_{P\in \mathbf{U}} \langle P, M\rangle - \frac{1}{\lambda}H(P) \label{eq:ot-entropy} \end{equation} $$ 式$\ref{eq:ot-entropy}$ 称为“熵正则的最优传输” (entropy-regularized optimal transport),对应的最优解表示为 $P^*_{\lambda}$。 从概率和信息论的角度,均匀分布的熵最大,因此熵正则会让最优传输矩阵 $P^*_{\lambda}$ 更趋近于均匀分布。

从优化的角度, 式$\ref{eq:ot-origin}$ 中的原问题的解一定存在于多边形 $U(r, c)$ 的某一顶点处(这么说似乎不严谨,但是当 $r, c$ 的维度很高时,最优解一定在顶点处取得), 因此 $P^*$ 是一个稀疏矩阵,绝大部分元素都是0。 增加熵正则之后,相当于将原来的可行域 $U(r, c)$ 向内收缩成光滑的 $U_{\lambda}(r, c)$,对应的最优解 $P^*_{\lambda}$ 不再是稀疏矩阵。

可行域 $U(r, c)$ 和 $U_{\lambda}(r, c)$ 包含所有满足约束的点, $U(r, c)$ 是一个高维空间中的多边形 (polytope),增加熵正则之后的可行域 $U_{\lambda}(r, c)$ 是一个边界光滑的区域。 蓝色虚线表示目标函数的等高线,距离矩阵 $M$ 决定最优解所在的方向。 本图参考了 文献[1] 的 Fig.1。
当$\lambda \rightarrow \infty$ 时,$U_{\lambda}(r, c)$趋向于 $U_{\lambda}(r, c)$,最优解 $P^*_{\lambda}$ 也趋向于 $P^*$; 当$\lambda \rightarrow 0$ 时,$U_{\lambda}(r, c)$ 收缩成点 $rc^T$,最优解为 $P^*_{\lambda}=rc^T$。

为什么要增加熵正则

参考文献 [1] 中作者列出了两个使用熵正则的原因:

  1. 式$\ref{eq:ot-origin}$ 中的原线性规划问题的解一定是在可行域 $U(r, c)$ 的某个顶点 (vertex) 上的,因此得到的解 $P^*$ 是一个稀疏矩阵。 稀疏的解 $P^*$ 会使得最终的传输方案十分不均衡, 使用熵正则可以让传输矩阵更加均衡,这部分可以参考 [1] 的第三章相关内容。
  2. 更重要的是,增加了熵正则之后,原问题可以使用 sinkhorn 算法得到近似解,而且计算开销大大降低。 关于这部分,可以参考原文第四章。
总的来说,增加熵正则之后的最优传输问题得到的解更加光滑和均衡,同时可以使用 sinkhorn 迭代算法快速求解。

基于 Sinkhorn 迭代的快速解法。

Sinkhorn 迭代求解最优传输

假设有向量 $u \in \mathbb{R}^{d1}, v\in \mathbb{R}^{d2}$, 令 $u, v$ 的初始值为 $$ \begin{split} u &= \mathbf{1}_{d1} / d1 \\ v &= \mathbf{1}_{d2} / d2 \end{split} $$ 然后用 sinkhorn 迭代算法求解: $$ \begin{equation} \begin{split} u &\leftarrow r / (K\cdot v) \\ v &\leftarrow c / (K^T\cdot u) \end{split}\label{eq:sinkhorn} \end{equation} $$ 上式中 $K:=e^{-\lambda M}$ (参考[1] 中 Lemma2)。 迭代收敛后最优传输矩阵 $P^*$ 和对应的最小传输距离 $C^*$ 可以由以下公式给出: $$ \begin{equation}\begin{split} P^* &= \text{diag}(u)K \text{diag}(v) \\ C^* &= \langle P^*, M\rangle_F \end{split}\label{eq:distance} \end{equation} $$ $C^*$ 对 $r, c$ 的导数分别是: $$ \begin{equation} \begin{split} \frac{\partial C^*}{\partial r} &= \log(u) / \lambda \\ \frac{\partial C^*}{\partial c} &= \log(v) / \lambda \end{split}\label{eq:gradient} \end{equation} $$ 有了 公式$\ref{eq:distance}$ 和 公式$\ref{eq:gradient}$, 最优传输距离就可以作为损失函数来训练神经网络。

PyTorch 实现

这里 (vlkit.optimal_transport.sinkhorn) 有一个 sinkhorn 的 PyTorch 实现,我们可以计算并可视化两个分布之间的最优传输。

以高斯分布为例,首先生成两个 1d 高斯分布作为 source 和 target distribution:

import torch
import matplotlib
import matplotlib.pyplot as plt
from matplotlib import gridspec
from vlkit.optimal_transport import sinkhorn

# generate two gaussians as the source and target
def gaussian(mean=0, std=10, n=100):
    d = (-(torch.arange(n) - mean)**2 / (2 * std**2)).exp()
    d /= d.sum()
    return d

n = 20
d1 = gaussian(mean=12, std=2, n=n)
d2 = gaussian(mean=6, std=4, n=n)

dist = (torch.arange(n).view(1, n) - torch.arange(n).view(n, 1)).abs().float()
dist /= dist.max()

# visualize distr
fig, axes = plt.subplots(1, 2, figsize=(9, 3))
axes[0].bar(torch.arange(n), d1)
axes[0].set_title('Source distribution')
axes[1].bar(torch.arange(n), d2)
axes[1].set_title('Target distribution')
plt.tight_layout()
然后通过 `sinkhorn` 迭代计算最优传输,并可视化结果。
T, u, v = sinkhorn(r=d1.unsqueeze(dim=0), c=d2.unsqueeze(dim=0), reg=1e-2, M=dist.unsqueeze(dim=0))
plt.figure(figsize=(10, 10))
gs = gridspec.GridSpec(3, 3)

ax1 = plt.subplot(gs[0, 1:3])
plt.bar(torch.arange(n), d2, label='Target distribution')

ax2 = plt.subplot(gs[1:, 0])
ax2.barh(torch.arange(n), d1, label='Source distribution')

plt.gca().invert_xaxis()
plt.gca().invert_yaxis()

plt.subplot(gs[1:3, 1:3], sharex=ax1, sharey=ax2)
plt.imshow(T.squeeze(dim=0))
plt.axis('off')

plt.tight_layout()

补充阅读

这里列出了一些关于 optimal transport 的相关阅读材料供大家参考。

如果本文的内容对你撰写学术论文有帮助,希望能考虑引用:
@misc{zhao2020optimal,
title   = {最优传输与 sinkhorn-knop 迭代},
author  = {Kai Zhao},
year    = 2020,
note    = {\url{http://kaizhao.net/blog/optimal-transport}}
}

References

  1. Cuturi, Marco. "Sinkhorn distances: Lightspeed computation of optimal transport." Advances in neural information processing systems 26 (2013): 2292-2300.
  2. Arjovsky, Martin, Soumith Chintala, and Léon Bottou. "Wasserstein generative adversarial networks." International conference on machine learning. PMLR, 2017.
  3. Zhang, Chi, et al. "DeepEMD: Few-Shot Image Classification With Differentiable Earth Mover's Distance and Structured Classifiers." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2020.
  4. Liu, Songtao, Zeming Li, and Jian Sun. "Self-EMD: Self-Supervised Object Detection without ImageNet." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
  5. Xie, Yujia, et al. "Differentiable top-k operator with optimal transport." arXiv preprint arXiv:2002.06504 (2020).
  6. Ge, Zheng, et al. "OTA: Optimal Transport Assignment for Object Detection." Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition. 2021.
  7. Frogner, Charlie, et al. "Learning with a Wasserstein loss." Advances in neural information processing systems (2015).

感谢阅读🤗 本文内容谢绝任何形式的转载,如果您想和朋友分享本文内容,请分享本文链接 kaizhao.net/blog/optimal-transport。 如果您发现文中的错误,或者有任何疑问,欢迎在下方留言交流 (留言功能基于 disqus,在中国大陆的读者可能需要一些技术手段才能连接🥲)。