"""
Influence of the blur parameter, scaling strategy
=====================================================

Dating back to the work of `Schrödinger <http://www.numdam.org/article/AIHP_1932__2_4_269_0.pdf>`_ 
- see e.g. `(Léonard, 2013) <https://arxiv.org/abs/1308.0215>`_ for a modern review -
entropy-regularized Optimal Transport is all about
solving the convex primal/dual problem:
"""

##################################################
# .. math::
#   \text{OT}_\varepsilon(\alpha,\beta)~&=~
#       \min_{0 \leqslant \pi \ll \alpha\otimes\beta} ~\langle\text{C},\pi\rangle
#           ~+~\varepsilon\,\text{KL}(\pi,\alpha\otimes\beta) \quad\text{s.t.}~~
#        \pi\,\mathbf{1} = \alpha ~~\text{and}~~ \pi^\intercal \mathbf{1} = \beta\\
#    &=~ \max_{f,g} ~~\langle \alpha,f\rangle + \langle \beta,g\rangle
#         - \varepsilon\langle \alpha\otimes\beta, 
#           \exp \tfrac{1}{\varepsilon}[ f\oplus g - \text{C} ] - 1 \rangle,
#
# where the linear `Kantorovitch program <https://en.wikipedia.org/wiki/Transportation_theory_(mathematics)>`_
# is convexified by the addition of an entropic penalty 
# - here, the generalized Kullback-Leibler divergence
# 
# .. math::
#   \text{KL}(\alpha,\beta) ~=~ 
#   \langle \alpha, \log \tfrac{\text{d}\alpha}{\text{d}\beta}\rangle 
#   - \langle \alpha, 1\rangle + \langle \beta, 1\rangle.  
#
# The celebrated `IPFP <https://en.wikipedia.org/wiki/Iterative_proportional_fitting>`_,
# `SoftAssign <https://en.wikipedia.org/wiki/Point_set_registration#Robust_point_matching>`_
# and `Sinkhorn <https://arxiv.org/abs/1803.00567>`_ algorithms are all equivalent 
# to a **block-coordinate ascent** on the **dual problem** above
# and can be understood as smooth generalizations of the 
# `Auction algorithm <https://en.wikipedia.org/wiki/Auction_algorithm>`_,
# where a **SoftMin operator** 
#
# .. math::
#   \text{min}_{\varepsilon, x\sim\alpha} [ \text{C}(x,y) - f(x) ]
#   ~=~ - \varepsilon \log \int_x \exp \tfrac{1}{\varepsilon}[ f(x) - \text{C}(x,y)  ]
#   \text{d}\alpha(x)
#
# is used to update prices in the bidding rounds.
# This algorithm can be shown to converge as a `Picard fixed-point iterator <https://en.wikipedia.org/wiki/Fixed-point_iteration>`_, 
# with a worst-case complexity that scales in 
# :math:`O( \max_{\alpha\otimes\beta} \text{C} \,/\,\varepsilon )` iterations
# to reach a target numerical accuracy, as :math:`\varepsilon` tends to zero.
#
# **Limitations of the (baseline) Sinkhorn algorithm.**
# In most applications, the cost function is the **squared Euclidean distance** 
# :math:`\text{C}(x,y)=\tfrac{1}{2}\|x-y\|^2`
# studied by `Brenier and subsequent authors <http://www.math.toronto.edu/mccann/papers/FiveLectures.pdf>`_,
# with a temperature :math:`\varepsilon` that is
# homogeneous to the **square** of a **blurring scale** :math:`\sigma = \sqrt{\varepsilon}`.
#
# With a complexity that scales in :math:`O( (\text{diameter}(\alpha, \beta) / \sigma)^2)` iterations
# for typical configurations,
# the Sinkhorn algorithm thus seems to be **restricted to high-temperature problems**
# where the point-spread radius :math:`\sigma` of the **fuzzy transport plan** :math:`\pi`
# does not go below ~1/20th of the configuration's diameter.
#
# **Scaling heuristic.**
# Fortunately though, as often in operational research, 
# `simulated annealing <https://en.wikipedia.org/wiki/Simulated_annealing>`_
# can be used to break this computational bottleneck.
# First introduced for the :math:`\text{OT}_\varepsilon` problem
# in `(Kosowsky and Yuille, 1994) <https://www.ics.uci.edu/~welling/teaching/271fall09/InvidibleHandAlg.pdf>`_,
# this heuristic is all about **decreasing the temperature** :math:`\varepsilon`
# across the Sinkhorn iterations, letting prices adjust in a coarse-to-fine fashion.
#
# The default behavior of the :mod:`SamplesLoss("sinkhorn") <geomloss.SamplesLoss>` layer
# is to let :math:`\varepsilon` decay according to an **exponential schedule**.
# Starting from a large value of :math:`\sigma = \sqrt{\varepsilon}`,
# estimated from the data or given through the **diameter** parameter,
# we multiply this blurring scale by a fixed **scaling** 
# coefficient in the :math:`(0,1)` range and loop until :math:`\sigma`
# reaches the target **blur** value.
# We thus work with decreasing values of the temperature :math:`\varepsilon` in
#
# .. math::
#   [ \text{diameter}^2,~(\text{diameter}\cdot \text{scaling})^2,
#       ~(\text{diameter}\cdot \text{scaling}^2)^2,~ \cdots~ , ~\text{blur}^2~],
#
# with an effective number of iterations that is equal to:
#
# .. math::
#   N_\text{its}~=~ \bigg\lceil \frac{ \log ( \text{diameter}/\text{blur} )}{ \log (1 / \text{scaling})} \bigg\rceil.
#
# Let us now illustrate the behavior of the Sinkhorn loop across
# these iterations, on a simple 2d problem.


##############################################
# Setup
# ---------------------
#
# Standard imports:

import numpy as np
import matplotlib.pyplot as plt
import time
import torch
import os
from torch.autograd import grad

use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

###############################################
# Display routines:

from imageio import imread


def load_image(fname) :
    img = np.mean( imread(fname), axis=2 )  # Grayscale
    img = (img[::-1, :])  / 255.
    return 1 - img


def draw_samples(fname, sampling, dtype=torch.FloatTensor) :
    A = load_image(fname)
    A = A[::sampling, ::sampling]
    A[A<=0] = 1e-8

    a_i = A.ravel() / A.sum()

    x, y = np.meshgrid( np.linspace(0,1,A.shape[0]), np.linspace(0,1,A.shape[1]) )
    x += .5 / A.shape[0] ; y += .5 / A.shape[1]

    x_i = np.vstack( (x.ravel(), y.ravel()) ).T

    return torch.from_numpy(a_i).type(dtype), \
           torch.from_numpy(x_i).contiguous().type(dtype)


def display_potential(ax, F, color, nlines=21):
    # Assume that the image is square...
    N = int( np.sqrt(len(F)) )  
    F = F.view(N,N).detach().cpu().numpy()
    F = np.nan_to_num(F)

    # And display it with contour lines:
    levels = np.linspace(-1, 1, nlines)
    ax.contour(F, origin='lower', linewidths = 2., colors = color,
               levels = levels, extent=[0,1,0,1]) 


def display_samples(ax, x, weights, color, v=None) :
    x_ = x.detach().cpu().numpy()
    weights_ = weights.detach().cpu().numpy()

    weights_[weights_ < 1e-5] = 0
    ax.scatter( x_[:,0], x_[:,1], 10 * 500 * weights_, color, edgecolors='none' )

    if v is not None :
        v_ = v.detach().cpu().numpy()
        ax.quiver( x_[:,0], x_[:,1], v_[:,0], v_[:,1], 
                    scale = 1, scale_units="xy", color="#5CBF3A", 
                    zorder= 3, width= 2. / len(x_) )

###############################################
# Dataset
# --------------
#
# Our source and target samples are drawn from measures whose densities
# are stored in simple PNG files. They allow us to define a pair of discrete 
# probability measures:
#
# .. math::
#   \alpha ~=~ \sum_{i=1}^N \alpha_i\,\delta_{x_i}, ~~~
#   \beta  ~=~ \sum_{j=1}^M \beta_j\,\delta_{y_j}.

sampling = 10 if not use_cuda else 2

A_i, X_i = draw_samples("ell_a.png", sampling)
B_j, Y_j = draw_samples("ell_b.png", sampling)

###############################################
# Scaling heuristic
# -------------------
#
# We now display the behavior of the Sinkhorn loss across
# our iterations.

from pykeops.torch.cluster import grid_cluster, cluster_ranges_centroids
from geomloss import SamplesLoss

plt.figure(figsize=( (6,6)))
os.makedirs(os.path.dirname("output/scaling/"), exist_ok=True)

scaling = .5
cluster_scale = .05

for (i, blur) in [ (0, 1.), (3, .5**3), (5, .5**5), (7, .01)  ]:
    Loss = SamplesLoss("sinkhorn", p=2, blur=blur, diameter=1., cluster_scale = cluster_scale,
                        scaling=scaling, verbose=True, backend="multiscale")

    # Create a copy of the data...
    a_i, x_i = A_i.clone(), X_i.clone()
    b_j, y_j = B_j.clone(), Y_j.clone()

    # And require grad:
    a_i.requires_grad = True
    x_i.requires_grad = True
    b_j.requires_grad = True

    # Compute the loss + gradients:
    Loss_xy = Loss(a_i, x_i, b_j, y_j)
    [F_i, G_j, dx_i] = grad( Loss_xy, [a_i, b_j, x_i] )

    # The generalized "Brenier map" is (minus) the gradient of the Sinkhorn loss
    # with respect to the Wasserstein metric:
    BrenierMap = - dx_i / (a_i.view(-1,1) + 1e-7)

    # Compute the coarse measures for display ----------------------------------

    x_lab = grid_cluster(x_i, cluster_scale)
    _, x_c, a_c = cluster_ranges_centroids(x_i, x_lab, weights=a_i)

    y_lab = grid_cluster(y_j, cluster_scale)
    _, y_c, b_c = cluster_ranges_centroids(y_j, y_lab, weights=b_j)


    # Fancy display: -----------------------------------------------------------

    plt.clf()
    ax = plt.subplot( 1, 1, 1)
    ax.scatter( [10], [10] )  # shameless hack to prevent a slight change of axis...

    display_potential(ax, G_j, "#E2C5C5")
    display_potential(ax, F_i, "#C8DFF9")


    if blur > cluster_scale:
        display_samples(ax, y_j, b_j, [(.55,.55,.95, .2)])
        display_samples(ax, x_i, a_i, [(.95,.55,.55, .2)], v = BrenierMap)
        display_samples(ax, y_c, b_c, [(.55,.55,.95)])
        display_samples(ax, x_c, a_c, [(.95,.55,.55)])

    else:
        display_samples(ax, y_j, b_j, [(.55,.55,.95)])
        display_samples(ax, x_i, a_i, [(.95,.55,.55)], v = BrenierMap)

    ax.set_xticks([], []) ; ax.set_yticks([], [])
    ax.axis([0,1,0,1]) ; ax.set_aspect('equal', adjustable='box')
    plt.tight_layout()
    plt.savefig("output/scaling/it_{:02d}_blur_{:.3f}.png".format(i+1, blur), 
                bbox_inches='tight',pad_inches=0 )


plt.show()

