/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.controlprogram.paramserv.dp;

import java.util.List;
import java.util.concurrent.Future;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.MatrixObject;
import org.apache.sysds.runtime.controlprogram.context.ExecutionContext;
import org.apache.sysds.runtime.controlprogram.federated.FederatedData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest;
import org.apache.sysds.runtime.controlprogram.federated.FederatedResponse;
import org.apache.sysds.runtime.controlprogram.federated.FederatedUDF;
import org.apache.sysds.runtime.controlprogram.paramserv.ParamservUtils;
import org.apache.sysds.runtime.controlprogram.paramserv.dp.DataPartitionFederatedScheme;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.lineage.LineageItem;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.meta.DataCharacteristics;

public class ReplicateToMaxFederatedScheme
extends DataPartitionFederatedScheme {
    @Override
    public DataPartitionFederatedScheme.Result partition(MatrixObject features, MatrixObject labels, int seed) {
        List<MatrixObject> pFeatures = ReplicateToMaxFederatedScheme.sliceFederatedMatrix(features);
        List<MatrixObject> pLabels = ReplicateToMaxFederatedScheme.sliceFederatedMatrix(labels);
        List<Double> weightingFactors = ReplicateToMaxFederatedScheme.getWeightingFactors(pFeatures, ReplicateToMaxFederatedScheme.getBalanceMetrics(pFeatures));
        int max_rows = 0;
        for (MatrixObject pFeature : pFeatures) {
            max_rows = pFeature.getNumRows() > (long)max_rows ? Math.toIntExact(pFeature.getNumRows()) : max_rows;
        }
        for (int i = 0; i < pFeatures.size(); ++i) {
            FederatedData featuresData = pFeatures.get(i).getFedMapping().getFederatedData()[0];
            FederatedData labelsData = pLabels.get(i).getFedMapping().getFederatedData()[0];
            Future<FederatedResponse> udfResponse = featuresData.executeFederatedOperation(new FederatedRequest(FederatedRequest.RequestType.EXEC_UDF, featuresData.getVarID(), new replicateDataOnFederatedWorker(new long[]{featuresData.getVarID(), labelsData.getVarID()}, seed, max_rows)));
            try {
                FederatedResponse response = udfResponse.get();
                if (!response.isSuccessful()) {
                    throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: replicate UDF returned fail");
                }
            }
            catch (Exception e) {
                throw new DMLRuntimeException("FederatedDataPartitioner ReplicateFederatedScheme: executing replicate UDF failed" + e.getMessage());
            }
            DataCharacteristics update = pFeatures.get(i).getDataCharacteristics().setRows(max_rows);
            pFeatures.get(i).updateDataCharacteristics(update);
            update = pLabels.get(i).getDataCharacteristics().setRows(max_rows);
            pLabels.get(i).updateDataCharacteristics(update);
        }
        return new DataPartitionFederatedScheme.Result(pFeatures, pLabels, pFeatures.size(), ReplicateToMaxFederatedScheme.getBalanceMetrics(pFeatures), weightingFactors);
    }

    private static class replicateDataOnFederatedWorker
    extends FederatedUDF {
        private static final long serialVersionUID = -6930898456315100587L;
        private final int _seed;
        private final int _max_rows;

        protected replicateDataOnFederatedWorker(long[] inIDs, int seed, int max_rows) {
            super(inIDs);
            this._seed = seed;
            this._max_rows = max_rows;
        }

        @Override
        public FederatedResponse execute(ExecutionContext ec, Data ... data) {
            MatrixObject features = (MatrixObject)data[0];
            MatrixObject labels = (MatrixObject)data[1];
            if (features.getNumRows() < (long)this._max_rows) {
                int num_rows_needed = this._max_rows - Math.toIntExact(features.getNumRows());
                MatrixBlock replicateMatrixBlock = ParamservUtils.generateReplicationMatrix(num_rows_needed, Math.toIntExact(features.getNumRows()), this._seed);
                DataPartitionFederatedScheme.replicateTo(features, replicateMatrixBlock);
                DataPartitionFederatedScheme.replicateTo(labels, replicateMatrixBlock);
            }
            return new FederatedResponse(FederatedResponse.ResponseType.SUCCESS);
        }

        @Override
        public Pair<String, LineageItem> getLineageItem(ExecutionContext ec) {
            return null;
        }
    }
}

