package ai.h2o.mojos.runtime.utils;

import java.lang.reflect.Method;

/**
 * Utility class to overcome differences
 * between Java7 and Java8+.
 *
 * Java 7 does not provide class `java.util.Base64` and suggested
 * workaround is to use `javax.xml.bind.DatatypeConverter` which is
 * not easily supported in Java9.
 * Hence, this utility class delegate implementation to `Base64` or
 * `DatatypeConvertor` based on existence of `java.util.Base64`.
 */
public class Base64Utils {

  /** Delegate to correct implementation. */
  private static final Base64Iface DELEGATE;

  static {
    // First try to load Java8 implementation
    Class<?> base64Klazz = tryLoadFirst("java.util.Base64",
                                        Base64Utils.class.getClassLoader());
    Base64Iface delegate = null;
    if (base64Klazz != null) {
      delegate = Base64Java8.create(base64Klazz);
    } else {
      // Try to load Java7 implementation
      Class<?> datatypeConvertorKlazz = tryLoadFirst("javax.xml.bind.DatatypeConverter",
                                                     Base64Utils.class.getClassLoader());
      if (datatypeConvertorKlazz != null) {
        delegate = Base64Java7.create(datatypeConvertorKlazz);
      }
    }
    if (delegate == null) {
      // Kill JVM since something is really wrong
      throw new LinkageError("Cannot find suitable implementation for Base64 utilities!");
    } else {
      DELEGATE = delegate;
    }
  }

  private Base64Utils() {}

  public static byte[] encode(byte[] src) {
    return DELEGATE.encode(src);
  }

  public static byte[] decode(byte[] src) {
    return DELEGATE.decode(src);
  }

  public static byte[] encodeUrl(byte[] src) {
    return DELEGATE.encodeUrl(src);
  }

  public static byte[] decodeUrl(byte[] src) {
    return DELEGATE.decodeUrl(src);
  }

  public static byte[] decodeUrl(String src) {
    return !isEmpty(src) ? DELEGATE.decodeUrl(src.getBytes()) : null;
  }

  private interface Base64Iface {
    byte[] encode(byte[] src);
    byte[] decode(byte[] src);
    byte[] encodeUrl(byte[] src);
    byte[] decodeUrl(byte[] src);
  }

  private static class Base64Java8 implements Base64Iface {

    private final Object encoder;
    private final Method encoderEncode;
    private final Object encoderUrl;
    private final Method encoderUrlEncode;
    private final Object decoder;
    private final Method decoderDecode;
    private final Object decoderUrl;
    private final Method decoderUrlDecode;

    private Base64Java8(Object encoder, Method encoderEncode,
                        Object encoderUrl, Method encoderUrlEncode,
                        Object decoder, Method decoderDecode,
                        Object decoderUrl, Method decoderUrlDecode) {
      this.encoder = encoder;
      this.encoderEncode = encoderEncode;
      this.encoderUrl = encoderUrl;
      this.encoderUrlEncode = encoderUrlEncode;
      this.decoder = decoder;
      this.decoderDecode = decoderDecode;
      this.decoderUrl = decoderUrl;
      this.decoderUrlDecode = decoderUrlDecode;
    }

    @Override
    public byte[] encode(byte[] src) {
      return call(encoder, encoderEncode, src);
    }

    @Override
    public byte[] decode(byte[] src) {
      return call(decoder, decoderDecode, src);
    }

    @Override
    public byte[] encodeUrl(byte[] src) {
      return call(encoderUrl, encoderUrlEncode, src);
    }

    @Override
    public byte[] decodeUrl(byte[] src) {
      return call(decoderUrl, decoderUrlDecode, src);
    }

    private byte[] call(Object o, Method m, byte[] src) {
      try {
        return (byte[]) m.invoke(o, src);
      } catch (Exception e) {
        return null;
      }
    }

    static Base64Java8 create(Class<?> base64Class) {
      Object encoder = scall(base64Class, "getEncoder");
      Method encoderEncode = getMethodBA(encoder, "encode");
      Object encoderUrl = scall(base64Class, "getUrlEncoder");
      Method encoderUrlEncode = getMethodBA(encoderUrl, "encode");
      Object decoder = scall(base64Class, "getDecoder");
      Method decoderDecode = getMethodBA(decoder, "decode");
      Object decoderUrl = scall(base64Class, "getUrlDecoder");
      Method decoderUrlDecode = getMethodBA(decoderUrl, "decode");
      if (allNotNull(encoder, encoderEncode,
                     encoderUrl, encoderUrlEncode,
                     decoder, decoderDecode,
                     decoderUrl, decoderUrlDecode)) {
        return new Base64Java8(encoder, encoderEncode,
                               encoderUrl, encoderUrlEncode,
                               decoder, decoderDecode,
                               decoderUrl, decoderUrlDecode);
      } else {
        return null;
      }
    }

    // Get method with input input parameter byte[] or return null
    private static Method getMethodBA(Object o, String name) {
      try {
        return o.getClass().getMethod(name, byte[].class);
      } catch (NoSuchMethodException e) {
        return null;
      }
    }
  }

  private static class Base64Java7 implements Base64Iface {

    private final Method encoder;
    private final Method decoder;

    private Base64Java7(Method encoder, Method decoder) {
      this.encoder = encoder;
      this.decoder = decoder;
    }

    @Override
    public byte[] encode(byte[] src) {
      try {
        return ((String) encoder.invoke(null, src)).getBytes();
      } catch (Exception e) {
        return null;
      }
    }

    @Override
    public byte[] decode(byte[] src) {
      return decode(new String(src));
    }

    @Override
    public byte[] encodeUrl(byte[] src) {
      throw new RuntimeException("Not implemented");
    }

    @Override
    public byte[] decodeUrl(byte[] src) {
      return decodeUrl(new String(src));
    }

    private byte[] decode(String in) {
      if (in != null) {
        try {
          return (byte[]) decoder.invoke(null, in);
        } catch (Exception e) {
          return null;
        }
      } else {
        return null;
      }
    }
    private byte[] decodeUrl(String in) {
      if (in != null) {
        String base64 = in.replace('-', '+').replace('_', '/');
        return decode(base64);
      } else {
        return null;
      }
    }

    static Base64Java7 create(Class<?> datatypeConvertorKlazz) {
      Method encoder = getMethod(datatypeConvertorKlazz, "printBase64Binary", byte[].class);
      Method decoder = getMethod(datatypeConvertorKlazz, "parseBase64Binary", String.class);
      if (allNotNull(encoder, decoder)) {
        return new Base64Java7(encoder, decoder);
      } else {
        return null;
      }
    }

    static Method getMethod(Class<?> klazz, String methodName, Class<?> paramType) {
      try {
        return klazz.getMethod(methodName, paramType);
      } catch (NoSuchMethodException e) {
        return null;
      }
    }
  }

  static Class<?> tryLoadFirst(String klazzName, ClassLoader ...classLoaders) {
    for (ClassLoader c : classLoaders) {
      try {
        return c.loadClass(klazzName);
      } catch (ClassNotFoundException e) {
        // try next one
      }
    }
    return null;
  }

  static boolean allNotNull(Object ...objs) {
    for (Object o : objs) {
      if (o == null) return false;
    }
    return true;
  }

  // Calls static method without any parameters.
  // In case of exception always return `null`.
  static Object scall(Class<?> klazz, String name) {
    try {
      Method m = klazz.getMethod(name);
      return m.invoke(null);
    } catch (Exception e) {
      return null;
    }
  }

  public static boolean isEmpty(String s) {
    return s == null || s.isEmpty();
  }
}
