package org.datavec.api.transform.transform.categorical;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import org.datavec.api.transform.ColumnType;
import org.datavec.api.transform.metadata.CategoricalMetaData;
import org.datavec.api.transform.metadata.ColumnMetaData;
import org.datavec.api.transform.schema.Schema;
import org.datavec.api.transform.transform.BaseTransform;
import org.datavec.api.writable.DoubleWritable;
import org.datavec.api.writable.FloatWritable;
import org.datavec.api.writable.IntWritable;
import org.datavec.api.writable.LongWritable;
import org.datavec.api.writable.NDArrayWritable;
import org.datavec.api.writable.NullWritable;
import org.datavec.api.writable.Text;
import org.datavec.api.writable.Writable;

/* loaded from: input_file:org/datavec/api/transform/transform/categorical/PivotTransform.class */
public class PivotTransform extends BaseTransform {
    private final String keyColumn;
    private final String valueColumn;
    private Writable defaultValue;

    /* renamed from: org.datavec.api.transform.transform.categorical.PivotTransform$1, reason: invalid class name */
    /* loaded from: input_file:org/datavec/api/transform/transform/categorical/PivotTransform$1.class */
    static /* synthetic */ class AnonymousClass1 {
        static final /* synthetic */ int[] $SwitchMap$org$datavec$api$transform$ColumnType = new int[ColumnType.values().length];

        static {
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.String.ordinal()] = 1;
            } catch (NoSuchFieldError e) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Integer.ordinal()] = 2;
            } catch (NoSuchFieldError e2) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Long.ordinal()] = 3;
            } catch (NoSuchFieldError e3) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Double.ordinal()] = 4;
            } catch (NoSuchFieldError e4) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Float.ordinal()] = 5;
            } catch (NoSuchFieldError e5) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Categorical.ordinal()] = 6;
            } catch (NoSuchFieldError e6) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Time.ordinal()] = 7;
            } catch (NoSuchFieldError e7) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Bytes.ordinal()] = 8;
            } catch (NoSuchFieldError e8) {
            }
            try {
                $SwitchMap$org$datavec$api$transform$ColumnType[ColumnType.Boolean.ordinal()] = 9;
            } catch (NoSuchFieldError e9) {
            }
        }
    }

    public PivotTransform(String str, String str2) {
        this(str, str2, null);
    }

    public PivotTransform(String str, String str2, Writable writable) {
        this.keyColumn = str;
        this.valueColumn = str2;
        this.defaultValue = writable;
    }

    @Override // org.datavec.api.transform.ColumnOp
    public Schema transform(Schema schema) {
        if (!schema.hasColumn(this.keyColumn) || !schema.hasColumn(this.valueColumn)) {
            throw new UnsupportedOperationException("Key or value column not found: " + this.keyColumn + ", " + this.valueColumn + " in " + schema.getColumnNames());
        }
        List<String> columnNames = schema.getColumnNames();
        List<ColumnMetaData> columnMetaData = schema.getColumnMetaData();
        int i = 0;
        Iterator<ColumnMetaData> it = columnMetaData.iterator();
        ArrayList arrayList = new ArrayList(schema.numColumns());
        int indexOfColumn = schema.getIndexOfColumn(this.keyColumn);
        int indexOfColumn2 = schema.getIndexOfColumn(this.valueColumn);
        ColumnMetaData metaData = schema.getMetaData(indexOfColumn2);
        for (String str : columnNames) {
            ColumnMetaData next = it.next();
            if (i == indexOfColumn) {
                Iterator<String> it2 = ((CategoricalMetaData) schema.getMetaData(indexOfColumn)).getStateNames().iterator();
                while (it2.hasNext()) {
                    String str2 = str + "[" + it2.next() + "]";
                    ColumnMetaData mo40clone = metaData.mo40clone();
                    mo40clone.setName(str2);
                    arrayList.add(mo40clone);
                }
            } else if (i == indexOfColumn2) {
                i++;
            } else {
                arrayList.add(next);
            }
            i++;
        }
        if (this.defaultValue == null) {
            switch (AnonymousClass1.$SwitchMap$org$datavec$api$transform$ColumnType[metaData.getColumnType().ordinal()]) {
                case NDArrayWritable.NDARRAY_SER_VERSION_HEADER /* 1 */:
                    this.defaultValue = new Text("");
                    break;
                case 2:
                    this.defaultValue = new IntWritable(0);
                    break;
                case 3:
                    this.defaultValue = new LongWritable(0L);
                    break;
                case 4:
                    this.defaultValue = new DoubleWritable(0.0d);
                    break;
                case 5:
                    this.defaultValue = new FloatWritable(0.0f);
                    break;
                case 6:
                    this.defaultValue = new NullWritable();
                    break;
                case 7:
                    this.defaultValue = new LongWritable(0L);
                    break;
                case 8:
                    throw new UnsupportedOperationException("Cannot infer default value for bytes");
                case 9:
                    this.defaultValue = new Text("false");
                    break;
                default:
                    throw new UnsupportedOperationException("Cannot infer default value for " + metaData.getColumnType());
            }
        }
        return schema.newSchema(arrayList);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String outputColumnName() {
        throw new UnsupportedOperationException("Output column name will be more than 1");
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] outputColumnNames() {
        List<String> stateNames = ((CategoricalMetaData) this.inputSchema.getMetaData(this.keyColumn)).getStateNames();
        return (String[]) stateNames.toArray(new String[stateNames.size()]);
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String[] columnNames() {
        return new String[]{this.keyColumn, this.valueColumn};
    }

    @Override // org.datavec.api.transform.ColumnOp
    public String columnName() {
        throw new UnsupportedOperationException("Multiple input columns");
    }

    @Override // org.datavec.api.transform.Transform
    public List<Writable> map(List<Writable> list) {
        if (list.size() != this.inputSchema.numColumns()) {
            throw new IllegalStateException("Cannot execute transform: input writables list length (" + list.size() + ") does not match expected number of elements (schema: " + this.inputSchema.numColumns() + "). Transform = " + toString());
        }
        int indexOfColumn = this.inputSchema.getIndexOfColumn(this.keyColumn);
        int indexOfColumn2 = this.inputSchema.getIndexOfColumn(this.valueColumn);
        List<String> stateNames = ((CategoricalMetaData) this.inputSchema.getMetaData(indexOfColumn)).getStateNames();
        int i = 0;
        ArrayList arrayList = new ArrayList();
        for (Writable writable : list) {
            if (i == indexOfColumn) {
                String obj = writable.toString();
                int indexOf = stateNames.indexOf(obj);
                if (indexOf < 0) {
                    throw new RuntimeException("Unknown state (index not found): " + obj);
                }
                for (int i2 = 0; i2 < stateNames.size(); i2++) {
                    if (i2 == indexOf) {
                        arrayList.add(list.get(indexOfColumn2));
                    } else {
                        arrayList.add(this.defaultValue);
                    }
                }
            } else if (i == indexOfColumn2) {
                i++;
            } else {
                arrayList.add(writable);
            }
            i++;
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object map(Object obj) {
        List list = (List) obj;
        Writable writable = (Writable) list.get(0);
        Writable writable2 = (Writable) list.get(1);
        List<String> stateNames = ((CategoricalMetaData) this.inputSchema.getMetaData(this.inputSchema.getIndexOfColumn(this.keyColumn))).getStateNames();
        int size = stateNames.size();
        int indexOf = stateNames.indexOf(writable.toString());
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < size; i++) {
            if (i == indexOf) {
                arrayList.add(writable2);
            } else {
                arrayList.add(this.defaultValue);
            }
        }
        return arrayList;
    }

    @Override // org.datavec.api.transform.Transform
    public Object mapSequence(Object obj) {
        ArrayList arrayList = new ArrayList();
        Iterator it = ((List) obj).iterator();
        while (it.hasNext()) {
            arrayList.add((List) map(it.next()));
        }
        return arrayList;
    }

    public String getKeyColumn() {
        return this.keyColumn;
    }

    public String getValueColumn() {
        return this.valueColumn;
    }

    public Writable getDefaultValue() {
        return this.defaultValue;
    }

    public void setDefaultValue(Writable writable) {
        this.defaultValue = writable;
    }

    public boolean equals(Object obj) {
        if (obj == this) {
            return true;
        }
        if (!(obj instanceof PivotTransform)) {
            return false;
        }
        PivotTransform pivotTransform = (PivotTransform) obj;
        if (!pivotTransform.canEqual(this)) {
            return false;
        }
        String keyColumn = getKeyColumn();
        String keyColumn2 = pivotTransform.getKeyColumn();
        if (keyColumn == null) {
            if (keyColumn2 != null) {
                return false;
            }
        } else if (!keyColumn.equals(keyColumn2)) {
            return false;
        }
        String valueColumn = getValueColumn();
        String valueColumn2 = pivotTransform.getValueColumn();
        if (valueColumn == null) {
            if (valueColumn2 != null) {
                return false;
            }
        } else if (!valueColumn.equals(valueColumn2)) {
            return false;
        }
        Writable defaultValue = getDefaultValue();
        Writable defaultValue2 = pivotTransform.getDefaultValue();
        return defaultValue == null ? defaultValue2 == null : defaultValue.equals(defaultValue2);
    }

    protected boolean canEqual(Object obj) {
        return obj instanceof PivotTransform;
    }

    public int hashCode() {
        String keyColumn = getKeyColumn();
        int hashCode = (1 * 59) + (keyColumn == null ? 43 : keyColumn.hashCode());
        String valueColumn = getValueColumn();
        int hashCode2 = (hashCode * 59) + (valueColumn == null ? 43 : valueColumn.hashCode());
        Writable defaultValue = getDefaultValue();
        return (hashCode2 * 59) + (defaultValue == null ? 43 : defaultValue.hashCode());
    }

    @Override // org.datavec.api.transform.transform.BaseTransform
    public String toString() {
        return "PivotTransform(keyColumn=" + getKeyColumn() + ", valueColumn=" + getValueColumn() + ", defaultValue=" + getDefaultValue() + ")";
    }
}
