Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'lfilter_zi', 'sosfilt_zi', 'get_window', 'besselap', 'envelope', 'remez', 'bessel'
26]
27
28# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
29CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
30
31
32def delegate_xp(delegator, module_name):
33 def inner(func):
34 @functools.wraps(func)
35 def wrapper(*args, **kwds):
36 try:
37 xp = delegator(*args, **kwds)
38 except TypeError:
39 # object arrays
40 if func.__name__ == "tf2ss":
41 import numpy as np
42 xp = np
43 else:
44 raise
45
46 # try delegating to a cupyx/jax namesake
47 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
48 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
49
50 # https://github.com/cupy/cupy/issues/8336
51 import importlib
52 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
53 cupyx_func = getattr(cupyx_module, func_name)
54 kwds.pop('xp', None)
55 return cupyx_func(*args, **kwds)
56 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
57 spx = scipy_namespace_for(xp)
58 jax_module = getattr(spx, module_name)
59 jax_func = getattr(jax_module, func.__name__)
60 kwds.pop('xp', None)
61 return jax_func(*args, **kwds)
62 else:
63 # the original function
64 return func(*args, **kwds)
65 return wrapper
66 return inner
67
68
69# Although most of these functions currently exist in CuPy and some in JAX,
70# there are no alternative backend tests for any of them in the current
71# test suite. Each will be documented as np_only until tests are added.
72untested = {
73 "argrelextrema",
74 "argrelmax",
75 "argrelmin",
76 "band_stop_obj",
77 "check_NOLA",
78 "chirp",
79 "coherence",
80 "csd",
81 "czt_points",
82 "dbode",
83 "dfreqresp",
84 "dlsim",
85 "dstep",
86 "find_peaks",
87 "find_peaks_cwt",
88 "freqresp",
89 "gausspulse",
90 "lombscargle",
91 "lsim",
92 "max_len_seq",
93 "peak_prominences",
94 "peak_widths",
95 "periodogram",
96 "place_pols",
97 "sawtooth",
98 "sepfir2d",
99 "square",
100 "ss2tf",
101 "ss2zpk",
102 "step",
103 "sweep_poly",
104 "symiirorder1",
105 "symiirorder2",
106 "tf2ss",
107 "unit_impulse",
108 "welch",
109 "zoom_fft",
110 "zpk2ss",
111}
112
113
114def get_default_capabilities(func_name, delegator):
115 if delegator is None or func_name in untested:
116 return xp_capabilities(np_only=True)
117 return xp_capabilities()
118
119bilinear_extra_note = \
120 """CuPy does not accept complex inputs.
121
122 """
123
124uses_choose_conv_extra_note = \
125 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
126 but does support higher dimensional arrays for ``method="direct"``
127 and ``method="fft"``.
128
129 """
130
131resample_poly_extra_note = \
132 """CuPy only supports ``padtype="constant"``.
133
134 """
135
136upfirdn_extra_note = \
137 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
138
139 """
140
141xord_extra_note = \
142 """The ``torch`` backend on GPU does not support the case where
143 `wp` and `ws` specify a Bandstop filter.
144
145 """
146
147convolve2d_extra_note = \
148 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
149
150 """
151
152zpk2tf_extra_note = \
153 """The CuPy and JAX backends both support only 1d input.
154
155 """
156
157capabilities_overrides = {
158 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
159 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
160 jax_jit=False, allow_dask_compute=True,
161 reason="Uses np.polynomial.Polynomial",
162 extra_note=bilinear_extra_note),
163 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
164 jax_jit=False, allow_dask_compute=True),
165 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
166 allow_dask_compute=True),
167 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
168 jax_jit=False, allow_dask_compute=True,
169 extra_note=xord_extra_note),
170 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
171 jax_jit=False, allow_dask_compute=True,
172 extra_note=xord_extra_note),
173 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
174 jax_jit=False, allow_dask_compute=True,
175 extra_note=xord_extra_note),
176 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
177 allow_dask_compute=True),
178
179 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
180 allow_dask_compute=True),
181 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
182 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
183 allow_dask_compute=True,
184 extra_note=uses_choose_conv_extra_note),
185 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
186 allow_dask_compute=True,
187 extra_note=convolve2d_extra_note),
188 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
189 allow_dask_compute=True,
190 extra_note=uses_choose_conv_extra_note),
191 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
192 allow_dask_compute=True,
193 extra_note=convolve2d_extra_note),
194 "correlation_lags": xp_capabilities(out_of_scope=True),
195 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
196 jax_jit=False, allow_dask_compute=True),
197 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
198 jax_jit=False, allow_dask_compute=True),
199 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
200 jax_jit=False, allow_dask_compute=True),
201 "czt": xp_capabilities(np_only=True, exceptions=["cupy"]),
202 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
203 allow_dask_compute=True,
204 skip_backends=[("jax.numpy", "item assignment")]),
205 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
206 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
207 allow_dask_compute=True),
208 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
209 "dlti": xp_capabilities(np_only=True,
210 reason="works in CuPy but delegation isn't set up yet"),
211 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
212 allow_dask_compute=True,
213 reason="scipy.special.ellipk"),
214 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
215 jax_jit=False, allow_dask_compute=True,
216 reason="scipy.special.ellipk"),
217 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
218 jax_jit=False, allow_dask_compute=True),
219 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
220 reason="lstsq"),
221 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
222 jax_jit=False, allow_dask_compute=True),
223 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
224 jax_jit=False, allow_dask_compute=True,
225 reason="firwin uses np.interp"),
226 "fftconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"]),
227 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
228 jax_jit=False, allow_dask_compute=True),
229 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
230 jax_jit=False, allow_dask_compute=True),
231 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
232 jax_jit=False, allow_dask_compute=True),
233 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
234 jax_jit=False, allow_dask_compute=True),
235 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
236 jax_jit=False, allow_dask_compute=True),
237 "hilbert": xp_capabilities(
238 cpu_only=True, exceptions=["cupy", "torch"],
239 skip_backends=[("jax.numpy", "item assignment")],
240 ),
241 "hilbert2": xp_capabilities(
242 cpu_only=True, exceptions=["cupy", "torch"],
243 skip_backends=[("jax.numpy", "item assignment")],
244 ),
245 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
246 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
247 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
248 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
249 jax_jit=False, allow_dask_compute=True),
250 "kaiser_atten": xp_capabilities(
251 out_of_scope=True, reason="scalars in, scalars out"
252 ),
253 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
254 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
255 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
256 allow_dask_compute=True, jax_jit=False),
257 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
258 jax_jit=False),
259 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
260 allow_dask_compute=True),
261 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
262 allow_dask_compute=True,
263 skip_backends=[("jax.numpy", "in-place item assignment")]),
264 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
265 allow_dask_compute=True, jax_jit=False),
266 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
267 allow_dask_compute=True,
268 skip_backends=[("jax.numpy", "in-place item assignment")]),
269 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
270 allow_dask_compute=True, jax_jit=False),
271 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
272 allow_dask_compute=True, jax_jit=False),
273 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
274 allow_dask_compute=True, jax_jit=False),
275 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
276 allow_dask_compute=True,
277 skip_backends=[("jax.numpy", "in-place item assignment")]),
278 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
279 allow_dask_compute=True, jax_jit=False),
280 "lti": xp_capabilities(np_only=True,
281 reason="works in CuPy but delegation isn't set up yet"),
282 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
283 allow_dask_compute=True, jax_jit=False,
284 reason="uses scipy.ndimage.rank_filter"),
285 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
286 allow_dask_compute=True, jax_jit=False,
287 reason="c extension module"),
288 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
289 allow_dask_compute=True, jax_jit=False),
290 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
291 jax_jit=False, allow_dask_compute=True),
292 "oaconvolve": xp_capabilities(
293 cpu_only=True, exceptions=["cupy", "torch"],
294 skip_backends=[("jax.numpy", "fails all around")],
295 xfail_backends=[("dask.array", "wrong answer")],
296 ),
297 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
298 allow_dask_compute=True, jax_jit=False,
299 reason="uses scipy.ndimage.rank_filter"),
300 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
301 jax_jit=False, allow_dask_compute=True),
302 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
303 jax_jit=False, allow_dask_compute=True),
304 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
305 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
306 "resample": xp_capabilities(
307 cpu_only=True, exceptions=["cupy"],
308 skip_backends=[
309 ("dask.array", "XXX something in dask"),
310 ("jax.numpy", "XXX: immutable arrays"),
311 ]
312 ),
313 "resample_poly": xp_capabilities(
314 cpu_only=True, exceptions=["cupy"],
315 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
316 extra_note=resample_poly_extra_note,
317 ),
318 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
319 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
320 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
321 jax_jit=False,
322 reason="convolve1d is cpu-only"),
323 "sepfir2d": xp_capabilities(np_only=True),
324 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
325 allow_dask_compute=True),
326 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
327 allow_dask_compute=True),
328 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
329 allow_dask_compute=True),
330 "sosfiltfilt": xp_capabilities(
331 cpu_only=True, exceptions=["cupy"],
332 skip_backends=[
333 (
334 "dask.array",
335 "sosfiltfilt directly sets shape attributes on arrays"
336 " which dask doesn't like"
337 ),
338 ("torch", "negative strides"),
339 ("jax.numpy", "sosfilt works in-place"),
340 ],
341 ),
342 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
343 jax_jit=False, allow_dask_compute=True),
344 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
345 jax_jit=False, allow_dask_compute=True),
346 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
347 allow_dask_compute=True),
348 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
349 allow_dask_compute=True),
350 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
351 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
352 allow_dask_compute=True,
353 reason="Cython implementation",
354 extra_note=upfirdn_extra_note),
355 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
356 allow_dask_compute=True, jax_jit=False),
357 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
358 allow_dask_compute=True, jax_jit=False,
359 reason="uses scipy.signal.correlate"),
360 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
361 allow_dask_compute=True),
362 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
363 allow_dask_compute=True,
364 extra_note=zpk2tf_extra_note),
365 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
366 "stft": xp_capabilities(out_of_scope=True), # legacy
367 "istft": xp_capabilities(out_of_scope=True), # legacy
368 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
369}
370
371
372# ### decorate ###
373for obj_name in _signal_api.__all__:
374 bare_obj = getattr(_signal_api, obj_name)
375 delegator = getattr(_delegators, obj_name + "_signature", None)
376
377 if SCIPY_ARRAY_API and delegator is not None:
378 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
379 else:
380 f = bare_obj
381
382 if not isinstance(f, types.ModuleType):
383 capabilities = capabilities_overrides.get(
384 obj_name, get_default_capabilities(obj_name, delegator)
385 )
386 f = capabilities(f)
387
388 # add the decorated function to the namespace, to be imported in __init__.py
389 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
340def firwin_signature(numtaps, cutoff, *args, **kwds):
341 if isinstance(cutoff, int | float):
342 xp = np_compat
343 else:
344 xp = array_namespace(cutoff)
345 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
torch |
jax |
dask |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
cupy |
torch |
jax |
|---|---|---|---|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function |
jax |
|---|---|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |