/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.types.inference.strategies;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.flink.annotation.Internal;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.ValidationException;
import org.apache.flink.table.functions.FunctionDefinition;
import org.apache.flink.table.functions.ModelSemantics;
import org.apache.flink.table.functions.TableSemantics;
import org.apache.flink.table.types.DataType;
import org.apache.flink.table.types.inference.ArgumentCount;
import org.apache.flink.table.types.inference.CallContext;
import org.apache.flink.table.types.inference.ConstantArgumentCount;
import org.apache.flink.table.types.inference.InputTypeStrategy;
import org.apache.flink.table.types.inference.Signature;
import org.apache.flink.table.types.inference.TypeStrategy;
import org.apache.flink.table.types.logical.LogicalType;
import org.apache.flink.table.types.logical.LogicalTypeRoot;
import org.apache.flink.table.types.logical.utils.LogicalTypeCasts;
import org.apache.flink.types.ColumnList;

@Internal
public class MLPredictTypeStrategy {
    public static final InputTypeStrategy ML_PREDICT_INPUT_TYPE_STRATEGY = new InputTypeStrategy(){

        @Override
        public ArgumentCount getArgumentCount() {
            return ConstantArgumentCount.between(3, 4);
        }

        @Override
        public Optional<List<DataType>> inferInputTypes(CallContext callContext, boolean throwOnFailure) {
            return MLPredictTypeStrategy.inferMLPredictInputTypes(callContext, throwOnFailure);
        }

        @Override
        public List<Signature> getExpectedSignatures(FunctionDefinition definition) {
            return List.of(Signature.of(Signature.Argument.of("TABLE", "ROW"), Signature.Argument.of("MODEL", "MODEL"), Signature.Argument.of("ARGS", "DESCRIPTOR")), Signature.of(Signature.Argument.of("TABLE", "ROW"), Signature.Argument.of("MODEL", "MODEL"), Signature.Argument.of("ARGS", "DESCRIPTOR"), Signature.Argument.of("CONFIG", "MAP")));
        }
    };
    public static final TypeStrategy ML_PREDICT_OUTPUT_TYPE_STRATEGY = callContext -> {
        TableSemantics tableSemantics = callContext.getTableSemantics(0).orElse(null);
        if (tableSemantics == null) {
            throw new ValidationException("First argument must be a table for ML_PREDICT function.");
        }
        ModelSemantics modelSemantics = callContext.getModelSemantics(1).orElse(null);
        if (modelSemantics == null) {
            throw new ValidationException("Second argument must be a model for ML_PREDICT function.");
        }
        LogicalType tableType = tableSemantics.dataType().getLogicalType();
        LogicalType modelOutputType = modelSemantics.outputDataType().getLogicalType();
        if (!tableType.is(LogicalTypeRoot.ROW) || !modelOutputType.is(LogicalTypeRoot.ROW)) {
            throw new ValidationException("Both table and model output types must be row types for ML_PREDICT function.");
        }
        List<DataTypes.Field> tableFields = DataType.getFields(tableSemantics.dataType());
        List<DataTypes.Field> modelFields = DataType.getFields(modelSemantics.outputDataType());
        ArrayList<DataTypes.Field> outputFields = new ArrayList<DataTypes.Field>(tableFields);
        Set tableFieldNames = tableFields.stream().map(DataTypes.AbstractField::getName).collect(Collectors.toSet());
        for (DataTypes.Field modelField : modelFields) {
            Object fieldName = modelField.getName();
            if (tableFieldNames.contains(modelField.getName())) {
                fieldName = (String)fieldName + "0";
            }
            outputFields.add(DataTypes.FIELD((String)fieldName, modelField.getDataType()));
        }
        return Optional.of(DataTypes.ROW(outputFields));
    };

    private static Optional<List<DataType>> inferMLPredictInputTypes(CallContext callContext, boolean throwOnFailure) {
        TableSemantics tableSemantics = callContext.getTableSemantics(0).orElse(null);
        if (tableSemantics == null) {
            if (throwOnFailure) {
                throw new ValidationException("First argument must be a table for ML_PREDICT function.");
            }
            return Optional.empty();
        }
        ModelSemantics modelSemantics = callContext.getModelSemantics(1).orElse(null);
        if (modelSemantics == null) {
            if (throwOnFailure) {
                throw new ValidationException("Second argument must be a model for ML_PREDICT function.");
            }
            return Optional.empty();
        }
        Optional<ColumnList> descriptorColumns = callContext.getArgumentValue(2, ColumnList.class);
        if (descriptorColumns.isEmpty()) {
            if (throwOnFailure) {
                throw new ValidationException("Third argument must be a descriptor with simple column names for ML_PREDICT function.");
            }
            return Optional.empty();
        }
        if (!MLPredictTypeStrategy.validateTableAndDescriptorArguments(tableSemantics, descriptorColumns.get(), throwOnFailure)) {
            return Optional.empty();
        }
        if (!MLPredictTypeStrategy.validateModelDescriptorCompatibility(tableSemantics, modelSemantics, descriptorColumns.get(), throwOnFailure)) {
            return Optional.empty();
        }
        return Optional.of(callContext.getArgumentDataTypes());
    }

    private static boolean validateTableAndDescriptorArguments(TableSemantics tableSemantics, ColumnList descriptorColumns, boolean throwOnFailure) {
        List<String> tableFieldNames = DataType.getFieldNames(tableSemantics.dataType());
        List<String> descriptorColumnNames = descriptorColumns.getNames();
        for (String descriptorColumnName : descriptorColumnNames) {
            if (tableFieldNames.contains(descriptorColumnName)) continue;
            if (throwOnFailure) {
                throw new ValidationException(String.format("Descriptor column '%s' not found in table columns. Available columns: %s.", descriptorColumnName, String.join((CharSequence)", ", tableFieldNames)));
            }
            return false;
        }
        return true;
    }

    private static boolean validateModelDescriptorCompatibility(TableSemantics tableSemantics, ModelSemantics modelSemantics, ColumnList descriptorColumns, boolean throwOnFailure) {
        DataType modelInputDataType = modelSemantics.inputDataType();
        List<DataTypes.Field> modelInputFields = DataType.getFields(modelInputDataType);
        List<String> descriptorColumnNames = descriptorColumns.getNames();
        if (descriptorColumnNames.size() != modelInputFields.size()) {
            if (throwOnFailure) {
                throw new ValidationException(String.format("Number of descriptor columns (%d) does not match model input size (%d).", descriptorColumnNames.size(), modelInputFields.size()));
            }
            return false;
        }
        List<DataTypes.Field> tableFields = DataType.getFields(tableSemantics.dataType());
        for (int i = 0; i < descriptorColumnNames.size(); ++i) {
            LogicalType modelInputColumnType;
            String descriptorColumnName = descriptorColumnNames.get(i);
            DataTypes.Field tableField = tableFields.stream().filter(field -> field.getName().equals(descriptorColumnName)).findFirst().orElseThrow(() -> new IllegalStateException("Column should exist"));
            LogicalType tableColumnType = tableField.getDataType().getLogicalType();
            if (LogicalTypeCasts.supportsImplicitCast(tableColumnType, modelInputColumnType = modelInputFields.get(i).getDataType().getLogicalType())) continue;
            if (throwOnFailure) {
                throw new ValidationException(String.format("Descriptor column '%s' type %s cannot be assigned to model input type %s at position %d.", descriptorColumnName, tableColumnType, modelInputColumnType, i));
            }
            return false;
        }
        return true;
    }
}

