package ai.databand;

import ai.databand.config.DbndConfig;
import ai.databand.log.HistogramRequest;
import java.lang.reflect.Method;
import java.util.ArrayDeque;
import java.util.Arrays;
import java.util.Deque;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import javassist.ClassPool;
import javassist.Loader;
import org.apache.log4j.Level;
import org.apache.log4j.PatternLayout;
import org.apache.log4j.spi.LoggingEvent;
import org.apache.spark.scheduler.SparkListenerEvent;
import org.apache.spark.scheduler.SparkListenerStageCompleted;
import org.apache.spark.sql.Dataset;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/* loaded from: input_file:ai/databand/DbndWrapper.class */
public class DbndWrapper {
    private DbndRun run;
    private boolean pipelineInitialized;
    private static final Logger LOG = LoggerFactory.getLogger(DbndWrapper.class);
    private static final DbndWrapper INSTANCE = new DbndWrapper();
    private final DbndConfig config = new DbndConfig();
    private final DbndClient dbnd = new DbndClient(this.config);
    private final Map<String, Method> methodsCache = new HashMap(1);
    private final Deque<String> stack = new ArrayDeque(1);
    private final Set<String> loadedClasses = new HashSet(1);

    public static DbndWrapper instance() {
        return INSTANCE;
    }

    public DbndWrapper() {
        DbndLogAppender dbndLogAppender = new DbndLogAppender(this);
        dbndLogAppender.setLayout(new PatternLayout("[%d] {%c{2}} %p - %m%n"));
        dbndLogAppender.setThreshold(Level.INFO);
        dbndLogAppender.activateOptions();
        org.apache.log4j.Logger.getLogger("org.apache.spark").addAppender(dbndLogAppender);
        org.apache.log4j.Logger.getLogger("org.spark_project").addAppender(dbndLogAppender);
        org.apache.log4j.Logger.getLogger("ai.databand").addAppender(dbndLogAppender);
    }

    public Optional<Class<?>> loadClass(String str) {
        try {
            return Optional.of(Class.forName(str));
        } catch (ClassNotFoundException e) {
            try {
                return Optional.of(new Loader(ClassPool.getDefault()).loadClass(str));
            } catch (ClassNotFoundException e2) {
                return Optional.empty();
            }
        }
    }

    public void beforePipeline(String str, String str2, Object[] objArr) {
        Method findMethodByName = findMethodByName(str2, str);
        if (findMethodByName == null) {
            this.pipelineInitialized = false;
            return;
        }
        System.out.println("Running Databand!");
        System.out.printf("TRACKER URL: %s%n", this.config.databandUrl());
        System.out.printf("CMD: %s%n", this.config.cmd());
        getOrCreateRun(findMethodByName, objArr);
        this.pipelineInitialized = true;
    }

    protected Method findMethodByName(String str, String str2) {
        if (str2 != null && !this.loadedClasses.contains(str2)) {
            loadMethods(str2);
        }
        String substring = str.substring(0, str.indexOf("(") + 1);
        for (Map.Entry<String, Method> entry : this.methodsCache.entrySet()) {
            if (entry.getKey().contains(substring)) {
                return entry.getValue();
            }
        }
        return null;
    }

    protected void loadMethods(String str) {
        Optional<Class<?>> loadClass = loadClass(str);
        if (!loadClass.isPresent()) {
            LOG.error("Unable to build method cache for class {}", str);
            this.pipelineInitialized = false;
            return;
        }
        for (Method method : loadClass.get().getDeclaredMethods()) {
            this.methodsCache.put(method.toGenericString(), method);
        }
        this.loadedClasses.add(str);
    }

    public void afterPipeline() {
        currentRun().stop();
        cleanup();
    }

    public void errorPipeline(Throwable th) {
        currentRun().error(th);
        cleanup();
    }

    protected void cleanup() {
        this.run = null;
        this.methodsCache.clear();
        this.pipelineInitialized = false;
        this.loadedClasses.clear();
    }

    public void beforeTask(String str, String str2, Object[] objArr) {
        if (this.pipelineInitialized) {
            DbndRun currentRun = currentRun();
            Method findMethodByName = findMethodByName(str2, str);
            LOG.info("Running task {}", currentRun.getTaskName(findMethodByName));
            currentRun.startTask(findMethodByName, objArr);
            this.stack.push(str2);
            return;
        }
        if (!this.stack.isEmpty()) {
            beforePipeline(str, this.stack.peek(), objArr);
        } else {
            beforePipeline(str, str2, objArr);
            this.stack.push(str2);
        }
    }

    public void afterTask(String str, Object obj) {
        this.stack.pop();
        if (this.stack.isEmpty()) {
            afterPipeline();
            return;
        }
        DbndRun currentRun = currentRun();
        Method findMethodByName = findMethodByName(str, null);
        currentRun.completeTask(findMethodByName, obj);
        LOG.info("Task {} has been completed!", currentRun.getTaskName(findMethodByName));
    }

    public void errorTask(String str, Throwable th) {
        LOG.info("Task {} returned error!", this.stack.pop());
        if (this.stack.isEmpty()) {
            errorPipeline(th);
        } else {
            currentRun().errorTask(findMethodByName(str, null), th);
        }
    }

    public void logTask(LoggingEvent loggingEvent, String str) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            return;
        }
        currentRun.saveLog(loggingEvent, str);
    }

    public void logMetric(String str, Object obj) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            currentRun = createAgentlessRun();
        }
        currentRun.logMetric(str, obj);
        LOG.info("Metric logged: [{}: {}]", str, obj);
    }

    public void logMetrics(Map<String, Object> map) {
        logMetrics(map, null);
    }

    public void logMetrics(Map<String, Object> map, String str) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            currentRun = createAgentlessRun();
        }
        currentRun.logMetrics(map, str);
    }

    public void logDataframe(String str, Dataset<?> dataset, HistogramRequest histogramRequest) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            currentRun = createAgentlessRun();
        }
        currentRun.logDataframe(str, dataset, histogramRequest);
    }

    public void logHistogram(Map<String, Object> map) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            currentRun = createAgentlessRun();
        }
        currentRun.logHistogram(map);
    }

    public void logDataframe(String str, Dataset<?> dataset, boolean z) {
        DbndRun currentRun = currentRun();
        if (currentRun == null) {
            currentRun = createAgentlessRun();
        }
        currentRun.logDataframe(str, dataset, new HistogramRequest(z));
        LOG.info("Dataframe logged");
    }

    public void logSpark(SparkListenerEvent sparkListenerEvent) {
        if (this.run == null) {
            this.run = createAgentlessRun();
        }
        if (sparkListenerEvent instanceof SparkListenerStageCompleted) {
            this.run.saveSparkMetrics((SparkListenerStageCompleted) sparkListenerEvent);
            LOG.info("Spark metrics saved");
        }
    }

    private synchronized DbndRun getOrCreateRun(Method method, Object[] objArr) {
        if (currentRun() == null) {
            initRun(method, objArr);
        }
        return currentRun();
    }

    private DbndRun createAgentlessRun() {
        try {
            StackTraceElement[] stackTrace = Thread.currentThread().getStackTrace();
            StackTraceElement stackTraceElement = null;
            int length = stackTrace.length;
            int i = 0;
            while (true) {
                if (i >= length) {
                    break;
                }
                StackTraceElement stackTraceElement2 = stackTrace[i];
                if (stackTraceElement2.getMethodName().equals("main")) {
                    stackTraceElement = stackTraceElement2;
                    break;
                }
                i++;
            }
            if (stackTraceElement == null) {
                stackTraceElement = stackTrace[stackTrace.length - 1];
            }
            for (Method method : Class.forName(stackTraceElement.getClassName()).getMethods()) {
                if (method.getName().contains(stackTraceElement.getMethodName())) {
                    Object[] objArr = new Object[method.getParameterCount()];
                    Arrays.fill(objArr, (Object) null);
                    beforePipeline(stackTraceElement.getClassName(), stackTraceElement.getMethodName(), objArr);
                }
            }
        } catch (ClassNotFoundException e) {
        }
        Runtime runtime = Runtime.getRuntime();
        DbndRun dbndRun = this.run;
        dbndRun.getClass();
        runtime.addShutdownHook(new Thread(dbndRun::stop));
        return this.run;
    }

    private DbndRun currentRun() {
        return this.run;
    }

    private void initRun(Method method, Object[] objArr) {
        this.run = this.config.isTrackingEnabled() ? new DefaultDbndRun(this.dbnd, this.config) : new NoopDbndRun();
        if (!this.config.isTrackingEnabled()) {
            System.out.println("Tracking is not enabled. Set DBND__TRACKING variable to True if you want to enable it.");
            return;
        }
        try {
            this.run.init(method, objArr);
            System.out.printf("Running pipeline %s%n", this.run.getTaskName(method));
        } catch (Exception e) {
            this.run = new NoopDbndRun();
            System.out.printf("Unable to init run: %s%n", e.getMessage());
            e.printStackTrace();
        }
    }

    protected void printStack() {
        StringBuilder sb = new StringBuilder(3);
        Iterator<String> it = this.stack.iterator();
        sb.append('[');
        while (it.hasNext()) {
            sb.append(' ');
            sb.append(it.next());
            sb.append(' ');
        }
        sb.append(']');
        LOG.info(sb.toString());
    }
}
