package com.amazon.randomcutforest.examples.dynamicinference;

import com.amazon.randomcutforest.CommonUtils;
import com.amazon.randomcutforest.PredictiveRandomCutForest;
import com.amazon.randomcutforest.config.TransformMethod;
import com.amazon.randomcutforest.examples.Example;
import com.amazon.randomcutforest.returntypes.SampleSummary;
import com.amazon.randomcutforest.summarization.Summarizer;
import java.io.BufferedWriter;
import java.io.FileWriter;
import java.util.Arrays;
import java.util.Random;

/* loaded from: input_file:com/amazon/randomcutforest/examples/dynamicinference/ConditionalPredictive.class */
public class ConditionalPredictive implements Example {

    /* JADX INFO: Access modifiers changed from: package-private */
    /* loaded from: input_file:com/amazon/randomcutforest/examples/dynamicinference/ConditionalPredictive$NormalDistribution.class */
    public static class NormalDistribution {
        private final Random rng;
        private final double[] buffer = new double[2];
        private int index = 0;

        NormalDistribution(Random random) {
            this.rng = random;
        }

        double nextDouble() {
            if (this.index == 0) {
                double nextDouble = this.rng.nextDouble();
                double nextDouble2 = this.rng.nextDouble();
                double sqrt = Math.sqrt((-2.0d) * Math.log(nextDouble));
                this.buffer[0] = sqrt * Math.cos(6.283185307179586d * nextDouble2);
                this.buffer[1] = sqrt * Math.sin(6.283185307179586d * nextDouble2);
            }
            double d = this.buffer[this.index];
            this.index = (this.index + 1) % 2;
            return d;
        }

        double nextDouble(double d, double d2) {
            return d + (d2 * nextDouble());
        }
    }

    public static void main(String[] strArr) throws Exception {
        new ConditionalPredictive().run();
    }

    @Override // com.amazon.randomcutforest.examples.Example
    public String command() {
        return "Conditional_predictive_example";
    }

    @Override // com.amazon.randomcutforest.examples.Example
    public String description() {
        return "An example that uses imputation for prediction";
    }

    @Override // com.amazon.randomcutforest.examples.Example
    public void run() throws Exception {
        int i = 40 * 256;
        PredictiveRandomCutForest build = new PredictiveRandomCutForest.Builder().inputDimensions(5).randomSeed(0L).numberOfTrees(100).shingleSize(1).sampleSize(256).startNormalization(256 / 2).transformMethod(TransformMethod.NORMALIZE).build();
        new Random().nextLong();
        System.out.println("seed = 17");
        NormalDistribution normalDistribution = new NormalDistribution(new Random(17L));
        Random random = new Random(17 + 10);
        BufferedWriter bufferedWriter = new BufferedWriter(new FileWriter("predictive_example"));
        for (int i2 = 0; i2 < i; i2++) {
            float[] generateRecordKey = generateRecordKey(random);
            float[] copyOf = Arrays.copyOf(generateRecordKey, 5);
            CommonUtils.checkArgument(copyOf[3] == 0.0f, " should not be filled");
            CommonUtils.checkArgument(copyOf[4] == 0.0f, " should not be filled");
            SampleSummary predict = build.predict(copyOf, 0L, new int[]{3, 4});
            fillInValues(copyOf, random, normalDistribution);
            build.update(copyOf, 0L);
            double d = Double.MAX_VALUE;
            for (int i3 = 0; i3 < predict.summaryPoints.length; i3++) {
                d = Math.min(d, Summarizer.L2distance(copyOf, predict.summaryPoints[i3]).doubleValue());
            }
            bufferedWriter.append((CharSequence) (copyOf[0] + " " + copyOf[1] + " " + copyOf[2] + " " + copyOf[3] + " " + copyOf[4] + " " + d + " " + generateRecordKey[5] + "\n"));
        }
        bufferedWriter.close();
    }

    float[] generateRecordKey(Random random) {
        float[] fArr = new float[6];
        double nextDouble = random.nextDouble();
        double nextDouble2 = random.nextDouble();
        double nextDouble3 = random.nextDouble();
        if (nextDouble < 0.8d) {
            fArr[0] = 1.0f;
            if (nextDouble2 < 0.8d) {
                fArr[1] = 19.0f;
                fArr[5] = 0.0f;
            } else {
                fArr[1] = 25.0f;
                fArr[5] = 1.0f;
            }
            fArr[2] = ((float) nextDouble3) * 10.0f;
        } else {
            fArr[0] = 0.0f;
            if (nextDouble2 < 0.3d) {
                fArr[1] = 16.0f;
                fArr[2] = 12.0f;
                fArr[5] = 2.0f;
            } else {
                fArr[1] = 20.0f;
                fArr[2] = 4.0f;
                fArr[5] = 3.0f;
            }
        }
        return fArr;
    }

    void fillInValues(float[] fArr, Random random, NormalDistribution normalDistribution) {
        if (fArr[0] < 0.5d) {
            fArr[3] = (float) (random.nextDouble() < 0.5d ? normalDistribution.nextDouble(20.0d, 5.0d) : normalDistribution.nextDouble(40.0d, 5.0d));
            fArr[4] = (float) normalDistribution.nextDouble(-30.0d, 3.0d);
            return;
        }
        if (fArr[1] < 20.0f) {
            fArr[3] = (float) normalDistribution.nextDouble(30.0d, 10.0d);
            fArr[4] = (float) normalDistribution.nextDouble(-10.0d, 3.0d);
        } else if (fArr[2] < 6.0f) {
            fArr[3] = (float) (random.nextDouble() < 0.3d ? normalDistribution.nextDouble(20.0d, 5.0d) : normalDistribution.nextDouble(40.0d, 3.0d));
            fArr[4] = (float) normalDistribution.nextDouble(-50.0d, 1.0d);
        } else {
            double nextDouble = random.nextDouble();
            fArr[3] = (float) normalDistribution.nextDouble(30.0d, 1.0d);
            fArr[4] = (float) (nextDouble < 0.7d ? normalDistribution.nextDouble(-10.0d, 3.0d) : normalDistribution.nextDouble(-30.0d, 5.0d));
        }
    }
}
