package com.github.chen0040.rl.utils;

import com.alibaba.fastjson.annotation.JSONField;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

/* loaded from: input_file:com/github/chen0040/rl/utils/Matrix.class */
public class Matrix implements Serializable {
    private Map<Integer, Vec> rows;
    private int rowCount;
    private int columnCount;
    private double defaultValue;

    public Matrix() {
        this.rows = new HashMap();
    }

    public Matrix(double[][] dArr) {
        this.rows = new HashMap();
        for (int i = 0; i < dArr.length; i++) {
            double[] dArr2 = dArr[i];
            for (int i2 = 0; i2 < dArr2.length; i2++) {
                set(i, i2, dArr2[i2]);
            }
        }
    }

    public void setRow(int i, Vec vec) {
        vec.setId(i);
        this.rows.put(Integer.valueOf(i), vec);
    }

    public static Matrix identity(int i) {
        Matrix matrix = new Matrix(i, i);
        for (int i2 = 0; i2 < matrix.getRowCount(); i2++) {
            matrix.set(i2, i2, 1.0d);
        }
        return matrix;
    }

    public boolean equals(Object obj) {
        if (obj == null || !(obj instanceof Matrix)) {
            return false;
        }
        Matrix matrix = (Matrix) obj;
        if (this.rowCount != matrix.rowCount || this.columnCount != matrix.columnCount) {
            return false;
        }
        if (this.defaultValue != matrix.defaultValue) {
            for (int i = 0; i < this.rowCount; i++) {
                for (int i2 = 0; i2 < this.columnCount; i2++) {
                    if (get(i, i2) != matrix.get(i, i2)) {
                        return false;
                    }
                }
            }
            return true;
        }
        for (Integer num : this.rows.keySet()) {
            if (!matrix.rows.containsKey(num)) {
                return false;
            }
            if (!this.rows.get(num).equals(matrix.rows.get(num))) {
                System.out.println("failed!");
                return false;
            }
        }
        for (Integer num2 : matrix.rows.keySet()) {
            if (!this.rows.containsKey(num2)) {
                return false;
            }
            if (!matrix.rows.get(num2).equals(this.rows.get(num2))) {
                System.out.println("failed! 22");
                return false;
            }
        }
        return true;
    }

    public Matrix makeCopy() {
        Matrix matrix = new Matrix(this.rowCount, this.columnCount);
        matrix.copy(this);
        return matrix;
    }

    public void copy(Matrix matrix) {
        this.rowCount = matrix.rowCount;
        this.columnCount = matrix.columnCount;
        this.defaultValue = matrix.defaultValue;
        this.rows.clear();
        for (Map.Entry<Integer, Vec> entry : matrix.rows.entrySet()) {
            this.rows.put(entry.getKey(), entry.getValue().makeCopy());
        }
    }

    public void set(int i, int i2, double d) {
        rowAt(i).set(i2, d);
        if (i >= this.rowCount) {
            this.rowCount = i + 1;
        }
        if (i2 >= this.columnCount) {
            this.columnCount = i2 + 1;
        }
    }

    public Matrix(int i, int i2) {
        this.rows = new HashMap();
        this.rowCount = i;
        this.columnCount = i2;
        this.defaultValue = 0.0d;
    }

    public Vec rowAt(int i) {
        Vec vec = this.rows.get(Integer.valueOf(i));
        if (vec == null) {
            vec = new Vec(this.columnCount);
            vec.setAll(this.defaultValue);
            vec.setId(i);
            this.rows.put(Integer.valueOf(i), vec);
        }
        return vec;
    }

    public void setAll(double d) {
        this.defaultValue = d;
        Iterator<Vec> it = this.rows.values().iterator();
        while (it.hasNext()) {
            it.next().setAll(d);
        }
    }

    public double get(int i, int i2) {
        return rowAt(i).get(i2);
    }

    public List<Vec> columnVectors() {
        int columnCount = getColumnCount();
        int rowCount = getRowCount();
        ArrayList arrayList = new ArrayList();
        for (int i = 0; i < columnCount; i++) {
            Vec vec = new Vec(rowCount);
            vec.setAll(this.defaultValue);
            vec.setId(i);
            for (int i2 = 0; i2 < rowCount; i2++) {
                vec.set(i2, get(i2, i));
            }
            arrayList.add(vec);
        }
        return arrayList;
    }

    public Matrix multiply(Matrix matrix) {
        if (getColumnCount() != matrix.getRowCount()) {
            System.err.println("A.columnCount must be equal to B.rowCount in multiplication");
            return null;
        }
        Matrix matrix2 = new Matrix(getRowCount(), matrix.getColumnCount());
        matrix2.setAll(this.defaultValue);
        List<Vec> columnVectors = matrix.columnVectors();
        for (Map.Entry<Integer, Vec> entry : this.rows.entrySet()) {
            int intValue = entry.getKey().intValue();
            Vec value = entry.getValue();
            for (int i = 0; i < columnVectors.size(); i++) {
                matrix2.set(intValue, i, value.multiply(columnVectors.get(i)));
            }
        }
        return matrix2;
    }

    @JSONField(serialize = false)
    public boolean isSymmetric() {
        if (getRowCount() != getColumnCount()) {
            return false;
        }
        for (Map.Entry<Integer, Vec> entry : this.rows.entrySet()) {
            int intValue = entry.getKey().intValue();
            Vec value = entry.getValue();
            for (Integer num : value.getData().keySet()) {
                if (intValue != num.intValue() && DoubleUtils.equals(value.get(num.intValue()), get(num.intValue(), intValue))) {
                    return false;
                }
            }
        }
        return true;
    }

    public Vec multiply(Vec vec) {
        if (getColumnCount() != vec.getDimension()) {
            System.err.println("columnCount must be equal to the size of the vector for multiplication");
        }
        Vec vec2 = new Vec(getRowCount());
        for (Map.Entry<Integer, Vec> entry : this.rows.entrySet()) {
            vec2.set(entry.getKey().intValue(), entry.getValue().multiply(vec));
        }
        return vec2;
    }

    public Map<Integer, Vec> getRows() {
        return this.rows;
    }

    public int getRowCount() {
        return this.rowCount;
    }

    public int getColumnCount() {
        return this.columnCount;
    }

    public double getDefaultValue() {
        return this.defaultValue;
    }

    public void setRows(Map<Integer, Vec> map) {
        this.rows = map;
    }

    public void setRowCount(int i) {
        this.rowCount = i;
    }

    public void setColumnCount(int i) {
        this.columnCount = i;
    }

    public void setDefaultValue(double d) {
        this.defaultValue = d;
    }
}
