Kaydet (Commit) 9a2be91c authored tarafından Steven D'Aprano's avatar Steven D'Aprano

Issue27181 add geometric mean.

üst e7fef52f
......@@ -303,6 +303,230 @@ def _fail_neg(values, errmsg='negative value'):
yield x
class _nroot_NS:
"""Hands off! Don't touch!
Everything inside this namespace (class) is an even-more-private
implementation detail of the private _nth_root function.
"""
# This class exists only to be used as a namespace, for convenience
# of being able to keep the related functions together, and to
# collapse the group in an editor. If this were C# or C++, I would
# use a Namespace, but the closest Python has is a class.
#
# FIXME possibly move this out into a separate module?
# That feels like overkill, and may encourage people to treat it as
# a public feature.
def __init__(self):
raise TypeError('namespace only, do not instantiate')
def nth_root(x, n):
"""Return the positive nth root of numeric x.
This may be more accurate than ** or pow():
>>> math.pow(1000, 1.0/3) #doctest:+SKIP
9.999999999999998
>>> _nth_root(1000, 3)
10.0
>>> _nth_root(11**5, 5)
11.0
>>> _nth_root(2, 12)
1.0594630943592953
"""
if not isinstance(n, int):
raise TypeError('degree n must be an int')
if n < 2:
raise ValueError('degree n must be 2 or more')
if isinstance(x, decimal.Decimal):
return _nroot_NS.decimal_nroot(x, n)
elif isinstance(x, numbers.Real):
return _nroot_NS.float_nroot(x, n)
else:
raise TypeError('expected a number, got %s') % type(x).__name__
def float_nroot(x, n):
"""Handle nth root of Reals, treated as a float."""
assert isinstance(n, int) and n > 1
if x < 0:
if n%2 == 0:
raise ValueError('domain error: even root of negative number')
else:
return -_nroot_NS.nroot(-x, n)
elif x == 0:
return math.copysign(0.0, x)
elif x > 0:
try:
isinfinity = math.isinf(x)
except OverflowError:
return _nroot_NS.bignum_nroot(x, n)
else:
if isinfinity:
return float('inf')
else:
return _nroot_NS.nroot(x, n)
else:
assert math.isnan(x)
return float('nan')
def nroot(x, n):
"""Calculate x**(1/n), then improve the answer."""
# This uses math.pow() to calculate an initial guess for the root,
# then uses the iterated nroot algorithm to improve it.
#
# By my testing, about 8% of the time the iterated algorithm ends
# up converging to a result which is less accurate than the initial
# guess. [FIXME: is this still true?] In that case, we use the
# guess instead of the "improved" value. This way, we're never
# less accurate than math.pow().
r1 = math.pow(x, 1.0/n)
eps1 = abs(r1**n - x)
if eps1 == 0.0:
# r1 is the exact root, so we're done. By my testing, this
# occurs about 80% of the time for x < 1 and 30% of the
# time for x > 1.
return r1
else:
try:
r2 = _nroot_NS.iterated_nroot(x, n, r1)
except RuntimeError:
return r1
else:
eps2 = abs(r2**n - x)
if eps1 < eps2:
return r1
return r2
def iterated_nroot(a, n, g):
"""Return the nth root of a, starting with guess g.
This is a special case of Newton's Method.
https://en.wikipedia.org/wiki/Nth_root_algorithm
"""
np = n - 1
def iterate(r):
try:
return (np*r + a/math.pow(r, np))/n
except OverflowError:
# If r is large enough, r**np may overflow. If that
# happens, r**-np will be small, but not necessarily zero.
return (np*r + a*math.pow(r, -np))/n
# With a good guess, such as g = a**(1/n), this will converge in
# only a few iterations. However a poor guess can take thousands
# of iterations to converge, if at all. We guard against poor
# guesses by setting an upper limit to the number of iterations.
r1 = g
r2 = iterate(g)
for i in range(1000):
if r1 == r2:
break
# Use Floyd's cycle-finding algorithm to avoid being trapped
# in a cycle.
# https://en.wikipedia.org/wiki/Cycle_detection#Tortoise_and_hare
r1 = iterate(r1)
r2 = iterate(iterate(r2))
else:
# If the guess is particularly bad, the above may fail to
# converge in any reasonable time.
raise RuntimeError('nth-root failed to converge')
return r2
def decimal_nroot(x, n):
"""Handle nth root of Decimals."""
assert isinstance(x, decimal.Decimal)
assert isinstance(n, int)
if x.is_snan():
# Signalling NANs always raise.
raise decimal.InvalidOperation('nth-root of snan')
if x.is_qnan():
# Quiet NANs only raise if the context is set to raise,
# otherwise return a NAN.
ctx = decimal.getcontext()
if ctx.traps[decimal.InvalidOperation]:
raise decimal.InvalidOperation('nth-root of nan')
else:
# Preserve the input NAN.
return x
if x.is_infinite():
return x
# FIXME this hasn't had the extensive testing of the float
# version _iterated_nroot so there's possibly some buggy
# corner cases buried in here. Can it overflow? Fail to
# converge or get trapped in a cycle? Converge to a less
# accurate root?
np = n - 1
def iterate(r):
return (np*r + x/r**np)/n
r0 = x**(decimal.Decimal(1)/n)
assert isinstance(r0, decimal.Decimal)
r1 = iterate(r0)
while True:
if r1 == r0:
return r1
r0, r1 = r1, iterate(r1)
def bignum_nroot(x, n):
"""Return the nth root of a positive huge number."""
assert x > 0
# I state without proof that ⁿ√x ≈ ⁿ√2·ⁿ√(x//2)
# and that for sufficiently big x the error is acceptible.
# We now halve x until it is small enough to get the root.
m = 0
while True:
x //= 2
m += 1
try:
y = float(x)
except OverflowError:
continue
break
a = _nroot_NS.nroot(y, n)
# At this point, we want the nth-root of 2**m, or 2**(m/n).
# We can write that as 2**(q + r/n) = 2**q * ⁿ√2**r where q = m//n.
q, r = divmod(m, n)
b = 2**q * _nroot_NS.nroot(2**r, n)
return a * b
# This is the (private) function for calculating nth roots:
_nth_root = _nroot_NS.nth_root
assert type(_nth_root) is type(lambda: None)
def _product(values):
"""Return product of values as (exponent, mantissa)."""
errmsg = 'mixed Decimal and float is not supported'
prod = 1
for x in values:
if isinstance(x, float):
break
prod *= x
else:
return (0, prod)
if isinstance(prod, Decimal):
raise TypeError(errmsg)
# Since floats can overflow easily, we calculate the product as a
# sort of poor-man's BigFloat. Given that:
#
# x = 2**p * m # p == power or exponent (scale), m = mantissa
#
# we can calculate the product of two (or more) x values as:
#
# x1*x2 = 2**p1*m1 * 2**p2*m2 = 2**(p1+p2)*(m1*m2)
#
mant, scale = 1, 0 #math.frexp(prod) # FIXME
for y in chain([x], values):
if isinstance(y, Decimal):
raise TypeError(errmsg)
m1, e1 = math.frexp(y)
m2, e2 = math.frexp(mant)
scale += (e1 + e2)
mant = m1*m2
return (scale, mant)
# === Measures of central tendency (averages) ===
def mean(data):
......@@ -331,6 +555,49 @@ def mean(data):
return _convert(total/n, T)
def geometric_mean(data):
"""Return the geometric mean of data.
The geometric mean is appropriate when averaging quantities which
are multiplied together rather than added, for example growth rates.
Suppose an investment grows by 10% in the first year, falls by 5% in
the second, then grows by 12% in the third, what is the average rate
of growth over the three years?
>>> geometric_mean([1.10, 0.95, 1.12])
1.0538483123382172
giving an average growth of 5.385%. Using the arithmetic mean will
give approximately 5.667%, which is too high.
``StatisticsError`` will be raised if ``data`` is empty, or any
element is less than zero.
"""
if iter(data) is data:
data = list(data)
errmsg = 'geometric mean does not support negative values'
n = len(data)
if n < 1:
raise StatisticsError('geometric_mean requires at least one data point')
elif n == 1:
x = data[0]
if isinstance(g, (numbers.Real, Decimal)):
if x < 0:
raise StatisticsError(errmsg)
return x
else:
raise TypeError('unsupported type')
else:
scale, prod = _product(_fail_neg(data, errmsg))
r = _nth_root(prod, n)
if scale:
p, q = divmod(scale, n)
s = 2**p * _nth_root(2**q, n)
else:
s = 1
return s*r
def harmonic_mean(data):
"""Return the harmonic mean of data.
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment