/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysml.scripts.nn.examples;

import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import org.apache.sysml.api.mlcontext.MLResults;
import org.apache.sysml.api.mlcontext.Matrix;
import org.apache.sysml.api.mlcontext.Script;
import org.apache.sysml.scripts.nn.examples.mnist_softmax.Eval_output;
import org.apache.sysml.scripts.nn.examples.mnist_softmax.Generate_dummy_data_output;
import org.apache.sysml.scripts.nn.examples.mnist_softmax.Train_output;

public class Mnist_softmax
extends Script {
    public Mnist_softmax() {
        String string = "scripts/nn/examples/mnist_softmax.dml";
        InputStream inputStream = Script.class.getResourceAsStream(new StringBuffer().append("/").append(string).toString());
        InputStreamReader inputStreamReader = new InputStreamReader(inputStream);
        char[] cArray = new char[1024];
        StringBuilder stringBuilder = new StringBuilder();
        try {
            int n;
            while ((n = inputStreamReader.read(cArray)) > 0) {
                stringBuilder.append(cArray, 0, n);
            }
        }
        catch (IOException iOException) {
            iOException.printStackTrace();
        }
        this.setScriptString(stringBuilder.toString());
    }

    public Generate_dummy_data_output generate_dummy_data() {
        String string = "source('scripts/nn/examples/mnist_softmax.dml') as mlcontextns;[X, Y, C, Hin, Win] = mlcontextns::generate_dummy_data();";
        Script script = new Script(string);
        script.out("X").out("Y").out("C").out("Hin").out("Win");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("X");
        Matrix matrix2 = mLResults.getMatrix("Y");
        long l = mLResults.getLong("C");
        long l2 = mLResults.getLong("Hin");
        long l3 = mLResults.getLong("Win");
        Generate_dummy_data_output generate_dummy_data_output = new Generate_dummy_data_output(matrix, matrix2, l, l2, l3);
        return generate_dummy_data_output;
    }

    public String generate_dummy_data__docs() {
        String string = "generate_dummy_data = function()\n    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {\n  /*\n   * Generate a dummy dataset similar to the MNIST dataset.\n   *\n   * Outputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - Y: Target matrix, of shape (N, K).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   */\n";
        return string;
    }

    public String generate_dummy_data__source() {
        String string = "generate_dummy_data = function()\n    return (matrix[double] X, matrix[double] Y, int C, int Hin, int Win) {\n  /*\n   * Generate a dummy dataset similar to the MNIST dataset.\n   *\n   * Outputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - Y: Target matrix, of shape (N, K).\n   *  - C: Number of input channels (dimensionality of input depth).\n   *  - Hin: Input height.\n   *  - Win: Input width.\n   */\n  # Generate dummy input data\n  N = 1024  # num examples\n  C = 1  # num input channels\n  Hin = 28  # input height\n  Win = 28  # input width\n  T = 10  # num targets\n  X = rand(rows=N, cols=C*Hin*Win, pdf=\"normal\")\n  classes = round(rand(rows=N, cols=1, min=1, max=T, pdf=\"uniform\"))\n  Y = table(seq(1, N), classes)  # one-hot encoding\n}\n";
        return string;
    }

    public Eval_output eval(Object object, Object object2) {
        String string = "source('scripts/nn/examples/mnist_softmax.dml') as mlcontextns;[loss, accuracy] = mlcontextns::eval(probs, Y);";
        Script script = new Script(string);
        script.in("probs", object).in("Y", object2).out("loss").out("accuracy");
        MLResults mLResults = script.execute();
        double d = mLResults.getDouble("loss");
        double d2 = mLResults.getDouble("accuracy");
        Eval_output eval_output = new Eval_output(d, d2);
        return eval_output;
    }

    public String eval__docs() {
        String string = "eval = function(matrix[double] probs, matrix[double] Y)\n    return (double loss, double accuracy) {\n  /*\n   * Evaluates a softmax classifier.\n   *\n   * The probs matrix contains the class probability predictions\n   * of K classes over N examples.  The targets, Y, have K classes,\n   * and are one-hot encoded.\n   *\n   * Inputs:\n   *  - probs: Class probabilities, of shape (N, K).\n   *  - Y: Target matrix, of shape (N, K).\n   *\n   * Outputs:\n   *  - loss: Scalar loss, of shape (1).\n   *  - accuracy: Scalar accuracy, of shape (1).\n   */\n";
        return string;
    }

    public String eval__source() {
        String string = "eval = function(matrix[double] probs, matrix[double] Y)\n    return (double loss, double accuracy) {\n  /*\n   * Evaluates a softmax classifier.\n   *\n   * The probs matrix contains the class probability predictions\n   * of K classes over N examples.  The targets, Y, have K classes,\n   * and are one-hot encoded.\n   *\n   * Inputs:\n   *  - probs: Class probabilities, of shape (N, K).\n   *  - Y: Target matrix, of shape (N, K).\n   *\n   * Outputs:\n   *  - loss: Scalar loss, of shape (1).\n   *  - accuracy: Scalar accuracy, of shape (1).\n   */\n  # Compute loss & accuracy\n  loss = cross_entropy_loss::forward(probs, Y)\n  correct_pred = rowIndexMax(probs) == rowIndexMax(Y)\n  accuracy = mean(correct_pred)\n}\n";
        return string;
    }

    public Matrix predict(Object object, Object object2, Object object3) {
        String string = "source('scripts/nn/examples/mnist_softmax.dml') as mlcontextns;probs = mlcontextns::predict(X, W, b);";
        Script script = new Script(string);
        script.in("X", object).in("W", object2).in("b", object3).out("probs");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("probs");
        return matrix;
    }

    public String predict__docs() {
        String string = "predict = function(matrix[double] X, matrix[double] W, matrix[double] b)\n    return (matrix[double] probs) {\n  /*\n   * Computes the class probability predictions of a softmax classifier.\n   *\n   * The input matrix, X, has N examples, each with D features.\n   *\n   * Inputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - W: Weights (parameters) matrix, of shape (D, M).\n   *  - b: Biases vector, of shape (1, M).\n   *\n   * Outputs:\n   *  - probs: Class probabilities, of shape (N, K).\n   */\n";
        return string;
    }

    public String predict__source() {
        String string = "predict = function(matrix[double] X, matrix[double] W, matrix[double] b)\n    return (matrix[double] probs) {\n  /*\n   * Computes the class probability predictions of a softmax classifier.\n   *\n   * The input matrix, X, has N examples, each with D features.\n   *\n   * Inputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - W: Weights (parameters) matrix, of shape (D, M).\n   *  - b: Biases vector, of shape (1, M).\n   *\n   * Outputs:\n   *  - probs: Class probabilities, of shape (N, K).\n   */\n  # Compute forward pass\n  ## affine & softmax:\n  out = affine::forward(X, W, b)\n  probs = softmax::forward(out)\n}\n";
        return string;
    }

    public Train_output train(Object object, Object object2, Object object3, Object object4, Object object5) {
        String string = "source('scripts/nn/examples/mnist_softmax.dml') as mlcontextns;[W, b] = mlcontextns::train(X, Y, X_val, Y_val, epochs);";
        Script script = new Script(string);
        script.in("X", object).in("Y", object2).in("X_val", object3).in("Y_val", object4).in("epochs", object5).out("W").out("b");
        MLResults mLResults = script.execute();
        Matrix matrix = mLResults.getMatrix("W");
        Matrix matrix2 = mLResults.getMatrix("b");
        Train_output train_output = new Train_output(matrix, matrix2);
        return train_output;
    }

    public String train__docs() {
        String string = "train = function(matrix[double] X, matrix[double] Y,\n                 matrix[double] X_val, matrix[double] Y_val,\n                 int epochs)\n    return (matrix[double] W, matrix[double] b) {\n  /*\n   * Trains a softmax classifier.\n   *\n   * The input matrix, X, has N examples, each with D features.\n   * The targets, Y, have K classes, and are one-hot encoded.\n   *\n   * Inputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - Y: Target matrix, of shape (N, K).\n   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).\n   *  - Y_val: Target validation matrix, of shape (N, K).\n   *  - epochs: Total number of full training loops over the full data set.\n   *\n   * Outputs:\n   *  - W: Weights (parameters) matrix, of shape (D, M).\n   *  - b: Biases vector, of shape (1, M).\n   */\n";
        return string;
    }

    public String train__source() {
        String string = "train = function(matrix[double] X, matrix[double] Y,\n                 matrix[double] X_val, matrix[double] Y_val,\n                 int epochs)\n    return (matrix[double] W, matrix[double] b) {\n  /*\n   * Trains a softmax classifier.\n   *\n   * The input matrix, X, has N examples, each with D features.\n   * The targets, Y, have K classes, and are one-hot encoded.\n   *\n   * Inputs:\n   *  - X: Input data matrix, of shape (N, D).\n   *  - Y: Target matrix, of shape (N, K).\n   *  - X_val: Input validation data matrix, of shape (N, C*Hin*Win).\n   *  - Y_val: Target validation matrix, of shape (N, K).\n   *  - epochs: Total number of full training loops over the full data set.\n   *\n   * Outputs:\n   *  - W: Weights (parameters) matrix, of shape (D, M).\n   *  - b: Biases vector, of shape (1, M).\n   */\n  N = nrow(X)  # num examples\n  D = ncol(X)  # num features\n  K = ncol(Y)  # num classes\n\n  # Create softmax classifier:\n  # affine -> softmax\n  [W, b] = affine::init(D, K)\n  W = W / sqrt(2.0/(D)) * sqrt(1/(D))\n\n  # Initialize SGD w/ Nesterov momentum optimizer\n  lr = 0.2  # learning rate\n  mu = 0  # momentum\n  decay = 0.99  # learning rate decay constant\n  vW = sgd_nesterov::init(W)  # optimizer momentum state for W\n  vb = sgd_nesterov::init(b)  # optimizer momentum state for b\n\n  # Optimize\n  print(\"Starting optimization\")\n  batch_size = 50\n  iters = 1000 #ceil(N / batch_size)\n  for (e in 1:epochs) {\n    for(i in 1:iters) {\n      # Get next batch\n      beg = ((i-1) * batch_size) %% N + 1\n      end = min(N, beg + batch_size - 1)\n      X_batch = X[beg:end,]\n      y_batch = Y[beg:end,]\n\n      # Compute forward pass\n      ## affine & softmax:\n      out = affine::forward(X_batch, W, b)\n      probs = softmax::forward(out)\n\n      # Compute loss & accuracy for training & validation data\n      loss = cross_entropy_loss::forward(probs, y_batch)\n      accuracy = mean(rowIndexMax(probs) == rowIndexMax(y_batch))\n      probs_val = predict(X_val, W, b)\n      loss_val = cross_entropy_loss::forward(probs_val, Y_val)\n      accuracy_val = mean(rowIndexMax(probs_val) == rowIndexMax(Y_val))\n      print(\"Epoch: \" + e + \", Iter: \" + i + \", Train Loss: \" + loss + \", Train Accuracy: \" +\n            accuracy + \", Val Loss: \" + loss_val + \", Val Accuracy: \" + accuracy_val)\n\n      # Compute backward pass\n      ## loss:\n      dprobs = cross_entropy_loss::backward(probs, y_batch)\n      ## affine & softmax:\n      dout = softmax::backward(dprobs, out)\n      [dX_batch, dW, db] = affine::backward(dout, X_batch, W, b)\n\n      # Optimize with SGD w/ Nesterov momentum\n      [W, vW] = sgd_nesterov::update(W, dW, lr, mu, vW)\n      [b, vb] = sgd_nesterov::update(b, db, lr, mu, vb)\n    }\n    # Anneal momentum towards 0.999\n    mu = mu + (0.999 - mu)/(1+epochs-e)\n    # Decay learning rate\n    lr = lr * decay\n  }\n}\n";
        return string;
    }
}

