from __future__ import annotations
import datetime
import logging
import os
import pickle # noqa: S403
from abc import abstractmethod
from dataclasses import dataclass
from enum import Enum
from typing import Any, cast
import fsspec
from gain.task_graph.graph import Task, TaskDesc
from gain.utils import fs_utils
logger = logging.getLogger(__name__)
[docs]
class CacheRecordType(Enum):
NEEDS_COMPUTE = 0
COMPUTED = 1
ERROR = 2
[docs]
@dataclass(frozen=True)
class CacheRecord:
"""Encapsulate information about a task in the cache."""
type: CacheRecordType
result_or_error: Any = None
@property
def result(self) -> Any:
assert self.type == CacheRecordType.COMPUTED
return self.result_or_error
@property
def error(self) -> Any:
assert self.type == CacheRecordType.ERROR
return self.result_or_error
[docs]
def invalidate(self) -> CacheRecord:
"""Return a new instance that needs to be recomputed."""
return CacheRecord(CacheRecordType.NEEDS_COMPUTE, self.result_or_error)
[docs]
class TaskCache:
"""Store the result of a task in a file and reuse it if possible."""
[docs]
@abstractmethod
def get_record(
self, task_desc: TaskDesc,
) -> CacheRecord:
"""For task in the `graph` load and yield the cache record."""
[docs]
@abstractmethod
def cache(
self, task: Task, *,
is_error: bool, result: Any,
) -> None:
"""Cache the result or exception of a task."""
[docs]
@staticmethod
def create(
*,
force: bool = False,
task_progress_mode: bool = True,
cache_dir: str | None = None,
) -> TaskCache:
"""Create the appropriate task cache."""
if force or not task_progress_mode:
return NoTaskCache()
if cache_dir is None:
cache_dir = os.getcwd()
return FileTaskCache(cache_dir=cache_dir)
[docs]
class NoTaskCache(dict[Any, Any], TaskCache):
"""Don't check any conditions and just run any task."""
[docs]
def get_record(
self, task_desc: TaskDesc, # noqa: ARG002
) -> CacheRecord:
return CacheRecord(CacheRecordType.NEEDS_COMPUTE)
[docs]
def cache(
self, task: Task, *,
is_error: bool, result: Any,
) -> None:
pass
[docs]
class FileTaskCache(TaskCache):
"""Use file modification timestamps to determine if a task needs to run."""
def __init__(self, cache_dir: str):
self.cache_dir = cache_dir
[docs]
def get_record(self, task_desc: TaskDesc) -> CacheRecord:
"""Get the cache record for a task."""
flag_filename = self._get_flag_filename(task_desc.task)
try:
with fsspec.open(flag_filename, "rb") as cache_file:
task_record = cast(
CacheRecord,
pickle.load(cache_file)) # pyright: ignore
except FileNotFoundError:
return CacheRecord(CacheRecordType.NEEDS_COMPUTE)
except Exception: # pylint: disable=broad-except
logger.exception(
"Cannot read status for task %s. Ignoring and continuing.",
task_desc,
)
return CacheRecord(CacheRecordType.NEEDS_COMPUTE)
if task_record.type != CacheRecordType.COMPUTED:
return task_record
if self._needs_recompute(task_desc):
task_record = CacheRecord(
CacheRecordType.NEEDS_COMPUTE,
result_or_error=task_record.result,
)
return task_record
[docs]
def cache(self, task: Task, *, is_error: bool, result: Any) -> None:
record_type = (
CacheRecordType.ERROR if is_error else CacheRecordType.COMPUTED
)
record = CacheRecord(
record_type,
result,
)
cache_fn = self._get_flag_filename(task)
try:
with fsspec.open(cache_fn, "wb") as cache_file:
pickle.dump(record, cache_file) # pyright: ignore
except Exception: # pylint: disable=broad-except
logger.exception(
"Cannot write cache for task %s. Ignoring and continuing.",
task,
)
def _get_flag_filename(self, task: Task) -> str:
return fs_utils.join(self.cache_dir, f"{task.task_id}.flag")
def _needs_recompute(
self, task: TaskDesc,
) -> bool:
"""
Determine if a task needs to be recomputed.
Phase 1 — output_files (final outputs the user may delete):
If any are missing → recompute. If all exist, compare against inputs.
Phase 2 — intermediate_output_files (consumed by downstream tasks):
If all exist, compare against inputs. If missing → fall through to
the flag-file check so the task is not needlessly recomputed.
Phase 3 — flag-file check (fallback when no output files are declared
or when intermediate outputs are missing).
"""
input_files = task.input_files
if task.output_files:
out_mtime = self._get_oldest_mod_time(task.output_files)
if out_mtime is None:
return True # missing final output → recompute
in_mtime = self._get_newest_mod_time(input_files)
if len(input_files) == 0:
return False # no inputs, outputs exist
if in_mtime is None:
return True # missing input file
return in_mtime > out_mtime
if task.intermediate_output_files:
out_mtime = self._get_oldest_mod_time(
task.intermediate_output_files)
if out_mtime is not None:
in_mtime = self._get_newest_mod_time(input_files)
if len(input_files) == 0:
return False # no inputs, outputs exist
if in_mtime is None:
return True # missing input file
return in_mtime > out_mtime
# files missing (consumed by downstream) → fall through
output_files = [self._get_flag_filename(task.task)]
input_mtime = self._get_newest_mod_time(input_files)
output_mtime = self._get_oldest_mod_time(output_files)
if len(input_files) == 0 and output_mtime is not None:
return False # no inputs, flag exists
if input_mtime is None or output_mtime is None:
return True # cannot determine mod times
return input_mtime > output_mtime
def _get_oldest_mod_time(
self, filenames: list[str],
) -> datetime.datetime | None:
mtimes = [self._safe_getmtime(path) for path in filenames]
if any(p is None for p in mtimes):
# cannot determine the mtime of a filename. Assume it needs recalc.
return None
if len(mtimes) > 0:
return min(cast(list[datetime.datetime], mtimes))
return None
def _get_newest_mod_time(
self, filenames: list[str],
) -> datetime.datetime | None:
mtimes = [self._safe_getmtime(path) for path in filenames]
if any(p is None for p in mtimes):
# cannot determine the mtime of a filename. Assume it needs recalc.
return None
if len(mtimes) > 0:
return max(cast(list[datetime.datetime], mtimes))
return None
def _safe_getmtime(self, path: str) -> datetime.datetime | None:
if fs_utils.exists(path):
return fs_utils.modified(path)
return None