Source code for earthkit.hydro.data_structures._network
import numpy as np
from earthkit.utils.array import to_device
from ._network_storage import RiverNetworkStorage
[docs]
class RiverNetwork:
"""
A class representing a river network for hydrological processing.
Attributes
----------
n_nodes : int
The number of nodes in the river network.
n_edges : int
The number of nodes in the river network.
sinks : array-like
Nodes with no downstream connections.
sources : array-like
Nodes with no upstream connections.
bifurcates : bool
Whether the river network has bifurcations.
shape : tuple
The size of the river network grid. None if the network is vector-based.
mask : array-like
Flattened 1D indices on the raster grid corresponding to river network nodes.
array_backend : str
The array backend of the river network.
device : str
The device of the river network.
return_type : str
The default return type of the river network. Either "gridded" or "masked".
"""
def __init__(self, river_network_storage: RiverNetworkStorage):
self._storage = river_network_storage
self.n_nodes = self._storage.n_nodes
self.n_edges = self._storage.n_edges
self.sources = self._storage.sources
self.sinks = self._storage.sinks
self.bifurcates = self._storage.bifurcates
self.edge_weights = self._storage.edge_weights
self.mask = self._storage.mask
self.shape = self._storage.shape
self.array_backend = "numpy"
self.device = "cpu"
self.return_type = "gridded"
self.coords = self._storage.coords
self.data = [self._storage.sorted_data]
self.groups = np.split(self._storage.sorted_data, self._storage.splits, axis=1)
def __str__(self):
return f"RiverNetwork with {self.n_nodes} nodes and {self.n_edges} edges."
def __repr__(self):
return self.__str__()
[docs]
def to_device(self, device=None, array_backend=None):
"""
Change the RiverNetwork's array backend and/or move it to a
different device.
Parameters
----------
device : str, optional
The device to which to transfer. Default is None, which is `'cpu'` for all backends except cupy, which is `'gpu'`.
array_backend : str, optional
The array backend.
One of "numpy", "np", "cupy", "cp", "pytorch", "torch", "jax", "jnp", "tensorflow", "tf", "mlx" or "mx".
Default is None, which uses `self.array_backend`.
Returns
-------
RiverNetwork
The modified RiverNetwork.
"""
# TODO: use xp.asarray
if array_backend == "np":
array_backend = "numpy"
elif array_backend == "cp":
array_backend = "cupy"
elif array_backend == "jnp":
array_backend = "jax"
elif array_backend == "tf":
array_backend = "tensorflow"
elif array_backend == "pytorch":
array_backend = "torch"
elif array_backend == "mx":
array_backend = "mlx"
if device is None:
device = "cpu" if array_backend != "cupy" else "gpu"
if array_backend is None:
if self.array_backend == "numpy" and device in ["gpu", "cuda"]:
array_backend = "cupy"
else:
array_backend = self.array_backend
if array_backend in ["torch", "cupy", "numpy"]:
self.groups = [
to_device(group, device, array_backend=array_backend)
for group in self.groups
]
self.mask = to_device(self.mask, device, array_backend=array_backend)
self.data = [to_device(self.data[0], device, array_backend=array_backend)]
elif array_backend == "jax":
assert device == "cpu"
import jax.numpy as jnp
self.groups = [jnp.array(x) for x in self.groups]
self.mask = jnp.array(self.mask)
self.data = [jnp.array(self.data[0])]
elif array_backend == "tensorflow":
assert device == "cpu"
import tensorflow as tf
self.groups = [tf.convert_to_tensor(x, dtype=tf.int32) for x in self.groups]
self.mask = tf.convert_to_tensor(self.mask, dtype=tf.int32)
self.data = [tf.convert_to_tensor(self.data[0], dtype=tf.int32)]
elif array_backend == "mlx":
import mlx.core as mx
self.groups = [mx.array(x) for x in self.groups]
self.mask = mx.array(self.mask)
self.data = [mx.array(self.data[0])]
else:
raise NotImplementedError
self.array_backend = array_backend
if self.array_backend != "mlx":
self.device = self.groups[0].device
else:
self.device = None
return self
[docs]
def set_default_return_type(self, return_type):
"""
Set the default return type for the river network.
Parameters
----------
return_type : str
The default return_type to use.
Returns
-------
None
"""
if return_type not in ["gridded", "masked"]:
raise ValueError(
f'Invalid return_type {return_type}. Valid types are "gridded", "masked"'
)
self.return_type = return_type
[docs]
def export(self, fpath="river_network.joblib", compression=1):
"""
Save the river network to a local file.
Parameters
----------
fpath : str, optional
The filepath specifying where to save the RiverNetwork. Default is `'river_network.joblib'`.
compression : str, optional
The compression factor used for saving. Default is 1.
Returns
-------
None
"""
import joblib
joblib.dump(self._storage, fpath, compress=compression)