package hex.deepwater;

import deepwater.backends.BackendModel;
import deepwater.backends.BackendTrain;
import hex.FrameTask;
import hex.deepwater.DeepWaterParameters;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Random;
import water.Futures;
import water.H2O;
import water.Job;
import water.fvec.Chunk;
import water.fvec.NewChunk;
import water.parser.BufferedString;
import water.util.Log;
import water.util.RandomUtils;

/* loaded from: input_file:hex/deepwater/DeepWaterTask.class */
public class DeepWaterTask extends FrameTask<DeepWaterTask> {
    private DeepWaterModelInfo _localmodel;
    private DeepWaterModelInfo _sharedmodel;
    private int _chunk_node_count;
    private float _useFraction;
    private boolean _shuffle;
    private final Job _job;
    private static long _lastWarn;
    private static long _warnCount;
    static final /* synthetic */ boolean $assertionsDisabled;

    /* loaded from: input_file:hex/deepwater/DeepWaterTask$NativeTrainTask.class */
    private static class NativeTrainTask extends H2O.H2OCountedCompleter<NativeTrainTask> {
        long _timeInMillis;
        final BackendTrain _backend;
        final BackendModel _model;
        float[] _data;
        float[] _labels;

        NativeTrainTask(BackendTrain backendTrain, BackendModel backendModel, float[] fArr, float[] fArr2) {
            this._backend = backendTrain;
            this._model = backendModel;
            this._data = fArr;
            this._labels = fArr2;
        }

        @Override // water.H2O.H2OCountedCompleter
        public void compute2() {
            long currentTimeMillis = System.currentTimeMillis();
            this._backend.train(this._model, this._data, this._labels);
            this._timeInMillis += System.currentTimeMillis() - currentTimeMillis;
            tryComplete();
        }
    }

    public final DeepWaterModelInfo model_info() {
        if ($assertionsDisabled || this._sharedmodel != null) {
            return this._sharedmodel;
        }
        throw new AssertionError();
    }

    /* JADX INFO: Access modifiers changed from: package-private */
    public DeepWaterTask(DeepWaterModelInfo deepWaterModelInfo, float f, Job job) {
        super(job._key, deepWaterModelInfo._dataInfo);
        this._chunk_node_count = 1;
        this._sharedmodel = deepWaterModelInfo;
        this._useFraction = f;
        this._shuffle = model_info().get_params()._shuffle_training_data;
        this._job = job;
    }

    @Override // hex.FrameTask, water.MRTask
    protected void setupLocal() {
        if (!$assertionsDisabled && this._localmodel != null) {
            throw new AssertionError();
        }
        this._localmodel = this._sharedmodel;
        this._sharedmodel = null;
        this._localmodel.set_processed_local(0L);
        int find = this._fr.find(this._localmodel.get_params()._weights_column);
        int find2 = this._fr.find(this._localmodel.get_params()._response_column);
        int i = this._localmodel.get_params()._mini_batch_size;
        DeepWaterIterator deepWaterIterator = null;
        long j = 912559 + (53261 * this._localmodel.get_processed_global());
        Random rng = RandomUtils.getRNG(j);
        if (this._fr.numRows() > 2147483647L) {
            throw H2O.unimpl("Need to implement batching into int-sized chunks.");
        }
        int numRows = (int) this._fr.numRows();
        int i2 = 0;
        Futures futures = new Futures();
        ArrayList arrayList = new ArrayList();
        ArrayList arrayList2 = new ArrayList();
        try {
            if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image || this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
                Object[] objArr = new Object[1];
                objArr[0] = "Using column " + this._fr.name(0) + " for " + (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image ? "path to image data" : this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text ? "text data" : "path to arbitrary bytes");
                Log.debug(objArr);
                BufferedString bufferedString = new BufferedString();
                int i3 = (int) this._useFraction;
                while (true) {
                    int i4 = i2;
                    i2++;
                    if (i4 >= i3) {
                        break;
                    }
                    for (int i5 = 0; i5 < this._fr.numRows(); i5++) {
                        if ((find == -1 ? 1.0d : this._fr.vec(find).at(i5)) != 0.0d) {
                            BufferedString atStr = this._fr.vec(0).atStr(bufferedString, i5);
                            if (atStr != null) {
                                arrayList2.add(atStr.toString());
                            }
                            arrayList.add(Float.valueOf((float) this._fr.vec(find2).at(i5)));
                        }
                    }
                }
                while (true) {
                    if (arrayList2.size() >= this._useFraction * numRows && arrayList2.size() % i == 0) {
                        break;
                    }
                    if (!$assertionsDisabled && !this._shuffle) {
                        throw new AssertionError();
                    }
                    int nextInt = rng.nextInt(numRows);
                    if ((find == -1 ? 1.0d : this._fr.vec(find).at(nextInt)) != 0.0d) {
                        BufferedString atStr2 = this._fr.vec(0).atStr(bufferedString, nextInt);
                        if (atStr2 != null) {
                            arrayList2.add(atStr2.toString());
                        }
                        arrayList.add(Float.valueOf((float) this._fr.vec(find2).at(nextInt)));
                    }
                }
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
                double d = this._localmodel._dataInfo._normRespMul != null ? this._localmodel._dataInfo._normRespMul[0] : 1.0d;
                double d2 = this._localmodel._dataInfo._normRespSub != null ? this._localmodel._dataInfo._normRespSub[0] : 0.0d;
                int i6 = (int) this._useFraction;
                while (true) {
                    int i7 = i2;
                    i2++;
                    if (i7 >= i6) {
                        break;
                    }
                    for (int i8 = 0; i8 < this._fr.numRows(); i8++) {
                        if ((find == -1 ? 1.0d : this._fr.vec(find).at(i8)) != 0.0d) {
                            float at = (float) ((this._fr.vec(find2).at(i8) - d2) / d);
                            arrayList2.add(Integer.valueOf(i8));
                            arrayList.add(Float.valueOf(at));
                        }
                    }
                }
                while (true) {
                    if (arrayList2.size() >= this._useFraction * numRows && arrayList2.size() % i == 0) {
                        break;
                    }
                    int nextInt2 = rng.nextInt(numRows);
                    if ((find == -1 ? 1.0d : this._fr.vec(find).at(nextInt2)) != 0.0d) {
                        float at2 = (float) ((this._fr.vec(find2).at(nextInt2) - d2) / d);
                        arrayList2.add(Integer.valueOf(nextInt2));
                        arrayList.add(Float.valueOf(at2));
                    }
                }
            }
            if (this._shuffle) {
                rng.setSeed(j);
                Collections.shuffle(arrayList, rng);
                rng.setSeed(j);
                Collections.shuffle(arrayList2, rng);
            }
            if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.image) {
                deepWaterIterator = new DeepWaterImageIterator(arrayList2, arrayList, this._localmodel._meanData, i, this._localmodel._width, this._localmodel._height, this._localmodel._channels, this._localmodel.get_params()._cache_data);
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.dataset) {
                if (!$assertionsDisabled && this._localmodel._dataInfo == null) {
                    throw new AssertionError();
                }
                deepWaterIterator = new DeepWaterDatasetIterator(arrayList2, arrayList, this._localmodel._dataInfo, i, this._localmodel.get_params()._cache_data);
            } else if (this._localmodel.get_params()._problem_type == DeepWaterParameters.ProblemType.text) {
                deepWaterIterator = new DeepWaterTextIterator(arrayList2, arrayList, i, 56, this._localmodel.get_params()._cache_data);
            }
            while (deepWaterIterator.Next(futures) && !this._job.isStopping()) {
                long j2 = this._localmodel.get_processed_total();
                this._localmodel._backend.setParameter(this._localmodel.getModel().get(), "learning_rate", this._localmodel.get_params().learningRate(j2));
                this._localmodel._backend.setParameter(this._localmodel.getModel().get(), "momentum", this._localmodel.get_params().momentum(j2));
                futures.add(H2O.submitTask(new NativeTrainTask(this._localmodel._backend, this._localmodel.getModel().get(), deepWaterIterator.getData(), deepWaterIterator.getLabel())));
                this._localmodel.add_processed_local(deepWaterIterator._batch_size);
            }
            futures.blockForPending();
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    @Override // hex.FrameTask, water.MRTask
    public void map(Chunk[] chunkArr, NewChunk[] newChunkArr) {
    }

    @Override // hex.FrameTask, water.MRTask
    protected void closeLocal() {
        this._sharedmodel = null;
    }

    @Override // water.MRTask
    public void reduce(DeepWaterTask deepWaterTask) {
        if (this._localmodel == null || deepWaterTask._localmodel == null || deepWaterTask._localmodel.get_processed_local() <= 0 || deepWaterTask._localmodel == this._localmodel) {
            return;
        }
        if (this._localmodel.get_processed_local() == 0) {
            this._localmodel = deepWaterTask._localmodel;
            this._chunk_node_count = deepWaterTask._chunk_node_count;
        } else {
            this._localmodel.add(deepWaterTask._localmodel);
            this._chunk_node_count += deepWaterTask._chunk_node_count;
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    @Override // water.MRTask
    public void postGlobal() {
        DeepWaterParameters deepWaterParameters = this._localmodel.get_params();
        if (H2O.CLOUD.size() > 1 && !deepWaterParameters._replicate_training_data) {
            long currentTimeMillis = System.currentTimeMillis();
            if (this._chunk_node_count < H2O.CLOUD.size() && currentTimeMillis - _lastWarn > 5000 && _warnCount < 3) {
                Log.warn((H2O.CLOUD.size() - this._chunk_node_count) + " node(s) (out of " + H2O.CLOUD.size() + ") are not contributing to model updates. Consider setting replicate_training_data to true or using a larger training dataset (or fewer H2O nodes).");
                _lastWarn = currentTimeMillis;
                _warnCount++;
            }
        }
        if (!$assertionsDisabled) {
            if ((!deepWaterParameters._replicate_training_data || H2O.CLOUD.size() == 1) != (!this._run_local)) {
                throw new AssertionError();
            }
        }
        if (this._run_local) {
            this._sharedmodel = this._localmodel;
        } else {
            this._localmodel.add_processed_global(this._localmodel.get_processed_local());
            this._localmodel.set_processed_local(0L);
            if (this._chunk_node_count > 1) {
                this._localmodel.div(this._chunk_node_count);
            }
        }
        if (this._sharedmodel == null) {
            this._sharedmodel = this._localmodel;
        }
        this._localmodel = null;
    }

    static {
        $assertionsDisabled = !DeepWaterTask.class.desiredAssertionStatus();
    }
}
