Source code for gain.templates

"""Central Jinja2 template environment for GAIn.

Provides a singleton Environment that resolves templates in two stages:

1. Physical files under gain/templates/template_files/ via PackageLoader.
2. Strings supplied by callables registered under the
   "gain.templates.providers" entry-point group.  Each callable must
   return a ``dict[str, str]`` mapping template name to template source.
   All provider dictionaries are merged lazily on first miss.

Raises ``jinja2.TemplateNotFound`` if a name is not found in either stage.
"""
from __future__ import annotations

from dataclasses import dataclass, field
from importlib.metadata import entry_points
from typing import TYPE_CHECKING

from jinja2 import (
    BaseLoader,
    ChoiceLoader,
    Environment,
    PackageLoader,
    Template,
    TemplateNotFound,
)

if TYPE_CHECKING:
    from collections.abc import Callable


@dataclass
class _TemplateCache:
    env: Environment | None = field(default=None)
    provider_cache: dict[str, str] | None = field(default=None)


_state = _TemplateCache()


def _get_provider_templates() -> dict[str, str]:
    if _state.provider_cache is None:
        merged: dict[str, str] = {}
        for ep in entry_points(group="gain.templates.providers"):
            provider_fn = ep.load()
            for name, source in provider_fn().items():
                if name in merged and merged[name] != source:
                    raise ValueError(
                        f"Template name conflict: '{name}' registered by "
                        f"provider '{ep.name}' conflicts with an existing "
                        f"provider registration.",
                    )
                merged[name] = source
        _state.provider_cache = merged
    return _state.provider_cache


class _ProviderLoader(BaseLoader):
    """Jinja2 loader that reads templates from entry-point provider dicts."""

    def get_source(
        self, environment: Environment, template: str,  # noqa: ARG002
    ) -> tuple[str, None, Callable[[], bool]]:
        source = _get_provider_templates().get(template)
        if source is None:
            raise TemplateNotFound(template)
        return source, None, lambda: True


[docs] def get_jinja_env() -> Environment: """Return the singleton GAIn Jinja2 Environment.""" if _state.env is None: _state.env = Environment( # noqa: S701 loader=ChoiceLoader([ PackageLoader("gain.templates", "template_files"), _ProviderLoader(), ]), ) return _state.env
[docs] def get_template(name: str) -> Template: """Convenience wrapper — raises TemplateNotFound if name is absent.""" return get_jinja_env().get_template(name)