/*
 * Decompiled with CFR 0.152.
 */
package com.github.chen0040.art.clustering;

import com.github.chen0040.art.core.ART1;
import com.github.chen0040.data.frame.DataFrame;
import com.github.chen0040.data.frame.DataRow;
import com.github.chen0040.data.utils.transforms.Standardization;

public class ART1Clustering
implements Cloneable {
    private ART1 net;
    private int initialNodeCount = 1;
    private boolean allowNewNodeInPrediction = false;
    private Standardization inputNormalization;
    private double alpha = 0.1;
    private double rho0 = 0.9;
    private double beta = 0.3;

    public Object clone() throws CloneNotSupportedException {
        ART1Clustering clone = (ART1Clustering)super.clone();
        clone.copy(this);
        return clone;
    }

    public void copy(ART1Clustering rhs2) throws CloneNotSupportedException {
        this.net = rhs2.net == null ? null : (ART1)rhs2.net.clone();
        this.initialNodeCount = rhs2.initialNodeCount;
        this.allowNewNodeInPrediction = rhs2.allowNewNodeInPrediction;
        this.inputNormalization = rhs2.inputNormalization == null ? null : (Standardization)rhs2.inputNormalization.clone();
    }

    public int transform(DataRow tuple) {
        return this.simulate(tuple, this.allowNewNodeInPrediction);
    }

    public void fit(DataFrame batch) {
        int dimension = batch.row(0).toArray().length;
        this.inputNormalization = new Standardization(batch);
        this.net = new ART1(dimension, this.initialNodeCount);
        this.net.setAlpha(this.alpha);
        this.net.setBeta(this.beta);
        this.net.setRho(this.rho0);
        int m = batch.rowCount();
        for (int i = 0; i < m; ++i) {
            DataRow tuple = batch.row(i);
            this.simulate(tuple, true);
        }
    }

    public int simulate(DataRow tuple, boolean can_create_new_node) {
        double[] x = tuple.toArray();
        x = this.inputNormalization.standardize(x);
        double[] y = this.binarize(x);
        int clusterId = this.net.simulate(y, can_create_new_node);
        tuple.setCategoricalTargetCell("predicted", String.format("%d", clusterId));
        return clusterId;
    }

    private double[] binarize(double[] x) {
        double[] y = new double[x.length];
        for (int i = 0; i < x.length; ++i) {
            y[i] = x[i] > 0.0 ? 1.0 : 0.0;
        }
        return y;
    }
}

