package weka.knowledgeflow.steps;

import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.zip.GZIPInputStream;
import java.util.zip.GZIPOutputStream;
import org.apache.commons.codec.binary.Base64;
import weka.classifiers.evaluation.NumericPrediction;
import weka.classifiers.timeseries.AbstractForecaster;
import weka.classifiers.timeseries.WekaForecaster;
import weka.core.Attribute;
import weka.core.DenseInstance;
import weka.core.Instance;
import weka.core.Instances;
import weka.core.OptionMetadata;
import weka.core.Utils;
import weka.core.WekaException;
import weka.filters.supervised.attribute.TSLagMaker;
import weka.gui.FilePropertyMetadata;
import weka.gui.ProgrammaticProperty;
import weka.gui.knowledgeflow.TimeSeriesPerspective;
import weka.knowledgeflow.Data;

@KFStep(name = "TimeSeriesForecasting", category = "TimeSeries", toolTipText = "Encapsulates a time series forecasting model and uses it to produce forecasts given incoming historical data. Forecaster can optionally be rebuilt using the incoming data before a forecast is generated.", iconPath = "weka/gui/knowledgeflow/icons/DefaultClassifier.gif")
/* loaded from: input_file:weka/knowledgeflow/steps/TimeSeriesForecasting.class */
public class TimeSeriesForecasting extends BaseStep {
    private static final long serialVersionUID = -7826178727365267059L;
    protected transient Instances m_header;
    protected transient WekaForecaster m_forecaster;
    protected transient Instances m_outgoingStructure;
    protected boolean m_rebuildForecaster;
    protected transient Instances m_overlayData;
    protected transient Instances m_bufferedPrimeData;
    protected transient boolean m_isUsingOverlayData;
    protected transient TSLagMaker m_modelLagMaker;
    protected transient List<String> m_fieldsToForecast;
    protected boolean m_isReset;
    protected boolean m_isStreaming;
    protected Data m_streamingData;
    protected File m_fileName = new File("");
    protected File m_saveFileName = new File("");
    protected String m_encodedForecaster = "-NONE-";
    protected String m_numberOfStepsToForecast = "1";
    protected String m_artificialTimeStartOffset = "0";
    protected transient String m_timeStampName = "";

    @ProgrammaticProperty
    public void setEncodedForecaster(String str) {
        this.m_encodedForecaster = str;
    }

    public String getEncodedForecaster() {
        return this.m_encodedForecaster;
    }

    @OptionMetadata(displayName = "Number of steps to forecast", description = "The number of steps to forecast beyond the end of the incoming priming data. This will be ignored if the forecaster is using overlay data, as the number of instances for which overlay data is present (and targets are missing) in the incoming data will determine how many forecasted values are produced", displayOrder = 0)
    public void setNumStepsToForecast(String str) {
        this.m_numberOfStepsToForecast = str;
    }

    public String getNumStepsToForecast() {
        return this.m_numberOfStepsToForecast;
    }

    @OptionMetadata(displayName = "Artificial time start offset", description = "Set the offset, from the value associated with the last training instance, for the artificial timestamp. Has no effect if an artificial timestamp is not in use by the forecaster. If in use, this needs to be set so that the forecaster knows what timestamp value corresponds to the first requested forecast (i.e. it should be equal to the number of recent historical priming instances that occur after the last training instance in time", displayOrder = TimeSeriesPerspective.TimeSeriesDefaults.SHOW_CLIPBOARD_POPUP)
    public void setArtificialTimeStartOffset(String str) {
        this.m_artificialTimeStartOffset = str;
    }

    public String getArtificialTimeStartOffset() {
        return this.m_artificialTimeStartOffset;
    }

    @FilePropertyMetadata(fileChooserDialogType = 0, directoriesOnly = false)
    @ProgrammaticProperty
    @OptionMetadata(displayName = "File to load forecaster from", description = "File to load a forecaster from at runtime", displayOrder = 2)
    public void setFilename(File file) {
        this.m_fileName = file;
    }

    public File getFilename() {
        return this.m_fileName;
    }

    @FilePropertyMetadata(fileChooserDialogType = TimeSeriesPerspective.TimeSeriesDefaults.SHOW_CLIPBOARD_POPUP, directoriesOnly = false)
    @OptionMetadata(displayName = "File to save forecaster to", description = "File to save forecaster to (only applies when rebuilding forecaster)", displayOrder = 4)
    public void setSaveFilename(File file) {
        this.m_saveFileName = file;
    }

    public File getSaveFilename() {
        return this.m_saveFileName;
    }

    @OptionMetadata(displayName = "Rebuild forecaster", description = "Rebuild forecaster on incoming data", displayOrder = 3)
    public void setRebuildForecaster(boolean z) {
        this.m_rebuildForecaster = z;
    }

    public boolean getRebuildForecaster() {
        return this.m_rebuildForecaster;
    }

    public void stepInit() throws WekaException {
        if ((this.m_encodedForecaster == null || this.m_encodedForecaster.equals("-NONE-")) && (this.m_fileName == null || isEmpty(this.m_fileName.toString()))) {
            throw new WekaException("No forecaster specified!");
        }
        this.m_isReset = true;
        this.m_isStreaming = false;
        this.m_overlayData = null;
        this.m_bufferedPrimeData = null;
        this.m_streamingData = new Data("instance");
    }

    public void processIncoming(Data data) throws WekaException {
        Instances instances;
        boolean z = false;
        if (this.m_isReset) {
            this.m_isReset = false;
            loadOrDecodeForecaster();
            z = true;
            if (getStepManager().numIncomingConnectionsOfType("instance") > 0) {
                this.m_isStreaming = true;
                instances = ((Instance) data.getPrimaryPayload()).dataset();
            } else {
                instances = new Instances((Instances) data.getPrimaryPayload(), 0);
            }
            if (!this.m_header.equalHeaders(instances)) {
                throw new WekaException(this.m_header.equalHeadersMsg(instances));
            }
            try {
                getStepManager().logBasic("Making output structure");
            } catch (Exception e) {
                throw new WekaException(e);
            }
        }
        if (!this.m_isStreaming) {
            processBatch(data);
            this.m_streamingData.clearPayload();
            getStepManager().throughputFinished(new Data[]{this.m_streamingData});
        } else {
            if (getStepManager().isStreamFinished(data)) {
                try {
                    processInstance(null, false);
                    generateForecast();
                    this.m_streamingData.clearPayload();
                    getStepManager().throughputFinished(new Data[]{this.m_streamingData});
                    return;
                } catch (Exception e2) {
                    throw new WekaException(e2);
                }
            }
            processStreaming(data, z);
        }
        if (isStopRequested()) {
            getStepManager().interrupted();
        } else {
            if (this.m_isStreaming) {
                return;
            }
            getStepManager().finished();
        }
    }

    protected void processStreaming(Data data, boolean z) throws WekaException {
        try {
            processInstance((Instance) data.getPrimaryPayload(), z);
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    protected void processBatch(Data data) throws WekaException {
        try {
            processInstance(null, true);
            Instances instances = (Instances) data.getPrimaryPayload();
            for (int i = 0; i < instances.numInstances(); i++) {
                processInstance(instances.instance(i), false);
            }
            processInstance(null, false);
            generateForecast();
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    protected void processInstance(Instance instance, boolean z) throws Exception {
        getStepManager().throughputUpdateStart();
        if (z) {
            getStepManager().statusMessage("Configuring forecaster...");
            getStepManager().logBasic("Configuring forecaster.");
            this.m_modelLagMaker = this.m_forecaster.getTSLagMaker();
            if (!this.m_modelLagMaker.isUsingAnArtificialTimeIndex() && this.m_modelLagMaker.getAdjustForTrends()) {
                this.m_timeStampName = this.m_modelLagMaker.getTimeStampField();
            }
            this.m_isUsingOverlayData = this.m_forecaster.isUsingOverlayData();
            if (this.m_rebuildForecaster) {
                getStepManager().logBasic("Forecaster will be rebuilt/re-estimated on incoming data");
            } else {
                getStepManager().logBasic("Forecaster will be primed incrementally.");
                this.m_forecaster.primeForecaster(new Instances(this.m_header, 0));
            }
            if (this.m_isUsingOverlayData) {
                getStepManager().logDetailed("Forecaster is using overlay data. We expect to see overlay attribute values for the forecasting period.");
                this.m_overlayData = new Instances(this.m_header, 0);
            }
            if (this.m_rebuildForecaster) {
                this.m_bufferedPrimeData = new Instances(this.m_header, 0);
            }
            this.m_fieldsToForecast = AbstractForecaster.stringToList(this.m_forecaster.getFieldsToForecast());
            this.m_outgoingStructure = new Instances(this.m_header);
            if (this.m_forecaster.isProducingConfidenceIntervals()) {
                ArrayList arrayList = new ArrayList();
                for (int i = 0; i < this.m_header.numAttributes(); i++) {
                    arrayList.add((Attribute) this.m_header.attribute(i).copy());
                }
                for (String str : this.m_fieldsToForecast) {
                    Attribute attribute = new Attribute(str + "_lowerBound");
                    Attribute attribute2 = new Attribute(str + "_upperBound");
                    arrayList.add(attribute);
                    arrayList.add(attribute2);
                }
                this.m_outgoingStructure = new Instances(this.m_header.relationName() + "_plus_forecast", arrayList, 0);
            }
        } else if (instance == null) {
            if (this.m_rebuildForecaster && this.m_bufferedPrimeData.numInstances() > 0) {
                for (int i2 = 0; i2 < this.m_bufferedPrimeData.numInstances(); i2++) {
                    this.m_streamingData.setPayloadElement("instance", this.m_bufferedPrimeData.instance(i2));
                    getStepManager().outputData(new Data[]{this.m_streamingData});
                }
                getStepManager().statusMessage("Rebuilding the forecasting model...");
                getStepManager().logBasic("Rebuilding the forecasting model");
                this.m_forecaster.buildForecaster(this.m_bufferedPrimeData, new PrintStream[0]);
                getStepManager().statusMessage("Priming the forecasting model...");
                getStepManager().logBasic("Priming the forecasting model");
            }
            if (this.m_rebuildForecaster && !isEmpty(this.m_saveFileName.toString())) {
                getStepManager().statusMessage("Saving rebuilt forecasting model...");
                getStepManager().logBasic("Saving rebuilt forecasting model to \"" + this.m_saveFileName.toString() + "\"");
                OutputStream fileOutputStream = new FileOutputStream(this.m_saveFileName);
                if (this.m_saveFileName.toString().endsWith(".gz")) {
                    fileOutputStream = new GZIPOutputStream(fileOutputStream);
                }
                ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(fileOutputStream));
                try {
                    objectOutputStream.writeObject(this.m_forecaster);
                    objectOutputStream.writeObject(this.m_header);
                    objectOutputStream.flush();
                    objectOutputStream.close();
                } catch (Throwable th) {
                    objectOutputStream.flush();
                    objectOutputStream.close();
                    throw th;
                }
            }
        } else if (this.m_isUsingOverlayData) {
            boolean z2 = true;
            Iterator<String> it = this.m_fieldsToForecast.iterator();
            while (true) {
                if (it.hasNext()) {
                    if (!instance.isMissing(this.m_header.attribute(it.next()))) {
                        z2 = false;
                        break;
                    }
                } else {
                    break;
                }
            }
            if (z2) {
                this.m_overlayData.add(instance);
                getStepManager().statusMessage("buffering overlay instance...");
            } else if (this.m_overlayData.numInstances() > 0) {
                this.m_overlayData.add(instance);
                getStepManager().logWarning("Encountered a supposed overlay instance with non-missing target values - converting buffered overlay data into " + (this.m_rebuildForecaster ? "training" : "priming") + " data...");
                getStepManager().statusMessage("Flushing overlay buffer.");
                for (int i3 = 0; i3 < this.m_overlayData.numInstances(); i3++) {
                    if (this.m_rebuildForecaster) {
                        this.m_bufferedPrimeData.add(this.m_overlayData.instance(i3));
                    } else {
                        this.m_forecaster.primeForecasterIncremental(this.m_overlayData.instance(i3));
                        this.m_streamingData.setPayloadElement("instance", convertToOutputFormat(this.m_overlayData.instance(i3)));
                        getStepManager().outputData(new Data[]{this.m_streamingData});
                    }
                }
                this.m_overlayData = new Instances(this.m_header, 0);
            } else if (this.m_rebuildForecaster) {
                this.m_bufferedPrimeData.add(instance);
            } else {
                this.m_forecaster.primeForecasterIncremental(instance);
                this.m_streamingData.setPayloadElement("instance", convertToOutputFormat(instance));
                getStepManager().outputData(new Data[]{this.m_streamingData});
            }
        } else if (this.m_rebuildForecaster) {
            this.m_bufferedPrimeData.add(instance);
        } else {
            this.m_forecaster.primeForecasterIncremental(instance);
            this.m_streamingData.setPayloadElement("instance", convertToOutputFormat(instance));
            getStepManager().outputData(new Data[]{this.m_streamingData});
        }
        getStepManager().throughputUpdateEnd();
    }

    private Instance convertToOutputFormat(Instance instance) {
        DenseInstance denseInstance = (Instance) instance.copy();
        if (this.m_forecaster.isProducingConfidenceIntervals()) {
            double[] dArr = new double[instance.numAttributes() + (this.m_fieldsToForecast.size() * 2)];
            for (int i = 0; i < instance.numAttributes(); i++) {
                dArr[i] = instance.value(i);
            }
            for (int numAttributes = instance.numAttributes(); numAttributes < instance.numAttributes() + (this.m_fieldsToForecast.size() * 2); numAttributes++) {
                dArr[numAttributes] = Utils.missingValue();
            }
            denseInstance = new DenseInstance(1.0d, dArr);
        }
        denseInstance.setDataset(this.m_outgoingStructure);
        return denseInstance;
    }

    private void generateForecast() throws Exception {
        double d = -1.0d;
        if (this.m_modelLagMaker.getAdjustForTrends() && this.m_modelLagMaker.getTimeStampField() != null && this.m_modelLagMaker.getTimeStampField().length() > 0 && !this.m_modelLagMaker.isUsingAnArtificialTimeIndex()) {
            d = this.m_modelLagMaker.getCurrentTimeStampValue();
        } else if (this.m_modelLagMaker.getAdjustForTrends() && this.m_modelLagMaker.isUsingAnArtificialTimeIndex()) {
            this.m_modelLagMaker.setArtificialTimeStartValue(this.m_modelLagMaker.getArtificialTimeStartValue() + Integer.parseInt(environmentSubstitute(this.m_artificialTimeStartOffset)));
        }
        boolean z = this.m_overlayData != null && this.m_overlayData.numInstances() > 0 && this.m_isUsingOverlayData;
        int numInstances = z ? this.m_overlayData.numInstances() : Integer.parseInt(environmentSubstitute(this.m_numberOfStepsToForecast));
        List<List<NumericPrediction>> forecast = z ? this.m_forecaster.forecast(numInstances, this.m_overlayData, new PrintStream[0]) : this.m_forecaster.forecast(numInstances, new PrintStream[0]);
        double d2 = d;
        int i = -1;
        if (this.m_timeStampName.length() > 0) {
            Attribute attribute = this.m_outgoingStructure.attribute(this.m_timeStampName);
            if (attribute == null) {
                getStepManager().logError("couldn't find time stamp: " + this.m_timeStampName + "in the input data", (Throwable) null);
            }
            i = attribute.index();
        }
        getStepManager().statusMessage("Generating forecast...");
        getStepManager().logBasic("Generating forecast.");
        for (int i2 = 0; i2 < numInstances; i2++) {
            if (this.m_isStreaming) {
                getStepManager().throughputUpdateStart();
            }
            double[] dArr = new double[this.m_outgoingStructure.numAttributes()];
            for (int i3 = 0; i3 < dArr.length; i3++) {
                if (z) {
                    dArr[i3] = this.m_overlayData.instance(i2).value(i3);
                } else {
                    dArr[i3] = Utils.missingValue();
                }
            }
            List<NumericPrediction> list = forecast.get(i2);
            if (i != -1) {
                d2 = this.m_modelLagMaker.advanceSuppliedTimeValue(d2);
                dArr[i] = d2;
            }
            for (int i4 = 0; i4 < this.m_fieldsToForecast.size(); i4++) {
                String str = this.m_fieldsToForecast.get(i4);
                int index = this.m_outgoingStructure.attribute(str).index();
                NumericPrediction numericPrediction = list.get(i4);
                double predicted = numericPrediction.predicted();
                double[][] predictionIntervals = numericPrediction.predictionIntervals();
                if (!Utils.isMissingValue(predicted)) {
                    dArr[index] = predicted;
                }
                if (predictionIntervals.length > 0) {
                    double d3 = predictionIntervals[0][0];
                    double d4 = predictionIntervals[0][1];
                    int index2 = this.m_outgoingStructure.attribute(str + "_lowerBound").index();
                    int index3 = this.m_outgoingStructure.attribute(str + "_upperBound").index();
                    dArr[index2] = d3;
                    dArr[index3] = d4;
                }
            }
            DenseInstance denseInstance = new DenseInstance(1.0d, dArr);
            denseInstance.setDataset(this.m_outgoingStructure);
            this.m_streamingData.setPayloadElement("instance", denseInstance);
            getStepManager().outputData(new Data[]{this.m_streamingData});
            if (this.m_isStreaming) {
                getStepManager().throughputUpdateEnd();
            }
        }
        getStepManager().logBasic("Finished. Generated " + numInstances + " forecasted values.");
    }

    public List<String> getIncomingConnectionTypes() {
        ArrayList arrayList = new ArrayList();
        if (getStepManager().numIncomingConnections() == 0) {
            arrayList.add("dataSet");
            arrayList.add("instance");
        }
        return arrayList;
    }

    public List<String> getOutgoingConnectionTypes() {
        ArrayList arrayList = new ArrayList();
        if (getStepManager().numIncomingConnections() > 0) {
            arrayList.add("instance");
        }
        return arrayList;
    }

    protected List<Object> loadModel(File file) throws WekaException {
        ArrayList arrayList = new ArrayList();
        try {
            if (isEmpty(file.toString()) || file.toString().equals("-NONE-")) {
                throw new WekaException("Model is null or no filename specified to load from!");
            }
            String environmentSubstitute = environmentSubstitute(file.toString());
            InputStream fileInputStream = new FileInputStream(environmentSubstitute);
            if (environmentSubstitute.toLowerCase().endsWith(".gz")) {
                fileInputStream = new GZIPInputStream(fileInputStream);
            }
            ObjectInputStream objectInputStream = new ObjectInputStream(new BufferedInputStream(fileInputStream));
            WekaForecaster wekaForecaster = (WekaForecaster) objectInputStream.readObject();
            Instances instances = (Instances) objectInputStream.readObject();
            fileInputStream.close();
            arrayList.add(wekaForecaster);
            arrayList.add(instances);
            return arrayList;
        } catch (Exception e) {
            throw new WekaException(e);
        }
    }

    public WekaForecaster getForecaster() throws Exception {
        if (this.m_forecaster != null) {
            return this.m_forecaster;
        }
        List<Object> forecaster = getForecaster(this.m_encodedForecaster);
        if (forecaster == null) {
            return null;
        }
        this.m_forecaster = (WekaForecaster) forecaster.get(0);
        this.m_header = (Instances) forecaster.get(1);
        return this.m_forecaster;
    }

    private void loadOrDecodeForecaster() throws WekaException {
        if (!isEmpty(this.m_fileName.toString())) {
            List<Object> loadModel = loadModel(this.m_fileName);
            if (loadModel == null) {
                throw new WekaException("problem loading forecasting model.");
            }
            this.m_forecaster = (WekaForecaster) loadModel.get(0);
            this.m_header = (Instances) loadModel.get(1);
            return;
        }
        if (this.m_encodedForecaster == null || this.m_encodedForecaster.length() <= 0 || this.m_encodedForecaster.equals("-NONE-")) {
            throw new WekaException("unable to obtain a forecasting model to use.");
        }
        try {
            getForecaster();
        } catch (Exception e) {
            throw new WekaException("a problem occurred while decoding the model.", e);
        }
    }

    public static List<Object> getForecaster(String str) throws Exception {
        if (str == null || str.length() <= 0 || str.equals("-NONE-")) {
            return null;
        }
        ObjectInputStream objectInputStream = new ObjectInputStream(new ByteArrayInputStream(decodeFromBase64(str)));
        List<Object> list = (List) objectInputStream.readObject();
        objectInputStream.close();
        return list;
    }

    protected static byte[] decodeFromBase64(String str) throws Exception {
        byte[] decodeBase64 = str == null ? new byte[0] : Base64.decodeBase64(str.getBytes());
        if (decodeBase64.length > 0) {
            GZIPInputStream gZIPInputStream = new GZIPInputStream(new ByteArrayInputStream(decodeBase64));
            BufferedInputStream bufferedInputStream = new BufferedInputStream(gZIPInputStream);
            byte[] bArr = new byte[0];
            byte[] bArr2 = new byte[1000000];
            int read = bufferedInputStream.read(bArr2);
            while (true) {
                int i = read;
                if (i < 0) {
                    break;
                }
                byte[] bArr3 = new byte[bArr.length + i];
                for (int i2 = 0; i2 < bArr.length; i2++) {
                    bArr3[i2] = bArr[i2];
                }
                for (int i3 = 0; i3 < i; i3++) {
                    bArr3[bArr.length + i3] = bArr2[i3];
                }
                bArr = bArr3;
                read = bufferedInputStream.read(bArr2);
            }
            decodeBase64 = bArr;
            gZIPInputStream.close();
        }
        return decodeBase64;
    }

    protected static String encodeToBase64(byte[] bArr) throws IOException {
        String str;
        if (bArr == null) {
            str = null;
        } else {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            BufferedOutputStream bufferedOutputStream = new BufferedOutputStream(new GZIPOutputStream(byteArrayOutputStream));
            bufferedOutputStream.write(bArr);
            bufferedOutputStream.flush();
            bufferedOutputStream.close();
            str = new String(Base64.encodeBase64(byteArrayOutputStream.toByteArray()));
        }
        return str;
    }

    public static String encodeForecasterToBase64(WekaForecaster wekaForecaster, Instances instances) throws Exception {
        if (wekaForecaster == null || instances == null) {
            throw new Exception("[TimeSeriesForecasting] unable to encode model!");
        }
        ArrayList arrayList = new ArrayList();
        arrayList.add(wekaForecaster);
        arrayList.add(instances);
        ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
        ObjectOutputStream objectOutputStream = new ObjectOutputStream(new BufferedOutputStream(byteArrayOutputStream));
        objectOutputStream.writeObject(arrayList);
        objectOutputStream.flush();
        return encodeToBase64(byteArrayOutputStream.toByteArray());
    }

    public static boolean isEmpty(String str) {
        return str == null || str.length() == 0;
    }

    public String getCustomEditorForStep() {
        return "weka.gui.knowledgeflow.steps.TimeSeriesForecastingStepEditorDialog";
    }
}
