import torch
import numpy as np 
import torch.nn as nn
from   torch.nn import Parameter
from   torch    import optim
from torch.autograd import Variable
import itertools

import os
from model_fitting import FitModel
import matplotlib
from matplotlib import pyplot as plt
import matplotlib.font_manager as fm
myfont = fm.FontProperties(fname='./Overlock-Regular.ttf')
#plt.rcParams["font.family"] = prop

plt.ion() 
plt.show() 

tensor = torch.FloatTensor

class Model(nn.Module) :
    def __init__(self) :
        super(Model, self).__init__()

    def __call__(self, inp ) :
        raise NotImplementedError
    
    def cost(self, inp, out) :
        return ((out - self(inp))**2).mean()

    def plot_midlayers(self, ax, t) :
        None

    def plot(self, ax, x, y, t, focus=None) :
        y_p, u = self(x), self(t)
        x_, y_, y_p_, t_, u_ =   x.data.cpu().numpy(), y.data.cpu().numpy(), \
                               y_p.data.cpu().numpy(), \
                                 t.data.cpu().numpy(), u.data.cpu().numpy()

        segs_x = [ [a, a, None] for a in x_ ]
        segs_x = np.array(list( itertools.chain.from_iterable(segs_x)))
        segs_y = [ [b, c, None] for b,c in zip(y_, y_p_) ]
        segs_y = np.array(list( itertools.chain.from_iterable(segs_y)))


        # get width and height of axes object to compute 
        # matching arrowhead length and width
        dps =  plt.gcf().dpi_scale_trans.inverted()
        bbox = ax.get_window_extent().transformed(dps)
        width, height = bbox.width, bbox.height
        
        xmin, xmax = -.5,.5
        ymin, ymax = -.3,.2
        # manual arrowhead width and length
        hw = 1./20.*(ymax-ymin) 
        hl = 1./20.*(xmax-xmin)
        lw = 1. # axis line width
        ohg = 0.3 # arrow overhang
        
        # compute matching arrowhead length and width
        yhw = hw/(ymax-ymin)*(xmax-xmin)* height/width 
        yhl = hl/(xmax-xmin)*(ymax-ymin)* width/height
        
        # draw x and y axis
        ax.arrow(xmin, -.25, xmax-xmin, 0., fc='k', ec='k', lw = lw, 
                head_width=hw, head_length=hl, overhang = ohg, 
                length_includes_head= True, clip_on = False) 
        
        ax.arrow(-.45, ymin, 0., ymax-ymin, fc='k', ec='k', lw = lw, 
                head_width=yhw, head_length=yhl, overhang = ohg, 
                length_includes_head= True, clip_on = False) 

        ax.text(.45, -.29, r'$\mathregular{x}$', fontproperties = myfont, fontsize=35, color="#080899")
        ax.text(-.495, .15, r'$\mathregular{y}$', fontproperties = myfont, fontsize=35, color="#990808")

        if focus is None :
            ax.scatter(x_, y_p_, 160, color = "b")
            ax.scatter(x_, y_,   160, color = "c")
            self.plot_midlayers(ax, t)
            ax.plot(t_, u_,      color = "b", linewidth=2)
            ax.plot(segs_x, segs_y, color = "#95DFEB", linewidth = 4, zorder=0)
        else :
            xf, yf = x_[focus], y_[focus]
            ax.scatter(x_, y_,   160, color = "c")
            ax.plot([-.45, xf, xf], [yf, yf, -.25], color = "#95DFEB", linewidth = 4, zorder=0)
            ax.text( xf, -.29, r'$\mathregular{x_' + str(focus+1) + r'}$', 
                               fontproperties = myfont, fontsize=35, color="#080899")
            ax.text(-.495, yf,  r'$\mathregular{y_' + str(focus+1) + r'}$',
                               fontproperties = myfont, fontsize=35, color="#990808")

cmap  = plt.cm.bwr
cnorm = matplotlib.colors.Normalize(vmin=-1, vmax=1)
csm   = matplotlib.cm.ScalarMappable(cmap=cmap, norm=cnorm)

def arrow(ax, x1,x2, y1,y2, color='k', lw = 2, weight=None) :
    if weight is None :
        ax.arrow(x1, y1, x2-x1, y2-y1, fc=color, ec=color, lw = lw, 
                    head_width=.1, head_length=.1, overhang = .2, 
                    length_includes_head= True, clip_on = False)
    else :
        size = min(np.abs(weight), 1)
        color = csm.to_rgba(weight, alpha = (.3+size)/1.3)
        ax.arrow(x1, y1, x2-x1, y2-y1, fc=color, ec=color, lw = 1+2*size, 
                    head_width=0., head_length=.1, overhang = .2, 
                    length_includes_head= True, clip_on = False)

def node( name, ax, x, y, radius=.25, color="k" ) :
    circle = plt.Circle( (x,y), radius, color=color, fill=False, lw=2, zorder=3)
    ax.add_artist(circle)
    ax.text(x, y,  name, horizontalalignment='center', verticalalignment='center',
                         fontproperties = myfont, fontsize=25, color=color)

def plot_io(ax) :
    ax.text(.5, 1.65,  "entrée", fontproperties = myfont, fontsize=25)
    ax.text(8.5, 1.65,  "sortie", fontproperties = myfont, fontsize=25)

    arrow( ax, 1.,1.75, 1.5,1.5)
    node(  r'$\mathregular{x}$', ax, 2., 1.5, color="#080899")
    node(  r'$\mathregular{y}$', ax, 8., 1.5, color="#990808")
    arrow( ax, 8.25,9., 1.5,1.5)

class PolynomialRegression(Model) :
    def __init__(self, degree) :
        super(Model, self).__init__()
        self.coeffs = Parameter(torch.zeros(degree + 1))
    
    def __call__(self, inp) :
        out = torch.zeros_like(inp)
        d = len(self.coeffs) - 1
        for (k, c) in enumerate(self.coeffs) :
            out = out + c * inp**k
        return out

    def plot_network(self, ax) :
        plot_io(ax)

        coeffs = self.coeffs.data.cpu().numpy()

        if len(coeffs) > 1 :
            arrow( ax, 2.25, 7.75, 1.5,1.5, weight = coeffs[-1]/.25)

        xo, yo = -.45, .25
        for (k, c) in enumerate(coeffs) :
            yo = yo + c * xo**k
        if len(coeffs) == 2 :
            str_1 = "$\mathregular{(a,b)\,\,\,=\,\,}$"
            str_2 = " {:+04.2f}, {:+04.2f}".format( coeffs[1], yo )
            str_3 = "$\mathregular{F(\,a,b\,;\,x\,) \,\,\,=\,\,}$"
            str_4 = "$\mathregular{\,\,\, a\cdot x + b}$"
        elif len(coeffs) == 3 :
            str_1 = "$\mathregular{(a,b,c)\,\,\,=\,\,}$"
            str_2 = " {:+04.2f}, {:+04.2f}, {:+04.2f}".format( coeffs[2], coeffs[1], yo )
            str_3 = "$\mathregular{F(\,a,b,c\,;\,x\,) \,\,\,=\,\,}$"
            str_4 = "$\mathregular{\,\,\, a\cdot x^2 + b\cdot x + c}$"

        elif len(coeffs) == 5 :
            str_1 = "$\mathregular{(a,b,\dots,e)\,\,\,=\,\,}$"
            str_2 = " {:+04.2f}, {:+04.2f}, ..., {:+04.2f}".format( coeffs[4], coeffs[3], yo )
            str_3 = "$\mathregular{F(\,a,b,c,d,e\,;\,x\,) \,\,\,=\,\,}$"
            str_4 = "$\mathregular{\,\,\, a\cdot x^4 + b\cdot x^3 + \cdots + e}$"

        else :
            str_1 = ""
            str_2 = ""
            str_3 = ""
            str_4 = ""

        ax.text(4, 2.,  str_1, 
                horizontalalignment='right',
                fontproperties = myfont, fontsize=25)
        ax.text(4, 2.,  str_2, 
                horizontalalignment='left',
                fontsize=20)

        ax.text(4, 2.5,  str_3, 
                horizontalalignment='right',
                fontproperties = myfont, fontsize=25)
        ax.text(4, 2.5,  str_4, 
                horizontalalignment='left',
                fontproperties = myfont, fontsize=25)


class OneLinearLayer(Model) :
    def __init__(self, h) :
        super(Model, self).__init__()
        self.offset = Variable(tensor([-.25]))
        self.network = torch.nn.Sequential(
            torch.nn.Linear(1, h),
            torch.nn.Linear(h, 1)
        )
        if h == 2 : # we're going to plot everything, so let's make it fancy...
            self.network[0].weight.data[0,0] = -1 
            self.network[0].weight.data[1,0] =  .5
            self.network[0].bias.data[0] = -0.25
            self.network[0].bias.data[1] = -0.25
            
            self.network[1].weight.data[0,0] = 1.
            self.network[1].weight.data[0,1] = 1.
            self.network[1].bias.data[0] = 0.
    
    def __call__(self, inp) :
        out = self.network[0](inp.view(-1,1))
        out = self.network[1](out)
        return out.view(-1)

    def plot_midlayers(self, ax, t) :
        coeffs_1 = self.network[0].weight.view(-1).data.cpu().numpy()
        if len(coeffs_1) in [2,4] :
            u = self.network[0](t.view(-1,1))
            t_, u_ = t.data.cpu().numpy(), u.data.cpu().numpy()
            ax.plot(t_, u_,      color = "#990899", linewidth=1)

    def plot_network(self, ax) :
        plot_io(ax)
        coeffs_1 = self.network[0].weight.view(-1).data.cpu().numpy()
        coeffs_2 = self.network[1].weight.view(-1).data.cpu().numpy()
        nmax = min(len(coeffs_1), 4)
        coeffs_1, coeffs_2 = coeffs_1[:nmax], coeffs_2[:nmax]
        zs = np.linspace(2.5, .5, nmax)
        for (i, (z, coeff_1, coeff_2)) in enumerate(zip(zs, coeffs_1, coeffs_2)) :
            arrow( ax, 2.25, 4.75, 1.5,z, weight=coeff_1/.4)
            arrow( ax, 5.25, 7.75, z,1.5, weight=coeff_2/.4)
            node(  r'$\mathregular{g_'+ str(i+1) + r'}$', ax, 5., z, color="#990899")

        if nmax == 2 :
            ax.text(5.25, 2.7, "$\mathregular{=\,a_1\cdot x + b_1}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990899")
            ax.text(5.25, .3, "$\mathregular{=\,a_2\cdot x + b_2}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990899")
            ax.text(6.75, 1., "$\mathregular{=\,c_1\cdot g_1 + c_2\cdot g_2 + d}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990808")




class OneHiddenLayer(Model) :
    def __init__(self, h) :
        super(Model, self).__init__()
        self.offset = Variable(tensor([-.25]))
        self.network = torch.nn.Sequential(
            torch.nn.Linear(1, h),
            torch.nn.ReLU(),
            torch.nn.Linear(h, 1)
        )
        if h == 2 : # we're going to plot everything, so let's make it fancy...
            self.network[0].weight.data[0,0] = -1 
            self.network[0].weight.data[1,0] =  .5
            self.network[0].bias.data[0] = -0.25
            self.network[0].bias.data[1] = -0.25
            
            self.network[2].weight.data[0,0] = 1.
            self.network[2].weight.data[0,1] = 1.
            self.network[2].bias.data[0] = 0.
        elif h == 4 :
            self.network[0].weight.data[0,0] = -1 
            self.network[0].weight.data[1,0] =  .5
            self.network[0].weight.data[2,0] = -.6 
            self.network[0].weight.data[3,0] = .9
            self.network[0].bias.data[0] = -0.5
            self.network[0].bias.data[1] = -0.3
            self.network[0].bias.data[2] = -0.45
            self.network[0].bias.data[3] = -0.55
            
            self.network[2].weight.data[0,0] = .5
            self.network[2].weight.data[0,1] = .4
            self.network[2].weight.data[0,2] = .4
            self.network[2].weight.data[0,3] = .6
            self.network[2].bias.data[0] = 0.
    
    def __call__(self, inp) :
        out = self.network[0](inp.view(-1,1))
        out = torch.max(self.offset, out)
        out = self.network[2](out)
        return out.view(-1)

    def plot_midlayers(self, ax, t) :
        coeffs_1 = self.network[0].weight.view(-1).data.cpu().numpy()
        if len(coeffs_1) in [2,4] :
            u = torch.max(self.offset,  self.network[0](t.view(-1,1)) )
            t_, u_ = t.data.cpu().numpy(), u.data.cpu().numpy()
            ax.plot(t_, u_,      color = "#990899", linewidth=1)

    def plot_network(self, ax) :
        plot_io(ax)
        coeffs_1 = self.network[0].weight.view(-1).data.cpu().numpy()
        coeffs_2 = self.network[2].weight.view(-1).data.cpu().numpy()
        nmax = min(len(coeffs_1), 4)
        coeffs_1, coeffs_2 = coeffs_1[:nmax], coeffs_2[:nmax]
        zs = np.linspace(2.5, .5, nmax)
        for (i, (z, coeff_1, coeff_2)) in enumerate(zip(zs, coeffs_1, coeffs_2)) :
            arrow( ax, 2.25, 4.75, 1.5,z, weight=coeff_1/.8)
            arrow( ax, 5.25, 7.75, z,1.5, weight=coeff_2/.8)
            node(  r'$\mathregular{g_'+ str(i+1) + r'}$', ax, 5., z, color="#990899")

        if nmax == 2 :
            ax.text(5.25, 2.7, "$\mathregular{=\,(a_1\cdot x + b_1)^+}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990899")
            ax.text(5.25, .3, "$\mathregular{=\,(a_2\cdot x + b_2)^+}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990899")
            ax.text(6.75, 1., "$\mathregular{=\,c_1\cdot g_1 + c_2\cdot g_2 + d}$",
                    horizontalalignment='left', verticalalignment='center',
                    fontproperties = myfont, fontsize=25, color="#990808")


        


class TwoHiddenLayers(Model) :
    def __init__(self, h1, h2) :
        super(Model, self).__init__()
        self.network = torch.nn.Sequential(
            torch.nn.Linear(1, h1),
            torch.nn.PReLU(),
            torch.nn.Linear(h1, h2),
            torch.nn.PReLU(),
            torch.nn.Linear(h2, 1),
        )
    
    def __call__(self, inp) :
        return self.network(inp.view(-1,1)).view(-1)
    def plot_network(self, ax) :
        plot_io(ax)
        coeffs_1 = self.network[0].weight.view(-1).data.cpu().numpy()
        coeffs_2 = self.network[2].weight.data.cpu().numpy()
        coeffs_3 = self.network[4].weight.view(-1).data.cpu().numpy()
        nmax_1, nmax_2 = min(len(coeffs_1), 4), min(len(coeffs_3), 4)
        coeffs_1, coeffs_2, coeffs_3 = coeffs_1[:nmax_1], coeffs_2[:nmax_2,:nmax_1], coeffs_3[:nmax_2]

        zs_1 = np.linspace(2.5, .5, nmax_1)
        zs_2 = np.linspace(2.5, .5, nmax_2)
        for (i, (z, coeff_1)) in enumerate(zip(zs_1, coeffs_1)) :
            arrow( ax, 2.25, 3.75, 1.5,z, weight=coeff_1/.5)
            node(  r'$\mathregular{g_'+ str(i+1) + r'}$', ax, 4., z, color="#990899")
        for (i, (z, coeff_2)) in enumerate(zip(zs_2, coeffs_3)) :
            arrow( ax, 6.25, 7.75, z,1.5, weight=coeff_2/.2)
            node(  r"$\mathregular{h_"+ str(i+1) + r'}$', ax, 6., z, color="#089908")
        for (i, coeffs_i) in enumerate(coeffs_2) :
            for (j, coeff_ij) in enumerate(coeffs_i) :
                arrow( ax, 4.25, 5.75, zs_1[j],zs_2[i], weight=coeff_ij/.1)


X = Variable(tensor( [1., 1.4, 2., 2.3, 2.9, 4., 4.5, 5., 5.4])) / 6 - .5
Y = Variable(tensor( [1.5, .8, .7, .2, .42, 1., 1.4, 1.1, 1.9])) / 6 - .25

T = Variable(torch.linspace(-.5, .5, 201))
params = {
    "optimization" : {
        "nits"  : 5000,
        "logs"  : [0,1,2,3,5,10,20,50,100,200,500,1000,2000,5000],
        "tol"   : 1e-12,
        "lr"    : .2,
        "method": "Adam",
    },
    "display" : {
        "t"         : T,
        "limits"    : [-.5,.5,-.3,.2],
        "show_axis" : False,
    },
    "save" : {
        "output_directory" : "output_regression/",
    }
}


fig_model = plt.figure(1, figsize=(20,20), dpi=100)
ax_model  = fig_model.add_axes((0,0,1,5/8))
ax_model.set_axis_off()

# fig_network = plt.figure(2, figsize=(10,3), dpi=100)
ax_network  = fig_model.add_axes((0,4.5/8,1,3/8))
ax_network.set_axis_off()

fig_model.set_size_inches(10,7.5)

model = PolynomialRegression(0)
for (it,_) in enumerate(X) :
    ax_model.clear()
    model.plot(ax_model, X, Y, params["display"]["t"], focus=it)
    ax_model.axis(params["display"]["limits"]) ; ax_model.set_aspect('equal') ; 
    if not params["display"].get("show_axis", True) : ax_model.axis('off')

    ax_network.clear()
    model.plot_network(ax_network)
    ax_network.axis([0,10,0,3]) ; ax_network.set_aspect('equal') ; 
    ax_network.axis('off')

    plt.draw() ; plt.pause(0.01)

    screenshot_filename = params["save"]["output_directory"]+"/data_"+str(it+1)+'.png'
    os.makedirs(os.path.dirname(screenshot_filename), exist_ok=True)
    fig_model.savefig( screenshot_filename, bbox_inches='tight', transparent=True)


for model, name in [(PolynomialRegression(1), "poly_1"),
                    (PolynomialRegression(2), "poly_2"),
                    (PolynomialRegression(4), "poly_4"),
                    (OneLinearLayer(2),       "linear_2"),
                    (OneHiddenLayer(2),       "single_2"),
                    (OneHiddenLayer(4),       "single_4"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_a"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_b"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_c"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_d"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_e"),
                    #(TwoHiddenLayers(100,100),"dual_100x100_f"),
                ] :

    params["save"]["output_directory"] = "output_regression/"+name+"/"
    FitModel(params, model, X, Y)


plt.show(block=True)

