/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.plan.rules.logical;

import java.util.Arrays;
import java.util.Comparator;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.calcite.plan.RelOptRuleCall;
import org.apache.calcite.plan.RelOptUtil;
import org.apache.calcite.plan.RelRule;
import org.apache.calcite.rel.RelNode;
import org.apache.calcite.rel.core.JoinRelType;
import org.apache.calcite.rel.type.RelDataTypeField;
import org.apache.calcite.rex.RexBuilder;
import org.apache.calcite.rex.RexCall;
import org.apache.calcite.rex.RexInputRef;
import org.apache.calcite.rex.RexNode;
import org.apache.calcite.rex.RexShuttle;
import org.apache.calcite.util.ImmutableBitSet;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.planner.plan.metadata.FlinkRelMetadataQuery;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalJoin;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalRel;
import org.apache.flink.table.planner.plan.nodes.logical.FlinkLogicalSnapshot;
import org.apache.flink.table.planner.plan.rules.common.CommonTemporalTableJoinRule;
import org.apache.flink.table.planner.plan.rules.logical.ImmutableTemporalJoinRewriteWithUniqueKeyRule;
import org.apache.flink.table.planner.plan.utils.TemporalJoinUtil;
import org.immutables.value.Value;
import scala.collection.JavaConverters;
import scala.collection.Seq;

@Value.Enclosing
public class TemporalJoinRewriteWithUniqueKeyRule
extends RelRule<TemporalJoinRewriteWithUniqueKeyRuleConfig>
implements CommonTemporalTableJoinRule {
    public static final TemporalJoinRewriteWithUniqueKeyRule INSTANCE = TemporalJoinRewriteWithUniqueKeyRuleConfig.DEFAULT.toRule();

    private TemporalJoinRewriteWithUniqueKeyRule(TemporalJoinRewriteWithUniqueKeyRuleConfig config) {
        super(config);
    }

    @Override
    public boolean matches(RelOptRuleCall call) {
        FlinkLogicalJoin join = (FlinkLogicalJoin)call.rel(0);
        FlinkLogicalSnapshot snapshot = (FlinkLogicalSnapshot)call.rel(2);
        FlinkLogicalRel snapshotInput = (FlinkLogicalRel)call.rel(3);
        boolean isTemporalJoin = this.matches(snapshot);
        boolean canConvertToLookup = this.canConvertToLookupJoin(snapshot, snapshotInput);
        List<JoinRelType> supportedJoinTypes = Arrays.asList(JoinRelType.INNER, JoinRelType.LEFT);
        return isTemporalJoin && !canConvertToLookup && supportedJoinTypes.contains((Object)join.getJoinType());
    }

    @Override
    public void onMatch(RelOptRuleCall call) {
        final FlinkLogicalJoin join = (FlinkLogicalJoin)call.rel(0);
        final FlinkLogicalRel leftInput = (FlinkLogicalRel)call.rel(1);
        final FlinkLogicalSnapshot snapshot = (FlinkLogicalSnapshot)call.rel(2);
        final FlinkLogicalRel snapshotInput = (FlinkLogicalRel)call.rel(3);
        RexNode joinCondition = join.getCondition();
        RexNode newJoinCondition = joinCondition.accept(new RexShuttle(){

            @Override
            public RexNode visitCall(RexCall call) {
                if (call.getOperator().equals(TemporalJoinUtil.INITIAL_TEMPORAL_JOIN_CONDITION())) {
                    List<RexNode> rightJoinKey;
                    List<RexNode> leftJoinKey;
                    RexNode snapshotTimeInputRef;
                    if (TemporalJoinUtil.isInitialRowTimeTemporalTableJoin(call)) {
                        snapshotTimeInputRef = (RexNode)call.operands.get(0);
                        leftJoinKey = ((RexCall)call.operands.get(2)).getOperands();
                        rightJoinKey = ((RexCall)call.operands.get(3)).getOperands();
                    } else {
                        snapshotTimeInputRef = (RexNode)call.operands.get(0);
                        leftJoinKey = ((RexCall)call.operands.get(1)).getOperands();
                        rightJoinKey = ((RexCall)call.operands.get(2)).getOperands();
                    }
                    RexBuilder rexBuilder = join.getCluster().getRexBuilder();
                    Optional<List<RexNode>> primaryKeyInputRefs = TemporalJoinRewriteWithUniqueKeyRule.this.extractPrimaryKeyInputRefs(leftInput, snapshot, snapshotInput, rexBuilder);
                    TemporalJoinRewriteWithUniqueKeyRule.this.validateRightPrimaryKey(join, rightJoinKey, primaryKeyInputRefs);
                    if (TemporalJoinUtil.isInitialRowTimeTemporalTableJoin(call)) {
                        RexNode rightTimeInputRef = (RexNode)call.operands.get(1);
                        return TemporalJoinUtil.makeRowTimeTemporalTableJoinConCall(rexBuilder, snapshotTimeInputRef, rightTimeInputRef, (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(primaryKeyInputRefs.get()).asScala()), (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(leftJoinKey).asScala()), (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(rightJoinKey).asScala()));
                    }
                    return TemporalJoinUtil.makeProcTimeTemporalTableJoinConCall(rexBuilder, snapshotTimeInputRef, (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(primaryKeyInputRefs.get()).asScala()), (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(leftJoinKey).asScala()), (Seq<RexNode>)((Seq)JavaConverters.asScalaBufferConverter(rightJoinKey).asScala()));
                }
                return super.visitCall(call);
            }
        });
        FlinkLogicalJoin rewriteJoin = FlinkLogicalJoin.create(leftInput, snapshot, newJoinCondition, join.getHints(), join.getJoinType());
        call.transformTo(rewriteJoin);
    }

    private void validateRightPrimaryKey(FlinkLogicalJoin join, List<RexNode> rightJoinKeyExpressions, Optional<List<RexNode>> rightPrimaryKeyInputRefs) {
        if (!rightPrimaryKeyInputRefs.isPresent()) {
            throw new ValidationException("Temporal Table Join requires primary key in versioned table, but no primary key can be found. The physical plan is:\n" + RelOptUtil.toString(join) + "\n");
        }
        List rightJoinKeyRefIndices = rightJoinKeyExpressions.stream().map(rex -> ((RexInputRef)rex).getIndex()).collect(Collectors.toList());
        List rightPrimaryKeyRefIndices = rightPrimaryKeyInputRefs.get().stream().map(rex -> ((RexInputRef)rex).getIndex()).collect(Collectors.toList());
        boolean primaryKeyContainedInJoinKey = rightPrimaryKeyRefIndices.stream().allMatch(pk -> rightJoinKeyRefIndices.contains(pk));
        if (!primaryKeyContainedInJoinKey) {
            List<String> joinFieldNames = join.getRowType().getFieldNames();
            List<String> joinLeftFieldNames = join.getLeft().getRowType().getFieldNames();
            List<String> joinRightFieldNames = join.getRight().getRowType().getFieldNames();
            String primaryKeyNames = rightPrimaryKeyRefIndices.stream().map(i -> (String)joinFieldNames.get((int)i)).collect(Collectors.joining(","));
            String joinEquiInfo = join.analyzeCondition().pairs().stream().map(pair -> (String)joinLeftFieldNames.get(pair.source) + "=" + (String)joinRightFieldNames.get(pair.target)).collect(Collectors.joining(","));
            throw new ValidationException("Temporal table's primary key [" + primaryKeyNames + "] must be included in the equivalence condition of temporal join, but current temporal join condition is [" + joinEquiInfo + "].");
        }
    }

    private Optional<List<RexNode>> extractPrimaryKeyInputRefs(RelNode leftInput, FlinkLogicalSnapshot snapshot, FlinkLogicalRel snapshotInput, RexBuilder rexBuilder) {
        List<RelDataTypeField> rightFields = snapshot.getRowType().getFieldList();
        FlinkRelMetadataQuery fmq = FlinkRelMetadataQuery.reuseOrCreate(snapshot.getCluster().getMetadataQuery());
        Set<ImmutableBitSet> upsertKeySet = fmq.getUpsertKeys(snapshotInput);
        List<RelDataTypeField> fields = snapshot.getRowType().getFieldList();
        if (upsertKeySet != null && !upsertKeySet.isEmpty()) {
            int leftFieldCnt = leftInput.getRowType().getFieldCount();
            List upsertKeySetInputRefs = upsertKeySet.stream().filter(bitSet -> !bitSet.isEmpty()).map(bitSet -> Arrays.stream(bitSet.toArray()).mapToObj(index -> (RelDataTypeField)fields.get(index)).map(f -> rexBuilder.makeInputRef(f.getType(), leftFieldCnt + rightFields.indexOf(f))).collect(Collectors.toList())).collect(Collectors.toList());
            return upsertKeySetInputRefs.stream().sorted(Comparator.comparingInt(List::size)).findFirst();
        }
        return Optional.empty();
    }

    @Value.Immutable(singleton=false)
    public static interface TemporalJoinRewriteWithUniqueKeyRuleConfig
    extends RelRule.Config {
        public static final TemporalJoinRewriteWithUniqueKeyRuleConfig DEFAULT = ImmutableTemporalJoinRewriteWithUniqueKeyRule.TemporalJoinRewriteWithUniqueKeyRuleConfig.builder().build().withOperandSupplier(b0 -> b0.operand(FlinkLogicalJoin.class).inputs(b1 -> b1.operand(FlinkLogicalRel.class).anyInputs(), b2 -> b2.operand(FlinkLogicalSnapshot.class).oneInput(b3 -> b3.operand(FlinkLogicalRel.class).anyInputs()))).withDescription("TemporalJoinRewriteWithUniqueKeyRule");

        @Override
        default public TemporalJoinRewriteWithUniqueKeyRule toRule() {
            return new TemporalJoinRewriteWithUniqueKeyRule(this);
        }
    }
}

