"""Central Jinja2 template environment for GAIn.
Provides a singleton Environment that resolves templates in two stages:
1. Physical files under gain/templates/template_files/ via PackageLoader.
2. Strings supplied by callables registered under the
"gain.templates.providers" entry-point group. Each callable must
return a ``dict[str, str]`` mapping template name to template source.
All provider dictionaries are merged lazily on first miss.
Raises ``jinja2.TemplateNotFound`` if a name is not found in either stage.
"""
from __future__ import annotations
from dataclasses import dataclass, field
from importlib.metadata import entry_points
from typing import TYPE_CHECKING
from jinja2 import (
BaseLoader,
ChoiceLoader,
Environment,
PackageLoader,
Template,
TemplateNotFound,
)
if TYPE_CHECKING:
from collections.abc import Callable
@dataclass
class _TemplateCache:
env: Environment | None = field(default=None)
provider_cache: dict[str, str] | None = field(default=None)
_state = _TemplateCache()
def _get_provider_templates() -> dict[str, str]:
if _state.provider_cache is None:
merged: dict[str, str] = {}
for ep in entry_points(group="gain.templates.providers"):
provider_fn = ep.load()
for name, source in provider_fn().items():
if name in merged and merged[name] != source:
raise ValueError(
f"Template name conflict: '{name}' registered by "
f"provider '{ep.name}' conflicts with an existing "
f"provider registration.",
)
merged[name] = source
_state.provider_cache = merged
return _state.provider_cache
class _ProviderLoader(BaseLoader):
"""Jinja2 loader that reads templates from entry-point provider dicts."""
def get_source(
self, environment: Environment, template: str, # noqa: ARG002
) -> tuple[str, None, Callable[[], bool]]:
source = _get_provider_templates().get(template)
if source is None:
raise TemplateNotFound(template)
return source, None, lambda: True
[docs]
def get_jinja_env() -> Environment:
"""Return the singleton GAIn Jinja2 Environment."""
if _state.env is None:
_state.env = Environment( # noqa: S701
loader=ChoiceLoader([
PackageLoader("gain.templates", "template_files"),
_ProviderLoader(),
]),
)
return _state.env
[docs]
def get_template(name: str) -> Template:
"""Convenience wrapper — raises TemplateNotFound if name is absent."""
return get_jinja_env().get_template(name)