Contributing¶
We'd love you to contribute to the TNP Statistic Library!
Issues¶
Questions, feature requests and bug reports are all welcome as discussions or issues.
To make it as simple as possible for us to help you, please include the output of the following call in your issue:
python -c "import tnp_statistic_library.version; print(tnp_statistic_library.version.version_info())"
Please try to always include the above unless you're unable to install the TNP Statistic Library or know it's not relevant to your question or feature request.
Adding New Metrics¶
This guide walks you through the complete process of adding a new statistical metric to the TNP Statistic Library. The library follows a consistent pattern that makes it straightforward to add new metrics while maintaining code quality and consistency.
Overview of the Metric Architecture¶
The library uses a layered architecture:
- Internal Implementation (
tnp_statistic_library/_internal/metrics/): Core metric classes and configuration - Public API (
tnp_statistic_library/metrics/): User-friendly helper functions - Tests (
tests/metrics/): Comprehensive test coverage - Documentation (
docs/): API documentation and examples
Step 1: Design Your Metric¶
Before coding, define:
- Metric name: Use descriptive, lowercase names with underscores (e.g.,
my_custom_metric) - Data formats supported: Decide if your metric supports:
record_level: Individual observation datasummary_level: Pre-aggregated summary data- Both formats
- Required inputs: What columns/parameters does your metric need?
- Outputs: What statistics will your metric return?
Step 2: Create Helper Functions (Recommended)¶
Create reusable helper functions for your calculations. These should:
- Be pure functions that operate on Polars expressions
- Handle edge cases (zero division, null values, etc.)
- Be placed at the top of your metric module
- Follow the naming convention:
calculate_{metric}_expressionsfor shared logic
Example helper function structure:
import polars as pl
def calculate_my_metric_expressions(input_expr1: pl.Expr, input_expr2: pl.Expr) -> dict[str, pl.Expr]:
"""Calculate shared expressions for my metric.
Args:
input_expr1: First input expression
input_expr2: Second input expression
Returns:
Dictionary mapping output column names to Polars expressions
"""
# Handle division by zero using Polars when() expression
result_expr = pl.when(input_expr2 != 0).then(input_expr1 / input_expr2).otherwise(None)
return {
"input1": input_expr1,
"input2": input_expr2,
"result": result_expr,
}
Step 3: Create Configuration Classes¶
For each data format you support, create a configuration class that inherits from ExpandedMetricConfig:
Record-Level Configuration¶
import polars as pl
from tnp_statistic_library._internal.metrics._base import (
BaseMetric,
ExpandedMetricConfig,
)
from pydantic import Field
from typing import Literal, Annotated, Any
from tnp_statistic_library._internal.common import (
InclusiveRange,
InputCol,
IsIndicator,
IsNumeric,
PositiveNumber,
)
class RecordLevelMyMetricConfig(ExpandedMetricConfig):
"""Configuration for record-level my metric calculation."""
data_format: Literal["record_level"] = Field("record_level", frozen=True)
input_col1: Annotated[InputCol, IsNumeric()] # Add appropriate validators
input_col2: Annotated[InputCol, IsNumeric()]
def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
"""Return aggregation expressions for record-level data."""
# Use your helper function or define calculations here
return calculate_my_metric_expressions(
pl.col(self.input_col1),
pl.col(self.input_col2)
)
def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
"""Compute the metric for record-level data."""
return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())
Summary-Level Configuration¶
class SummaryLevelMyMetricConfig(ExpandedMetricConfig):
"""Configuration for summary-level my metric calculation."""
data_format: Literal["summary_level"] = Field("summary_level", frozen=True)
sum_col1: Annotated[InputCol, PositiveNumber()]
mean_col2: Annotated[InputCol, IsNumeric()]
def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
"""Return aggregation expressions for summary-level data."""
return calculate_my_metric_expressions(
pl.col(self.sum_col1),
pl.col(self.mean_col2)
)
def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
"""Compute the metric for summary-level data."""
return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())
Step 4: Create the Configuration Union and Metric Class¶
Create the discriminated union and main metric class:
# Configuration union for discriminated access
MyMetricConfig = Annotated[
RecordLevelMyMetricConfig | SummaryLevelMyMetricConfig,
Field(discriminator="data_format"),
]
# Main metric class
class MyMetric(BaseMetric[MyMetricConfig]):
"""A metric class that computes my custom statistic.
Provide a comprehensive docstring explaining:
- What the metric calculates
- Use cases and applications
- Interpretation of results
- References to academic literature if applicable
Attributes:
config: Configuration specifying data format and column names
output_columns: List of output column names produced by this metric
metric_type: Always set to "my_metric"
"""
config: MyMetricConfig
output_columns: list[str] = Field(
default_factory=lambda: ["input1", "input2", "result"] # Your actual outputs
)
metric_type: Literal["my_metric"] = Field("my_metric", frozen=True)
Step 5: Create Public API Functions¶
Create user-friendly helper functions in the appropriate module (tnp_statistic_library/metrics/):
- Choose the appropriate module based on metric type:
accuracy.py: Model accuracy metrics (e.g., default accuracy, EAD accuracy, calibration tests)discrimination.py: Model discrimination metrics (e.g., AUC, Gini, KS test)normality.py: Normality tests (e.g., Shapiro-Wilk)stability.py: Data stability metrics (e.g., Population Stability Index)summary.py: General summary statistics (e.g., mean, median)-
Or create a new module if your metric doesn't fit existing categories
-
Create the helper function:
import polars as pl
from typing import Literal, Any
def my_metric(
name: str,
dataset: pl.DataFrame | pl.LazyFrame,
data_format: Literal["record_level", "summary_level"],
input_col1: str,
input_col2: str,
segment: list[str] | None = None,
**additional_kwargs: Any,
) -> pl.DataFrame:
"""Calculate my custom metric.
Provide a clear description of what this metric does and how to interpret results.
Args:
name: Name identifier for this metric instance
dataset: Input data as DataFrame or LazyFrame
data_format: Either "record_level" or "summary_level"
input_col1: Name of the first input column
input_col2: Name of the second input column
segment: Optional list of columns for grouping/segmentation
**additional_kwargs: Additional metric-specific parameters
Returns:
DataFrame with computed metric results
Examples:
>>> import polars as pl
>>> from tnp_statistic_library.metrics.summary import my_metric
>>>
>>> data = pl.DataFrame({
... "col1": [1, 2, 3, 4, 5],
... "col2": [2, 4, 6, 8, 10]
... })
>>>
>>> result = my_metric(
... name="test_metric",
... dataset=data,
... data_format="record_level",
... input_col1="col1",
... input_col2="col2"
... )
>>> print(result)
"""
return MyMetric.build(
dataset=dataset,
name=name,
data_format=data_format,
input_col1=input_col1,
input_col2=input_col2,
segment=segment,
**additional_kwargs,
).run_metric().collect()
- Update the module's
__all__list to include your new function - Update the package
__init__.pyto import and expose your function
Step 6: Write Comprehensive Tests¶
Create tests in the appropriate test file (e.g., tests/metrics/test_summary.py):
import polars as pl
import pytest
from polars.testing import assert_frame_equal
from tnp_statistic_library.metrics.summary import my_metric
from tnp_statistic_library._internal.exceptions import DataValidationError
class TestMyMetric:
"""Test class for MyMetric with comprehensive coverage."""
def test_my_metric_record_level_without_segments(self):
"""Test with record-level data without segments."""
result = my_metric(
name="test_my_metric",
dataset=pl.DataFrame({
"col1": [1, 2, 3, 4, 5],
"col2": [2, 4, 6, 8, 10],
}),
data_format="record_level",
input_col1="col1",
input_col2="col2",
)
assert result.shape == (1, 4) # group_key + your outputs
assert result["result"][0] == 0.5 # Expected result
def test_my_metric_with_segments(self):
"""Test with segmented data."""
result = my_metric(
name="test_segmented",
dataset=pl.DataFrame({
"col1": [1, 2, 3, 4],
"col2": [2, 4, 6, 8],
"segment": ["A", "A", "B", "B"],
}),
data_format="record_level",
input_col1="col1",
input_col2="col2",
segment=["segment"],
)
assert result.shape == (2, 4)
assert set(result["segment"].to_list()) == {"A", "B"}
def test_my_metric_edge_cases(self):
"""Test edge cases like zero division, null values, etc."""
# Test with zero values
result = my_metric(
name="test_zero_division",
dataset=pl.DataFrame({
"col1": [0, 1, 2],
"col2": [0, 0, 1],
}),
data_format="record_level",
input_col1="col1",
input_col2="col2",
)
# Should handle division by zero gracefully
assert result["result"].null_count() > 0
def test_my_metric_summary_level(self):
"""Test with summary-level data."""
result = my_metric(
name="test_summary",
dataset=pl.DataFrame({
"sum_col1": [10, 20],
"mean_col2": [2.0, 4.0],
}),
data_format="summary_level",
input_col1="sum_col1",
input_col2="mean_col2",
)
assert result.shape == (1, 4)
def test_my_metric_validation_errors(self):
"""Test that appropriate validation errors are raised."""
with pytest.raises(DataValidationError):
my_metric(
name="test_validation",
dataset=pl.DataFrame({
"col1": ["a", "b", "c"], # Non-numeric data
"col2": [1, 2, 3],
}),
data_format="record_level",
input_col1="col1",
input_col2="col2",
)
Step 7: Add Documentation¶
Create documentation in the appropriate API file:
- API Documentation: Add to the relevant file in
docs/api/(e.g.,docs/api/accuracy.md) - Workflow Examples: Add practical examples in
docs/workflows/ - Update the index: Ensure your new metric is listed in the appropriate index
Step 8: Validate Your Implementation¶
Run these commands to ensure your metric is properly implemented:
# Run tests for your specific metric
uv run pytest tests/metrics/test_{your_module}.py::TestYourMetric -v
# Run all tests to ensure no regressions
uv run pytest
# Run type checking (if available)
uv run mypy tnp_statistic_library/metrics/{your_module}.py
# Run linting
uv run ruff check tnp_statistic_library/metrics/{your_module}.py
# Test documentation builds
uv run mkdocs serve
You can also use the just command runner to run these commands easily:
Common Patterns and Best Practices¶
1. Handling Edge Cases¶
Always handle these scenarios:
- Division by zero using
pl.when(denominator != 0).then(numerator / denominator).otherwise(None) - Null/missing values
- Empty datasets
- Insufficient sample sizes
- Invalid input ranges
2. Column Validation¶
Use appropriate validators for input columns:
IsNumeric(): Ensures column contains numeric dataPositiveNumber(): Ensures positive numeric valuesInclusiveRange(min, max): Ensures values fall within rangeIsIndicator(): Ensures binary indicator values (0/1)
3. Segmentation Support¶
All metrics support segmentation by default. The segment parameter allows users to group results by specified columns.
4. LazyFrame Support¶
All calculations use Polars LazyFrames for performance. Your compute method should return a LazyFrame that can be collected when needed.
5. Naming Conventions¶
- Metric names: Use lowercase with underscores (e.g.,
my_custom_metric) - Configuration classes: Use descriptive names ending with
Config - Helper functions: Use descriptive names with
calculate_prefix - Test classes: Use
Test{MetricName}format
Example: Real Metric Implementation¶
Here's how the existing mean metric is implemented as a reference:
# From tnp_statistic_library/_internal/metrics/summary.py
class MeanConfig(ExpandedMetricConfig):
"""Configuration for computing the mean statistic.
This configuration specifies the column for which to compute the mean value,
optionally grouped by segment columns.
Attributes:
variable: The name of the column for which to compute the mean.
Must be a valid input column.
"""
variable: Annotated[InputCol, IsNumeric()]
def get_aggregation_expressions(self) -> dict[str, pl.Expr]:
"""Get aggregation expressions for mean calculation.
Returns:
A dictionary mapping output column names to Polars expressions
for computing the mean statistic.
"""
variable = self.variable
lit_variable = pl.lit(variable)
mean = pl.col(variable).mean()
return {
"variable_col": lit_variable,
"mean": mean,
}
def compute(self, lf: pl.LazyFrame) -> pl.LazyFrame:
"""Compute the mean metric.
Args:
lf: The input LazyFrame containing the data to analyze.
Returns:
A LazyFrame containing the computed metric results.
"""
return self._return_grouped(lf, self.segment).agg(**self.get_aggregation_expressions())
class MeanSummary(BaseMetric[MeanConfig]):
"""A metric that calculates the mean of the specified values in a grouped dataset.
This metric computes the arithmetic mean of a specified column, optionally
grouped by segment columns. It handles missing values according to Polars
default behavior.
Attributes:
config: Configuration for the mean calculation, including the column to aggregate.
output_columns: List of output column names: ["variable_col", "mean"].
metric_type: Always set to "mean".
"""
config: MeanConfig
output_columns: list[str] = Field(default_factory=lambda: ["variable_col", "mean"])
metric_type: Literal["mean"] = Field("mean", frozen=True)
# From tnp_statistic_library/metrics/summary.py
def mean(
name: str,
dataset: pl.LazyFrame | pl.DataFrame,
variable: str,
segment: list[str] | None = None,
) -> pl.DataFrame:
"""Calculate the mean summary for the given dataset and parameters.
Args:
name: Name of the metric.
dataset: Dataset to compute the mean on.
variable: Column name for which to compute the mean.
segment: Segmentation groups for calculation.
Returns:
DataFrame containing the mean summary and associated metadata.
"""
kw = {k: v for k, v in locals().items() if v is not None}
return MeanSummary.build(**kw).run_metric().collect()
This comprehensive guide should help you successfully add new metrics while maintaining consistency with the existing codebase. If you have questions or need clarification on any step, please open an issue on GitHub (recommended) or add a post on the Teams Channel