aboutsummaryrefslogtreecommitdiff
path: root/vquad/core.py
diff options
context:
space:
mode:
authorChristoph Groth <christoph.groth@cea.fr>2018-02-22 16:51:29 +0100
committerChristoph Groth <christoph.groth@cea.fr>2018-02-28 21:01:31 +0100
commit2f95f2f991956abac93cf6b273b1628fbad22f09 (patch)
tree2434455ec919f338a0421ce83d9703e747f9c5d1 /vquad/core.py
parent581bad7324c4c1164ba4adec69c7b11b0e25fb7e (diff)
implement evaluation of the interpolated integrand
Diffstat (limited to 'vquad/core.py')
-rw-r--r--vquad/core.py89
1 files changed, 82 insertions, 7 deletions
diff --git a/vquad/core.py b/vquad/core.py
index f6baedb..e7faef2 100644
--- a/vquad/core.py
+++ b/vquad/core.py
@@ -3,7 +3,7 @@
# This file is part of Vquad. It is subject to the license terms in the file
# LICENSE.rst found in the top-level directory of this distribution.
-from bisect import insort
+import bisect
import numpy as np
from scipy.linalg import norm
@@ -31,6 +31,28 @@ ndiv_max = 20
_sqrt_one_half = np.sqrt(0.5)
+def _eval_legendre(c, x):
+ """Evaluate _orthonormal_ Legendre polynomial.
+
+ This uses the three-term recurrence relation from page 63 of Perdo
+ Gonnet's thesis.
+ """
+ if len(c) <= 1:
+ c0 = c[0]
+ c1 = 0
+ else:
+ n = len(c)
+ c0 = c[-2] # = c[k + 0]
+ c1 = c[-1] # = c[k + 1]
+ for k in range(len(c) - 2, 0, -1):
+ a = (2*k + 3) / (k + 1)**2
+ tmp = c0
+ c0 = c[k - 1] - c1 * np.sqrt(a * k**2 / (2*k - 1))
+ c1 = tmp + c1 * x * np.sqrt(a * (2*k + 1))
+
+ return np.sqrt(1/2) * c0 + np.sqrt(3/2) * c1 * x
+
+
def _calc_coeffs(vals, level):
nans = np.flatnonzero(~np.isfinite(vals))
if nans.size:
@@ -67,9 +89,13 @@ class DivergentIntegralError(ValueError):
super().__init__(msg)
+class _Terminator:
+ __slots__ = ['prev', 'next']
+
+
class _Interval:
__slots__ = ['a', 'b', 'coeffs', 'vals', 'igral', 'err', 'level', 'depth',
- 'ndiv', 'c00', 'unreliable_err']
+ 'ndiv', 'c00', 'unreliable_err', 'prev', 'next']
def __init__(self, a, b, level, depth):
self.a = a
@@ -102,6 +128,11 @@ class _Interval:
def __lt__(self, other):
return self.err < other.err
+ def __call__(self, x):
+ a = self.a
+ b = self.b
+ x = (2 * x - (a + b)) / (b - a)
+ return _eval_legendre(self.coeffs, x)
class Vquad:
"""Evaluate an integral using adaptive quadrature.
@@ -128,11 +159,16 @@ class Vquad:
ival.ndiv = 0
self.ivals = [ival] # Active intervals
- self.attic = [] # Inactive intervals
self.f = f
self.igral_excess = 0
self.err_excess = 0
+ # Initialize linked list.
+ ival.prev = self.begin = _Terminator()
+ self.begin.next = ival
+ ival.next = self.end = _Terminator()
+ self.end.prev = ival
+
def split(self, ival):
m = (ival.a + ival.b) / 2
f_center = ival.vals[(len(ival.vals) - 1) // 2]
@@ -191,17 +227,27 @@ class Vquad:
# error.
self.err_excess += ival.err
self.igral_excess += ival.igral
- self.attic.append(self.ivals.pop())
+ self.ivals.pop()
return
split = ival.unreliable_err
if split:
# Replace current interval by its children.
- for new in self.split(self.ivals.pop()):
- insort(self.ivals, new)
+ self.ivals.pop()
+ child0, child1 = self.split(ival)
+ bisect.insort(self.ivals, child0)
+ bisect.insort(self.ivals, child1)
+
+ # Maintain linked list.
+ ival.prev.next = child0
+ ival.next.prev = child1
+ child0.prev = ival.prev
+ child0.next = child1
+ child1.prev = child0
+ child1.next = ival.next
else:
# The error estimate of the current interval has changed.
- insort(self.ivals, self.ivals.pop())
+ bisect.insort(self.ivals, self.ivals.pop())
def totals(self):
igral = self.igral_excess
@@ -222,6 +268,35 @@ class Vquad:
or not self.ivals):
return igral, err
+ def __call__(self, xs):
+ xs = np.asarray(xs)
+ shape = xs.shape
+ if xs.size == 0:
+ return np.empty(shape)
+ xs = xs.flatten()
+
+ # Sort xs, but remember inverse permutation.
+ perm = np.argsort(xs)
+ xs = xs[perm]
+ inv_perm = np.empty(len(perm), int)
+ inv_perm[perm] = np.arange(len(perm))
+
+ # Evaluate points interval by interval.
+ results = []
+ ival = self.begin.next
+ end = self.end
+ if xs[0] < ival.a or xs[-1] > end.prev.b:
+ raise ValueError("Point lies outside of integration interval.")
+ i = 0
+ while ival is not end:
+ j = bisect.bisect(xs, ival.b, i)
+ if j != i:
+ results.append(ival(xs[i:j]))
+ i = j
+ ival = ival.next
+
+ return np.concatenate(results)[inv_perm].reshape(shape)
+
def vquad(f, a, b, tol):
igrator = Vquad(f, a, b)