package org.bytedeco.pytorch;

import org.bytedeco.javacpp.Loader;
import org.bytedeco.javacpp.Pointer;
import org.bytedeco.javacpp.annotation.ByRef;
import org.bytedeco.javacpp.annotation.ByVal;
import org.bytedeco.javacpp.annotation.Cast;
import org.bytedeco.javacpp.annotation.Const;
import org.bytedeco.javacpp.annotation.Namespace;
import org.bytedeco.javacpp.annotation.NoDeallocator;
import org.bytedeco.javacpp.annotation.NoOffset;
import org.bytedeco.javacpp.annotation.Properties;
import org.bytedeco.pytorch.presets.torch;

@Namespace("torch::nn")
@NoOffset
@Properties(inherit = {torch.class})
/* loaded from: input_file:org/bytedeco/pytorch/MultiheadAttentionImpl.class */
public class MultiheadAttentionImpl extends MultiheadAttentionImplCloneable {
    public MultiheadAttentionImpl(Pointer pointer) {
        super(pointer);
    }

    public MultiheadAttentionImpl(@Cast({"int64_t"}) long j, @Cast({"int64_t"}) long j2) {
        super((Pointer) null);
        allocate(j, j2);
    }

    @NoDeallocator
    private native void allocate(@Cast({"int64_t"}) long j, @Cast({"int64_t"}) long j2);

    public MultiheadAttentionImpl(@Const @ByRef MultiheadAttentionOptions multiheadAttentionOptions) {
        super((Pointer) null);
        allocate(multiheadAttentionOptions);
    }

    @NoDeallocator
    private native void allocate(@Const @ByRef MultiheadAttentionOptions multiheadAttentionOptions);

    @ByVal
    public native TensorTensorTuple forward(@Const @ByRef Tensor tensor, @Const @ByRef Tensor tensor2, @Const @ByRef Tensor tensor3, @Const @ByRef(nullValue = "at::Tensor{}") Tensor tensor4, @Cast({"bool"}) boolean z, @Const @ByRef(nullValue = "at::Tensor{}") Tensor tensor5, @Cast({"bool"}) boolean z2);

    @ByVal
    public native TensorTensorTuple forward(@Const @ByRef Tensor tensor, @Const @ByRef Tensor tensor2, @Const @ByRef Tensor tensor3);

    @Override // org.bytedeco.pytorch.MultiheadAttentionImplCloneable
    public native void reset();

    public native void _reset_parameters();

    @ByRef
    public native MultiheadAttentionOptions options();

    public native MultiheadAttentionImpl options(MultiheadAttentionOptions multiheadAttentionOptions);

    @Cast({"bool"})
    public native boolean _qkv_same_embed_dim();

    public native MultiheadAttentionImpl _qkv_same_embed_dim(boolean z);

    @ByRef
    public native Tensor in_proj_weight();

    public native MultiheadAttentionImpl in_proj_weight(Tensor tensor);

    @ByRef
    public native Tensor in_proj_bias();

    public native MultiheadAttentionImpl in_proj_bias(Tensor tensor);

    @ByRef
    public native Tensor bias_k();

    public native MultiheadAttentionImpl bias_k(Tensor tensor);

    @ByRef
    public native Tensor bias_v();

    public native MultiheadAttentionImpl bias_v(Tensor tensor);

    @ByRef
    public native Linear out_proj();

    public native MultiheadAttentionImpl out_proj(Linear linear);

    @ByRef
    public native Tensor q_proj_weight();

    public native MultiheadAttentionImpl q_proj_weight(Tensor tensor);

    @ByRef
    public native Tensor k_proj_weight();

    public native MultiheadAttentionImpl k_proj_weight(Tensor tensor);

    @ByRef
    public native Tensor v_proj_weight();

    public native MultiheadAttentionImpl v_proj_weight(Tensor tensor);

    @Cast({"int64_t"})
    public native long head_dim();

    public native MultiheadAttentionImpl head_dim(long j);

    static {
        Loader.load();
    }
}
