from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Collection, TypeVar, Generic, Optional, Iterable, Mapping
from openmnglab.datamodel.exceptions import DataSchemeCompatibilityError
from openmnglab.model.datamodel.interface import IInputDataScheme, IOutputDataScheme
from openmnglab.model.functions.interface import IFunctionDefinition
from openmnglab.planning.exceptions import InvalidFunctionArgumentCountError, FunctionArgumentSchemaError, PlanningError
from openmnglab.model.planning.interface import IExecutionPlanner, IProxyData
from openmnglab.model.planning.plan.interface import IExecutionPlan, IStage, IPlannedData, IPlannedElement
[docs]class ProxyData(IProxyData):
[docs] def __init__(self, planned_hash: bytes):
self._planned_hash = planned_hash
@property
def calculated_hash(self) -> bytes:
return self._planned_hash
@staticmethod
def copy_from(other: IProxyData) -> ProxyData:
return ProxyData(other.calculated_hash)
[docs]class ExecutionPlan(IExecutionPlan):
[docs] def __init__(self, functions: Iterable[IStage] | Mapping[bytes, IStage],
data: Iterable[IPlannedData] | Mapping[bytes, IPlannedData]):
def to_mapping(param: Iterable[IPlannedElement] | Mapping[bytes, IPlannedElement]):
return param if isinstance(param, Mapping) else {element.calculated_hash: element for element in param}
self._functions: Mapping[bytes, IStage] = to_mapping(functions)
self._proxy_data: Mapping[bytes, IPlannedData] = to_mapping(data)
@property
def stages(self) -> Mapping[bytes, IStage]:
return self._functions
@property
def planned_data(self) -> Mapping[bytes, IPlannedData]:
return self._proxy_data
_FuncT = TypeVar('_FuncT', bound=IStage)
_DataT = TypeVar('_DataT', bound=IPlannedData)
[docs]class PlannerBase(IExecutionPlanner, ABC, Generic[_FuncT, _DataT]):
[docs] def __init__(self):
self._functions: dict[bytes, _FuncT] = dict()
self._data: dict[bytes, _DataT] = dict()
def get_plan(self) -> ExecutionPlan:
return ExecutionPlan(self._functions.copy(), self._data.copy())
@abstractmethod
def _add_function(self, function: IFunctionDefinition, *inp_data: _DataT) -> tuple[IProxyData, ...]:
...
def add_function(self, function: IFunctionDefinition, *inp_data: IProxyData) -> Optional[tuple[IProxyData, ...]]:
result = self._add_function(function, *self._proxy_data_to_concrete(*inp_data))
return result if len(result) > 0 else None
def _proxy_data_to_concrete(self, *inp_data: IProxyData) -> Iterable[_DataT]:
for pos, inp in enumerate(inp_data):
concrete_data = self._data.get(inp.calculated_hash)
if concrete_data is None:
raise PlanningError(
f"Argument at position {pos} with hash {inp.calculated_hash.hex()} is not part of this plan and therefore cannot be used as an argument in it")
yield concrete_data