diff --git a/docs/user/function.rst b/docs/user/function.rst index f2e49d427..4b2d1541d 100644 --- a/docs/user/function.rst +++ b/docs/user/function.rst @@ -90,9 +90,13 @@ plotted as follows: .. important:: - The ``Function`` class only supports interpolation ``shepard`` and \ - extrapolation ``natural`` for datasets higher than one dimension (more than \ - one input). + For datasets higher than one dimension (more than one input), the + ``Function`` class supports interpolation ``linear``, ``shepard``, ``rbf`` + and ``regular_grid``. + + The ``regular_grid`` interpolation requires a complete Cartesian grid and + must be provided as ``(axes, grid_data)``. See the ``Function`` API + documentation for details. CSV File ^^^^^^^^ @@ -183,7 +187,7 @@ In this section we are going to delve deeper on ``Function`` creation and its pa - source: the ``Function`` data source. We have explored this parameter in the section above; - inputs: a list of strings containing each input variable name. If the source only has one input, may be abbreviated as a string (e.g. "speed (m/s)"); - outputs: a list of strings containing each output variable name. If the source only has one output, may be abbreviated as a string (e.g. "total energy (J)"); -- interpolation: a string that is the interpolation method to be used if the source is a dataset. Defaults to ``spline``; +- interpolation: a string that is the interpolation method to be used if the source is a dataset. For N-D datasets, supported options are ``linear``, ``shepard``, ``rbf`` and ``regular_grid``. Defaults to ``spline`` for 1-D and ``shepard`` for N-D datasets; - extrapolation: a string that is the extrapolation method to be used if the source is a dataset. Defaults to ``constant``; - title: the title to be shown in the plots. diff --git a/docs/user/rocket/generic_surface.rst b/docs/user/rocket/generic_surface.rst index f2e166148..997f5a178 100644 --- a/docs/user/rocket/generic_surface.rst +++ b/docs/user/rocket/generic_surface.rst @@ -243,16 +243,17 @@ independent variables: - ``beta``: Side slip angle. - ``mach``: Mach number. - ``reynolds``: Reynolds number. -- ``q``: Pitch rate. -- ``r``: Yaw rate. -- ``p``: Roll rate. +- ``pitch_rate``: Pitch rate. +- ``yaw_rate``: Yaw rate. +- ``roll_rate``: Roll rate. The last column must be the coefficient value, and must contain a header, though the header name can be anything. .. important:: Not all columns need to be present in the file, but the columns that are - present must be named, **and ordered**, as described above. + present must be correctly named as described above. Independent variable + columns can be in any order. An example of a ``.csv`` file is shown below: diff --git a/rocketpy/mathutils/function.py b/rocketpy/mathutils/function.py index fbece42dd..bdc66da08 100644 --- a/rocketpy/mathutils/function.py +++ b/rocketpy/mathutils/function.py @@ -108,8 +108,8 @@ def __init__( interpolation : string, optional Interpolation method to be used if source type is ndarray. For 1-D functions, linear, polynomial, akima and spline are - supported. For N-D functions, linear, shepard and rbf are - supported. + supported. For N-D functions, linear, shepard, rbf and + regular_grid are supported. Default for 1-D functions is spline and for N-D functions is shepard. extrapolation : string, optional @@ -310,8 +310,8 @@ def set_interpolation(self, method="spline"): method : string, optional Interpolation method to be used if source type is ndarray. For 1-D functions, linear, polynomial, akima and spline is - supported. For N-D functions, linear, shepard and rbf are - supported. + supported. For N-D functions, linear, shepard, rbf and + regular_grid are supported. Default for 1-D functions is spline and for N-D functions is shepard. @@ -368,17 +368,77 @@ def set_extrapolation(self, method="constant"): self.__set_extrapolation_func() return self - @property - def is_multidimensional(self): - """Return True when the Function has domain dimension greater than 1. + def __process_grid_source(self, source): + """Validate and process a ``(axes, grid_data)`` tuple into a flat + scatter :class:`numpy.ndarray` ready for :meth:`set_source`. - This abstracts checks for multi-dimensionality so callers don't need - to inspect internal attributes like ``__inputs__`` or ``__dom_dim__``. + As a side-effect, stores ``self._grid_axes`` and ``self._grid_data`` + so that :meth:`__set_interpolation_func` (case 6) and + :meth:`__set_extrapolation_func` can build the + :class:`~scipy.interpolate.RegularGridInterpolator`. + + Parameters + ---------- + source : tuple + A 2-element tuple ``(axes, grid_data)`` where *axes* is a list of + 1-D arrays sorted in ascending order (one per input dimension) and + *grid_data* is a matching N-dimensional :class:`numpy.ndarray` of + values. + + Returns + ------- + flat_source : numpy.ndarray + Array of shape ``(n_points, n_dims + 1)`` with all grid points + unrolled in row-major (C) order. + + Raises + ------ + ValueError + If *source* is not a 2-element tuple, if the number of axes + mismatches the grid dimensionality, or if an axis length mismatches + the corresponding grid dimension. """ - try: - return int(self.__dom_dim__) > 1 - except (AttributeError, TypeError): - return False + if not (isinstance(source, Iterable) and len(source) == 2): + raise ValueError( + "For 'regular_grid' interpolation, source must be a " + "(axes, grid_data) tuple where axes is a list of 1-D arrays " + "and grid_data is a matching N-dimensional ndarray." + ) + + raw_axes, raw_data = source + if not isinstance(raw_axes, Iterable): + raise ValueError( + "The first element of the source tuple must be a list or tuple " + "of 1-D arrays representing the grid axes." + ) + + axes = [np.asarray(ax) for ax in raw_axes] + grid_data = np.asarray(raw_data, dtype=np.float64) + + if len(axes) != grid_data.ndim: + raise ValueError( + f"Number of axes ({len(axes)}) must match grid_data dimensions " + f"({grid_data.ndim})." + ) + for i, ax in enumerate(axes): + if len(ax) != grid_data.shape[i]: + raise ValueError( + f"Axis {i} has {len(ax)} points but grid dimension {i} has " + f"{grid_data.shape[i]} points." + ) + if not np.all(np.diff(ax) > 0): + warnings.warn( + f"Axis {i} is not strictly sorted in ascending order. " + "RegularGridInterpolator requires sorted axes.", + UserWarning, + ) + + self._grid_axes = axes + self._grid_data = grid_data + + mesh = np.meshgrid(*axes, indexing="ij") + domain_points = np.column_stack([m.ravel() for m in mesh]) + return np.column_stack([domain_points, grid_data.ravel()]) def __set_interpolation_func(self): # pylint: disable=too-many-statements """Defines interpolation function used by the Function. Each @@ -466,39 +526,25 @@ def rbf_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disab self._interpolation_func = rbf_interpolation case 6: # regular_grid (RegularGridInterpolator) - # For grid interpolation, the actual interpolator is stored separately - # This function is a placeholder that should not be called directly - # since __get_value_opt_grid is used instead - if hasattr(self, "_grid_interpolator"): - - def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument - return self._grid_interpolator(x) - - self._interpolation_func = grid_interpolation - else: - # Fallback to shepard if grid interpolator not available - warnings.warn( - "Grid interpolator not found, falling back to shepard interpolation" + if not hasattr(self, "_grid_axes") or not hasattr(self, "_grid_data"): + raise AttributeError( + "The 'regular_grid' interpolation requires '_grid_axes' and " + "'_grid_data' to be set on the Function instance before calling " + "set_interpolation('regular_grid')." ) + grid_interpolator = RegularGridInterpolator( + self._grid_axes, + self._grid_data, + method="linear", + bounds_error=True, + ) + # Store so extrapolation funcs can reuse it + self._grid_interpolator = grid_interpolator + + def grid_interpolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument + return grid_interpolator(x) - def shepard_fallback(x, x_min, x_max, x_data, y_data, _): - # pylint: disable=unused-argument - arg_qty, arg_dim = x.shape - result = np.empty(arg_qty) - x = x.reshape((arg_qty, 1, arg_dim)) - sub_matrix = x_data - x - distances_squared = np.sum(sub_matrix**2, axis=2) - zero_distances = np.where(distances_squared == 0) - valid_indexes = np.ones(arg_qty, dtype=bool) - valid_indexes[zero_distances[0]] = False - weights = distances_squared[valid_indexes] ** (-1.5) - numerator_sum = np.sum(y_data * weights, axis=1) - denominator_sum = np.sum(weights, axis=1) - result[valid_indexes] = numerator_sum / denominator_sum - result[~valid_indexes] = y_data[zero_distances[1]] - return result - - self._interpolation_func = shepard_fallback + self._interpolation_func = grid_interpolation case _: raise ValueError( @@ -609,6 +655,20 @@ def natural_extrapolation( # pylint: disable=function-redefined ): # pylint: disable=unused-argument return interpolator(x) + case 6: # regular_grid + grid_extrapolator = RegularGridInterpolator( + self._grid_axes, + self._grid_data, + method="linear", + bounds_error=False, + fill_value=None, # linear extrapolation beyond edges + ) + + def natural_extrapolation( # pylint: disable=function-redefined + x, x_min, x_max, x_data, y_data, coeffs + ): # pylint: disable=unused-argument + return grid_extrapolator(x) + case _: raise ValueError( f"Natural extrapolation not defined for {interpolation}." @@ -621,11 +681,23 @@ def natural_extrapolation( # pylint: disable=function-redefined def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument return y_data[0] if x < x_min else y_data[-1] + elif self.__interpolation__ == "regular_grid": + grid_axes = self._grid_axes + grid_interpolator_const = self._grid_interpolator + + def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument + # Clamp each coordinate to its axis bounds, then interpolate + x_clamped = np.copy(x) + for i, axis in enumerate(grid_axes): + x_clamped[:, i] = np.clip( + x_clamped[:, i], axis[0], axis[-1] + ) + return grid_interpolator_const(x_clamped) + else: extrapolator = NearestNDInterpolator(self._domain, self._image) - def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): - # pylint: disable=unused-argument + def constant_extrapolation(x, x_min, x_max, x_data, y_data, coeffs): # pylint: disable=unused-argument return extrapolator(x) self._extrapolation_func = constant_extrapolation @@ -711,66 +783,6 @@ def __get_value_opt_nd(self, *args): return result - def __get_value_opt_grid(self, *args): # pylint: disable=unused-private-member - """Evaluate the Function using RegularGridInterpolator for structured grids. - - This method is dynamically assigned in from_grid() class method. - - Parameters - ---------- - args : tuple - Values where the Function is to be evaluated. Must match the number - of dimensions of the grid. - - Returns - ------- - result : scalar or ndarray - Value of the Function at the specified points. - """ - # Check if we have the grid interpolator - if not hasattr(self, "_grid_interpolator"): - raise RuntimeError( - "Grid interpolator not initialized. Use from_grid() to create " - "a Function with grid interpolation." - ) - - # Convert args to appropriate format for RegularGridInterpolator - # RegularGridInterpolator expects points as (N, ndim) array - if len(args) != self.__dom_dim__: - raise ValueError( - f"Expected {self.__dom_dim__} arguments but got {len(args)}" - ) - - # Handle single point evaluation - point = np.array(args).reshape(1, -1) - - # Handle extrapolation based on the extrapolation setting - if self.__extrapolation__ == "constant": - # Clamp point to grid boundaries for constant extrapolation - for i, axis in enumerate(self._grid_axes): - point[0, i] = np.clip(point[0, i], axis[0], axis[-1]) - result = self._grid_interpolator(point) - elif self.__extrapolation__ == "zero": - # Check if point is outside bounds - outside_bounds = False - for i, axis in enumerate(self._grid_axes): - if point[0, i] < axis[0] or point[0, i] > axis[-1]: - outside_bounds = True - break - if outside_bounds: - result = np.array([0.0]) - else: - result = self._grid_interpolator(point) - else: - # Natural or other extrapolation - use interpolator directly - result = self._grid_interpolator(point) - - # Return scalar for single evaluation - if result.size == 1: - return float(result[0]) - - return result - def __determine_1d_domain_bounds(self, lower, upper): """Determine domain bounds for 1-D function discretization. @@ -912,8 +924,8 @@ def set_discrete( interpolation : string Interpolation method to be used if source type is ndarray. For 1-D functions, linear, polynomial, akima and spline are - supported. For N-D functions, linear, shepard and rbf are - supported. + supported. For N-D functions, linear, shepard, rbf and + regular_grid are supported. Default for 1-D functions is spline and for N-D functions is shepard. extrapolation : string, optional @@ -3912,8 +3924,11 @@ def __validate_source(self, source): # pylint: disable=too-many-statements "Could not read the csv or txt file to create Function source." ) from e - if isinstance(source, (list, np.ndarray)): + if isinstance(source, Iterable): # Triggers an error if source is not a list of numbers + if self.__interpolation__ == "regular_grid": + return self.__process_grid_source(source) + source = np.array(source, dtype=np.float64) # Checks if 2D array @@ -4095,250 +4110,6 @@ def to_dict(self, **kwargs): # pylint: disable=unused-argument "extrapolation": self.__extrapolation__, } - @classmethod - def from_grid( - cls, - grid_data, - axes, - inputs=None, - outputs=None, - interpolation="regular_grid", - extrapolation="constant", - flatten_for_compatibility=True, - **kwargs, - ): # pylint: disable=too-many-statements #TODO: Refactor this method into smaller methods - """Creates a Function from N-dimensional grid data. - - This method is designed for structured grid data, such as CFD simulation - results where values are computed on a regular grid. It uses - scipy.interpolate.RegularGridInterpolator for efficient interpolation. - - Parameters - ---------- - grid_data : ndarray - N-dimensional array containing the function values on the grid. - For example, for a 3D function Cd(M, Re, α), this would be a 3D array - where grid_data[i, j, k] = Cd(M[i], Re[j], α[k]). - axes : list of ndarray - List of 1D arrays defining the grid points along each axis. - Each array should be sorted in ascending order. - For example: [M_axis, Re_axis, alpha_axis]. - inputs : list of str, optional - Names of the input variables. If None, generic names will be used. - For example: ['Mach', 'Reynolds', 'Alpha']. - outputs : str, optional - Name of the output variable. For example: 'Cd'. - interpolation : str, optional - Interpolation method. Default is 'regular_grid'. - Currently only 'regular_grid' is supported for grid data. - extrapolation : str, optional - Extrapolation behavior. Default is ``'constant'`` which clamps to - edge values. Supported options are:: - - 'constant' - Use nearest edge value for out-of-bounds points (clamp). - 'zero' - Return zero for out-of-bounds points. - 'natural' - Use the interpolator's natural behavior: when the - underlying ``RegularGridInterpolator`` is created with - ``fill_value=None`` and ``method='linear'``, this results - in linear extrapolation based on the edge gradients. - - If an unsupported extrapolation value is supplied a ``ValueError`` - is raised. - flatten_for_compatibility : bool, optional - If True (default), creates flattened ``_domain``, ``_image``, and - ``source`` arrays for backward compatibility with existing Function - methods and serialization. For large N-dimensional grids (e.g., - 100x100x100 points), this requires O(n^d) additional memory where n - is the typical axis length and d is the number of dimensions. - Set to False to skip this flattening and reduce memory usage if - compatibility with legacy code paths is not required. - **kwargs : dict, optional - Additional arguments passed to the Function constructor. - - Returns - ------- - Function - A Function object using RegularGridInterpolator for evaluation. - - Notes - ----- - - Grid data must be on a regular (structured) grid. - - For unstructured data, use the regular Function constructor with - scattered points. - - Extrapolation with 'constant' mode uses the nearest edge values, - which is appropriate for aerodynamic coefficients where extrapolation - beyond the data range should be avoided. - - Examples - -------- - >>> import numpy as np - >>> # Create 3D drag coefficient data - >>> mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0]) - >>> reynolds = np.array([1e5, 5e5, 1e6]) - >>> alpha = np.array([0.0, 2.0, 4.0, 6.0]) - >>> # Create a simple drag coefficient function - >>> M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing='ij') - >>> cd_data = 0.3 + 0.1 * M + 1e-7 * Re + 0.01 * A - >>> # Create Function object - >>> cd_func = Function.from_grid( - ... cd_data, - ... [mach, reynolds, alpha], - ... inputs=['Mach', 'Reynolds', 'Alpha'], - ... outputs='Cd' - ... ) - >>> # Evaluate at a point - >>> cd_func(1.2, 3e5, 3.0) - 0.48000000000000004 - - """ - # Validate inputs - if not isinstance(grid_data, np.ndarray): - grid_data = np.array(grid_data) - - if not isinstance(axes, (list, tuple)): - raise ValueError("axes must be a list or tuple of 1D arrays") - - # Ensure all axes are numpy arrays - axes = [ - np.array(axis) if not isinstance(axis, np.ndarray) else axis - for axis in axes - ] - - # Check dimensions match - if len(axes) != grid_data.ndim: - raise ValueError( - f"Number of axes ({len(axes)}) must match grid_data dimensions " - f"({grid_data.ndim})" - ) - - # Check each axis matches corresponding grid dimension and is sorted - for i, axis in enumerate(axes): - if len(axis) != grid_data.shape[i]: - raise ValueError( - f"Axis {i} has {len(axis)} points but grid dimension {i} " - f"has {grid_data.shape[i]} points" - ) - # Check if axis is sorted in ascending order - if not np.all(np.diff(axis) > 0): - warnings.warn( - f"Axis {i} is not strictly sorted in ascending order. " - "RegularGridInterpolator requires sorted axes. " - "This may cause unexpected interpolation results.", - UserWarning, - ) - - # Set default inputs if not provided - if inputs is None: - inputs = [f"x{i}" for i in range(len(axes))] - elif len(inputs) != len(axes): - raise ValueError( - f"Number of inputs ({len(inputs)}) must match number of axes ({len(axes)})" - ) - - # Create a new Function instance - func = cls.__new__(cls) - - # Validate extrapolation option for grid-based interpolation - allowed_extrap = ("constant", "zero", "natural") - if extrapolation not in allowed_extrap: - raise ValueError( - "Unsupported extrapolation for grid interpolation. " - f"Supported values: {allowed_extrap}" - ) - - # Store grid-specific data first - func._grid_axes = axes - func._grid_data = grid_data - - # Create RegularGridInterpolator - # We handle extrapolation manually in __get_value_opt_grid, - # so we set bounds_error=False and let it extrapolate linearly - # (which we'll override when needed) - func._grid_interpolator = RegularGridInterpolator( - axes, - grid_data, - method="linear", - bounds_error=False, - fill_value=None, # Linear extrapolation (will be overridden by manual handling) - ) - - # Create placeholder domain and image for compatibility. - # For large grids this requires O(n^d) memory; set flatten_for_compatibility=False - # to skip this if legacy code compatibility is not required. - if flatten_for_compatibility: - mesh = np.meshgrid(*axes, indexing="ij") - domain_points = np.column_stack([m.ravel() for m in mesh]) - func._domain = domain_points - func._image = grid_data.ravel() - # Set source as flattened data array (for compatibility with serialization) - func.source = np.column_stack([domain_points, func._image]) - else: - # Minimal placeholders - grid interpolator is the primary data source - func._domain = None - func._image = None - func.source = None - - # Initialize basic attributes - func.__inputs__ = inputs - func.__outputs__ = outputs if outputs is not None else "f" - func.__interpolation__ = interpolation - func.__extrapolation__ = extrapolation - func.title = kwargs.get("title", None) - func.__img_dim__ = 1 - func.__cropped_domain__ = (None, None) - func._source_type = SourceType.ARRAY - func.__dom_dim__ = len(axes) - - # Set basic array attributes for compatibility - func.x_array = axes[0] - func.x_initial, func.x_final = axes[0][0], axes[0][-1] - if flatten_for_compatibility: - # For grid-based (N-D) functions, a 1-D `y_array` is not a meaningful - # representation of the function values. Some legacy code paths and - # serialization expect a `y_array` attribute to exist, so provide the - # full flattened image for compatibility rather than a truncated slice. - # Callers should avoid relying on `y_array` for multidimensional - # Functions; use the interpolator / `get_value_opt` instead. - func.y_array = func._image - # Use the global min/max of the flattened image as a sensible - # `y_initial`/`y_final` for compatibility with code that inspects - # scalar bounds. These describe the image range, not an ordering - # along any particular axis. - func.y_initial, func.y_final = ( - float(func._image.min()), - float(func._image.max()), - ) - else: - # Minimal placeholders when flattening is disabled - func.y_array = None - func.y_initial, func.y_final = ( - float(grid_data.min()), - float(grid_data.max()), - ) - if len(axes) > 2: - func.z_array = axes[2] - func.z_initial, func.z_final = axes[2][0], axes[2][-1] - - # Set get_value_opt to use grid interpolation - func.get_value_opt = func.__get_value_opt_grid - - # Set interpolation and extrapolation functions - func.__set_interpolation_func() - # Only set extrapolation function if we have flattened data, otherwise - # extrapolation is handled by __get_value_opt_grid directly - if flatten_for_compatibility: - func.__set_extrapolation_func() - - # Set inputs and outputs properly - func.set_inputs(inputs) - func.set_outputs(outputs) - func.set_title(func.title) - - return func - @classmethod def from_dict(cls, func_dict): """Creates a Function instance from a dictionary. diff --git a/rocketpy/plots/rocket_plots.py b/rocketpy/plots/rocket_plots.py index 0f087a5d6..e208c775f 100644 --- a/rocketpy/plots/rocket_plots.py +++ b/rocketpy/plots/rocket_plots.py @@ -104,20 +104,26 @@ def drag_curves(self, *, filename=None): """ try: - x_power_drag_on = self.rocket.power_on_drag.x_array - y_power_drag_on = self.rocket.power_on_drag.y_array + x_power_drag_on = self.rocket.power_on_drag_by_mach.x_array + y_power_drag_on = self.rocket.power_on_drag_by_mach.y_array except AttributeError: x_power_drag_on = np.linspace(0, 2, 50) y_power_drag_on = np.array( - [self.rocket.power_on_drag.source(x) for x in x_power_drag_on] + [ + self.rocket.power_on_drag_by_mach.get_value_opt(x) + for x in x_power_drag_on + ] ) try: - x_power_drag_off = self.rocket.power_off_drag.x_array - y_power_drag_off = self.rocket.power_off_drag.y_array + x_power_drag_off = self.rocket.power_off_drag_by_mach.x_array + y_power_drag_off = self.rocket.power_off_drag_by_mach.y_array except AttributeError: x_power_drag_off = np.linspace(0, 2, 50) y_power_drag_off = np.array( - [self.rocket.power_off_drag.source(x) for x in x_power_drag_off] + [ + self.rocket.power_off_drag_by_mach.get_value_opt(x) + for x in x_power_drag_off + ] ) _, ax = plt.subplots() diff --git a/rocketpy/rocket/aero_surface/generic_surface.py b/rocketpy/rocket/aero_surface/generic_surface.py index d1982ae04..23ccb0d77 100644 --- a/rocketpy/rocket/aero_surface/generic_surface.py +++ b/rocketpy/rocket/aero_surface/generic_surface.py @@ -1,11 +1,11 @@ import copy -import csv import math import numpy as np from rocketpy.mathutils import Function from rocketpy.mathutils.vector_matrix import Matrix, Vector +from rocketpy.tools import load_generic_surface_csv class GenericSurface: @@ -32,7 +32,8 @@ def __init__( angle of attack, angle of sideslip, Mach number, Reynolds number, pitch rate, yaw rate and roll rate. For CSV files, the header must contain at least one of the following: "alpha", "beta", "mach", - "reynolds", "pitch_rate", "yaw_rate" and "roll_rate". + "reynolds", "pitch_rate", "yaw_rate" and "roll_rate". The + independent variable columns can be provided in any order. See Also -------- @@ -327,7 +328,7 @@ def _process_input(self, input_data, coeff_name): """ if isinstance(input_data, str): # Input is assumed to be a file path to a CSV - return self.__load_csv(input_data, coeff_name) + return load_generic_surface_csv(input_data, coeff_name) elif isinstance(input_data, Function): if input_data.__dom_dim__ != 7: raise ValueError( @@ -342,7 +343,21 @@ def _process_input(self, input_data, coeff_name): f"{coeff_name} function must have 7 input arguments" " (alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate)." ) - return input_data + return Function( + input_data, + [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ], + [coeff_name], + interpolation="linear", + extrapolation="natural", + ) elif input_data == 0: return Function( lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: 0, @@ -364,82 +379,3 @@ def _process_input(self, input_data, coeff_name): f"Invalid input for {coeff_name}: must be a CSV file path" " or a callable." ) - - def __load_csv(self, file_path, coeff_name): - """Load a CSV file and create a Function object with the correct number - of arguments. The CSV file must have a header that specifies the - independent variables that are used. - - Parameters - ---------- - file_path : str - Path to the CSV file. - coeff_name : str - Name of the coefficient being processed. - - Returns - ------- - Function - Function object with 7 input arguments (alpha, beta, mach, reynolds, - pitch_rate, yaw_rate, roll_rate). - """ - try: - with open(file_path, mode="r") as file: - reader = csv.reader(file) - header = next(reader) - except (FileNotFoundError, IOError) as e: - raise ValueError(f"Error reading {coeff_name} CSV file: {e}") from e - - if not header: - raise ValueError(f"Invalid or empty CSV file for {coeff_name}.") - - # TODO make header strings flexible (e.g. 'alpha', 'Alpha', 'ALPHA') - independent_vars = [ - "alpha", - "beta", - "mach", - "reynolds", - "pitch_rate", - "yaw_rate", - "roll_rate", - ] - present_columns = [col for col in independent_vars if col in header] - - # Check that the last column is not an independent variable - if header[-1] in independent_vars: - raise ValueError( - f"Last column in {coeff_name} CSV must be the coefficient" - " value, not an independent variable." - ) - - # Ensure that at least one independent variable is present - if not present_columns: - raise ValueError(f"No independent variables found in {coeff_name} CSV.") - - # Initialize the CSV-based function - csv_func = Function( - file_path, - interpolation="linear", - extrapolation="natural", - ) - - # Create a mask for the presence of each independent variable - # save on self to avoid loss of scope - _mask = [1 if col in present_columns else 0 for col in independent_vars] - - # Generate a lambda that applies only the relevant arguments to csv_func - def wrapper(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate): - args = [alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate] - # Select arguments that correspond to present variables - selected_args = [arg for arg, m in zip(args, _mask) if m] - return csv_func(*selected_args) - - # Create the interpolation function - func = Function( - wrapper, - independent_vars, - [coeff_name], - interpolation="linear", - extrapolation="natural", - ) - return func diff --git a/rocketpy/rocket/point_mass_rocket.py b/rocketpy/rocket/point_mass_rocket.py index 9cdf86d47..eaddaadec 100644 --- a/rocketpy/rocket/point_mass_rocket.py +++ b/rocketpy/rocket/point_mass_rocket.py @@ -21,10 +21,16 @@ class PointMassRocket(Rocket): center_of_mass_without_motor : float Position, in meters, of the rocket's center of mass without motor relative to the rocket's coordinate system. - power_off_drag : float, callable, array, string, Function - Drag coefficient as a function of Mach number when the motor is off. - power_on_drag : float, callable, array, string, Function - Drag coefficient as a function of Mach number when the motor is on. + power_off_drag : int, float, callable, array, string, Function + Drag coefficient input when the motor is off. Accepts the same formats + as :class:`rocketpy.Rocket`, including 1D (Mach-only) and 7D + (alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate) + definitions. + power_on_drag : int, float, callable, array, string, Function + Drag coefficient input when the motor is on. Accepts the same formats + as :class:`rocketpy.Rocket`, including 1D (Mach-only) and 7D + (alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate) + definitions. Attributes ---------- @@ -35,10 +41,16 @@ class PointMassRocket(Rocket): center_of_mass_without_motor : float Position, in meters, of the rocket's center of mass without motor relative to the rocket's coordinate system. - power_off_drag : Function - Drag coefficient as a function of Mach number when the motor is off. - power_on_drag : Function - Drag coefficient as a function of Mach number when the motor is on. + power_off_drag_7d : Function + Drag coefficient function with seven inputs in the order: + alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate. + power_on_drag_7d : Function + Drag coefficient function with seven inputs in the order: + alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate. + power_off_drag_by_mach : Function + Convenience wrapper for power-off drag as a Mach-only function. + power_on_drag_by_mach : Function + Convenience wrapper for power-on drag as a Mach-only function. """ def __init__( diff --git a/rocketpy/rocket/rocket.py b/rocketpy/rocket/rocket.py index 93ab46321..185ef96db 100644 --- a/rocketpy/rocket/rocket.py +++ b/rocketpy/rocket/rocket.py @@ -1,3 +1,4 @@ +import inspect import math import warnings from typing import Iterable @@ -26,6 +27,7 @@ from rocketpy.tools import ( deprecated, find_obj_from_hash, + load_rocket_drag_csv, parallel_axis_theorem_from_com, ) @@ -145,12 +147,22 @@ class Rocket: Rocket.static_margin : float Float value corresponding to rocket static margin when loaded with propellant in units of rocket diameter or calibers. - Rocket.power_off_drag : Function - Rocket's drag coefficient as a function of Mach number when the - motor is off. - Rocket.power_on_drag : Function - Rocket's drag coefficient as a function of Mach number when the - motor is on. + Rocket.power_off_drag : int, float, callable, string, array, Function + Original user input for rocket's drag coefficient when the motor is + off. This is preserved for reconstruction and Monte Carlo workflows. + Rocket.power_on_drag : int, float, callable, string, array, Function + Original user input for rocket's drag coefficient when the motor is + on. This is preserved for reconstruction and Monte Carlo workflows. + Rocket.power_off_drag_7d : Function + Rocket's drag coefficient with motor off as a 7D function of + (alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate). + Rocket.power_on_drag_7d : Function + Rocket's drag coefficient with motor on as a 7D function of + (alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate). + Rocket.power_off_drag_by_mach : Function + Rocket's drag coefficient with motor off as a function of Mach number. + Rocket.power_on_drag_by_mach : Function + Rocket's drag coefficient with motor on as a function of Mach number. Rocket.rail_buttons : RailButtons RailButtons object containing the rail buttons information. Rocket.motor : Motor @@ -342,28 +354,30 @@ def __init__( # pylint: disable=too-many-statements ) # Define aerodynamic drag coefficients - # If already a Function, use it directly (preserves multi-dimensional drag) - if isinstance(power_off_drag, Function): - self.power_off_drag = power_off_drag - else: - self.power_off_drag = Function( - power_off_drag, - "Mach Number", - "Drag Coefficient with Power Off", - "linear", - "constant", - ) - - if isinstance(power_on_drag, Function): - self.power_on_drag = power_on_drag - else: - self.power_on_drag = Function( - power_on_drag, - "Mach Number", - "Drag Coefficient with Power On", - "linear", - "constant", - ) + # Coefficients used during flight simulation + self.power_off_drag_7d = self.__process_drag_input( + power_off_drag, "Drag Coefficient with Power Off" + ) + self.power_on_drag_7d = self.__process_drag_input( + power_on_drag, "Drag Coefficient with Power On" + ) + self.power_on_drag_by_mach = Function( + lambda mach: self.power_on_drag_7d(0, 0, mach, 0, 0, 0, 0), + inputs="Mach Number", + outputs="Drag Coefficient with Power On", + interpolation="linear", + extrapolation="constant", + ) + self.power_off_drag_by_mach = Function( + lambda mach: self.power_off_drag_7d(0, 0, mach, 0, 0, 0, 0), + inputs="Mach Number", + outputs="Drag Coefficient with Power Off", + interpolation="linear", + extrapolation="constant", + ) + # Saving user input for monte carlo + self.power_off_drag = power_off_drag + self.power_on_drag = power_on_drag # Create a, possibly, temporary empty motor # self.motors = Components() # currently unused, only 1 motor is supported @@ -1974,11 +1988,8 @@ def all_info(self): def to_dict(self, **kwargs): discretize = kwargs.get("discretize", False) - power_off_drag = self.power_off_drag - power_on_drag = self.power_on_drag - if discretize: - power_off_drag = power_off_drag.set_discrete(0, 4, 50, mutate_self=False) - power_on_drag = power_on_drag.set_discrete(0, 4, 50, mutate_self=False) + power_off_drag = self.power_off_drag_7d + power_on_drag = self.power_on_drag_7d rocket_dict = { "radius": self.radius, @@ -2143,3 +2154,145 @@ def from_dict(cls, data): rocket._add_controllers(controller) return rocket + + def __process_drag_input(self, input_data, coeff_name): + """Process drag coefficient input and normalize it to a 7D Function. + + Parameters + ---------- + input_data : int, float, str, callable, Function + Input data to be processed. + coeff_name : str + Name of the coefficient being processed for error reporting. + + Returns + ------- + Function + Function object with 7 input arguments in the following order: + alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate. + """ + inputs = [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ] + + # Helper: lift a 1D Mach-only source into the required 7D signature. + def _wrap_mach_only_source(mach_source): + return Function( + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + mach_source(mach) + ), + inputs, + [coeff_name], + interpolation="linear", + extrapolation="constant", + ) + + # Helper: enforce that Function-based inputs are either 1D (Mach) or 7D. + def _validate_function_domain_dimension(function): + if function.__dom_dim__ not in (1, 7): + raise ValueError( + f"{coeff_name} function must have either 1 input argument " + "(mach) or 7 input arguments (alpha, beta, mach, reynolds, " + "pitch_rate, yaw_rate, roll_rate), in that order." + ) + + # Helper: count required positional arguments in a callable. + def _count_positional_args(callable_obj): + signature = inspect.signature(callable_obj) + positional_params = [ + parameter + for parameter in signature.parameters.values() + if parameter.kind + in ( + inspect.Parameter.POSITIONAL_ONLY, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + ) + and parameter.default is inspect.Parameter.empty + ] + return len(positional_params) + + # Case 1: string input can be a CSV path or any Function-supported source. + if isinstance(input_data, str): + if input_data.lower().endswith(".csv"): + return load_rocket_drag_csv(input_data, coeff_name) + + function_data = Function(input_data) + _validate_function_domain_dimension(function_data) + if function_data.__dom_dim__ == 7: + function_data.set_extrapolation("constant") + return function_data + return _wrap_mach_only_source(function_data.get_value_opt) + + # Case 2: Function input is accepted directly after domain validation. + if isinstance(input_data, Function): + _validate_function_domain_dimension(input_data) + if input_data.__dom_dim__ == 7: + input_data.set_extrapolation("constant") + return input_data + return _wrap_mach_only_source(input_data.get_value_opt) + + # Case 3: callable input must expose either 1 (Mach) or 7 arguments. + if callable(input_data): + n_positional_args = _count_positional_args(input_data) + if n_positional_args not in (1, 7): + raise ValueError( + f"{coeff_name} callable must have either 1 positional " + "argument (mach) or 7 positional arguments (alpha, beta, " + "mach, reynolds, pitch_rate, yaw_rate, roll_rate), in that " + "order." + ) + + if n_positional_args == 1: + return _wrap_mach_only_source(input_data) + + return Function( + input_data, + inputs, + [coeff_name], + interpolation="linear", + extrapolation="constant", + ) + + # Case 4: scalar input means a constant drag coefficient in all conditions. + if isinstance(input_data, (int, float)): + return Function( + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + float(input_data) + ), + inputs, + [coeff_name], + interpolation="linear", + extrapolation="constant", + ) + + # If is list/tuple try to pass it to a function. + # If composed of lists/tuples len 2, then interpret as function of mach + # Otherwise interpret it as function of all 7 variables + # This reuses Function's parser and then feeds back into this same pipeline. + if isinstance(input_data, (list, tuple)): + if all( + isinstance(item, (list, tuple)) and (len(item) == 2 or len(item) == 8) + for item in input_data + ): + try: + return self.__process_drag_input( + Function(list(input_data)), coeff_name + ) + except (TypeError, ValueError) as e: + raise ValueError( + f"Invalid list/tuple format for {coeff_name}. Expected " + "a list of [mach, coefficient] pairs or a list of " + "[alpha, beta, mach, reynolds, pitch_rate, yaw_rate, " + "roll_rate, coefficient] entries." + ) from e + + raise TypeError( + f"Invalid input for {coeff_name}: must be int, float, CSV file path, " + "Function, or callable." + ) diff --git a/rocketpy/simulation/flight.py b/rocketpy/simulation/flight.py index 36e299ea0..218fc2620 100644 --- a/rocketpy/simulation/flight.py +++ b/rocketpy/simulation/flight.py @@ -1700,89 +1700,28 @@ def lateral_surface_wind(self): return -wind_u * np.cos(heading_rad) + wind_v * np.sin(heading_rad) - def __get_drag_coefficient(self, drag_function, mach, z, freestream_velocity_body): - """Calculate drag coefficient, handling both 1D and multi-dimensional functions. - - Parameters - ---------- - drag_function : Function - The drag coefficient function (power_on_drag or power_off_drag) - mach : float - Mach number - z : float - Altitude in meters - freestream_velocity_body : Vector or array-like - Freestream velocity in body frame [stream_vx_b, stream_vy_b, stream_vz_b] - - Returns - ------- - float - Drag coefficient value - """ - # Early return for 1D drag functions (only mach number) - if not isinstance(drag_function, Function) or not getattr( - drag_function, "is_multidimensional", False - ): - return drag_function.get_value_opt(mach) - - # Multi-dimensional drag function - calculate additional parameters - - # Calculate Reynolds number: Re = rho * V * L / mu - # where L is characteristic length (rocket diameter) - rho = self.env.density.get_value_opt(z) - mu = self.env.dynamic_viscosity.get_value_opt(z) - freestream_speed = np.linalg.norm(freestream_velocity_body) - characteristic_length = 2 * self.rocket.radius # Diameter - # Defensive: avoid division by zero or non-finite viscosity values. - # Use a small epsilon fallback if `mu` is zero, negative, NaN or infinite. - try: - mu_val = float(mu) - except (TypeError, ValueError, OverflowError): - # Only catch errors related to invalid numeric conversion. - # Avoid catching broad Exception to satisfy linters and - # allow other unexpected errors to surface. - mu_val = 0.0 - if not np.isfinite(mu_val) or mu_val <= 0.0: - mu_safe = 1e-10 - else: - mu_safe = mu_val - - reynolds = rho * freestream_speed * characteristic_length / mu_safe - - # Calculate angle of attack - # Angle between freestream velocity and rocket axis (z-axis in body frame) - # The z component of freestream velocity in body frame - if hasattr(freestream_velocity_body, "z"): - stream_vz_b = -freestream_velocity_body.z - else: - stream_vz_b = -freestream_velocity_body[2] - - # Normalize and calculate angle - if freestream_speed > 1e-6: - cos_alpha = stream_vz_b / freestream_speed - # Clamp to [-1, 1] to avoid numerical issues - cos_alpha = np.clip(cos_alpha, -1.0, 1.0) - alpha_rad = np.arccos(cos_alpha) - alpha_deg = np.rad2deg(alpha_rad) - else: - alpha_deg = 0.0 - - # Determine which parameters to pass based on input names - input_names = [name.lower() for name in drag_function.__inputs__] - args = [] - - for name in input_names: - if "mach" in name or name == "m": - args.append(mach) - elif "reynolds" in name or name == "re": - args.append(reynolds) - elif "alpha" in name or name == "a" or "attack" in name: - args.append(alpha_deg) - else: - # Unknown parameter, default to mach - args.append(mach) - - return drag_function.get_value_opt(*args) + def __compute_drag_7d_inputs( + self, + stream_velocity_body, + stream_speed, + stream_mach, + density, + dynamic_viscosity, + ): + """Build drag-model inputs in the 7D order used by Rocket drag functions.""" + aerodynamic_stream_velocity = -stream_velocity_body + alpha = np.arctan2( + aerodynamic_stream_velocity[1], aerodynamic_stream_velocity[2] + ) + beta = np.arctan2( + aerodynamic_stream_velocity[0], aerodynamic_stream_velocity[2] + ) + reynolds = ( + density * stream_speed * (2 * self.rocket.radius) / dynamic_viscosity + if dynamic_viscosity > 0 + else 0 + ) + return alpha, beta, stream_mach, reynolds def udot_rail1(self, t, u, post_processing=False): """Calculates derivative of u state vector with respect to time @@ -1814,38 +1753,28 @@ def udot_rail1(self, t, u, post_processing=False): total_mass_at_t = self.rocket.total_mass.get_value_opt(t) # Get freestream speed - free_stream_speed = ( - (self.env.wind_velocity_x.get_value_opt(z) - vx) ** 2 - + (self.env.wind_velocity_y.get_value_opt(z) - vy) ** 2 - + (vz) ** 2 - ) ** 0.5 + free_stream_velocity = Vector( + [ + self.env.wind_velocity_x.get_value_opt(z) - vx, + self.env.wind_velocity_y.get_value_opt(z) - vy, + -vz, + ] + ) + free_stream_speed = abs(free_stream_velocity) free_stream_mach = free_stream_speed / self.env.speed_of_sound.get_value_opt(z) - - # For rail motion, rocket is constrained - velocity mostly along z-axis in body frame - # Calculate velocity in body frame (simplified for rail) - a11 = 1 - 2 * (e2**2 + e3**2) - a12 = 2 * (e1 * e2 - e0 * e3) - a13 = 2 * (e1 * e3 + e0 * e2) - a21 = 2 * (e1 * e2 + e0 * e3) - a22 = 1 - 2 * (e1**2 + e3**2) - a23 = 2 * (e2 * e3 - e0 * e1) - a31 = 2 * (e1 * e3 - e0 * e2) - a32 = 2 * (e2 * e3 + e0 * e1) - a33 = 1 - 2 * (e1**2 + e2**2) - - # Freestream velocity in body frame - wind_vx = self.env.wind_velocity_x.get_value_opt(z) - wind_vy = self.env.wind_velocity_y.get_value_opt(z) - stream_vx_b = a11 * (wind_vx - vx) + a21 * (wind_vy - vy) + a31 * (-vz) - stream_vy_b = a12 * (wind_vx - vx) + a22 * (wind_vy - vy) + a32 * (-vz) - stream_vz_b = a13 * (wind_vx - vx) + a23 * (wind_vy - vy) + a33 * (-vz) - - drag_coeff = self.__get_drag_coefficient( - self.rocket.power_on_drag, + rho = self.env.density.get_value_opt(z) + stream_velocity_body = ( + Matrix.transformation([e0, e1, e2, e3]).transpose @ free_stream_velocity + ) + dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(z) + alpha, beta, mach, reynolds = self.__compute_drag_7d_inputs( + stream_velocity_body, + free_stream_speed, free_stream_mach, - z, - [stream_vx_b, stream_vy_b, stream_vz_b], + rho, + dynamic_viscosity, ) + drag_coeff = self.rocket.power_on_drag_7d(alpha, beta, mach, reynolds, 0, 0, 0) # Calculate Forces pressure = self.env.pressure.get_value_opt(z) @@ -1854,7 +1783,6 @@ def udot_rail1(self, t, u, post_processing=False): + self.rocket.motor.pressure_thrust(pressure), 0, ) - rho = self.env.density.get_value_opt(z) R3 = -0.5 * rho * (free_stream_speed**2) * self.rocket.area * (drag_coeff) # Calculate Linear acceleration @@ -2010,44 +1938,42 @@ def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals wind_velocity_x = self.env.wind_velocity_x.get_value_opt(z) wind_velocity_y = self.env.wind_velocity_y.get_value_opt(z) speed_of_sound = self.env.speed_of_sound.get_value_opt(z) - free_stream_speed = ( - (wind_velocity_x - vx) ** 2 + (wind_velocity_y - vy) ** 2 + (vz) ** 2 - ) ** 0.5 + free_stream_velocity = Vector([wind_velocity_x - vx, wind_velocity_y - vy, -vz]) + free_stream_speed = abs(free_stream_velocity) free_stream_mach = free_stream_speed / speed_of_sound - - # Get rocket velocity in body frame (needed for drag calculation) - vx_b = a11 * vx + a21 * vy + a31 * vz - vy_b = a12 * vx + a22 * vy + a32 * vz - vz_b = a13 * vx + a23 * vy + a33 * vz - - # Calculate freestream velocity in body frame - stream_vx_b = ( - a11 * (wind_velocity_x - vx) + a21 * (wind_velocity_y - vy) + a31 * (-vz) - ) - stream_vy_b = ( - a12 * (wind_velocity_x - vx) + a22 * (wind_velocity_y - vy) + a32 * (-vz) - ) - stream_vz_b = ( - a13 * (wind_velocity_x - vx) + a23 * (wind_velocity_y - vy) + a33 * (-vz) - ) + stream_velocity_body = Kt @ free_stream_velocity # Determine aerodynamics forces # Determine Drag Force + rho = self.env.density.get_value_opt(z) + dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(z) + alpha, beta, mach, reynolds = self.__compute_drag_7d_inputs( + stream_velocity_body, + free_stream_speed, + free_stream_mach, + rho, + dynamic_viscosity, + ) if t < self.rocket.motor.burn_out_time: - drag_coeff = self.__get_drag_coefficient( - self.rocket.power_on_drag, - free_stream_mach, - z, - [stream_vx_b, stream_vy_b, stream_vz_b], + drag_coeff = self.rocket.power_on_drag_7d( + alpha, + beta, + mach, + reynolds, + omega1, + omega2, + omega3, ) else: - drag_coeff = self.__get_drag_coefficient( - self.rocket.power_off_drag, - free_stream_mach, - z, - [stream_vx_b, stream_vy_b, stream_vz_b], + drag_coeff = self.rocket.power_off_drag_7d( + alpha, + beta, + mach, + reynolds, + omega1, + omega2, + omega3, ) - rho = self.env.density.get_value_opt(z) R3 = -0.5 * rho * (free_stream_speed**2) * self.rocket.area * drag_coeff for air_brakes in self.rocket.air_brakes: if air_brakes.deployment_level > 0: @@ -2068,6 +1994,10 @@ def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals # Off center moment M1 += self.rocket.cp_eccentricity_y * R3 M2 -= self.rocket.cp_eccentricity_x * R3 + # Get rocket velocity in body frame + vx_b = a11 * vx + a21 * vy + a31 * vz + vy_b = a12 * vx + a22 * vy + a32 * vz + vz_b = a13 * vx + a23 * vy + a33 * vz # Calculate lift and moment for each component of the rocket velocity_in_body_frame = Vector([vx_b, vy_b, vz_b]) w = Vector([omega1, omega2, omega3]) @@ -2088,11 +2018,13 @@ def u_dot(self, t, u, post_processing=False): # pylint: disable=too-many-locals # Reynolds at component altitude # TODO: Reynolds is only used in generic surfaces. This calculation # should be moved to the surface class for efficiency + comp_density = self.env.density.get_value_opt(comp_z) + comp_dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(comp_z) comp_reynolds = ( - self.env.density.get_value_opt(comp_z) + comp_density * comp_stream_speed * aero_surface.reference_length - / self.env.dynamic_viscosity.get_value_opt(comp_z) + / comp_dynamic_viscosity ) # Forces and moments X, Y, Z, M, N, L = aero_surface.compute_forces_and_moments( @@ -2286,12 +2218,25 @@ def u_dot_generalized_3dof(self, t, u, post_processing=False): free_stream_speed = abs(free_stream_velocity) speed_of_sound = self.env.speed_of_sound.get_value_opt(z) mach = free_stream_speed / speed_of_sound + stream_velocity_body = Kt @ free_stream_velocity + dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(z) + alpha, beta, mach, reynolds = self.__compute_drag_7d_inputs( + stream_velocity_body, + free_stream_speed, + mach, + rho, + dynamic_viscosity, + ) # Drag computation if t < self.rocket.motor.burn_out_time: - cd = self.rocket.power_on_drag.get_value_opt(mach) + cd = self.rocket.power_on_drag_7d( + alpha, beta, mach, reynolds, omega1, omega2, omega3 + ) else: - cd = self.rocket.power_off_drag.get_value_opt(mach) + cd = self.rocket.power_off_drag_7d( + alpha, beta, mach, reynolds, omega1, omega2, omega3 + ) R1, R2 = 0, 0 R3 = -0.5 * rho * free_stream_speed**2 * self.rocket.area * cd @@ -2325,11 +2270,13 @@ def u_dot_generalized_3dof(self, t, u, post_processing=False): rel_speed = abs(rel_velocity) rel_mach = rel_speed / speed_of_sound + comp_density = self.env.density.get_value_opt(comp_z) + comp_dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(comp_z) reynolds = ( - self.env.density.get_value_opt(comp_z) + comp_density * rel_speed * surface.reference_length - / self.env.dynamic_viscosity.get_value_opt(comp_z) + / comp_dynamic_viscosity ) fx, fy, fz, *_ = surface.compute_forces_and_moments( @@ -2514,15 +2461,19 @@ def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too wind_velocity_x = self.env.wind_velocity_x.get_value_opt(z) wind_velocity_y = self.env.wind_velocity_y.get_value_opt(z) wind_velocity = Vector([wind_velocity_x, wind_velocity_y, 0]) - free_stream_speed = abs((wind_velocity - Vector(v))) + free_stream_velocity = wind_velocity - v + free_stream_speed = abs(free_stream_velocity) speed_of_sound = self.env.speed_of_sound.get_value_opt(z) free_stream_mach = free_stream_speed / speed_of_sound - - # Get rocket velocity in body frame (needed for drag calculation) - velocity_in_body_frame = Kt @ v - # Calculate freestream velocity in body frame - freestream_velocity = wind_velocity - v - freestream_velocity_body = Kt @ freestream_velocity + stream_velocity_body = Kt @ free_stream_velocity + dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(z) + alpha, beta, mach, reynolds = self.__compute_drag_7d_inputs( + stream_velocity_body, + free_stream_speed, + free_stream_mach, + rho, + dynamic_viscosity, + ) if self.rocket.motor.burn_start_time < t < self.rocket.motor.burn_out_time: pressure = self.env.pressure.get_value_opt(z) @@ -2531,19 +2482,25 @@ def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too + self.rocket.motor.pressure_thrust(pressure), 0, ) - drag_coeff = self.__get_drag_coefficient( - self.rocket.power_on_drag, - free_stream_mach, - z, - freestream_velocity_body, + drag_coeff = self.rocket.power_on_drag_7d( + alpha, + beta, + mach, + reynolds, + omega1, + omega2, + omega3, ) else: net_thrust = 0 - drag_coeff = self.__get_drag_coefficient( - self.rocket.power_off_drag, - free_stream_mach, - z, - freestream_velocity_body, + drag_coeff = self.rocket.power_off_drag_7d( + alpha, + beta, + mach, + reynolds, + omega1, + omega2, + omega3, ) R3 += -0.5 * rho * (free_stream_speed**2) * self.rocket.area * drag_coeff for air_brakes in self.rocket.air_brakes: @@ -2562,6 +2519,8 @@ def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too R3 = air_brakes_force # Substitutes rocket drag coefficient else: R3 += air_brakes_force + # Get rocket velocity in body frame + velocity_in_body_frame = Kt @ v # Calculate lift and moment for each component of the rocket for aero_surface, _ in self.rocket.aerodynamic_surfaces: # Component cp relative to CDM in body frame @@ -2580,11 +2539,13 @@ def u_dot_generalized(self, t, u, post_processing=False): # pylint: disable=too # Reynolds at component altitude # TODO: Reynolds is only used in generic surfaces. This calculation # should be moved to the surface class for efficiency + comp_density = self.env.density.get_value_opt(comp_z) + comp_dynamic_viscosity = self.env.dynamic_viscosity.get_value_opt(comp_z) comp_reynolds = ( - self.env.density.get_value_opt(comp_z) + comp_density * comp_stream_speed * aero_surface.reference_length - / self.env.dynamic_viscosity.get_value_opt(comp_z) + / comp_dynamic_viscosity ) # Forces and moments X, Y, Z, M, N, L = aero_surface.compute_forces_and_moments( diff --git a/rocketpy/stochastic/stochastic_rocket.py b/rocketpy/stochastic/stochastic_rocket.py index 9aad8872b..794a66c85 100644 --- a/rocketpy/stochastic/stochastic_rocket.py +++ b/rocketpy/stochastic/stochastic_rocket.py @@ -742,8 +742,8 @@ def create_object(self): "coordinate_system_orientation" ], ) - rocket.power_off_drag *= generated_dict["power_off_drag_factor"] - rocket.power_on_drag *= generated_dict["power_on_drag_factor"] + rocket.power_off_drag_7d *= generated_dict["power_off_drag_factor"] + rocket.power_on_drag_7d *= generated_dict["power_on_drag_factor"] if hasattr(self, "cp_eccentricity_x") and hasattr(self, "cp_eccentricity_y"): cp_ecc_x, cp_ecc_y = self._create_eccentricities( diff --git a/rocketpy/tools.py b/rocketpy/tools.py index 68ab3404a..ef6dbfcc7 100644 --- a/rocketpy/tools.py +++ b/rocketpy/tools.py @@ -7,6 +7,7 @@ """ import base64 +import csv import functools import importlib import importlib.metadata @@ -116,6 +117,292 @@ def tuple_handler(value): raise ValueError("value must be a list or tuple of length 1 or 2.") +def create_regular_grid_function( + csv_source, + variable_names, + coeff_name, + extrapolation, +): + """Create a regular-grid Function when CSV samples form a full grid. + + Parameters + ---------- + csv_source : str + Path to the CSV file. + variable_names : list[str] + Ordered independent variable names present in the CSV. + coeff_name : str + Name of the coefficient output. + extrapolation : str + Extrapolation method passed to the Function constructor. + + Returns + ------- + Function or None + A ``Function`` configured with ``regular_grid`` interpolation when the + CSV data forms a strict Cartesian grid, otherwise ``None``. + """ + from rocketpy.mathutils.function import ( + Function, # pylint: disable=import-outside-toplevel + ) + + data = np.loadtxt(csv_source, delimiter=",", skiprows=1, dtype=float) + + data = np.atleast_2d(data) + expected_columns = len(variable_names) + 1 + if data.shape[1] != expected_columns: + return None + + coordinates = data[:, :-1] + values = data[:, -1] + + if np.unique(coordinates, axis=0).shape[0] != coordinates.shape[0]: + return None + + axes = [np.unique(coordinates[:, i]) for i in range(len(variable_names))] + expected_size = int(np.prod([axis.size for axis in axes])) + if expected_size != coordinates.shape[0]: + return None + + sorting_keys = [coordinates[:, i] for i in range(len(variable_names) - 1, -1, -1)] + sorted_indices = np.lexsort(tuple(sorting_keys)) + sorted_coordinates = coordinates[sorted_indices] + sorted_values = values[sorted_indices] + + expected_coordinates = np.column_stack( + [axis_values.ravel() for axis_values in np.meshgrid(*axes, indexing="ij")] + ) + if not np.allclose(sorted_coordinates, expected_coordinates, rtol=0, atol=1e-12): + return None + + grid_data = sorted_values.reshape(tuple(axis.size for axis in axes)) + return Function( + (axes, grid_data), + inputs=variable_names, + outputs=[coeff_name], + interpolation="regular_grid", + extrapolation=extrapolation, + ) + + +def load_generic_surface_csv(file_path, coeff_name): # pylint: disable=too-many-statements + """Load GenericSurface coefficient CSV into a 7D Function. + + This loader expects header-based CSV data with one or more independent + variables among: alpha, beta, mach, reynolds, pitch_rate, yaw_rate, + roll_rate. + """ + from rocketpy.mathutils.function import ( + Function, # pylint: disable=import-outside-toplevel + ) + + independent_vars = [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ] + + try: + with open(file_path, mode="r") as file: + reader = csv.reader(file) + header = next(reader) + except (FileNotFoundError, IOError) as e: + raise ValueError(f"Error reading {coeff_name} CSV file: {e}") from e + except StopIteration as e: + raise ValueError(f"Invalid or empty CSV file for {coeff_name}.") from e + + if not header: + raise ValueError(f"Invalid or empty CSV file for {coeff_name}.") + + header = [column.strip() for column in header] + present_columns = [col for col in independent_vars if col in header] + + invalid_columns = [col for col in header[:-1] if col not in independent_vars] + if invalid_columns: + raise ValueError( + f"Invalid independent variable(s) in {coeff_name} CSV: " + f"{invalid_columns}. Valid options are: {independent_vars}." + ) + + if header[-1] in independent_vars: + raise ValueError( + f"Last column in {coeff_name} CSV must be the coefficient" + " value, not an independent variable." + ) + + if not present_columns: + raise ValueError(f"No independent variables found in {coeff_name} CSV.") + + ordered_present_columns = [col for col in header[:-1] if col in independent_vars] + + csv_func = create_regular_grid_function( + file_path, + ordered_present_columns, + coeff_name, + extrapolation="natural", + ) + if csv_func is None: + csv_func = Function( + file_path, + interpolation="linear", + extrapolation="natural", + ) + + def wrapper(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate): + args_by_name = { + "alpha": alpha, + "beta": beta, + "mach": mach, + "reynolds": reynolds, + "pitch_rate": pitch_rate, + "yaw_rate": yaw_rate, + "roll_rate": roll_rate, + } + selected_args = [args_by_name[col] for col in ordered_present_columns] + return csv_func(*selected_args) + + return Function( + wrapper, + independent_vars, + [coeff_name], + interpolation="linear", + extrapolation="natural", + ) + + +def load_rocket_drag_csv(file_path, coeff_name): # pylint: disable=too-many-statements + """Load Rocket drag CSV into a 7D Function. + + Supports either headerless two-column (mach, coefficient) tables or + header-based multi-variable CSV tables. + """ + from rocketpy.mathutils.function import ( + Function, # pylint: disable=import-outside-toplevel + ) + + independent_vars = [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ] + + def _is_numeric(value): + try: + float(value) + return True + except (TypeError, ValueError): + try: + int(value) + return True + except (TypeError, ValueError): + return False + + try: + with open(file_path, mode="r") as file: + reader = csv.reader(file) + first_row = next(reader) + except (FileNotFoundError, IOError) as e: + raise ValueError(f"Error reading {coeff_name} CSV file: {e}") from e + except StopIteration as e: + raise ValueError(f"Invalid or empty CSV file for {coeff_name}.") from e + + if not first_row: + raise ValueError(f"Invalid or empty CSV file for {coeff_name}.") + + is_headerless_two_column = len(first_row) == 2 and all( + _is_numeric(cell) for cell in first_row + ) + + if is_headerless_two_column: + csv_func = Function( + file_path, + interpolation="linear", + extrapolation="constant", + ) + + def mach_wrapper( + _alpha, + _beta, + mach, + _reynolds, + _pitch_rate, + _yaw_rate, + _roll_rate, + ): + return csv_func(mach) + + return Function( + mach_wrapper, + independent_vars, + [coeff_name], + interpolation="linear", + extrapolation="constant", + ) + + header = [column.strip() for column in first_row] + present_columns = [col for col in independent_vars if col in header] + + invalid_columns = [col for col in header[:-1] if col not in independent_vars] + if invalid_columns: + raise ValueError( + f"Invalid independent variable(s) in {coeff_name} CSV: " + f"{invalid_columns}. Valid options are: {independent_vars}." + ) + + if header[-1] in independent_vars: + raise ValueError( + f"Last column in {coeff_name} CSV must be the coefficient " + "value, not an independent variable." + ) + + if not present_columns: + raise ValueError(f"No independent variables found in {coeff_name} CSV.") + + ordered_present_columns = [col for col in header[:-1] if col in independent_vars] + + csv_func = create_regular_grid_function( + file_path, + ordered_present_columns, + coeff_name, + extrapolation="constant", + ) + if csv_func is None: + csv_func = Function( + file_path, + interpolation="linear", + extrapolation="constant", + ) + + def wrapper(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate): + args_by_name = { + "alpha": alpha, + "beta": beta, + "mach": mach, + "reynolds": reynolds, + "pitch_rate": pitch_rate, + "yaw_rate": yaw_rate, + "roll_rate": roll_rate, + } + selected_args = [args_by_name[col] for col in ordered_present_columns] + return csv_func(*selected_args) + + return Function( + wrapper, + independent_vars, + [coeff_name], + interpolation="linear", + extrapolation="constant", + ) + + def calculate_cubic_hermite_coefficients(x0, x1, y0, yp0, y1, yp1): """Calculate the coefficients of a cubic Hermite interpolation function. The function is defined as ax**3 + bx**2 + cx + d. diff --git a/tests/acceptance/test_bella_lui_rocket.py b/tests/acceptance/test_bella_lui_rocket.py index a67547780..bcfe325bc 100644 --- a/tests/acceptance/test_bella_lui_rocket.py +++ b/tests/acceptance/test_bella_lui_rocket.py @@ -141,7 +141,7 @@ def drogue_trigger(p, h, y): ) # Define aerodynamic drag coefficients - BellaLui.power_off_drag = Function( + power_off_drag_by_mach = Function( [ (0.01, 0.51), (0.02, 0.46), @@ -156,7 +156,7 @@ def drogue_trigger(p, h, y): "linear", "constant", ) - BellaLui.power_on_drag = Function( + power_on_drag_by_mach = Function( [ (0.01, 0.51), (0.02, 0.46), @@ -171,8 +171,42 @@ def drogue_trigger(p, h, y): "linear", "constant", ) - BellaLui.power_off_drag *= parameters.get("power_off_drag")[0] - BellaLui.power_on_drag *= parameters.get("power_on_drag")[0] + BellaLui.power_off_drag_7d = Function( + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + power_off_drag_by_mach.get_value_opt(mach) + ), + [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ], + ["Drag Coefficient with Power Off"], + interpolation="linear", + extrapolation="constant", + ) + BellaLui.power_on_drag_7d = Function( + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + power_on_drag_by_mach.get_value_opt(mach) + ), + [ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ], + ["Drag Coefficient with Power On"], + interpolation="linear", + extrapolation="constant", + ) + BellaLui.power_off_drag_7d *= parameters.get("power_off_drag")[0] + BellaLui.power_on_drag_7d *= parameters.get("power_on_drag")[0] # Flight test_flight = Flight( diff --git a/tests/fixtures/flight/flight_fixtures.py b/tests/fixtures/flight/flight_fixtures.py index 65acb8661..b13b52b6b 100644 --- a/tests/fixtures/flight/flight_fixtures.py +++ b/tests/fixtures/flight/flight_fixtures.py @@ -1,7 +1,6 @@ -import numpy as np import pytest -from rocketpy import Flight, Function, Rocket +from rocketpy import Flight from rocketpy.motors.point_mass_motor import PointMassMotor from rocketpy.rocket.point_mass_rocket import PointMassRocket @@ -298,93 +297,6 @@ def flight_calisto_with_sensors(calisto_with_sensors, example_plain_env): ) -@pytest.fixture -def flight_alpha(example_plain_env, cesaroni_m1670): - """Fixture that returns a Flight using an alpha-dependent 3D Cd function.""" - # Create grid data - mach = np.array([0.0, 0.5, 1.0, 1.5]) - reynolds = np.array([1e5, 1e6]) - alpha = np.array([0.0, 5.0, 10.0, 15.0]) - M, _, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - cd_data = 0.3 + 0.05 * M + 0.03 * A - cd_data = np.clip(cd_data, 0.2, 2.0) - - drag_func = Function.from_grid( - cd_data, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - - env = example_plain_env - env.set_atmospheric_model(type="standard_atmosphere") - - # Build rocket and flight - rocket = Rocket( - radius=0.0635, - mass=16.24, - inertia=(6.321, 6.321, 0.034), - power_off_drag=drag_func, - power_on_drag=drag_func, - center_of_mass_without_motor=0, - coordinate_system_orientation="tail_to_nose", - ) - rocket.set_rail_buttons(0.2, -0.5, 30) - rocket.add_motor(cesaroni_m1670, position=-1.255) - - return Flight( - rocket=rocket, - environment=env, - rail_length=5.2, - inclination=85, - heading=0, - ) - - -@pytest.fixture -def flight_flat(example_plain_env, cesaroni_m1670): - """Fixture that returns a Flight using an alpha-averaged (flat) Cd function.""" - # Create grid data - mach = np.array([0.0, 0.5, 1.0, 1.5]) - reynolds = np.array([1e5, 1e6]) - alpha = np.array([0.0, 5.0, 10.0, 15.0]) - M, _, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - cd_data = 0.3 + 0.05 * M + 0.03 * A - cd_data = np.clip(cd_data, 0.2, 2.0) - - cd_flat = cd_data.mean(axis=2) - drag_flat = Function.from_grid( - cd_flat, - [mach, reynolds], - inputs=["Mach", "Reynolds"], - outputs="Cd", - ) - - env = example_plain_env - env.set_atmospheric_model(type="standard_atmosphere") - - # Build rocket and flight - rocket = Rocket( - radius=0.0635, - mass=16.24, - inertia=(6.321, 6.321, 0.034), - power_off_drag=drag_flat, - power_on_drag=drag_flat, - center_of_mass_without_motor=0, - coordinate_system_orientation="tail_to_nose", - ) - rocket.set_rail_buttons(0.2, -0.5, 30) - rocket.add_motor(cesaroni_m1670, position=-1.255) - - return Flight( - rocket=rocket, - environment=env, - rail_length=5.2, - inclination=85, - heading=0, - ) - - # 3 DOF Flight Fixtures # These fixtures are for testing the 3 DOF flight simulation mode # Based on Bella Lui rocket parameters for realistic acceptance testing diff --git a/tests/integration/simulation/test_flight_3dof.py b/tests/integration/simulation/test_flight_3dof.py index 94a4e33eb..fc3035caa 100644 --- a/tests/integration/simulation/test_flight_3dof.py +++ b/tests/integration/simulation/test_flight_3dof.py @@ -200,6 +200,76 @@ def test_weathercock_coeff_default(flight_3dof): assert flight_3dof.weathercock_coeff == 0.0 +def test_point_mass_rocket_3dof_uses_7d_drag_inputs( + example_plain_env, point_mass_motor +): + """Ensure PointMassRocket uses the 7D drag interface in 3-DOF dynamics.""" + + def drag_7d(alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate): + return ( + 0.2 + + 0.01 * abs(alpha) + + 0.01 * abs(beta) + + 1e-7 * reynolds + + 0.001 * (abs(pitch_rate) + abs(yaw_rate) + abs(roll_rate)) + + 0.01 * mach + ) + + rocket = PointMassRocket( + radius=0.05, + mass=2.0, + center_of_mass_without_motor=0.1, + power_off_drag=drag_7d, + power_on_drag=drag_7d, + ) + rocket.add_motor(point_mass_motor, position=0) + + flight = Flight( + rocket=rocket, + environment=example_plain_env, + rail_length=1, + simulation_mode="3 DOF", + ) + + t = 10.0 + u = [0, 0, 100, 50, 5, 0, 1, 0, 0, 0, 0.3, -0.2, 0.1] + u_dot = flight.u_dot_generalized_3dof(t, u) + + z = u[2] + vx, vy, vz = u[3], u[4], u[5] + omega1, omega2, omega3 = u[10], u[11], u[12] + + rho = flight.env.density.get_value_opt(z) + dynamic_viscosity = flight.env.dynamic_viscosity.get_value_opt(z) + wind_vx = flight.env.wind_velocity_x.get_value_opt(z) + wind_vy = flight.env.wind_velocity_y.get_value_opt(z) + speed_of_sound = flight.env.speed_of_sound.get_value_opt(z) + gravity = flight.env.gravity.get_value_opt(z) + + free_stream_velocity = np.array([wind_vx - vx, wind_vy - vy, -vz]) + free_stream_speed = np.linalg.norm(free_stream_velocity) + mach = free_stream_speed / speed_of_sound + + stream_velocity_body = free_stream_velocity + aerodynamic_stream_velocity = -stream_velocity_body + alpha = np.arctan2(aerodynamic_stream_velocity[1], aerodynamic_stream_velocity[2]) + beta = np.arctan2(aerodynamic_stream_velocity[0], aerodynamic_stream_velocity[2]) + + reynolds = ( + rho * free_stream_speed * (2 * rocket.radius) / dynamic_viscosity + if dynamic_viscosity > 0 + else 0 + ) + + cd_expected = drag_7d(alpha, beta, mach, reynolds, omega1, omega2, omega3) + r3_expected = -0.5 * rho * free_stream_speed**2 * rocket.area * cd_expected + az_expected = ( + r3_expected - rocket.total_mass.get_value_opt(t) * gravity + ) / rocket.total_mass.get_value_opt(t) + + assert u_dot[5] == pytest.approx(az_expected) + + def test_weathercock_zero_gives_fixed_attitude(flight_weathercock_zero): """Tests that weathercock_coeff=0 results in fixed attitude (no quaternion change). When weathercock_coeff is 0, the quaternion derivatives should be zero, diff --git a/tests/integration/test_multidim_drag.py b/tests/integration/test_multidim_drag.py deleted file mode 100644 index a6280e53b..000000000 --- a/tests/integration/test_multidim_drag.py +++ /dev/null @@ -1,157 +0,0 @@ -"""Integration tests for multi-dimensional drag coefficient support.""" - -import numpy as np - -from rocketpy import Flight, Function, Rocket - - -def test_flight_with_1d_drag(flight_calisto): - """Test that flights with 1D drag curves still work (backward compatibility).""" - - # `flight_calisto` is a fixture that already runs the simulation - flight = flight_calisto - - # Check that flight completed successfully - assert flight.t_final > 0 - assert flight.apogee > 0 - assert flight.apogee_time > 0 - - -def test_flight_with_3d_drag_basic(example_plain_env, cesaroni_m1670): - """Test that a simple 3D drag function works.""" - # Use fixtures for environment and motor - env = example_plain_env - env.set_atmospheric_model(type="standard_atmosphere") - motor = cesaroni_m1670 - - # Create 3D drag - mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0]) - reynolds = np.array([1e5, 5e5, 1e6]) - alpha = np.array([0.0, 2.0, 4.0, 6.0]) - - M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - cd_data = 0.3 + 0.1 * M - 1e-7 * Re + 0.01 * A - cd_data = np.clip(cd_data, 0.2, 1.0) - - power_off_drag = Function.from_grid( - cd_data, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - power_on_drag = Function.from_grid( - cd_data * 1.1, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - - # Create rocket - rocket = Rocket( - radius=0.0635, - mass=16.24, - inertia=(6.321, 6.321, 0.034), - power_off_drag=power_off_drag, - power_on_drag=power_on_drag, - center_of_mass_without_motor=0, - coordinate_system_orientation="tail_to_nose", - ) - rocket.set_rail_buttons(0.2, -0.5, 30) - rocket.add_motor(motor, position=-1.255) - - # Run flight - flight = Flight( - rocket=rocket, - environment=env, - rail_length=5.2, - inclination=85, - heading=0, - ) - - # Check results - should launch and have non-zero apogee - assert flight.apogee > 100, f"Apogee too low: {flight.apogee}m" - assert flight.apogee < 5000, f"Apogee too high: {flight.apogee}m" - assert hasattr(flight, "angle_of_attack") - - -def test_3d_drag_with_varying_alpha(): - """Test that 3D drag responds to angle of attack changes. - - This test only verifies the Function mapping from alpha -> Cd. The - integration-level comparison is placed in a separate test to keep each - test function small and easier to lint/maintain. - """ - # Create drag function with strong alpha dependency - mach = np.array([0.0, 0.5, 1.0, 1.5]) - reynolds = np.array([1e5, 1e6]) - alpha = np.array([0.0, 5.0, 10.0, 15.0]) - - M, _, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - # Strong alpha dependency: Cd increases significantly with alpha - cd_data = 0.3 + 0.05 * M + 0.03 * A - cd_data = np.clip(cd_data, 0.2, 2.0) - - drag_func = Function.from_grid( - cd_data, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - - # Test at different angles of attack (direct function call) - # At zero alpha, Cd should be lower - cd_0 = drag_func(0.8, 5e5, 0.0) - cd_10 = drag_func(0.8, 5e5, 10.0) - - # Cd should increase with alpha - assert cd_10 > cd_0 - assert cd_10 - cd_0 > 0.2 # Should show significant difference - - -def test_flight_apogee_diff(flight_alpha, flight_flat): - """Run paired flights (fixtures) and assert their apogees differ.""" - - # Flights should both launch - assert flight_alpha.apogee > 100 - assert flight_flat.apogee > 100 - - # Apogees should differ - assert flight_alpha.apogee != flight_flat.apogee - - -def test_flight_cd_sample_consistency(flight_alpha, flight_flat): - """Sample Cd during a flight and ensure Cd difference matches apogee ordering. - - Uses the `flight_alpha` and `flight_flat` fixtures which provide paired - flights constructed with alpha-dependent and alpha-averaged Cd functions. - """ - - # Sample a mid-ascent time and compare Cd evaluations - speeds = flight_alpha.free_stream_speed[:, 1] - idx_candidates = np.where(speeds > 5)[0] - assert idx_candidates.size > 0 - idx = idx_candidates[len(idx_candidates) // 2] - t_sample = flight_alpha.time[idx] - - mach_sample = flight_alpha.mach_number.get_value_opt(t_sample) - v_sample = flight_alpha.free_stream_speed.get_value_opt(t_sample) - reynolds_sample = ( - flight_alpha.density.get_value_opt(t_sample) - * v_sample - * (2 * flight_alpha.rocket.radius) - / flight_alpha.dynamic_viscosity.get_value_opt(t_sample) - ) - alpha_sample = flight_alpha.angle_of_attack.get_value_opt(t_sample) - - cd_alpha_sample = flight_alpha.rocket.power_on_drag.get_value_opt( - mach_sample, reynolds_sample, alpha_sample - ) - cd_flat_sample = flight_flat.rocket.power_on_drag.get_value_opt( - mach_sample, reynolds_sample - ) - - assert cd_alpha_sample != cd_flat_sample - if cd_alpha_sample > cd_flat_sample: - assert flight_alpha.apogee < flight_flat.apogee - else: - assert flight_alpha.apogee > flight_flat.apogee diff --git a/tests/unit/mathutils/test_function.py b/tests/unit/mathutils/test_function.py index 96acf45f5..93c439def 100644 --- a/tests/unit/mathutils/test_function.py +++ b/tests/unit/mathutils/test_function.py @@ -1306,3 +1306,202 @@ def test_short_time_fft( else: assert np.all(frequencies >= -sampling_frequency / 2) assert np.all(frequencies <= sampling_frequency / 2) + + +@pytest.fixture +def bilinear_grid_2d(): + """Return a 2-D regular_grid Function for f(x, y) = 2x + 3y. + + Because the true function is bilinear, RegularGridInterpolator can + reproduce it exactly both on grid nodes and at any interior point. + + Returns + ------- + Function + Regular-grid Function object sampled over x in [0, 2] and y in [0, 2]. + """ + x_axis = np.array([0.0, 1.0, 2.0]) + y_axis = np.array([0.0, 1.0, 2.0]) + X, Y = np.meshgrid(x_axis, y_axis, indexing="ij") + data = 2.0 * X + 3.0 * Y + return Function( + ([x_axis, y_axis], data), + inputs=["x", "y"], + outputs=["z"], + interpolation="regular_grid", + ) + + +def test_regular_grid_constructor_sets_metadata(bilinear_grid_2d): + """Test that a regular_grid Function is initialised with correct metadata. + + Checks that the interpolation method, domain dimension, inputs and outputs + are all stored correctly after construction via the ``(axes, grid_data)`` + tuple form. + """ + assert bilinear_grid_2d.get_interpolation_method() == "regular_grid" + assert bilinear_grid_2d.get_extrapolation_method() == "natural" + assert bilinear_grid_2d.get_domain_dim() == 2 + assert bilinear_grid_2d.get_inputs() == ["x", "y"] + assert bilinear_grid_2d.get_outputs() == ["z"] + + +@pytest.mark.parametrize( + "x, y, expected", + [ + (0.0, 0.0, 0.0), + (1.0, 0.0, 2.0), + (0.0, 1.0, 3.0), + (2.0, 2.0, 10.0), + (0.5, 0.5, 2.5), + (1.5, 1.0, 6.0), + ([0.0, 1.0], [0.0, 1.0], [0.0, 5.0]), + ], +) +def test_2d_regular_grid_interpolation(bilinear_grid_2d, x, y, expected): + """Test in-domain evaluation of a 2-D regular_grid Function. + + Parameters + ---------- + bilinear_grid_2d : Function + 2-D regular_grid Function for f(x, y) = 2x + 3y. + x : float or list + First input coordinate(s). + y : float or list + Second input coordinate(s). + expected : float or list + Expected function value(s). + """ + result = bilinear_grid_2d(x, y) + result_opt = bilinear_grid_2d.get_value_opt(x, y) + + assert np.isclose(result, expected, atol=1e-10).all() + assert np.isclose(result_opt, expected, atol=1e-10).all() + + +@pytest.mark.parametrize( + "x, y, z, expected", + [ + (0.0, 0.0, 0.0, 0.0), + (1.0, 0.0, 0.0, 1.0), + (0.0, 1.0, 0.0, 2.0), + (0.0, 0.0, 1.0, 3.0), + (0.5, 0.5, 0.5, 3.0), + ([0.0, 1.0], [0.0, 0.0], [0.0, 0.0], [0.0, 1.0]), + ], +) +def test_3d_regular_grid_interpolation(x, y, z, expected): + """Test in-domain evaluation of a 3-D regular_grid Function. + + The sampled function is f(x, y, z) = x + 2y + 3z, which is trilinear so + RegularGridInterpolator reproduces it exactly everywhere inside the domain. + + Parameters + ---------- + x : float or list + First input coordinate(s). + y : float or list + Second input coordinate(s). + z : float or list + Third input coordinate(s). + expected : float or list + Expected function value(s). + """ + x_axis = np.array([0.0, 1.0]) + y_axis = np.array([0.0, 1.0]) + z_axis = np.array([0.0, 1.0]) + X, Y, Z = np.meshgrid(x_axis, y_axis, z_axis, indexing="ij") + data = X + 2.0 * Y + 3.0 * Z + func = Function( + ([x_axis, y_axis, z_axis], data), + inputs=["x", "y", "z"], + outputs=["w"], + interpolation="regular_grid", + ) + + result = func(x, y, z) + result_opt = func.get_value_opt(x, y, z) + + assert np.isclose(result, expected, atol=1e-10).all() + assert np.isclose(result_opt, expected, atol=1e-10).all() + + +@pytest.mark.parametrize( + "extrapolation, x_out, y_out, expected", + [ + # constant: clamp to boundary face and interpolate + ("constant", 3.0, 0.0, 4.0), # x clamped to 2 → 2*2 + 3*0 = 4 + ("constant", -1.0, 1.0, 3.0), # x clamped to 0 → 2*0 + 3*1 = 3 + ("constant", 1.0, 5.0, 8.0), # y clamped to 2 → 2*1 + 3*2 = 8 + # zero: always returns 0 outside domain + ("zero", 3.0, 0.0, 0.0), + ("zero", -1.0, 1.0, 0.0), + # natural: linear continuation — exact for a bilinear underlying function + ("natural", 3.0, 0.0, 6.0), # 2*3 + 3*0 = 6 + ("natural", -1.0, 0.0, -2.0), # 2*(-1) + 3*0 = -2 + ("natural", 0.0, 3.0, 9.0), # 2*0 + 3*3 = 9 + ], +) +def test_regular_grid_extrapolation(extrapolation, x_out, y_out, expected): + """Test out-of-domain behaviour for all extrapolation modes. + + Parameters + ---------- + extrapolation : str + Extrapolation mode: ``'constant'``, ``'zero'``, or ``'natural'``. + x_out : float + Out-of-domain x coordinate. + y_out : float + Out-of-domain y coordinate. + expected : float + Expected function value at the out-of-domain point. + """ + x_axis = np.array([0.0, 1.0, 2.0]) + y_axis = np.array([0.0, 1.0, 2.0]) + X, Y = np.meshgrid(x_axis, y_axis, indexing="ij") + data = 2.0 * X + 3.0 * Y + func = Function( + ([x_axis, y_axis], data), + inputs=["x", "y"], + outputs=["z"], + interpolation="regular_grid", + extrapolation=extrapolation, + ) + + result = func(x_out, y_out) + + assert np.isclose(result, expected, atol=1e-10) + + +@pytest.mark.parametrize( + "bad_source, match", + [ + # axes count doesn't match grid_data.ndim + ( + ([np.array([0.0, 1.0])], np.ones((2, 2))), + "Number of axes", + ), + # first axis length doesn't match first grid dimension + ( + ([np.array([0.0, 1.0, 2.0]), np.array([0.0, 1.0])], np.ones((2, 2))), + "Axis 0", + ), + ], +) +def test_regular_grid_invalid_source_raises(bad_source, match): + """Test that a malformed ``(axes, grid_data)`` source raises ValueError. + + Parameters + ---------- + bad_source : tuple + An ``(axes, grid_data)`` tuple that is structurally invalid. + match : str + Substring expected in the exception message. + """ + with pytest.raises(ValueError, match=match): + Function( + bad_source, + inputs=["x", "y"], + outputs=["z"], + interpolation="regular_grid", + ) diff --git a/tests/unit/mathutils/test_function_from_grid.py b/tests/unit/mathutils/test_function_from_grid.py deleted file mode 100644 index cc0e214c5..000000000 --- a/tests/unit/mathutils/test_function_from_grid.py +++ /dev/null @@ -1,40 +0,0 @@ -import numpy as np -import pytest - -from rocketpy.mathutils.function import Function - - -def test_from_grid_unsupported_extrapolation_raises(): - """from_grid should reject unsupported extrapolation names with ValueError.""" - mach = np.array([0.0, 1.0]) - reynolds = np.array([1e5, 2e5]) - grid = np.zeros((mach.size, reynolds.size)) - - with pytest.raises(ValueError): - Function.from_grid(grid, [mach, reynolds], extrapolation="unsupported_mode") - - -def test_from_grid_is_multidimensional_property(): - """Ensure `is_multidimensional` is True for ND grid Functions and False for 1D.""" - mach = np.array([0.0, 0.5, 1.0]) - reynolds = np.array([1e5, 2e5, 3e5]) - alpha = np.array([0.0, 2.0]) - - M, R, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - cd_data = 0.1 + 0.2 * M + 1e-7 * R + 0.01 * A - - func_nd = Function.from_grid( - cd_data, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - - assert hasattr(func_nd, "is_multidimensional") - assert func_nd.is_multidimensional is True - - # 1D Function constructed from a two-column array should not be multidimensional - src = np.column_stack((mach, 0.5 + 0.1 * mach)) - func_1d = Function(src, inputs=["Mach"], outputs="Cd") - assert hasattr(func_1d, "is_multidimensional") - assert func_1d.is_multidimensional is False diff --git a/tests/unit/mathutils/test_function_grid.py b/tests/unit/mathutils/test_function_grid.py deleted file mode 100644 index 3752d4de0..000000000 --- a/tests/unit/mathutils/test_function_grid.py +++ /dev/null @@ -1,352 +0,0 @@ -"""Unit tests for Function.from_grid() method and grid interpolation.""" - -import warnings - -import numpy as np -import pytest - -from rocketpy import Function - - -def test_from_grid_1d(): - """Test from_grid with 1D data (edge case).""" - x = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) - y_data = np.array([0.0, 1.0, 4.0, 9.0, 16.0]) # y = x^2 - - func = Function.from_grid(y_data, [x], inputs=["x"], outputs="y") - - # Test interpolation - assert abs(func(1.5) - 2.25) < 0.5 # Should be close to 1.5^2 - - -def test_from_grid_2d(): - """Test from_grid with 2D data.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0, 2.0]) - - # Create grid: f(x, y) = x + 2*y - X, Y = np.meshgrid(x, y, indexing="ij") - z_data = X + 2 * Y - - func = Function.from_grid(z_data, [x, y], inputs=["x", "y"], outputs="z") - - # Test exact points - assert func(0.0, 0.0) == 0.0 - assert func(1.0, 1.0) == 3.0 - assert func(2.0, 2.0) == 6.0 - - # Test interpolation - result = func(1.0, 0.5) - expected = 1.0 + 2 * 0.5 # = 2.0 - assert abs(result - expected) < 0.01 - - -def test_from_grid_3d_drag_coefficient(): - """Test from_grid with 3D drag coefficient data (Mach, Reynolds, Alpha).""" - # Create sample aerodynamic data - mach = np.array([0.0, 0.5, 1.0, 1.5, 2.0]) - reynolds = np.array([1e5, 5e5, 1e6]) - alpha = np.array([0.0, 2.0, 4.0, 6.0]) - - # Create a simple drag coefficient model - # Cd increases with Mach and alpha, slight dependency on Reynolds - M, Re, A = np.meshgrid(mach, reynolds, alpha, indexing="ij") - cd_data = 0.3 + 0.1 * M - 1e-7 * Re + 0.01 * A - - cd_func = Function.from_grid( - cd_data, - [mach, reynolds, alpha], - inputs=["Mach", "Reynolds", "Alpha"], - outputs="Cd", - ) - - # Test at grid points - assert abs(cd_func(0.0, 1e5, 0.0) - 0.29) < 0.01 # 0.3 - 1e-7*1e5 - assert abs(cd_func(1.0, 5e5, 0.0) - 0.35) < 0.01 # 0.3 + 0.1*1 - 1e-7*5e5 - - # Test interpolation between points - result = cd_func(0.5, 3e5, 1.0) - # Expected roughly: 0.3 + 0.1*0.5 - 1e-7*3e5 + 0.01*1.0 = 0.32 - assert 0.31 < result < 0.34 - - -def test_from_grid_extrapolation_constant(): - """Test that constant extrapolation clamps to edge values.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0, 4.0]) # y = x^2 - - func = Function.from_grid( - y, [x], inputs=["x"], outputs="y", extrapolation="constant" - ) - - # Test below lower bound - should return value at x=0 - assert func(-1.0) == 0.0 - - # Test above upper bound - should return value at x=2 - assert func(3.0) == 4.0 - - -def test_from_grid_validation_errors(): - """Test that from_grid raises appropriate errors for invalid inputs.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0, 2.0]) - - # Mismatched dimensions - X, Y = np.meshgrid(x, y, indexing="ij") - z_data = X + Y - - # Wrong number of axes - with pytest.raises(ValueError, match="Number of axes"): - Function.from_grid(z_data, [x], inputs=["x"], outputs="z") - - # Wrong axis length - with pytest.raises(ValueError, match="Axis 1 has"): - Function.from_grid( - z_data, [x, np.array([0.0, 1.0])], inputs=["x", "y"], outputs="z" - ) - - # Wrong number of inputs - with pytest.raises(ValueError, match="Number of inputs"): - Function.from_grid(z_data, [x, y], inputs=["x"], outputs="z") - - -def test_from_grid_default_inputs(): - """Test that from_grid uses default input names when not provided.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0, 2.0]) - - X, Y = np.meshgrid(x, y, indexing="ij") - z_data = X + Y - - func = Function.from_grid(z_data, [x, y]) - - # Should use default names - assert "x0" in func.__inputs__ - assert "x1" in func.__inputs__ - - -def test_from_grid_backward_compatibility(): - """Test that regular Function creation still works after adding from_grid.""" - # Test 1D function from list - func1 = Function([[0, 0], [1, 1], [2, 4], [3, 9]]) - assert func1(1.5) > 0 # Should interpolate - - # Test 2D function from array - data = np.array([[0, 0, 0], [1, 0, 1], [0, 1, 2], [1, 1, 3]]) - func2 = Function(data) - assert func2(0.5, 0.5) > 0 # Should interpolate - - # Test callable function - func3 = Function(lambda x: x**2) - assert func3(2) == 4 - - -def test_regular_grid_without_grid_interpolator_warns(): - """Test that setting `regular_grid` without a grid interpolator warns. - - This test constructs a Function from scattered points (no structured - grid). If `regular_grid` interpolation is later selected without a - grid interpolator being configured, the implementation currently - falls back to shepard interpolation and should emit a warning. The - test ensures a warning is raised in this scenario. - """ - # Create a 2D function with scattered points (not structured grid) - source = [(0, 0, 0), (1, 0, 1), (0, 1, 2), (1, 1, 3)] - func = Function( - source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" - ) - - # Now manually change interpolation to regular_grid without setting up the grid - # This simulates the fallback scenario - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - func.set_interpolation("regular_grid") - - # Check that a warning was issued - assert len(w) == 1 - assert "falling back to shepard interpolation" in str(w[0].message) - - -def test_shepard_fallback_2d_interpolation(): - """Test that shepard_fallback produces correct interpolation for 2D data. - - This test verifies the fallback interpolation works correctly when - regular_grid is set without a grid interpolator. - """ - # Create a 2D function: z = x + y - source = [ - (0, 0, 0), # f(0, 0) = 0 - (1, 0, 1), # f(1, 0) = 1 - (0, 1, 1), # f(0, 1) = 1 - (1, 1, 2), # f(1, 1) = 2 - ] - - # First, create with shepard to get baseline results - func_shepard = Function( - source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" - ) - - # Create another function and trigger the fallback - func_fallback = Function( - source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" - ) - - # Trigger fallback - with warnings.catch_warnings(): - warnings.simplefilter("ignore") # Suppress warnings for this test - func_fallback.set_interpolation("regular_grid") - - # Test that both produce the same results at exact points - assert func_fallback(0, 0) == func_shepard(0, 0) - assert func_fallback(1, 1) == func_shepard(1, 1) - - # Test interpolation at an intermediate point - result_fallback = func_fallback(0.5, 0.5) - result_shepard = func_shepard(0.5, 0.5) - assert np.isclose(result_fallback, result_shepard, atol=1e-6) - - -def test_shepard_fallback_3d_interpolation(): - """Test that shepard_fallback produces correct interpolation for 3D data. - - This test verifies the fallback interpolation works correctly for - 3-dimensional input data. - """ - # Create a 3D function: w = x + y + z - source = [ - (0, 0, 0, 0), # f(0, 0, 0) = 0 - (1, 0, 0, 1), # f(1, 0, 0) = 1 - (0, 1, 0, 1), # f(0, 1, 0) = 1 - (0, 0, 1, 1), # f(0, 0, 1) = 1 - (1, 1, 1, 3), # f(1, 1, 1) = 3 - ] - - # Create with shepard to get baseline results - func_shepard = Function( - source=source, - inputs=["x", "y", "z"], - outputs="w", - interpolation="shepard", - ) - - # Create another function and trigger the fallback - func_fallback = Function( - source=source, - inputs=["x", "y", "z"], - outputs="w", - interpolation="shepard", - ) - - # Trigger fallback - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - func_fallback.set_interpolation("regular_grid") - - # Test that both produce the same results at exact points - assert func_fallback(0, 0, 0) == func_shepard(0, 0, 0) - assert func_fallback(1, 1, 1) == func_shepard(1, 1, 1) - - # Test interpolation at an intermediate point - result_fallback = func_fallback(0.5, 0.5, 0.5) - result_shepard = func_shepard(0.5, 0.5, 0.5) - assert np.isclose(result_fallback, result_shepard, atol=1e-6) - - -def test_shepard_fallback_at_exact_data_points(): - """Test that shepard_fallback returns exact values at data points. - - When querying at exact data points, the fallback should return the - exact value stored at that point. - """ - # Create a 2D function - source = [ - (0, 0, 10), - (1, 0, 20), - (0, 1, 30), - (1, 1, 40), - ] - - func = Function( - source=source, inputs=["x", "y"], outputs="z", interpolation="shepard" - ) - - # Trigger fallback - with warnings.catch_warnings(): - warnings.simplefilter("ignore") - func.set_interpolation("regular_grid") - - # Test exact data points - should return exact values - assert func(0, 0) == 10 - assert func(1, 0) == 20 - assert func(0, 1) == 30 - assert func(1, 1) == 40 - - -def test_from_grid_unsorted_axis_warns(): - """Test that from_grid warns when axes are not sorted in ascending order.""" - y_data = np.array([0.0, 1.0, 4.0]) - - # Test with unsorted axis (descending order) - unsorted_axis = np.array([2.0, 1.0, 0.0]) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - Function.from_grid(y_data, [unsorted_axis], inputs=["x"], outputs="y") - - # Check that a warning was issued - assert len(w) == 1 - assert "not strictly sorted in ascending order" in str(w[0].message) - - -def test_from_grid_repeated_values_warns(): - """Test that from_grid warns when axes have repeated values. - - Note: RegularGridInterpolator requires strictly ascending or descending - axes. Repeated values will cause scipy to raise a ValueError after our - warning is issued. - """ - y_data = np.array([0.0, 1.0, 4.0]) - - # Test with repeated values (not strictly ascending) - repeated_axis = np.array([0.0, 1.0, 1.0]) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - # Scipy will raise ValueError after our warning, so we expect both - try: - Function.from_grid(y_data, [repeated_axis], inputs=["x"], outputs="y") - except ValueError as e: - # scipy raises this error for non-strictly-sorted axes - assert "strictly ascending" in str(e).lower() or "dimension 0" in str(e) - - # Check that a warning was issued before the error - assert len(w) == 1 - assert "not strictly sorted in ascending order" in str(w[0].message) - - -def test_from_grid_flatten_for_compatibility_false(): - """Test that flatten_for_compatibility=False skips flattening.""" - x = np.array([0.0, 1.0, 2.0]) - y = np.array([0.0, 1.0]) - - X, Y = np.meshgrid(x, y, indexing="ij") - z_data = X + Y - - func = Function.from_grid( - z_data, - [x, y], - inputs=["x", "y"], - outputs="z", - flatten_for_compatibility=False, - ) - - # Check that flattened attributes are None - assert func._domain is None - assert func._image is None - assert func.source is None - assert func.y_array is None - - # But the function should still work correctly - assert func(0.0, 0.0) == 0.0 - assert func(1.0, 1.0) == 2.0 - assert func(2.0, 1.0) == 3.0 diff --git a/tests/unit/rocket/aero_surface/test_generic_surfaces.py b/tests/unit/rocket/aero_surface/test_generic_surfaces.py index a04ed429d..267b64d18 100644 --- a/tests/unit/rocket/aero_surface/test_generic_surfaces.py +++ b/tests/unit/rocket/aero_surface/test_generic_surfaces.py @@ -74,6 +74,32 @@ def test_valid_initialization_from_csv(filename_valid_coeff): ) +def test_csv_independent_variables_accept_any_order(tmp_path): + """Checks if GenericSurface correctly maps CSV columns by header names, + regardless of independent variable column order.""" + filename = tmp_path / "valid_coefficients_shuffled_order.csv" + filename.write_text( + "mach,alpha,cL\n0,0,0\n0,1,10\n2,0,2\n2,1,12\n", + encoding="utf-8", + ) + + generic_surface = GenericSurface( + reference_area=REFERENCE_AREA, + reference_length=REFERENCE_LENGTH, + coefficients={"cL": str(filename)}, + ) + + closure = generic_surface.cL.source.__closure__ + csv_function = next( + cell.cell_contents + for cell in closure + if isinstance(cell.cell_contents, Function) + ) + + assert generic_surface.cL(1, 0, 2, 0, 0, 0, 0) == pytest.approx(12) + assert csv_function.get_interpolation_method() == "regular_grid" + + def test_compute_forces_and_moments(): """Checks if there are not logical errors in compute forces and moments""" diff --git a/tests/unit/rocket/test_rocket.py b/tests/unit/rocket/test_rocket.py index 5c3fed359..3c7725fa5 100644 --- a/tests/unit/rocket/test_rocket.py +++ b/tests/unit/rocket/test_rocket.py @@ -1,4 +1,5 @@ import warnings +from itertools import product from unittest.mock import patch import numpy as np @@ -689,3 +690,148 @@ def test_coordinate_system_orientation( static_margin_nose_to_tail = rocket_nose_to_tail.static_margin assert np.array_equal(static_margin_tail_to_nose, static_margin_nose_to_tail) + + +def test_drag_csv_header_order_independent_for_multivariable_input(tmp_path): + """Ensure drag CSV independent-variable columns are interpreted by name. + + This test checks that swapping the order of header-defined variables + (mach and reynolds) yields equivalent drag interpolation results. + """ + + ordered_csv = tmp_path / "drag_mach_reynolds.csv" + ordered_csv.write_text( + "mach,reynolds,cd\n0.5,0.1,0.6\n1.0,0.1,1.1\n0.5,0.2,0.7\n1.0,0.2,1.2\n", + encoding="utf-8", + ) + + swapped_csv = tmp_path / "drag_reynolds_mach.csv" + swapped_csv.write_text( + "reynolds,mach,cd\n0.1,0.5,0.6\n0.1,1.0,1.1\n0.2,0.5,0.7\n0.2,1.0,1.2\n", + encoding="utf-8", + ) + + rocket_ordered = Rocket( + radius=0.05, + mass=1.0, + inertia=(1.0, 1.0, 1.0), + power_off_drag=str(ordered_csv), + power_on_drag=str(ordered_csv), + center_of_mass_without_motor=0.0, + ) + + rocket_swapped = Rocket( + radius=0.05, + mass=1.0, + inertia=(1.0, 1.0, 1.0), + power_off_drag=str(swapped_csv), + power_on_drag=str(swapped_csv), + center_of_mass_without_motor=0.0, + ) + + drag_ordered = rocket_ordered.power_off_drag_7d(0, 0, 0.8, 0.15, 0, 0, 0) + drag_swapped = rocket_swapped.power_off_drag_7d(0, 0, 0.8, 0.15, 0, 0, 0) + + ordered_closure = rocket_ordered.power_off_drag_7d.source.__closure__ + swapped_closure = rocket_swapped.power_off_drag_7d.source.__closure__ + ordered_csv_function = next( + cell.cell_contents + for cell in ordered_closure + if isinstance(cell.cell_contents, Function) + ) + swapped_csv_function = next( + cell.cell_contents + for cell in swapped_closure + if isinstance(cell.cell_contents, Function) + ) + + assert drag_ordered == pytest.approx(0.95) + assert drag_swapped == pytest.approx(0.95) + assert drag_swapped == pytest.approx(drag_ordered) + assert ordered_csv_function.get_interpolation_method() == "regular_grid" + assert swapped_csv_function.get_interpolation_method() == "regular_grid" + + +def test_drag_input_types_supported_for_power_on_and_power_off(tmp_path): + """Ensure drag input processing accepts all supported input types. + + This test validates that both ``power_off_drag`` and ``power_on_drag`` + accept and correctly evaluate all supported input categories. + """ + query = (1.0, 0.0, 0.8, 0.15, 0.0, 0.0, 0.0) + + csv_drag = tmp_path / "drag_mach_reynolds.csv" + csv_drag.write_text( + "mach,reynolds,cd\n0.5,0.1,0.6\n1.0,0.1,1.1\n0.5,0.2,0.7\n1.0,0.2,1.2\n", + encoding="utf-8", + ) + + txt_drag = tmp_path / "drag_curve.txt" + txt_drag.write_text("0.0,0.2\n1.0,0.4\n", encoding="utf-8") + + function_1d = Function( + lambda mach: 0.2 + mach, + inputs=["mach"], + outputs=["cd"], + interpolation="linear", + ) + function_7d = Function( + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + mach + reynolds + ), + inputs=[ + "alpha", + "beta", + "mach", + "reynolds", + "pitch_rate", + "yaw_rate", + "roll_rate", + ], + outputs=["cd"], + interpolation="linear", + ) + + drag_7d_table = [ + (*coords, float(sum(coords))) for coords in product((0.0, 1.0), repeat=7) + ] + drag_7d_query = (1.0, 0.0, 1.0, 0.0, 1.0, 0.0, 1.0) + + test_cases = [ + ("int", 1, query, 1.0), + ("float", 0.37, query, 0.37), + ("csv_path", str(csv_drag), query, 0.95), + ("txt_path", str(txt_drag), query, 0.36), + ("function_1d", function_1d, query, 1.0), + ("function_7d", function_7d, query, 0.95), + ( + "callable_1d", + lambda mach: 0.2 + mach, + query, + 1.0, + ), + ( + "callable_7d", + lambda alpha, beta, mach, reynolds, pitch_rate, yaw_rate, roll_rate: ( + mach + reynolds + ), + query, + 0.95, + ), + ("list_pairs", [[0.0, 0.2], [1.0, 0.4]], query, 0.36), + ("tuple_pairs", ((0.0, 0.2), (1.0, 0.4)), query, 0.36), + ("list_7d_entries", drag_7d_table, drag_7d_query, 4.0), + ] + + for _, drag_input, query_point, expected in test_cases: + rocket = Rocket( + radius=0.05, + mass=1.0, + inertia=(1.0, 1.0, 1.0), + power_off_drag=drag_input, + power_on_drag=drag_input, + center_of_mass_without_motor=0.0, + ) + + assert rocket.power_off_drag_7d(*query_point) == pytest.approx(expected) + assert rocket.power_on_drag_7d(*query_point) == pytest.approx(expected)