Source: mSigSDKScripts/machineLearning.js

import * as MLR from "https://esm.sh/ml-regression-multivariate-linear@2.0.4"

import * as CV from "https://esm.sh/ml-cross-validation@1.3.0";

import { groupBy } from "./utils.js";

/**
 * Preprocesses mutational and exposure data for a given data source.
 * Currently only supports data from MSigDB portal.
 * @param {Array} mutationalData - Array of mutational data.
 * @memberof machineLearning
 * @param {Array} exposureData - Array of exposure data.
 * @param {string} dataSource - Data source identifier.
 * @returns {Object} - Object containing input (Xs) and output (Ys) arrays for regression.
 * @throws {Error} - If an unknown data source is provided.
 */

export function preprocessData(mutationalData, exposureData, dataSource) {
  switch (dataSource.toUpperCase()) {
    case "MSIGPORTAL":
      return preprocessMSIGPORTALExposureData(mutationalData, exposureData);
    case "ICGC":
      return null;
    default:
      throw new Error("Unknown data source: " + dataSource);
  }
}

function intersectByKeys(dict1, dict2) {
  const intersection = {};
  for (const key in dict1) {
    if (key in dict2) {
      intersection[key] = [
        dict1[key].map((data) => data["mutations"]),
        dict2[key].map((data) => data["exposure"]),
      ];
    }
  }
  return intersection;
}

function preprocessMSIGPORTALExposureData(mutationalData, exposureData) {
  // Group the data by the column "sampleName"
  const groupedMutationalData = groupBy(mutationalData, "sample");
  const groupedExposureData = groupBy(exposureData, "sample");

  const intersectedData = intersectByKeys(
    groupedMutationalData,
    groupedExposureData
  );

  // Separate the intersected data into Xs and Ys
  const Xs = [];
  const Ys = [];
  for (const key in intersectedData) {
    // Check if the length of the Xs is 96 (i.e. the number of mutational signatures) and the length of the Ys is 65 (i.e. the number of mutational exposures)

    if (
      intersectedData[key][0].length !==
        intersectedData[Object.keys(intersectedData)[0]][0].length ||
      intersectedData[key][1].length !==
        intersectedData[Object.keys(intersectedData)[0]][1].length
    ) {
      continue;
    } else {
      Xs.push(intersectedData[key][0]);
      Ys.push(intersectedData[key][1]);
    }
  }
  return { Xs, Ys };
}

/**
 * Performs k-fold stratified cross-validation for multivariate linear or MLP regression models.
 * @memberof machineLearning
 * @param {Array} Xs - Array of input data.
 * @param {Array} Ys - Array of output data.
 * @param {number} [k=10] - Number of folds for cross-validation.
 * @param {string} [modelType="MLR"] - Regression model type ("MLR" for multivariate linear regression or "MLP" for multilayer perceptron).
 * @returns {Object} - Object containing an array of trained regression models, an array of mean squared errors for each fold, and the average mean squared error across all folds.
 * @throws {Error} - If an unknown model type is provided.
 */
export function kFoldCV(Xs, Ys, k = 10, modelType = "MLR") {
  // Prepare the dataset for stratified k-fold cross-validation
  const dataset = [];
  for (let i = 0; i < Xs.length; i++) {
    dataset.push({
      input: Xs[i],
      output: Ys[i],
    });
  }

  // Create a stratified k-fold cross-validator
  const crossValidator = CV.getFolds(dataset, k);

  // Initialize variables to store performance metrics
  let totalMSE = 0;

  const models = [];
  const mses = [];
  // Perform stratified k-fold cross-validation
  crossValidator.forEach((crossFold) => {
    // Prepare the training data

    const X_train = crossFold.trainIndex.map((index) => Xs[index]);

    const Y_train = crossFold.trainIndex.map((index) => Ys[index]);

    // Prepare the testing data
    const X_test = crossFold.testIndex.map((index) => Xs[index]);
    const Y_test = crossFold.testIndex.map((index) => Ys[index]);

    // Train the multivariate linear regression model
    let regression;

    switch (modelType.toUpperCase()) {
      case "MLR":
        regression = new MLR.default(X_train, Y_train);
        models.push(regression);
        break;
      case "MLP":
        regression = new MLP.default(X_train, Y_train);
        models.push(regression);
        break;
      default:
        throw new Error("Unknown model type: " + modelType);
    }

    // Test the model and calculate the mean squared error
    let mse = 0;
    for (let i = 0; i < X_test.length; i++) {
      const prediction = regression.predict(X_test[i]);
      mse += meanSquaredError(prediction, Y_test[i]);
    }
    mse /= X_test.length;
    mses.push(mse);
    // Accumulate the mean squared error
    totalMSE += mse;
  });

  // Calculate the average mean squared error
  const averageMSE = totalMSE / k;

  // Return the average mean squared error
  return { model: models, MSE: mses, averageMSE: averageMSE };
}

function meanSquaredError(prediction, actual) {
  let mse = 0;
  for (let i = 0; i < prediction.length; i++) {
    mse += Math.pow(prediction[i] - actual[i], 2);
  }
  return mse;
}