package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.HashMap;
import java.util.Map;
import org.apache.commons.lang3.NotImplementedException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.frame.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.encode.ColumnEncoder;
import org.apache.sysds.runtime.util.UtilFunctions;

/* loaded from: input_file:org/apache/sysds/runtime/transform/encode/ColumnEncoderWordEmbedding.class */
public class ColumnEncoderWordEmbedding extends ColumnEncoder {
    private MatrixBlock _wordEmbeddings;
    private Map<Object, Long> _rcdMap;
    private HashMap<String, double[]> _embMap;

    public ColumnEncoderWordEmbedding() {
        super(-1);
        this._rcdMap = new HashMap();
        this._wordEmbeddings = new MatrixBlock();
    }

    private long lookupRCDMap(Object obj) {
        return this._rcdMap.getOrDefault(obj, -1L).longValue();
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public int getDomainSize() {
        return this._wordEmbeddings.getNumColumns();
    }

    public int getNrDistinctEmbeddings() {
        return this._wordEmbeddings.getNumRows();
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public ColumnEncoderWordEmbedding(int i) {
        super(i);
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected double getCode(CacheBlock<?> cacheBlock, int i) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected double[] getCodeCol(CacheBlock<?> cacheBlock, int i, int i2, double[] dArr) {
        throw new NotImplementedException();
    }

    private double[] getEmbeddedingFromEmbeddingMatrix(long j) {
        double[] dArr = new double[getDomainSize()];
        for (int i = 0; i < getDomainSize(); i++) {
            dArr[i] = this._wordEmbeddings.quickGetValue((int) j, (this._colID - 1) + i);
        }
        return dArr;
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public void applyDense(CacheBlock<?> cacheBlock, MatrixBlock matrixBlock, int i, int i2, int i3) {
        double[] dArr;
        int endIndex = UtilFunctions.getEndIndex(cacheBlock.getNumRows(), i2, i3);
        for (int i4 = i2; i4 < endIndex; i4++) {
            String string = cacheBlock.getString(i4, this._colID - 1);
            if (string != null && !string.isEmpty() && (dArr = this._embMap.get(string)) != null) {
                matrixBlock.quickSetRow(i4, dArr);
            }
        }
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    protected ColumnEncoder.TransformType getTransformType() {
        return ColumnEncoder.TransformType.WORD_EMBEDDING;
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void build(CacheBlock<?> cacheBlock) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void allocateMetaData(FrameBlock frameBlock) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public FrameBlock getMetaData(FrameBlock frameBlock) {
        throw new NotImplementedException();
    }

    @Override // org.apache.sysds.runtime.transform.encode.Encoder
    public void initMetaData(FrameBlock frameBlock) {
        if (frameBlock == null || frameBlock.getNumRows() <= 0) {
            return;
        }
        this._rcdMap = frameBlock.getRecodeMap(this._colID - 1);
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder
    public void initEmbeddings(MatrixBlock matrixBlock) {
        this._wordEmbeddings = matrixBlock;
        this._embMap = new HashMap<>();
        this._rcdMap.forEach((obj, l) -> {
            this._embMap.put((String) obj, getEmbeddedingFromEmbeddingMatrix(l.longValue() - 1));
        });
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, java.io.Externalizable
    public void writeExternal(ObjectOutput objectOutput) throws IOException {
        super.writeExternal(objectOutput);
        objectOutput.writeInt(this._rcdMap.size());
        for (Map.Entry<Object, Long> entry : this._rcdMap.entrySet()) {
            objectOutput.writeUTF(entry.getKey().toString());
            objectOutput.writeLong(entry.getValue().longValue());
        }
        this._wordEmbeddings.write(objectOutput);
    }

    @Override // org.apache.sysds.runtime.transform.encode.ColumnEncoder, java.io.Externalizable
    public void readExternal(ObjectInput objectInput) throws IOException {
        super.readExternal(objectInput);
        int readInt = objectInput.readInt();
        for (int i = 0; i < readInt; i++) {
            this._rcdMap.put(objectInput.readUTF(), Long.valueOf(objectInput.readLong()));
        }
        this._wordEmbeddings.readExternal(objectInput);
        initEmbeddings(this._wordEmbeddings);
    }
}
