diff --git a/pyproject.toml b/pyproject.toml index 719e167..e846414 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,7 @@ name = "source-modelling" authors = [{ name = "ucgmsim" }] description = "Source modelling library" readme = "README.md" -requires-python = ">=3.11,<3.14" +requires-python = ">=3.11" dynamic = ["version"] dependencies = [ "fiona", diff --git a/source_modelling/gc2_distances.py b/source_modelling/gc2_distances.py new file mode 100644 index 0000000..bedcfe2 --- /dev/null +++ b/source_modelling/gc2_distances.py @@ -0,0 +1,405 @@ +"""Implementation of gc2 distance metrics from NGA-West-3. + +All functions from this module are derived from the following paper: + +Spudich, P. A., & Chiou, B. (2015). Strike-parallel and strike-normal +coordinate system around geometrically complicated rupture traces: Use +by NGA-West2 and further improvements (No. 2015-1028). US Geological +Survey. + +All referenced pages and equations are in this paper. +""" + +import itertools + +import numpy as np +import shapely +from numba import float64, int64, njit, uint64 + + +def segment_rx_ry( + bounds: np.ndarray, points: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Calculate segment rx and ry distances. + + Parameters + ---------- + bounds : np.ndarray + Bounds for n segments, an array of compatible shape to (m, 2, + 2) = (num_segments, num_trace_points, x & y). + points : np.ndarray + Points to measure rx and ry. Has shape (n, 2). + + Returns + ------- + rx : np.ndarray + The Rx distance measure, an array of shape (m, n). + ry : np.ndarray + The Ry distance measure, an array of shape (m, n). + """ + bounds = np.atleast_3d(bounds).reshape((-1, 2, 2)) + points = np.atleast_2d(points) + + q = bounds[:, 0, :] + r = bounds[:, 1, :] + + qr = (r - q)[:, np.newaxis, :] + + qp = points[np.newaxis, :, :] - q[:, np.newaxis, :] + + dot_qp_qr = np.vecdot(qp, qr) + dot_qr_qr = np.vecdot(qr, qr) + + t = dot_qp_qr / dot_qr_qr + + closest_point = q[:, np.newaxis, :] + t[:, :, np.newaxis] * qr + + rx = np.linalg.norm(points[np.newaxis, :, :] - closest_point, axis=-1) + + sign = np.sign(qp[..., 1] * qr[..., 0] - qp[..., 0] * qr[..., 1]) + rx *= sign + + ry = t * np.sqrt(dot_qr_qr) + + return rx, ry + + +def segment_weights( + trace_lengths: np.ndarray, rx: np.ndarray, ry: np.ndarray +) -> np.ndarray: + """Calculate segment weights from trace lengths. + + Segment weights implement Equation (1) of Section 3.4. + + Parameters + ---------- + trace_lengths : np.ndarray + Length of each segment trace in strike order, an array of shape (m,). + rx : np.ndarray + The rx distances to each point (n points in total), an array of shape (m, n). + ry : np.ndarray + The ry distances to each point (n points in total), an array of shape (m, n). + + Returns + ------- + np.ndarray + An array of shape (m, n) containing distance weights for each segment point pair. + """ + trace_lengths = trace_lengths[:, np.newaxis] + + mask_zero = np.isclose(rx, 0.0) + + theta = np.arctan( + np.divide(trace_lengths - ry, rx, where=~mask_zero, out=np.zeros_like(rx)) + ) - np.arctan(np.divide(-ry, rx, where=~mask_zero, out=np.zeros_like(rx))) + + w = np.divide(theta, rx, where=~mask_zero, out=np.zeros_like(rx)) + + if np.any(mask_zero): + special_case = np.divide( + 1.0, + ry - trace_lengths, + where=(ry != trace_lengths), + out=np.full_like(w, np.nan), + ) - np.divide(1.0, ry, where=(ry != 0), out=np.full_like(w, np.nan)) + + w = np.where(mask_zero, special_case, w) + + return w + + +@njit([float64[:](float64[:], uint64[:]), float64[:](float64[:], int64[:])], cache=True) +def cumulative_reduction(data: np.ndarray, indices: np.ndarray) -> np.ndarray: + """Calculate sum between indices. + + Parameters + ---------- + data : np.ndarray + The data to sum. + indices : np.ndarray + The indices to sum between. + + Returns + ------- + np.ndarray + The equivalent to ``np.cumulative_sum(data[indices[i]:indices[i + 1]], include_initial=True)`` for each index i. + """ + out = np.zeros_like(data) + # indices defines the start of each trace. + # We append the total length to handle the last block. + for j in range(len(indices) - 1): + start = indices[j] + end = indices[j + 1] + + acc = 0.0 + for i in range(start, end): + out[i] = acc + acc += data[i] + return out + + +@njit( + [float64[:, :](float64[:, :], uint64[:]), float64[:, :](float64[:, :], int64[:])], + cache=True, +) +def diff_reduction(points: np.ndarray, indices: np.ndarray) -> np.ndarray: + """Calculate a diff reduction points between the given indices. + + Parameters + ---------- + points : np.ndarray + The points take a difference for. Differences are done in pairs so that ``diff[i] = points[2*i + 1] - points[2*i]``. + indices : np.ndarray + Indices of the boundary. Differences are not computed between boundaries. + + Returns + ------- + np.ndarray + The differences between pairs of elements in ``points`` considering the boundaries ``indices``. + """ + n_segments = len(points) // 2 + out = np.zeros((n_segments, points.shape[1])) + + write_ptr = 0 + for j in range(len(indices) - 1): + start = indices[j] + end = indices[j + 1] + for i in range(start, end, 2): + out[write_ptr] = points[i + 1] - points[i] + write_ptr += 1 + return out + + +def calculate_gc2_u_origins( + segment_lengths: np.ndarray, # List of lengths per trace j + segment_indices: np.ndarray, # Index of segments + trace_starts: np.ndarray, # p_1,j for each trace, shape (num_traces, 2) + p_origin: np.ndarray, # The back-end antipodal point + b_hat: np.ndarray, # Nominal strike unit vector +) -> np.ndarray: + """Calculates shifted origins for GC2 Rx/Ry calculations in multi-trace systems. + + Follows Equation (12), page 6 of the GC2 distance metric specification to transform + local segment distances into a globalised U-coordinate system. + + Parameters + ---------- + segment_lengths : np.ndarray + The lengths of every individual segment across all traces, + shape (m,). + segment_indices : np.ndarray + The indices marking the start of each trace within the flattened + segment array, shape (num_traces + 1,). This must include the + total number of segments as the final element. + trace_starts : np.ndarray + The (x, y) coordinates for the first point of each trace, + shape (num_traces, 2). + p_origin : np.ndarray + The global origin point, shape (2,). + b_hat : np.ndarray + The nominal strike unit vector for the entire system, shape (2,). + + Returns + ------- + np.ndarray + The cumulative U-coordinate distance at the start of each segment + relative to p_origin, shape (m,). + """ + local_shifts = cumulative_reduction(segment_lengths, segment_indices) + segment_counts = np.diff(segment_indices) + + segment_counts = segment_counts.astype(np.int64, casting="safe") + + global_shifts = np.repeat(np.dot(trace_starts - p_origin, b_hat), segment_counts) + return global_shifts + local_shifts + + +def generalised_t_u_coordinates( + trace_lengths: np.ndarray, + rx: np.ndarray, + ry: np.ndarray, + segment_u_origins: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Generalised rx and ry calculations for multiple segments. + + Implements Equations (3) and (9) of page 4. + + Parameters + ---------- + trace_lengths : np.ndarray + Lengths of each segment, an array of shape (m,). + rx : np.ndarray + The rx distances to each point (n points in total), an array of shape (m, n). + ry : np.ndarray + The ry distances to each point (n points in total), an array of shape (m, n). + segment_u_origins : np.ndarray + The U-coordinate origin shift for each segment, shape (m,). + Calculated via Equation 12. + + Returns + ------- + t : np.ndarray + The weighted average rx distance, an array of shape (m, n). + u : np.ndarray + The weighted average ry distance, an array of shape (m, n). + """ + w = segment_weights(trace_lengths, rx, ry) + + # Apply the Equation 12 shifts to ry + ry_global = ry + segment_u_origins[:, np.newaxis] + + # Calculate weighted averages + t = np.average(rx, axis=0, weights=w) + u = np.average(ry_global, axis=0, weights=w) + + return t, u + + +def antipodal_points(points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Calculate the farthest pair of points in a given set of points. + + Parameters + ---------- + points : np.ndarray + The points to search. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + The farthest pair of points in the set. + + Raises + ------ + ValueError + If ``len(points) == 1``. + """ + # Implementation following the observation that complex algorithms + # => many complex bugs. Therefore, only pay the cost of a complex + # implementation when the bugs are worth the speed increase. We + # could use the rotating calipers method here but it would only be + # marginally faster given the input size and we cannot implement + # it from well-tested libraries. + + if len(points) == 1: + raise ValueError("Cannot find antipodal pair with only one point.") + hull = shapely.convex_hull(shapely.multipoints(points)) + hull_points = shapely.get_coordinates(hull) + + # Shapely polygons repeat the first point at the end to close the + # exterior ring. Removing this doesn't change the result because + # the point is included twice, but it is faster. + if isinstance(hull, shapely.Polygon): + hull_points = hull_points[:-1] + + return max( + (pair for pair in itertools.combinations(hull_points, 2)), + key=lambda pair: np.square(pair[1] - pair[0]).sum(), + ) + + +def trial_strike_vector(trace_endpoints: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Calculate a trial strike vector between all trace endpoints. + + Implements the algorithm described in page 6. + + Parameters + ---------- + trace_endpoints : np.ndarray + The trace endpoints of the rupture. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + The pair (a, b) of farthest endpoints ordered such that a -> b is in the + east direction. + """ + a, b = antipodal_points(trace_endpoints) + # NZTM (y, x) coordinate system, b east of a ideally. + if b[1] < a[1]: + a, b = b, a + return a, b + + +def strike_corrected_directions( + trace_directions: np.ndarray, trial_unit_vector: np.ndarray +) -> np.ndarray: + """Correct strike directions by reversing discordant strikes. + + Implements the algorithm described in page 6. + + Parameters + ---------- + trace_directions : np.ndarray + The trace directions to correct. + trial_unit_vector : np.ndarray + The unit vector for directions. + + Returns + ------- + np.ndarray + Trace endpoints re-ordered so that they always point in or against the + direction of the unit vector. + """ + trace_directions = trace_directions.copy() + e = np.vecdot(trial_unit_vector, trace_directions) + e_sum = e.sum() + trace_directions[np.sign(e) != np.sign(e_sum)] *= -1 + return trace_directions + + +def multi_trace_rx_ry( + trace_points: np.ndarray, + trace_indices: np.ndarray, + rx: np.ndarray, + ry: np.ndarray, +) -> tuple[np.ndarray, np.ndarray]: + """Multi-trace rx and ry calculations. + + Parameters + ---------- + trace_points : np.ndarray + Trace points. + trace_indices : np.ndarray + Trace indices for end of segments + rx : np.ndarray + The rx values calculated to points. + ry : np.ndarray + The ry values calculated to points. + + Returns + ------- + tuple[np.ndarray, np.ndarray] + The (T, U) generalised coordinates for Rx and Ry calculations. + """ + direction_vectors = diff_reduction(trace_points, trace_indices) + segment_lengths = np.linalg.norm(direction_vectors, axis=1) + + directions_per_trace = np.diff(trace_indices) // 2 + directions_per_trace = directions_per_trace.astype(np.int64, casting="safe") + trace_direction_start_indices = np.cumulative_sum( + directions_per_trace, include_initial=True + ) + + trace_starts = trace_points[trace_indices[:-1]] + end_indices = trace_indices[1:] - 1 + trace_ends = trace_points[end_indices] + trace_endpoints = np.concatenate((trace_starts, trace_ends)) + + p_origin, p_end = trial_strike_vector(trace_endpoints) + trial_unit_vector = p_end - p_origin + trial_unit_vector /= np.linalg.norm(trial_unit_vector) + direction_vectors = strike_corrected_directions( + direction_vectors, trial_unit_vector + ) + b_hat = np.sum(direction_vectors, axis=0) + b_hat /= np.linalg.norm(b_hat) + + u_shift_origins = calculate_gc2_u_origins( + segment_lengths, + trace_direction_start_indices, + trace_starts, + p_origin, + b_hat, + ) + + return generalised_t_u_coordinates(segment_lengths, rx, ry, u_shift_origins) diff --git a/source_modelling/sources.py b/source_modelling/sources.py index 3af5e8d..de8f310 100644 --- a/source_modelling/sources.py +++ b/source_modelling/sources.py @@ -30,6 +30,7 @@ import shapely from qcore import coordinates, geo, grid +from source_modelling import gc2_distances _KM_TO_M = 1000 @@ -850,6 +851,18 @@ def rjb_distance(self, point: np.ndarray) -> float: shapely.Point(coordinates.wgs_depth_to_nztm(point)) ) + def rx_ry_distance(self, point: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + trace = self.bounds[:2, :2] + point = coordinates.wgs_depth_to_nztm(point)[..., :2] + rx, ry = gc2_distances.segment_rx_ry(trace, point) + return rx.squeeze(), ry.squeeze() + + def rx_distance(self, point: np.ndarray) -> np.ndarray: + return self.rx_ry_distance(point)[0] + + def ry_distance(self, point: np.ndarray) -> np.ndarray: + return self.rx_ry_distance(point)[1] + @dataclasses.dataclass class Fault: @@ -1330,6 +1343,61 @@ def rjb_distance(self, point: np.ndarray) -> float: shapely.Point(coordinates.wgs_depth_to_nztm(point)) ) + def rx_ry_distance(self, points: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + """Calculate the rx and ry distance between the fault and a given set of points + + Parameters + ---------- + points : np.ndarray + Points to calculate distance to, has shape (n, 2). + + Returns + ------- + rx : np.ndarray + The generalised rx distance (in metres) between the faults and the points. Has shape (n,) + ry : np.ndarray + The generalised ry distance (in metres) between the faults and the points. Has shape (n,) + """ + trace = self.trace[:, :2].reshape((-1, 2, 2)) + points = coordinates.wgs_depth_to_nztm(points)[..., :2] + + rx, ry = gc2_distances.segment_rx_ry(trace, points) + p_start = trace[:, 0, :] + p_end = trace[:, 1, :] + trace_lengths = np.linalg.norm(p_end - p_start, axis=-1) + origins = np.cumulative_sum(trace_lengths[:-1], include_initial=True) + t, u = gc2_distances.generalised_t_u_coordinates(trace_lengths, rx, ry, origins) + return t.squeeze(), u.squeeze() + + +def multi_fault_rx_ry_distance( + faults: list[Fault | Plane], points: np.ndarray +) -> tuple[np.ndarray, np.ndarray]: + """Calculate the rx-ry distance between a set of (possibly disconnected) faults and a set of points. + + Parameters + ---------- + faults : list[Fault | Plane] + Faults to calculate distances from. + points : np.ndarray + Points to calculate to, has shape (n, 2). + + Returns + ------- + rx : np.ndarray + The generalised rx distance (in metres) between the faults and the points. Has shape (n,) + ry : np.ndarray + The generalised ry distance (in metres) between the faults and the points. Has shape (n,) + """ + points = coordinates.wgs_depth_to_nztm(points[:, :2]) + traces = [fault.trace[:, :2] for fault in faults] + trace_points = np.concatenate(traces, axis=0) + trace_indices = np.cumulative_sum( + [len(trace) for trace in traces], include_initial=True + ) + rx, ry = gc2_distances.segment_rx_ry(trace_points, points) + return gc2_distances.multi_trace_rx_ry(trace_points, trace_indices, rx, ry) + IsSource = Plane | Fault | Point diff --git a/tests/test_rx_ry.py b/tests/test_rx_ry.py new file mode 100644 index 0000000..4b8fedc --- /dev/null +++ b/tests/test_rx_ry.py @@ -0,0 +1,363 @@ +import numpy as np +import pytest +import scipy as sp +from hypothesis import assume, given +from hypothesis import strategies as st +from hypothesis.extra.numpy import arrays + +from source_modelling import gc2_distances, sources + + +def test_segment_rx_ry_output_shapes(): + m, n = 5, 10 + bounds = np.random.rand(m, 2, 2) + points = np.random.rand(n, 2) + + rx, ry = gc2_distances.segment_rx_ry(bounds, points) + + assert rx.shape == (m, n), f"Expected rx shape ({m}, {n}), got {rx.shape}" + assert ry.shape == (m, n), f"Expected ry shape ({m}, {n}), got {ry.shape}" + + +@given( + bounds=arrays( + np.float64, + (1, 2, 2), + elements=st.floats(-1e3, 1e3, allow_nan=False, allow_infinity=False), + ), + t=st.floats(min_value=-100, max_value=100), + u=st.floats(min_value=-100, max_value=100), +) +def test_rx_ry(bounds, t, u): + q = bounds[0, 0, :] + r = bounds[0, 1, :] + qr = r - q + segment_length = np.linalg.norm(qr) + + assume(segment_length > 1e-7) + + norm = np.array([-qr[1], qr[0]]) / segment_length + + point = q + t * qr + u * norm + + rx_arr, ry_arr = gc2_distances.segment_rx_ry(bounds, point) + rx = rx_arr.item() + ry = ry_arr.item() + + assert rx == pytest.approx(u, abs=1e-5, rel=1e-5) + + expected_ry = t * segment_length + assert ry == pytest.approx(expected_ry, abs=1e-5, rel=1e-5) + + +@given( + lengths=st.floats(min_value=0.1, max_value=1000.0), + # Set rx != 0 to avoid the edge case where equation 5 applies. + rx=st.one_of( + st.floats(min_value=0.1, max_value=500.0), + st.floats(min_value=-500, max_value=-0.1), + ), + # ry: longitudinal distance from the start of the segment + ry=st.floats(min_value=-500.0, max_value=1500.0), +) +def test_segment_weights_integral_match(lengths, rx, ry) -> None: + trace_lengths = np.array([lengths]) + rx_arr = np.array([[rx]]) + ry_arr = np.array([[ry]]) + + def integrand(u: float) -> float: + dist_sq = rx**2 + (u - ry) ** 2 + return 1.0 / dist_sq + + expected_w, _ = sp.integrate.quad(integrand, 0, lengths) + + actual_w_arr = gc2_distances.segment_weights(trace_lengths, rx_arr, ry_arr) + actual_w = actual_w_arr.item() + + assert actual_w == pytest.approx(expected_w, rel=1e-4, abs=1e-8) + + +@given( + lengths=st.floats(min_value=0.1, max_value=1000.0), + ry=st.floats(min_value=-500.0, max_value=1500.0), +) +def test_segment_weights_at_zero(lengths, ry): + assume(not np.isclose(ry, 0.0) and not np.isclose(ry, lengths)) + trace_lengths = np.array([lengths]) + rx = 0.0 + rx_arr = np.array([[rx]]) + ry_arr = np.array([[ry]]) + + expected_w = 1 / (ry - lengths) - (1.0 / ry) + + actual_w_arr = gc2_distances.segment_weights(trace_lengths, rx_arr, ry_arr) + actual_w = actual_w_arr.item() + + assert actual_w == pytest.approx(expected_w, rel=1e-4, abs=1e-8) + + +def test_rx_ry_plane() -> None: + """Test that strike order is correctly interpreted in plane context""" + plane = sources.Plane.from_centroid_strike_dip( + np.array([-43.538, 172.6474, 5.0]), 90.0, 10.0, width=5.0, strike=0.0 + ) + # Conveniently, strike = 0 implies that west = left and east = right + point_left = np.array([-43.538, 172.6174]) + # meanwhile, south of the bottom edge implies negative ry. + point_right = np.array([-43.60, 172.6574]) + rx_left, ry_left = plane.rx_ry_distance(point_left) + rx_right, ry_right = plane.rx_ry_distance(point_right) + + assert rx_left < 0 + assert rx_right > 0 + + assert ry_left > 0 + assert ry_right < 0 + + +def test_rx_ry_fault() -> None: + """Test that strike order is correctly interpreted in fault context""" + + # Deliberately setup as an extension of the plane case with the + # fault striking 0 degrees initially and then tending 15 degrees + # east. + fault = sources.Fault.from_trace_points( + np.array( + [ + [-43.58302058, 172.64739978], + [-43.49297907, 172.6474], + [-43.44832311516278, 172.7160675056574], + ] + ), + dtop=0.0, + dbottom=10.0, + dip=90.0, + dip_dir=90.0, + ) + + # Conveniently, strike = 0 implies that west = left and east = right + point_left = np.array([-43.538, 172.6174]) + # meanwhile, south of the bottom edge implies negative ry. + point_right = np.array([-43.60, 172.6574]) + rx_left, ry_left = fault.rx_ry_distance(point_left) + rx_right, ry_right = fault.rx_ry_distance(point_right) + + assert rx_left < 0 + assert rx_right > 0 + + assert ry_left > 0 + assert ry_right < 0 + + +def test_rx_ry_fault_single_plane() -> None: + """Test that fault rx, ry == plane rx, ry in the single plane fault case""" + plane = sources.Plane.from_centroid_strike_dip( + np.array([-43.538, 172.6474, 5.0]), 90.0, 10.0, width=5.0, strike=0.0 + ) + fault = sources.Fault([plane]) + + point = np.array([-43.538, 172.6174]) + + rx_plane, ry_plane = plane.rx_ry_distance(point) + rx_fault, ry_fault = fault.rx_ry_distance(point) + + assert rx_plane == pytest.approx(rx_fault) + assert ry_plane == pytest.approx(ry_fault) + + +def test_antipodal_points_simple_square(): + """Tests a unit square; antipodal points should be diagonal corners.""" + points = np.array([[0, 0], [1, 0], [1, 1], [0, 1]]) + p1, p2 = gc2_distances.antipodal_points(points) + + distance = np.linalg.norm(p1 - p2) + # Diagonal of a 1x1 square is sqrt(2) + assert distance == pytest.approx(np.sqrt(2)) + + +def test_antipodal_points_with_interior_points(): + """Tests that points inside the hull don't affect the result.""" + # A triangle with a point right in the middle + points = np.array([[0, 0], [10, 0], [5, 10], [5, 5]]) + p1, p2 = gc2_distances.antipodal_points(points) + + distance = np.linalg.norm(p1 - p2) + # The furthest distance is between (0,0) and (5,10) or (10,0) and (5,10) + # both are sqrt(125) ~ 11.18 + assert distance == pytest.approx(np.sqrt(125)) + + +def test_antipodal_points_collinear_points(): + """Tests points on a straight line.""" + points = np.array([[0, 0], [1, 1], [2, 2], [5, 5]]) + p1, p2 = gc2_distances.antipodal_points(points) + + # Furthest points are the ends of the line + assert np.linalg.norm(p1 - p2) == pytest.approx(np.sqrt(50)) + + +def test_antipodal_points_single_point_error(): + """Tests that a single point raises an error (as combinations requires 2).""" + points = np.array([[1, 1]]) + with pytest.raises(ValueError): + gc2_distances.antipodal_points(points) + + +@given( + data=arrays(np.float64, st.integers(2, 20), elements=st.floats(0.1, 100)), +) +def test_cumulative_reduction_resets_per_trace(data) -> None: + # Split data into two traces at a random point + split = len(data) // 2 + indices = np.array([0, split, len(data)], dtype=np.uint64) + + out = gc2_distances.cumulative_reduction(data, indices) + + # Invariant: Each trace starts at 0.0 + assert out[0] == 0.0 + assert out[split] == 0.0 + # Invariant: It is a prefix sum within the trace + assert out[split - 1] == pytest.approx(np.sum(data[0 : split - 1])) + assert out[-1] == pytest.approx(np.sum(data[split:-1])) + + +def test_diff_reduction_isolates_traces() -> None: + # 5 points, 2 traces: [P0, P1, P2] and [P3, P4] + points = np.array( + [ + [0, 0], + [1, 0], + [1, 0], + [2, 0], # Trace 1 + [10, 10], + [11, 11], # Trace 2 + ], + dtype=np.float64, + ) + indices = np.array([0, 4, len(points)], dtype=np.uint64) + + # Expected segments: (P1-P0), (P2-P1), (P4-P3) + expected = np.array([[1, 0], [1, 0], [1, 1]]) + actual = gc2_distances.diff_reduction(points, indices) + + assert actual == pytest.approx(expected) + + +@given( + points=arrays( + np.float64, + shape=st.tuples(st.integers(2, 10), st.just(2)), + elements=st.floats(-1e5, 1e5), + ) +) +def test_trial_strike_vector_invariants(points: np.ndarray) -> None: + assume(len(np.unique(points, axis=0)) == len(points)) + # Antipodal distance should be max distance in set + a, b = gc2_distances.trial_strike_vector(points) + dist_sq = np.sum((a - b) ** 2) + + # Check against a few random pairs + for _ in range(20): + p1, p2 = points[np.random.choice(len(points), 2)] + assert dist_sq >= np.sum((p1 - p2) ** 2) - 1e-7 + + # Invariant: Canonical orientation (NZTM x-axis/Easting is index 1) + assert b[1] >= a[1] + + +def test_strike_corrected_directions_flips_appropriately() -> None: + trial_unit = np.array([1.0, 0.0]) # East + # Trace 1: East, Trace 2: Slightly North East, Trace 3: West (Opposing) + dirs = np.array([[1.0, 0], [1.0, 0.2], [-1.0, 0]]) + + # Invariant: Majority direction wins. West should flip to East. + corrected = gc2_distances.strike_corrected_directions(dirs, trial_unit) + + assert corrected[2, 0] > 0 # The third segment should have flipped + # Invariant: Dot product sum of corrected dirs with trial vector must be positive + assert np.sum(np.vecdot(corrected, trial_unit)) > 0 + + +def test_calculate_gc2_u_origins_logic(): + # Setup: 2 traces. Trace 1 is 10 units long. Trace 2 starts 50 units East. + seg_lengths = np.array([5.0, 5.0, 3.0, 3.0]) + seg_indices = np.array([0, 2, 4], dtype=np.uint64) + trace_starts = np.array([[0.0, 0.0], [0.0, 50.0]]) + p_origin = np.array([0.0, 0.0]) + b_hat = np.array([0.0, 1.0]) # Strike is East (x-axis) + + # Trace 1 global: 0.0. Local: [0, 5] + # Trace 2 global: 50.0. Local: [0, 3] + expected = np.array([0.0, 5.0, 50.0, 53.0]) + + actual = gc2_distances.calculate_gc2_u_origins( + seg_lengths, seg_indices, trace_starts, p_origin, b_hat + ) + + assert actual == pytest.approx(expected) + + +def test_multi_trace_rx_ry_origin_alignment() -> None: + """ + Test that origin shifts correctly align colinear traces. + + Setup: + Trace 1: (0, 0) to (10, 0) [Length 10] + Gap: (10, 0) to (20, 0) [Length 10] + Trace 2: (20, 0) to (30, 0) [Length 10] + + Observation Point P: (25, 5) + """ + # 1. Define geometry + trace_points = np.array( + [ + [0.0, 0.0], + [10.0, 0.0], # Trace 1 (2 points) + [20.0, 0.0], + [30.0, 0.0], # Trace 2 (2 points) + ] + ) + trace_indices = np.array([0, 2, 4], dtype=np.uint64) + + # 2. Mock rx and ry for point P(25, 5) relative to each segment + # Segment 1 (0,0 -> 10,0): + # P is 25 units 'along' the infinite line (ry=25) + # P is 5 units 'away' (rx=5) + # Segment 2 (20,0 -> 30,0): + # P is 5 units 'along' the start of this segment (ry=5) + # P is 5 units 'away' (rx=5) + rx = np.array([[5.0], [5.0]]) # Shape (m=2 segments, n=1 point) + ry = np.array([[25.0], [5.0]]) + + # 3. Calculate + t, u = gc2_distances.multi_trace_rx_ry(trace_points, trace_indices, rx, ry) + + # 4. Assertions + # T (weighted rx) should remain 5.0 + assert t.item() == pytest.approx(5.0) + + # U (weighted global ry) calculation check: + # Trace 1 origin shift: 0 (start) + 0 (global) = 0. Global U = 25 + 0 = 25. + # Trace 2 origin shift: 0 (start) + 20 (global) = 20. Global U = 5 + 20 = 25. + # Since both segments agree the global U is 25, the average must be 25. + assert u.item() == pytest.approx(25.0) + + +def test_multi_trace_rx_ry_local_shift_accumulation() -> None: + """Test that multiple segments within a single trace accumulate local shifts.""" + # Single trace with two segments: (0,0) -> (5,0) -> (15,0) + trace_points = np.array([[0.0, 0.0], [5.0, 0.0], [5.0, 0.0], [15.0, 0.0]]) + trace_indices = np.array([0, 2, 4], dtype=np.uint64) + + # Point P(10, 2) + # Seg 1: ry = 10, rx = 2 + # Seg 2: ry = 5, rx = 2 (10 units from start of Seg 2, which is at x=5) + rx = np.array([[2.0], [2.0]]) + ry = np.array([[10.0], [5.0]]) + + _, u = gc2_distances.multi_trace_rx_ry(trace_points, trace_indices, rx, ry) + + # Local shift for Seg 2 should be length of Seg 1 (5.0) + # Seg 1 Global U: 10 + 0 = 10 + # Seg 2 Global U: 5 + 5 = 10 + assert u.item() == pytest.approx(10.0)