Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'abcd_normalize', 'bessel', 'besselap', 'envelope', 'get_window', 'lfilter_zi',
26 'sosfilt_zi', 'remez',
27]
28
29# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
30CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
31
32
33def delegate_xp(delegator, module_name):
34 def inner(func):
35 @functools.wraps(func)
36 def wrapper(*args, **kwds):
37 try:
38 xp = delegator(*args, **kwds)
39 except TypeError:
40 # object arrays
41 if func.__name__ == "tf2ss":
42 import numpy as np
43 xp = np
44 else:
45 raise
46
47 # try delegating to a cupyx/jax namesake
48 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
49 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
50
51 # https://github.com/cupy/cupy/issues/8336
52 import importlib
53 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
54 cupyx_func = getattr(cupyx_module, func_name)
55 kwds.pop('xp', None)
56 return cupyx_func(*args, **kwds)
57 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
58 spx = scipy_namespace_for(xp)
59 jax_module = getattr(spx, module_name)
60 jax_func = getattr(jax_module, func.__name__)
61 kwds.pop('xp', None)
62 return jax_func(*args, **kwds)
63 else:
64 # the original function
65 return func(*args, **kwds)
66 return wrapper
67 return inner
68
69
70# Although most of these functions currently exist in CuPy and some in JAX,
71# there are no alternative backend tests for any of them in the current
72# test suite. Each will be documented as np_only until tests are added.
73untested = {
74 "argrelextrema",
75 "argrelmax",
76 "argrelmin",
77 "band_stop_obj",
78 "bode",
79 "check_NOLA",
80 "chirp",
81 "coherence",
82 "csd",
83 "czt",
84 "czt_points",
85 "dbode",
86 "dfreqresp",
87 "dlsim",
88 "dstep",
89 "find_peaks",
90 "find_peaks_cwt",
91 "freqresp",
92 "gausspulse",
93 "iirdesign", # There's no reason this shouldn't work. It just needs tests.
94 "istft",
95 "lombscargle",
96 "lsim",
97 "max_len_seq",
98 "peak_prominences",
99 "peak_widths",
100 "periodogram",
101 "place_poles",
102 "sawtooth",
103 "sepfir2d",
104 "ss2tf",
105 "ss2zpk",
106 "step",
107 "sweep_poly",
108 "symiirorder1",
109 "symiirorder2",
110 "tf2ss",
111 "unit_impulse",
112 "welch",
113 "zoom_fft",
114 "zpk2ss",
115}
116
117
118def get_default_capabilities(func_name, delegator):
119 if delegator is None or func_name in untested:
120 return xp_capabilities(np_only=True)
121 return xp_capabilities()
122
123bilinear_extra_note = \
124 """CuPy does not accept complex inputs.
125
126 """
127
128uses_choose_conv_extra_note = \
129 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
130 but does support higher dimensional arrays for ``method="direct"``
131 and ``method="fft"``.
132
133 """
134
135resample_poly_extra_note = \
136 """CuPy only supports ``padtype="constant"``.
137
138 """
139
140upfirdn_extra_note = \
141 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
142
143 """
144
145xord_extra_note = \
146 """The ``torch`` backend on GPU does not support the case where
147 `wp` and `ws` specify a Bandstop filter.
148
149 """
150
151convolve2d_extra_note = \
152 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
153
154 """
155
156zpk2tf_extra_note = \
157 """The CuPy and JAX backends both support only 1d input.
158
159 """
160
161abcd_normalize_extra_note = \
162 """The result dtype when all array inputs are of integer dtype is the
163 backend's current default floating point dtype.
164
165 """
166
167capabilities_overrides = {
168 "abcd_normalize": xp_capabilities(extra_note=abcd_normalize_extra_note),
169 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
170 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
171 jax_jit=False, allow_dask_compute=True,
172 reason="Uses np.polynomial.Polynomial",
173 extra_note=bilinear_extra_note),
174 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
175 jax_jit=False, allow_dask_compute=True),
176 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
177 allow_dask_compute=True),
178 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
179 jax_jit=False, allow_dask_compute=True,
180 extra_note=xord_extra_note),
181 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
182 jax_jit=False, allow_dask_compute=True,
183 extra_note=xord_extra_note),
184 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
185 jax_jit=False, allow_dask_compute=True,
186 extra_note=xord_extra_note),
187 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
188 allow_dask_compute=True),
189
190 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
191 allow_dask_compute=True),
192 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
193 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
194 allow_dask_compute=True,
195 extra_note=uses_choose_conv_extra_note),
196 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
197 allow_dask_compute=True,
198 extra_note=convolve2d_extra_note),
199 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
200 allow_dask_compute=True,
201 extra_note=uses_choose_conv_extra_note),
202 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
203 allow_dask_compute=True,
204 extra_note=convolve2d_extra_note),
205 "correlation_lags": xp_capabilities(out_of_scope=True),
206 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
207 jax_jit=False, allow_dask_compute=True),
208 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
209 jax_jit=False, allow_dask_compute=True),
210 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
211 jax_jit=False, allow_dask_compute=True),
212 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
213 jax_jit=False, allow_dask_compute=True),
214 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
215 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
216 allow_dask_compute=True),
217 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
218 "dlti": xp_capabilities(np_only=True,
219 reason="works in CuPy but delegation isn't set up yet"),
220 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
221 allow_dask_compute=True,
222 reason="scipy.special.ellipk"),
223 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
224 jax_jit=False, allow_dask_compute=True,
225 reason="scipy.special.ellipk"),
226 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
227 jax_jit=False, allow_dask_compute=True),
228 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
229 reason="lstsq"),
230 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
231 jax_jit=False, allow_dask_compute=True),
232 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
233 jax_jit=False, allow_dask_compute=True,
234 reason="firwin uses np.interp"),
235 "fftconvolve": xp_capabilities(cpu_only=True,
236 exceptions=["cupy", "jax.numpy", "torch"]),
237 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
238 jax_jit=False, allow_dask_compute=True),
239 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
240 jax_jit=False, allow_dask_compute=True),
241 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
242 jax_jit=False, allow_dask_compute=True),
243 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
244 jax_jit=False, allow_dask_compute=True),
245 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
246 jax_jit=False, allow_dask_compute=True),
247 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
248 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
249 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
250 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
251 jax_jit=False, allow_dask_compute=True),
252 "kaiser_atten": xp_capabilities(
253 out_of_scope=True, reason="scalars in, scalars out"
254 ),
255 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
256 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
257 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
258 allow_dask_compute=True, jax_jit=False),
259 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
260 jax_jit=False),
261 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
262 allow_dask_compute=True),
263 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
264 allow_dask_compute=True, jax_jit=False),
265 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
266 allow_dask_compute=True, jax_jit=False),
267 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
268 allow_dask_compute=True, jax_jit=False),
269 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
270 allow_dask_compute=True, jax_jit=False),
271 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
272 allow_dask_compute=True, jax_jit=False),
273 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
274 allow_dask_compute=True, jax_jit=False),
275 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
276 allow_dask_compute=True, jax_jit=False),
277 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
278 allow_dask_compute=True, jax_jit=False),
279 "lti": xp_capabilities(np_only=True,
280 reason="works in CuPy but delegation isn't set up yet"),
281 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
282 allow_dask_compute=True, jax_jit=False,
283 reason="uses scipy.ndimage.rank_filter"),
284 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
285 allow_dask_compute=True, jax_jit=False,
286 reason="c extension module"),
287 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
288 allow_dask_compute=True, jax_jit=False),
289 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
290 jax_jit=False, allow_dask_compute=True),
291 "oaconvolve": xp_capabilities(
292 cpu_only=True, exceptions=["cupy", "torch"],
293 xfail_backends=[("dask.array", "wrong answer")],
294 ),
295 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
296 allow_dask_compute=True, jax_jit=False,
297 reason="uses scipy.ndimage.rank_filter"),
298 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
299 jax_jit=False, allow_dask_compute=True),
300 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
301 jax_jit=False, allow_dask_compute=True),
302 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
303 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
304 "resample": xp_capabilities(
305 cpu_only=True, exceptions=["cupy"],
306 skip_backends=[
307 ("dask.array", "XXX something in dask"),
308 ("jax.numpy", "XXX: immutable arrays"),
309 ]
310 ),
311 "resample_poly": xp_capabilities(
312 cpu_only=True, exceptions=["cupy"],
313 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
314 extra_note=resample_poly_extra_note,
315 ),
316 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
317 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
318 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
319 jax_jit=False,
320 reason="convolve1d is cpu-only"),
321 "sawtooth": xp_capabilities(jax_jit=False,
322 skip_backends=[("dask.array", "dask tests fail")]),
323 "sepfir2d": xp_capabilities(np_only=True),
324 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
325 allow_dask_compute=True),
326 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
327 allow_dask_compute=True),
328 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
329 allow_dask_compute=True),
330 "sosfilt_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
331 jax_jit=False),
332 "sosfiltfilt": xp_capabilities(
333 cpu_only=True, exceptions=["cupy"],
334 skip_backends=[
335 (
336 "dask.array",
337 "sosfiltfilt directly sets shape attributes on arrays"
338 " which dask doesn't like"
339 ),
340 ("torch", "negative strides"),
341 ("jax.numpy", "sosfilt works in-place"),
342 ],
343 ),
344 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
345 jax_jit=False, allow_dask_compute=True),
346 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
347 jax_jit=False, allow_dask_compute=True),
348 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
349 allow_dask_compute=True),
350 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
351 allow_dask_compute=True),
352 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
353 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
354 allow_dask_compute=True,
355 reason="Cython implementation",
356 extra_note=upfirdn_extra_note),
357 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
358 allow_dask_compute=True, jax_jit=False),
359 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
360 allow_dask_compute=True, jax_jit=False,
361 reason="uses scipy.signal.correlate"),
362 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
363 allow_dask_compute=True),
364 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
365 allow_dask_compute=True,
366 extra_note=zpk2tf_extra_note),
367 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
368 "stft": xp_capabilities(out_of_scope=True), # legacy
369 "istft": xp_capabilities(out_of_scope=True), # legacy
370 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
371}
372
373
374# ### decorate ###
375for obj_name in _signal_api.__all__:
376 bare_obj = getattr(_signal_api, obj_name)
377 delegator = getattr(_delegators, obj_name + "_signature", None)
378
379 if SCIPY_ARRAY_API and delegator is not None:
380 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
381 else:
382 f = bare_obj
383
384 if not isinstance(f, types.ModuleType):
385 capabilities = capabilities_overrides.get(
386 obj_name, get_default_capabilities(obj_name, delegator)
387 )
388 f = capabilities(f) # pyrefly:ignore[not-callable]
389
390 # add the decorated function to the namespace, to be imported in __init__.py
391 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
339def firwin_signature(numtaps, cutoff, *args, **kwds):
340 if isinstance(cutoff, int | float):
341 xp = np_compat
342 else:
343 xp = array_namespace(cutoff)
344 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
torch |
jax |
dask |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
cupy |
torch |
jax |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
jax |
|---|---|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |