"""
"""



##############################################
# Setup
# ---------------------

import numpy as np
import matplotlib.pyplot as plt
import time
import os

import torch
from geomloss import SamplesLoss

use_cuda = torch.cuda.is_available()
dtype    = torch.cuda.FloatTensor if use_cuda else torch.FloatTensor

###############################################
# Display routines
# ~~~~~~~~~~~~~~~~~

from random import choices
from scipy import misc


def load_image(fname) :
    img = misc.imread(fname, flatten = True) # Grayscale
    img = (img[::-1, :])  / 255.
    return 1 - img

def draw_samples(fname, n, dtype=torch.FloatTensor) :
    A = load_image(fname)
    xg, yg = np.meshgrid( np.linspace(0,1,A.shape[0]), np.linspace(0,1,A.shape[1]) )
    
    grid = list( zip(xg.ravel(), yg.ravel()) )
    dens = A.ravel() / A.sum()
    dots = np.array( choices(grid, dens, k=n ) )
    dots += (.5/A.shape[0]) * np.random.standard_normal(dots.shape)

    return torch.from_numpy(dots).type(dtype)

def display_samples(ax, x, color) :
    x_ = x.detach().cpu().numpy()
    ax.scatter( x_[:,0], x_[:,1], 25*500 / len(x_), color, edgecolors='none' )


###############################################
# Dataset
# ~~~~~~~~~~~~~~~~~~
#
# Our source and target samples are drawn from measures whose densities
# are stored in simple PNG files. They allow us to define a pair of discrete 
# probability measures:
#
# .. math::
#   \alpha ~=~ \frac{1}{N}\sum_{i=1}^N \delta_{x_i}, ~~~
#   \beta  ~=~ \frac{1}{M}\sum_{j=1}^M \delta_{y_j}.

N, M = (250, 250) if not use_cuda else (500, 500)
 
x_i = draw_samples("slope_a.png", N, dtype)
y_j = draw_samples("slope_b.png", M, dtype)


###############################################
# Lagrangian gradient descent
# -------------------------------
# 
 
blur = .1
x_i.requires_grad = True

loss = SamplesLoss("sinkhorn", blur=blur, debias=False)

for i in range(10): 
    # Compute cost and gradient, and update the points' positions accordingly
    L_αβ = loss(x_i, y_j)
    [g]  = torch.autograd.grad(L_αβ, [x_i])
    x_i.data -= len(x_i) * g 

################################################
# Compute the pseudo-transport plan
# ---------------------------------------------


OTe_solver = SamplesLoss("sinkhorn", blur=blur, debias=False, potentials=True)

f_i, g_j = OTe_solver(x_i, y_j)
costs = .5 * ((x_i[:,None,:] - y_j[None,:,:])**2).sum(2)

density = ( (f_i.view(-1,1) + g_j.view(1,-1) - costs) / blur**2 ).exp() / (len(x_i)*len(y_j))

# Fancy display
x_i_ = x_i.detach().cpu().numpy()
y_j_ = y_j.detach().cpu().numpy()
density_ = density.detach().cpu().numpy()

from   matplotlib.collections  import LineCollection

springs, springs_colors = [], []
for (i, line_i) in enumerate(density_):
    for (j, dens_ij) in enumerate(line_i):
        if dens_ij * len(x_i) > .001:
            springs.append( [ x_i_[i], y_j_[j] ] )
            springs_colors.append( (.8,.9,1., min(dens_ij * len(x_i), 1)) )

springs = LineCollection(springs, linewidths=(1,), 
                         colors=springs_colors, linestyle='solid', zorder=0)
            

# Add a blur-circle for scale
i0 = np.argmin( x_i_[:,1] - x_i_[:,0])
circle_blur = plt.Circle( x_i_[i0], blur, color='r', fill=False)



plt.figure(figsize=(8,5))
os.makedirs(os.path.dirname("output/scaling/"), exist_ok=True)

ax = plt.subplot(1,1,1)
ax.scatter( [10], [10] ) # shameless hack to prevent a slight change of axis...

ax.add_collection(springs)
display_samples(ax, y_j, [(.55,.55,.95)])
display_samples(ax, x_i, [(.95,.55,.55)])
ax.add_artist(circle_blur)

plt.axis([0,1,.15,.8])
ax.set_aspect('equal', adjustable='box')

ax.annotate('$\\sqrt{\\varepsilon}$', xy=(x_i_[i0,0]-blur*np.sqrt(3)/2, x_i_[i0,1] - blur/2), 
            xytext=(.5, .2), textcoords='data', size=24, color="red",
            arrowprops=dict(facecolor='red', edgecolor='red', arrowstyle="->"),
            )

plt.xticks([], []); plt.yticks([], [])
plt.tight_layout()
plt.savefig("output/entropic_bias.png", bbox_inches='tight',pad_inches=0 )

plt.show()

