diff options
| author | Christoph Groth <christoph.groth@cea.fr> | 2018-02-22 16:51:29 +0100 |
|---|---|---|
| committer | Christoph Groth <christoph.groth@cea.fr> | 2018-02-28 21:01:31 +0100 |
| commit | 2f95f2f991956abac93cf6b273b1628fbad22f09 (patch) | |
| tree | 2434455ec919f338a0421ce83d9703e747f9c5d1 | |
| parent | 581bad7324c4c1164ba4adec69c7b11b0e25fb7e (diff) | |
implement evaluation of the interpolated integrand
| -rw-r--r-- | vquad/__init__.py | 2 | ||||
| -rw-r--r-- | vquad/core.py | 89 | ||||
| -rw-r--r-- | vquad/test/test_core.py | 24 |
3 files changed, 104 insertions, 11 deletions
diff --git a/vquad/__init__.py b/vquad/__init__.py index 3313b0f..8555ba5 100644 --- a/vquad/__init__.py +++ b/vquad/__init__.py @@ -4,4 +4,4 @@ # LICENSE.rst found in the top-level directory of this distribution. -from .core import vquad +from .core import vquad, Vquad diff --git a/vquad/core.py b/vquad/core.py index f6baedb..e7faef2 100644 --- a/vquad/core.py +++ b/vquad/core.py @@ -3,7 +3,7 @@ # This file is part of Vquad. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution. -from bisect import insort +import bisect import numpy as np from scipy.linalg import norm @@ -31,6 +31,28 @@ ndiv_max = 20 _sqrt_one_half = np.sqrt(0.5) +def _eval_legendre(c, x): + """Evaluate _orthonormal_ Legendre polynomial. + + This uses the three-term recurrence relation from page 63 of Perdo + Gonnet's thesis. + """ + if len(c) <= 1: + c0 = c[0] + c1 = 0 + else: + n = len(c) + c0 = c[-2] # = c[k + 0] + c1 = c[-1] # = c[k + 1] + for k in range(len(c) - 2, 0, -1): + a = (2*k + 3) / (k + 1)**2 + tmp = c0 + c0 = c[k - 1] - c1 * np.sqrt(a * k**2 / (2*k - 1)) + c1 = tmp + c1 * x * np.sqrt(a * (2*k + 1)) + + return np.sqrt(1/2) * c0 + np.sqrt(3/2) * c1 * x + + def _calc_coeffs(vals, level): nans = np.flatnonzero(~np.isfinite(vals)) if nans.size: @@ -67,9 +89,13 @@ class DivergentIntegralError(ValueError): super().__init__(msg) +class _Terminator: + __slots__ = ['prev', 'next'] + + class _Interval: __slots__ = ['a', 'b', 'coeffs', 'vals', 'igral', 'err', 'level', 'depth', - 'ndiv', 'c00', 'unreliable_err'] + 'ndiv', 'c00', 'unreliable_err', 'prev', 'next'] def __init__(self, a, b, level, depth): self.a = a @@ -102,6 +128,11 @@ class _Interval: def __lt__(self, other): return self.err < other.err + def __call__(self, x): + a = self.a + b = self.b + x = (2 * x - (a + b)) / (b - a) + return _eval_legendre(self.coeffs, x) class Vquad: """Evaluate an integral using adaptive quadrature. @@ -128,11 +159,16 @@ class Vquad: ival.ndiv = 0 self.ivals = [ival] # Active intervals - self.attic = [] # Inactive intervals self.f = f self.igral_excess = 0 self.err_excess = 0 + # Initialize linked list. + ival.prev = self.begin = _Terminator() + self.begin.next = ival + ival.next = self.end = _Terminator() + self.end.prev = ival + def split(self, ival): m = (ival.a + ival.b) / 2 f_center = ival.vals[(len(ival.vals) - 1) // 2] @@ -191,17 +227,27 @@ class Vquad: # error. self.err_excess += ival.err self.igral_excess += ival.igral - self.attic.append(self.ivals.pop()) + self.ivals.pop() return split = ival.unreliable_err if split: # Replace current interval by its children. - for new in self.split(self.ivals.pop()): - insort(self.ivals, new) + self.ivals.pop() + child0, child1 = self.split(ival) + bisect.insort(self.ivals, child0) + bisect.insort(self.ivals, child1) + + # Maintain linked list. + ival.prev.next = child0 + ival.next.prev = child1 + child0.prev = ival.prev + child0.next = child1 + child1.prev = child0 + child1.next = ival.next else: # The error estimate of the current interval has changed. - insort(self.ivals, self.ivals.pop()) + bisect.insort(self.ivals, self.ivals.pop()) def totals(self): igral = self.igral_excess @@ -222,6 +268,35 @@ class Vquad: or not self.ivals): return igral, err + def __call__(self, xs): + xs = np.asarray(xs) + shape = xs.shape + if xs.size == 0: + return np.empty(shape) + xs = xs.flatten() + + # Sort xs, but remember inverse permutation. + perm = np.argsort(xs) + xs = xs[perm] + inv_perm = np.empty(len(perm), int) + inv_perm[perm] = np.arange(len(perm)) + + # Evaluate points interval by interval. + results = [] + ival = self.begin.next + end = self.end + if xs[0] < ival.a or xs[-1] > end.prev.b: + raise ValueError("Point lies outside of integration interval.") + i = 0 + while ival is not end: + j = bisect.bisect(xs, ival.b, i) + if j != i: + results.append(ival(xs[i:j])) + i = j + ival = ival.next + + return np.concatenate(results)[inv_perm].reshape(shape) + def vquad(f, a, b, tol): igrator = Vquad(f, a, b) diff --git a/vquad/test/test_core.py b/vquad/test/test_core.py index 648a445..46fd823 100644 --- a/vquad/test/test_core.py +++ b/vquad/test/test_core.py @@ -1,10 +1,11 @@ -# Copyright 2017 Christoph Groth (CEA). +# Copyright 2017, 2018 Christoph Groth (CEA). # # This file is part of Vquad. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution. import numpy as np from numpy.testing import assert_allclose +from pytest import raises from .. import core @@ -46,7 +47,7 @@ def f_one_with_nan(x): return result -def test_downdate(level=3): +def test_coeffs(level=3): vals = np.abs(core.tbls.nodes[level]) vals[1::2] = np.nan c_downdated = core._calc_coeffs(vals, level) @@ -54,8 +55,10 @@ def test_downdate(level=3): level -= 1 vals = np.abs(core.tbls.nodes[level]) c = core._calc_coeffs(vals, level) + assert_allclose(c_downdated[:len(c)], c, rtol=0, atol=1e-12) - assert_allclose(c_downdated[:len(c)], c, rtol=0, atol=1e-9) + vals_from_c = core._eval_legendre(c, core.tbls.nodes[level]) + assert_allclose(vals_from_c, vals, rtol=0, atol=1e-15) def test_integration(): @@ -90,6 +93,21 @@ def test_integration(): np.seterr(**old_settings) +def test_interpolation(): + vquad = core.Vquad(f24, 0, 3) + vquad.improve_until(1e-6) + + rng = np.random.RandomState(123) + x = np.linspace(0, 3, 100) + rng.shuffle(x) + + for x in [x, 1.23, [[2, 0], [1, 2]], [], [[]]]: + assert_allclose(vquad(x), f24(x)) + + for x in [-1e-100, 3.0001, 1e100, [1, 2, -1e50]]: + raises(ValueError, vquad, x) + + def test_analytic(n=200): def f(x): return f63(x, alpha, beta) |
