Skip to content

Contributing

We'd love you to contribute to the TNP Statistic Library!

Issues

Questions, feature requests and bug reports are all welcome as discussions or issues.

To make it as simple as possible for us to help you, please include the output of the following call in your issue:

python -c "import tnp_statistic_library.version; print(tnp_statistic_library.version.version_info())"

Please try to always include the above unless you're unable to install the TNP Statistic Library or know it's not relevant to your question or feature request.

Adding New Metrics

This guide walks you through the complete process of adding a new statistical metric to the TNP Statistic Library. The library follows a consistent pattern that makes it straightforward to add new metrics while maintaining code quality and consistency.

Overview of the Metric Architecture

The library uses a layered architecture:

  1. Internal Implementation (tnp_statistic_library/_internal/metrics/): Core metric classes and configuration
  2. Public API (tnp_statistic_library/metrics/): User-friendly helper functions
  3. Tests (tests/metrics/): Comprehensive test coverage
  4. Documentation (docs/): API documentation and examples

Step 1: Design Your Metric

Before coding, define:

  • Metric name: Use descriptive, lowercase names with underscores (e.g., my_custom_metric)
  • Data formats supported: Decide if your metric supports:
  • record_level: Individual observation data
  • summary_level: Pre-aggregated summary data
  • Both formats
  • Required inputs: What columns/parameters does your metric need?
  • Outputs: What statistics will your metric return?

Create reusable helper functions for your calculations. These should:

  • Be pure functions that operate on Polars expressions
  • Handle edge cases (zero division, null values, etc.)
  • Be placed at the top of your metric module
  • Follow the naming convention: calculate_{metric}_expressions for shared logic

Example helper function structure:

import polars as pl

def calculate_my_metric_expressions(input_expr1: pl.Expr, input_expr2: pl.Expr) -> dict[str, pl.Expr]:
    """Calculate shared expressions for my metric.

    Args:
        input_expr1: First input expression
        input_expr2: Second input expression

    Returns:
        Dictionary mapping output column names to Polars expressions
    """
    # Handle division by zero using Polars when() expression
    result_expr = pl.when(input_expr2 != 0).then(input_expr1 / input_expr2).otherwise(None)

    return {
        "input1": input_expr1,
        "input2": input_expr2,
        "result": result_expr,
    }

Step 3: Create Configuration Classes

For each data format you support, create a configuration class that inherits from ExpandedMetricConfig:

Record-Level Configuration

import polars as pl
from tnp_statistic_library._internal.metrics._base import (
    BaseMetric,
    ExpandedMetricConfig,
)
from pydantic import Field
from typing import Literal, Annotated, Any
from tnp_statistic_library._internal.common import (
    InclusiveRange,
    InputCol,
    IsIndicator,
    IsNumeric,
    PositiveNumber,
)

class RecordLevelMyMetricConfig(ExpandedMetricConfig):
    """Configuration for record-level my metric calculation."""

    data_format: Literal["record_level"] = Field("record_level", frozen=True)
    input_col1: Annotated[InputCol, IsNumeric()]  # Add appropriate validators
    input_col2: Annotated[InputCol, IsNumeric()]

    def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
        """Return aggregation expressions for record-level data."""
        # Use your helper function or define calculations here
        return calculate_my_metric_expressions(
            pl.col(self.input_col1),
            pl.col(self.input_col2)
        )

    def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
        """Compute the metric for record-level data."""
        return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())

Summary-Level Configuration

class SummaryLevelMyMetricConfig(ExpandedMetricConfig):
    """Configuration for summary-level my metric calculation."""

    data_format: Literal["summary_level"] = Field("summary_level", frozen=True)
    sum_col1: Annotated[InputCol, PositiveNumber()]
    mean_col2: Annotated[InputCol, IsNumeric()]

    def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
        """Return aggregation expressions for summary-level data."""
        return calculate_my_metric_expressions(
            pl.col(self.sum_col1),
            pl.col(self.mean_col2)
        )

    def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
        """Compute the metric for summary-level data."""
        return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())

Step 4: Create the Configuration Union and Metric Class

Create the discriminated union and main metric class:

# Configuration union for discriminated access
MyMetricConfig = Annotated[
    RecordLevelMyMetricConfig | SummaryLevelMyMetricConfig,
    Field(discriminator="data_format"),
]

# Main metric class
class MyMetric(BaseMetric[MyMetricConfig]):
    """A metric class that computes my custom statistic.

    Provide a comprehensive docstring explaining:
    - What the metric calculates
    - Use cases and applications
    - Interpretation of results
    - References to academic literature if applicable

    Attributes:
        config: Configuration specifying data format and column names
        output_columns: List of output column names produced by this metric
        metric_type: Always set to "my_metric"
    """

    config: MyMetricConfig

    output_columns: list[str] = Field(
        default_factory=lambda: ["input1", "input2", "result"]  # Your actual outputs
    )
    metric_type: Literal["my_metric"] = Field("my_metric", frozen=True)

Step 5: Create Public API Functions

Create user-friendly helper functions in the appropriate module (tnp_statistic_library/metrics/):

  1. Choose the appropriate module based on metric type:
  2. accuracy.py: Model accuracy metrics (e.g., default accuracy, EAD accuracy, calibration tests)
  3. discrimination.py: Model discrimination metrics (e.g., AUC, Gini, KS test)
  4. normality.py: Normality tests (e.g., Shapiro-Wilk)
  5. stability.py: Data stability metrics (e.g., Population Stability Index)
  6. summary.py: General summary statistics (e.g., mean, median)
  7. Or create a new module if your metric doesn't fit existing categories

  8. Create the helper function:

import polars as pl
from typing import Literal, Any

def my_metric(
    name: str,
    dataset: pl.DataFrame | pl.LazyFrame,
    data_format: Literal["record_level", "summary_level"],
    input_col1: str,
    input_col2: str,
    segment: list[str] | None = None,
    **additional_kwargs: Any,
) -> pl.DataFrame:
    """Calculate my custom metric.

    Provide a clear description of what this metric does and how to interpret results.

    Args:
        name: Name identifier for this metric instance
        dataset: Input data as DataFrame or LazyFrame
        data_format: Either "record_level" or "summary_level"
        input_col1: Name of the first input column
        input_col2: Name of the second input column
        segment: Optional list of columns for grouping/segmentation
        **additional_kwargs: Additional metric-specific parameters

    Returns:
        DataFrame with computed metric results

    Examples:
        >>> import polars as pl
        >>> from tnp_statistic_library.metrics.summary import my_metric
        >>> 
        >>> data = pl.DataFrame({
        ...     "col1": [1, 2, 3, 4, 5],
        ...     "col2": [2, 4, 6, 8, 10]
        ... })
        >>> 
        >>> result = my_metric(
        ...     name="test_metric",
        ...     dataset=data,
        ...     data_format="record_level",
        ...     input_col1="col1",
        ...     input_col2="col2"
        ... )
        >>> print(result)
    """
    return MyMetric.build(
        dataset=dataset,
        name=name,
        data_format=data_format,
        input_col1=input_col1,
        input_col2=input_col2,
        segment=segment,
        **additional_kwargs,
    ).run_metric().collect()
  1. Update the module's __all__ list to include your new function
  2. Update the package __init__.py to import and expose your function

Step 6: Write Comprehensive Tests

Create tests in the appropriate test file (e.g., tests/metrics/test_summary.py):

import polars as pl
import pytest
from polars.testing import assert_frame_equal
from tnp_statistic_library.metrics.summary import my_metric
from tnp_statistic_library._internal.exceptions import DataValidationError

class TestMyMetric:
    """Test class for MyMetric with comprehensive coverage."""

    def test_my_metric_record_level_without_segments(self):
        """Test with record-level data without segments."""
        result = my_metric(
            name="test_my_metric",
            dataset=pl.DataFrame({
                "col1": [1, 2, 3, 4, 5],
                "col2": [2, 4, 6, 8, 10],
            }),
            data_format="record_level",
            input_col1="col1",
            input_col2="col2",
        )

        assert result.shape == (1, 4)  # group_key + your outputs
        assert result["result"][0] == 0.5  # Expected result

    def test_my_metric_with_segments(self):
        """Test with segmented data."""
        result = my_metric(
            name="test_segmented",
            dataset=pl.DataFrame({
                "col1": [1, 2, 3, 4],
                "col2": [2, 4, 6, 8],
                "segment": ["A", "A", "B", "B"],
            }),
            data_format="record_level",
            input_col1="col1",
            input_col2="col2",
            segment=["segment"],
        )

        assert result.shape == (2, 4)
        assert set(result["segment"].to_list()) == {"A", "B"}

    def test_my_metric_edge_cases(self):
        """Test edge cases like zero division, null values, etc."""
        # Test with zero values
        result = my_metric(
            name="test_zero_division",
            dataset=pl.DataFrame({
                "col1": [0, 1, 2],
                "col2": [0, 0, 1],
            }),
            data_format="record_level",
            input_col1="col1",
            input_col2="col2",
        )

        # Should handle division by zero gracefully
        assert result["result"].null_count() > 0

    def test_my_metric_summary_level(self):
        """Test with summary-level data."""
        result = my_metric(
            name="test_summary",
            dataset=pl.DataFrame({
                "sum_col1": [10, 20],
                "mean_col2": [2.0, 4.0],
            }),
            data_format="summary_level",
            input_col1="sum_col1",
            input_col2="mean_col2",
        )

        assert result.shape == (1, 4)

    def test_my_metric_validation_errors(self):
        """Test that appropriate validation errors are raised."""
        with pytest.raises(DataValidationError):
            my_metric(
                name="test_validation",
                dataset=pl.DataFrame({
                    "col1": ["a", "b", "c"],  # Non-numeric data
                    "col2": [1, 2, 3],
                }),
                data_format="record_level",
                input_col1="col1",
                input_col2="col2",
            )

Step 7: Add Documentation

Create documentation in the appropriate API file:

  1. API Documentation: Add to the relevant file in docs/api/ (e.g., docs/api/accuracy.md)
  2. Workflow Examples: Add practical examples in docs/workflows/
  3. Update the index: Ensure your new metric is listed in the appropriate index

Step 8: Validate Your Implementation

Run these commands to ensure your metric is properly implemented:

# Run tests for your specific metric
uv run pytest tests/metrics/test_{your_module}.py::TestYourMetric -v

# Run all tests to ensure no regressions
uv run pytest

# Run type checking (if available)
uv run mypy tnp_statistic_library/metrics/{your_module}.py

# Run linting
uv run ruff check tnp_statistic_library/metrics/{your_module}.py

# Test documentation builds
uv run mkdocs serve

You can also use the just command runner to run these commands easily:

just all

Common Patterns and Best Practices

1. Handling Edge Cases

Always handle these scenarios:

  • Division by zero using pl.when(denominator != 0).then(numerator / denominator).otherwise(None)
  • Null/missing values
  • Empty datasets
  • Insufficient sample sizes
  • Invalid input ranges

2. Column Validation

Use appropriate validators for input columns:

  • IsNumeric(): Ensures column contains numeric data
  • PositiveNumber(): Ensures positive numeric values
  • InclusiveRange(min, max): Ensures values fall within range
  • IsIndicator(): Ensures binary indicator values (0/1)

3. Segmentation Support

All metrics support segmentation by default. The segment parameter allows users to group results by specified columns.

4. LazyFrame Support

All calculations use Polars LazyFrames for performance. Your compute method should return a LazyFrame that can be collected when needed.

5. Naming Conventions

  • Metric names: Use lowercase with underscores (e.g., my_custom_metric)
  • Configuration classes: Use descriptive names ending with Config
  • Helper functions: Use descriptive names with calculate_ prefix
  • Test classes: Use Test{MetricName} format

Example: Real Metric Implementation

Here's how the existing mean metric is implemented as a reference:

# From tnp_statistic_library/_internal/metrics/summary.py

class MeanConfig(ExpandedMetricConfig):
    """Configuration for computing the mean statistic.

    This configuration specifies the column for which to compute the mean value,
    optionally grouped by segment columns.

    Attributes:
        variable: The name of the column for which to compute the mean.
            Must be a valid input column.
    """

    variable: Annotated[InputCol, IsNumeric()]

    def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
        """Get aggregation expressions for mean calculation.

        Returns:
            A dictionary mapping output column names to Polars expressions
            for computing the mean statistic.
        """
        variable = self.variable
        lit_variable = pl.lit(variable)
        mean = pl.col(variable).mean()

        return {
            "variable_col": lit_variable,
            "mean": mean,
        }

    def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
        """Compute the mean metric.

        Args:
            lf: The input LazyFrame containing the data to analyze.

        Returns:
            A LazyFrame containing the computed metric results.
        """
        return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())

class MeanSummary(BaseMetric[MeanConfig]):
    """A metric that calculates the mean of the specified values in a grouped dataset.

    This metric computes the arithmetic mean of a specified column, optionally
    grouped by segment columns. It handles missing values according to Polars
    default behavior.

    Attributes:
        config: Configuration for the mean calculation, including the column to aggregate.
        output_columns: List of output column names: ["variable_col", "mean"].
        metric_type: Always set to "mean".
    """

    config: MeanConfig
    output_columns: list[str] = Field(default_factory=lambda: ["variable_col", "mean"])
    metric_type: Literal["mean"] = Field("mean", frozen=True)

# From tnp_statistic_library/metrics/summary.py

def mean(
    name: str,
    dataset: pl.LazyFrame | pl.DataFrame,
    variable: str,
    segment: list[str] | None = None,
) -> pl.DataFrame:
    """Calculate the mean summary for the given dataset and parameters.

    Args:
        name: Name of the metric.
        dataset: Dataset to compute the mean on.
        variable: Column name for which to compute the mean.
        segment: Segmentation groups for calculation.

    Returns:
        DataFrame containing the mean summary and associated metadata.
    """
    kw = {k: v for k, v in locals().items() if v is not None}
    return MeanSummary.build(**kw).run_metric().collect()

This comprehensive guide should help you successfully add new metrics while maintaining consistency with the existing codebase. If you have questions or need clarification on any step, please open an issue on GitHub (recommended) or add a post on the Teams Channel