/*
 * Decompiled with CFR 0.152.
 */
package model.inference;

import data.condition.FilterNumericFeature;
import data.condition.FilterSet;
import data.condition.Operator;
import data.feature.SingleValueFeature;
import data.instance.Instance;
import data.instance.Instances;
import data.value.Value;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.TreeMap;
import model.NodeSplit;
import model.criterion.cancel.CancelCriterion;

public class SplitFinder {
    public static ArrayList<CancelCriterion> cancels = new ArrayList();

    public static NodeSplit inferNominalSplitOnNominalTarget(SingleValueFeature feature, HashMap<Value, Instances> groups, Instances noValue, boolean complexityUse) {
        HashMap countsPerPartition = new HashMap();
        HashMap<Value, Double> countsNoValue = new HashMap<Value, Double>();
        HashMap<Value, double[]> countsBranches = new HashMap<Value, double[]>();
        double nl = 0.0;
        double nr = 0.0;
        double nn = 0.0;
        double nt = 0.0;
        for (Instance instance : noValue) {
            Value classVal = instance.getLabel();
            if (!countsNoValue.containsKey(classVal)) {
                countsNoValue.put(classVal, 0.0);
            }
            countsNoValue.put(classVal, (Double)countsNoValue.get(classVal) + 1.0);
            if (!countsBranches.containsKey(classVal)) {
                countsBranches.put(classVal, new double[]{0.0, 0.0, 0.0, 0.0});
            }
            double w = instance.getWeight();
            double[] dArray = (double[])countsBranches.get(classVal);
            dArray[2] = dArray[2] + w;
            double[] dArray2 = (double[])countsBranches.get(classVal);
            dArray2[3] = dArray2[3] + w;
            nn += w;
            nt += w;
        }
        for (Map.Entry entry : groups.entrySet()) {
            Value val = (Value)entry.getKey();
            Instances ids = (Instances)entry.getValue();
            countsPerPartition.put(val, new HashMap());
            for (Instance inst : ids) {
                Value classVal = inst.getLabel();
                if (!((HashMap)countsPerPartition.get(val)).containsKey(classVal)) {
                    ((HashMap)countsPerPartition.get(val)).put(classVal, 0.0);
                }
                ((HashMap)countsPerPartition.get(val)).put(classVal, (Double)((HashMap)countsPerPartition.get(val)).get(classVal) + 1.0);
                if (!countsBranches.containsKey(classVal)) {
                    countsBranches.put(classVal, new double[]{0.0, 0.0, 0.0, 0.0});
                }
                double w = inst.getWeight();
                double[] dArray = (double[])countsBranches.get(classVal);
                dArray[1] = dArray[1] + w;
                double[] dArray3 = (double[])countsBranches.get(classVal);
                dArray3[3] = dArray3[3] + w;
                nr += w;
                nt += w;
            }
        }
        double d = nn * nn;
        double sct = nt * nt;
        for (double[] cs : countsBranches.values()) {
            d -= cs[2] * cs[2];
            sct -= cs[3] * cs[3];
        }
        if (nn > 0.0) {
            d /= nn;
        }
        if (nt > 0.0) {
            sct /= nt;
        }
        TreeMap partitionsOrdered = new TreeMap();
        for (Map.Entry ent : countsPerPartition.entrySet()) {
            Value val = (Value)ent.getKey();
            HashMap countsForPartition = (HashMap)ent.getValue();
            for (Double w : countsForPartition.values()) {
                nl += w.doubleValue();
                nr -= w.doubleValue();
            }
            double scl = nl * nl;
            double scr = nr * nr;
            for (Map.Entry c : countsForPartition.entrySet()) {
                scl -= (Double)c.getValue() * (Double)c.getValue();
                scr -= (((double[])countsBranches.get(c.getKey()))[3] - ((double[])countsBranches.get(c.getKey()))[2] - (Double)c.getValue()) * (((double[])countsBranches.get(c.getKey()))[3] - ((double[])countsBranches.get(c.getKey()))[2] - (Double)c.getValue());
            }
            if (nl > 0.0) {
                scl /= nl;
            }
            if (nr > 0.0) {
                scr /= nr;
            }
            double score = (sct - scl - scr - d) / nt;
            if (complexityUse) {
                score /= feature.complexity();
            }
            if (!partitionsOrdered.containsKey(score)) {
                partitionsOrdered.put(score, new HashSet());
            }
            ((HashSet)partitionsOrdered.get(score)).add(val);
            for (Double w : ((HashMap)ent.getValue()).values()) {
                nl -= w.doubleValue();
                nr += w.doubleValue();
            }
        }
        double prevScore = -1.7976931348623157E308;
        HashSet<Value> valuesSet = new HashSet<Value>();
        boolean go = true;
        NodeSplit res = null;
        block8: for (HashSet vals : partitionsOrdered.descendingMap().values()) {
            for (Value val : vals) {
                if (!go) continue block8;
                HashSet<Value> possibleValuesSet = new HashSet<Value>(valuesSet);
                possibleValuesSet.add(val);
                for (Double w : ((HashMap)countsPerPartition.get(val)).values()) {
                    nl += w.doubleValue();
                    nr -= w.doubleValue();
                }
                double scl = nl * nl;
                double scr = nr * nr;
                for (Map.Entry classCount : countsBranches.entrySet()) {
                    double c = ((HashMap)countsPerPartition.get(val)).get(classCount.getKey()) == null ? 0.0 : (Double)((HashMap)countsPerPartition.get(val)).get(classCount.getKey());
                    double[] dArray = (double[])classCount.getValue();
                    dArray[0] = dArray[0] + c;
                    double[] dArray4 = (double[])classCount.getValue();
                    dArray4[1] = dArray4[1] - c;
                    scl -= ((double[])classCount.getValue())[0] * ((double[])classCount.getValue())[0];
                    scr -= ((double[])classCount.getValue())[1] * ((double[])classCount.getValue())[1];
                }
                if (nl > 0.0) {
                    scl /= nl;
                }
                if (nr > 0.0) {
                    scr /= nr;
                }
                double score = (sct - scl - scr - d) / nt;
                if (complexityUse) {
                    score /= feature.complexity();
                }
                if (score > prevScore) {
                    Instances left = new Instances();
                    Instances right = new Instances();
                    valuesSet.add(val);
                    for (Map.Entry<Value, Instances> gr : groups.entrySet()) {
                        if (valuesSet.contains(gr.getKey())) {
                            left.addAll(gr.getValue());
                            continue;
                        }
                        right.addAll(gr.getValue());
                    }
                    NodeSplit candSplit = new NodeSplit(left, right, noValue, new FilterSet(feature.clone(), valuesSet), score);
                    boolean toCancel = false;
                    for (CancelCriterion cancel : cancels) {
                        boolean stop = cancel.cancelCriterion(candSplit);
                        if (!stop) continue;
                        toCancel = stop;
                        break;
                    }
                    if (!toCancel) {
                        res = candSplit;
                        prevScore = score;
                        continue;
                    }
                    go = false;
                    continue;
                }
                go = false;
            }
        }
        return res;
    }

    public static NodeSplit inferNominalSplitOnNumericTarget(SingleValueFeature feature, HashMap<Value, Instances> groups, Instances noValue, boolean complexityUse) {
        double sr;
        HashMap<Value, Double> sumPerPartition = new HashMap<Value, Double>();
        double sumNoValue = 0.0;
        double sumTotal = 0.0;
        double nl = 0.0;
        double nr = 0.0;
        double nn = 0.0;
        double nt = 0.0;
        for (Instance instance : noValue) {
            sumNoValue += instance.getWeight() * instance.getLabel().getNumericValue();
            nn += instance.getWeight();
            nt += instance.getWeight();
        }
        for (Map.Entry entry : groups.entrySet()) {
            Value val = (Value)entry.getKey();
            Instances ids = (Instances)entry.getValue();
            double sumForVal = 0.0;
            for (Instance inst : ids) {
                sumForVal += inst.getWeight() * inst.getLabel().getNumericValue();
                nr += inst.getWeight();
                nt += inst.getWeight();
            }
            sumPerPartition.put(val, sumForVal);
            sumTotal += sumForVal;
        }
        double d = nn == 0.0 ? 0.0 : sumNoValue / nn;
        TreeMap partitionsOrdered = new TreeMap();
        for (Map.Entry ent : sumPerPartition.entrySet()) {
            Value val = (Value)ent.getKey();
            double sl = (Double)ent.getValue();
            sr = sumTotal - sl;
            for (Instance inst : groups.get(val)) {
                nl += inst.getWeight();
                nr -= inst.getWeight();
            }
            double ml = nl == 0.0 ? 0.0 : sl / nl;
            double mr = nr == 0.0 ? 0.0 : sr / nr;
            double d2 = (nl * nr * (ml - mr) * (ml - mr) + nl * nn * (ml - d) * (ml - d) + nn * nr * (d - mr) * (d - mr)) / (nt * nt);
            if (complexityUse) {
                d2 /= feature.complexity();
            }
            if (!partitionsOrdered.containsKey(d2)) {
                partitionsOrdered.put(d2, new HashSet());
            }
            ((HashSet)partitionsOrdered.get(d2)).add(val);
            for (Instance inst : groups.get(val)) {
                nl -= inst.getWeight();
                nr += inst.getWeight();
            }
        }
        double prevScore = -1.7976931348623157E308;
        HashSet<Value> valuesSet = new HashSet<Value>();
        double sl = 0.0;
        sr = sumTotal;
        boolean go = true;
        NodeSplit res = null;
        block6: for (HashSet vals : partitionsOrdered.descendingMap().values()) {
            for (Value value : vals) {
                if (!go) continue block6;
                HashSet<Value> possibleValuesSet = new HashSet<Value>(valuesSet);
                possibleValuesSet.add(value);
                sl += ((Double)sumPerPartition.get(value)).doubleValue();
                sr -= ((Double)sumPerPartition.get(value)).doubleValue();
                for (Instance inst : groups.get(value)) {
                    nl += inst.getWeight();
                    nr -= inst.getWeight();
                }
                double ml = nl == 0.0 ? 0.0 : sl / nl;
                double mr = nr == 0.0 ? 0.0 : sr / nr;
                double score = (nl * nr * (ml - mr) * (ml - mr) + nl * nn * (ml - d) * (ml - d) + nn * nr * (d - mr) * (d - mr)) / (nt * nt);
                if (complexityUse) {
                    score /= feature.complexity();
                }
                if (score > prevScore) {
                    Instances left = new Instances();
                    Instances right = new Instances();
                    valuesSet.add(value);
                    for (Map.Entry<Value, Instances> gr : groups.entrySet()) {
                        if (valuesSet.contains(gr.getKey())) {
                            left.addAll(gr.getValue());
                            continue;
                        }
                        right.addAll(gr.getValue());
                    }
                    NodeSplit candSplit = new NodeSplit(left, right, noValue, new FilterSet(feature.clone(), valuesSet), prevScore);
                    boolean toCancel = false;
                    for (CancelCriterion cancel : cancels) {
                        boolean stop = cancel.cancelCriterion(candSplit);
                        if (!stop) continue;
                        toCancel = stop;
                        break;
                    }
                    if (!toCancel) {
                        res = candSplit;
                        prevScore = score;
                        continue;
                    }
                    go = false;
                    continue;
                }
                go = false;
            }
        }
        Instances left = new Instances();
        Instances right = new Instances();
        for (Map.Entry entry : groups.entrySet()) {
            if (valuesSet.contains(entry.getKey())) {
                left.addAll((Collection)entry.getValue());
                continue;
            }
            right.addAll((Collection)entry.getValue());
        }
        return res;
    }

    public static NodeSplit inferNumericSplitOnNominalTarget(SingleValueFeature feature, TreeMap<Value, Instances> groups, Instances noValue, boolean complexityUse) {
        HashMap<Value, double[]> counts = new HashMap<Value, double[]>();
        double nl = 0.0;
        double nr = 0.0;
        double nn = 0.0;
        double nt = 0.0;
        Instances left = new Instances();
        Instances right = new Instances();
        for (Instance inst : noValue) {
            Value classVal = inst.getLabel();
            if (!counts.containsKey(classVal)) {
                counts.put(classVal, new double[]{0.0, 0.0, 0.0, 0.0});
            }
            double w = inst.getWeight();
            double[] dArray = (double[])counts.get(classVal);
            dArray[2] = dArray[2] + w;
            double[] dArray2 = (double[])counts.get(classVal);
            dArray2[3] = dArray2[3] + w;
            nn += w;
            nt += w;
        }
        for (Instances insts : groups.values()) {
            for (Instance inst : insts) {
                Value classVal = inst.getLabel();
                if (!counts.containsKey(classVal)) {
                    counts.put(classVal, new double[]{0.0, 0.0, 0.0, 0.0});
                }
                double w = inst.getWeight();
                double[] dArray = (double[])counts.get(classVal);
                dArray[0] = dArray[0] + w;
                double[] dArray3 = (double[])counts.get(classVal);
                dArray3[3] = dArray3[3] + w;
                nl += w;
                nt += w;
            }
            left.addAll(insts);
        }
        ArrayList<NodeSplit> bestSplits = new ArrayList<NodeSplit>();
        double bestScore = -1.7976931348623157E308;
        double sct = nt * nt;
        double scn = nn * nn;
        for (Value value : counts.keySet()) {
            scn -= ((double[])counts.get(value))[2] * ((double[])counts.get(value))[2];
            sct -= ((double[])counts.get(value))[3] * ((double[])counts.get(value))[3];
        }
        if (nn > 0.0) {
            scn /= nn;
        }
        if (nt > 0.0) {
            sct /= nt;
        }
        for (Map.Entry entry : groups.entrySet()) {
            Value thresh = (Value)entry.getKey();
            Instances instsForThresh = (Instances)entry.getValue();
            double scl = nl * nl;
            double scr = nr * nr;
            for (Value rv : counts.keySet()) {
                scl -= ((double[])counts.get(rv))[0] * ((double[])counts.get(rv))[0];
                scr -= ((double[])counts.get(rv))[1] * ((double[])counts.get(rv))[1];
            }
            if (nl > 0.0) {
                scl /= nl;
            }
            if (nr > 0.0) {
                scr /= nr;
            }
            double score = (sct - scl - scr - scn) / nt;
            if (complexityUse) {
                score /= feature.complexity();
            }
            if (score >= bestScore) {
                NodeSplit candSplit = new NodeSplit(new Instances(left), new Instances(right), new Instances(noValue), new FilterNumericFeature(feature.clone(), Operator.GEQ, thresh), score);
                boolean toCancel = false;
                for (CancelCriterion cancel : cancels) {
                    boolean stop = cancel.cancelCriterion(candSplit);
                    if (!stop) continue;
                    toCancel = stop;
                    break;
                }
                if (!toCancel) {
                    if (score > bestScore) {
                        bestScore = score;
                        bestSplits.clear();
                    }
                    bestSplits.add(candSplit);
                }
            }
            left.removeAll(instsForThresh);
            right.addAll(instsForThresh);
            for (Instance inst : instsForThresh) {
                Value classVal = inst.getLabel();
                Double w = inst.getWeight();
                double[] dArray = (double[])counts.get(classVal);
                dArray[0] = dArray[0] - w;
                double[] dArray4 = (double[])counts.get(classVal);
                dArray4[1] = dArray4[1] + w;
                nl -= w.doubleValue();
                nr += w.doubleValue();
            }
        }
        if (bestSplits.isEmpty()) {
            return null;
        }
        if (bestSplits.size() == 1) {
            return (NodeSplit)bestSplits.get(0);
        }
        return (NodeSplit)bestSplits.get(bestSplits.size() / 2);
    }

    public static NodeSplit inferNumericSplitOnNumericTarget(SingleValueFeature feature, TreeMap<Value, Instances> groups, Instances noValue, boolean complexityUse) {
        double sl = groups.values().stream().flatMap(gr -> gr.stream()).mapToDouble(inst -> inst.getWeight() * inst.getLabel().getNumericValue()).sum();
        double sr = 0.0;
        double sn = noValue.stream().mapToDouble(inst -> inst.getWeight() * inst.getLabel().getNumericValue()).sum();
        double nl = groups.values().stream().flatMap(gr -> gr.stream()).mapToDouble(inst -> inst.getWeight()).sum();
        double nr = 0.0;
        double nn = noValue.stream().mapToDouble(inst -> inst.getWeight()).sum();
        double nt = nl + nn;
        Instances left = new Instances();
        Instances right = new Instances();
        ArrayList<NodeSplit> bestSplits = new ArrayList<NodeSplit>();
        double bestScore = -1.7976931348623157E308;
        double mn = nn == 0.0 ? 0.0 : sn / nn;
        for (Map.Entry<Value, Instances> ent : groups.entrySet()) {
            Value thresh = ent.getKey();
            Instances instsForThresh = ent.getValue();
            double ml = nl == 0.0 ? 0.0 : sl / nl;
            double mr = nr == 0.0 ? 0.0 : sr / nr;
            double score = (nl * nr * (ml - mr) * (ml - mr) + nl * nn * (ml - mn) * (ml - mn) + nn * nr * (mn - mr) * (mn - mr)) / (nt * nt);
            if (complexityUse) {
                score /= feature.complexity();
            }
            if (score >= bestScore) {
                NodeSplit candSplit = new NodeSplit(new Instances(left), new Instances(right), new Instances(noValue), new FilterNumericFeature(feature.clone(), Operator.GEQ, thresh), score);
                boolean toCancel = false;
                for (CancelCriterion cancel : cancels) {
                    boolean stop = cancel.cancelCriterion(candSplit);
                    if (!stop) continue;
                    toCancel = stop;
                    break;
                }
                if (!toCancel) {
                    if (score > bestScore) {
                        bestScore = score;
                        bestSplits.clear();
                    }
                    bestSplits.add(candSplit);
                }
            }
            left.removeAll(instsForThresh);
            right.addAll(instsForThresh);
            for (Instance inst2 : instsForThresh) {
                Double label = inst2.getLabel().getNumericValue();
                Double w = inst2.getWeight();
                sl -= w * label;
                sr += w * label;
                nl -= w.doubleValue();
                nr += w.doubleValue();
            }
        }
        if (bestSplits.isEmpty()) {
            return null;
        }
        if (bestSplits.size() == 1) {
            return (NodeSplit)bestSplits.get(0);
        }
        return (NodeSplit)bestSplits.get(bestSplits.size() / 2);
    }

    public static double scoreNominal(Instances insts) {
        HashMap<Value, Double> counts = new HashMap<Value, Double>();
        double total = 0.0;
        for (Instance inst : insts) {
            Value classVal = inst.getLabel();
            if (!counts.containsKey(classVal)) {
                counts.put(classVal, 0.0);
            }
            counts.put(classVal, (Double)counts.get(classVal) + inst.getWeight());
            total += inst.getWeight();
        }
        double res = total * total;
        for (Double cnt : counts.values()) {
            res -= cnt * cnt;
        }
        return res / (total * total);
    }

    public static double scoreNumeric(Instances insts) {
        double s1 = 0.0;
        double s2 = 0.0;
        double s0 = 0.0;
        for (Instance inst : insts) {
            Value classVal = inst.getLabel();
            double v = classVal.getNumericValue();
            double w = inst.getWeight();
            s0 += w;
            s1 += w * v;
            s2 += w * v * v;
        }
        return s2 / s0 - s1 / s0 * (s1 / s0);
    }
}

