Source code for jaxkuramoto.solver.ode_solver

from typing import Callable

from jax.tree_util import tree_map
import jax.numpy as jnp

VECTOR_FN = Callable[[float, jnp.ndarray], jnp.ndarray]

_tmul = lambda a, x: tree_map(lambda _x: a * _x, x)
_tsum = lambda x, y: tree_map(lambda _x, _y: _x + _y, x, y)

[docs]def euler(func: VECTOR_FN, t: float, dt: float, state: jnp.ndarray): """Euler method. Args: func: A function of the form func(t, x) -> dx/dt. t: Current time. dt: Time step. state: Current state. Returns: Next state. """ diff = _tmul(dt, func(t, state)) return _tsum(state, diff)
[docs]def runge_kutta(func: VECTOR_FN, t: float, dt: float, state: jnp.ndarray): """Runge Kutta method. Args: func: A function of the form func(t, x) -> dx/dt. t: Current time. dt: Time step. state: Current state. Returns: Next state. """ k1 = func(t, state) k2 = func(t + 0.5 * dt, _tsum(state, _tmul(0.5 * dt, k1))) k3 = func(t + 0.5 * dt, _tsum(state, _tmul(0.5 * dt, k2))) k4 = func(t + dt, _tsum(state, _tmul(dt, k3))) diff = tree_map( lambda _k1, _k2, _k3, _k4: (_k1 + 2.0 * _k2 + 2.0 * _k3 + _k4) * dt / 6.0, k1, k2, k3, k4 ) return _tsum(state, diff)