package com.github.chen0040.art.clustering;

import com.github.chen0040.art.core.ART1;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.transforms.Standardization;

/* loaded from: input_file:com/github/chen0040/art/clustering/ART1Clustering.class */
public class ART1Clustering implements Cloneable {
    private ART1 net;
    private Standardization inputNormalization;
    private int initialNodeCount = 1;
    private boolean allowNewNodeInPrediction = false;
    private double alpha = 0.1d;
    private double rho0 = 0.9d;
    private double beta = 0.3d;

    public Object clone() throws CloneNotSupportedException {
        ART1Clustering aRT1Clustering = (ART1Clustering) super.clone();
        aRT1Clustering.copy(this);
        return aRT1Clustering;
    }

    public void copy(ART1Clustering aRT1Clustering) throws CloneNotSupportedException {
        this.net = aRT1Clustering.net == null ? null : (ART1) aRT1Clustering.net.clone();
        this.initialNodeCount = aRT1Clustering.initialNodeCount;
        this.allowNewNodeInPrediction = aRT1Clustering.allowNewNodeInPrediction;
        this.inputNormalization = aRT1Clustering.inputNormalization == null ? null : (Standardization) aRT1Clustering.inputNormalization.clone();
    }

    public int transform(DataRow dataRow) {
        return simulate(dataRow, this.allowNewNodeInPrediction);
    }

    public void fit(DataFrame dataFrame) {
        int length = dataFrame.row(0).toArray().length;
        this.inputNormalization = new Standardization(dataFrame);
        this.net = new ART1(length, this.initialNodeCount);
        this.net.setAlpha(this.alpha);
        this.net.setBeta(this.beta);
        this.net.setRho(this.rho0);
        int rowCount = dataFrame.rowCount();
        for (int i = 0; i < rowCount; i++) {
            simulate(dataFrame.row(i), true);
        }
    }

    public int simulate(DataRow dataRow, boolean z) {
        int simulate = this.net.simulate(binarize(this.inputNormalization.standardize(dataRow.toArray())), z);
        dataRow.setCategoricalTargetCell("predicted", String.format("%d", Integer.valueOf(simulate)));
        return simulate;
    }

    private double[] binarize(double[] dArr) {
        double[] dArr2 = new double[dArr.length];
        for (int i = 0; i < dArr.length; i++) {
            if (dArr[i] > 0.0d) {
                dArr2[i] = 1.0d;
            } else {
                dArr2[i] = 0.0d;
            }
        }
        return dArr2;
    }
}
