diff options
| -rw-r--r-- | vquad/core.py | 111 |
1 files changed, 62 insertions, 49 deletions
diff --git a/vquad/core.py b/vquad/core.py index e7faef2..2f3cac0 100644 --- a/vquad/core.py +++ b/vquad/core.py @@ -134,6 +134,64 @@ class _Interval: x = (2 * x - (a + b)) / (b - a) return _eval_legendre(self.coeffs, x) + def split(self): + """Split this interval in the center into two children. + + This is a coroutine that initially yields an array of x values + of points to be evaluated. Once the corresponding values have + been sent back a tuple containing the child intervals is + yielded and execution ends. + """ + m = (self.a + self.b) / 2 + f_center = self.vals[(len(self.vals) - 1) // 2] + + depth = self.depth + 1 + children = [_Interval(self.a, m, min_level, depth), + _Interval(m, self.b, min_level, depth)] + points = np.concatenate([child.points()[1:-1] for child in children]) + valss = np.empty((2, tbls.sizes[min_level])) + valss[:, 0] = self.vals[0], f_center + valss[:, -1] = f_center, self.vals[-1] + valss[:, 1:-1] = (yield points).reshape((2, -1)) + + for child, vals, T in zip(children, valss, tbls.Ts): + child.interpolate(vals, T[:, :self.coeffs.shape[0]] @ self.coeffs) + child.ndiv = (self.ndiv + + (self.c00 and child.c00 / self.c00 > 2)) + if child.ndiv > ndiv_max and 2*child.ndiv > child.depth: + msg = ('Possibly divergent integral in the interval ' + '[{}, {}]! (h={})') + raise DivergentIntegralError( + msg.format(child.a, child.b, child.b - child.a), + child.igral * np.inf, None) + yield children + + def refine(self): + """Increase degree of interval. + + This is a coroutine that initially yields an array of x values + of points to be evaluated. Once the corresponding values have + been sent back, a bool is yielded and execution ends. + + It is "true" if further refinements/splits of the interval seem + promising, and "false" otherwise. This is the case when + neigboring points can be resolved only barely by floating point + numbers, or when the estimated relative error is already at the + limit of numerical accuracy and cannot be reduced further. + """ + self.level += 1 + points = self.points() + vals = np.empty(points.shape) + vals[0::2] = self.vals + vals[1::2] = (yield points[1::2]) + self.interpolate(vals, self.coeffs) + + yield (points[1] - points[0] > points[0] * min_sep + and points[-1] - points[-2] > points[-2] * min_sep + and self.err > (abs(self.igral) + * eps * tbls.V_cond_nums[self.level])) + + class Vquad: """Evaluate an integral using adaptive quadrature. @@ -169,60 +227,14 @@ class Vquad: ival.next = self.end = _Terminator() self.end.prev = ival - def split(self, ival): - m = (ival.a + ival.b) / 2 - f_center = ival.vals[(len(ival.vals) - 1) // 2] - - depth = ival.depth + 1 - children = [_Interval(ival.a, m, min_level, depth), - _Interval(m, ival.b, min_level, depth)] - points = np.concatenate([child.points()[1:-1] for child in children]) - valss = np.empty((2, tbls.sizes[min_level])) - valss[:, 0] = ival.vals[0], f_center - valss[:, -1] = f_center, ival.vals[-1] - valss[:, 1:-1] = self.f(points).reshape((2, -1)) - - for child, vals, T in zip(children, valss, tbls.Ts): - child.interpolate(vals, T[:, :ival.coeffs.shape[0]] @ ival.coeffs) - child.ndiv = (ival.ndiv - + (ival.c00 and child.c00 / ival.c00 > 2)) - if child.ndiv > ndiv_max and 2*child.ndiv > child.depth: - msg = ('Possibly divergent integral in the interval ' - '[{}, {}]! (h={})') - raise DivergentIntegralError( - msg.format(child.a, child.b, child.b - child.a), - child.igral * np.inf, None) - return children - - def refine(self, ival): - """Increase degree of interval. - - Returns True if the refined interval is OK, and False if it is - borderline and should not be refined/split further. This - happens when neigboring points can be only barely resolved by - floating point numbers, or when the estimated relative error is - already at the limit of numerical accuracy and cannot be reduced - further. - """ - ival.level += 1 - points = ival.points() - vals = np.empty(points.shape) - vals[0::2] = ival.vals - vals[1::2] = self.f(points[1::2]) - ival.interpolate(vals, ival.coeffs) - - return (points[1] - points[0] > points[0] * min_sep - and points[-1] - points[-2] > points[-2] * min_sep - and ival.err > (abs(ival.igral) - * eps * tbls.V_cond_nums[ival.level])) - def improve(self): ival = self.ivals[-1] if ival.level == max_level: split = True else: - if not self.refine(ival): + refine = ival.refine() + if not refine.send(self.f(next(refine))): # Remove the interval but remember the excess integral and # error. self.err_excess += ival.err @@ -234,7 +246,8 @@ class Vquad: if split: # Replace current interval by its children. self.ivals.pop() - child0, child1 = self.split(ival) + split = ival.split() + child0, child1 = split.send(self.f(next(split))) bisect.insort(self.ivals, child0) bisect.insort(self.ivals, child1) |
