aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Groth <christoph.groth@cea.fr>2018-02-22 16:51:29 +0100
committerChristoph Groth <christoph.groth@cea.fr>2018-02-28 21:01:31 +0100
commit2f95f2f991956abac93cf6b273b1628fbad22f09 (patch)
tree2434455ec919f338a0421ce83d9703e747f9c5d1
parent581bad7324c4c1164ba4adec69c7b11b0e25fb7e (diff)
implement evaluation of the interpolated integrand
-rw-r--r--vquad/__init__.py2
-rw-r--r--vquad/core.py89
-rw-r--r--vquad/test/test_core.py24
3 files changed, 104 insertions, 11 deletions
diff --git a/vquad/__init__.py b/vquad/__init__.py
index 3313b0f..8555ba5 100644
--- a/vquad/__init__.py
+++ b/vquad/__init__.py
@@ -4,4 +4,4 @@
# LICENSE.rst found in the top-level directory of this distribution.
-from .core import vquad
+from .core import vquad, Vquad
diff --git a/vquad/core.py b/vquad/core.py
index f6baedb..e7faef2 100644
--- a/vquad/core.py
+++ b/vquad/core.py
@@ -3,7 +3,7 @@
# This file is part of Vquad. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution.
-from bisect import insort
+import bisect
import numpy as np
from scipy.linalg import norm
@@ -31,6 +31,28 @@ ndiv_max = 20
_sqrt_one_half = np.sqrt(0.5)
+def _eval_legendre(c, x):
+ """Evaluate _orthonormal_ Legendre polynomial.
+
+ This uses the three-term recurrence relation from page 63 of Perdo
+ Gonnet's thesis.
+ """
+ if len(c) <= 1:
+ c0 = c[0]
+ c1 = 0
+ else:
+ n = len(c)
+ c0 = c[-2] # = c[k + 0]
+ c1 = c[-1] # = c[k + 1]
+ for k in range(len(c) - 2, 0, -1):
+ a = (2*k + 3) / (k + 1)**2
+ tmp = c0
+ c0 = c[k - 1] - c1 * np.sqrt(a * k**2 / (2*k - 1))
+ c1 = tmp + c1 * x * np.sqrt(a * (2*k + 1))
+
+ return np.sqrt(1/2) * c0 + np.sqrt(3/2) * c1 * x
+
+
def _calc_coeffs(vals, level):
nans = np.flatnonzero(~np.isfinite(vals))
if nans.size:
@@ -67,9 +89,13 @@ class DivergentIntegralError(ValueError):
super().__init__(msg)
+class _Terminator:
+ __slots__ = ['prev', 'next']
+
+
class _Interval:
__slots__ = ['a', 'b', 'coeffs', 'vals', 'igral', 'err', 'level', 'depth',
- 'ndiv', 'c00', 'unreliable_err']
+ 'ndiv', 'c00', 'unreliable_err', 'prev', 'next']
def __init__(self, a, b, level, depth):
self.a = a
@@ -102,6 +128,11 @@ class _Interval:
def __lt__(self, other):
return self.err < other.err
+ def __call__(self, x):
+ a = self.a
+ b = self.b
+ x = (2 * x - (a + b)) / (b - a)
+ return _eval_legendre(self.coeffs, x)
class Vquad:
"""Evaluate an integral using adaptive quadrature.
@@ -128,11 +159,16 @@ class Vquad:
ival.ndiv = 0
self.ivals = [ival] # Active intervals
- self.attic = [] # Inactive intervals
self.f = f
self.igral_excess = 0
self.err_excess = 0
+ # Initialize linked list.
+ ival.prev = self.begin = _Terminator()
+ self.begin.next = ival
+ ival.next = self.end = _Terminator()
+ self.end.prev = ival
+
def split(self, ival):
m = (ival.a + ival.b) / 2
f_center = ival.vals[(len(ival.vals) - 1) // 2]
@@ -191,17 +227,27 @@ class Vquad:
# error.
self.err_excess += ival.err
self.igral_excess += ival.igral
- self.attic.append(self.ivals.pop())
+ self.ivals.pop()
return
split = ival.unreliable_err
if split:
# Replace current interval by its children.
- for new in self.split(self.ivals.pop()):
- insort(self.ivals, new)
+ self.ivals.pop()
+ child0, child1 = self.split(ival)
+ bisect.insort(self.ivals, child0)
+ bisect.insort(self.ivals, child1)
+
+ # Maintain linked list.
+ ival.prev.next = child0
+ ival.next.prev = child1
+ child0.prev = ival.prev
+ child0.next = child1
+ child1.prev = child0
+ child1.next = ival.next
else:
# The error estimate of the current interval has changed.
- insort(self.ivals, self.ivals.pop())
+ bisect.insort(self.ivals, self.ivals.pop())
def totals(self):
igral = self.igral_excess
@@ -222,6 +268,35 @@ class Vquad:
or not self.ivals):
return igral, err
+ def __call__(self, xs):
+ xs = np.asarray(xs)
+ shape = xs.shape
+ if xs.size == 0:
+ return np.empty(shape)
+ xs = xs.flatten()
+
+ # Sort xs, but remember inverse permutation.
+ perm = np.argsort(xs)
+ xs = xs[perm]
+ inv_perm = np.empty(len(perm), int)
+ inv_perm[perm] = np.arange(len(perm))
+
+ # Evaluate points interval by interval.
+ results = []
+ ival = self.begin.next
+ end = self.end
+ if xs[0] < ival.a or xs[-1] > end.prev.b:
+ raise ValueError("Point lies outside of integration interval.")
+ i = 0
+ while ival is not end:
+ j = bisect.bisect(xs, ival.b, i)
+ if j != i:
+ results.append(ival(xs[i:j]))
+ i = j
+ ival = ival.next
+
+ return np.concatenate(results)[inv_perm].reshape(shape)
+
def vquad(f, a, b, tol):
igrator = Vquad(f, a, b)
diff --git a/vquad/test/test_core.py b/vquad/test/test_core.py
index 648a445..46fd823 100644
--- a/vquad/test/test_core.py
+++ b/vquad/test/test_core.py
@@ -1,10 +1,11 @@
-# Copyright 2017 Christoph Groth (CEA).
+# Copyright 2017, 2018 Christoph Groth (CEA).
#
# This file is part of Vquad. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution.
import numpy as np
from numpy.testing import assert_allclose
+from pytest import raises
from .. import core
@@ -46,7 +47,7 @@ def f_one_with_nan(x):
return result
-def test_downdate(level=3):
+def test_coeffs(level=3):
vals = np.abs(core.tbls.nodes[level])
vals[1::2] = np.nan
c_downdated = core._calc_coeffs(vals, level)
@@ -54,8 +55,10 @@ def test_downdate(level=3):
level -= 1
vals = np.abs(core.tbls.nodes[level])
c = core._calc_coeffs(vals, level)
+ assert_allclose(c_downdated[:len(c)], c, rtol=0, atol=1e-12)
- assert_allclose(c_downdated[:len(c)], c, rtol=0, atol=1e-9)
+ vals_from_c = core._eval_legendre(c, core.tbls.nodes[level])
+ assert_allclose(vals_from_c, vals, rtol=0, atol=1e-15)
def test_integration():
@@ -90,6 +93,21 @@ def test_integration():
np.seterr(**old_settings)
+def test_interpolation():
+ vquad = core.Vquad(f24, 0, 3)
+ vquad.improve_until(1e-6)
+
+ rng = np.random.RandomState(123)
+ x = np.linspace(0, 3, 100)
+ rng.shuffle(x)
+
+ for x in [x, 1.23, [[2, 0], [1, 2]], [], [[]]]:
+ assert_allclose(vquad(x), f24(x))
+
+ for x in [-1e-100, 3.0001, 1e100, [1, 2, -1e50]]:
+ raises(ValueError, vquad, x)
+
+
def test_analytic(n=200):
def f(x):
return f63(x, alpha, beta)