package ai.djl.nn.transformer;

import ai.djl.MalformedModelException;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.AbstractBlock;
import ai.djl.nn.Block;
import ai.djl.nn.BlockList;
import ai.djl.nn.Parameter;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.function.Function;

/* loaded from: input_file:ai/djl/nn/transformer/TransformerBaseBlock.class */
public abstract class TransformerBaseBlock extends AbstractBlock {
    protected int version;
    protected LinkedHashMap<String, Block> children = new LinkedHashMap<>();
    protected LinkedHashMap<String, Parameter> parameters = new LinkedHashMap<>();
    protected LinkedHashMap<String, Function<Shape[], Shape>> parameterShapeCallbacks = new LinkedHashMap<>();

    public TransformerBaseBlock(int i) {
        this.version = i;
    }

    public int getVersion() {
        return this.version;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public <B extends Block> B addChildBlock(String str, B b) {
        this.children.put(str, b);
        return b;
    }

    protected <P extends Parameter> P addParameter(P p) {
        return (P) addParameter((TransformerBaseBlock) p, (Function<Shape[], Shape>) null);
    }

    protected <P extends Parameter> P addParameter(P p, Shape shape) {
        return (P) addParameter((TransformerBaseBlock) p, shapeArr -> {
            return shape;
        });
    }

    protected <P extends Parameter> P addParameter(P p, Function<Shape[], Shape> function) {
        this.parameters.put(p.getName(), p);
        this.parameterShapeCallbacks.put(p.getName(), function);
        return p;
    }

    @Override // ai.djl.nn.Block
    public Shape getParameterShape(String str, Shape[] shapeArr) {
        Function<Shape[], Shape> function = this.parameterShapeCallbacks.get(str);
        if (function != null) {
            return function.apply(shapeArr);
        }
        if (this.parameters.get(str) == null) {
            throw new IllegalArgumentException("No parameter named " + str + " found in this block.");
        }
        throw new IllegalStateException("No shape initializer for parameter " + str + "found. Either pass an initializer for the shape when adding the parameter or override getParameterShape in the subclass.");
    }

    @Override // ai.djl.nn.Block
    public BlockList getChildren() {
        return new BlockList(this.children);
    }

    @Override // ai.djl.nn.Block
    public Shape[] initialize(NDManager nDManager, DataType dataType, Shape... shapeArr) {
        beforeInitialize(shapeArr);
        Iterator<Parameter> it = getDirectParameters().iterator();
        while (it.hasNext()) {
            it.next().initialize(nDManager, dataType, shapeArr);
        }
        initializeChildBlocks(nDManager, dataType, shapeArr);
        return getOutputShapes(nDManager, shapeArr);
    }

    public abstract void initializeChildBlocks(NDManager nDManager, DataType dataType, Shape... shapeArr);

    @Override // ai.djl.nn.Block
    public List<Parameter> getDirectParameters() {
        return new ArrayList(this.parameters.values());
    }

    @Override // ai.djl.nn.Block
    public void saveParameters(DataOutputStream dataOutputStream) throws IOException {
        dataOutputStream.write(this.version);
        Iterator<Parameter> it = this.parameters.values().iterator();
        while (it.hasNext()) {
            it.next().save(dataOutputStream);
        }
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            it2.next().saveParameters(dataOutputStream);
        }
    }

    @Override // ai.djl.nn.Block
    public void loadParameters(NDManager nDManager, DataInputStream dataInputStream) throws IOException, MalformedModelException {
        int readInt = dataInputStream.readInt();
        if (readInt != getVersion()) {
            throw new MalformedModelException("Cannot load parameters for " + getClass().getCanonicalName() + ", expected version " + getVersion() + ", got " + readInt + ".");
        }
        Iterator<Parameter> it = this.parameters.values().iterator();
        while (it.hasNext()) {
            it.next().load(nDManager, dataInputStream);
        }
        Iterator<Block> it2 = this.children.values().iterator();
        while (it2.hasNext()) {
            it2.next().loadParameters(nDManager, dataInputStream);
        }
    }
}
