package ai.djl.nn.core;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.LayoutType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Parameter;
import ai.djl.nn.ParameterBlock;
import ai.djl.nn.ParameterType;
import ai.djl.training.ParameterStore;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;

/* loaded from: input_file:ai/djl/nn/core/Linear.class */
public class Linear extends ParameterBlock {
    private static final byte VERSION = 3;
    private long outChannels;
    private long inputDimension;
    private boolean flatten;
    private Shape inputShape;
    private Parameter weight = new Parameter("weight", this, ParameterType.WEIGHT);
    private Parameter bias;

    /* loaded from: input_file:ai/djl/nn/core/Linear$Builder.class */
    public static final class Builder {
        private long outChannels;
        private boolean bias = true;
        private boolean flatten;

        Builder() {
        }

        public Builder setOutChannels(long j) {
            this.outChannels = j;
            return this;
        }

        public Builder optBias(boolean z) {
            this.bias = z;
            return this;
        }

        public Builder optFlatten(boolean z) {
            this.flatten = z;
            return this;
        }

        public Linear build() {
            if (this.outChannels == 0) {
                throw new IllegalArgumentException("You must specify outChannels");
            }
            return new Linear(this);
        }
    }

    Linear(Builder builder) {
        this.outChannels = builder.outChannels;
        this.flatten = builder.flatten;
        if (builder.bias) {
            this.bias = new Parameter("bias", this, ParameterType.BIAS);
        }
    }

    @Override // ai.djl.nn.Block
    public NDList forward(ParameterStore parameterStore, NDList nDList, boolean z, PairList<String, Object> pairList) {
        NDList opInputs = opInputs(parameterStore, nDList);
        return opInputs.head().getNDArrayInternal().fullyConnected(opInputs, this.outChannels, this.flatten, this.bias == null, pairList);
    }

    @Override // ai.djl.nn.Block
    public Shape[] getOutputShapes(NDManager nDManager, Shape[] shapeArr) {
        return this.flatten ? new Shape[]{new Shape(shapeArr[0].get(0), this.outChannels)} : new Shape[]{this.inputShape.addAll(new Shape(this.outChannels))};
    }

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return this.bias != null ? Arrays.asList(this.weight, this.bias) : Collections.singletonList(this.weight);
    }

    @Override // ai.djl.nn.AbstractBlock, ai.djl.nn.Block
    public PairList<String, Shape> describeInput() {
        return new PairList<>(Collections.singletonList("linearInput"), Collections.singletonList(this.inputShape));
    }

    @Override // ai.djl.nn.AbstractBlock
    public void beforeInitialize(Shape[] shapeArr) {
        Shape shape;
        this.inputShapes = shapeArr;
        Shape shape2 = shapeArr[0];
        if (!this.flatten) {
            this.inputDimension = shape2.get(shape2.dimension() - 1);
            this.inputShape = shape2.slice(0, shape2.dimension() - 1);
            return;
        }
        if (shape2.isLayoutKnown()) {
            shape = shape2.filterByLayoutType(layoutType -> {
                return !layoutType.equals(LayoutType.BATCH);
            });
            this.inputShape = shape2.map(pair -> {
                return new Pair(((LayoutType) pair.getValue()).equals(LayoutType.BATCH) ? -1L : (Long) pair.getKey(), pair.getValue());
            });
        } else if (shape2.dimension() > 1) {
            shape = shape2.slice(1);
            this.inputShape = new Shape(new long[]{-1}, new LayoutType[]{LayoutType.BATCH}).addAll(shape2.slice(1));
        } else {
            shape = shape2;
            this.inputShape = shape2;
        }
        this.inputDimension = shape.size();
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        boolean z = -1;
        switch (str.hashCode()) {
            case -791592328:
                if (str.equals("weight")) {
                    z = false;
                    break;
                }
                break;
            case 3023545:
                if (str.equals("bias")) {
                    z = true;
                    break;
                }
                break;
        }
        switch (z) {
            case false:
                return new Shape(this.outChannels, this.inputDimension);
            case true:
                return new Shape(this.outChannels);
            default:
                throw new IllegalArgumentException("Invalid parameter name");
        }
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.writeByte(3);
        dataOutputStream.writeLong(this.outChannels);
        dataOutputStream.writeBoolean(this.flatten);
        dataOutputStream.writeLong(this.inputDimension);
        dataOutputStream.write(this.inputShape.getEncoded());
        this.weight.save(dataOutputStream);
        if (this.bias != null) {
            this.bias.save(dataOutputStream);
        }
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        byte readByte = dataInputStream.readByte();
        if (readByte < 1 || readByte > 3) {
            throw new MalformedModelException("Unsupported encoding version: " + ((int) readByte));
        }
        if (readByte == 3) {
            this.outChannels = dataInputStream.readLong();
            this.flatten = dataInputStream.readBoolean();
            this.inputDimension = dataInputStream.readLong();
        } else if (readByte == 2) {
            this.flatten = dataInputStream.readBoolean();
            this.inputDimension = dataInputStream.readLong();
        } else {
            this.flatten = false;
            this.inputDimension = Shape.decode(dataInputStream).size();
        }
        this.inputShape = Shape.decode(dataInputStream);
        this.weight.load(nDManager, dataInputStream);
        if (this.bias != null) {
            this.bias.load(nDManager, dataInputStream);
        }
    }

    private NDList opInputs(ParameterStore parameterStore, NDList nDList) {
        if (nDList.size() != 1) {
            throw new IllegalArgumentException("Linear requires exactly 1 NDArray");
        }
        Device device = nDList.head().getDevice();
        NDList nDList2 = new NDList(nDList);
        nDList2.add(parameterStore.getValue(this.weight, device));
        if (this.bias != null) {
            nDList2.add(parameterStore.getValue(this.bias, device));
        }
        return nDList2;
    }

    public static Builder builder() {
        return new Builder();
    }
}
