Source code for jaxkuramoto.ode

from typing import Callable

from jax.lax import fori_loop
from jax import jit
import jax.numpy as jnp

from .solution import Solution

VECTOR_FN = Callable[[float, jnp.ndarray], jnp.ndarray]
SOLVER = Callable[[VECTOR_FN, float, float, jnp.ndarray], jnp.ndarray]
OBSERVABLE_FN = Callable[[jnp.ndarray], jnp.ndarray]

class ODE:
    """Ordinary Differential Equation (ODE) Class.
    """
    def __init__(self) -> None:
        pass

    def vector_fn(self, t: float, state: jnp.ndarray) -> jnp.ndarray:
        """Vector field of the ODE.
        """
        raise NotImplementedError()

[docs]def odeint(vector_fn: VECTOR_FN, solver: SOLVER, t0: float, t1: float, dt: float, init_state: jnp.ndarray, observable_fn: OBSERVABLE_FN=None): """Integrate the ODE. Args: vector_fn (VECTOR_FN): Vector field of the ODE. solver (SOLVER): Solver of the ODE. t0 (float): Initial time. t1 (float): Final time. dt (float): Time step. init_state (jnp.ndarray): Initial state. observable_fn (OBSERVABLE_FN, optional): Function to calculate observables. Defaults to None. Returns: Solution: Solution class. Raises: ValueError: If t0 >= t1 or dt <= 0. """ # check input if t0 >= t1: raise ValueError("t0 must be smaller than t1.") if dt <= 0: raise ValueError("dt must be positive.") # create update function update_fn = jit(lambda t, state: solver(vector_fn, t, dt, state)) ts = jnp.arange(t0, t1, dt) n_step = ts.shape[0] # check observable_fn if observable_fn is None: observable_fn = lambda _t, _x: _x # create observables list observable_1st = observable_fn(t0, init_state) observables = jnp.zeros(shape=(n_step, *(observable_1st.shape)), dtype=observable_1st.dtype) observables = observables.at[0].set(observable_1st) @jit def body_fn(i, val): state, _observables = val next_state = update_fn(ts[i], state) observable_i = observable_fn(ts[i], state) _observables = _observables.at[i].set(observable_i) return (next_state, _observables) # run!! final_state, observables = fori_loop(1, n_step, body_fn, (init_state, observables)) # store the result to Solution class sol = Solution( t0=t0, t1=t1, ts=ts, init_state=init_state, final_state=final_state, observables=observables ) return sol
def odeint_resume(): pass