fit method

void fit(
  1. List<List<double>> data
)

Computes the mean and standard deviation per feature (column).

Implementation

void fit(List<List<double>> data) {
  validate2DInput(data);
  final numFeatures = data[0].length;
  final numSamples = data.length;

  _mean = List.filled(numFeatures, 0.0);
  _std = List.filled(numFeatures, 0.0);

  // Mean
  for (final row in data) {
    for (int i = 0; i < numFeatures; i++) {
      _mean[i] += row[i];
    }
  }
  for (int i = 0; i < numFeatures; i++) {
    _mean[i] /= numSamples;
  }

  // Std
  for (final row in data) {
    for (int i = 0; i < numFeatures; i++) {
      final diff = row[i] - _mean[i];
      _std[i] += diff * diff;
    }
  }
  for (int i = 0; i < numFeatures; i++) {
    _std[i] = (_std[i] / numSamples).sqrtSafe(); // sqrt-safe
  }

  _isFitted = true;
}