/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.examples.ml;

import org.apache.spark.ml.classification.Classifier;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.classification.OneVsRest;
import org.apache.spark.ml.classification.OneVsRestModel;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.SparkSession;

public class JavaOneVsRestExample {
    public static void main(String[] args) {
        SparkSession spark = SparkSession.builder().appName("JavaOneVsRestExample").getOrCreate();
        Dataset inputData = spark.read().format("libsvm").load("data/mllib/sample_multiclass_classification_data.txt");
        Dataset[] tmp = inputData.randomSplit(new double[]{0.8, 0.2});
        Dataset train = tmp[0];
        Dataset test = tmp[1];
        LogisticRegression classifier = new LogisticRegression().setMaxIter(10).setTol(1.0E-6).setFitIntercept(true);
        OneVsRest ovr = new OneVsRest().setClassifier((Classifier)classifier);
        OneVsRestModel ovrModel = ovr.fit(train);
        Dataset predictions = ovrModel.transform(test).select("prediction", new String[]{"label"});
        MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setMetricName("accuracy");
        double accuracy = evaluator.evaluate(predictions);
        System.out.println("Test Error = " + (1.0 - accuracy));
        spark.stop();
    }
}

