"""Common data structures for simulator I/O."""
import chex
import jax.numpy as jnp
import jax_md
[docs]
@chex.dataclass()
class SimulatorTrajectory:
"""A trajectory of a simulation run."""
rigid_body: jax_md.rigid_body.RigidBody
[docs]
def slice(self, key: int | slice) -> "SimulatorTrajectory":
"""Slice the trajectory."""
if isinstance(key, int):
key = slice(key, key + 1)
return self.replace(
rigid_body=jax_md.rigid_body.RigidBody(
center=self.rigid_body.center[key, ...],
orientation=jax_md.rigid_body.Quaternion(
vec=self.rigid_body.orientation.vec[key, ...],
),
)
)
[docs]
def length(self) -> int:
"""Return the length of the trajectory.
Note, that this may have been more natural to implement as the built-in
__len__ method. However, the chex.dataclass decorator overrides that
method to be compatabile with the abc.Mapping interface
See here:
https://github.com/google-deepmind/chex/blob/8af2c9e8a19f3a57d9bd283c2a34148aef952f60/chex/_src/dataclass.py#L50
"""
return self.rigid_body.center.shape[0]
[docs]
def __add__(self, other: "SimulatorTrajectory") -> "SimulatorTrajectory":
"""Concatenate two trajectories."""
return self.replace(
rigid_body=jax_md.rigid_body.RigidBody(
center=jnp.concat(
[self.rigid_body.center, other.rigid_body.center],
axis=0,
),
orientation=jax_md.rigid_body.Quaternion(
vec=jnp.concatenate([self.rigid_body.orientation.vec, other.rigid_body.orientation.vec], axis=0)
),
)
)