diff --git a/bilby/core/series.py b/bilby/core/series.py index ba1d0ffcb..e607b5f57 100644 --- a/bilby/core/series.py +++ b/bilby/core/series.py @@ -128,3 +128,7 @@ def start_time(self): def start_time(self, start_time): self._start_time = start_time self._time_array_updated = False + + @property + def end_time(self): + return self.start_time + self.duration diff --git a/bilby/gw/detector/interferometer.py b/bilby/gw/detector/interferometer.py index 9e9c23bdf..ae619fa63 100644 --- a/bilby/gw/detector/interferometer.py +++ b/bilby/gw/detector/interferometer.py @@ -43,6 +43,10 @@ class Interferometer(object): minimum_frequency = PropertyAccessor('strain_data', 'minimum_frequency') maximum_frequency = PropertyAccessor('strain_data', 'maximum_frequency') frequency_mask = PropertyAccessor('strain_data', 'frequency_mask') + time_mask = PropertyAccessor('strain_data', 'time_mask') + crop_duration = PropertyAccessor('strain_data', 'crop_duration') + cropped_duration = PropertyAccessor('strain_data', 'cropped_duration') + cropped_frequency_mask = PropertyAccessor('strain_data', 'cropped_frequency_mask') frequency_domain_strain = PropertyAccessor('strain_data', 'frequency_domain_strain') time_domain_strain = PropertyAccessor('strain_data', 'time_domain_strain') @@ -597,12 +601,9 @@ def optimal_snr_squared(self, signal): Returns ======= - float: The optimal signal to noise ratio possible squared + float: The optimal signal-to-noise ratio squared of the signal """ - return gwutils.optimal_snr_squared( - signal=signal[self.strain_data.frequency_mask], - power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + return (abs(self.whiten_frequency_series(signal))**2).sum() def inner_product(self, signal): """ @@ -614,13 +615,13 @@ def inner_product(self, signal): Returns ======= - float: The optimal signal to noise ratio possible squared + float: + The noise-weighted inner product between the passed signal + and the data stored in the :code:`Interferometer`. """ - return gwutils.noise_weighted_inner_product( - aa=signal[self.strain_data.frequency_mask], - bb=self.strain_data.frequency_domain_strain[self.strain_data.frequency_mask], - power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + whitened_signal = self.whiten_frequency_series(signal) + whitened_data = self.whitened_frequency_domain_strain + return (whitened_signal.T * whitened_data.conj()).sum() def template_template_inner_product(self, signal_1, signal_2): """A noise weighted inner product between two templates, using this ifo's PSD. @@ -636,11 +637,9 @@ def template_template_inner_product(self, signal_1, signal_2): ======= float: The noise weighted inner product of the two templates """ - return gwutils.noise_weighted_inner_product( - aa=signal_1[self.strain_data.frequency_mask], - bb=signal_2[self.strain_data.frequency_mask], - power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + whitened_1 = self.whiten_frequency_series(signal_1) + whitened_2 = self.whiten_frequency_series(signal_2) + return (whitened_1 * whitened_2.conj()).sum() def matched_filter_snr(self, signal): """ @@ -655,13 +654,9 @@ def matched_filter_snr(self, signal): complex: The matched filter signal to noise ratio """ - return gwutils.matched_filter_snr( - signal=signal[self.strain_data.frequency_mask], - frequency_domain_strain=self.strain_data.frequency_domain_strain[self.strain_data.frequency_mask], - power_spectral_density=self.power_spectral_density_array[self.strain_data.frequency_mask], - duration=self.strain_data.duration) + return self.inner_product(signal) / self.optimal_snr_squared(signal)**0.5 - def whiten_frequency_series(self, frequency_series : np.array) -> np.array: + def whiten_frequency_series(self, frequency_series: np.array) -> np.array: """Whitens a frequency series with the noise properties of the detector .. math:: @@ -680,7 +675,21 @@ def whiten_frequency_series(self, frequency_series : np.array) -> np.array: frequency_series : np.array The frequency series, whitened by the ASD """ - return frequency_series / (self.amplitude_spectral_density_array * np.sqrt(self.duration / 4)) + if self.crop_duration == 0: + return gwutils.frequency_domain_whiten( + frequency_series=frequency_series, + amplitude_spectral_density=self.amplitude_spectral_density_array, + frequency_mask=self.frequency_mask, + duration=self.duration, + ) + else: + return gwutils.whiten_and_crop( + frequency_series=frequency_series, + amplitude_spectral_density=self.amplitude_spectral_density_array, + frequency_mask=self.frequency_mask, + time_mask=self.time_mask, + duration=self.duration, + ) def get_whitened_time_series_from_whitened_frequency_series( self, @@ -718,7 +727,9 @@ def get_whitened_time_series_from_whitened_frequency_series( whitened_time_series = ( np.fft.irfft(whitened_frequency_series) - * np.sqrt(np.sum(self.frequency_mask)) / frequency_window_factor + * np.sqrt(np.sum(self.frequency_mask)) + / frequency_window_factor + * self.time_mask.mean()**0.5 ) return whitened_time_series diff --git a/bilby/gw/detector/strain_data.py b/bilby/gw/detector/strain_data.py index bca7acced..9ce05df0b 100644 --- a/bilby/gw/detector/strain_data.py +++ b/bilby/gw/detector/strain_data.py @@ -3,6 +3,7 @@ from ...core import utils from ...core.series import CoupledTimeAndFrequencySeries from ...core.utils import logger, PropertyAccessor +from ...core.utils.series import create_frequency_series from .. import utils as gwutils @@ -12,11 +13,12 @@ class InterferometerStrainData(object): duration = PropertyAccessor('_times_and_frequencies', 'duration') sampling_frequency = PropertyAccessor('_times_and_frequencies', 'sampling_frequency') start_time = PropertyAccessor('_times_and_frequencies', 'start_time') + end_time = PropertyAccessor('_times_and_frequencies', 'end_time') frequency_array = PropertyAccessor('_times_and_frequencies', 'frequency_array') time_array = PropertyAccessor('_times_and_frequencies', 'time_array') def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, - roll_off=0.2, notch_list=None): + roll_off=0.2, notch_list=None, crop_duration=0): """ Initiate an InterferometerStrainData object The initialised object contains no data, this should be added using one @@ -33,6 +35,11 @@ def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, This corresponds to alpha * duration / 2 for scipy tukey window. notch_list: bilby.gw.detector.strain_data.NotchList A list of notches + crop_duration: float | tuple + The duration of data to crop at the beginning/end of the segment + to avoid whitening artifacts. If a float, that duration is excluded + at each end, if a tuple, this specifies the truncation duration + at the beginning and end. """ @@ -41,11 +48,14 @@ def __init__(self, minimum_frequency=0, maximum_frequency=np.inf, self.notch_list = notch_list self.roll_off = roll_off self.window_factor = 1 + self._crop_duration = crop_duration self._times_and_frequencies = CoupledTimeAndFrequencySeries() self._frequency_mask_updated = False self._frequency_mask = None + self._time_mask_updated = False + self._time_mask = None self._frequency_domain_strain = None self._time_domain_strain = None self._channel = None @@ -135,6 +145,33 @@ def notch_list(self, notch_list): raise ValueError("notch_list {} not understood".format(notch_list)) self._frequency_mask_updated = False + @property + def crop_duration(self): + """ + The duration of data to crop at the beginning/end of the segment + to avoid conditioning artifacts. If a float, that duration is + excluded at each end, if a tuple, this specifies the truncation + duration at the beginning and end. + """ + return self._crop_duration + + @crop_duration.setter + def crop_duration(self, crop_duration): + if not isinstance(self.crop_duration, (float, int, list, tuple)): + raise TypeError(f"Invalid crop specification {self.crop_duration}") + self._crop_duration = crop_duration + self._time_mask_updated = False + + @property + def cropped_duration(self): + """ + The duration after applying the time-domain mask. + """ + if isinstance(self.crop_duration, (float, int)): + return self.duration - 2 * self.crop_duration + else: + return self.duration - sum(self.crop_duration[:2]) + @property def frequency_mask(self): """ Masking array for limiting the frequency band. @@ -145,20 +182,69 @@ def frequency_mask(self): An array of boolean values """ if not self._frequency_mask_updated: - frequency_array = self._times_and_frequencies.frequency_array + self._update_frequency_mask() + return self._frequency_mask + + def _update_frequency_mask(self): + def calculate_frequency_mask(frequency_array): mask = ((frequency_array >= self.minimum_frequency) & (frequency_array <= self.maximum_frequency)) for notch in self.notch_list: mask[notch.get_idxs(frequency_array)] = False - self._frequency_mask = mask - self._frequency_mask_updated = True - return self._frequency_mask + return mask + + self._frequency_mask = calculate_frequency_mask( + self._times_and_frequencies.frequency_array + ) + + cropped_frequencies = create_frequency_series( + duration=self.cropped_duration, + sampling_frequency=self.sampling_frequency + ) + self._cropped_frequency_mask = calculate_frequency_mask( + cropped_frequencies + ) + + self._frequency_mask_updated = True @frequency_mask.setter def frequency_mask(self, mask): self._frequency_mask = mask self._frequency_mask_updated = True + @property + def cropped_frequency_mask(self): + if not self._frequency_mask_updated: + self._update_frequency_mask + return self._cropped_frequency_mask + + @property + def time_mask(self): + """ Masking array for cropping corrupted data at the edges. + + Returns + ======= + mask: np.ndarray + An array of boolean values + """ + if not self._time_mask_updated: + if isinstance(self.crop_duration, (tuple, list)): + crop_start, crop_end = self.crop_duration + elif isinstance(self.crop_duration, (float, int)): + crop_start = crop_end = self.crop_duration + + time_array = self._times_and_frequencies.time_array + mask = ((time_array > self.start_time + crop_start) & + (time_array <= self.end_time - crop_end)) + self._time_mask = mask + self._time_mask_updated = True + return self._time_mask + + @time_mask.setter + def time_mask(self, mask): + self._time_mask = mask + self._time_mask_updated = True + @property def alpha(self): return 2 * self.roll_off / self.duration diff --git a/bilby/gw/likelihood/base.py b/bilby/gw/likelihood/base.py index 8dfbcdbf5..39121527e 100644 --- a/bilby/gw/likelihood/base.py +++ b/bilby/gw/likelihood/base.py @@ -11,7 +11,7 @@ from ...core.prior import Interped, Prior, Uniform, DeltaFunction from ..detector import InterferometerList, get_empty_interferometer, calibration from ..prior import BBHPriorDict, Cosmological -from ..utils import noise_weighted_inner_product, zenith_azimuth_to_ra_dec, ln_i0 +from ..utils import zenith_azimuth_to_ra_dec, ln_i0 class GravitationalWaveTransient(Likelihood): @@ -277,63 +277,55 @@ def calculate_snrs(self, waveform_polarizations, interferometer, *, return_array interferometer=interferometer, parameters=parameters, ) - _mask = interferometer.frequency_mask if 'recalib_index' in parameters: - signal[_mask] *= self.calibration_draws[interferometer.name][int(parameters['recalib_index'])] + signal *= self.calibration_draws[interferometer.name][int(parameters['recalib_index'])] - d_inner_h = interferometer.inner_product(signal=signal) - optimal_snr_squared = interferometer.optimal_snr_squared(signal=signal) + whitened_signal = interferometer.whiten_frequency_series(signal) + + d_inner_h = (interferometer.whitened_frequency_domain_strain.conjugate() * whitened_signal).sum() + optimal_snr_squared = (abs(whitened_signal)**2).sum() complex_matched_filter_snr = d_inner_h / (optimal_snr_squared**0.5) d_inner_h_array = None optimal_snr_squared_array = None - normalization = 4 / self.waveform_generator.duration - if return_array is False: d_inner_h_array = None optimal_snr_squared_array = None elif self.time_marginalization and self.calibration_marginalization: d_inner_h_integrand = np.tile( - interferometer.frequency_domain_strain.conjugate() * signal / - interferometer.power_spectral_density_array, (self.number_of_response_curves, 1)).T + interferometer.whitened_frequency_domain_strain.conjugate() * whitened_signal, + (self.number_of_response_curves, 1) + ).T - d_inner_h_integrand[_mask] *= self.calibration_draws[interferometer.name].T + d_inner_h_integrand[interferometer.frequency_mask] *= self.calibration_draws[interferometer.name].T - d_inner_h_array = 4 / self.waveform_generator.duration * np.fft.fft( - d_inner_h_integrand[0:-1], axis=0 - ).T + d_inner_h_array = np.fft.fft(d_inner_h_integrand[:-1], axis=0).T - optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array - ) + optimal_snr_squared_integrand = np.abs(whitened_signal)**2 optimal_snr_squared_array = np.dot( - optimal_snr_squared_integrand[_mask], + optimal_snr_squared_integrand[interferometer.frequency_mask], self.calibration_abs_draws[interferometer.name].T ) elif self.time_marginalization and not self.calibration_marginalization: - d_inner_h_array = normalization * np.fft.fft( - signal[0:-1] - * interferometer.frequency_domain_strain.conjugate()[0:-1] - / interferometer.power_spectral_density_array[0:-1] + d_inner_h_integrand = ( + whitened_signal + * interferometer.whitened_frequency_domain_strain.conjugate() ) + d_inner_h_array = np.fft.fft(d_inner_h_integrand[:-1]) elif self.calibration_marginalization and ('recalib_index' not in parameters): d_inner_h_integrand = ( - normalization * - interferometer.frequency_domain_strain.conjugate() * signal - / interferometer.power_spectral_density_array - ) - d_inner_h_array = np.dot(d_inner_h_integrand[_mask], self.calibration_draws[interferometer.name].T) + interferometer.whitened_frequency_domain_strain.conjugate() * whitened_signal + )[interferometer.frequency_mask] + d_inner_h_array = np.dot(d_inner_h_integrand, self.calibration_draws[interferometer.name].T) - optimal_snr_squared_integrand = ( - normalization * np.abs(signal)**2 / interferometer.power_spectral_density_array - ) + optimal_snr_squared_integrand = np.abs(whitened_signal)**2 optimal_snr_squared_array = np.dot( - optimal_snr_squared_integrand[_mask], + optimal_snr_squared_integrand[interferometer.frequency_mask], self.calibration_abs_draws[interferometer.name].T ) @@ -390,12 +382,9 @@ def priors(self, priors): def _calculate_noise_log_likelihood(self): log_l = 0 for interferometer in self.interferometers: - mask = interferometer.frequency_mask - log_l -= noise_weighted_inner_product( - interferometer.frequency_domain_strain[mask], - interferometer.frequency_domain_strain[mask], - interferometer.power_spectral_density_array[mask], - self.waveform_generator.duration) / 2 + log_l -= ( + abs(interferometer.whitened_frequency_domain_strain)**2 + ).sum() / 2 return float(np.real(log_l)) def noise_log_likelihood(self): diff --git a/bilby/gw/likelihood/basic.py b/bilby/gw/likelihood/basic.py index da67481f0..e04f51b0d 100644 --- a/bilby/gw/likelihood/basic.py +++ b/bilby/gw/likelihood/basic.py @@ -43,9 +43,9 @@ def noise_log_likelihood(self): """ log_l = 0 for interferometer in self.interferometers: - log_l -= 2. / self.waveform_generator.duration * np.sum( - abs(interferometer.frequency_domain_strain) ** 2 / - interferometer.power_spectral_density_array) + log_l -= ( + abs(interferometer.whitened_frequency_domain_strain)**2 + ).sum() / 2 return log_l.real def log_likelihood(self, parameters): @@ -85,8 +85,8 @@ def log_likelihood_interferometer(self, waveform_polarizations, signal_ifo = interferometer.get_detector_response( waveform_polarizations, parameters) - log_l = - 2. / self.waveform_generator.duration * np.vdot( - interferometer.frequency_domain_strain - signal_ifo, - (interferometer.frequency_domain_strain - signal_ifo) / - interferometer.power_spectral_density_array) + residual = interferometer.frequency_domain_strain - signal_ifo + white_residual = interferometer.whiten_frequency_series(residual) + + log_l = - (abs(white_residual)**2).sum() / 2 return log_l.real diff --git a/bilby/gw/utils.py b/bilby/gw/utils.py index 420e1fc04..33738b6e5 100644 --- a/bilby/gw/utils.py +++ b/bilby/gw/utils.py @@ -228,6 +228,104 @@ def overlap(signal_a, signal_b, power_spectral_density=None, delta_frequency=Non return sum(integral).real +def frequency_domain_whiten(frequency_series, amplitude_spectral_density, frequency_mask, duration): + """ + Whitens a frequency series with the provided asd in the frequency domain + + .. math:: + \\tilde{a}_w(f) = \\tilde{a}(f) \\sqrt{\\frac{4}{T S_n(f)}} + + Such that + + .. math:: + Var(n) = \\frac{1}{N} \\sum_{k=0}^N n_W(f_k)n_W^*(f_k) = 2 + + Where the factor of two is due to the independent real and imaginary + components. + + Parameters + ========== + frequency_series : np.ndarray + The frequency series to whiten + amplitude_spectral_density: np.ndarray + The amplitude spectral density array to use for whitening + frequency_mask: np.ndarray + A boolean mask to apply to the whitened strain + duration: float + The effective duration of the passed frequency series + + Returns + ======= + np.ndarray + The whitened frequency series + """ + whitened = frequency_series / amplitude_spectral_density + return np.nan_to_num(whitened) * frequency_mask * (4 / duration)**0.5 + + +def whiten_and_crop(frequency_series, amplitude_spectral_density, frequency_mask, time_mask, duration): + """ + Whitens a frequency series with the noise properties and applies + the time mask to the whitened time-domain strain[1]_. + + First, we naively whiten the data in the frequency domain and apply + our frequency mask :math:`\\tilde{m}(f)` + + .. math:: + \\tilde{a}_w(f) = \\tilde{a}(f) \\tilde{m}(f) / \\sqrt{S_n(f)}. + + We then inverse discrete Fourier transform to get the whitened + time-domain strain :math:`w(t)` and apply the time-domain mask + :math:`m(t)`. We then perform a final discrete Fourier transform. + + .. math:: + \\bar{a}_{w}(f) = \\mathcal{F}(\\mathcal{iF}(\\tilde{a}_w(f))[m(t)]) + + Finally, we normalize the whitened strain by a factor :math:`N` + + .. math:: + N = \\sqrt{\\frac{4}{T_{c}}} + + where :math:`T_{c}` is the cropped duration such that + + .. math:: + Var(n) = \\frac{1}{N} \\sum_{k=0}^N n_W(f_k)n_W^*(f_k) = 2. + + Where the factor of two is due to the independent real and imaginary + components. + + + Parameters + ========== + frequency_series : np.ndarray + The frequency series to whiten + amplitude_spectral_density: np.ndarray + The amplitude spectral density array to use for whitening + frequency_mask: np.ndarray + A boolean mask to apply to the whitened strain + time_mask: np.ndarray + A boolean mask used to crop the whitened time series + duration: float + The effective duration of the passed frequency series + + Returns + ======= + np.ndarray + The whitened frequency series + + .. [1] C. Talbot et al. 2025 + `Class. Quantum Grav. 42 235023 `_ + """ + whitened = frequency_series / amplitude_spectral_density + initial_white = np.nan_to_num(whitened) * frequency_mask + time_series = np.fft.irfft(initial_white) + cropped_time_series = time_series[time_mask] + cropped_white = np.fft.rfft(cropped_time_series) + cropped_duration = duration * len(cropped_time_series) / len(time_series) + + return cropped_white * (4 / cropped_duration)**0.5 + + def zenith_azimuth_to_theta_phi(zenith, azimuth, ifos): """ Convert from the 'detector frame' to the Earth frame. diff --git a/test/core/series_test.py b/test/core/series_test.py index bf1b19c43..abea7e646 100644 --- a/test/core/series_test.py +++ b/test/core/series_test.py @@ -42,6 +42,10 @@ def test_sampling_from_init(self): def test_start_time_from_init(self): self.assertEqual(self.start_time, self.series.start_time) + def test_end_time(self): + expected = self.start_time + self.duration + self.assertEqual(expected, self.series.end_time) + def test_frequency_array_type(self): self.assertIsInstance(self.series.frequency_array, np.ndarray) diff --git a/test/gw/detector/interferometer_test.py b/test/gw/detector/interferometer_test.py index cb1666320..6de3bca42 100644 --- a/test/gw/detector/interferometer_test.py +++ b/test/gw/detector/interferometer_test.py @@ -6,6 +6,7 @@ import lalsimulation import pytest from shutil import rmtree +from parameterized import parameterized_class import numpy as np @@ -302,27 +303,6 @@ def test_inject_signal_raises_value_error(self): with self.assertRaises(ValueError): self.ifo.inject_signal(injection_polarizations=None, parameters=None) - def test_optimal_snr_squared(self): - """ - Merely checks parameters are given in the right order and the frequency - mask is applied. - """ - with mock.patch("bilby.gw.utils.noise_weighted_inner_product") as m: - m.side_effect = lambda a, b, c, d: [a, b, c, d] - signal = np.ones_like(self.ifo.power_spectral_density_array) - mask = self.ifo.frequency_mask - expected = [ - signal[mask], - signal[mask], - self.ifo.power_spectral_density_array[mask], - self.ifo.strain_data.duration, - ] - actual = self.ifo.optimal_snr_squared(signal=signal) - self.assertTrue(np.array_equal(expected[0], actual[0])) - self.assertTrue(np.array_equal(expected[1], actual[1])) - self.assertTrue(np.array_equal(expected[2], actual[2])) - self.assertEqual(expected[3], actual[3]) - def test_template_template_inner_product(self): signal_1 = np.ones_like(self.ifo.power_spectral_density_array) signal_2 = np.ones_like(self.ifo.power_spectral_density_array) * 2 @@ -596,6 +576,7 @@ def test_time_delay_vs_lal(self): @pytest.mark.flaky(reruns=3, only_rerun=["AssertionError"]) +@parameterized_class(("crop_duration",), [(0,), (4,), (16,)]) class TestInterferometerWhitenedStrain(unittest.TestCase): def setUp(self): self.duration = 64 @@ -603,6 +584,7 @@ def setUp(self): self.ifo = bilby.gw.detector.get_empty_interferometer('H1') self.ifo.set_strain_data_from_power_spectral_density( sampling_frequency=self.sampling_frequency, duration=self.duration) + self.ifo.crop_duration = self.crop_duration self.waveform_generator = bilby.gw.waveform_generator.WaveformGenerator( duration=self.duration, sampling_frequency=self.sampling_frequency, @@ -647,8 +629,7 @@ def _check_time_series_whiteness(self, time_series): self.assertAlmostEqual(std, 1, places=2) def test_frequency_domain_whitened_strain(self): - mask = self.ifo.frequency_mask - white = self.ifo.whitened_frequency_domain_strain[mask] + white = self.ifo.whitened_frequency_domain_strain[self.ifo.cropped_frequency_mask] self._check_frequency_series_whiteness(white) def test_time_domain_whitened_strain(self): @@ -666,8 +647,10 @@ def test_frequency_domain_noise_and_signal_whitening(self): ) # Whiten the template whitened_signal_ifo = self.ifo.whiten_frequency_series(signal_ifo) - mask = self.ifo.frequency_mask - white = self.ifo.whitened_frequency_domain_strain[mask] - whitened_signal_ifo[mask] + white = ( + self.ifo.whitened_frequency_domain_strain[self.ifo.cropped_frequency_mask] + - whitened_signal_ifo[self.ifo.cropped_frequency_mask] + ) self._check_frequency_series_whiteness(white) def test_time_domain_noise_and_signal_whitening(self): diff --git a/test/gw/detector/strain_data_test.py b/test/gw/detector/strain_data_test.py index 0f82a40a2..e04cfd8e0 100644 --- a/test/gw/detector/strain_data_test.py +++ b/test/gw/detector/strain_data_test.py @@ -77,6 +77,25 @@ def test_notches_frequency_mask(self): idxs = (freqs > 100) * (freqs < 101) self.assertTrue(len(freqs[idxs]) == 0) + def test_time_mask(self): + strain_data = bilby.gw.detector.InterferometerStrainData( + minimum_frequency=20, maximum_frequency=512, crop_duration=0.1) + strain_data.set_from_time_domain_strain( + time_domain_strain=np.random.normal(0, 1, 4096), + time_array=np.arange(0, 4, 4 / 4096), + ) + + # Test from init + times = strain_data.time_array[strain_data.time_mask] + self.assertTrue(all(times > 0.1)) + self.assertTrue(all(times <= 3.9)) + + # Test from update + strain_data.crop_duration = (0.5, 1) + times = strain_data.time_array[strain_data.time_mask] + self.assertTrue(all(times > 0.5)) + self.assertTrue(all(times <= 3)) + def test_set_data_fails(self): with mock.patch("bilby.core.utils.create_frequency_series") as m: m.return_value = [1, 2, 3] diff --git a/test/gw/likelihood_test.py b/test/gw/likelihood_test.py index 9d7a7e36f..ee16ac4dd 100644 --- a/test/gw/likelihood_test.py +++ b/test/gw/likelihood_test.py @@ -60,11 +60,27 @@ def test_noise_log_likelihood(self): -4014.1787704539474, self.likelihood.noise_log_likelihood(), 3 ) + def test_noise_log_likelihood_with_cropping(self): + """Test noise log likelihood matches precomputed value""" + self.interferometers[0].crop_duration = 1.0 + self.likelihood.noise_log_likelihood() + self.assertAlmostEqual( + -1991.4986550018828, self.likelihood.noise_log_likelihood(), 3 + ) + def test_log_likelihood(self): """Test log likelihood matches precomputed value""" self.likelihood.log_likelihood(self.parameters) self.assertAlmostEqual(self.likelihood.log_likelihood(self.parameters), -4032.4397343470005, 3) + def test_log_likelihood_with_cropping(self): + """Test noise log likelihood matches precomputed value""" + self.interferometers[0].crop_duration = 1.0 + self.likelihood.log_likelihood() + self.assertAlmostEqual( + -2009.7686313437, self.likelihood.log_likelihood(), 3 + ) + def test_log_likelihood_ratio(self): """Test log likelihood ratio returns the correct value""" self.assertAlmostEqual( @@ -158,6 +174,21 @@ def test_log_likelihood_ratio(self): 3, ) + def test_noise_log_likelihood_with_cropping(self): + """Test noise log likelihood matches precomputed value""" + self.interferometers[0].crop_duration = 1.0 + self.likelihood.noise_log_likelihood() + self.assertAlmostEqual( + -1991.4986550018828, self.likelihood.noise_log_likelihood(), 3 + ) + + def test_log_likelihood_with_cropping(self): + """Test log likelihood matches precomputed value""" + self.interferometers[0].crop_duration = 1.0 + self.likelihood.log_likelihood() + self.assertAlmostEqual(self.likelihood.log_likelihood(), + -2009.7686313436998, 3) + def test_likelihood_zero_when_waveform_is_none(self): """Test log likelihood returns np.nan_to_num(-np.inf) when the waveform is None"""