Kaydet (Commit) a474afdd authored tarafından Steven D'Aprano's avatar Steven D'Aprano

Add harmonic mean and tests.

üst 95e0df83
...@@ -28,6 +28,7 @@ Calculating averages ...@@ -28,6 +28,7 @@ Calculating averages
Function Description Function Description
================== ============================================= ================== =============================================
mean Arithmetic mean (average) of data. mean Arithmetic mean (average) of data.
harmonic_mean Harmonic mean of data.
median Median (middle value) of data. median Median (middle value) of data.
median_low Low median of data. median_low Low median of data.
median_high High median of data. median_high High median of data.
...@@ -95,16 +96,17 @@ A single exception is defined: StatisticsError is a subclass of ValueError. ...@@ -95,16 +96,17 @@ A single exception is defined: StatisticsError is a subclass of ValueError.
__all__ = [ 'StatisticsError', __all__ = [ 'StatisticsError',
'pstdev', 'pvariance', 'stdev', 'variance', 'pstdev', 'pvariance', 'stdev', 'variance',
'median', 'median_low', 'median_high', 'median_grouped', 'median', 'median_low', 'median_high', 'median_grouped',
'mean', 'mode', 'mean', 'mode', 'harmonic_mean',
] ]
import collections import collections
import decimal
import math import math
import numbers
from fractions import Fraction from fractions import Fraction
from decimal import Decimal from decimal import Decimal
from itertools import groupby from itertools import groupby, chain
from bisect import bisect_left, bisect_right from bisect import bisect_left, bisect_right
...@@ -135,7 +137,8 @@ def _sum(data, start=0): ...@@ -135,7 +137,8 @@ def _sum(data, start=0):
Some sources of round-off error will be avoided: Some sources of round-off error will be avoided:
>>> _sum([1e50, 1, -1e50] * 1000) # Built-in sum returns zero. # Built-in sum returns zero.
>>> _sum([1e50, 1, -1e50] * 1000)
(<class 'float'>, Fraction(1000, 1), 3000) (<class 'float'>, Fraction(1000, 1), 3000)
Fractions and Decimals are also supported: Fractions and Decimals are also supported:
...@@ -291,6 +294,15 @@ def _find_rteq(a, l, x): ...@@ -291,6 +294,15 @@ def _find_rteq(a, l, x):
return i-1 return i-1
raise ValueError raise ValueError
def _fail_neg(values, errmsg='negative value'):
"""Iterate over values, failing if any are less than zero."""
for x in values:
if x < 0:
raise StatisticsError(errmsg)
yield x
# === Measures of central tendency (averages) === # === Measures of central tendency (averages) ===
def mean(data): def mean(data):
...@@ -319,6 +331,52 @@ def mean(data): ...@@ -319,6 +331,52 @@ def mean(data):
return _convert(total/n, T) return _convert(total/n, T)
def harmonic_mean(data):
"""Return the harmonic mean of data.
The harmonic mean, sometimes called the subcontrary mean, is the
reciprocal of the arithmetic mean of the reciprocals of the data,
and is often appropriate when averaging quantities which are rates
or ratios, for example speeds. Example:
Suppose an investor purchases an equal value of shares in each of
three companies, with P/E (price/earning) ratios of 2.5, 3 and 10.
What is the average P/E ratio for the investor's portfolio?
>>> harmonic_mean([2.5, 3, 10]) # For an equal investment portfolio.
3.6
Using the arithmetic mean would give an average of about 5.167, which
is too high.
If ``data`` is empty, or any element is less than zero,
``harmonic_mean`` will raise ``StatisticsError``.
"""
# For a justification for using harmonic mean for P/E ratios, see
# http://fixthepitch.pellucid.com/comps-analysis-the-missing-harmony-of-summary-statistics/
# http://papers.ssrn.com/sol3/papers.cfm?abstract_id=2621087
if iter(data) is data:
data = list(data)
errmsg = 'harmonic mean does not support negative values'
n = len(data)
if n < 1:
raise StatisticsError('harmonic_mean requires at least one data point')
elif n == 1:
x = data[0]
if isinstance(x, (numbers.Real, Decimal)):
if x < 0:
raise StatisticsError(errmsg)
return x
else:
raise TypeError('unsupported type')
try:
T, total, count = _sum(1/x for x in _fail_neg(data, errmsg))
except ZeroDivisionError:
return 0
assert count == n
return _convert(n/total, T)
# FIXME: investigate ways to calculate medians without sorting? Quickselect? # FIXME: investigate ways to calculate medians without sorting? Quickselect?
def median(data): def median(data):
"""Return the median (middle value) of numeric data. """Return the median (middle value) of numeric data.
......
...@@ -21,6 +21,10 @@ import statistics ...@@ -21,6 +21,10 @@ import statistics
# === Helper functions and class === # === Helper functions and class ===
def sign(x):
"""Return -1.0 for negatives, including -0.0, otherwise +1.0."""
return math.copysign(1, x)
def _nan_equal(a, b): def _nan_equal(a, b):
"""Return True if a and b are both the same kind of NAN. """Return True if a and b are both the same kind of NAN.
...@@ -264,6 +268,13 @@ class NumericTestCase(unittest.TestCase): ...@@ -264,6 +268,13 @@ class NumericTestCase(unittest.TestCase):
# === Test the helpers === # === Test the helpers ===
# ======================== # ========================
class TestSign(unittest.TestCase):
"""Test that the helper function sign() works correctly."""
def testZeroes(self):
# Test that signed zeroes report their sign correctly.
self.assertEqual(sign(0.0), +1)
self.assertEqual(sign(-0.0), -1)
# --- Tests for approx_equal --- # --- Tests for approx_equal ---
...@@ -659,7 +670,7 @@ class DocTests(unittest.TestCase): ...@@ -659,7 +670,7 @@ class DocTests(unittest.TestCase):
@unittest.skipIf(sys.flags.optimize >= 2, @unittest.skipIf(sys.flags.optimize >= 2,
"Docstrings are omitted with -OO and above") "Docstrings are omitted with -OO and above")
def test_doc_tests(self): def test_doc_tests(self):
failed, tried = doctest.testmod(statistics) failed, tried = doctest.testmod(statistics, optionflags=doctest.ELLIPSIS)
self.assertGreater(tried, 0) self.assertGreater(tried, 0)
self.assertEqual(failed, 0) self.assertEqual(failed, 0)
...@@ -971,6 +982,34 @@ class ConvertTest(unittest.TestCase): ...@@ -971,6 +982,34 @@ class ConvertTest(unittest.TestCase):
self.assertTrue(_nan_equal(x, nan)) self.assertTrue(_nan_equal(x, nan))
class FailNegTest(unittest.TestCase):
"""Test _fail_neg private function."""
def test_pass_through(self):
# Test that values are passed through unchanged.
values = [1, 2.0, Fraction(3), Decimal(4)]
new = list(statistics._fail_neg(values))
self.assertEqual(values, new)
def test_negatives_raise(self):
# Test that negatives raise an exception.
for x in [1, 2.0, Fraction(3), Decimal(4)]:
seq = [-x]
it = statistics._fail_neg(seq)
self.assertRaises(statistics.StatisticsError, next, it)
def test_error_msg(self):
# Test that a given error message is used.
msg = "badness #%d" % random.randint(10000, 99999)
try:
next(statistics._fail_neg([-1], msg))
except statistics.StatisticsError as e:
errmsg = e.args[0]
else:
self.fail("expected exception, but it didn't happen")
self.assertEqual(errmsg, msg)
# === Tests for public functions === # === Tests for public functions ===
class UnivariateCommonMixin: class UnivariateCommonMixin:
...@@ -1082,13 +1121,13 @@ class UnivariateTypeMixin: ...@@ -1082,13 +1121,13 @@ class UnivariateTypeMixin:
Not all tests to do with types need go in this class. Only those that Not all tests to do with types need go in this class. Only those that
rely on the function returning the same type as its input data. rely on the function returning the same type as its input data.
""" """
def test_types_conserved(self): def prepare_types_for_conservation_test(self):
# Test that functions keeps the same type as their data points. """Return the types which are expected to be conserved."""
# (Excludes mixed data types.) This only tests the type of the return
# result, not the value.
class MyFloat(float): class MyFloat(float):
def __truediv__(self, other): def __truediv__(self, other):
return type(self)(super().__truediv__(other)) return type(self)(super().__truediv__(other))
def __rtruediv__(self, other):
return type(self)(super().__rtruediv__(other))
def __sub__(self, other): def __sub__(self, other):
return type(self)(super().__sub__(other)) return type(self)(super().__sub__(other))
def __rsub__(self, other): def __rsub__(self, other):
...@@ -1098,9 +1137,14 @@ class UnivariateTypeMixin: ...@@ -1098,9 +1137,14 @@ class UnivariateTypeMixin:
def __add__(self, other): def __add__(self, other):
return type(self)(super().__add__(other)) return type(self)(super().__add__(other))
__radd__ = __add__ __radd__ = __add__
return (float, Decimal, Fraction, MyFloat)
def test_types_conserved(self):
# Test that functions keeps the same type as their data points.
# (Excludes mixed data types.) This only tests the type of the return
# result, not the value.
data = self.prepare_data() data = self.prepare_data()
for kind in (float, Decimal, Fraction, MyFloat): for kind in self.prepare_types_for_conservation_test():
d = [kind(x) for x in data] d = [kind(x) for x in data]
result = self.func(d) result = self.func(d)
self.assertIs(type(result), kind) self.assertIs(type(result), kind)
...@@ -1275,12 +1319,16 @@ class AverageMixin(UnivariateCommonMixin): ...@@ -1275,12 +1319,16 @@ class AverageMixin(UnivariateCommonMixin):
for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')): for x in (23, 42.5, 1.3e15, Fraction(15, 19), Decimal('0.28')):
self.assertEqual(self.func([x]), x) self.assertEqual(self.func([x]), x)
def prepare_values_for_repeated_single_test(self):
return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712'))
def test_repeated_single_value(self): def test_repeated_single_value(self):
# The average of a single repeated value is the value itself. # The average of a single repeated value is the value itself.
for x in (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.9712')): for x in self.prepare_values_for_repeated_single_test():
for count in (2, 5, 10, 20): for count in (2, 5, 10, 20):
data = [x]*count with self.subTest(x=x, count=count):
self.assertEqual(self.func(data), x) data = [x]*count
self.assertEqual(self.func(data), x)
class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
...@@ -1304,7 +1352,7 @@ class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): ...@@ -1304,7 +1352,7 @@ class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
self.assertEqual(self.func(data), 22.015625) self.assertEqual(self.func(data), 22.015625)
def test_decimals(self): def test_decimals(self):
# Test mean with ints. # Test mean with Decimals.
D = Decimal D = Decimal
data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")] data = [D("1.634"), D("2.517"), D("3.912"), D("4.072"), D("5.813")]
random.shuffle(data) random.shuffle(data)
...@@ -1379,6 +1427,97 @@ class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin): ...@@ -1379,6 +1427,97 @@ class TestMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
self.assertEqual(statistics.mean([tiny]*n), tiny) self.assertEqual(statistics.mean([tiny]*n), tiny)
class TestHarmonicMean(NumericTestCase, AverageMixin, UnivariateTypeMixin):
def setUp(self):
self.func = statistics.harmonic_mean
def prepare_data(self):
# Override mixin method.
values = super().prepare_data()
values.remove(0)
return values
def prepare_values_for_repeated_single_test(self):
# Override mixin method.
return (3.5, 17, 2.5e15, Fraction(61, 67), Decimal('4.125'))
def test_zero(self):
# Test that harmonic mean returns zero when given zero.
values = [1, 0, 2]
self.assertEqual(self.func(values), 0)
def test_negative_error(self):
# Test that harmonic mean raises when given a negative value.
exc = statistics.StatisticsError
for values in ([-1], [1, -2, 3]):
with self.subTest(values=values):
self.assertRaises(exc, self.func, values)
def test_ints(self):
# Test harmonic mean with ints.
data = [2, 4, 4, 8, 16, 16]
random.shuffle(data)
self.assertEqual(self.func(data), 6*4/5)
def test_floats_exact(self):
# Test harmonic mean with some carefully chosen floats.
data = [1/8, 1/4, 1/4, 1/2, 1/2]
random.shuffle(data)
self.assertEqual(self.func(data), 1/4)
self.assertEqual(self.func([0.25, 0.5, 1.0, 1.0]), 0.5)
def test_singleton_lists(self):
# Test that harmonic mean([x]) returns (approximately) x.
for x in range(1, 101):
if x in (49, 93, 98, 99):
self.assertApproxEqual(self.func([x]), x, tol=2e-14)
else:
self.assertEqual(self.func([x]), x)
def test_decimals_exact(self):
# Test harmonic mean with some carefully chosen Decimals.
D = Decimal
self.assertEqual(self.func([D(15), D(30), D(60), D(60)]), D(30))
data = [D("0.05"), D("0.10"), D("0.20"), D("0.20")]
random.shuffle(data)
self.assertEqual(self.func(data), D("0.10"))
data = [D("1.68"), D("0.32"), D("5.94"), D("2.75")]
random.shuffle(data)
self.assertEqual(self.func(data), D(66528)/70723)
def test_fractions(self):
# Test harmonic mean with Fractions.
F = Fraction
data = [F(1, 2), F(2, 3), F(3, 4), F(4, 5), F(5, 6), F(6, 7), F(7, 8)]
random.shuffle(data)
self.assertEqual(self.func(data), F(7*420, 4029))
def test_inf(self):
# Test harmonic mean with infinity.
values = [2.0, float('inf'), 1.0]
self.assertEqual(self.func(values), 2.0)
def test_nan(self):
# Test harmonic mean with NANs.
values = [2.0, float('nan'), 1.0]
self.assertTrue(math.isnan(self.func(values)))
def test_multiply_data_points(self):
# Test multiplying every data point by a constant.
c = 111
data = [3.4, 4.5, 4.9, 6.7, 6.8, 7.2, 8.0, 8.1, 9.4]
expected = self.func(data)*c
result = self.func([x*c for x in data])
self.assertEqual(result, expected)
def test_doubled_data(self):
# Harmonic mean of [a,b...z] should be same as for [a,a,b,b...z,z].
data = [random.uniform(1, 5) for _ in range(1000)]
expected = self.func(data)
actual = self.func(data*2)
self.assertApproxEqual(actual, expected)
class TestMedian(NumericTestCase, AverageMixin): class TestMedian(NumericTestCase, AverageMixin):
# Common tests for median and all median.* functions. # Common tests for median and all median.* functions.
def setUp(self): def setUp(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment