Skip to content

LightRL API Reference

Welcome to the detailed API reference for LightRL. Below you'll find documentation for the key classes and functions available in the library, complete with usage guidelines and examples.

Bandits Module

LightRL includes a variety of bandit algorithms, each tailored for specific use cases in reinforcement learning environments. The following classes are part of the lightrl.bandits module:

Base Bandit Class

Bandit: The foundational class for all bandit algorithms. Subclasses provide specialized implementations.

Bases: ABC

Source code in lightrl/bandits.py
class Bandit(ABC):
    def __init__(self, arms: list, priors: Optional[List[float]] = None, ema_alpha: float = 0.0):
        self.arms = arms
        if priors is not None and len(priors) != len(arms):
            raise ValueError(f"priors length {len(priors)} != arms length {len(arms)}")
        self.q_values = list(priors) if priors else [0.0] * len(arms)
        self.counts = [0] * len(arms)
        self.ema_alpha = ema_alpha

    @abstractmethod
    def select_arm(self) -> int: ...

    def update(self, arm_index: int, reward: float) -> None:
        self.counts[arm_index] += 1
        if self.ema_alpha > 0:
            self.q_values[arm_index] += self.ema_alpha * (reward - self.q_values[arm_index])
        else:
            n = self.counts[arm_index]
            self.q_values[arm_index] = ((n - 1) * self.q_values[arm_index] + reward) / n

    def __repr__(self) -> str:
        return f"{self.__class__.__name__}(arms={self.arms})"

    def report(self) -> None:
        print("Q-values per arm:")
        for arm, q, cnt in zip(self.arms, self.q_values, self.counts):
            print(f"  {arm}: avg_reward={q:.5f}, count={cnt}")

    def _exploit(self) -> int:
        max_q = max(self.q_values)
        return random.choice([i for i, q in enumerate(self.q_values) if q == max_q])

    def save(self, path: Union[str, Path]) -> None:
        data = {"class": self.__class__.__name__, "state": self.__dict__.copy()}
        Path(path).write_text(json.dumps(data, default=str))

    @classmethod
    def load(cls, path: Union[str, Path]) -> Bandit:
        data = json.loads(Path(path).read_text())
        registry: Dict[str, Type[Bandit]] = {c.__name__: c for c in _all_bandit_classes()}
        klass = registry[data["class"]]
        obj: Bandit = object.__new__(klass)
        obj.__dict__.update(data["state"])
        return obj

Epsilon-Based Bandits

These bandits use epsilon strategies to balance exploration and exploitation.

EpsilonGreedyBandit: Implements an epsilon-greedy algorithm, allowing for a tunable exploration rate.

Bases: Bandit

Source code in lightrl/bandits.py
class EpsilonGreedyBandit(Bandit):
    def __init__(self, arms: list, epsilon: float = 0.1, **kwargs) -> None:
        super().__init__(arms, **kwargs)
        self.epsilon = epsilon

    def select_arm(self) -> int:
        if random.random() < self.epsilon:
            return random.randint(0, len(self.arms) - 1)
        return self._exploit()

EpsilonFirstBandit: Prioritizes exploration for a set number of initial steps before switching to exploitation.

Bases: Bandit

Source code in lightrl/bandits.py
class EpsilonFirstBandit(Bandit):
    def __init__(self, arms: list, exploration_steps: int = 100, epsilon: float = 0.1, **kwargs):
        super().__init__(arms, **kwargs)
        self.exploration_steps = exploration_steps
        self.epsilon = epsilon
        self.step = 0

    def select_arm(self) -> int:
        if self.step < self.exploration_steps or random.random() < self.epsilon:
            return random.randint(0, len(self.arms) - 1)
        self.step += 1
        return self._exploit()

EpsilonDecreasingBandit: Uses a decreasing epsilon value over time to reduce exploration as understanding improves.

Bases: Bandit

Source code in lightrl/bandits.py
class EpsilonDecreasingBandit(Bandit):
    def __init__(
        self,
        arms: list,
        initial_epsilon: float = 1.0,
        limit_epsilon: float = 0.1,
        half_decay_steps: int = 100,
        **kwargs,
    ):
        super().__init__(arms, **kwargs)
        self.epsilon = initial_epsilon
        self.initial_epsilon = initial_epsilon
        self.limit_epsilon = limit_epsilon
        self.half_decay_steps = half_decay_steps
        self.step = 0

    def _update_epsilon(self) -> None:
        decay = 0.5 ** (self.step / self.half_decay_steps)
        self.epsilon = self.limit_epsilon + (self.initial_epsilon - self.limit_epsilon) * decay

    def select_arm(self) -> int:
        self.step += 1
        self._update_epsilon()
        if random.random() < self.epsilon:
            return random.randint(0, len(self.arms) - 1)
        return self._exploit()

Other Bandit Strategies

UCB1Bandit: Employs the UCB1 algorithm, focusing on arm pulls with calculated confidence bounds.

Bases: Bandit

Source code in lightrl/bandits.py
class UCB1Bandit(Bandit):
    def __init__(self, arms: list, **kwargs):
        super().__init__(arms, **kwargs)
        self.total_count = 0

    def select_arm(self) -> int:
        for i, count in enumerate(self.counts):
            if count == 0:
                return i
        ucb_values = [
            self.q_values[i] + math.sqrt(2 * math.log(self.total_count) / self.counts[i])
            for i in range(len(self.arms))
        ]
        return ucb_values.index(max(ucb_values))

    def update(self, arm_index: int, reward: float) -> None:
        if not (0 <= reward <= 1):
            raise ValueError("Reward must be in the range [0, 1].")
        self.total_count += 1
        super().update(arm_index, reward)

GreedyBanditWithHistory: A variant that uses historical performance data to adjust its greedy selection strategy.

Bases: Bandit

Source code in lightrl/bandits.py
class GreedyBanditWithHistory(Bandit):
    def __init__(self, arms: list, history_length: int = 100, **kwargs):
        super().__init__(arms, **kwargs)
        self.history_length = history_length
        self.history: List[List[float]] = [[] for _ in range(len(arms))]

    def select_arm(self) -> int:
        incomplete = [i for i, h in enumerate(self.history) if len(h) < self.history_length]
        if incomplete:
            return random.choice(incomplete)
        return self._exploit()

    def update(self, arm_index: int, reward: float) -> None:
        h = self.history[arm_index]
        if len(h) >= self.history_length:
            h.pop(0)
        h.append(reward)
        self.counts[arm_index] = len(h)
        self.q_values[arm_index] = sum(h) / len(h)

Runners Module

two_state_time_dependent_process: The function two_state_time_dependent_process() takes the bandit and keeps two states. The ALIVE state and WAITING state, the bandit is switching between those two states in order to probe the rewards (tasks per seconds multiplied by reward_factor). In WAITING state we can select lower number of tasks to process (waiting_args).

Source code in lightrl/runners.py
def two_state_time_dependent_process(
    bandit,
    fun,
    failure_threshold=0.1,
    default_wait_time=5,
    extra_wait_time=10,
    waiting_args=None,
    max_steps=500,
    verbose=False,
    reward_factor=1e-6,
):
    if waiting_args is None:
        raise ValueError("waiting_args must be provided")
    waiting_args = _ensure_tuple(waiting_args)

    state = "ALIVE"
    last_alive_successes = 0.0
    last_arm_index = None
    waiting_time = 0.0

    iterator = tqdm(range(max_steps)) if verbose else range(max_steps)

    for _ in iterator:
        if verbose:
            bandit.report()

        if state == "ALIVE":
            arm_idx = bandit.select_arm()
            fun_args = _ensure_tuple(bandit.arms[arm_idx])
            ok, fail = fun(*fun_args)
            time.sleep(default_wait_time)
            waiting_time += default_wait_time

            if fail / (ok + fail) >= failure_threshold:
                last_alive_successes = ok
                last_arm_index = arm_idx
                state = "WAITING"
            else:
                bandit.update(arm_idx, ok / waiting_time * reward_factor)
                waiting_time = 0.0
        else:
            ok, fail = fun(*waiting_args)
            time.sleep(default_wait_time + extra_wait_time)
            waiting_time += default_wait_time + extra_wait_time

            if fail / (ok + fail) < failure_threshold:
                bandit.update(last_arm_index, last_alive_successes / waiting_time * reward_factor)
                waiting_time = 0.0
                state = "ALIVE"

    if verbose:
        bandit.report()

If you have any questions or require further assistance, feel free to open an issue.