package org.tensorflow.framework.losses;

import org.tensorflow.Operand;
import org.tensorflow.framework.losses.impl.AbstractLoss;
import org.tensorflow.framework.losses.impl.LossesHelper;
import org.tensorflow.op.Ops;
import org.tensorflow.types.family.TNumber;

/* loaded from: input_file:org/tensorflow/framework/losses/CosineSimilarity.class */
public class CosineSimilarity extends AbstractLoss {
    public static final int DEFAULT_AXIS = -1;
    public static final Reduction DEFAULT_REDUCTION = Reduction.AUTO;
    private final int[] axis;

    public CosineSimilarity() {
        this((String) null, -1, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(String str) {
        this(str, -1, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(int i) {
        this((String) null, i, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(int[] iArr) {
        this((String) null, iArr, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(String str, int i) {
        this(str, i, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(String str, int[] iArr) {
        this(str, iArr, DEFAULT_REDUCTION);
    }

    public CosineSimilarity(Reduction reduction) {
        this((String) null, -1, reduction);
    }

    public CosineSimilarity(String str, Reduction reduction) {
        this(str, -1, reduction);
    }

    public CosineSimilarity(int i, Reduction reduction) {
        this((String) null, new int[]{i}, reduction);
    }

    public CosineSimilarity(int[] iArr, Reduction reduction) {
        this((String) null, iArr, reduction);
    }

    public CosineSimilarity(String str, int i, Reduction reduction) {
        this(str, new int[]{i}, reduction);
    }

    public CosineSimilarity(String str, int[] iArr, Reduction reduction) {
        super(str, reduction);
        this.axis = iArr;
    }

    @Override // org.tensorflow.framework.losses.Loss
    public <T extends TNumber> Operand<T> call(Ops ops, Operand<? extends TNumber> operand, Operand<T> operand2, Operand<T> operand3) {
        return LossesHelper.computeWeightedLoss(ops, ops.math.neg(Losses.cosineSimilarity(ops, operand, operand2, this.axis)), getReduction(), operand3);
    }
}
