package org.apache.sysds.runtime.privacy.propagation;

import java.util.stream.Stream;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.privacy.PrivacyConstraint;
import org.apache.sysds.runtime.privacy.finegrained.FineGrainedPrivacy;

/* loaded from: input_file:org/apache/sysds/runtime/privacy/propagation/MatrixMultiplicationPropagator.class */
public abstract class MatrixMultiplicationPropagator implements Propagator {
    MatrixBlock input1;
    MatrixBlock input2;
    PrivacyConstraint privacyConstraint1;
    PrivacyConstraint privacyConstraint2;

    public MatrixMultiplicationPropagator() {
    }

    public MatrixMultiplicationPropagator(MatrixBlock matrixBlock, PrivacyConstraint privacyConstraint, MatrixBlock matrixBlock2, PrivacyConstraint privacyConstraint2) {
        setFields(matrixBlock, privacyConstraint, matrixBlock2, privacyConstraint2);
    }

    public void setFields(MatrixBlock matrixBlock, PrivacyConstraint privacyConstraint, MatrixBlock matrixBlock2, PrivacyConstraint privacyConstraint2) {
        this.input1 = matrixBlock;
        this.privacyConstraint1 = privacyConstraint;
        this.input2 = matrixBlock2;
        this.privacyConstraint2 = privacyConstraint2;
    }

    @Override // org.apache.sysds.runtime.privacy.propagation.Propagator
    public PrivacyConstraint propagate() {
        if ((this.privacyConstraint1 != null && this.privacyConstraint1.getPrivacyLevel() == PrivacyConstraint.PrivacyLevel.Private) || (this.privacyConstraint2 != null && this.privacyConstraint2.getPrivacyLevel() == PrivacyConstraint.PrivacyLevel.Private)) {
            return new PrivacyConstraint(PrivacyConstraint.PrivacyLevel.Private);
        }
        int numRows = this.input1.getNumRows();
        int numColumns = this.input1.getNumColumns();
        int numRows2 = this.input2.getNumRows();
        int numColumns2 = this.input2.getNumColumns();
        PrivacyConstraint privacyConstraint = new PrivacyConstraint();
        generateFineGrainedConstraints(privacyConstraint.getFineGrainedPrivacy(), (this.privacyConstraint1 == null || this.privacyConstraint1.getFineGrainedPrivacy() == null) ? (PrivacyConstraint.PrivacyLevel[]) Stream.generate(() -> {
            return PrivacyConstraint.PrivacyLevel.None;
        }).limit(numRows).toArray(i -> {
            return new PrivacyConstraint.PrivacyLevel[i];
        }) : this.privacyConstraint1.getFineGrainedPrivacy().getRowPrivacy(numRows, numColumns), (this.privacyConstraint2 == null || this.privacyConstraint2.getFineGrainedPrivacy() == null) ? (PrivacyConstraint.PrivacyLevel[]) Stream.generate(() -> {
            return PrivacyConstraint.PrivacyLevel.None;
        }).limit(numColumns2).toArray(i2 -> {
            return new PrivacyConstraint.PrivacyLevel[i2];
        }) : this.privacyConstraint2.getFineGrainedPrivacy().getColPrivacy(numRows2, numColumns2), getOperatorTypesRow(), getOperatorTypesCol());
        return privacyConstraint;
    }

    public OperatorType[] getOperatorTypesRow() {
        OperatorType[] operatorTypeArr = new OperatorType[this.input1.getNumRows()];
        for (int i = 0; i < this.input1.getNumRows(); i++) {
            operatorTypeArr[i] = getOperatorType(this.input1.slice(i, i));
        }
        return operatorTypeArr;
    }

    public OperatorType[] getOperatorTypesCol() {
        OperatorType[] operatorTypeArr = new OperatorType[this.input2.getNumColumns()];
        for (int i = 0; i < this.input2.getNumColumns(); i++) {
            operatorTypeArr[i] = getOperatorType(this.input2.slice(0, this.input2.getNumRows() - 1, i, i, (CacheBlock) new MatrixBlock()));
        }
        return operatorTypeArr;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public OperatorType mergeOperatorType(OperatorType operatorType, OperatorType operatorType2) {
        return (operatorType == OperatorType.NonAggregate || operatorType2 == OperatorType.NonAggregate) ? OperatorType.NonAggregate : OperatorType.Aggregate;
    }

    protected OperatorType getOperatorType(MatrixBlock matrixBlock) {
        return matrixBlock.getNonZeros() == 1 ? OperatorType.NonAggregate : OperatorType.Aggregate;
    }

    protected abstract void generateFineGrainedConstraints(FineGrainedPrivacy fineGrainedPrivacy, PrivacyConstraint.PrivacyLevel[] privacyLevelArr, PrivacyConstraint.PrivacyLevel[] privacyLevelArr2, OperatorType[] operatorTypeArr, OperatorType[] operatorTypeArr2);
}
