/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.instructions.fed;

import java.util.ArrayList;
import org.apache.commons.lang3.ArrayUtils;
import org.apache.sysds.hops.fedplanner.FTypes;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederationMap;
import org.apache.sysds.runtime.controlprogram.federated.FederationUtils;
import org.apache.sysds.runtime.controlprogram.federated.MatrixLineagePair;
import org.apache.sysds.runtime.instructions.cp.CPOperand;
import org.apache.sysds.runtime.instructions.fed.FEDInstruction;
import org.apache.sysds.runtime.instructions.fed.QuaternaryFEDInstruction;
import org.apache.sysds.runtime.matrix.operators.Operator;

public class QuaternaryWUMMFEDInstruction
extends QuaternaryFEDInstruction {
    protected QuaternaryWUMMFEDInstruction(Operator operator, CPOperand in1, CPOperand in2, CPOperand in3, CPOperand out, String opcode, String instruction_str) {
        super(FEDInstruction.FEDType.Quaternary, operator, in1, in2, in3, out, opcode, instruction_str);
    }

    @Override
    public void processInstruction(ExecutionContext ec) {
        FederatedRequest frComp;
        FederationMap fedMap;
        MatrixObject X = ec.getMatrixObject(this.input1);
        MatrixLineagePair U = ec.getMatrixLineagePair(this.input2);
        MatrixLineagePair V = ec.getMatrixLineagePair(this.input3);
        if (X.isFederated()) {
            fedMap = X.getFedMapping();
            FederatedRequest[] frSliced = null;
            FederatedRequest frB = null;
            long[] varNewIn = new long[3];
            varNewIn[0] = fedMap.getID();
            if (X.isFederated(FTypes.FType.ROW)) {
                if (U.isFederated(FTypes.FType.ROW) && fedMap.isAligned(U.getFedMapping(), FTypes.AlignType.ROW)) {
                    varNewIn[1] = U.getFedMapping().getID();
                } else {
                    frSliced = fedMap.broadcastSliced(U, false);
                    varNewIn[1] = frSliced[0].getID();
                }
                frB = fedMap.broadcast(V);
                varNewIn[2] = frB.getID();
            } else if (X.isFederated(FTypes.FType.COL)) {
                frB = fedMap.broadcast(U);
                varNewIn[1] = frB.getID();
                if (V.isFederated() && fedMap.isAligned(V.getFedMapping(), FTypes.AlignType.COL, FTypes.AlignType.COL_T)) {
                    varNewIn[2] = V.getFedMapping().getID();
                } else {
                    frSliced = fedMap.broadcastSliced(V, true);
                    varNewIn[2] = frSliced[0].getID();
                }
            } else {
                throw new DMLRuntimeException("Federated WUMM only supported for ROW or COLUMN partitioned federated data.");
            }
            frComp = FederationUtils.callInstruction(this.instString, this.output, new CPOperand[]{this.input1, this.input2, this.input3}, varNewIn);
            ArrayList frC = new ArrayList();
            FederatedRequest[] frAll = (FederatedRequest[])ArrayUtils.addAll((Object[])new FederatedRequest[]{frB, frComp}, (Object[])frC.toArray(new FederatedRequest[0]));
            if (frSliced == null) {
                fedMap.execute(this.getTID(), true, frAll);
            } else {
                fedMap.execute(this.getTID(), true, frSliced, frAll);
            }
        } else {
            throw new DMLRuntimeException("Unsupported federated inputs (X, U, V) = (" + X.isFederated() + ", " + U.isFederated() + ", " + V.isFederated() + ")");
        }
        MatrixObject out = ec.getMatrixObject(this.output);
        out.setFedMapping(fedMap.copyWithNewID(frComp.getID()));
        this.setOutputDataCharacteristics(X, U.getMO(), V.getMO(), ec);
    }
}

