aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorChristoph Groth <christoph.groth@cea.fr>2018-02-23 15:55:33 +0100
committerChristoph Groth <christoph.groth@cea.fr>2018-02-28 21:02:54 +0100
commitca7951f99e22f9d8e70f93f4bf590cd251350c4e (patch)
tree2aff6f0e3da5b37c448bacded91b9f950691539a
parent2f95f2f991956abac93cf6b273b1628fbad22f09 (diff)
turn split and refine into coroutines and move them to _Interval
-rw-r--r--vquad/core.py111
1 files changed, 62 insertions, 49 deletions
diff --git a/vquad/core.py b/vquad/core.py
index e7faef2..2f3cac0 100644
--- a/vquad/core.py
+++ b/vquad/core.py
@@ -134,6 +134,64 @@ class _Interval:
x = (2 * x - (a + b)) / (b - a)
return _eval_legendre(self.coeffs, x)
+ def split(self):
+ """Split this interval in the center into two children.
+
+ This is a coroutine that initially yields an array of x values
+ of points to be evaluated. Once the corresponding values have
+ been sent back a tuple containing the child intervals is
+ yielded and execution ends.
+ """
+ m = (self.a + self.b) / 2
+ f_center = self.vals[(len(self.vals) - 1) // 2]
+
+ depth = self.depth + 1
+ children = [_Interval(self.a, m, min_level, depth),
+ _Interval(m, self.b, min_level, depth)]
+ points = np.concatenate([child.points()[1:-1] for child in children])
+ valss = np.empty((2, tbls.sizes[min_level]))
+ valss[:, 0] = self.vals[0], f_center
+ valss[:, -1] = f_center, self.vals[-1]
+ valss[:, 1:-1] = (yield points).reshape((2, -1))
+
+ for child, vals, T in zip(children, valss, tbls.Ts):
+ child.interpolate(vals, T[:, :self.coeffs.shape[0]] @ self.coeffs)
+ child.ndiv = (self.ndiv
+ + (self.c00 and child.c00 / self.c00 > 2))
+ if child.ndiv > ndiv_max and 2*child.ndiv > child.depth:
+ msg = ('Possibly divergent integral in the interval '
+ '[{}, {}]! (h={})')
+ raise DivergentIntegralError(
+ msg.format(child.a, child.b, child.b - child.a),
+ child.igral * np.inf, None)
+ yield children
+
+ def refine(self):
+ """Increase degree of interval.
+
+ This is a coroutine that initially yields an array of x values
+ of points to be evaluated. Once the corresponding values have
+ been sent back, a bool is yielded and execution ends.
+
+ It is "true" if further refinements/splits of the interval seem
+ promising, and "false" otherwise. This is the case when
+ neigboring points can be resolved only barely by floating point
+ numbers, or when the estimated relative error is already at the
+ limit of numerical accuracy and cannot be reduced further.
+ """
+ self.level += 1
+ points = self.points()
+ vals = np.empty(points.shape)
+ vals[0::2] = self.vals
+ vals[1::2] = (yield points[1::2])
+ self.interpolate(vals, self.coeffs)
+
+ yield (points[1] - points[0] > points[0] * min_sep
+ and points[-1] - points[-2] > points[-2] * min_sep
+ and self.err > (abs(self.igral)
+ * eps * tbls.V_cond_nums[self.level]))
+
+
class Vquad:
"""Evaluate an integral using adaptive quadrature.
@@ -169,60 +227,14 @@ class Vquad:
ival.next = self.end = _Terminator()
self.end.prev = ival
- def split(self, ival):
- m = (ival.a + ival.b) / 2
- f_center = ival.vals[(len(ival.vals) - 1) // 2]
-
- depth = ival.depth + 1
- children = [_Interval(ival.a, m, min_level, depth),
- _Interval(m, ival.b, min_level, depth)]
- points = np.concatenate([child.points()[1:-1] for child in children])
- valss = np.empty((2, tbls.sizes[min_level]))
- valss[:, 0] = ival.vals[0], f_center
- valss[:, -1] = f_center, ival.vals[-1]
- valss[:, 1:-1] = self.f(points).reshape((2, -1))
-
- for child, vals, T in zip(children, valss, tbls.Ts):
- child.interpolate(vals, T[:, :ival.coeffs.shape[0]] @ ival.coeffs)
- child.ndiv = (ival.ndiv
- + (ival.c00 and child.c00 / ival.c00 > 2))
- if child.ndiv > ndiv_max and 2*child.ndiv > child.depth:
- msg = ('Possibly divergent integral in the interval '
- '[{}, {}]! (h={})')
- raise DivergentIntegralError(
- msg.format(child.a, child.b, child.b - child.a),
- child.igral * np.inf, None)
- return children
-
- def refine(self, ival):
- """Increase degree of interval.
-
- Returns True if the refined interval is OK, and False if it is
- borderline and should not be refined/split further. This
- happens when neigboring points can be only barely resolved by
- floating point numbers, or when the estimated relative error is
- already at the limit of numerical accuracy and cannot be reduced
- further.
- """
- ival.level += 1
- points = ival.points()
- vals = np.empty(points.shape)
- vals[0::2] = ival.vals
- vals[1::2] = self.f(points[1::2])
- ival.interpolate(vals, ival.coeffs)
-
- return (points[1] - points[0] > points[0] * min_sep
- and points[-1] - points[-2] > points[-2] * min_sep
- and ival.err > (abs(ival.igral)
- * eps * tbls.V_cond_nums[ival.level]))
-
def improve(self):
ival = self.ivals[-1]
if ival.level == max_level:
split = True
else:
- if not self.refine(ival):
+ refine = ival.refine()
+ if not refine.send(self.f(next(refine))):
# Remove the interval but remember the excess integral and
# error.
self.err_excess += ival.err
@@ -234,7 +246,8 @@ class Vquad:
if split:
# Replace current interval by its children.
self.ivals.pop()
- child0, child1 = self.split(ival)
+ split = ival.split()
+ child0, child1 = split.send(self.f(next(split)))
bisect.insort(self.ivals, child0)
bisect.insort(self.ivals, child1)