Source code for continuum.frontend.param

from dataclasses import dataclass
from dataclasses import field
from typing import Any


[docs] @dataclass class Param: """Typed program parameter tracked by the Continuum optimizer.""" kind: str value: Any metadata: dict[str, Any] = field(default_factory=dict)
[docs] @staticmethod def tensor(shape=None, dtype="f32", initial=None, **metadata) -> "Param": merged = {"shape": tuple(shape) if shape is not None else None, "dtype": dtype} merged.update(metadata) return Param(kind="tensor", value=initial, metadata=merged)
[docs] @staticmethod def text(initial: str, **metadata) -> "Param": return Param(kind="text", value=initial, metadata=metadata)
[docs] @staticmethod def fewshot(k: int = 3) -> "Param": return Param(kind="text", value="", metadata={"fewshot": k})
[docs] @staticmethod def lora(rank: int = 8) -> "Param": return Param(kind="tensor", value=None, metadata={"adapter": "lora", "rank": rank})
[docs] @staticmethod def discrete(initial: Any = 0, *, choices=None, **metadata) -> "Param": merged = dict(metadata) if choices is not None: merged["choices"] = list(choices) return Param(kind="discrete", value=initial, metadata=merged)
[docs] @staticmethod def continuous(initial: float = 0.0, *, min_value=None, max_value=None, **metadata) -> "Param": merged = dict(metadata) if min_value is not None: merged["min"] = float(min_value) if max_value is not None: merged["max"] = float(max_value) return Param(kind="continuous", value=float(initial), metadata=merged)