Source code for fitting.custom_function_evaluator

"""
Custom Function Evaluator Module.

This module provides safe runtime evaluation of custom mathematical functions
for curve fitting.
"""

# Standard library
import ast
import re
from typing import Any, Callable, List, Optional, Tuple, Union

# Third-party packages
import numpy as np
import pandas as pd
from numpy.typing import NDArray

# Local imports
from config import MATH_FUNCTION_REPLACEMENTS_COMPILED
from i18n import t
from utils import (
    EquationError,
    ValidationError,
    get_logger,
    validate_parameter_names,
)

logger = get_logger(__name__)

_ALLOWED_AST_NODES = {
    ast.Expression,
    ast.BinOp,
    ast.UnaryOp,
    ast.Call,
    ast.Name,
    ast.Load,
    ast.Constant,
    ast.Attribute,
}
_ALLOWED_BIN_OPS = {
    ast.Add,
    ast.Sub,
    ast.Mult,
    ast.Div,
    ast.Pow,
    ast.Mod,
}
_ALLOWED_UNARY_OPS = {
    ast.UAdd,
    ast.USub,
}
_ALLOWED_NUMPY_MEMBERS = {
    "sin",
    "cos",
    "tan",
    "arcsin",
    "arccos",
    "arctan",
    "sinh",
    "cosh",
    "tanh",
    "arcsinh",
    "arccosh",
    "arctanh",
    "exp",
    "sqrt",
    "cbrt",
    "power",
    "log",
    "log10",
    "log2",
    "abs",
    "floor",
    "ceil",
    "round",
    "max",
    "min",
    "mean",
    "sum",
    "pi",
    "e",
}


[docs] class CustomFunctionEvaluator: """Evaluates custom mathematical functions safely at runtime."""
[docs] def __init__( self, equation_str: str, parameter_names: List[str], num_independent_vars: int = 1, ): logger.info("Creating custom function evaluator: '%s'", equation_str) logger.debug( "Parameters: %s, Independent vars: %s", parameter_names, num_independent_vars, ) if not equation_str or not equation_str.strip(): logger.error("Empty equation string provided") raise ValidationError(t("error.equation_empty")) if num_independent_vars < 1: logger.error( "Invalid number of independent variables: %s", num_independent_vars ) raise ValidationError( t("error.invalid_independent_vars", num=num_independent_vars) ) validate_parameter_names(parameter_names) self.original_equation_str = equation_str self.parameter_names = parameter_names self.num_independent_vars = num_independent_vars try: self.equation_str = self._prepare_formula(equation_str) self._compiled_code = self._compile_and_validate_formula(self.equation_str) self.function = self._create_function() logger.info("Custom function evaluator created successfully") except Exception as e: logger.error("Failed to create custom function: %s", str(e), exc_info=True) raise EquationError(t("error.equation_create_error", error=str(e)))
def _prepare_formula(self, equation_str: str) -> str: """Convert mathematical notation to NumPy function calls.""" logger.debug("Preparing formula: %s", equation_str) try: prepared = equation_str for compiled_pattern, replacement in MATH_FUNCTION_REPLACEMENTS_COMPILED: prepared = compiled_pattern.sub(replacement, prepared) logger.debug("Formula prepared: %s", prepared) return prepared except Exception as e: logger.error("Error preparing formula: %s", str(e), exc_info=True) raise EquationError(t("error.equation_prepare_error", error=str(e))) def _compile_and_validate_formula(self, expression: str) -> Any: """Validate AST against a strict allowlist and compile to code object.""" try: tree = ast.parse(expression, mode="eval") except SyntaxError as e: logger.error("Formula has syntax error: %s", str(e)) raise EquationError(t("error.equation_syntax_error", error=str(e))) allowed_names = set(self.parameter_names) allowed_names.add("np") if self.num_independent_vars == 1: allowed_names.add("x") else: allowed_names.update(f"x_{i}" for i in range(self.num_independent_vars)) for node in ast.walk(tree): if ( type(node) not in _ALLOWED_AST_NODES and not isinstance(node, ast.operator) and not isinstance(node, ast.unaryop) ): raise EquationError( t( "error.evaluation_error", error=f"Unsupported syntax: {type(node).__name__}", ) ) if isinstance(node, ast.BinOp) and type(node.op) not in _ALLOWED_BIN_OPS: raise EquationError( t( "error.evaluation_error", error=f"Unsupported operator: {type(node.op).__name__}", ) ) if ( isinstance(node, ast.UnaryOp) and type(node.op) not in _ALLOWED_UNARY_OPS ): raise EquationError( t( "error.evaluation_error", error=f"Unsupported unary operator: {type(node.op).__name__}", ) ) if isinstance(node, ast.Name): if node.id not in allowed_names: raise EquationError( t( "error.evaluation_error", error=f"Unknown identifier: {node.id}", ) ) if isinstance(node, ast.Attribute): if not isinstance(node.value, ast.Name) or node.value.id != "np": raise EquationError( t( "error.evaluation_error", error="Only numpy attributes are allowed", ) ) if node.attr not in _ALLOWED_NUMPY_MEMBERS: raise EquationError( t( "error.evaluation_error", error=f"Unsupported numpy member: np.{node.attr}", ) ) if isinstance(node, ast.Call): if node.keywords: raise EquationError( t( "error.evaluation_error", error="Keyword arguments are not allowed", ) ) if not isinstance(node.func, ast.Attribute): raise EquationError( t( "error.evaluation_error", error="Only numpy function calls are allowed", ) ) if ( not isinstance(node.func.value, ast.Name) or node.func.value.id != "np" ): raise EquationError( t( "error.evaluation_error", error="Only numpy function calls are allowed", ) ) if node.func.attr not in _ALLOWED_NUMPY_MEMBERS: raise EquationError( t( "error.evaluation_error", error=f"Unsupported numpy function: np.{node.func.attr}", ) ) if isinstance(node, ast.Constant): if not isinstance(node.value, (int, float)): raise EquationError( t( "error.evaluation_error", error="Only numeric constants are allowed", ) ) return compile(tree, "<custom_formula>", "eval") def _create_function(self) -> Callable: """Create callable function from validated expression.""" num_indep = self.num_independent_vars def custom_func(x: NDArray, *params: float) -> NDArray: if len(params) != len(self.parameter_names): error_msg = f"Expected {len(self.parameter_names)} parameters, got {len(params)}" logger.error(error_msg) raise EquationError(error_msg) if num_indep == 1: namespace: dict[str, Any] = { "np": np, "x": x, **dict(zip(self.parameter_names, params)), } else: if x.ndim == 1: raise EquationError( f"Expected 2D array for {num_indep} independent variables, got 1D" ) if x.shape[1] != num_indep: raise EquationError( f"Expected {num_indep} columns in x array, got {x.shape[1]}" ) namespace = { "np": np, **{f"x_{i}": x[:, i] for i in range(num_indep)}, **dict(zip(self.parameter_names, params)), } try: result = eval(self._compiled_code, {"__builtins__": {}}, namespace) return np.asarray(result) except ZeroDivisionError as e: logger.error("Division by zero in formula: %s", str(e)) raise EquationError(t("error.equation_division_by_zero", error=str(e))) except OverflowError as e: logger.error("Overflow in formula evaluation: %s", str(e)) raise EquationError(t("error.equation_overflow_error", error=str(e))) except Exception as e: logger.error("Error evaluating formula: %s", str(e), exc_info=True) raise EquationError(t("error.evaluation_error", error=str(e))) return custom_func def _generate_equation_template(self) -> str: """Generate equation template for displaying fitted parameters.""" if not self.parameter_names: return "y=" + self.original_equation_str pattern = re.compile( "|".join(r"\b" + re.escape(p) + r"\b" for p in self.parameter_names) ) template = pattern.sub( lambda m: "{" + m.group(0) + "}", self.original_equation_str ) return "y=" + template
[docs] def fit( self, data: Union[dict, pd.DataFrame], x_name: Union[str, List[str]], y_name: str, initial_guess_override: Optional[List[Optional[float]]] = None, bounds_override: Optional[ Tuple[List[Optional[float]], List[Optional[float]]] ] = None, ) -> Tuple[str, NDArray, str, Optional[dict]]: """Perform curve fitting using this custom expression.""" logger.info("Performing custom fit: %s", self.original_equation_str) from fitting.fitting_utils import generic_fit, merge_bounds, merge_initial_guess try: equation_template = self._generate_equation_template() logger.debug("Equation template: %s", equation_template) computed_guess = [1.0] * len(self.parameter_names) initial_guess = merge_initial_guess(computed_guess, initial_guess_override) bounds = ( merge_bounds( None, bounds_override[0], bounds_override[1], len(self.parameter_names), ) if bounds_override is not None else None ) result = generic_fit( data=data, x_name=x_name, y_name=y_name, fit_func=self.function, param_names=self.parameter_names, equation_template=equation_template, equation_formula=self.original_equation_str, initial_guess=initial_guess, bounds=bounds, ) logger.info("Custom fit completed successfully") return result except Exception as e: logger.error("Custom fit failed: %s", str(e), exc_info=True) raise
[docs] def get_function(self) -> Callable: """Get generated callable function for direct use.""" return self.function
[docs] def __repr__(self) -> str: """String representation of evaluator.""" return ( f"CustomFunctionEvaluator(formula='{self.original_equation_str}', " f"params={self.parameter_names})" )