/*
 * Decompiled with CFR 0.152.
 */
package org.apache.calcite.rel.rules;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import javax.annotation.Nonnull;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptRuleOperand;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Aggregate;
import org.apache.calcite.rel.core.AggregateCall;
import org.apache.calcite.rel.logical.LogicalAggregate;
import org.apache.calcite.rel.rules.TransformationRule;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexLiteral;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.sql.SqlAggFunction;
import org.apache.calcite.sql.SqlKind;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
import org.apache.calcite.sql.type.SqlTypeUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.tools.RelBuilderFactory;
import org.apache.calcite.util.CompositeList;
import org.apache.calcite.util.ImmutableBeans;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.Util;

public class AggregateReduceFunctionsRule
extends RelRule<Config>
implements TransformationRule {
    private final Set<SqlKind> functionsToReduce;

    private static void validateFunction(SqlKind function) {
        if (!AggregateReduceFunctionsRule.isValid(function)) {
            throw new IllegalArgumentException("AggregateReduceFunctionsRule doesn't support function: " + function.sql);
        }
    }

    private static boolean isValid(SqlKind function) {
        return SqlKind.AVG_AGG_FUNCTIONS.contains((Object)function) || SqlKind.COVAR_AVG_AGG_FUNCTIONS.contains((Object)function) || function == SqlKind.SUM;
    }

    protected AggregateReduceFunctionsRule(Config config) {
        super(config);
        this.functionsToReduce = ImmutableSet.copyOf(config.actualFunctionsToReduce());
    }

    @Deprecated
    public AggregateReduceFunctionsRule(RelOptRuleOperand operand, RelBuilderFactory relBuilderFactory) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).withOperandSupplier(b -> b.exactly(operand)).as(Config.class).withFunctionsToReduce(null));
    }

    @Deprecated
    public AggregateReduceFunctionsRule(Class<? extends Aggregate> aggregateClass, RelBuilderFactory relBuilderFactory, EnumSet<SqlKind> functionsToReduce) {
        this(Config.DEFAULT.withRelBuilderFactory(relBuilderFactory).as(Config.class).withOperandFor(aggregateClass).withFunctionsToReduce((Set<SqlKind>)Objects.requireNonNull(functionsToReduce)));
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        if (!super.matches(call)) {
            return false;
        }
        Aggregate oldAggRel = (Aggregate)call.rels[0];
        return this.containsAvgStddevVarCall(oldAggRel.getAggCallList());
    }

    @Override
    public void onMatch(RelOptRuleCall ruleCall) {
        Aggregate oldAggRel = (Aggregate)ruleCall.rels[0];
        this.reduceAggs(ruleCall, oldAggRel);
    }

    private boolean containsAvgStddevVarCall(List<AggregateCall> aggCallList) {
        for (AggregateCall call : aggCallList) {
            if (!this.isReducible(call.getAggregation().getKind())) continue;
            return true;
        }
        return false;
    }

    private boolean isReducible(SqlKind kind) {
        return this.functionsToReduce.contains((Object)kind);
    }

    private void reduceAggs(RelOptRuleCall ruleCall, Aggregate oldAggRel) {
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        List<AggregateCall> oldCalls = oldAggRel.getAggCallList();
        int groupCount = oldAggRel.getGroupCount();
        ArrayList<AggregateCall> newCalls = new ArrayList<AggregateCall>();
        HashMap<AggregateCall, RexNode> aggCallMapping = new HashMap<AggregateCall, RexNode>();
        ArrayList<RexNode> projList = new ArrayList<RexNode>();
        for (int i = 0; i < groupCount; ++i) {
            projList.add(rexBuilder.makeInputRef(this.getFieldType(oldAggRel, i), i));
        }
        RelBuilder relBuilder = ruleCall.builder();
        relBuilder.push(oldAggRel.getInput());
        ArrayList<RexNode> inputExprs = new ArrayList<RexNode>(relBuilder.fields());
        for (AggregateCall oldCall : oldCalls) {
            projList.add(this.reduceAgg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs));
        }
        int extraArgCount = inputExprs.size() - relBuilder.peek().getRowType().getFieldCount();
        if (extraArgCount > 0) {
            relBuilder.project(inputExprs, CompositeList.of(relBuilder.peek().getRowType().getFieldNames(), Collections.nCopies(extraArgCount, null)));
        }
        this.newAggregateRel(relBuilder, oldAggRel, newCalls);
        this.newCalcRel(relBuilder, oldAggRel.getRowType(), projList);
        ruleCall.transformTo(relBuilder.build());
    }

    private RexNode reduceAgg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        SqlKind kind = oldCall.getAggregation().getKind();
        if (this.isReducible(kind)) {
            switch (kind) {
                case SUM: {
                    return this.reduceSum(oldAggRel, oldCall, newCalls, aggCallMapping);
                }
                case AVG: {
                    return this.reduceAvg(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs);
                }
                case COVAR_POP: {
                    return this.reduceCovariance(oldAggRel, oldCall, true, newCalls, aggCallMapping, inputExprs);
                }
                case COVAR_SAMP: {
                    return this.reduceCovariance(oldAggRel, oldCall, false, newCalls, aggCallMapping, inputExprs);
                }
                case REGR_SXX: {
                    assert (oldCall.getArgList().size() == 2) : oldCall.getArgList();
                    Integer x = oldCall.getArgList().get(0);
                    Integer y = oldCall.getArgList().get(1);
                    return this.reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, y, y, x);
                }
                case REGR_SYY: {
                    assert (oldCall.getArgList().size() == 2) : oldCall.getArgList();
                    Integer x = oldCall.getArgList().get(0);
                    Integer y = oldCall.getArgList().get(1);
                    return this.reduceRegrSzz(oldAggRel, oldCall, newCalls, aggCallMapping, inputExprs, x, x, y);
                }
                case STDDEV_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, true, newCalls, aggCallMapping, inputExprs);
                }
                case STDDEV_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, true, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_POP: {
                    return this.reduceStddev(oldAggRel, oldCall, true, false, newCalls, aggCallMapping, inputExprs);
                }
                case VAR_SAMP: {
                    return this.reduceStddev(oldAggRel, oldCall, false, false, newCalls, aggCallMapping, inputExprs);
                }
            }
            throw Util.unexpected(kind);
        }
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int nGroups = oldAggRel.getGroupCount();
        List<RelDataType> oldArgTypes = SqlTypeUtil.projectTypes(oldAggRel.getInput().getRowType(), oldCall.getArgList());
        return rexBuilder.addAggCall(oldCall, nGroups, newCalls, aggCallMapping, oldArgTypes);
    }

    private AggregateCall createAggregateCallWithBinding(RelDataTypeFactory typeFactory, SqlAggFunction aggFunction, RelDataType operandType, Aggregate oldAggRel, AggregateCall oldCall, int argOrdinal, int filter) {
        Aggregate.AggCallBinding binding = new Aggregate.AggCallBinding(typeFactory, aggFunction, ImmutableList.of(operandType), oldAggRel.getGroupCount(), filter >= 0);
        return AggregateCall.create(aggFunction, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), filter, oldCall.collation, aggFunction.inferReturnType(binding), null);
    }

    private RexNode reduceAvg(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        int nGroups = oldAggRel.getGroupCount();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int iAvgInput = oldCall.getArgList().get(0);
        RelDataType avgInputType = this.getFieldType(oldAggRel.getInput(), iAvgInput);
        AggregateCall sumCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        RexNode numeratorRef = rexBuilder.addAggCall(sumCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RexNode denominatorRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(avgInputType));
        RelDataTypeFactory typeFactory = oldAggRel.getCluster().getTypeFactory();
        RelDataType avgType = typeFactory.createTypeWithNullability(oldCall.getType(), numeratorRef.getType().isNullable());
        numeratorRef = rexBuilder.ensureType(avgType, numeratorRef, true);
        RexNode divideRef = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, numeratorRef, denominatorRef);
        return rexBuilder.makeCast(oldCall.getType(), divideRef);
    }

    private RexNode reduceSum(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping) {
        int nGroups = oldAggRel.getGroupCount();
        RexBuilder rexBuilder = oldAggRel.getCluster().getRexBuilder();
        int arg = oldCall.getArgList().get(0);
        RelDataType argType = this.getFieldType(oldAggRel.getInput(), arg);
        AggregateCall sumZeroCall = AggregateCall.create(SqlStdOperatorTable.SUM0, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, oldCall.name);
        AggregateCall countCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null);
        RexNode sumZeroRef = rexBuilder.addAggCall(sumZeroCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        if (!oldCall.getType().isNullable()) {
            return sumZeroRef;
        }
        RexNode countRef = rexBuilder.addAggCall(countCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argType));
        return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, countRef, rexBuilder.makeExactLiteral(BigDecimal.ZERO)), rexBuilder.makeNullLiteral(sumZeroRef.getType()), sumZeroRef);
    }

    private RexNode reduceStddev(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, boolean sqrt, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        RexNode result;
        int nGroups = oldAggRel.getGroupCount();
        RelOptCluster cluster = oldAggRel.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        assert (oldCall.getArgList().size() == 1) : oldCall.getArgList();
        int argOrdinal = oldCall.getArgList().get(0);
        RelDataType argOrdinalType = this.getFieldType(oldAggRel.getInput(), argOrdinal);
        RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), argOrdinalType.isNullable());
        RexNode argRef = rexBuilder.ensureType(oldCallType, inputExprs.get(argOrdinal), true);
        RexNode argSquared = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, argRef, argRef);
        int argSquaredOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argSquared);
        AggregateCall sumArgSquaredAggCall = this.createAggregateCallWithBinding(typeFactory, SqlStdOperatorTable.SUM, argSquared.getType(), oldAggRel, oldCall, argSquaredOrdinal, -1);
        RexNode sumArgSquared = rexBuilder.addAggCall(sumArgSquaredAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
        AggregateCall sumArgAggCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        RexNode sumArg = rexBuilder.addAggCall(sumArgAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(sumArgAggCall.getType()));
        RexNode sumArgCast = rexBuilder.ensureType(oldCallType, sumArg, true);
        RexNode sumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, sumArgCast, sumArgCast);
        AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), oldCall.getArgList(), oldCall.filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null);
        RexNode countArg = rexBuilder.addAggCall(countArgAggCall, nGroups, newCalls, aggCallMapping, ImmutableList.of(argOrdinalType));
        RexNode div = this.divide(biased, rexBuilder, sumArgSquared, sumSquaredArg, countArg);
        if (sqrt) {
            RexLiteral half = rexBuilder.makeExactLiteral(new BigDecimal("0.5"));
            result = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.POWER, div, half);
        } else {
            result = div;
        }
        return rexBuilder.makeCast(oldCall.getType(), result);
    }

    private RexNode getSumAggregatedRexNode(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, RexBuilder rexBuilder, int argOrdinal, int filterArg) {
        AggregateCall aggregateCall = AggregateCall.create(SqlStdOperatorTable.SUM, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), ImmutableIntList.of(argOrdinal), filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel.getInput(), null, null);
        return rexBuilder.addAggCall(aggregateCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, ImmutableList.of(aggregateCall.getType()));
    }

    private RexNode getSumAggregatedRexNodeWithBinding(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, RelDataType operandType, int argOrdinal, int filter) {
        RelOptCluster cluster = oldAggRel.getCluster();
        AggregateCall sumArgSquaredAggCall = this.createAggregateCallWithBinding(cluster.getTypeFactory(), SqlStdOperatorTable.SUM, operandType, oldAggRel, oldCall, argOrdinal, filter);
        return cluster.getRexBuilder().addAggCall(sumArgSquaredAggCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, ImmutableList.of(sumArgSquaredAggCall.getType()));
    }

    private RexNode getRegrCountRexNode(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, ImmutableIntList argOrdinals, ImmutableList<RelDataType> operandTypes, int filterArg) {
        AggregateCall countArgAggCall = AggregateCall.create(SqlStdOperatorTable.REGR_COUNT, oldCall.isDistinct(), oldCall.isApproximate(), oldCall.ignoreNulls(), argOrdinals, filterArg, oldCall.collation, oldAggRel.getGroupCount(), oldAggRel, null, null);
        return oldAggRel.getCluster().getRexBuilder().addAggCall(countArgAggCall, oldAggRel.getGroupCount(), newCalls, aggCallMapping, operandTypes);
    }

    private RexNode reduceRegrSzz(Aggregate oldAggRel, AggregateCall oldCall, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs, int xIndex, int yIndex, int nullFilterIndex) {
        RelOptCluster cluster = oldAggRel.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        RelDataType argXType = this.getFieldType(oldAggRel.getInput(), xIndex);
        RelDataType argYType = xIndex == yIndex ? argXType : this.getFieldType(oldAggRel.getInput(), yIndex);
        RelDataType nullFilterIndexType = nullFilterIndex == yIndex ? argYType : this.getFieldType(oldAggRel.getInput(), yIndex);
        RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), argXType.isNullable() || argYType.isNullable() || nullFilterIndexType.isNullable());
        RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(xIndex), true);
        RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(yIndex), true);
        RexNode argNullFilter = rexBuilder.ensureType(oldCallType, inputExprs.get(nullFilterIndex), true);
        RexNode argXArgY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, argX, argY);
        int argSquaredOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argXArgY);
        RexNode argXAndYNotNullFilter = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, argX), rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, argY)), rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, argNullFilter));
        int argXAndYNotNullFilterOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argXAndYNotNullFilter);
        RexNode sumXY = this.getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls, aggCallMapping, argXArgY.getType(), argSquaredOrdinal, argXAndYNotNullFilterOrdinal);
        RexNode sumXYCast = rexBuilder.ensureType(oldCallType, sumXY, true);
        RexNode sumX = this.getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, xIndex, argXAndYNotNullFilterOrdinal);
        RexNode sumY = xIndex == yIndex ? sumX : this.getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, yIndex, argXAndYNotNullFilterOrdinal);
        RexNode sumXSumY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, sumX, sumY);
        RexNode countArg = this.getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, ImmutableIntList.of(xIndex), ImmutableList.of(argXType), argXAndYNotNullFilterOrdinal);
        RexLiteral zero = rexBuilder.makeExactLiteral(BigDecimal.ZERO);
        RexLiteral nul = rexBuilder.makeNullLiteral(zero.getType());
        RexNode avgSumXSumY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, countArg, zero), nul, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, sumXSumY, countArg));
        RexNode avgSumXSumYCast = rexBuilder.ensureType(oldCallType, avgSumXSumY, true);
        RexNode result = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, sumXYCast, avgSumXSumYCast);
        return rexBuilder.makeCast(oldCall.getType(), result);
    }

    private RexNode reduceCovariance(Aggregate oldAggRel, AggregateCall oldCall, boolean biased, List<AggregateCall> newCalls, Map<AggregateCall, RexNode> aggCallMapping, List<RexNode> inputExprs) {
        RelOptCluster cluster = oldAggRel.getCluster();
        RexBuilder rexBuilder = cluster.getRexBuilder();
        RelDataTypeFactory typeFactory = cluster.getTypeFactory();
        assert (oldCall.getArgList().size() == 2) : oldCall.getArgList();
        int argXOrdinal = oldCall.getArgList().get(0);
        int argYOrdinal = oldCall.getArgList().get(1);
        RelDataType argXOrdinalType = this.getFieldType(oldAggRel.getInput(), argXOrdinal);
        RelDataType argYOrdinalType = this.getFieldType(oldAggRel.getInput(), argYOrdinal);
        RelDataType oldCallType = typeFactory.createTypeWithNullability(oldCall.getType(), argXOrdinalType.isNullable() || argYOrdinalType.isNullable());
        RexNode argX = rexBuilder.ensureType(oldCallType, inputExprs.get(argXOrdinal), true);
        RexNode argY = rexBuilder.ensureType(oldCallType, inputExprs.get(argYOrdinal), true);
        RexNode argXAndYNotNullFilter = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.AND, rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, argX), rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.IS_NOT_NULL, argY));
        int argXAndYNotNullFilterOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argXAndYNotNullFilter);
        RexNode argXY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, argX, argY);
        int argXYOrdinal = AggregateReduceFunctionsRule.lookupOrAdd(inputExprs, argXY);
        RexNode sumXY = this.getSumAggregatedRexNodeWithBinding(oldAggRel, oldCall, newCalls, aggCallMapping, argXY.getType(), argXYOrdinal, argXAndYNotNullFilterOrdinal);
        RexNode sumX = this.getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, argXOrdinal, argXAndYNotNullFilterOrdinal);
        RexNode sumY = this.getSumAggregatedRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, rexBuilder, argYOrdinal, argXAndYNotNullFilterOrdinal);
        RexNode sumXSumY = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MULTIPLY, sumX, sumY);
        RexNode countArg = this.getRegrCountRexNode(oldAggRel, oldCall, newCalls, aggCallMapping, ImmutableIntList.of(argXOrdinal, argYOrdinal), ImmutableList.of(argXOrdinalType, argYOrdinalType), argXAndYNotNullFilterOrdinal);
        RexNode result = this.divide(biased, rexBuilder, sumXY, sumXSumY, countArg);
        return rexBuilder.makeCast(oldCall.getType(), result);
    }

    private RexNode divide(boolean biased, RexBuilder rexBuilder, RexNode sumXY, RexNode sumXSumY, RexNode countArg) {
        RexNode denominator;
        RexNode avgSumSquaredArg = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, sumXSumY, countArg);
        RexNode diff = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, sumXY, avgSumSquaredArg);
        if (biased) {
            denominator = countArg;
        } else {
            RexLiteral one = rexBuilder.makeExactLiteral(BigDecimal.ONE);
            RexLiteral nul = rexBuilder.makeNullLiteral(countArg.getType());
            RexNode countMinusOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.MINUS, countArg, one);
            RexNode countEqOne = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.EQUALS, countArg, one);
            denominator = rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.CASE, countEqOne, nul, countMinusOne);
        }
        return rexBuilder.makeCall((SqlOperator)SqlStdOperatorTable.DIVIDE, diff, denominator);
    }

    private static <T> int lookupOrAdd(List<T> list, T element) {
        int ordinal = list.indexOf(element);
        if (ordinal == -1) {
            ordinal = list.size();
            list.add(element);
        }
        return ordinal;
    }

    protected void newAggregateRel(RelBuilder relBuilder, Aggregate oldAggregate, List<AggregateCall> newCalls) {
        relBuilder.aggregate(relBuilder.groupKey(oldAggregate.getGroupSet(), (Iterable<? extends ImmutableBitSet>)oldAggregate.getGroupSets()), newCalls);
    }

    protected void newCalcRel(RelBuilder relBuilder, RelDataType rowType, List<RexNode> exprs) {
        relBuilder.project(exprs, rowType.getFieldNames());
    }

    private RelDataType getFieldType(RelNode relNode, int i) {
        RelDataTypeField inputField = relNode.getRowType().getFieldList().get(i);
        return inputField.getType();
    }

    public static interface Config
    extends RelRule.Config {
        public static final Config DEFAULT = EMPTY.as(Config.class).withOperandFor(LogicalAggregate.class);
        public static final Set<SqlKind> DEFAULT_FUNCTIONS_TO_REDUCE = ((ImmutableSet.Builder)((ImmutableSet.Builder)((ImmutableSet.Builder)ImmutableSet.builder().addAll(SqlKind.AVG_AGG_FUNCTIONS)).addAll(SqlKind.COVAR_AVG_AGG_FUNCTIONS)).add((Object)SqlKind.SUM)).build();

        @Override
        default public AggregateReduceFunctionsRule toRule() {
            return new AggregateReduceFunctionsRule(this);
        }

        @ImmutableBeans.Property
        public Set<SqlKind> functionsToReduce();

        public Config withFunctionsToReduce(Set<SqlKind> var1);

        @Nonnull
        default public Set<SqlKind> actualFunctionsToReduce() {
            Set<SqlKind> set = Util.first(this.functionsToReduce(), DEFAULT_FUNCTIONS_TO_REDUCE);
            set.forEach(x$0 -> AggregateReduceFunctionsRule.validateFunction(x$0));
            return set;
        }

        default public Config withOperandFor(Class<? extends Aggregate> aggregateClass) {
            return this.withOperandSupplier(b -> b.operand(aggregateClass).anyInputs()).as(Config.class);
        }
    }
}

