package com.gengoai.apollo.math.linalg;

import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.fasterxml.jackson.annotation.JsonValue;
import com.gengoai.Copyable;
import com.gengoai.Validation;
import java.io.Serializable;
import java.util.Arrays;
import java.util.Objects;
import java.util.stream.Collectors;
import java.util.stream.IntStream;

/* loaded from: input_file:com/gengoai/apollo/math/linalg/Shape.class */
public class Shape implements Serializable, Copyable<Shape> {
    private static final long serialVersionUID = 1;
    public final int matrixLength;
    public final int sliceLength;

    @JsonValue
    final int[] shape;

    public static Shape empty() {
        return new Shape();
    }

    public static Shape shape(int... iArr) {
        return new Shape(iArr);
    }

    private Shape() {
        this.shape = new int[4];
        this.matrixLength = 0;
        this.sliceLength = 0;
    }

    @JsonCreator
    public Shape(@JsonProperty int... iArr) {
        this.shape = new int[4];
        if (iArr == null || iArr.length <= 0) {
            this.shape[2] = 1;
            this.shape[3] = 1;
            this.sliceLength = 1;
            this.matrixLength = 1;
        } else {
            System.arraycopy(iArr, 0, this.shape, this.shape.length - iArr.length, iArr.length);
            this.shape[2] = Math.max(1, this.shape[2]);
            this.shape[3] = Math.max(1, this.shape[3]);
            this.sliceLength = Math.max(1, this.shape[0]) * Math.max(1, this.shape[1]);
            this.matrixLength = this.shape[2] * this.shape[3];
        }
        Validation.checkArgument(this.shape[0] >= 0, "Invalid Kernel: " + this.shape[0]);
        Validation.checkArgument(this.shape[1] >= 0, "Invalid Channel: " + this.shape[1]);
        Validation.checkArgument(this.shape[2] >= 0, "Invalid Row: " + this.shape[2]);
        Validation.checkArgument(this.shape[3] >= 0, "Invalid Column: " + this.shape[3]);
    }

    public int channels() {
        return this.shape[1];
    }

    public int columns() {
        return this.shape[3];
    }

    /* renamed from: copy, reason: merged with bridge method [inline-methods] */
    public Shape m4copy() {
        return new Shape(this.shape);
    }

    public boolean equals(Object obj) {
        if (this == obj) {
            return true;
        }
        if (obj == null || getClass() != obj.getClass()) {
            return false;
        }
        return Objects.deepEquals(this.shape, ((Shape) obj).shape);
    }

    public int hashCode() {
        return Arrays.hashCode(this.shape);
    }

    public boolean isColumnVector() {
        return this.shape[0] == 0 && this.shape[1] == 0 && this.shape[2] > 1 && this.shape[3] == 1;
    }

    public boolean isRowVector() {
        return this.shape[0] == 0 && this.shape[1] == 0 && this.shape[2] == 1 && this.shape[3] > 1;
    }

    public boolean isScalar() {
        return this.shape[0] == 0 && this.shape[1] == 0 && this.shape[2] == 1 && this.shape[3] == 1;
    }

    public boolean isSquare() {
        return this.shape[0] == 0 && this.shape[1] == 0 && this.shape[2] == this.shape[3];
    }

    public boolean isTensor() {
        return this.shape[0] > 0 || this.shape[1] > 0;
    }

    public boolean isVector() {
        if (this.shape[0] == 0 && this.shape[1] == 0) {
            if ((this.shape[2] > 0) ^ (this.shape[3] > 0)) {
                return true;
            }
        }
        return false;
    }

    public int get(int i) {
        return this.shape[i];
    }

    public int kernels() {
        return this.shape[0];
    }

    public int matrixIndex(int i, int i2) {
        return i + (this.shape[2] * i2);
    }

    public int order() {
        int i = 0;
        for (int i2 : this.shape) {
            i += i2 >= 1 ? 1 : 0;
        }
        return i;
    }

    public void reshape(int... iArr) {
        Shape shape = new Shape(iArr);
        if (this.sliceLength != shape.sliceLength) {
            throw new IllegalArgumentException("Invalid slice length: " + this.sliceLength + " != " + shape.sliceLength);
        }
        if (this.matrixLength != shape.matrixLength) {
            throw new IllegalArgumentException("Invalid matrix length: " + this.matrixLength + " != " + shape.matrixLength);
        }
        System.arraycopy(shape.shape, 0, this.shape, 0, this.shape.length);
    }

    public int rows() {
        return this.shape[2];
    }

    public int sliceIndex(int i, int i2) {
        return i + (this.shape[0] * i2);
    }

    public int toChannel(int i) {
        return i / this.shape[0];
    }

    public int toColumn(int i) {
        return i / this.shape[2];
    }

    public int toKernel(int i) {
        return i % this.shape[0];
    }

    public int toMatrixIndex(long j) {
        return (int) (j / this.sliceLength);
    }

    public int toRow(int i) {
        return i % this.shape[2];
    }

    public int toSliceIndex(long j) {
        return (int) (j % this.sliceLength);
    }

    public String toString() {
        return "(" + ((String) IntStream.of(this.shape).filter(i -> {
            return i > 0;
        }).mapToObj(Integer::toString).collect(Collectors.joining(", "))) + ")";
    }
}
