/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.calcite.plan.RelOptCluster;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.plan.hep.HepRelVertex;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.Project;
import org.apache.calcite.rel.core.RelFactories;
import org.apache.calcite.rel.rules.MultiJoin;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rel.type.RelDataTypeFieldImpl;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.rex.RexUtil;
import org.apache.calcite.tools.RelBuilder;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.calcite.util.ImmutableIntList;
import org.apache.calcite.util.mapping.Mappings;
import org.apache.flink.calcite.shaded.com.google.common.collect.ImmutableMap;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableProjectMultiJoinTransposeRule;
import org.immutables.value.Value;

@Value.Enclosing
public class ProjectMultiJoinTransposeRule
extends RelRule<ProjectMultiJoinTransposeRuleConfig> {
    public static final ProjectMultiJoinTransposeRule INSTANCE = ProjectMultiJoinTransposeRuleConfig.DEFAULT.toRule();

    public ProjectMultiJoinTransposeRule(ProjectMultiJoinTransposeRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        Project originalProject = (Project)call.rel(0);
        MultiJoin multiJoin = (MultiJoin)call.rel(1);
        if (RexUtil.isIdentity(originalProject.getProjects(), multiJoin.getRowType())) {
            return false;
        }
        for (RelNode input : multiJoin.getInputs()) {
            if (!this.isProject(input)) continue;
            return false;
        }
        return true;
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        Project originalProject = (Project)call.rel(0);
        MultiJoin multiJoin = (MultiJoin)call.rel(1);
        RelBuilder relBuilder = call.builder();
        ImmutableBitSet referencedFields = this.collectReferencedFields(originalProject, multiJoin);
        TransformedInputs transformedInputs = this.createTransformedInputs(multiJoin, referencedFields, relBuilder);
        Mappings.TargetMapping fieldMapping = this.createFieldMapping(multiJoin, transformedInputs);
        MultiJoin newMultiJoin = this.createMultiJoinWithAdjustedParams(multiJoin, transformedInputs, fieldMapping);
        List<RexNode> newProjects = RexUtil.apply(fieldMapping, originalProject.getProjects());
        relBuilder.push(newMultiJoin);
        relBuilder.project(newProjects, originalProject.getRowType().getFieldNames());
        call.transformTo(relBuilder.build());
    }

    private ImmutableBitSet collectReferencedFields(Project project, MultiJoin multiJoin) {
        final ImmutableBitSet.Builder referencedFieldsBuilder = ImmutableBitSet.builder();
        RexShuttle fieldCollector = new RexShuttle(){

            @Override
            public RexNode visitInputRef(RexInputRef inputRef) {
                referencedFieldsBuilder.set(inputRef.getIndex());
                return inputRef;
            }
        };
        fieldCollector.apply(project.getProjects());
        fieldCollector.apply(multiJoin.getJoinFilter());
        if (multiJoin.getPostJoinFilter() != null) {
            fieldCollector.apply(multiJoin.getPostJoinFilter());
        }
        multiJoin.getOuterJoinConditions().forEach(fieldCollector::apply);
        return referencedFieldsBuilder.build();
    }

    private TransformedInputs createTransformedInputs(MultiJoin multiJoin, ImmutableBitSet referencedFields, RelBuilder relBuilder) {
        ArrayList<RelNode> newInputs = new ArrayList<RelNode>();
        ArrayList<ImmutableBitSet> newProjFields = new ArrayList<ImmutableBitSet>();
        ArrayList<Integer> newFieldCounts = new ArrayList<Integer>();
        int fieldOffset = 0;
        for (RelNode input : multiJoin.getInputs()) {
            int inputFieldCount = input.getRowType().getFieldCount();
            ImmutableBitSet requiredFields = this.extractRequiredFieldsForInput(referencedFields, fieldOffset, inputFieldCount);
            if (requiredFields.cardinality() == inputFieldCount) {
                newInputs.add(input);
                newProjFields.add(null);
                newFieldCounts.add(inputFieldCount);
            } else {
                RelNode projectedInput = this.createProjectionForInput(input, requiredFields, relBuilder);
                newInputs.add(projectedInput);
                newProjFields.add(requiredFields);
                newFieldCounts.add(requiredFields.cardinality());
            }
            fieldOffset += inputFieldCount;
        }
        return new TransformedInputs(newInputs, newProjFields, newFieldCounts);
    }

    private ImmutableBitSet extractRequiredFieldsForInput(ImmutableBitSet referencedFields, int fieldOffset, int inputFieldCount) {
        ImmutableBitSet.Builder requiredFieldsBuilder = ImmutableBitSet.builder();
        for (int bit : referencedFields) {
            if (bit < fieldOffset || bit >= fieldOffset + inputFieldCount) continue;
            requiredFieldsBuilder.set(bit - fieldOffset);
        }
        return requiredFieldsBuilder.build();
    }

    private RelNode createProjectionForInput(RelNode input, ImmutableBitSet requiredFields, RelBuilder relBuilder) {
        ArrayList newProjects = new ArrayList();
        ArrayList newNames = new ArrayList();
        List<RelDataTypeField> inputFields = input.getRowType().getFieldList();
        relBuilder.push(input);
        requiredFields.forEach((Consumer<? super Integer>)((Consumer<Integer>)i -> {
            newProjects.add(relBuilder.field((int)i));
            newNames.add(((RelDataTypeField)inputFields.get((int)i)).getName());
        }));
        return relBuilder.project(newProjects, newNames).build();
    }

    private Mappings.TargetMapping createFieldMapping(MultiJoin multiJoin, TransformedInputs transformedInputs) {
        int[] adjustments = new int[multiJoin.getRowType().getFieldCount()];
        Arrays.fill(adjustments, -1);
        int newFieldOffset = 0;
        int oldFieldOffset = 0;
        for (int inputIndex = 0; inputIndex < transformedInputs.newInputs.size(); ++inputIndex) {
            int fieldIndex;
            ImmutableBitSet projection = transformedInputs.newProjFields.get(inputIndex);
            int oldInputFieldCount = multiJoin.getInputs().get(inputIndex).getRowType().getFieldCount();
            if (projection == null) {
                for (fieldIndex = 0; fieldIndex < oldInputFieldCount; ++fieldIndex) {
                    adjustments[oldFieldOffset + fieldIndex] = newFieldOffset + fieldIndex;
                }
            } else {
                for (fieldIndex = 0; fieldIndex < oldInputFieldCount; ++fieldIndex) {
                    if (!projection.get(fieldIndex)) continue;
                    adjustments[oldFieldOffset + fieldIndex] = newFieldOffset + projection.indexOf(fieldIndex);
                }
            }
            oldFieldOffset += oldInputFieldCount;
            newFieldOffset += transformedInputs.newFieldCounts.get(inputIndex).intValue();
        }
        HashMap<Integer, Integer> oldToNewMapping = new HashMap<Integer, Integer>();
        for (int i = 0; i < adjustments.length; ++i) {
            if (adjustments[i] == -1) continue;
            oldToNewMapping.put(i, adjustments[i]);
        }
        return Mappings.target(oldToNewMapping, multiJoin.getRowType().getFieldCount(), transformedInputs.newInputs.stream().mapToInt(input -> input.getRowType().getFieldCount()).sum());
    }

    private MultiJoin createMultiJoinWithAdjustedParams(MultiJoin originalMultiJoin, TransformedInputs transformedInputs, Mappings.TargetMapping fieldMapping) {
        RelOptCluster cluster = originalMultiJoin.getCluster();
        RelDataType newRowType = this.buildNewRowType(originalMultiJoin, fieldMapping);
        RexNode newJoinFilter = this.applyMappingToRexNode(originalMultiJoin.getJoinFilter(), fieldMapping);
        RexNode newPostJoinFilter = this.applyMappingToRexNode(originalMultiJoin.getPostJoinFilter(), fieldMapping);
        List<RexNode> newOuterJoinConditions = this.applyMappingToOuterJoinConditions(originalMultiJoin.getOuterJoinConditions(), fieldMapping);
        Map<Integer, ImmutableIntList> newJoinFieldRefCountsMap = this.createNewJoinFieldRefCountsMap(originalMultiJoin, transformedInputs, fieldMapping);
        return new MultiJoin(cluster, originalMultiJoin.getHints(), transformedInputs.newInputs, newJoinFilter, newRowType, originalMultiJoin.isFullOuterJoin(), newOuterJoinConditions, originalMultiJoin.getJoinTypes(), transformedInputs.newProjFields, ImmutableMap.copyOf(newJoinFieldRefCountsMap), newPostJoinFilter);
    }

    private RelDataType buildNewRowType(MultiJoin originalMultiJoin, Mappings.TargetMapping fieldMapping) {
        RelDataTypeFactory typeFactory = originalMultiJoin.getCluster().getTypeFactory();
        ArrayList<RelDataTypeFieldImpl> newFields = new ArrayList<RelDataTypeFieldImpl>();
        List<RelNode> originalInputs = originalMultiJoin.getInputs();
        List<RelDataTypeField> originalMultiJoinFields = originalMultiJoin.getRowType().getFieldList();
        int globalFieldId = 0;
        for (int inputIndex = 0; inputIndex < originalInputs.size(); ++inputIndex) {
            RelNode originalInput = originalInputs.get(inputIndex);
            List<RelDataTypeField> inputFields = originalInput.getRowType().getFieldList();
            for (int localFieldIndex = 0; localFieldIndex < inputFields.size(); ++localFieldIndex) {
                int newGlobalFieldId = fieldMapping.getTargetOpt(globalFieldId);
                if (newGlobalFieldId != -1) {
                    RelDataTypeField originalField = originalMultiJoinFields.get(globalFieldId);
                    newFields.add(new RelDataTypeFieldImpl(originalField.getName(), newFields.size(), originalField.getType()));
                }
                ++globalFieldId;
            }
        }
        return typeFactory.createStructType(newFields);
    }

    private RexNode applyMappingToRexNode(RexNode rexNode, Mappings.TargetMapping fieldMapping) {
        return rexNode != null ? RexUtil.apply(fieldMapping, rexNode) : null;
    }

    private List<RexNode> applyMappingToOuterJoinConditions(List<RexNode> outerJoinConditions, Mappings.TargetMapping fieldMapping) {
        ArrayList<RexNode> newOuterJoinConditions = new ArrayList<RexNode>();
        for (RexNode condition : outerJoinConditions) {
            newOuterJoinConditions.add(this.applyMappingToRexNode(condition, fieldMapping));
        }
        return newOuterJoinConditions;
    }

    private Map<Integer, ImmutableIntList> createNewJoinFieldRefCountsMap(MultiJoin originalMultiJoin, TransformedInputs transformedInputs, Mappings.TargetMapping fieldMapping) {
        HashMap<Integer, ImmutableIntList> newJoinFieldRefCountsMap = new HashMap<Integer, ImmutableIntList>();
        ImmutableMap<Integer, ImmutableIntList> originalJoinFieldRefCountsMap = originalMultiJoin.getJoinFieldRefCountsMap();
        for (Map.Entry entry : originalJoinFieldRefCountsMap.entrySet()) {
            Integer inputIndex = (Integer)entry.getKey();
            ImmutableIntList originalRefCounts = (ImmutableIntList)entry.getValue();
            RelNode newInput = transformedInputs.newInputs.get(inputIndex);
            int newFieldCount = newInput.getRowType().getFieldCount();
            int[] newRefCounts = new int[newFieldCount];
            for (int originalFieldIndex = 0; originalFieldIndex < originalRefCounts.size(); ++originalFieldIndex) {
                int newLocalFieldIndex;
                int globalFieldIndex = this.calculateGlobalFieldIndex(inputIndex, originalFieldIndex, originalMultiJoin.getInputs());
                int newGlobalFieldIndex = fieldMapping.getTargetOpt(globalFieldIndex);
                if (newGlobalFieldIndex == -1 || (newLocalFieldIndex = this.calculateLocalFieldIndex(newGlobalFieldIndex, transformedInputs.newInputs, inputIndex)) < 0 || newLocalFieldIndex >= newFieldCount) continue;
                newRefCounts[newLocalFieldIndex] = originalRefCounts.get(originalFieldIndex);
            }
            newJoinFieldRefCountsMap.put(inputIndex, ImmutableIntList.of(newRefCounts));
        }
        return newJoinFieldRefCountsMap;
    }

    private boolean isProject(RelNode relNode) {
        if (relNode instanceof Project) {
            return true;
        }
        if (relNode instanceof HepRelVertex) {
            return ((HepRelVertex)relNode).getCurrentRel() instanceof Project;
        }
        return false;
    }

    private int calculateGlobalFieldIndex(int inputIndex, int localFieldIndex, List<RelNode> inputs) {
        int globalFieldIndex = this.calculateFieldOffset(inputIndex, inputs);
        return globalFieldIndex + localFieldIndex;
    }

    private int calculateLocalFieldIndex(int globalFieldIndex, List<RelNode> inputs, int currentInputIndex) {
        int offset = this.calculateFieldOffset(currentInputIndex, inputs);
        return globalFieldIndex - offset;
    }

    private int calculateFieldOffset(int inputIndex, List<RelNode> inputs) {
        int offset = 0;
        for (int i = 0; i < inputIndex; ++i) {
            offset += inputs.get(i).getRowType().getFieldCount();
        }
        return offset;
    }

    @Value.Immutable
    public static interface ProjectMultiJoinTransposeRuleConfig
    extends RelRule.Config {
        public static final ProjectMultiJoinTransposeRuleConfig DEFAULT = ImmutableProjectMultiJoinTransposeRule.ProjectMultiJoinTransposeRuleConfig.builder().build().withRelBuilderFactory(RelFactories.LOGICAL_BUILDER).as(ProjectMultiJoinTransposeRuleConfig.class).withOperandFor(Project.class, MultiJoin.class);

        @Override
        default public ProjectMultiJoinTransposeRule toRule() {
            return new ProjectMultiJoinTransposeRule(this);
        }

        default public ProjectMultiJoinTransposeRuleConfig withOperandFor(Class<? extends Project> projectClass, Class<? extends MultiJoin> multiJoinClass) {
            return this.withOperandSupplier(b0 -> b0.operand(projectClass).oneInput(b1 -> b1.operand(multiJoinClass).anyInputs())).as(ProjectMultiJoinTransposeRuleConfig.class);
        }
    }

    private static class TransformedInputs {
        final List<RelNode> newInputs;
        final List<ImmutableBitSet> newProjFields;
        final List<Integer> newFieldCounts;

        TransformedInputs(List<RelNode> newInputs, List<ImmutableBitSet> newProjFields, List<Integer> newFieldCounts) {
            this.newInputs = newInputs;
            this.newProjFields = newProjFields;
            this.newFieldCounts = newFieldCounts;
        }
    }
}

