package ai.djl.mxnet.engine;

import ai.djl.Device;
import ai.djl.MalformedModelException;
import ai.djl.Model;
import ai.djl.inference.Predictor;
import ai.djl.mxnet.jna.JnaUtils;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.ndarray.NDManager;
import ai.djl.ndarray.types.DataType;
import ai.djl.ndarray.types.Shape;
import ai.djl.nn.Block;
import ai.djl.nn.Parameter;
import ai.djl.training.Trainer;
import ai.djl.training.TrainingConfig;
import ai.djl.translate.Translator;
import ai.djl.util.Pair;
import ai.djl.util.PairList;
import ai.djl.util.Utils;
import java.io.DataInputStream;
import java.io.DataOutputStream;
import java.io.FileNotFoundException;
import java.io.IOException;
import java.io.InputStream;
import java.net.URL;
import java.nio.charset.StandardCharsets;
import java.nio.file.FileVisitOption;
import java.nio.file.Files;
import java.nio.file.LinkOption;
import java.nio.file.OpenOption;
import java.nio.file.Path;
import java.nio.file.attribute.FileAttribute;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/djl/mxnet/engine/MxModel.class */
public class MxModel implements Model {
    private static final Logger logger = LoggerFactory.getLogger(MxModel.class);
    private static final int MODEL_VERSION = 1;
    private Path modelDir;
    private String modelName;
    private MxNDManager manager;
    private Block block;
    private DataType dataType;
    private Map<String, String> properties;
    private PairList<String, Shape> inputData;
    private Map<String, Object> artifacts = new ConcurrentHashMap();
    private AtomicBoolean first;

    /* JADX INFO: Access modifiers changed from: package-private */
    public MxModel(Device device) {
        Device defaultIfNull = Device.defaultIfNull(device);
        this.dataType = DataType.FLOAT32;
        this.properties = new ConcurrentHashMap();
        this.manager = MxNDManager.getSystemManager().mo6newSubManager(defaultIfNull);
        this.first = new AtomicBoolean(true);
    }

    public void load(Path path, String str, Map<String, String> map) throws IOException, MalformedModelException {
        this.modelDir = path.toAbsolutePath();
        this.modelName = str;
        if (this.block == null) {
            Path resolve = this.modelDir.resolve(str + "-symbol.json");
            if (Files.notExists(resolve, new LinkOption[0])) {
                throw new FileNotFoundException("Symbol file not found in: " + path + ", please set block manually.");
            }
            this.block = new MxSymbolBlock(this.manager, Symbol.load(this.manager, resolve.toAbsolutePath().toString()));
        }
        loadParameters(str, map);
    }

    public void save(Path path, String str) throws IOException {
        if (Files.notExists(path, new LinkOption[0])) {
            Files.createDirectories(path, new FileAttribute[0]);
        }
        if (this.block == null || !this.block.isInitialized()) {
            throw new IllegalStateException("Model has not be trained or loaded yet.");
        }
        String property = getProperty("Epoch");
        DataOutputStream dataOutputStream = new DataOutputStream(Files.newOutputStream(path.resolve(String.format("%s-%04d.params", str, Integer.valueOf(property == null ? Utils.getCurrentEpoch(path, str) + MODEL_VERSION : Integer.parseInt(property)))), new OpenOption[0]));
        Throwable th = null;
        try {
            try {
                dataOutputStream.writeBytes("DJL@");
                dataOutputStream.writeInt(MODEL_VERSION);
                dataOutputStream.writeUTF(str);
                dataOutputStream.writeUTF(this.dataType.name());
                this.inputData = this.block.describeInput();
                dataOutputStream.writeInt(this.inputData.size());
                Iterator it = this.inputData.iterator();
                while (it.hasNext()) {
                    Pair pair = (Pair) it.next();
                    String str2 = (String) pair.getKey();
                    if (str2 == null) {
                        dataOutputStream.writeUTF("");
                    } else {
                        dataOutputStream.writeUTF(str2);
                    }
                    dataOutputStream.write(((Shape) pair.getValue()).getEncoded());
                }
                dataOutputStream.writeInt(this.properties.size());
                for (Map.Entry<String, String> entry : this.properties.entrySet()) {
                    dataOutputStream.writeUTF(entry.getKey());
                    dataOutputStream.writeUTF(entry.getValue());
                }
                this.block.saveParameters(dataOutputStream);
                if (dataOutputStream != null) {
                    if (0 != 0) {
                        try {
                            dataOutputStream.close();
                        } catch (Throwable th2) {
                            th.addSuppressed(th2);
                        }
                    } else {
                        dataOutputStream.close();
                    }
                }
                this.modelName = str;
                this.modelDir = path.toAbsolutePath();
            } finally {
            }
        } catch (Throwable th3) {
            if (dataOutputStream != null) {
                if (th != null) {
                    try {
                        dataOutputStream.close();
                    } catch (Throwable th4) {
                        th.addSuppressed(th4);
                    }
                } else {
                    dataOutputStream.close();
                }
            }
            throw th3;
        }
    }

    public Block getBlock() {
        return this.block;
    }

    public void setBlock(Block block) {
        this.block = block;
    }

    public String getName() {
        return this.modelName;
    }

    public Trainer newTrainer(TrainingConfig trainingConfig) {
        this.block.setInitializer(trainingConfig.getInitializer());
        return new MxTrainer(this, trainingConfig);
    }

    public <I, O> Predictor<I, O> newPredictor(Translator<I, O> translator) {
        return new MxPredictor(this, translator, (JnaUtils.useThreadSafePredictor() || this.first.getAndSet(false)) ? false : true);
    }

    public void setDataType(DataType dataType) {
        this.dataType = dataType;
    }

    public DataType getDataType() {
        return this.dataType;
    }

    public void cast(DataType dataType) {
        throw new UnsupportedOperationException("Not implemented yet.");
    }

    public PairList<String, Shape> describeInput() {
        if (this.inputData == null) {
            this.inputData = this.block.describeInput();
        }
        return this.inputData;
    }

    public PairList<String, Shape> describeOutput() {
        return new PairList<>(this.inputData.keys(), Arrays.asList(this.block.getOutputShapes(this.manager, (Shape[]) this.inputData.values().toArray(new Shape[this.inputData.size()]))));
    }

    public String[] getArtifactNames() {
        try {
            List<Path> list = (List) Files.walk(this.modelDir, new FileVisitOption[0]).filter(path -> {
                return Files.isRegularFile(path, new LinkOption[0]);
            }).collect(Collectors.toList());
            ArrayList arrayList = new ArrayList(list.size());
            for (Path path2 : list) {
                String name = path2.toFile().getName();
                if (!name.endsWith(".params") && !name.endsWith("-symbol.json")) {
                    arrayList.add(this.modelDir.relativize(path2).toString());
                }
            }
            return (String[]) arrayList.toArray(new String[0]);
        } catch (IOException e) {
            throw new AssertionError("Failed list files", e);
        }
    }

    public <T> T getArtifact(String str, Function<InputStream, T> function) throws IOException {
        try {
            return (T) this.artifacts.computeIfAbsent(str, str2 -> {
                try {
                    InputStream artifactAsStream = getArtifactAsStream(str);
                    Throwable th = null;
                    try {
                        try {
                            Object apply = function.apply(artifactAsStream);
                            if (artifactAsStream != null) {
                                if (0 != 0) {
                                    try {
                                        artifactAsStream.close();
                                    } catch (Throwable th2) {
                                        th.addSuppressed(th2);
                                    }
                                } else {
                                    artifactAsStream.close();
                                }
                            }
                            return apply;
                        } finally {
                        }
                    } finally {
                    }
                } catch (IOException e) {
                    throw new IllegalStateException(e);
                }
            });
        } catch (RuntimeException e) {
            if (e.getCause() instanceof IOException) {
                throw ((IOException) e.getCause());
            }
            throw e;
        }
    }

    public URL getArtifact(String str) throws IOException {
        if (str == null) {
            throw new IllegalArgumentException("artifactName cannot be null");
        }
        Path resolve = this.modelDir.resolve(str);
        if (Files.exists(resolve, new LinkOption[0]) && Files.isReadable(resolve)) {
            return resolve.toUri().toURL();
        }
        throw new FileNotFoundException("File not found: " + resolve);
    }

    public InputStream getArtifactAsStream(String str) throws IOException {
        return getArtifact(str).openStream();
    }

    public NDManager getNDManager() {
        return this.manager;
    }

    public void setProperty(String str, String str2) {
        this.properties.put(str, str2);
    }

    public String getProperty(String str) {
        return this.properties.get(str);
    }

    public void close() {
        JnaUtils.waitAll();
        this.manager.close();
    }

    protected void finalize() throws Throwable {
        if (this.manager.isOpen()) {
            logger.warn("MxModel was not closed explicitly.");
            this.manager.close();
        }
        super.finalize();
    }

    private void loadParameters(String str, Map<String, String> map) throws IOException, MalformedModelException {
        int parseInt;
        Path resolve;
        if (Files.isRegularFile(this.modelDir, new LinkOption[0])) {
            resolve = this.modelDir;
        } else {
            String str2 = null;
            if (map != null) {
                str2 = map.get("epoch");
            }
            if (str2 == null) {
                parseInt = Utils.getCurrentEpoch(this.modelDir, str);
                if (parseInt == -1) {
                    throw new IOException("Parameter file not found in: " + this.modelDir + ". If you only specified model path, make sure path name matchyour saved model file name.");
                }
            } else {
                parseInt = Integer.parseInt(str2);
            }
            resolve = this.modelDir.resolve(String.format("%s-%04d.params", str, Integer.valueOf(parseInt)));
        }
        logger.debug("Try to load model from {}", resolve);
        if (readParameters(resolve)) {
            return;
        }
        logger.debug("DJL formatted model not found, try to find MXNet model");
        NDList loadNdArray = JnaUtils.loadNdArray(this.manager, resolve.toAbsolutePath(), this.manager.getDevice());
        MxSymbolBlock mxSymbolBlock = this.block;
        List<Parameter> allParameters = mxSymbolBlock.getAllParameters();
        LinkedHashMap linkedHashMap = new LinkedHashMap();
        allParameters.forEach(parameter -> {
        });
        Iterator it = loadNdArray.iterator();
        while (it.hasNext()) {
            NDArray nDArray = (NDArray) it.next();
            String name = nDArray.getName();
            if (name == null) {
                throw new IllegalArgumentException("Array names must be present in parameter file");
            }
            ((Parameter) linkedHashMap.remove(name.split(":", 2)[MODEL_VERSION])).setArray(nDArray);
        }
        mxSymbolBlock.setInputNames(new ArrayList(linkedHashMap.keySet()));
        this.dataType = loadNdArray.head().getDataType();
        logger.debug("MXNet Model {} ({}) loaded successfully.", str, this.dataType);
    }

    private boolean readParameters(Path path) throws IOException, MalformedModelException {
        DataInputStream dataInputStream = new DataInputStream(Files.newInputStream(path, new OpenOption[0]));
        Throwable th = null;
        try {
            byte[] bArr = new byte[4];
            dataInputStream.readFully(bArr);
            if (!"DJL@".equals(new String(bArr, StandardCharsets.US_ASCII))) {
                return false;
            }
            int readInt = dataInputStream.readInt();
            if (readInt != MODEL_VERSION) {
                throw new IOException("Unsupported model version: " + readInt);
            }
            this.modelName = dataInputStream.readUTF();
            logger.debug("Loading model parameter: {}", this.modelName);
            this.dataType = DataType.valueOf(dataInputStream.readUTF());
            int readInt2 = dataInputStream.readInt();
            this.inputData = new PairList<>();
            for (int i = 0; i < readInt2; i += MODEL_VERSION) {
                this.inputData.add(dataInputStream.readUTF(), Shape.decode(dataInputStream));
            }
            int readInt3 = dataInputStream.readInt();
            for (int i2 = 0; i2 < readInt3; i2 += MODEL_VERSION) {
                this.properties.put(dataInputStream.readUTF(), dataInputStream.readUTF());
            }
            this.block.loadParameters(this.manager, dataInputStream);
            logger.debug("DJL model loaded successfully");
            if (dataInputStream == null) {
                return true;
            }
            if (0 == 0) {
                dataInputStream.close();
                return true;
            }
            try {
                dataInputStream.close();
                return true;
            } catch (Throwable th2) {
                th.addSuppressed(th2);
                return true;
            }
        } finally {
            if (dataInputStream != null) {
                if (0 != 0) {
                    try {
                        dataInputStream.close();
                    } catch (Throwable th3) {
                        th.addSuppressed(th3);
                    }
                } else {
                    dataInputStream.close();
                }
            }
        }
    }

    public String toString() {
        StringBuilder sb = new StringBuilder(200);
        sb.append("Model (\n\tName: ").append(this.modelName);
        if (this.modelDir != null) {
            sb.append("\n\tModel location: ").append(this.modelDir.toAbsolutePath());
        }
        sb.append("\n\tData Type: ").append(this.dataType);
        for (Map.Entry<String, String> entry : this.properties.entrySet()) {
            sb.append("\n\t").append(entry.getKey()).append(": ").append(entry.getValue());
        }
        sb.append("\n)");
        return sb.toString();
    }
}
