import argparse
import six
import time
import pandas as pd
import numpy as np
from pathlib import Path
from warnings import warn
from symengine import Add, sympify
from dexom_python.enum_functions.icut_functions import create_icut_constraint
from dexom_python.imat_functions import imat
from dexom_python.result_functions import read_solution
from dexom_python.model_functions import load_reaction_weights, read_model, check_model_options, DEFAULT_VALUES
from dexom_python.enum_functions.enumeration import EnumSolution, get_recent_solution_and_iteration, create_enum_variables


def create_maxdist_constraint(model, reaction_weights, prev_sol, obj_tol, name='maxdist_optimality', full=False):
    """
    Creates the optimality constraint for the maxdist algorithm.
    This constraint conserves the optimal objective value of the previous solution
    """
    y_variables = []
    y_weights = []
    x_variables = []
    x_weights = []

    if full:
        for rid, weight in six.iteritems(reaction_weights):
            if weight > 0:
                y_pos = model.solver.variables['xf_' + rid]
                y_neg = model.solver.variables['xr_' + rid]
                y_variables.append([y_neg, y_pos])
                y_weights.append(weight)
            elif weight < 0:
                x = sympify('1') - model.solver.variables['x_' + rid]
                x_variables.append(x)
                x_weights.append(abs(weight))
    else:
        for rid, weight in six.iteritems(reaction_weights):
            if weight > 0:
                y_neg = model.solver.variables['rh_' + rid + '_neg']
                y_pos = model.solver.variables['rh_' + rid + '_pos']
                y_variables.append([y_neg, y_pos])
                y_weights.append(weight)
            elif weight < 0:
                x_variables.append(sympify('1') - model.solver.variables['rl_' + rid])
                x_weights.append(abs(weight))

    lower_opt = prev_sol.objective_value - prev_sol.objective_value * obj_tol
    rh_objective = [(y[0] + y[1]) * y_weights[idx] for idx, y in enumerate(y_variables)]
    rl_objective = [x * x_weights[idx] for idx, x in enumerate(x_variables)]
    opt_const = model.solver.interface.Constraint(Add(*rh_objective) + Add(*rl_objective), lb=lower_opt, name=name)
    return opt_const


def create_maxdist_objective(model, reaction_weights, prev_sol, prev_sol_bin, only_ones=False, full=False):
    """
    Create the new objective for the maxdist algorithm.
    This objective is the minimization of similarity between the binary solution vectors
    If only_ones is set to False, the similarity will only be calculated with overlapping ones
    """
    expr = sympify('0')
    if full:
        for rxn in model.reactions:
            rid = rxn.id
            rid_loc = prev_sol.fluxes.index.get_loc(rid)
            x = model.solver.variables['x_' + rid]
            if prev_sol_bin[rid_loc] == 1:
                expr += x
            elif not only_ones:
                expr += 1 - x
    else:
        for rid, weight in six.iteritems(reaction_weights):
            rid_loc = prev_sol.fluxes.index.get_loc(rid)
            if weight > 0:
                y_neg = model.solver.variables['rh_' + rid + '_neg']
                y_pos = model.solver.variables['rh_' + rid + '_pos']
                if prev_sol_bin[rid_loc] == 1:
                    expr += y_neg + y_pos
                elif not only_ones:
                    expr += 1 - (y_neg + y_pos)
            elif weight < 0:
                x_rl = sympify('1') - model.solver.variables['rl_' + rid]
                if prev_sol_bin[rid_loc] == 1:
                    expr += 1 - x_rl
                elif not only_ones:
                    expr += x_rl
    objective = model.solver.interface.Objective(expr, direction='min')
    return objective


def maxdist(model, reaction_weights, prev_sol=None, eps=DEFAULT_VALUES['epsilon'], thr=DEFAULT_VALUES['threshold'],
            obj_tol=DEFAULT_VALUES['obj_tol'], maxiter=DEFAULT_VALUES['maxiter'], icut=True, full=False, only_ones=False):
    """

    Parameters
    ----------
    model: cobrapy Model
    reaction_weights: dict
        keys = reactions and values = weights
    prev_sol: imat Solution object
        an imat solution used as a starting point
    eps: float
        activation threshold in imat
    thr: float
        detection threshold of activated reactions
    obj_tol: float
        variance allowed in the objective_values of the solutions
    maxiter: foat
        maximum number of solutions to check for
    icut: bool
        if True, icut constraints are applied
    full: bool
        if True, carries out integer-cut on all reactions; if False, only on reactions with non-zero weights
    only_ones: bool
        if True, only the ones in the binary solution are used for distance calculation (as in dexom matlab)

    Returns
    -------
    solution: EnumSolution object
    """
    if prev_sol is None:
        prev_sol = imat(model, reaction_weights, epsilon=eps, threshold=thr, full=full)
    else:
        model = create_enum_variables(model=model, reaction_weights=reaction_weights, eps=eps, thr=thr, full=full)
    tol = model.solver.configuration.tolerances.feasibility
    icut_constraints = []
    all_solutions = [prev_sol]
    prev_sol_bin = (np.abs(prev_sol.fluxes) >= thr-tol).values.astype(int)
    all_binary = [prev_sol_bin]
    # adding the optimality constraint: the new objective value must be equal to the previous objective value
    opt_const = create_maxdist_constraint(model, reaction_weights, prev_sol, obj_tol,
                                          name='maxdist_optimality', full=full)
    model.solver.add(opt_const)
    for i in range(maxiter):
        t0 = time.perf_counter()
        if icut:
            # adding the icut constraint to prevent the algorithm from finding the same solutions
            const = create_icut_constraint(model, reaction_weights, thr, prev_sol, name='icut_'+str(i), full=full)
            model.solver.add(const)
            icut_constraints.append(const)
        # defining the objective: minimize the number of overlapping ones and zeros
        objective = create_maxdist_objective(model, reaction_weights, prev_sol, prev_sol_bin, only_ones, full)
        model.objective = objective
        try:
            with model:
                prev_sol = model.optimize()
            prev_sol_bin = (np.abs(prev_sol.fluxes) >= thr-tol).values.astype(int)
            all_solutions.append(prev_sol)
            all_binary.append(prev_sol_bin)
        except:
            print('An error occured in iter %i of maxdist' % (i+1))
        t1 = time.perf_counter()
        print('time for iteration '+str(i+1)+': ', t1-t0)
    model.solver.remove([const for const in icut_constraints if const in model.solver.constraints])
    model.solver.remove(opt_const)
    solution = EnumSolution(all_solutions, all_binary, all_solutions[0].objective_value)
    return solution


def main():
    """
    This function is called when you run this script from the commandline.
    It performs the distance-maximization enumeration algorithm
    Use --help to see commandline parameters
    """
    description = 'Performs the distance-maximization enumeration algorithm'
    parser = argparse.ArgumentParser(description=description, formatter_class=argparse.RawTextHelpFormatter)
    parser.add_argument('-m', '--model', help='Metabolic model in sbml, matlab, or json format')
    parser.add_argument('-r', '--reaction_weights', default=None,
                        help='Reaction weights in csv format (first row: reaction names, second row: weights)')
    parser.add_argument('-p', '--prev_sol', default=[], help='starting solution or directory of recent solutions')
    parser.add_argument('-e', '--epsilon', type=float, default=DEFAULT_VALUES['epsilon'],
                        help='Activation threshold for highly expressed reactions')
    parser.add_argument('--threshold', type=float, default=DEFAULT_VALUES['threshold'],
                        help='Activation threshold for all reactions')
    parser.add_argument('-t', '--timelimit', type=int, default=DEFAULT_VALUES['timelimit'], help='Solver time limit')
    parser.add_argument('--tol', type=float, default=DEFAULT_VALUES['tolerance'], help='Solver feasibility tolerance')
    parser.add_argument('--mipgap', type=float, default=DEFAULT_VALUES['mipgap'], help='Solver MIP gap tolerance')
    parser.add_argument('--obj_tol', type=float, default=DEFAULT_VALUES['obj_tol'],
                        help='objective value tolerance, as a fraction of the original value')
    parser.add_argument('-i', '--maxiter', type=int, default=DEFAULT_VALUES['maxiter'], help='Iteration limit')
    parser.add_argument('-o', '--output', default='div_enum', help='Base name of output files, without format')
    parser.add_argument('--noicut', action='store_true', help='Use this flag to remove the icut constraint')
    parser.add_argument('--full', action='store_true', help='Use this flag to assign non-zero weights to all reactions')
    parser.add_argument('--onlyones', action='store_true', help='Use this flag for the old implementation of maxdist')
    args = parser.parse_args()

    model = read_model(args.model)
    check_model_options(model, timelimit=args.timelimit, feasibility=args.tol, mipgaptol=args.mipgap)
    reaction_weights = {}
    if args.reaction_weights is not None:
        reaction_weights = load_reaction_weights(args.reaction_weights)

    prev_sol_success = False
    if args.prev_sol is not None:
        prev_sol_path = Path(args.prev_sol)
        if prev_sol_path.is_file():
            prev_sol, prev_bin = read_solution(args.prev_sol, model)
            model = create_enum_variables(model, reaction_weights, eps=args.epsilon, thr=args.threshold, full=args.full)
            prev_sol_success = True
        elif prev_sol_path.is_dir():
            try:
                prev_sol, i = get_recent_solution_and_iteration(args.prev_sol, args.startsol_num)
            except:
                warn('Could not find solution in directory %s, computing new starting solution' % args.prev_sol)
            else:
                model = create_enum_variables(model, reaction_weights, eps=args.epsilon, thr=args.threshold,
                                              full=args.full)
                prev_sol_success = True
        else:
            warn('Could not read previous solution at path %s, computing new starting solution' % args.prev_sol)
    if not prev_sol_success:
        prev_sol = imat(model, reaction_weights, epsilon=args.epsilon, threshold=args.threshold)
    icut = False if args.noicut else True
    maxdist_sol = maxdist(model=model, reaction_weights=reaction_weights, prev_sol=prev_sol, eps=args.epsilon,
                          thr=args.threshold, obj_tol=args.obj_tol, maxiter=args.maxiter, icut=icut, full=args.full,
                          only_ones=args.onlyones)
    sol = pd.DataFrame(maxdist_sol.binary)
    sol.to_csv(args.output+'_solutions.csv')
    return True


if __name__ == '__main__':
    main()
