"""
This file implements the visualisation of the results with geopandas and matplotlib

@author: Stephan Bogs, Chair of Operations Management, RWTH Aachen
"""
import os

import json
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np

from solution_helper import get_transport_cost_from_solution, get_biomass_cost_from_solution, \
    get_production_cost_from_solution, get_outside_system_cost_from_solution

data = {}

if not os.path.isdir('output'):
    os.mkdir('output')

demand_dir = f'output/demand'
if not os.path.isdir(demand_dir):
    os.mkdir(demand_dir)

availability_dir = f'output/availability'
if not os.path.isdir(availability_dir):
    os.mkdir(availability_dir)


def rmdir(directory):
    directory = Path(directory)
    for item in directory.iterdir():
        if item.is_dir():
            rmdir(item)
        else:
            item.unlink()
    directory.rmdir()


def delete_files_in_directory(directory):
    try:
        files = os.listdir(directory)
        for file in files:
            file_path = os.path.join(directory, file)
            if os.path.isfile(file_path):
                os.remove(file_path)
        print("All files deleted successfully.")
    except OSError:
        print("Error occurred while deleting files.")


def refresh():
    params = {}
    with open('data/opt_params/params.json') as params_file:
        file_contents = params_file.read()
        params = json.loads(file_contents)

    with open('data/opt_params/params-default.json') as params_file:
        file_contents = params_file.read()
        default_params = json.loads(file_contents)

    data["params"] = params
    data["default_params"] = default_params

    rmdir('output')

    if not os.path.isdir('output'):
        os.mkdir('output')


refresh()


def visualise_demand(demand_nuts_0, year):
    fig, ax = plt.subplots()
    fig.set_dpi(200)
    fig.suptitle("Demands Biofuels")

    # VISUALISATION BLOCK
    demand_nuts_0.plot(ax=ax, column='demand', legend=True, cmap='Reds', edgecolor="black")
    ax.set_title(year)
    # Centroids on EPSG:4326 might be distorted
    centroids = demand_nuts_0.to_crs('+proj=cea').centroid.to_crs(demand_nuts_0.crs)
    centroids_x = centroids.x
    centroids_y = centroids.y
    values = demand_nuts_0['demand']
    zipped = zip(centroids_x, centroids_y, values)
    for xi, yi, text in zipped:
        ax.annotate("%.0f" % text,
                    xy=(xi, yi), xycoords='data', color='b',
                    xytext=(-5, -5), textcoords='offset points')
    plt.axis('off')
    fig.savefig(f"{demand_dir}/{year}.png")
    plt.close(fig)


def visualise_biomass_availability(availability, year, biomass):
    fig, ax = plt.subplots()
    fig.set_dpi(200)
    fig.suptitle("Availability Biofuels")

    # VISUALISATION BLOCK
    availability.plot(ax=ax, column='interpol', legend=True, cmap='Greens', edgecolor="black")
    ax.set_title(f"{year}-{biomass}")
    # Centroids on EPSG:4326 might be distorted
    centroids = availability.to_crs('+proj=cea').centroid.to_crs(availability.crs)
    centroids_x = centroids.x
    centroids_y = centroids.y
    values = availability['interpol']
    zipped = zip(centroids_x, centroids_y, values)
    for xi, yi, text in zipped:
        ax.annotate("%.0f" % text,
                    xy=(xi, yi), xycoords='data', color='b',
                    xytext=(-5, -5), textcoords='offset points')
    plt.axis('off')
    fig.savefig(f"{availability_dir}/{biomass}-{year}.png")
    plt.close(fig)


def visualise_transport(run, epsilon="opt", multi_scenario=False):
    B = run["biomass_range"]
    I = run["regions_range"]
    T = run["years_range"]
    output_dir = f'output/{epsilon}'
    if not os.path.isdir(output_dir):
        os.mkdir(output_dir)

    # We want to print first, middle and last element
    if not multi_scenario:
        potential_scenarios = [run["scenario"]]
        demand_scenarios = [run["demand_scenario"]]
    else:
        potential_scenarios = run["potential_scenarios"]
        demand_scenarios = run["demand_scenarios"]

    elements = [T[0], int((T[-1] + T[0]) / 2)]
    for scenario in potential_scenarios:
        for demand_scenario in demand_scenarios:
            for t in elements:
                year = t + data['params']["year_start"]
                fig, ax = plt.subplots()
                fig.set_dpi(300)
                values = []
                for i in I:
                    if t < data["params"]["reinvestment_period"]:
                        if run["y0"][i] > 0.0001:
                            values.append(round(run["y0"][i]))
                        else:
                            values.append(0)
                    else:
                        if not multi_scenario and run["y1"][i] > 0.0001:
                            values.append(round(run["y1"][i]))
                        elif multi_scenario and run["y1"][scenario, demand_scenario, i] > 0.0001:
                            values.append(round(run["y1"][scenario, demand_scenario, i]))
                        else:
                            values.append(0)
                
                run["geometries"]['values'] = values
                run["geometries"].plot(ax=ax, column='values', edgecolor="black", cmap="Oranges", legend=True)
                ax.set_title(f'Production facilities {year}')
                # Centroids on EPSG:4326 might be distorted
                centroids = run["geometries"].to_crs('+proj=cea').centroid.to_crs(run["geometries"].crs)
                centroids_x = centroids.x
                centroids_y = centroids.y
                zipped = zip(centroids_x, centroids_y, values)
                for xi, yi, text in zipped:
                    ax.annotate("%.0f" % text,
                                xy=(xi, yi), xycoords='data', color='b',
                                xytext=(-5, -5), textcoords='offset points')

                plt.axis('off')
                fig.savefig(f'{output_dir}/production_facilities-{year}-{scenario}-{demand_scenario}')
                plt.close(fig)
                # Biomass transport:
                fig, ax = plt.subplots()
                fig.set_dpi(300)
                ax.set_title(f'Export biomass {year}')
                run["geometries"].plot(ax=ax, edgecolor="black", facecolor="none")
                # Centroids on EPSG:4326 might be distorted
                centroids = run["geometries"].to_crs('+proj=cea').centroid.to_crs(run["geometries"].crs)
                centroids_x = centroids.x
                centroids_y = centroids.y
                for i in I:
                    for p in I:
                        if not multi_scenario:
                            transport_b = sum(run["xb"][b, i, p, t] for b in B)
                        else:
                            transport_b = sum(run["xb1"][scenario, demand_scenario, b, i, p, t] for b in B)
                        if transport_b > 0.000001:
                            ax.arrow(centroids_x[i], centroids_y[i],
                                    centroids_x[p] - centroids_x[i],
                                    centroids_y[p] - centroids_y[i], length_includes_head=True,
                                    head_width=0.5, head_length=0.5, fc='g', ec='g')
                plt.axis('off')
                fig.savefig(f'{output_dir}/export-biomass-{year}-{scenario}-{demand_scenario}.png')
                plt.close(fig)

                # Fuel transport:
                fig, ax = plt.subplots()
                fig.set_dpi(300)
                ax.set_title(f'Export biofuel {year}')
                run["geometries"].plot(ax=ax, edgecolor="black", facecolor="none")
                # Centroids on EPSG:4326 might be distorted
                centroids = run["geometries"].to_crs('+proj=cea').centroid.to_crs(run["geometries"].crs)
                centroids_x = centroids.x
                centroids_y = centroids.y
                for p in I:
                    for j in I:
                        if not multi_scenario:
                            transport_f = run["xf"][p, j, t]
                        else:
                            transport_f = run["xf1"][scenario, demand_scenario, p, j, t]
                        if transport_f > 0.000001 and p != j:
                            ax.arrow(centroids_x[p], centroids_y[p],
                                    centroids_x[j] - centroids_x[p],
                                    centroids_y[j] - centroids_y[p], length_includes_head=True,
                                    head_width=0.5, head_length=0.5, fc='r', ec='r')
                plt.axis('off')
                fig.savefig(f'{output_dir}/export-biofuel-{year}-{scenario}-{demand_scenario}.png')
                plt.close(fig)

            transport_cost = get_transport_cost_from_solution(run, scenario, multi_scenario, demand_scenario)

            production_cost = get_production_cost_from_solution(run, scenario, multi_scenario, demand_scenario)

            biomass_cost = get_biomass_cost_from_solution(run, scenario, multi_scenario, demand_scenario)

            outside_system_cost = get_outside_system_cost_from_solution(run, scenario, multi_scenario, demand_scenario)

            fig, ax = plt.subplots()

            labels = ['Transport', 'Production', 'Biomass', 'Outside System']
            counts = [transport_cost, production_cost, biomass_cost, outside_system_cost]

            fig.set_dpi(150)
            ax.pie(counts, labels=labels, autopct='%1.1f%%',
                shadow=True, startangle=90)
            ax.axis('equal')
            plt.axis('off')
            fig.savefig(f'{output_dir}/pie_chart-{scenario}-{demand_scenario}.png')
            plt.close(fig)

            for t in elements:
                year = t + data['params']["year_start"]
                fig, ax = plt.subplots(1, 2)
                fig.set_dpi(300)
                fig.autofmt_xdate(rotation=45)
                biomass = data['default_params']["biomasses"]

                counts = []
                for b in B:
                    sum_biomass = sum(run["biomass_availability"][scenario][b][t][i] for i in I)
                    if not multi_scenario:
                        sum_used = sum(run["xb"][b, i, p, t] for p in I for i in I)
                    else:
                        sum_used = sum(run["xb1"][scenario, demand_scenario, b, i, p, t] for p in I for i in I)
                    if sum_biomass > 0:
                        counts.append(sum_used / sum_biomass)
                    else:
                        counts.append(0)
                ax[0].bar(biomass, counts)
                ax[0].set_title(f'Percentage used, {year}')

                fig.autofmt_xdate(rotation=45)
                biomass = data['default_params']["biomasses"]
                counts = []
                counts2 = []
                bottom = np.zeros(len(data['default_params']["biomasses"]))

                for b in B:
                    if not multi_scenario:
                        sum_used = sum(run["xb"][b, i, p, t] for p in I for i in I)
                    else:
                        sum_used = sum(run["xb1"][scenario, demand_scenario, b, i, p, t] for p in I for i in I)
                    total_sum = sum(run["biomass_availability"][scenario][b][t][i] for i in I)
                    bottom[b] = sum_used
                    counts.append(sum_used)
                    counts2.append(total_sum - sum_used)
                ax[1].bar(biomass, counts, label="Used", bottom=np.zeros(len(data['default_params']["biomasses"])))
                ax[1].bar(biomass, counts2, label="Left", bottom=bottom)
                ax[1].legend(loc="upper right")
                ax[1].set_title(f'Totals of biomass, {year}')
                plt.axis('on')
                fig.savefig(f'{output_dir}/biomass_percent_usage-{year}-{scenario}-{demand_scenario}.png')
                plt.close(fig)

                fig, ax = plt.subplots()
                fig.set_dpi(300)
                fig.autofmt_xdate(rotation=45)
                counts = []
                for b in B:
                    if not multi_scenario:
                        sum_used = sum(run["xb"][b, i, p, t] for p in I for i in I)
                    else:
                        sum_used = sum(run["xb1"][scenario, demand_scenario, b, i, p, t] for p in I for i in I)
                    counts.append(sum_used)
                ax.bar(biomass, counts)
                ax.set_title(f'Total used, {year}')
                fig.savefig(f'{output_dir}/biomass_total_usage-{year}-{scenario}-{demand_scenario}.png')
                plt.close(fig)

                # Fuel import:
                fig, ax = plt.subplots()
                fig.set_dpi(300)
                fig.autofmt_xdate(rotation=45)

                countries = data['params']["countries"]
                countries_to_plot = []
                counts = []
                for p in I:
                    if not multi_scenario:
                        import_fuel = sum(run["xf"][j, p, t] for j in I)
                        export_fuel = sum(run["xf"][p, j, t] for j in I)
                    else:
                        import_fuel = sum(run["xf1"][scenario, demand_scenario, j, p, t] for j in I)
                        export_fuel = sum(run["xf1"][scenario, demand_scenario, p, j, t] for j in I)
                    counts.append(import_fuel - export_fuel)
                    countries_to_plot.append(countries[p])
                plt.axis('on')
                ax.bar(countries, counts, label="Import", bottom=np.zeros(len(data['params']["countries"])))
                fig.savefig(f'{output_dir}/biofuel_export_import-{year}-{scenario}-{demand_scenario}-.png')
                plt.close(fig)

                # Biomass import:
                fig, ax = plt.subplots()
                fig.set_dpi(300)
                fig.autofmt_xdate(rotation=45)

                countries = data['params']["countries"]
                countries_to_plot = []
                counts = []
                for p in I:
                    if not multi_scenario:
                        import_biomass = sum(run["xb"][b, j, p, t] for b in B for j in I)
                        export_biomass = sum(run["xb"][b, p, j, t] for b in B for j in I)
                    else:
                        import_biomass = sum(run["xb1"][scenario, demand_scenario, b, j, p, t] for b in B for j in I)
                        export_biomass = sum(run["xb1"][scenario, demand_scenario, b, p, j, t] for b in B for j in I)
                    counts.append(import_biomass - export_biomass)
                    countries_to_plot.append(countries[p])
                plt.axis('on')
                ax.bar(countries, counts, label="Import", bottom=np.zeros(len(data['params']["countries"])))
                fig.savefig(f'{output_dir}/biomass_export_import-{year}-{scenario}-{demand_scenario}.png')
                plt.close(fig)
