Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'abcd_normalize', 'bessel', 'besselap', 'envelope', 'get_window', 'lfilter_zi',
26 'sosfilt_zi', 'remez',
27]
28
29# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
30CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
31
32
33def delegate_xp(delegator, module_name):
34 def inner(func):
35 @functools.wraps(func)
36 def wrapper(*args, **kwds):
37 try:
38 xp = delegator(*args, **kwds)
39 except TypeError:
40 # object arrays
41 if func.__name__ == "tf2ss":
42 import numpy as np
43 xp = np
44 else:
45 raise
46
47 # try delegating to a cupyx/jax namesake
48 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
49 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
50
51 # https://github.com/cupy/cupy/issues/8336
52 import importlib
53 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
54 cupyx_func = getattr(cupyx_module, func_name)
55 kwds.pop('xp', None)
56 return cupyx_func(*args, **kwds)
57 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
58 spx = scipy_namespace_for(xp)
59 jax_module = getattr(spx, module_name)
60 jax_func = getattr(jax_module, func.__name__)
61 kwds.pop('xp', None)
62 return jax_func(*args, **kwds)
63 else:
64 # the original function
65 return func(*args, **kwds)
66 return wrapper
67 return inner
68
69
70# Although most of these functions currently exist in CuPy and some in JAX,
71# there are no alternative backend tests for any of them in the current
72# test suite. Each will be documented as np_only until tests are added.
73untested = {
74 "argrelextrema",
75 "argrelmax",
76 "argrelmin",
77 "band_stop_obj",
78 "bode",
79 "check_NOLA",
80 "chirp",
81 "coherence",
82 "csd",
83 "czt",
84 "czt_points",
85 "dbode",
86 "dfreqresp",
87 "dlsim",
88 "dstep",
89 "find_peaks",
90 "find_peaks_cwt",
91 "freqresp",
92 "gausspulse",
93 "iirdesign", # There's no reason this shouldn't work. It just needs tests.
94 "istft",
95 "lombscargle",
96 "lsim",
97 "max_len_seq",
98 "peak_prominences",
99 "peak_widths",
100 "periodogram",
101 "place_poles",
102 "sepfir2d",
103 "ss2tf",
104 "ss2zpk",
105 "step",
106 "sweep_poly",
107 "symiirorder1",
108 "symiirorder2",
109 "tf2ss",
110 "unit_impulse",
111 "zoom_fft",
112 "zpk2ss",
113}
114
115
116def get_default_capabilities(func_name, delegator):
117 if delegator is None or func_name in untested:
118 return xp_capabilities(np_only=True)
119 return xp_capabilities()
120
121bilinear_extra_note = \
122 """CuPy does not accept complex inputs.
123
124 """
125
126uses_choose_conv_extra_note = \
127 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
128 but does support higher dimensional arrays for ``method="direct"``
129 and ``method="fft"``.
130
131 """
132
133resample_poly_extra_note = \
134 """CuPy only supports ``padtype="constant"``.
135
136 """
137
138upfirdn_extra_note = \
139 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
140
141 """
142
143xord_extra_note = \
144 """The ``torch`` backend on GPU does not support the case where
145 `wp` and `ws` specify a Bandstop filter.
146
147 """
148
149convolve2d_extra_note = \
150 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
151
152 """
153
154zpk2tf_extra_note = \
155 """The CuPy and JAX backends both support only 1d input.
156
157 """
158
159abcd_normalize_extra_note = \
160 """The result dtype when all array inputs are of integer dtype is the
161 backend's current default floating point dtype.
162
163 """
164
165welch_extra_note = \
166 """Support for CuPy and JAX is provided by delegation to
167 ``cupyx.scipy.signal.welch`` and ``jax.scipy.signal.welch``.
168
169 For single-precision input (``float32`` or ``complex64``), JAX returns the sample
170 frequencies in ``float32``, whereas SciPy and CuPy always return them in
171 ``float64``.
172 """
173
174capabilities_overrides = {
175 "abcd_normalize": xp_capabilities(extra_note=abcd_normalize_extra_note),
176 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
177 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
178 jax_jit=False, allow_dask_compute=True,
179 reason="Uses np.polynomial.Polynomial",
180 extra_note=bilinear_extra_note),
181 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
182 jax_jit=False, allow_dask_compute=True),
183 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
184 allow_dask_compute=True),
185 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
186 jax_jit=False, allow_dask_compute=True,
187 extra_note=xord_extra_note),
188 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
189 jax_jit=False, allow_dask_compute=True,
190 extra_note=xord_extra_note),
191 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
192 jax_jit=False, allow_dask_compute=True,
193 extra_note=xord_extra_note),
194 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
195 allow_dask_compute=True),
196
197 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
198 allow_dask_compute=True),
199 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
200 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
201 allow_dask_compute=True,
202 extra_note=uses_choose_conv_extra_note),
203 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
204 allow_dask_compute=True,
205 extra_note=convolve2d_extra_note),
206 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
207 allow_dask_compute=True,
208 extra_note=uses_choose_conv_extra_note),
209 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
210 allow_dask_compute=True,
211 extra_note=convolve2d_extra_note),
212 "correlation_lags": xp_capabilities(out_of_scope=True),
213 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
214 jax_jit=False, allow_dask_compute=True),
215 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
216 jax_jit=False, allow_dask_compute=True),
217 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
218 jax_jit=False, allow_dask_compute=True),
219 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
220 jax_jit=False, allow_dask_compute=True),
221 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
222 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
223 allow_dask_compute=True),
224 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
225 "dlti": xp_capabilities(np_only=True,
226 reason="works in CuPy but delegation isn't set up yet"),
227 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
228 allow_dask_compute=True,
229 reason="scipy.special.ellipk"),
230 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
231 jax_jit=False, allow_dask_compute=True,
232 reason="scipy.special.ellipk"),
233 "filtfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
234 allow_dask_compute=True, jax_jit=False),
235 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
236 jax_jit=False, allow_dask_compute=True),
237 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
238 reason="lstsq"),
239 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
240 jax_jit=False, allow_dask_compute=True),
241 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
242 jax_jit=False, allow_dask_compute=True,
243 reason="firwin2 uses np.interp"),
244 "fftconvolve": xp_capabilities(cpu_only=True,
245 exceptions=["cupy", "jax.numpy", "torch"]),
246 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
247 jax_jit=False, allow_dask_compute=True),
248 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
249 jax_jit=False, allow_dask_compute=True),
250 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
251 jax_jit=False, allow_dask_compute=True),
252 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
253 jax_jit=False, allow_dask_compute=True),
254 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
255 jax_jit=False, allow_dask_compute=True),
256 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
257 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
258 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
259 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
260 jax_jit=False, allow_dask_compute=True),
261 "kaiser_atten": xp_capabilities(
262 out_of_scope=True, reason="scalars in, scalars out"
263 ),
264 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
265 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
266 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
267 allow_dask_compute=True, jax_jit=False),
268 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
269 jax_jit=False),
270 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
271 allow_dask_compute=True, jax_jit=False),
272 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
273 allow_dask_compute=True, jax_jit=False),
274 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
275 allow_dask_compute=True, jax_jit=False),
276 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
277 allow_dask_compute=True, jax_jit=False),
278 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
279 allow_dask_compute=True, jax_jit=False),
280 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
281 allow_dask_compute=True, jax_jit=False),
282 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
283 allow_dask_compute=True, jax_jit=False),
284 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
285 allow_dask_compute=True, jax_jit=False),
286 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
287 allow_dask_compute=True, jax_jit=False),
288 "lti": xp_capabilities(np_only=True,
289 reason="works in CuPy but delegation isn't set up yet"),
290 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
291 allow_dask_compute=True, jax_jit=False,
292 reason="uses scipy.ndimage.rank_filter"),
293 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
294 allow_dask_compute=True, jax_jit=False,
295 reason="c extension module"),
296 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
297 allow_dask_compute=True, jax_jit=False),
298 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
299 jax_jit=False, allow_dask_compute=True),
300 "oaconvolve": xp_capabilities(
301 cpu_only=True, exceptions=["cupy", "torch"],
302 xfail_backends=[("dask.array", "wrong answer")],
303 ),
304 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
305 allow_dask_compute=True, jax_jit=False,
306 reason="uses scipy.ndimage.rank_filter"),
307 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
308 jax_jit=False, allow_dask_compute=True),
309 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
310 jax_jit=False, allow_dask_compute=True),
311 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
312 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
313 "resample_poly": xp_capabilities(
314 cpu_only=True, exceptions=["cupy"],
315 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
316 extra_note=resample_poly_extra_note,
317 ),
318 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
319 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
320 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
321 jax_jit=False,
322 reason="convolve1d is cpu-only"),
323 "sawtooth": xp_capabilities(jax_jit=False,
324 skip_backends=[("dask.array", "dask tests fail")]),
325 "sepfir2d": xp_capabilities(np_only=True),
326 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
327 allow_dask_compute=True),
328 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
329 allow_dask_compute=True),
330 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
331 allow_dask_compute=True),
332 "sosfilt_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
333 jax_jit=False),
334 "sosfiltfilt": xp_capabilities(
335 cpu_only=True, exceptions=["cupy"], jax_jit=False,
336 skip_backends=[
337 (
338 "dask.array",
339 "sosfiltfilt directly sets shape attributes on arrays"
340 " which dask doesn't like"
341 ),
342 ("torch", "negative strides"),
343 ],
344 ),
345 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
346 jax_jit=False, allow_dask_compute=True),
347 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
348 jax_jit=False, allow_dask_compute=True),
349 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
350 allow_dask_compute=True),
351 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
352 allow_dask_compute=True),
353 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
354 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
355 allow_dask_compute=True,
356 reason="Cython implementation",
357 extra_note=upfirdn_extra_note),
358 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
359 allow_dask_compute=True, jax_jit=False),
360 "welch": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
361 allow_dask_compute=True,
362 extra_note=welch_extra_note),
363 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
364 allow_dask_compute=True, jax_jit=False,
365 reason="uses scipy.signal.correlate"),
366 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
367 allow_dask_compute=True),
368 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
369 allow_dask_compute=True,
370 extra_note=zpk2tf_extra_note),
371 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
372 "stft": xp_capabilities(out_of_scope=True), # legacy
373 "istft": xp_capabilities(out_of_scope=True), # legacy
374 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
375}
376
377
378# ### decorate ###
379for obj_name in _signal_api.__all__:
380 bare_obj = getattr(_signal_api, obj_name)
381 delegator = getattr(_delegators, obj_name + "_signature", None)
382
383 if SCIPY_ARRAY_API and delegator is not None:
384 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
385 else:
386 f = bare_obj
387
388 if not isinstance(f, types.ModuleType):
389 capabilities = capabilities_overrides.get(
390 obj_name, get_default_capabilities(obj_name, delegator)
391 )
392 f = capabilities(f) # pyrefly:ignore[not-callable]
393
394 # add the decorated function to the namespace, to be imported in __init__.py
395 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
339def firwin_signature(numtaps, cutoff, *args, **kwds):
340 if isinstance(cutoff, int | float):
341 xp = np_compat
342 else:
343 xp = array_namespace(cutoff)
344 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
torch |
jax |
dask |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
cupy |
torch |
jax |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
jax |
|---|---|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |