aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore3
-rw-r--r--LICENSE.rst28
-rw-r--r--README.rst9
-rw-r--r--vquad/__init__.py7
-rw-r--r--vquad/core.py240
-rw-r--r--vquad/tables.py183
-rw-r--r--vquad/test/__init__.py0
-rw-r--r--vquad/test/test_core.py132
-rw-r--r--vquad/test/test_tables.py53
9 files changed, 655 insertions, 0 deletions
diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000..982c4c1
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,3 @@
+*~
+__pycache__/
+.cache/
diff --git a/LICENSE.rst b/LICENSE.rst
new file mode 100644
index 0000000..8e2acd6
--- /dev/null
+++ b/LICENSE.rst
@@ -0,0 +1,28 @@
+=============
+Vquad license
+=============
+
+Copyright 2017 Christoph Groth (CEA). All rights reserved.
+
+(CEA = Commissariat à l'énergie atomique et aux énergies alternatives)
+
+Redistribution and use in source and binary forms, with or without
+modification, are permitted provided that the following conditions are met:
+
+* Redistributions of source code must retain the above copyright notice, this
+ list of conditions and the following disclaimer.
+
+* Redistributions in binary form must reproduce the above copyright notice,
+ this list of conditions and the following disclaimer in the documentation
+ and/or other materials provided with the distribution.
+
+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
+ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
+WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
+DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
+ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
+(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
+LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
+ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
+(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
+SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
diff --git a/README.rst b/README.rst
new file mode 100644
index 0000000..12eba6d
--- /dev/null
+++ b/README.rst
@@ -0,0 +1,9 @@
+Vquad
+=====
+
+Vquad is a work-in-progress towards a general-purpose, robust, efficient, and parallel library for numerical integration.
+
+History
+-------
+
+Vquad is based on algorithm 4 from the article "Increasing the Reliability of Adaptive Quadrature Using Explicit Interpolants", Pedro Gonnet, ACM TOMS 37, 26 (2010).
diff --git a/vquad/__init__.py b/vquad/__init__.py
new file mode 100644
index 0000000..3313b0f
--- /dev/null
+++ b/vquad/__init__.py
@@ -0,0 +1,7 @@
+# Copyright 2017 Christoph Groth (CEA).
+#
+# This file is part of Vquad. It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution.
+
+
+from .core import vquad
diff --git a/vquad/core.py b/vquad/core.py
new file mode 100644
index 0000000..1866276
--- /dev/null
+++ b/vquad/core.py
@@ -0,0 +1,240 @@
+# Copyright 2017 Christoph Groth (CEA).
+#
+# This file is part of Vquad. It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution.
+
+import numpy as np
+from scipy.linalg import norm
+
+from . import tables as tbls
+
+eps = np.spacing(1)
+
+# If the relative difference between two consecutive approximations is
+# lower than this value, the error estimate is considered reliable.
+# See section 6.2 of Pedro Gonnet's thesis.
+hint = 0.1
+
+# Smallest acceptable relative difference of points in a rule. This was chosen
+# such that no artifacts are apparent in plots of (i, log(a_i)), where a_i is
+# the sequence of estimates of the integral value of an interval and all its
+# ancestors..
+min_sep = 16 * eps
+
+min_level = 1
+max_level = 4
+
+ndiv_max = 20
+max_ivals = 200
+
+_sqrt_one_half = np.sqrt(0.5)
+
+
+def _calc_coeffs(vals, level):
+ nans = np.flatnonzero(~np.isfinite(vals))
+ if nans.size:
+ # Replace vals by a copy and zero-out non-finite elements.
+ vals = vals.copy()
+ vals[nans] = 0
+ # Prepare things for the loop further down.
+ b = tbls.newton_coeffs[level].copy()
+ m = len(b) - 2 # = len(tbls.nodes[level]) - 1
+ coeffs = tbls.inv_Vs[level] @ vals
+
+ # This is a variant of Algorithm 7 from the thesis of Pedro Gonnet where no
+ # linear system has to be solved explicitly. Instead, Algorithm 5 is used.
+ for i in nans:
+ b[m + 1] /= tbls.alpha[m]
+ x = tbls.nodes[level][i]
+ b[m] = (b[m] + x * b[m + 1]) / tbls.alpha[m - 1]
+ for j in range(m - 1, 0, -1):
+ b[j] = ((b[j] + x * b[j + 1] - tbls.gamma[j + 1] * b[j + 2])
+ / tbls.alpha[j - 1])
+ b = b[1:]
+
+ coeffs[:m] -= coeffs[m] / b[m] * b[:m]
+ coeffs[m] = 0
+ m -= 1
+
+ return coeffs
+
+
+class DivergentIntegralError(ValueError):
+ def __init__(self, msg, igral, err, nr_points):
+ self.igral = igral
+ self.err = err
+ self.nr_points = nr_points
+ super().__init__(msg)
+
+
+class _Interval:
+ __slots__ = ['a', 'b', 'coeffs', 'vals', 'igral', 'err', 'level', 'depth',
+ 'ndiv', 'c00', 'unreliable_err']
+
+ def __init__(self, a, b, level, depth):
+ self.a = a
+ self.b = b
+ self.level = level
+ self.depth = depth
+
+ def points(self):
+ a = self.a
+ b = self.b
+ return (a + b) / 2 + (b - a) * tbls.nodes[self.level] / 2
+
+ def interpolate(self, vals, coeffs_old=None):
+ self.vals = vals
+ self.coeffs = coeffs = _calc_coeffs(self.vals, self.level)
+ if self.level == min_level:
+ self.c00 = coeffs[0]
+ if coeffs_old is None:
+ coeffs_diff = norm(coeffs)
+ else:
+ coeffs_diff = np.zeros(max(len(coeffs_old), len(coeffs)))
+ coeffs_diff[:len(coeffs_old)] = coeffs_old
+ coeffs_diff[:len(coeffs)] -= coeffs
+ coeffs_diff = norm(coeffs_diff)
+ w = self.b - self.a
+ self.igral = w * coeffs[0] * _sqrt_one_half
+ self.err = w * coeffs_diff
+ self.unreliable_err = coeffs_diff > hint * norm(coeffs)
+
+
+class Vquad:
+ """Evaluate an integral using adaptive quadrature.
+
+ The algorithm uses Clenshaw-Curtis quadrature rules of increasing
+ degree in each interval. The error estimate is
+ sqrt(integrate((f0(x) - f1(x))**2)), where f0 and f1 are two
+ successive interpolations of the integrand. To fall below the
+ desired total error, intervals are worked on ranked by their own
+ absolute error: either the degree of the rule is increased or the
+ interval is split if either the function does not appear to be
+ smooth or a rule of maximum degree has been reached.
+
+ Reference: "Increasing the Reliability of Adaptive Quadrature
+ Using Explicit Interpolants", P. Gonnet, ACM Transactions on
+ Mathematical Software, 37 (3), art. no. 26, 2008.
+ """
+
+ def __init__(self, f, a, b, level=max_level - 1):
+ ival = _Interval(a, b, level, 1)
+ vals = f(ival.points())
+ ival.interpolate(vals)
+ ival.c00 = 0.0 # Will go away.
+ ival.ndiv = 0
+
+ self.ivals = [ival]
+ self.f = f
+ self.nr_points = len(vals)
+ self.igral_excess = 0
+ self.err_excess = 0
+ self.i_max = 0
+
+ def split(self, ival):
+ m = (ival.a + ival.b) / 2
+ f_center = ival.vals[(len(ival.vals) - 1) // 2]
+
+ depth = ival.depth + 1
+ children = [_Interval(ival.a, m, min_level, depth),
+ _Interval(m, ival.b, min_level, depth)]
+ points = np.concatenate([child.points()[1:-1] for child in children])
+ self.nr_points += len(points)
+ valss = np.empty((2, tbls.sizes[min_level]))
+ valss[:, 0] = ival.vals[0], f_center
+ valss[:, -1] = f_center, ival.vals[-1]
+ valss[:, 1:-1] = self.f(points).reshape((2, -1))
+
+ for child, vals, T in zip(children, valss, tbls.Ts):
+ child.interpolate(vals, T[:, :ival.coeffs.shape[0]] @ ival.coeffs)
+ child.ndiv = (ival.ndiv
+ + (ival.c00 and child.c00 / ival.c00 > 2))
+ if child.ndiv > ndiv_max and 2*child.ndiv > child.depth:
+ msg = ('Possibly divergent integral in the interval '
+ '[{}, {}]! (h={})')
+ raise DivergentIntegralError(
+ msg.format(child.a, child.b, child.b - child.a),
+ child.igral * np.inf, None, self.nr_points)
+ return children
+
+ def refine(self, ival):
+ """Increase degree of interval."""
+ ival.level += 1
+ points = ival.points()
+ vals = np.empty(points.shape)
+ vals[0::2] = ival.vals
+ vals[1::2] = self.f(points[1::2])
+ self.nr_points += (len(vals) - 1) // 2
+ ival.interpolate(vals, ival.coeffs)
+ return points
+
+ def improve(self):
+ i_max = self.i_max
+
+ if self.ivals[i_max].level == max_level:
+ split = True
+ else:
+ points = self.refine(self.ivals[i_max])
+ split = self.ivals[i_max].unreliable_err
+
+ if (points[1] - points[0] < points[0] * min_sep
+ or points[-1] - points[-2] < points[-2] * min_sep
+ or (self.ivals[i_max].err
+ < (abs(self.ivals[i_max].igral) * eps
+ * tbls.V_cond_nums[self.ivals[i_max].level]))):
+ # Remove the interval (while remembering the excess integral
+ # and error), since it is either too narrow, or the estimated
+ # relative error is already at the limit of numerical accuracy
+ # and cannot be reduced further.
+ self.err_excess += self.ivals[i_max].err
+ self.igral_excess += self.ivals[i_max].igral
+ self.ivals[i_max] = self.ivals[-1]
+ self.ivals.pop()
+ return
+
+ if split:
+ self.ivals.extend(self.split(self.ivals[i_max]))
+ self.ivals[i_max] = self.ivals.pop()
+
+ def totals(self):
+ # Compute the total error and new max.
+ i_max = 0
+ i_min = 0
+ err = self.err_excess
+ igral = self.igral_excess
+ for i in range(len(self.ivals)):
+ if self.ivals[i].err > self.ivals[i_max].err:
+ i_max = i
+ elif self.ivals[i].err < self.ivals[i_min].err:
+ i_min = i
+ err += self.ivals[i].err
+ igral += self.ivals[i].igral
+
+ # If there are too many intervals, remove the one with smallest
+ # contribution to the error.
+ if len(self.ivals) > max_ivals:
+ self.err_excess += self.ivals[i_min].err
+ self.igral_excess += self.ivals[i_min].igral
+ self.ivals[i_min] = self.ivals[-1]
+ self.ivals.pop()
+ if i_max == len(self.ivals):
+ i_max = i_min
+
+ self.i_max = i_max
+ return igral, err
+
+ def improve_until(self, tol):
+ while True:
+ self.improve()
+ igral, err = self.totals()
+
+ if (err == 0
+ or err < abs(igral) * tol
+ or err - self.err_excess < abs(igral) * tol < self.err_excess
+ or not self.ivals):
+ return igral, err, self.nr_points
+
+
+def vquad(f, a, b, tol):
+ igrator = Vquad(f, a, b)
+ return igrator.improve_until(tol)
diff --git a/vquad/tables.py b/vquad/tables.py
new file mode 100644
index 0000000..267ac51
--- /dev/null
+++ b/vquad/tables.py
@@ -0,0 +1,183 @@
+# Copyright 2017 Christoph Groth (CEA).
+#
+# This file is part of Vquad. It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution.
+
+from fractions import Fraction as Frac
+from collections import defaultdict
+import numpy as np
+from scipy.linalg import norm, inv
+
+__all__ = ['sizes', 'nodes', 'newton_coeffs', 'inv_Vs', 'V_cond_nums', 'Ts',
+ 'alpha', 'gamma']
+
+def legendre(n):
+ """Return the first n Legendre polynomials.
+
+ The polynomials have *standard* normalization, i.e.
+ int_{-1}^1 dx L_n(x) L_m(x) = delta(m, n) * 2 / (2 * n + 1).
+
+ The return value is a list of list of fraction.Fraction instances.
+ """
+ result = [[Frac(1)], [Frac(0), Frac(1)]]
+ if n <= 2:
+ return result[:n]
+ for i in range(2, n):
+ # Use Bonnet's recursion formula.
+ new = (i + 1) * [Frac(0)]
+ new[1:] = (r * (2*i - 1) for r in result[-1])
+ new[:-2] = (n - r * (i - 1) for n, r in zip(new[:-2], result[-2]))
+ new[:] = (n / i for n in new)
+ result.append(new)
+ return result
+
+
+def newton(n):
+ """Compute the monomial coefficients of the Newton polynomial over the
+ nodes of the n-point Clenshaw-Curtis quadrature rule.
+ """
+ # The nodes of the Clenshaw-Curtis rule are x_i = -cos(i * Pi / (n-1)).
+ # Here, we calculate the coefficients c_i such that sum_i c_i * x^i
+ # = prod_i (x - x_i). The coefficients are thus sums of products of
+ # cosines.
+ #
+ # This routine uses the relation
+ # cos(a) cos(b) = (cos(a + b) + cos(a - b)) / 2
+ # to efficiently calculate the coefficients.
+ #
+ # The dictionary 'terms' descibes the terms that make up the
+ # monomial coefficients. Each item ((d, a), m) corresponds to a
+ # term m * cos(a * Pi / n) to be added to prefactor of the
+ # monomial x^(n-d).
+
+ mod = 2 * (n-1)
+ terms = defaultdict(int)
+ terms[0, 0] += 1
+
+ for i in range(n):
+ newterms = []
+ for (d, a), m in terms.items():
+ for b in [i, -i]:
+ # In order to reduce the number of terms, cosine
+ # arguments are mapped back to the inteval [0, pi/2).
+ arg = (a + b) % mod
+ if arg > n-1:
+ arg = mod - arg
+ if arg >= n // 2:
+ if n % 2 and arg == n // 2:
+ # Zero term: ignore
+ continue
+ newterms.append((d + 1, n - 1 - arg, -m))
+ else:
+ newterms.append((d + 1, arg, m))
+ for d, s, m in newterms:
+ terms[d, s] += m
+
+ c = (n + 1) * [0]
+ for (d, a), m in terms.items():
+ if m and a != 0:
+ raise ValueError("Newton polynomial cannot be represented exactly.")
+ c[n - d] += m
+ # The check could be removed and the above line replaced by
+ # the following, but then the result would be no longer exact.
+ # c[n - d] += m * np.cos(a * np.pi / (n - 1))
+
+ cf = np.array(c, float)
+ assert all(int(cfe) == ce for cfe, ce in zip(cf, c)), 'Precision loss'
+
+ cf /= 2.**np.arange(n, -1, -1)
+ return cf
+
+
+def scalar_product(a, b):
+ """Compute the polynomial scalar product int_-1^1 dx a(x) b(x).
+
+ The args must be sequences of polynomial coefficients. This
+ function is careful to use the input data type for calculations.
+ """
+ la = len(a)
+ lc = len(b) + la + 1
+
+ # Compute the even coefficients of the product of a and b.
+ c = lc * [a[0].__class__()]
+ for i, bi in enumerate(b):
+ if bi == 0:
+ continue
+ for j in range(i % 2, la, 2):
+ c[i + j] += a[j] * bi
+
+ # Calculate the definite integral from -1 to 1.
+ return 2 * sum(c[i] / (i + 1) for i in range(0, lc, 2))
+
+
+def newton_legendre(sizes):
+ """Calculate the decompositions of Newton polynomials (over the nodes
+ of the n-point Clenshaw-Curtis quadrature rule) in terms of
+ Legandre polynomials.
+
+ The parameter 'sizes' is a sequence of numers of points of the
+ quadrature rule. The return value is a corresponding sequence of
+ normalized Legendre polynomial coefficients.
+ """
+ legs = legendre(max(sizes) + 1)
+ result = []
+ for n in sizes:
+ poly = []
+ a = list(map(Frac, newton(n)))
+ for b in legs[:n + 1]:
+ igral = scalar_product(a, b)
+
+ # Normalize & store. (The polynomials returned by
+ # legendre() have standard normalization that is not
+ # orthonormal.)
+ poly.append(np.sqrt((2*len(b) - 1) / 2) * igral)
+
+ result.append(np.array(poly))
+ return result
+
+
+def vandermonde(nodes):
+ V = [np.ones(nodes.shape), nodes.copy()]
+ for i in range(2, len(nodes)):
+ V.append((2*i-1) / i * nodes * V[-1] - (i-1) / i * V[-2])
+ for i in np.arange(len(nodes)):
+ V[i] *= np.sqrt(i + 0.5)
+ return np.array(V).T
+
+
+def precalculate(n_levels=5):
+ """Precalculate tables for adaptive quadrature based on the
+ Clenshaw-Curtis rule and Legendre polynomials.
+
+ The Clenshaw-Curtis rule with three points is the lowest rule that
+ contains the center of the interval, hence we define it as level 0.
+ """
+ global sizes, nodes, newton_coeffs, inv_Vs, V_cond_nums, Ts, alpha, gamma
+
+ # Points of the Clenshaw-Curtis rule.
+ sizes = [2**(level + 1) + 1 for level in range(n_levels)]
+ nodes = [-np.cos(np.pi / (n - 1) * np.arange(n)) for n in sizes]
+ # Set central rule points precisely to zero. This does not really
+ # matter in practice, but is useful for tests.
+ for l in nodes:
+ l[len(l) // 2] = 0.0
+
+ # Vandermonde-like matrices and their condition numbers
+ V = list(map(vandermonde, nodes))
+ inv_Vs = inv_Vs = list(map(inv, V))
+ V_cond_nums = [norm(a, 2) * norm(b, 2) for a, b in zip(V, inv_Vs)]
+
+ # Shift matrices
+ Ts = [inv_Vs[-1] @ vandermonde((nodes[-1] + a) / 2) for a in [-1, 1]]
+
+ # Newton polynomials
+ newton_coeffs = newton_legendre(sizes)
+
+ # Other downdate matrices
+ k = np.arange(sizes[-1])
+ alpha = np.sqrt((k+1)**2 / (2*k+1) / (2*k+3))
+ gamma = np.concatenate([[0, 0],
+ np.sqrt(k[2:]**2 / (4*k[2:]**2-1))])
+
+
+precalculate()
diff --git a/vquad/test/__init__.py b/vquad/test/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/vquad/test/__init__.py
diff --git a/vquad/test/test_core.py b/vquad/test/test_core.py
new file mode 100644
index 0000000..70b0c10
--- /dev/null
+++ b/vquad/test/test_core.py
@@ -0,0 +1,132 @@
+# Copyright 2017 Christoph Groth (CEA).
+#
+# This file is part of Vquad. It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution.
+
+import numpy as np
+from numpy.testing import assert_allclose
+
+from .. import core
+
+def f0(x):
+ return x * np.sin(1/x) * np.sqrt(abs(1 - x))
+
+
+def f7(x):
+ return x**-0.5
+
+
+def f24(x):
+ return np.floor(np.exp(x))
+
+
+def f21(x):
+ y = 0
+ for i in range(1, 4):
+ y += 1 / np.cosh(20**i * (x - 2 * i / 10))
+ return y
+
+
+def f63(x, alpha, beta):
+ return abs(x - beta) ** alpha
+
+
+def F63(x, alpha, beta):
+ return (x - beta) * abs(x - beta) ** alpha / (alpha + 1)
+
+
+def fdiv(x):
+ return abs(x - 0.987654321) ** -1.1
+
+
+def f_one_with_nan(x):
+ x = np.asarray(x)
+ result = np.ones(x.shape)
+ result[x == 0] = np.inf
+ return result
+
+
+def test_downdate(level=3):
+ vals = np.abs(core.tbls.nodes[level])
+ vals[1::2] = np.nan
+ c_downdated = core._calc_coeffs(vals, level)
+
+ level -= 1
+ vals = np.abs(core.tbls.nodes[level])
+ c = core._calc_coeffs(vals, level)
+
+ assert_allclose(c_downdated[:len(c)], c, rtol=0, atol=1e-9)
+
+
+def test_integration():
+ old_settings = np.seterr(all='ignore')
+
+ igral, err, nr_points = core.vquad(f0, 0, 3, 1e-5)
+ assert_allclose(igral, 1.98194117954329, 1e-15)
+ assert_allclose(err, 1.9563545589988155e-05, 1e-10)
+ assert nr_points == 1129
+
+ igral, err, nr_points = core.vquad(f7, 0, 1, 1e-6)
+ assert_allclose(igral, 1.9999998579359648, 1e-15)
+ assert_allclose(err, 1.8561437334964041e-06, 1e-10)
+ assert nr_points == 693
+
+ igral, err, nr_points = core.vquad(f24, 0, 3, 1e-3)
+ assert_allclose(igral, 17.664696186312934, 1e-15)
+ assert_allclose(err, 0.017602618074957457, 1e-10)
+ assert nr_points == 4519
+
+ igral, err, nr_points = core.vquad(f21, 0, 1, 1e-3)
+ assert_allclose(igral, 0.16310022131213361, 1e-15)
+ assert_allclose(err, 0.00011848806384952786, 1e-10)
+ assert nr_points == 191
+
+ igral, err, nr_points = core.vquad(f_one_with_nan, -1, 1, 1e-12)
+ assert_allclose(igral, 2, 1e-15)
+ assert_allclose(err, 2.4237853822937613e-15, 1e-7)
+ assert nr_points == 33
+
+ try:
+ igral, err, nr_points = core.vquad(fdiv, 0, 1, 1e-6)
+ except core.DivergentIntegralError as e:
+ assert e.igral == np.inf
+ assert e.err is None
+ assert e.nr_points == 431
+
+ np.seterr(**old_settings)
+
+
+def test_analytic(n=200):
+ def f(x):
+ return f63(x, alpha, beta)
+
+ def F(x):
+ return F63(x, alpha, beta)
+
+ old_settings = np.seterr(all='ignore')
+
+ np.random.seed(123)
+ params = np.empty((n, 2))
+ params[:, 0] = np.linspace(-0.5, -1.5, n)
+ params[:, 1] = np.random.random_sample(n)
+
+ false_negatives = 0
+ false_positives = 0
+
+ for alpha, beta in params:
+ try:
+ igral, err, nr_points = core.vquad(f, 0, 1, 1e-3)
+ except core.DivergentIntegralError:
+ assert alpha < -0.8
+ false_negatives += alpha > -1
+ else:
+ if alpha <= -1:
+ false_positives += 1
+ else:
+ igral_exact = F(1) - F(0)
+ assert alpha < -0.7 or abs(igral - igral_exact) < err
+
+ assert false_negatives < 0.05 * n
+ assert false_positives < 0.05 * n
+
+ np.seterr(**old_settings)
diff --git a/vquad/test/test_tables.py b/vquad/test/test_tables.py
new file mode 100644
index 0000000..79b8ffe
--- /dev/null
+++ b/vquad/test/test_tables.py
@@ -0,0 +1,53 @@
+# Copyright 2017 Christoph Groth (CEA).
+#
+# This file is part of Vquad. It is subject to the license terms in the file
+# LICENSE.rst found in the top-level directory of this distribution.
+
+from fractions import Fraction as Frac
+from itertools import combinations
+import numpy as np
+from numpy.testing import assert_allclose
+
+from .. import tables
+
+
+def test_legendre():
+ legs = tables.legendre(11)
+ comparisons = [(legs[0], [1], 1),
+ (legs[1], [0, 1], 1),
+ (legs[10], [-63, 0, 3465, 0, -30030, 0,
+ 90090, 0, -109395, 0, 46189], 256)]
+ for a, b, div in comparisons:
+ for c, d in zip(a, b):
+ assert c * div == d
+
+
+def test_scalar_product(n=33):
+ legs = tables.legendre(n)
+ selection = [0, 5, 7, n-1]
+ for i in selection:
+ for j in selection:
+ assert (tables.scalar_product(legs[i], legs[j])
+ == ((i == j) and Frac(2, 2*i + 1)))
+
+
+def simple_newton(n):
+ """Slower than 'newton()' and prone to numerical error."""
+ nodes = -np.cos(np.arange(n) / (n-1) * np.pi)
+ return [sum(np.prod(-np.asarray(sel))
+ for sel in combinations(nodes, n - d))
+ for d in range(n + 1)]
+
+
+def test_newton():
+ assert_allclose(tables.newton(9), simple_newton(9), atol=1e-15)
+
+
+def test_newton_legendre(level=1):
+ legs = [np.array(leg, float)
+ for leg in tables.legendre(tables.sizes[level] + 1)]
+ result = np.zeros(len(legs[-1]))
+ for factor, leg in zip(tables.newton_coeffs[level], legs):
+ factor *= np.sqrt((2*len(leg) - 1) / 2)
+ result[:len(leg)] += factor * leg
+ assert_allclose(result, tables.newton(tables.sizes[level]), rtol=1e-15)