Source code for jaxkuramoto.theory.ott_antonsen

import jax.numpy as jnp

from jaxkuramoto.ode import ODE
from jaxkuramoto.distribution import Distribution

[docs]class OttAntonsen(ODE): """Ott-Antonsen reduction of the Kuramoto model.""" def __init__(self, dist: Distribution, K: float) -> None: """Ott-Antonsen reduction of the Kuramoto model. Args: dist (Distribution): Distribution of natural frequencies. K (float): Coupling strength. """ super().__init__() self.dist = dist self.dist_name = dist.__class__.__name__ self.K = K if self.dist_name == "Cauchy": self.vector_fn = self._vector_fn_cauchy self.to_orderparam = lambda _, z: jnp.abs(z) elif self.dist_name == "CauchyMultiply": self.vector_fn = self._vector_fn_cauchymultiply self.k1 = dist.gamma2 * (2 * dist.Omega - 1j * (dist.gamma1 + dist.gamma2)) / (dist.gamma1 + dist.gamma2) / (2 * dist.Omega + 1j * (dist.gamma1 - dist.gamma2)) self.k2 = dist.gamma1 * (2 * dist.Omega + 1j * (dist.gamma1 + dist.gamma2)) / (dist.gamma1 + dist.gamma2) / (2 * dist.Omega + 1j * (dist.gamma1 - dist.gamma2)) self.zs2z = lambda zs: self.k1 * zs[0] + self.k2 * zs[1] self.to_orderparam = lambda _, zs: jnp.abs(self.zs2z(zs)) else: raise ValueError("Distribution must be Cauchy or CauchyMultiply.") def _vector_fn_cauchy(self, t, z: jnp.ndarray) -> jnp.ndarray: """Vector field of Ott-Antonsen reduction of the Kuramoto model with the Cauchy distribution. Args: t (float): time z (jnp.ndarray): orderparameter. Returns: jnp.ndarray: Derivatives of orderparameter. """ return 1j * self.dist.loc + (0.5 * self.K - self.dist.gamma) * z - 0.5 * self.K * jnp.abs(z)**2 * z def _vector_fn_cauchymultiply(self, t, zs: jnp.ndarray) -> jnp.ndarray: """Vector field of Ott-Antonsen reduction of the Kuramoto model with the CauchyMultiply distribution. Args: t (float): time zs (jnp.ndarray): Oscillator phases. Returns: jnp.ndarray: Derivatives of oscillator phases. """ z1, z2 = zs z = self.zs2z(zs) dz1 = (1j * self.dist.Omega - self.dist.gamma1) * z1 - 0.5 * self.K * (z.conj() * z1 * z1 - z) dz2 = -(1j * self.dist.Omega + self.dist.gamma2) * z2 - 0.5 * self.K * (z.conj() * z2 * z2 - z) return jnp.array([dz1, dz2])