/* * Copyright (c) Meta Platforms, Inc. and affiliates. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #pragma once #include #include #include #include #include #include #include namespace folly { // Robust and efficient online computation of statistics, // using Welford's method for variance. // https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Welford's_online_algorithm template class StreamingStats final { // Caclulated statistic result has to be floating point type static_assert(std::is_floating_point_v); public: struct StreamingState { size_t count = 0; StatsType mean = 0; StatsType m2 = 0; SampleDataType min = std::numeric_limits::max(); SampleDataType max = std::numeric_limits::lowest(); }; template StreamingStats(Iterator first, Iterator last) noexcept { add(first, last); } explicit StreamingStats(StreamingState state) : count_(state.count), mean_(state.mean), m2_(state.m2), min_(state.min), max_(state.max) {} StreamingStats() = default; ~StreamingStats() = default; /// Add sample data via iteratation template void add(Iterator first, Iterator last) noexcept { for (auto it = first; it != last; ++it) { add(*it); } } /// Add a single sample void add(SampleDataType value) noexcept { max_ = std::max(max_, value); min_ = std::min(min_, value); ++count_; StatsType const delta = value - mean_; mean_ += delta / count_; StatsType const delta2 = value - mean_; m2_ += delta * delta2; } /// Merge with an existing StreamingStats object void merge(StreamingStats const& other) { if (other.count_ == 0) { return; } max_ = std::max(max_, other.max_); min_ = std::min(min_, other.min_); size_t const new_size = count_ + other.count_; StatsType const new_mean = (mean_ * count_ + other.mean_ * other.count_) / new_size; // Each cumulant must be corrected. // * from: sum((x_i - mean_)²) // * to: sum((x_i - new_mean)²) auto delta = [&](auto const& stats) { return stats.count_ * (new_mean * (new_mean - 2 * stats.mean_) + stats.mean_ * stats.mean_); }; m2_ = m2_ + delta(*this) + other.m2_ + delta(other); mean_ = new_mean; count_ = new_size; } size_t count() const noexcept { return count_; } SampleDataType minimum() const { checkMinimumDataSize(1); return min_; } SampleDataType maximum() const { checkMinimumDataSize(1); return max_; } StatsType mean() const { checkMinimumDataSize(1); return mean_; } StatsType m2() const { checkMinimumDataSize(1); return m2_; } StatsType populationVariance() const { checkMinimumDataSize(2); return var_(0); } StatsType sampleVariance() const { checkMinimumDataSize(2); return var_(1); } StatsType populationStandardDeviation() const { checkMinimumDataSize(2); return std_(0); } StatsType sampleStandardDeviation() const { checkMinimumDataSize(2); return std_(1); } StreamingState state() const { StreamingState state; state.count = count_; state.m2 = m2_; state.max = max_; state.mean = mean_; state.min = min_; return state; } private: void checkMinimumDataSize(size_t const minElements) const { if (count_ < minElements) { throw_exception("stats: unavailable with no samples"); } } StatsType var_(size_t bias) const noexcept { return m2_ / (count_ - bias); } StatsType std_(size_t bias) const noexcept { return std::sqrt(var_(bias)); } size_t count_ = 0; StatsType mean_ = 0; StatsType m2_ = 0; SampleDataType min_ = std::numeric_limits::max(); SampleDataType max_ = std::numeric_limits::lowest(); }; } // namespace folly