"""
Gradient flows in 2D
====================

Let's showcase the properties of **kernel MMDs**, **Hausdorff**
and **Sinkhorn** divergences on a simple toy problem:
the registration of one blob onto another.
"""



##############################################
# Setup
# ---------------------

import numpy as np
import matplotlib.pyplot as plt
import time
import os

import torch
from geomloss import SamplesLoss

use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

###############################################
# Display routine
# ~~~~~~~~~~~~~~~~~
 

import numpy as np
import torch
from random import choices
from scipy import misc
from matplotlib import pyplot as plt


def load_image(fname) :
    img = misc.imread(fname, flatten = True) # Grayscale
    img = (img[::-1, :])  / 255.
    return 1 - img

def draw_samples(fname, n, dtype=torch.FloatTensor) :
    A = load_image(fname)
    xg, yg = np.meshgrid( np.linspace(0,1,A.shape[0]), np.linspace(0,1,A.shape[1]) )
    
    grid = list( zip(xg.ravel(), yg.ravel()) )
    dens = A.ravel() / A.sum()
    dots = np.array( choices(grid, dens, k=n ) )
    dots += (.5/A.shape[0]) * np.random.standard_normal(dots.shape)

    return torch.from_numpy(dots).type(dtype)

def display_samples(ax, x, color) :
    x_ = x.detach().cpu().numpy()
    ax.scatter( x_[:,0], x_[:,1], 100*500 / len(x_), color, edgecolors='none' )


###############################################
# Dataset
# ~~~~~~~~~~~~~~~~~~
#
# Our source and target samples are drawn from intervals of the real line
# and define discrete probability measures:
#
# .. math::
#   \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~
#   \beta  ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}.

N, M = (100, 100) if not use_cuda else (10000, 10000)
 
X_i = draw_samples("density_a.png", N, dtype)
Y_j = draw_samples("density_b.png", M, dtype)


###############################################
# Wasserstein gradient flow
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# 
# To study the influence of the :math:`\text{Loss}` function in measure-fitting
# applications, we perform gradient descent on the positions
# :math:`x_i` of the samples that make up :math:`\alpha`
# as we minimize the cost :math:`\text{Loss}(\alpha,\beta)`.
# This procedure can be understood as a discrete (Lagrangian) 
# `Wasserstein gradient flow <https://arxiv.org/abs/1609.03890>`_
# and as a "model-free" machine learning program, where
# we optimize directly on the samples' locations.
 
def gradient_flow(loss, name, lr=.05) :
    """Flows along the gradient of the cost function, using a simple Euler scheme.
    
    Parameters:
        loss ((x_i,y_j) -> torch float number): 
            Real-valued loss function.
        lr (float, default = .05):
            Learning rate, i.e. time step.
    """
    
    # Parameters for the gradient descent
    Nsteps = int(5/lr)+1 
    display_its = [int(t/lr) for t in [0, .25, .50, 1., 5.]]
    
    # Use colors to identify the particles
    colors = (10*X_i[:,0]).cos() * (10*X_i[:,1]).cos()
    colors = colors.detach().cpu().numpy()
    
    # Make sure that we won't modify the reference samples
    x_i, y_j = X_i.clone(), Y_j.clone()

    # We're going to perform gradient descent on Loss(α, β) 
    # wrt. the positions x_i of the diracs masses that make up α:
    x_i.requires_grad = True  
    
    t_0 = time.time()
    plt.figure(figsize=(6,6)) ; k = 1
    os.makedirs(os.path.dirname("output/flow_2D/"), exist_ok=True)
    for i in range(Nsteps): # Euler scheme ===============
        # Compute cost and gradient
        L_αβ = loss(x_i, y_j)
        [g]  = torch.autograd.grad(L_αβ, [x_i])

        if i in display_its : # display
            plt.clf()
            ax = plt.subplot(1,1,1)
            plt.set_cmap("hsv")
            plt.scatter( [10], [10] ) # shameless hack to prevent a slight change of axis...

            display_samples(ax, y_j, [(.55,.55,.95)])
            display_samples(ax, x_i, colors)
            
            plt.axis([0,1,0,1])
            plt.gca().set_aspect('equal', adjustable='box')
            plt.xticks([], []); plt.yticks([], [])
            plt.tight_layout()
            plt.savefig("output/flow_2D/{}_{:03d}.png".format(name, int(100*lr*i) ), 
                        bbox_inches='tight',pad_inches=0 )
        
        # in-place modification of the tensor's values
        x_i.data -= lr * len(x_i) * g 
    plt.title("t = {:1.2f}, elapsed time: {:.2f}s/it".format(lr*i, (time.time() - t_0)/Nsteps ))


from pykeops.torch import generic_logsumexp

GMM_loglikelihood = generic_logsumexp(
    "- IntInv(2) * SqDist(X,Y) ",
    "L = Vi(1)",
    "X = Vi(2)",
    "Y = Vj(2)",
)

def GMM_Loss(x, y, blur = .1):
    N, M = len(x), len(y)
    x, y = x / blur, y / blur

    a_x = - blur**2 * GMM_loglikelihood(x, x)
    a_y = - blur**2 * GMM_loglikelihood(y, x)
    b_x = - blur**2 * GMM_loglikelihood(x, y)
    b_y = - blur**2 * GMM_loglikelihood(y, y)

    return (a_y - a_x).mean() + (b_x - b_y).mean()


    

gradient_flow( GMM_Loss, "hausdorff" )

gradient_flow( SamplesLoss("gaussian", blur=.1), "gaussian" )

gradient_flow( SamplesLoss("energy"), "energy" )

gradient_flow( SamplesLoss("sinkhorn", blur=.01), "sinkhorn" )
