/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.datagen;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.datagen.genranddata4logreg_ltstats.GenerateWeights_output;

public class GenRandData4LogReg_LTstats
extends Script {
    public GenRandData4LogReg_LTstats() {
        String string = "scripts/datagen/genRandData4LogReg_LTstats.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public Matrix straightenX(Object object) {
        String string = "source('scripts/datagen/genRandData4LogReg_LTstats.dml') as mlcontextns;w = mlcontextns::straightenX(X);";
        Script script = new Script(string);
        script.in("X", object).out("w");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("w");
        return matrix;
    }

    public String straightenX__docs() {
        String string = "straightenX =\n    function (Matrix[double] X)\n    return   (Matrix[double] w)\n{\n    w_X = t(colSums(X));\n    lambda_LS = 0.000001 * sum(X ^ 2) / ncol(X);\n    eps = 0.000000001 * nrow(X);\n\n    # BEGIN LEAST SQUARES\n    \n    r_LS = - w_X;\n    z_LS = matrix (0.0, rows = ncol(X), cols = 1);\n    p_LS = - r_LS;\n    norm_r2_LS = sum (r_LS ^ 2);\n    i_LS = 0;\n    while (i_LS < 50 & i_LS < ncol(X) & norm_r2_LS >= eps)\n    {\n        temp_LS = X %*% p_LS;\n        q_LS = (t(X) %*% temp_LS) + lambda_LS * p_LS;\n        alpha_LS = norm_r2_LS / sum (p_LS * q_LS);\n        z_LS = z_LS + alpha_LS * p_LS;\n        old_norm_r2_LS = norm_r2_LS;\n        r_LS = r_LS + alpha_LS * q_LS;\n        norm_r2_LS = sum (r_LS ^ 2);\n        p_LS = -r_LS + (norm_r2_LS / old_norm_r2_LS) * p_LS;\n        i_LS = i_LS + 1;\n    }\n    \n    # END LEAST SQUARES\n    \n    w = (nrow(X) / sum (w_X * z_LS)) * z_LS;\n}\n";
        return string;
    }

    public String straightenX__source() {
        String string = "straightenX =\n    function (Matrix[double] X)\n    return   (Matrix[double] w)\n{\n    w_X = t(colSums(X));\n    lambda_LS = 0.000001 * sum(X ^ 2) / ncol(X);\n    eps = 0.000000001 * nrow(X);\n\n    # BEGIN LEAST SQUARES\n    \n    r_LS = - w_X;\n    z_LS = matrix (0.0, rows = ncol(X), cols = 1);\n    p_LS = - r_LS;\n    norm_r2_LS = sum (r_LS ^ 2);\n    i_LS = 0;\n    while (i_LS < 50 & i_LS < ncol(X) & norm_r2_LS >= eps)\n    {\n        temp_LS = X %*% p_LS;\n        q_LS = (t(X) %*% temp_LS) + lambda_LS * p_LS;\n        alpha_LS = norm_r2_LS / sum (p_LS * q_LS);\n        z_LS = z_LS + alpha_LS * p_LS;\n        old_norm_r2_LS = norm_r2_LS;\n        r_LS = r_LS + alpha_LS * q_LS;\n        norm_r2_LS = sum (r_LS ^ 2);\n        p_LS = -r_LS + (norm_r2_LS / old_norm_r2_LS) * p_LS;\n        i_LS = i_LS + 1;\n    }\n    \n    # END LEAST SQUARES\n    \n    w = (nrow(X) / sum (w_X * z_LS)) * z_LS;\n}\n";
        return string;
    }

    public GenerateWeights_output generateWeights(Object object, Object object2, Object object3) {
        String string = "source('scripts/datagen/genRandData4LogReg_LTstats.dml') as mlcontextns;[W, new_sigmaLT] = mlcontextns::generateWeights(X, meanLT, sigmaLT);";
        Script script = new Script(string);
        script.in("X", object).in("meanLT", object2).in("sigmaLT", object3).out("W").out("new_sigmaLT");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("W");
        Matrix matrix2 = mLResults.getMatrix("new_sigmaLT");
        GenerateWeights_output generateWeights_output = new GenerateWeights_output(matrix, matrix2);
        return generateWeights_output;
    }

    public String generateWeights__docs() {
        String string = "generateWeights = \n    function (Matrix[double] X, Matrix[double] meanLT, Matrix[double] sigmaLT)\n    return   (Matrix[double] W, Matrix[double] new_sigmaLT)\n{\n    num_w = ncol (meanLT);  # Number of output weight vectors\n    dim_w = ncol (X);       # Number of features / dimensions in a weight vector\n    w_X = t(colSums(X));    # \"Prohibited\" weight shift direction that changes meanLT\n                            # (all orthogonal shift directions do not affect meanLT)\n\n    # Compute \"w_1\" with meanLT = 1 and with the smallest possible sigmaLT\n\n    w_1 = straightenX (X);\n    r_1 = (X %*% w_1) - 1.0;\n    norm_r_1_sq = sum (r_1 ^ 2);\n    \n    # For each W[, i] generate uniformly random directions to shift away from \"w_1\"\n    \n    DW_raw = Rand (rows = dim_w, cols = num_w, pdf = \"normal\");\n    DW = DW_raw - (w_X %*% t(w_X) %*% DW_raw) / sum (w_X ^ 2); # Orthogonal to w_X\n    XDW = X %*% DW;\n    \n    # Determine how far to shift in the chosen directions to satisfy the constraints\n    # Use the positive root of the quadratic equation; relax sigmaLT where needed\n    \n    a_qe = colSums (XDW ^ 2);\n    b_qe = 2.0 * meanLT * (t(r_1) %*% XDW);\n    c_qe = meanLT^2 * norm_r_1_sq - sigmaLT^2 * nrow(X);\n\n    is_sigmaLT_OK = (c_qe <= 0);\n    new_sigmaLT = is_sigmaLT_OK * sigmaLT + (1 - is_sigmaLT_OK) * abs (meanLT) * sqrt (norm_r_1_sq / nrow(X));\n    c_qe = is_sigmaLT_OK * c_qe;\n    x_qe = (- b_qe + sqrt (b_qe * b_qe - 4.0 * a_qe * c_qe)) / (2.0 * a_qe);\n    \n    # Scale and shift \"w_1\" in the \"DW\" directions to produce the result:\n    \n    ones = matrix (1.0, rows = dim_w, cols = 1);\n    W = w_1 %*% meanLT + DW * (ones %*% x_qe);\n}\n";
        return string;
    }

    public String generateWeights__source() {
        String string = "generateWeights = \n    function (Matrix[double] X, Matrix[double] meanLT, Matrix[double] sigmaLT)\n    return   (Matrix[double] W, Matrix[double] new_sigmaLT)\n{\n    num_w = ncol (meanLT);  # Number of output weight vectors\n    dim_w = ncol (X);       # Number of features / dimensions in a weight vector\n    w_X = t(colSums(X));    # \"Prohibited\" weight shift direction that changes meanLT\n                            # (all orthogonal shift directions do not affect meanLT)\n\n    # Compute \"w_1\" with meanLT = 1 and with the smallest possible sigmaLT\n\n    w_1 = straightenX (X);\n    r_1 = (X %*% w_1) - 1.0;\n    norm_r_1_sq = sum (r_1 ^ 2);\n    \n    # For each W[, i] generate uniformly random directions to shift away from \"w_1\"\n    \n    DW_raw = Rand (rows = dim_w, cols = num_w, pdf = \"normal\");\n    DW = DW_raw - (w_X %*% t(w_X) %*% DW_raw) / sum (w_X ^ 2); # Orthogonal to w_X\n    XDW = X %*% DW;\n    \n    # Determine how far to shift in the chosen directions to satisfy the constraints\n    # Use the positive root of the quadratic equation; relax sigmaLT where needed\n    \n    a_qe = colSums (XDW ^ 2);\n    b_qe = 2.0 * meanLT * (t(r_1) %*% XDW);\n    c_qe = meanLT^2 * norm_r_1_sq - sigmaLT^2 * nrow(X);\n\n    is_sigmaLT_OK = (c_qe <= 0);\n    new_sigmaLT = is_sigmaLT_OK * sigmaLT + (1 - is_sigmaLT_OK) * abs (meanLT) * sqrt (norm_r_1_sq / nrow(X));\n    c_qe = is_sigmaLT_OK * c_qe;\n    x_qe = (- b_qe + sqrt (b_qe * b_qe - 4.0 * a_qe * c_qe)) / (2.0 * a_qe);\n    \n    # Scale and shift \"w_1\" in the \"DW\" directions to produce the result:\n    \n    ones = matrix (1.0, rows = dim_w, cols = 1);\n    W = w_1 %*% meanLT + DW * (ones %*% x_qe);\n}\n";
        return string;
    }
}

