Source code for dae.task_graph.cache

from __future__ import annotations

import datetime
import logging
import os
import pickle  # noqa: S403
from abc import abstractmethod
from collections.abc import Generator, Iterator
from copy import copy
from dataclasses import dataclass
from enum import Enum
from typing import Any, cast

import fsspec

from dae.task_graph.graph import Task, TaskGraph
from dae.utils import fs_utils

logger = logging.getLogger(__name__)


[docs] class CacheRecordType(Enum): NEEDS_COMPUTE = 0 COMPUTED = 1 ERROR = 2
[docs] @dataclass(frozen=True) class CacheRecord: """Encapsulate information about a task in the cache.""" type: CacheRecordType result_or_error: Any = None @property def result(self) -> Any: assert self.type == CacheRecordType.COMPUTED return self.result_or_error @property def error(self) -> Any: assert self.type == CacheRecordType.ERROR return self.result_or_error
[docs] class TaskCache: """Store the result of a task in a file and reuse it if possible."""
[docs] @abstractmethod def load(self, graph: TaskGraph) -> Iterator[tuple[Task, CacheRecord]]: """For task in the `graph` load and yield the cache record."""
[docs] @abstractmethod def cache( self, task_node: Task, *, is_error: bool, result: Any, ) -> None: """Cache the result or exception of a task."""
[docs] @staticmethod def create( *, force: bool | None = None, cache_dir: str | None = None, ) -> TaskCache: """Create the appropriate task cache.""" if force is None: # the force_mode is set to 'always' return NoTaskCache() if cache_dir is None: cache_dir = os.getcwd() return FileTaskCache(force=force, cache_dir=cache_dir)
[docs] class NoTaskCache(dict, TaskCache): """Don't check any conditions and just run any task."""
[docs] def load( self, graph: TaskGraph, ) -> Generator[tuple[Task, CacheRecord], None, None]: for task in graph.tasks: yield task, CacheRecord(CacheRecordType.NEEDS_COMPUTE)
[docs] def cache( self, task_node: Task, *, is_error: bool, result: Any, ) -> None: pass
[docs] class FileTaskCache(TaskCache): """Use file modification timestamps to determine if a task needs to run.""" def __init__(self, cache_dir: str, *, force: bool = False): self.force = force self.cache_dir = cache_dir self._global_dependancies: list[str] | None = None self._mtime_cache: dict[str, datetime.datetime] = {}
[docs] def load( self, graph: TaskGraph, ) -> Generator[tuple[Task, CacheRecord], None, None]: assert self._global_dependancies is None self._global_dependancies = graph.input_files task2record: dict[Task, CacheRecord] = {} for task in graph.tasks: yield task, self._get_record(task, task2record) self._global_dependancies = None self._mtime_cache = {}
def _get_record( self, task_node: Task, task2record: dict[Task, CacheRecord], ) -> CacheRecord: if self.force: return CacheRecord(CacheRecordType.NEEDS_COMPUTE) record = task2record.get(task_node) if record is not None: return record unsatisfied_deps = False for dep in task_node.deps: dep_rec = self._get_record(dep, task2record) if dep_rec.type != CacheRecordType.COMPUTED: unsatisfied_deps = True break if unsatisfied_deps or self._needs_compute(task_node): res_record = CacheRecord(CacheRecordType.NEEDS_COMPUTE) task2record[task_node] = res_record return res_record output_fn = self._get_flag_filename(task_node) with fsspec.open(output_fn, "rb") as cache_file: res_record = cast(CacheRecord, pickle.load(cache_file)) # noqa: S301 task2record[task_node] = res_record return res_record def _needs_compute(self, task: Task) -> bool: # check _global_dependancies only for first level task_nodes if len(task.deps) == 0: in_files = copy(self._global_dependancies) else: in_files = [] assert in_files is not None in_files.extend(task.input_files) for dep in task.deps: in_files.append(self._get_flag_filename(dep)) output_fn = self._get_flag_filename(task) return self._should_recompute_output(in_files, [output_fn])
[docs] def cache(self, task_node: Task, *, is_error: bool, result: Any) -> None: record_type = ( CacheRecordType.ERROR if is_error else CacheRecordType.COMPUTED ) record = CacheRecord( record_type, result, ) cache_fn = self._get_flag_filename(task_node) try: with fsspec.open(cache_fn, "wb") as cache_file: pickle.dump(record, cache_file) except Exception: # pylint: disable=broad-except logger.exception( "Cannot write cache for task %s. Ignoring and continuing.", task_node, )
def _get_flag_filename(self, task_node: Task) -> str: return fs_utils.join(self.cache_dir, f"{task_node.task_id}.flag") def _should_recompute_output( self, input_files: list[str], output_files: list[str], ) -> bool: input_mtime = self._get_last_mod_time(input_files) output_mtime = self._get_last_mod_time(output_files) if len(input_files) == 0 and output_mtime is not None: return False # No input, but output file exists, don't recompute if input_mtime is None or output_mtime is None: return True # cannot determine mod times. Always run. should_run: bool = input_mtime > output_mtime return should_run def _get_last_mod_time( self, filenames: list[str], ) -> datetime.datetime | None: mtimes = [self._safe_getmtime(path) for path in filenames] if any(p is None for p in mtimes): # cannot determine the mtime of a filename. Assume it needs recalc. return None if len(mtimes) > 0: return max(cast(list[datetime.datetime], mtimes)) return None def _safe_getmtime(self, path: str) -> datetime.datetime | None: assert self._mtime_cache is not None if path in self._mtime_cache: return self._mtime_cache[path] if fs_utils.exists(path): mtime = fs_utils.modified(path) self._mtime_cache[path] = mtime return mtime return None