diff options
Diffstat (limited to 'vquad/core.py')
| -rw-r--r-- | vquad/core.py | 89 |
1 files changed, 82 insertions, 7 deletions
diff --git a/vquad/core.py b/vquad/core.py index f6baedb..e7faef2 100644 --- a/vquad/core.py +++ b/vquad/core.py @@ -3,7 +3,7 @@ # This file is part of Vquad. It is subject to the license terms in the file # LICENSE.rst found in the top-level directory of this distribution. -from bisect import insort +import bisect import numpy as np from scipy.linalg import norm @@ -31,6 +31,28 @@ ndiv_max = 20 _sqrt_one_half = np.sqrt(0.5) +def _eval_legendre(c, x): + """Evaluate _orthonormal_ Legendre polynomial. + + This uses the three-term recurrence relation from page 63 of Perdo + Gonnet's thesis. + """ + if len(c) <= 1: + c0 = c[0] + c1 = 0 + else: + n = len(c) + c0 = c[-2] # = c[k + 0] + c1 = c[-1] # = c[k + 1] + for k in range(len(c) - 2, 0, -1): + a = (2*k + 3) / (k + 1)**2 + tmp = c0 + c0 = c[k - 1] - c1 * np.sqrt(a * k**2 / (2*k - 1)) + c1 = tmp + c1 * x * np.sqrt(a * (2*k + 1)) + + return np.sqrt(1/2) * c0 + np.sqrt(3/2) * c1 * x + + def _calc_coeffs(vals, level): nans = np.flatnonzero(~np.isfinite(vals)) if nans.size: @@ -67,9 +89,13 @@ class DivergentIntegralError(ValueError): super().__init__(msg) +class _Terminator: + __slots__ = ['prev', 'next'] + + class _Interval: __slots__ = ['a', 'b', 'coeffs', 'vals', 'igral', 'err', 'level', 'depth', - 'ndiv', 'c00', 'unreliable_err'] + 'ndiv', 'c00', 'unreliable_err', 'prev', 'next'] def __init__(self, a, b, level, depth): self.a = a @@ -102,6 +128,11 @@ class _Interval: def __lt__(self, other): return self.err < other.err + def __call__(self, x): + a = self.a + b = self.b + x = (2 * x - (a + b)) / (b - a) + return _eval_legendre(self.coeffs, x) class Vquad: """Evaluate an integral using adaptive quadrature. @@ -128,11 +159,16 @@ class Vquad: ival.ndiv = 0 self.ivals = [ival] # Active intervals - self.attic = [] # Inactive intervals self.f = f self.igral_excess = 0 self.err_excess = 0 + # Initialize linked list. + ival.prev = self.begin = _Terminator() + self.begin.next = ival + ival.next = self.end = _Terminator() + self.end.prev = ival + def split(self, ival): m = (ival.a + ival.b) / 2 f_center = ival.vals[(len(ival.vals) - 1) // 2] @@ -191,17 +227,27 @@ class Vquad: # error. self.err_excess += ival.err self.igral_excess += ival.igral - self.attic.append(self.ivals.pop()) + self.ivals.pop() return split = ival.unreliable_err if split: # Replace current interval by its children. - for new in self.split(self.ivals.pop()): - insort(self.ivals, new) + self.ivals.pop() + child0, child1 = self.split(ival) + bisect.insort(self.ivals, child0) + bisect.insort(self.ivals, child1) + + # Maintain linked list. + ival.prev.next = child0 + ival.next.prev = child1 + child0.prev = ival.prev + child0.next = child1 + child1.prev = child0 + child1.next = ival.next else: # The error estimate of the current interval has changed. - insort(self.ivals, self.ivals.pop()) + bisect.insort(self.ivals, self.ivals.pop()) def totals(self): igral = self.igral_excess @@ -222,6 +268,35 @@ class Vquad: or not self.ivals): return igral, err + def __call__(self, xs): + xs = np.asarray(xs) + shape = xs.shape + if xs.size == 0: + return np.empty(shape) + xs = xs.flatten() + + # Sort xs, but remember inverse permutation. + perm = np.argsort(xs) + xs = xs[perm] + inv_perm = np.empty(len(perm), int) + inv_perm[perm] = np.arange(len(perm)) + + # Evaluate points interval by interval. + results = [] + ival = self.begin.next + end = self.end + if xs[0] < ival.a or xs[-1] > end.prev.b: + raise ValueError("Point lies outside of integration interval.") + i = 0 + while ival is not end: + j = bisect.bisect(xs, ival.b, i) + if j != i: + results.append(ival(xs[i:j])) + i = j + ival = ival.next + + return np.concatenate(results)[inv_perm].reshape(shape) + def vquad(f, a, b, tol): igrator = Vquad(f, a, b) |
