Coverage for src / zooc / dsp / surface_extrapolator.py: 94%

77 statements  

« prev     ^ index     » next       coverage.py v7.13.0, created at 2025-12-11 21:45 +0000

1"""2D-surface with interpolation and extrapolation.""" 

2from __future__ import annotations 

3 

4from abc import ABC, abstractmethod 

5from collections.abc import Callable, Sequence 

6from functools import cached_property 

7from typing import Any, Final, TypeVar, override 

8 

9import numpy as np 

10import numpy.typing as npt 

11from numpy import dtype, float64, ndarray 

12from scipy.interpolate import CloughTocher2DInterpolator, RegularGridInterpolator, interp1d 

13from scipy.spatial import Delaunay # pylint: disable=no-name-in-module 

14 

15_MIN_POINTS: Final[int] = 2 

16"""Minimum number of points required to define a line.""" 

17 

18T = TypeVar('T', bound=npt.NDArray[np.float64] | float) 

19 

20 

21class SurfaceExtrapolator(ABC): # pylint: disable=too-few-public-methods 

22 """Base class for 2D-surface extrapolators.""" 

23 

24 @abstractmethod 

25 def __call__(self, xx: T, yy: T) -> npt.NDArray[np.float64]: 

26 """Evaluate interpolator at given points. 

27 

28 x1, x2, … xn: array-like of float. 

29 Points where to interpolate data at. x1, x2, … xn can be array-like of float with broadcastable shape. or x1 can be array-like of float with shape (..., ndim) (xi[, method, nu]). 

30 

31 :param xx: X-coordinates. 

32 :param yy: Y-coordinates. 

33 :return: Z-values at given x, y coordinates. 

34 """ 

35 

36 

37class SurfaceExtrapolator2d(SurfaceExtrapolator): 

38 """2D-surface interpolator using LinearNDInterpolator and when outside hull, RegularGridInterpolator for extrapolation. 

39 

40 The interpolation works for any shaped input coordinates. 

41 However, the extrapolation works only when the convex hull defined by the input coordinates is rectangular. 

42 """ 

43 

44 def __init__(self, xy: Sequence[tuple[float, float]], z: Sequence[float]) -> None: 

45 """Initialize with given XY and Z data. 

46 

47 :param xy: XY-data in order. 

48 :param z: Z-data in order. 

49 """ 

50 super().__init__() 

51 self.xy = xy 

52 self.interp = CloughTocher2DInterpolator(xy, z) 

53 # LinearNDInterpolator produces biased results with few data points. 

54 # self.interp = LinearNDInterpolator(xy, z) 

55 self._validate() 

56 

57 def get_bounds(self) -> tuple[ndarray[tuple[int, ...], dtype[float64]], ndarray[tuple[int, ...], dtype[float64]]]: 

58 """Get the XY-data bounds. 

59 

60 :return: Tuple of min and max bounds of the xy-data. [min_x, min_y], [max_x, max_y]. 

61 """ 

62 tri = Delaunay(np.vstack(self.xy)) 

63 return tri.min_bound, tri.max_bound 

64 

65 @cached_property 

66 def extrapolator(self) -> RegularGridInterpolator: 

67 """Get the extrapolator for the data outside the xy-data range. 

68 

69 :return: Extrapolator object. 

70 """ 

71 # Convert data to linear space by interpolating the linear data points 

72 # To avoid losing precision, interpolate extra points with lin_factor coefficient 

73 # Use even integer to get exact values on original data points when those are already linear 

74 lin_factor: int = 8 

75 num = len(self.xy) * lin_factor + (len(self.xy) % 2) 

76 min_v, max_v = self.get_bounds() 

77 

78 x_lin = np.linspace(min_v[0], max_v[0], num=num) 

79 y_lin = np.linspace(min_v[1], max_v[1], num=num) 

80 grid_x_lin, grid_y_lin = np.meshgrid(x_lin, y_lin, indexing='ij') 

81 z_lin = self.interp(grid_x_lin, grid_y_lin) 

82 

83 # Use RegularGridInterpolator for extrapolation with linear method 

84 # https://docs.scipy.org/doc/scipy/tutorial/interpolate.html 

85 return RegularGridInterpolator((x_lin, y_lin), z_lin, 

86 bounds_error=False, 

87 fill_value=None, 

88 method="linear", # "linear", "nearest", "slinear", "cubic", "quintic" and "pchip" 

89 ) 

90 

91 @override 

92 def __call__(self, xx: T, yy: T) -> npt.NDArray[np.float64]: 

93 zz = self.interp((xx, yy)) 

94 # For out-of-bounds values, use the extrapolator 

95 nans = np.isnan(zz) 

96 if nans.any(): 

97 zz[nans] = self.extrapolator((np.asarray(xx)[nans], np.asarray(yy)[nans])) 

98 # zz = self.extrapolator((xx, yy)) # always use extrapolator, even if inside bounds 

99 return zz 

100 

101 def _validate(self) -> None: 

102 """Ensure the corners of the convex hull rectangle are part of the input points. 

103 

104 :raises ValueError: If the input points do not form a convex hull rectangle. 

105 """ 

106 min_v, max_v = self.get_bounds() 

107 corners = {(min_v[0], min_v[1]), (min_v[0], max_v[1]), (max_v[0], min_v[1]), (max_v[0], max_v[1])} 

108 if not corners.issubset(set(map(tuple, self.xy))): 

109 raise ValueError("Input points must form a upright rectangular convex hull for extrapolation.") 

110 

111 

112class SurfaceExtrapolator1d(SurfaceExtrapolator): # pylint: disable=too-few-public-methods 

113 """1D interpolator wrapper for surfaces where another dimension is collinear, while maintaining the 2d-surface interface. 

114 

115 Using interp1d for extrapolation. 

116 """ 

117 

118 def __init__(self, get_lin: Callable[[Any, Any], Any], a: npt.NDArray[np.float64], z: Sequence[float]) -> None: 

119 """Initialize with given X and Z data. 

120 

121 :param get_lin: Function to select the dimension from the input coordinates, e.g.: lambda x, y: x. 

122 :param a: X- or Y-data in order 

123 :param z: Z-data in order 

124 """ 

125 super().__init__() 

126 self.get_lin = get_lin 

127 self.interp = interp1d(a, z, kind='linear', fill_value="extrapolate") 

128 

129 @override 

130 def __call__(self, xx: T, yy: T) -> npt.NDArray[np.float64]: 

131 return self.interp(self.get_lin(xx, yy)) # type: ignore[return-value] 

132 

133 

134def get_non_collinear(points_2d: npt.NDArray[np.float64]) -> tuple[npt.NDArray[np.float64] | None, npt.NDArray[np.float64] | None]: 

135 """Get non-collinear vectors, or None if the points are coplanar. 

136 

137 Checks if a set of 2D points are collinear (lie on the same line). 

138 

139 :param points_2d: A 2D NumPy array of shape (N, 2) where N is the number of points. 

140 :return: A tuple of two arrays (x, y) where: 

141 

142 - x contains the x-coordinates of the points if they are not collinear in the x-dimension, otherwise None. 

143 - y contains the y-coordinates of the points if they are not collinear in the y-dimension, otherwise None. 

144 """ 

145 x: npt.NDArray[np.float64] | None = points_2d[:, 0] 

146 y: npt.NDArray[np.float64] | None = points_2d[:, 1] 

147 # Check if all x-coordinates are the same (vertical line) 

148 if np.all(x == points_2d[0, 0]): 

149 x = None 

150 # Check if all y-coordinates are the same (horizontal line) 

151 if np.all(y == points_2d[0, 1]): 

152 y = None 

153 return x, y 

154 

155 

156def create(xy: Sequence[tuple[float, float]], z: Sequence[float]) -> SurfaceExtrapolator: 

157 """Create a surface extrapolator based on the input XY and Z data. 

158 

159 :param xy: XY-data in order. 

160 :param z: Z-data in order. 

161 :return: SurfaceExtrapolator object. 

162 :raises ValueError: If the input points are collinear in both x and y dimensions, or if there are not enough data points. 

163 """ 

164 array = np.array(xy) 

165 if array.shape[0] < _MIN_POINTS: # Need at least two points to define a line 165 ↛ 166line 165 didn't jump to line 166 because the condition on line 165 was never true

166 raise ValueError("Not enough data points") 

167 

168 x_cl, y_cl = get_non_collinear(array) 

169 if x_cl is None and y_cl is None: 169 ↛ 170line 169 didn't jump to line 170 because the condition on line 169 was never true

170 raise ValueError("Input points are collinear in both x and y dimensions, cannot create a valid surface extrapolator") 

171 if x_cl is not None and y_cl is not None: 

172 return SurfaceExtrapolator2d(xy, z) 

173 

174 # Data is collinear in one dimension, use 1D extrapolator 

175 if x_cl is not None: 

176 return SurfaceExtrapolator1d(lambda x, y: x, x_cl, z) 

177 if y_cl is not None: 177 ↛ 179line 177 didn't jump to line 179 because the condition on line 177 was always true

178 return SurfaceExtrapolator1d(lambda x, y: y, y_cl, z) 

179 raise ValueError("Unexpected state: both x and y are None, should not happen")