Array API Standard Support: signal#
This page explains some caveats of the signal module and provides (currently
incomplete) tables about the
CPU,
GPU and
JIT support.
Caveats#
JAX and CuPy provide alternative
implementations for some signal functions. When such a function is called, a
decorator decides which implementation to use by inspecting the xp parameter.
Hence, there can be, especially during CI testing, discrepancies in behavior between the default NumPy-based implementation and the JAX and CuPy backends. Skipping the incompatible backends in unit tests, as described in the Adding tests section, is the currently recommended workaround.
The functions are decorated by the code in file
scipy/signal/_support_alternative_backends.py:
1import functools
2import types
3from scipy._lib._array_api import (
4 is_cupy, is_jax, scipy_namespace_for, SCIPY_ARRAY_API, xp_capabilities
5)
6
7from ._signal_api import * # noqa: F403
8from . import _signal_api
9from . import _delegators
10__all__ = _signal_api.__all__
11
12
13MODULE_NAME = 'signal'
14
15# jax.scipy.signal has only partial coverage of scipy.signal, so we keep the list
16# of functions we can delegate to JAX
17# https://jax.readthedocs.io/en/latest/jax.scipy.html
18JAX_SIGNAL_FUNCS = [
19 'fftconvolve', 'convolve', 'convolve2d', 'correlate', 'correlate2d',
20 'csd', 'detrend', 'istft', 'welch'
21]
22
23# some cupyx.scipy.signal functions are incompatible with their scipy counterparts
24CUPY_BLACKLIST = [
25 'abcd_normalize', 'bessel', 'besselap', 'envelope', 'get_window', 'lfilter_zi',
26 'sosfilt_zi', 'remez',
27]
28
29# freqz_sos is a sosfreqz rename, and cupy does not have the new name yet (in v13.x)
30CUPY_RENAMES = {'freqz_sos': 'sosfreqz'}
31
32
33def delegate_xp(delegator, module_name):
34 def inner(func):
35 @functools.wraps(func)
36 def wrapper(*args, **kwds):
37 try:
38 xp = delegator(*args, **kwds)
39 except TypeError:
40 # object arrays
41 if func.__name__ == "tf2ss":
42 import numpy as np
43 xp = np
44 else:
45 raise
46
47 # try delegating to a cupyx/jax namesake
48 if is_cupy(xp) and func.__name__ not in CUPY_BLACKLIST:
49 func_name = CUPY_RENAMES.get(func.__name__, func.__name__)
50
51 # https://github.com/cupy/cupy/issues/8336
52 import importlib
53 cupyx_module = importlib.import_module(f"cupyx.scipy.{module_name}")
54 cupyx_func = getattr(cupyx_module, func_name)
55 kwds.pop('xp', None)
56 return cupyx_func(*args, **kwds)
57 elif is_jax(xp) and func.__name__ in JAX_SIGNAL_FUNCS:
58 spx = scipy_namespace_for(xp)
59 jax_module = getattr(spx, module_name)
60 jax_func = getattr(jax_module, func.__name__)
61 kwds.pop('xp', None)
62 return jax_func(*args, **kwds)
63 else:
64 # the original function
65 return func(*args, **kwds)
66 return wrapper
67 return inner
68
69
70# Although most of these functions currently exist in CuPy and some in JAX,
71# there are no alternative backend tests for any of them in the current
72# test suite. Each will be documented as np_only until tests are added.
73untested = {
74 "argrelextrema",
75 "argrelmax",
76 "argrelmin",
77 "band_stop_obj",
78 "bode",
79 "check_NOLA",
80 "chirp",
81 "coherence",
82 "csd",
83 "czt",
84 "czt_points",
85 "dbode",
86 "dfreqresp",
87 "dlsim",
88 "dstep",
89 "find_peaks",
90 "find_peaks_cwt",
91 "freqresp",
92 "gausspulse",
93 "iirdesign", # There's no reason this shouldn't work. It just needs tests.
94 "istft",
95 "lombscargle",
96 "lsim",
97 "max_len_seq",
98 "peak_prominences",
99 "peak_widths",
100 "periodogram",
101 "place_poles",
102 "sepfir2d",
103 "ss2tf",
104 "ss2zpk",
105 "step",
106 "sweep_poly",
107 "symiirorder1",
108 "symiirorder2",
109 "tf2ss",
110 "unit_impulse",
111 "zoom_fft",
112 "zpk2ss",
113}
114
115
116def get_default_capabilities(func_name, delegator):
117 if delegator is None or func_name in untested:
118 return xp_capabilities(np_only=True)
119 return xp_capabilities()
120
121bilinear_extra_note = \
122 """CuPy does not accept complex inputs.
123
124 """
125
126uses_choose_conv_extra_note = \
127 """CuPy does not support inputs with ``ndim>1`` when ``method="auto"``
128 but does support higher dimensional arrays for ``method="direct"``
129 and ``method="fft"``.
130
131 """
132
133resample_poly_extra_note = \
134 """CuPy only supports ``padtype="constant"``.
135
136 """
137
138upfirdn_extra_note = \
139 """CuPy only supports ``mode="constant"`` and ``cval=0.0``.
140
141 """
142
143xord_extra_note = \
144 """The ``torch`` backend on GPU does not support the case where
145 `wp` and `ws` specify a Bandstop filter.
146
147 """
148
149convolve2d_extra_note = \
150 """The JAX backend only supports ``boundary="fill"`` and ``fillvalue=0``.
151
152 """
153
154zpk2tf_extra_note = \
155 """The CuPy and JAX backends both support only 1d input.
156
157 """
158
159abcd_normalize_extra_note = \
160 """The result dtype when all array inputs are of integer dtype is the
161 backend's current default floating point dtype.
162
163 """
164
165welch_extra_note = \
166 """Support for CuPy and JAX is provided by delegation to
167 ``cupyx.scipy.signal.welch`` and ``jax.scipy.signal.welch``.
168
169 For single-precision input (``float32`` or ``complex64``), JAX returns the sample
170 frequencies in ``float32``, whereas SciPy and CuPy always return them in
171 ``float64``.
172 """
173
174capabilities_overrides = {
175 "abcd_normalize": xp_capabilities(extra_note=abcd_normalize_extra_note),
176 "bessel": xp_capabilities(cpu_only=True, jax_jit=False, allow_dask_compute=True),
177 "bilinear": xp_capabilities(cpu_only=True, exceptions=["cupy"],
178 jax_jit=False, allow_dask_compute=True,
179 reason="Uses np.polynomial.Polynomial",
180 extra_note=bilinear_extra_note),
181 "bilinear_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
182 jax_jit=False, allow_dask_compute=True),
183 "butter": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
184 allow_dask_compute=True),
185 "buttord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
186 jax_jit=False, allow_dask_compute=True,
187 extra_note=xord_extra_note),
188 "cheb1ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
189 jax_jit=False, allow_dask_compute=True,
190 extra_note=xord_extra_note),
191 "cheb2ord": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
192 jax_jit=False, allow_dask_compute=True,
193 extra_note=xord_extra_note),
194 "cheby1": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
195 allow_dask_compute=True),
196
197 "cheby2": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
198 allow_dask_compute=True),
199 "cont2discrete": xp_capabilities(np_only=True, exceptions=["cupy"]),
200 "convolve": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
201 allow_dask_compute=True,
202 extra_note=uses_choose_conv_extra_note),
203 "convolve2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
204 allow_dask_compute=True,
205 extra_note=convolve2d_extra_note),
206 "correlate": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
207 allow_dask_compute=True,
208 extra_note=uses_choose_conv_extra_note),
209 "correlate2d": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
210 allow_dask_compute=True,
211 extra_note=convolve2d_extra_note),
212 "correlation_lags": xp_capabilities(out_of_scope=True),
213 "cspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
214 jax_jit=False, allow_dask_compute=True),
215 "cspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
216 jax_jit=False, allow_dask_compute=True),
217 "cspline2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
218 jax_jit=False, allow_dask_compute=True),
219 "deconvolve": xp_capabilities(cpu_only=True, exceptions=["cupy"],
220 jax_jit=False, allow_dask_compute=True),
221 "decimate": xp_capabilities(np_only=True, exceptions=["cupy"]),
222 "detrend": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
223 allow_dask_compute=True),
224 "dimpulse": xp_capabilities(np_only=True, exceptions=["cupy"]),
225 "dlti": xp_capabilities(np_only=True,
226 reason="works in CuPy but delegation isn't set up yet"),
227 "ellip": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
228 allow_dask_compute=True,
229 reason="scipy.special.ellipk"),
230 "ellipord": xp_capabilities(cpu_only=True, exceptions=["cupy"],
231 jax_jit=False, allow_dask_compute=True,
232 reason="scipy.special.ellipk"),
233 "filtfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
234 allow_dask_compute=True, jax_jit=False),
235 "findfreqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
236 jax_jit=False, allow_dask_compute=True),
237 "firls": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False,
238 reason="lstsq"),
239 "firwin": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
240 jax_jit=False, allow_dask_compute=True),
241 "firwin2": xp_capabilities(cpu_only=True, exceptions=["cupy"],
242 jax_jit=False, allow_dask_compute=True,
243 reason="firwin2 uses np.interp"),
244 "freqs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
245 jax_jit=False, allow_dask_compute=True),
246 "freqs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
247 jax_jit=False, allow_dask_compute=True),
248 "freqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
249 jax_jit=False, allow_dask_compute=True),
250 "freqz_sos": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
251 jax_jit=False, allow_dask_compute=True),
252 "group_delay": xp_capabilities(cpu_only=True, exceptions=["cupy"],
253 jax_jit=False, allow_dask_compute=True),
254 "invres": xp_capabilities(np_only=True, exceptions=["cupy"]),
255 "invresz": xp_capabilities(np_only=True, exceptions=["cupy"]),
256 "iircomb": xp_capabilities(xfail_backends=[("jax.numpy", "inaccurate")]),
257 "iirfilter": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
258 jax_jit=False, allow_dask_compute=True),
259 "kaiser_atten": xp_capabilities(
260 out_of_scope=True, reason="scalars in, scalars out"
261 ),
262 "kaiser_beta": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
263 "kaiserord": xp_capabilities(out_of_scope=True, reason="scalars in, scalars out"),
264 "lfilter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
265 allow_dask_compute=True, jax_jit=False),
266 "lfilter_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
267 jax_jit=False),
268 "lfiltic": xp_capabilities(cpu_only=True, exceptions=["cupy"],
269 allow_dask_compute=True, jax_jit=False),
270 "lp2bp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
271 allow_dask_compute=True, jax_jit=False),
272 "lp2bp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
273 allow_dask_compute=True, jax_jit=False),
274 "lp2bs": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
275 allow_dask_compute=True, jax_jit=False),
276 "lp2bs_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
277 allow_dask_compute=True, jax_jit=False),
278 "lp2lp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
279 allow_dask_compute=True, jax_jit=False),
280 "lp2lp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
281 allow_dask_compute=True, jax_jit=False),
282 "lp2hp": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
283 allow_dask_compute=True, jax_jit=False),
284 "lp2hp_zpk": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
285 allow_dask_compute=True, jax_jit=False),
286 "lti": xp_capabilities(np_only=True,
287 reason="works in CuPy but delegation isn't set up yet"),
288 "medfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"],
289 allow_dask_compute=True, jax_jit=False,
290 reason="uses scipy.ndimage.rank_filter"),
291 "medfilt2d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
292 allow_dask_compute=True, jax_jit=False,
293 reason="c extension module"),
294 "minimum_phase": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
295 allow_dask_compute=True, jax_jit=False),
296 "normalize": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
297 jax_jit=False, allow_dask_compute=True),
298 "oaconvolve": xp_capabilities(
299 cpu_only=True, exceptions=["cupy", "torch"],
300 xfail_backends=[("dask.array", "wrong answer")],
301 ),
302 "order_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
303 allow_dask_compute=True, jax_jit=False,
304 reason="uses scipy.ndimage.rank_filter"),
305 "qspline1d": xp_capabilities(cpu_only=True, exceptions=["cupy"],
306 jax_jit=False, allow_dask_compute=True),
307 "qspline1d_eval": xp_capabilities(cpu_only=True, exceptions=["cupy"],
308 jax_jit=False, allow_dask_compute=True),
309 "qspline2d": xp_capabilities(np_only=True, exceptions=["cupy"]),
310 "remez": xp_capabilities(cpu_only=True, allow_dask_compute=True, jax_jit=False),
311 "resample_poly": xp_capabilities(
312 cpu_only=True, exceptions=["cupy"],
313 jax_jit=False, skip_backends=[("dask.array", "XXX something in dask")],
314 extra_note=resample_poly_extra_note,
315 ),
316 "residue": xp_capabilities(np_only=True, exceptions=["cupy"]),
317 "residuez": xp_capabilities(np_only=True, exceptions=["cupy"]),
318 "savgol_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
319 jax_jit=False,
320 reason="convolve1d is cpu-only"),
321 "sawtooth": xp_capabilities(jax_jit=False,
322 skip_backends=[("dask.array", "dask tests fail")]),
323 "sos2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
324 allow_dask_compute=True),
325 "sos2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
326 allow_dask_compute=True),
327 "sosfilt": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
328 allow_dask_compute=True),
329 "sosfilt_zi": xp_capabilities(cpu_only=True, allow_dask_compute=True,
330 jax_jit=False),
331 "sosfiltfilt": xp_capabilities(
332 cpu_only=True, exceptions=["cupy"], jax_jit=False,
333 skip_backends=[
334 (
335 "dask.array",
336 "sosfiltfilt directly sets shape attributes on arrays"
337 " which dask doesn't like"
338 ),
339 ("torch", "negative strides"),
340 ],
341 ),
342 "sosfreqz": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
343 jax_jit=False, allow_dask_compute=True),
344 "spline_filter": xp_capabilities(cpu_only=True, exceptions=["cupy"],
345 jax_jit=False, allow_dask_compute=True),
346 "tf2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
347 allow_dask_compute=True),
348 "tf2zpk": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
349 allow_dask_compute=True),
350 "unique_roots": xp_capabilities(np_only=True, exceptions=["cupy"]),
351 "upfirdn": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
352 allow_dask_compute=True,
353 reason="Cython implementation",
354 extra_note=upfirdn_extra_note),
355 "vectorstrength": xp_capabilities(cpu_only=True, exceptions=["cupy", "torch"],
356 allow_dask_compute=True, jax_jit=False),
357 "welch": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
358 allow_dask_compute=True,
359 extra_note=welch_extra_note),
360 "wiener": xp_capabilities(cpu_only=True, exceptions=["cupy", "jax.numpy"],
361 allow_dask_compute=True, jax_jit=False,
362 reason="uses scipy.signal.correlate"),
363 "zpk2sos": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
364 allow_dask_compute=True),
365 "zpk2tf": xp_capabilities(cpu_only=True, exceptions=["cupy"], jax_jit=False,
366 allow_dask_compute=True,
367 extra_note=zpk2tf_extra_note),
368 "spectrogram": xp_capabilities(out_of_scope=True), # legacy
369 "stft": xp_capabilities(out_of_scope=True), # legacy
370 "istft": xp_capabilities(out_of_scope=True), # legacy
371 "check_COLA": xp_capabilities(out_of_scope=True), # legacy
372}
373
374
375# ### decorate ###
376for obj_name in _signal_api.__all__:
377 bare_obj = getattr(_signal_api, obj_name)
378 delegator = getattr(_delegators, obj_name + "_signature", None)
379
380 if SCIPY_ARRAY_API and delegator is not None:
381 f = delegate_xp(delegator, MODULE_NAME)(bare_obj)
382 else:
383 f = bare_obj
384
385 if not isinstance(f, types.ModuleType):
386 capabilities = capabilities_overrides.get(
387 obj_name, get_default_capabilities(obj_name, delegator)
388 )
389 f = capabilities(f) # pyrefly:ignore[not-callable]
390
391 # add the decorated function to the namespace, to be imported in __init__.py
392 vars()[obj_name] = f
Note that a function will only be decorated if the environment variable
SCIPY_ARRAY_API is set and its signature is listed in the file
scipy/signal/_delegators.py. E.g., for firwin, the signature
function looks like this:
339def firwin_signature(numtaps, cutoff, *args, **kwds):
340 if isinstance(cutoff, int | float):
341 xp = np_compat
342 else:
343 xp = array_namespace(cutoff)
344 return xp
Support on CPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
torch |
jax |
dask |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
N/A |
N/A |
N/A |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
Support on GPU#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
cupy |
torch |
jax |
|---|---|---|---|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
✔️ |
✖ |
✔️ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
N/A |
N/A |
N/A |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✔️ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
N/A |
N/A |
N/A |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✔️ |
✔️ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✔️ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
|
✖ |
✖ |
✖ |
|
✔️ |
✖ |
✖ |
Support with JIT#
Legend
✔️ = supported
✖ = unsupported
N/A = out-of-scope
function/class |
jax |
|---|---|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✔️ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✔️ |
|
✖ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✔️ |
|
✔️ |
|
✖ |
|
✖ |
|
N/A |
|
N/A |
|
N/A |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
N/A |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✔️ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |
|
✖ |