package org.flinkextended.examples.tensorflow.linear;

import java.net.URL;
import java.util.concurrent.ExecutionException;
import org.apache.flink.api.java.utils.MultipleParameterTool;
import org.apache.flink.streaming.api.environment.StreamExecutionEnvironment;
import org.apache.flink.table.api.DataTypes;
import org.apache.flink.table.api.Expressions;
import org.apache.flink.table.api.Schema;
import org.apache.flink.table.api.Table;
import org.apache.flink.table.api.TableDescriptor;
import org.apache.flink.table.api.bridge.java.StreamStatementSet;
import org.apache.flink.table.api.bridge.java.StreamTableEnvironment;
import org.apache.flink.table.expressions.Expression;
import org.flinkextended.flink.ml.tensorflow.client.TFClusterConfig;
import org.flinkextended.flink.ml.tensorflow.client.TFUtils;

/* loaded from: input_file:org/flinkextended/examples/tensorflow/linear/Linear.class */
public class Linear {
    private static final String MODEL_PATH = "model-path";
    private static final String EPOCH = "epoch";
    private static final String SAMPLE_COUNT = "sample-count";
    private static final String MODE = "mode";
    private static final String INFERENCE_OUTPUT_PATH = "inference-output-path";

    public static void main(String[] strArr) throws ExecutionException, InterruptedException {
        MultipleParameterTool fromArgs = MultipleParameterTool.fromArgs(strArr);
        String str = fromArgs.get(MODE, "train");
        String str2 = fromArgs.get(INFERENCE_OUTPUT_PATH, "/tmp/linear/output.csv");
        String str3 = fromArgs.get(MODEL_PATH, String.format("/tmp/linear/%s", Long.valueOf(System.currentTimeMillis())));
        Integer valueOf = Integer.valueOf(fromArgs.get(EPOCH, "1"));
        Integer valueOf2 = Integer.valueOf(fromArgs.get(SAMPLE_COUNT, "512000"));
        StreamExecutionEnvironment executionEnvironment = StreamExecutionEnvironment.getExecutionEnvironment();
        executionEnvironment.setParallelism(2);
        StreamTableEnvironment create = StreamTableEnvironment.create(executionEnvironment);
        StreamStatementSet createStatementSet = create.createStatementSet();
        Table from = create.from(TableDescriptor.forConnector("datagen").schema(Schema.newBuilder().column("x", DataTypes.DOUBLE()).columnByExpression("y", "2 * x + 1").build()).option("fields.x.min", "0").option("fields.x.max", "1").option("number-of-rows", String.valueOf(valueOf2)).build());
        if ("train".equals(str)) {
            System.out.printf("Model will be trained with %d samples for %d epochs and saved at: %s%n", valueOf2, valueOf, str3);
            train(str3, valueOf, createStatementSet, from);
        } else {
            if (!"inference".equals(str)) {
                throw new RuntimeException(String.format("Unknown mode %s", str));
            }
            System.out.printf("Inference with model at %s, output will be at %s%n", str3, str2);
            inference(str3, createStatementSet, from, str2);
        }
    }

    private static void inference(String str, StreamStatementSet streamStatementSet, Table table, String str2) throws ExecutionException, InterruptedException {
        streamStatementSet.addInsert(TableDescriptor.forConnector("filesystem").format("csv").option("path", str2).build(), TFUtils.inference(streamStatementSet, table.dropColumns(new Expression[]{Expressions.$("y")}), TFClusterConfig.newBuilder().setWorkerCount(2).setNodeEntry(getScriptPathFromResources(), "inference").setProperty("storage_type", "local_file").setProperty("model_save_path", str).setProperty("input_types", "FLOAT_64").setProperty("output_types", "FLOAT_64,FLOAT_64").build(), Schema.newBuilder().column("x", DataTypes.DOUBLE()).column("y", DataTypes.DOUBLE()).build()));
        streamStatementSet.execute().await();
    }

    private static void train(String str, Integer num, StreamStatementSet streamStatementSet, Table table) throws InterruptedException, ExecutionException {
        TFUtils.train(streamStatementSet, table, TFClusterConfig.newBuilder().setWorkerCount(2).setNodeEntry(getScriptPathFromResources(), "train").setProperty("storage_type", "local_file").setProperty("model_save_path", str).setProperty("input_types", "FLOAT_64,FLOAT_64").build(), num);
        streamStatementSet.execute().await();
    }

    private static String getScriptPathFromResources() {
        URL resource = Thread.currentThread().getContextClassLoader().getResource("linear.py");
        if (resource == null) {
            throw new RuntimeException(String.format("Fail to find resource %s", "linear.py"));
        }
        return resource.getPath();
    }
}
