package org.deeplearning4j;

import java.lang.management.ManagementFactory;
import java.lang.reflect.Method;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import org.bytedeco.javacpp.Pointer;
import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Assumptions;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.DisplayName;
import org.junit.jupiter.api.TestInfo;
import org.junit.jupiter.api.Timeout;
import org.nd4j.common.base.Preconditions;
import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.memory.MemoryWorkspace;
import org.nd4j.linalg.api.ops.executioner.OpExecutioner;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.profiler.ProfilerConfig;
import org.slf4j.ILoggerFactory;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@DisplayName("Base DL 4 J Test")
/* loaded from: input_file:org/deeplearning4j/BaseDL4JTest.class */
public abstract class BaseDL4JTest {
    private static Logger log = LoggerFactory.getLogger(BaseDL4JTest.class.getName());
    protected long startTime;
    protected int threadCountBefore;
    private final int DEFAULT_THREADS = Runtime.getRuntime().availableProcessors();
    protected static Boolean integrationTest;

    public int numThreads() {
        return this.DEFAULT_THREADS;
    }

    public long getTimeoutMilliseconds() {
        return 90000L;
    }

    public OpExecutioner.ProfilingMode getProfilingMode() {
        return OpExecutioner.ProfilingMode.SCOPE_PANIC;
    }

    public DataType getDataType() {
        return DataType.DOUBLE;
    }

    public DataType getDefaultFPDataType() {
        return getDataType();
    }

    public static boolean isIntegrationTests() {
        if (integrationTest == null) {
            integrationTest = Boolean.valueOf(Boolean.parseBoolean(System.getenv("DL4J_INTEGRATION_TESTS")));
        }
        return integrationTest.booleanValue();
    }

    public static void skipUnlessIntegrationTests() {
        Assumptions.assumeTrue(isIntegrationTests(), "Skipping integration test - integration profile is not enabled");
    }

    @BeforeEach
    @Timeout(90000)
    void beforeTest(TestInfo testInfo) {
        log.info("{}.{}", getClass().getSimpleName(), ((Method) testInfo.getTestMethod().get()).getName());
        System.setProperty("org.nd4j.log.initialization", "false");
        System.setProperty("org.nd4j.avx.ignore", "true");
        Nd4j.getExecutioner().setProfilingMode(getProfilingMode());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.setDefaultDataTypes(getDataType(), getDefaultFPDataType());
        Nd4j.getExecutioner().setProfilingConfig(ProfilerConfig.builder().build());
        Nd4j.getExecutioner().enableDebugMode(false);
        Nd4j.getExecutioner().enableVerboseMode(false);
        int numThreads = numThreads();
        Preconditions.checkState(numThreads > 0, "Number of threads must be > 0");
        if (numThreads != Nd4j.getEnvironment().maxMasterThreads()) {
            Nd4j.getEnvironment().setMaxMasterThreads(numThreads);
        }
        this.startTime = System.currentTimeMillis();
        this.threadCountBefore = ManagementFactory.getThreadMXBean().getThreadCount();
    }

    @AfterEach
    void afterTest(TestInfo testInfo) {
        Nd4j.getWorkspaceManager().destroyAllWorkspacesForCurrentThread();
        MemoryWorkspace currentWorkspace = Nd4j.getMemoryManager().getCurrentWorkspace();
        Nd4j.getMemoryManager().setCurrentWorkspace((MemoryWorkspace) null);
        if (currentWorkspace != null) {
            log.error("Open workspace leaked from test! Exiting - {}, isOpen = {} - {}", new Object[]{currentWorkspace.getId(), Boolean.valueOf(currentWorkspace.isScopeActive()), currentWorkspace});
            System.out.println("Open workspace leaked from test! Exiting - " + currentWorkspace.getId() + ", isOpen = " + currentWorkspace.isScopeActive() + " - " + currentWorkspace);
            System.out.flush();
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e) {
            }
            ILoggerFactory iLoggerFactory = LoggerFactory.getILoggerFactory();
            if (iLoggerFactory.getClass().getName().equals("ch.qos.logback.classic.LoggerContext")) {
                Method method = iLoggerFactory.getClass().getMethod("stop", new Class[0]);
                method.setAccessible(true);
                method.invoke(iLoggerFactory, new Object[0]);
            }
            try {
                Thread.sleep(1000L);
            } catch (InterruptedException e2) {
            }
            System.exit(1);
        }
        StringBuilder sb = new StringBuilder();
        sb.append(getClass().getSimpleName()).append(".").append(((Method) testInfo.getTestMethod().get()).getName()).append(": ").append(System.currentTimeMillis() - this.startTime).append(" ms").append(", threadCount: (").append(this.threadCountBefore).append("->").append(ManagementFactory.getThreadMXBean().getThreadCount()).append(")").append(", jvmTotal=").append(Runtime.getRuntime().totalMemory()).append(", jvmMax=").append(Runtime.getRuntime().maxMemory()).append(", totalBytes=").append(Pointer.totalBytes()).append(", maxBytes=").append(Pointer.maxBytes()).append(", currPhys=").append(Pointer.physicalBytes()).append(", maxPhys=").append(Pointer.maxPhysicalBytes());
        List allWorkspacesForCurrentThread = Nd4j.getWorkspaceManager().getAllWorkspacesForCurrentThread();
        if (allWorkspacesForCurrentThread != null && allWorkspacesForCurrentThread.size() > 0) {
            long j = 0;
            Iterator it = allWorkspacesForCurrentThread.iterator();
            while (it.hasNext()) {
                j += ((MemoryWorkspace) it.next()).getCurrentSize();
            }
            if (j > 0) {
                sb.append(", threadWSSize=").append(j).append(" (").append(allWorkspacesForCurrentThread.size()).append(" WSs)");
            }
        }
        Object obj = Nd4j.getExecutioner().getEnvironmentInformation().get("cuda.devicesInformation");
        if (obj instanceof List) {
            List list = (List) obj;
            if (list.size() > 0) {
                sb.append(" [").append(list.size()).append(" GPUs: ");
                for (int i = 0; i < list.size(); i++) {
                    Map map = (Map) list.get(i);
                    if (i > 0) {
                        sb.append(",");
                    }
                    sb.append("(").append(map.get("cuda.freeMemory")).append(" free, ").append(map.get("cuda.totalMemory")).append(" total)");
                }
                sb.append("]");
            }
        }
        log.info(sb.toString());
    }
}
