package org.apache.mahout.math;

import java.util.HashMap;
import java.util.Iterator;
import java.util.Random;
import org.apache.mahout.common.RandomUtils;
import org.apache.mahout.math.function.Functions;
import org.apache.mahout.math.function.VectorFunction;
import org.junit.Before;
import org.junit.Test;

/* loaded from: input_file:org/apache/mahout/math/MatrixTest.class */
public abstract class MatrixTest extends MahoutTestCase {
    protected static final int ROW = 0;
    protected static final int COL = 1;
    private final double[][] values = {new double[]{1.1d, 2.2d}, new double[]{3.3d, 4.4d}, new double[]{5.5d, 6.6d}};
    private final double[] vectorAValues = {0.9090909090909091d, 1.8181818181818181d};
    protected Matrix test;

    @Override // org.apache.mahout.math.MahoutTestCase
    @Before
    public void setUp() throws Exception {
        super.setUp();
        this.test = matrixFactory(this.values);
    }

    public abstract Matrix matrixFactory(double[][] dArr);

    @Test
    public void testCardinality() {
        int[] size = this.test.size();
        assertEquals("row cardinality", this.values.length, size[ROW]);
        assertEquals("col cardinality", this.values[ROW].length, size[COL]);
    }

    @Test
    public void testCopy() {
        int[] size = this.test.size();
        Matrix clone = this.test.clone();
        assertSame("wrong class", clone.getClass(), this.test.getClass());
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.test.getQuick(i, i2), clone.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testIterate() {
        MatrixSlice matrixSlice;
        Iterator it = this.test.iterator();
        while (it.hasNext() && (matrixSlice = (MatrixSlice) it.next()) != null) {
            Vector vector = matrixSlice.vector();
            Vector column = this.test instanceof SparseColumnMatrix ? this.test.getColumn(matrixSlice.index()) : this.test.getRow(matrixSlice.index());
            assertEquals("iterator: " + vector + ", randomAccess: " + column, vector, column);
        }
    }

    @Test
    public void testGetQuick() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2], this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testLike() {
        assertSame("type", this.test.like().getClass(), this.test.getClass());
        assertEquals("rows", this.test.size()[ROW], r0.size()[ROW]);
        assertEquals("columns", this.test.size()[COL], r0.size()[COL]);
    }

    @Test
    public void testLikeIntInt() {
        assertSame("type", this.test.like(4, 4).getClass(), this.test.getClass());
        assertEquals("rows", 4L, r0.size()[ROW]);
        assertEquals("columns", 4L, r0.size()[COL]);
    }

    @Test
    public void testSetQuick() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                this.test.setQuick(i, i2, 1.23d);
                assertEquals("value[" + i + "][" + i2 + ']', 1.23d, this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testSize() {
        int[] numNondefaultElements = this.test.getNumNondefaultElements();
        assertEquals("row size", this.values.length, numNondefaultElements[ROW]);
        assertEquals("col size", this.values[ROW].length, numNondefaultElements[COL]);
    }

    @Test
    public void testViewPart() {
        Matrix viewPart = this.test.viewPart(new int[]{COL, COL}, new int[]{2, COL});
        assertEquals(2L, viewPart.rowSize());
        assertEquals(1L, viewPart.columnSize());
        int[] size = viewPart.size();
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i + COL][i2 + COL], viewPart.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = IndexException.class)
    public void testViewPartCardinality() {
        this.test.viewPart(new int[]{COL, COL}, new int[]{3, 3});
    }

    @Test(expected = IndexException.class)
    public void testViewPartIndexOver() {
        this.test.viewPart(new int[]{COL, COL}, new int[]{2, 2});
    }

    @Test(expected = IndexException.class)
    public void testViewPartIndexUnder() {
        this.test.viewPart(new int[]{-1, -1}, new int[]{2, 2});
    }

    @Test
    public void testAssignDouble() {
        int[] size = this.test.size();
        this.test.assign(4.53d);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', 4.53d, this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testAssignDoubleArrayArray() {
        int[] size = this.test.size();
        this.test.assign(new double[3][2]);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', 0.0d, this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = CardinalityException.class)
    public void testAssignDoubleArrayArrayCardinality() {
        int[] size = this.test.size();
        this.test.assign(new double[size[ROW] + COL][size[COL]]);
    }

    @Test
    public void testAssignMatrixBinaryFunction() {
        int[] size = this.test.size();
        this.test.assign(this.test, Functions.PLUS);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', 2.0d * this.values[i][i2], this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = CardinalityException.class)
    public void testAssignMatrixBinaryFunctionCardinality() {
        this.test.assign(this.test.transpose(), Functions.PLUS);
    }

    @Test
    public void testAssignMatrix() {
        int[] size = this.test.size();
        Matrix like = this.test.like();
        like.assign(this.test);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.test.getQuick(i, i2), like.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = CardinalityException.class)
    public void testAssignMatrixCardinality() {
        this.test.assign(this.test.transpose());
    }

    @Test
    public void testAssignUnaryFunction() {
        int[] size = this.test.size();
        this.test.assign(Functions.mult(-1.0d));
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', -this.values[i][i2], this.test.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testRowView() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW]; i += COL) {
            assertEquals(0.0d, this.test.getRow(i).minus(this.test.viewRow(i)).norm(1.0d), 0.0d);
        }
        assertEquals(size[COL], this.test.viewRow(3).size());
        assertEquals(size[COL], this.test.viewRow(5).size());
        Random random = RandomUtils.getRandom();
        for (int i2 = ROW; i2 < size[ROW]; i2 += COL) {
            int nextInt = random.nextInt(size[COL]);
            double d = this.test.get(i2, nextInt);
            double nextGaussian = random.nextGaussian();
            this.test.viewRow(i2).set(nextInt, nextGaussian);
            assertEquals(nextGaussian, this.test.get(i2, nextInt), 0.0d);
            assertEquals(nextGaussian, this.test.viewRow(i2).get(nextInt), 0.0d);
            this.test.set(i2, nextInt, d);
            assertEquals(d, this.test.get(i2, nextInt), 0.0d);
            assertEquals(d, this.test.viewRow(i2).get(nextInt), 0.0d);
        }
    }

    @Test
    public void testColumnView() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[COL]; i += COL) {
            assertEquals(0.0d, this.test.getColumn(i).minus(this.test.viewColumn(i)).norm(1.0d), 0.0d);
        }
        assertEquals(size[ROW], this.test.viewColumn(3).size());
        assertEquals(size[ROW], this.test.viewColumn(5).size());
        Random random = RandomUtils.getRandom();
        for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
            int nextInt = random.nextInt(size[COL]);
            double d = this.test.get(i2, nextInt);
            double nextGaussian = random.nextGaussian();
            this.test.viewColumn(i2).set(nextInt, nextGaussian);
            assertEquals(nextGaussian, this.test.get(nextInt, i2), 0.0d);
            assertEquals(nextGaussian, this.test.viewColumn(i2).get(nextInt), 0.0d);
            this.test.set(nextInt, i2, d);
            assertEquals(d, this.test.get(nextInt, i2), 0.0d);
            assertEquals(d, this.test.viewColumn(i2).get(nextInt), 0.0d);
        }
    }

    @Test
    public void testAggregateRows() {
        Vector aggregateRows = this.test.aggregateRows(new VectorFunction() { // from class: org.apache.mahout.math.MatrixTest.1
            public double apply(Vector vector) {
                return vector.zSum();
            }
        });
        for (int i = ROW; i < this.test.numRows(); i += COL) {
            assertEquals(this.test.getRow(i).zSum(), aggregateRows.get(i), 1.0E-6d);
        }
    }

    @Test
    public void testAggregateCols() {
        Vector aggregateColumns = this.test.aggregateColumns(new VectorFunction() { // from class: org.apache.mahout.math.MatrixTest.2
            public double apply(Vector vector) {
                return vector.zSum();
            }
        });
        for (int i = ROW; i < this.test.numCols(); i += COL) {
            assertEquals(this.test.getColumn(i).zSum(), aggregateColumns.get(i), 1.0E-6d);
        }
    }

    @Test
    public void testAggregate() {
        assertEquals(this.test.aggregateRows(new VectorFunction() { // from class: org.apache.mahout.math.MatrixTest.3
            public double apply(Vector vector) {
                return vector.zSum();
            }
        }).zSum(), this.test.aggregate(Functions.PLUS, Functions.IDENTITY), 1.0E-6d);
    }

    @Test
    public void testDivide() {
        int[] size = this.test.size();
        Matrix divide = this.test.divide(4.53d);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2] / 4.53d, divide.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testGet() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2], this.test.get(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = IndexException.class)
    public void testGetIndexUnder() {
        int[] size = this.test.size();
        for (int i = -1; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                this.test.get(i, i2);
            }
        }
    }

    @Test(expected = IndexException.class)
    public void testGetIndexOver() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW] + COL; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                this.test.get(i, i2);
            }
        }
    }

    @Test
    public void testMinus() {
        int[] size = this.test.size();
        Matrix minus = this.test.minus(this.test);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', 0.0d, minus.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = CardinalityException.class)
    public void testMinusCardinality() {
        this.test.minus(this.test.transpose());
    }

    @Test
    public void testPlusDouble() {
        int[] size = this.test.size();
        Matrix plus = this.test.plus(4.53d);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2] + 4.53d, plus.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test
    public void testPlusMatrix() {
        int[] size = this.test.size();
        Matrix plus = this.test.plus(this.test);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2] * 2.0d, plus.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    @Test(expected = CardinalityException.class)
    public void testPlusMatrixCardinality() {
        this.test.plus(this.test.transpose());
    }

    @Test(expected = IndexException.class)
    public void testSetUnder() {
        int[] size = this.test.size();
        for (int i = -1; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                this.test.set(i, i2, 1.23d);
            }
        }
    }

    @Test(expected = IndexException.class)
    public void testSetOver() {
        int[] size = this.test.size();
        for (int i = ROW; i < size[ROW] + COL; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                this.test.set(i, i2, 1.23d);
            }
        }
    }

    @Test
    public void testTimesDouble() {
        int[] size = this.test.size();
        Matrix times = this.test.times(4.53d);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.values[i][i2] * 4.53d, times.getQuick(i, i2), 1.0E-6d);
            }
        }
    }

    /* JADX WARN: Type inference failed for: r2v9, types: [double[], double[][]] */
    @Test
    public void testTimesMatrix() {
        int[] size = this.test.size();
        Matrix times = this.test.times(this.test.transpose());
        int[] size2 = times.size();
        assertEquals("rows", size[ROW], size2[ROW]);
        assertEquals("cols", size[ROW], size2[COL]);
        Matrix times2 = new DenseMatrix((double[][]) new double[]{new double[]{5.0d, 11.0d, 17.0d}, new double[]{11.0d, 25.0d, 39.0d}, new double[]{17.0d, 39.0d, 61.0d}}).times(1.21d);
        for (int i = ROW; i < times2.numCols(); i += COL) {
            for (int i2 = ROW; i2 < times2.numRows(); i2 += COL) {
                assertTrue("Matrix times transpose not correct: " + i + ", " + i2 + "\nexpected:\n\t" + times2 + "\nactual:\n\t" + times, Math.abs(times2.get(i, i2) - times.get(i, i2)) < 1.0E-12d);
            }
        }
        DenseMatrix denseMatrix = new DenseMatrix(10, COL);
        denseMatrix.transpose().times(denseMatrix);
    }

    @Test(expected = CardinalityException.class)
    public void testTimesVector() {
        DenseVector denseVector = new DenseVector(this.vectorAValues);
        Vector times = this.test.times(denseVector);
        assertTrue("Matrix times vector not equals: " + denseVector + " != " + times, new DenseVector(new double[]{5.0d, 11.0d, 17.0d}).minus(times).norm(2.0d) < 1.0E-12d);
        this.test.times(times);
    }

    @Test
    public void testTimesSquaredTimesVector() {
        DenseVector denseVector = new DenseVector(this.vectorAValues);
        Vector timesSquared = this.test.timesSquared(denseVector);
        Vector times = this.test.transpose().times(this.test.times(denseVector));
        assertTrue("M'Mv != M.timesSquared(v): " + timesSquared + " != " + times, times.minus(timesSquared).norm(2.0d) < 1.0E-12d);
    }

    @Test(expected = CardinalityException.class)
    public void testTimesMatrixCardinality() {
        this.test.times(this.test.like(5, 8));
    }

    @Test
    public void testTranspose() {
        int[] size = this.test.size();
        Matrix transpose = this.test.transpose();
        int[] size2 = transpose.size();
        assertEquals("rows", size[COL], size2[ROW]);
        assertEquals("cols", size[ROW], size2[COL]);
        for (int i = ROW; i < size[ROW]; i += COL) {
            for (int i2 = ROW; i2 < size[COL]; i2 += COL) {
                assertEquals("value[" + i + "][" + i2 + ']', this.test.getQuick(i, i2), transpose.getQuick(i2, i), 1.0E-6d);
            }
        }
    }

    @Test
    public void testZSum() {
        assertEquals("zsum", 23.1d, this.test.zSum(), 1.0E-6d);
    }

    @Test
    public void testAssignRow() {
        this.test.assignRow(COL, new DenseVector(new double[]{2.1d, 3.2d}));
        assertEquals("test[1][0]", 2.1d, this.test.getQuick(COL, ROW), 1.0E-6d);
        assertEquals("test[1][1]", 3.2d, this.test.getQuick(COL, COL), 1.0E-6d);
    }

    @Test(expected = CardinalityException.class)
    public void testAssignRowCardinality() {
        this.test.assignRow(COL, new DenseVector(new double[]{2.1d, 3.2d, 4.3d}));
    }

    @Test
    public void testAssignColumn() {
        this.test.assignColumn(COL, new DenseVector(new double[]{2.1d, 3.2d, 4.3d}));
        assertEquals("test[0][1]", 2.1d, this.test.getQuick(ROW, COL), 1.0E-6d);
        assertEquals("test[1][1]", 3.2d, this.test.getQuick(COL, COL), 1.0E-6d);
        assertEquals("test[2][1]", 4.3d, this.test.getQuick(2, COL), 1.0E-6d);
    }

    @Test(expected = CardinalityException.class)
    public void testAssignColumnCardinality() {
        this.test.assignColumn(COL, new DenseVector(new double[]{2.1d, 3.2d}));
    }

    @Test
    public void testGetRow() {
        assertEquals("row size", 2L, this.test.getRow(COL).getNumNondefaultElements());
    }

    @Test(expected = IndexException.class)
    public void testGetRowIndexUnder() {
        this.test.getRow(-1);
    }

    @Test(expected = IndexException.class)
    public void testGetRowIndexOver() {
        this.test.getRow(5);
    }

    @Test
    public void testGetColumn() {
        assertEquals("row size", 3L, this.test.getColumn(COL).getNumNondefaultElements());
    }

    @Test(expected = IndexException.class)
    public void testGetColumnIndexUnder() {
        this.test.getColumn(-1);
    }

    @Test(expected = IndexException.class)
    public void testGetColumnIndexOver() {
        this.test.getColumn(5);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    @Test
    public void testDetermitant() {
        assertEquals("determinant", 43.0d, matrixFactory(new double[]{new double[]{1.0d, 3.0d, 4.0d}, new double[]{5.0d, 2.0d, 3.0d}, new double[]{1.0d, 4.0d, 2.0d}}).determinant(), 1.0E-6d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    @Test
    public void testLabelBindings() {
        Matrix matrixFactory = matrixFactory(new double[]{new double[]{1.0d, 3.0d, 4.0d}, new double[]{5.0d, 2.0d, 3.0d}, new double[]{1.0d, 4.0d, 2.0d}});
        assertNull("row bindings", matrixFactory.getRowLabelBindings());
        assertNull("col bindings", matrixFactory.getColumnLabelBindings());
        HashMap hashMap = new HashMap();
        hashMap.put("Fee", Integer.valueOf(ROW));
        hashMap.put("Fie", Integer.valueOf(COL));
        hashMap.put("Foe", 2);
        matrixFactory.setRowLabelBindings(hashMap);
        assertEquals("row", hashMap, matrixFactory.getRowLabelBindings());
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Foo", Integer.valueOf(ROW));
        hashMap2.put("Bar", Integer.valueOf(COL));
        hashMap2.put("Baz", 2);
        matrixFactory.setColumnLabelBindings(hashMap2);
        assertEquals("row", hashMap, matrixFactory.getRowLabelBindings());
        assertEquals("Fee", matrixFactory.get(ROW, COL), matrixFactory.get("Fee", "Bar"), 1.0E-6d);
        matrixFactory.set("Foe", new double[]{9.0d, 8.0d, 7.0d});
        assertEquals("FeeBaz", matrixFactory.get(ROW, 2), matrixFactory.get("Fee", "Baz"), 1.0E-6d);
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    @Test(expected = UnboundLabelException.class)
    public void testSettingLabelBindings() {
        Matrix matrixFactory = matrixFactory(new double[]{new double[]{1.0d, 3.0d, 4.0d}, new double[]{5.0d, 2.0d, 3.0d}, new double[]{1.0d, 4.0d, 2.0d}});
        assertNull("row bindings", matrixFactory.getRowLabelBindings());
        assertNull("col bindings", matrixFactory.getColumnLabelBindings());
        matrixFactory.set("Fee", "Foo", COL, 2, 9.0d);
        assertNotNull("row", matrixFactory.getRowLabelBindings());
        assertNotNull("row", matrixFactory.getRowLabelBindings());
        assertEquals("Fee", 1L, ((Integer) matrixFactory.getRowLabelBindings().get("Fee")).intValue());
        assertEquals("Fee", 2L, ((Integer) matrixFactory.getColumnLabelBindings().get("Foo")).intValue());
        assertEquals("FeeFoo", matrixFactory.get(COL, 2), matrixFactory.get("Fee", "Foo"), 1.0E-6d);
        matrixFactory.get("Fie", "Foe");
    }

    /* JADX WARN: Multi-variable type inference failed */
    /* JADX WARN: Type inference failed for: r1v1, types: [double[], double[][]] */
    @Test
    public void testLabelBindingSerialization() {
        Matrix matrixFactory = matrixFactory(new double[]{new double[]{1.0d, 3.0d, 4.0d}, new double[]{5.0d, 2.0d, 3.0d}, new double[]{1.0d, 4.0d, 2.0d}});
        assertNull("row bindings", matrixFactory.getRowLabelBindings());
        assertNull("col bindings", matrixFactory.getColumnLabelBindings());
        HashMap hashMap = new HashMap();
        hashMap.put("Fee", Integer.valueOf(ROW));
        hashMap.put("Fie", Integer.valueOf(COL));
        hashMap.put("Foe", 2);
        matrixFactory.setRowLabelBindings(hashMap);
        assertEquals("row", hashMap, matrixFactory.getRowLabelBindings());
        HashMap hashMap2 = new HashMap();
        hashMap2.put("Foo", Integer.valueOf(ROW));
        hashMap2.put("Bar", Integer.valueOf(COL));
        hashMap2.put("Baz", 2);
        matrixFactory.setColumnLabelBindings(hashMap2);
        assertEquals("col", hashMap2, matrixFactory.getColumnLabelBindings());
    }
}
