Skip to content

Documentation for parajax

Enabling parallel execution on CPU

By default, JAX on CPU only uses a single core. To enable parallel execution on all available CPU cores, set the jax_num_cpu_devices configuration option appropriately. This should be done at the beginning of your code as follows:

import multiprocessing
import jax

jax.config.update("jax_num_cpu_devices", multiprocessing.cpu_count())

parajax

Parallelization utilities for JAX.

parallelize(func: Callable[_P, _T] | None = None, /, *, max_devices: int | None = None, remainder_strategy: Literal['pad', 'drop', 'strict'] = 'pad') -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]

parallelize(func: Callable[_P, _T], /, *, max_devices: int | None = ..., remainder_strategy: Literal['pad', 'drop', 'strict'] = ...) -> Callable[_P, _T]
parallelize(*, max_devices: int | None = ..., remainder_strategy: Literal['pad', 'drop', 'strict'] = ...) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]

Automatic parallelizing map.

Creates a parallelized version of func that distributes computation of the leading axis of array arguments across multiple devices.

func: The function to be parallelized. It should accept array arguments with a leading batch dimension. If your function cannot work in a batched manner, you can wrap it with jax.vmap first. For passing non-batched arguments, consider using functools.partial or a lambda function. max_devices: The maximum number of JAX devices to use for parallelization. remainder_strategy: Specifies how to handle cases where the batch size is not divisible by the number of devices, which is not directly supported by JAX. The available strategies are: - "pad" (default): Transparently pad the input arrays along the leading axis to make the batch size divisible by the number of devices. The padding is done by repeating the last element. The output is then automatically unpadded to match the original batch size, with no visible effect to the caller. - "drop": Use with caution. The extra elements that do not fit evenly into the devices are dropped from the computation, resulting in a smaller output size. - "strict": Will only work if the batch size is divisible by the number of devices. Otherwise, a ValueError will be raised.

Returns:

Type Description
Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]

The decorator returns a parallel version of func with the same signature.

Basic usage
import jax.numpy as jnp
from parajax import parallelize

@parallelize
def square(xs):
    return xs ** 2

xs = jnp.arange(12_345)
ys = square(xs)  # This will run in parallel across available JAX devices
Setting options
import jax.numpy as jnp
from parajax import parallelize

@parallelize(max_devices=4)
def square(xs):
    return xs ** 2

xs = jnp.arange(12_345)
ys = square(xs)  # Parallelized across 4 devices
Composability with vmap
import jax
import jax.numpy as jnp
from parajax import parallelize

@parallelize
@jax.vmap
def relu_single(x):
    return jnp.maximum(x, 0)

xs = jnp.arange(-6_000, 6_000)
ys = relu_single(xs)  # Parallelized over the batch
Source code in parajax/__init__.py
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
def parallelize(
    func: Callable[_P, _T] | None = None,
    /,
    *,
    max_devices: int | None = None,
    remainder_strategy: Literal["pad", "drop", "strict"] = "pad",
) -> Callable[_P, _T] | Callable[[Callable[_P, _T]], Callable[_P, _T]]:
    """Automatic parallelizing map.

    Creates a parallelized version of `func` that distributes computation of the
    leading axis of array arguments across multiple devices.

    Args:
    func: The function to be parallelized. It should accept array arguments with a
        leading batch dimension. If your function cannot work in a batched manner, you
        can wrap it with `jax.vmap` first. For passing non-batched arguments, consider
        using `functools.partial` or a lambda function.
    max_devices: The maximum number of JAX devices to use for parallelization.
    remainder_strategy: Specifies how to handle cases where the batch size is not
        divisible by the number of devices, which is not directly supported by JAX. The
        available strategies are:
        - `"pad"` (default): Transparently pad the input arrays along the leading axis
            to make the batch size divisible by the number of devices. The padding is
            done by repeating the last element. The output is then automatically
            unpadded to match the original batch size, with no visible effect to the
            caller.
        - `"drop"`: Use with caution. The extra elements that do not fit evenly into the
            devices are dropped from the computation, resulting in a smaller output
            size.
        - `"strict"`: Will only work if the batch size is divisible by the number of
            devices. Otherwise, a `ValueError` will be raised.

    Returns:
        The decorator returns a parallel version of `func` with the same signature.

    Basic usage:
        ```python
        import jax.numpy as jnp
        from parajax import parallelize

        @parallelize
        def square(xs):
            return xs ** 2

        xs = jnp.arange(12_345)
        ys = square(xs)  # This will run in parallel across available JAX devices
        ```

    Setting options:
        ```python
        import jax.numpy as jnp
        from parajax import parallelize

        @parallelize(max_devices=4)
        def square(xs):
            return xs ** 2

        xs = jnp.arange(12_345)
        ys = square(xs)  # Parallelized across 4 devices
        ```

    Composability with vmap:
        ```python
        import jax
        import jax.numpy as jnp
        from parajax import parallelize

        @parallelize
        @jax.vmap
        def relu_single(x):
            return jnp.maximum(x, 0)

        xs = jnp.arange(-6_000, 6_000)
        ys = relu_single(xs)  # Parallelized over the batch
        ```
    """
    if max_devices is not None and max_devices < 1:
        msg = "max_devices must be at least 1"
        raise ValueError(msg)

    if remainder_strategy not in {"pad", "drop", "strict"}:
        msg = f"invalid remainder_strategy: {remainder_strategy}"
        raise ValueError(msg)

    def parallelize_decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
        @functools.wraps(func)
        @jax.jit
        def parallelize_wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
            device_count = jax.device_count()

            if max_devices != 1 and device_count == 1:
                msg = (
                    "parajax.parallelize: parallelization requested but only a single"
                    " JAX device is available."
                )
                if jax.default_backend() == "cpu" and multiprocessing.cpu_count() > 1:
                    msg += (
                        '\nSet \'jax.config.update("jax_num_cpu_devices",'
                        f" {multiprocessing.cpu_count()})' right after importing JAX to"
                        " enable all available CPUs."
                        "\nSee https://parajax.readthedocs.io for more information."
                    )
                warnings.warn(msg, UserWarning, stacklevel=2)

            if max_devices is not None and max_devices > device_count:
                msg = (
                    "max_devices cannot be greater than the number of available JAX"
                    f" devices (={device_count})"
                )
                raise ValueError(msg)

            devices = max_devices if max_devices is not None else device_count

            flat_args, _ = jax.tree.flatten((args, kwargs))
            batch_sizes = {jnp.shape(arg)[0] for arg in flat_args}
            if len(batch_sizes) > 1:
                msg = f"mismatched sizes for mapped axes: {batch_sizes}"
                raise ValueError(msg)
            try:
                batch_size = batch_sizes.pop()
            except KeyError:
                msg = "no arguments to map over"
                raise ValueError(msg) from None

            devices = min(devices, batch_size)
            pfunc = _parallelize_strict(func, devices=devices)

            match remainder_strategy:
                case "strict":
                    if batch_size % devices != 0:
                        msg = (
                            f"remainder_strategy='strict' but batch size {batch_size}"
                            f" is not divisible by the number of devices {devices}"
                        )
                        raise ValueError(msg)

                    return pfunc(*args, **kwargs)

                case "drop":
                    remainder_size = batch_size % devices
                    even_size = batch_size - remainder_size

                    args_even, kwargs_even = jax.tree.map(
                        lambda x: x[:even_size], (args, kwargs)
                    )

                    return pfunc(*args_even, **kwargs_even)

                case "pad":
                    pad_size = (-batch_size) % devices

                    padded_args, padded_kwargs = jax.tree.map(
                        lambda x: jnp.pad(
                            x, [(0, pad_size)] + [(0, 0)] * (x.ndim - 1), mode="edge"
                        ),
                        (args, kwargs),
                    )

                    padded_output = pfunc(*padded_args, **padded_kwargs)

                    return jax.tree.map(lambda x: x[:batch_size], padded_output)

        return parallelize_wrapper

    return parallelize_decorator(func) if func is not None else parallelize_decorator