Skip to content

Sharding: "Hip error: 'operation would make the legacy stream ..." #275

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
PhilipVinc opened this issue Mar 11, 2025 · 8 comments
Open
Labels
bug Something isn't working Under Investigation

Comments

@PhilipVinc
Copy link

PhilipVinc commented Mar 11, 2025

Description

I am running on 'bare metal' jax/lib 0.4.35 installed from the releases in this repository.
I had to work around the missing libsuitesparseconfig.so.4 by manually installing it.

When running a relatively straightforward script that uses sharding on a single process addressing 4 local GPUS I get weird errors.

The script and pyproject file can be found at this gist.

export NETKET_ENABLE_X64=1
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0

[cad14908] fvicentini@a1004:~/rep2$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
 - NETKET_DEBUG = False
 - NETKET_EXPERIMENTAL = False
 - NETKET_MPI_WARNING = True
 - NETKET_MPI = False
 - NETKET_USE_PLAIN_RHAT = False
 - NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
 - NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
 - NETKET_EXPERIMENTAL_SHARDING_CPU = 0
 - NETKET_ENABLE_X64 = True
 - NETKET_SPHINX_BUILD = False
 - NETKET_EXPERIMENTAL_SHARDING = True
 - NETKET_MPI_AUTODETECT_LOCAL_GPU = False
 - NETKET_RANDOM_STATE_FALLBACK_WARNING = True
 - NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
 - NETKET_SPIN_ORDERING_WARNING = True

jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
  0%|                                                                                                                                                | 0/500 [00:00<?, ?it/s]
Hip error: 'operation would make the legacy stream depend on a capturing blocking stream'(906) at /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/hipBLASLt/library/src/amd_detail/hipblaslt.cpp:135

Hip error: 'operation would make the legacy stream depend on a capturing blocking stream'(906) at /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/hipBLASLt/library/src/amd_detail/hipblaslt.cpp:135

rocBLAS error: Could not initialize Tensile host:
_Map_base::at

System info (python version, jaxlib version, accelerator, etc.)

[cad14908] fvicentini@a1002:~/rep2/lib$ uv run python -c 'import jax; jax.print_environment_info()'
jax:    0.4.33
jaxlib: 0.4.33
numpy:  2.1.3
python: 3.12.1 (main, Sep 18 2024, 23:46:30) [GCC 12.2.1 20221121 (Red Hat 12.2.1-7)]
jax.devices (4 total, 4 local): [RocmDevice(id=0) RocmDevice(id=1) RocmDevice(id=2) RocmDevice(id=3)]
process_count: 1
platform: uname_result(system='Linux', node='a1002', release='5.14.0-427.26.1.el9_4.x86_64', version='#1 SMP PREEMPT_DYNAMIC Fri Jul 5 11:34:54 EDT 2024', machine='x86_64')

============================================= ROCm System Management Interface =============================================
======================================================= Concise Info =======================================================
Device  Node  IDs              Temp        Power     Partitions          SCLK    MCLK     Fan  Perf    PwrCap  VRAM%  GPU%
              (DID,     GUID)  (Junction)  (Socket)  (Mem, Compute, ID)
============================================================================================================================
0       4     0x74a0,   16722  45.0°C      110.0W    NPS1, SPX, 0        102Mhz  1200Mhz  0%   manual  550.0W  0%     0%
1       5     0x74a0,   8346   46.0°C      70.0W     NPS1, SPX, 0        94Mhz   900Mhz   0%   manual  550.0W  0%     0%
2       6     0x74a0,   33475  45.0°C      106.0W    NPS1, SPX, 0        94Mhz   1200Mhz  0%   manual  550.0W  0%     0%
3       7     0x74a0,   25611  47.0°C      72.0W     NPS1, SPX, 0        95Mhz   900Mhz   0%   manual  550.0W  0%     0%
============================================================================================================================
=================================================== End of ROCm SMI Log ====================================================
@PhilipVinc PhilipVinc added the bug Something isn't working label Mar 11, 2025
@PhilipVinc
Copy link
Author

Hello @ppanchad-amd , is there any news about this issue? It's a total blocker for us.

@lucbruni-amd
Copy link

Hi @PhilipVinc, I am investigating this, could you clarify if you are still using Jax 0.4.35 for this, or 0.5.0?

@PhilipVinc
Copy link
Author

PhilipVinc commented Mar 27, 2025

Hi @lucbruni-amd I was using jax 0.4.35 because of the issues in #268.
With the workaround from #268 I was now able to run it with jax 0.5.0.

The code still errors with jax 0.5.0, but the error is different.

See below

[cad14908] fvicentini@a1018:~/rep2$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
 - NETKET_DEBUG = False
 - NETKET_EXPERIMENTAL = False
 - NETKET_MPI_WARNING = True
 - NETKET_MPI = False
 - NETKET_USE_PLAIN_RHAT = False
 - NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
 - NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
 - NETKET_EXPERIMENTAL_SHARDING_CPU = 0
 - NETKET_ENABLE_X64 = True
 - NETKET_SPHINX_BUILD = False
 - NETKET_EXPERIMENTAL_SHARDING = True
 - NETKET_MPI_AUTODETECT_LOCAL_GPU = False
 - NETKET_RANDOM_STATE_FALLBACK_WARNING = True
 - NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
 - NETKET_SPIN_ORDERING_WARNING = True

jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
  0%|                                                                                                                                                             | 0/500 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "/lus/home/CT5/cad14908/fvicentini/rep2/simple_nocnn.py", line 44, in <module>
    gs.run(n_iter=500, out=log, timeit=True)
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/abstract_variational_driver.py", line 347, in run
    for step in self.iter(n_iter, step_size):
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/abstract_variational_driver.py", line 232, in iter
    self._dp = self._forward_and_backward()
               ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
    result = fun(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/driver/vmc.py", line 122, in _forward_and_backward
    self._loss_stats, self._loss_grad = self.state.expect_and_grad(self._ham)
                                        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
    result = fun(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 710, in expect_and_grad
    return expect_and_grad(
           ^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect_grad.py", line 58, in expect_and_grad_default_formula
    Ō, Ō_grad = expect_and_forces(vstate, Ô, chunk_size, *args, mutable=mutable)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect_forces.py", line 48, in expect_and_forces
    σ, args = get_local_kernel_arguments(vstate, Ô)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/plum/function.py", line 383, in __call__
    return _convert(method(*args, **kw_args), return_type)
                    ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/expect.py", line 61, in get_local_kernel_arguments
    σ = vstate.samples
        ^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 611, in samples
    self.sample()
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/utils/timing.py", line 276, in timed_function
    result = fun(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/vqs/mc/mc_state/state.py", line 575, in sample
    self.sampler_state = self.sampler.reset(
                         ^^^^^^^^^^^^^^^^^^^
  File "/lus/home/CT5/cad14908/fvicentini/rep2/.venv/lib/python3.12/site-packages/netket/sampler/base.py", line 276, in reset
    return sampler._reset(wrap_afun(machine), parameters, state)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(98)
--------------------
For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

and the jax version is

[cad14908] fvicentini@a1018:~/rep2$ uv tree | grep jax
Resolved 46 packages in 4ms
├── jax[rocm] v0.5.0
│   ├── jaxlib v0.5.0
│   ├── jax-rocm60-plugin v0.5.0 (extra: rocm)
│   │   └── jax-rocm60-pjrt v0.5.0
│   └── jaxlib v0.5.0 (extra: rocm) (*)
├── jax-rocm60-pjrt v0.5.0
├── jax-rocm60-plugin v0.5.0 (*)
├── jaxlib v0.5.0 (*)
    │   ├── jax v0.5.0 (*)
    │   ├── jaxtyping v0.3.0
    │   ├── jax v0.5.0 (*)
    │   │   │   ├── jax v0.5.0 (*)
    │   │   │   ├── jaxlib v0.5.0 (*)
    │   │   ├── jax v0.5.0 (*)
    │   │   ├── jaxlib v0.5.0 (*)
    │   │   ├── jax v0.5.0 (*)
    ├── jax v0.5.0 (*)

@lucbruni-amd
Copy link

Thanks for the helpful info. I'll respond here when I have updates on this.

@lucbruni-amd
Copy link

Hi @PhilipVinc,

Just wanted to let you know I'm able to consistently reproduce your latter error jaxlib.xla_extension.XlaRuntimeError: INTERNAL: Failed to launch ROCm kernel: redzone_checker with block dimensions: 1024x1x1: hipError_t(98), and I'm working with the team to get to the bottom of it. Thanks for your patience.

@PhilipVinc
Copy link
Author

Thank you for the update.
Please do let me know if you manage to solve it.

@PhilipVinc
Copy link
Author

PhilipVinc commented Apr 11, 2025

Hi @lucbruni-amd is there any update on your side?

I just tried rerunning the example from this bug report on
ROCM/6.3.3
'jax[rocm]' (installed from pypi)
Notice that previous bug reports where using ROCM/6.2.1 and jax installed from the release tarballs of this GitHub repository.

And the outcome is again different.
It actually reports a deadlock (and remains deadlocked after 20 minutes).

export NETKET_ENABLE_X64=1
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0

[cad14908] fvicentini@a1004:~/rep4$ uv run simple_nocnn.py
Config:
Global configurations for NetKet
 - NETKET_DEBUG = False
 - NETKET_EXPERIMENTAL = False
 - NETKET_MPI_WARNING = True
 - NETKET_MPI = False
 - NETKET_USE_PLAIN_RHAT = False
 - NETKET_EXPERIMENTAL_FFT_AUTOCORRELATION = False
 - NETKET_EXPERIMENTAL_DISABLE_ODE_JIT = True
 - NETKET_EXPERIMENTAL_SHARDING_CPU = 0
 - NETKET_ENABLE_X64 = True
 - NETKET_SPHINX_BUILD = False
 - NETKET_EXPERIMENTAL_SHARDING = True
 - NETKET_MPI_AUTODETECT_LOCAL_GPU = False
 - NETKET_RANDOM_STATE_FALLBACK_WARNING = True
 - NETKET_EXPERIMENTAL_SHARDING_FAST_SERIALIZATION = False
 - NETKET_SPIN_ORDERING_WARNING = True

jax global devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
jax local devices [RocmDevice(id=0), RocmDevice(id=1), RocmDevice(id=2), RocmDevice(id=3)]
  0%|                                                                                                           | 0/500 [00:00<?, ?it/s]2025-04-11 17:35:12.675020: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 2; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-04-11 17:35:12.676137: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 3; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.
2025-04-11 17:35:12.676165: E external/xla/xla/service/rendezvous.cc:90] This thread has been waiting for `initialize clique for rank 1; clique=devices=[0,1,2,3]; stream=0; groups=[[0,1,2,3]]; root_device=-1; run_id=1452947152` for 10 seconds and may be stuck. All 4 threads joined the rendezvous, however the leader has not marked the rendezvous as completed. Leader can be deadlocked inside the rendezvous callback.

Do you have any timeline for a fix? this is a 100% blocker for us

@PhilipVinc
Copy link
Author

It seems that it is an x64 (double precision) issue, as setting

export NETKET_ENABLE_X64=0
export NETKET_EXPERIMENTAL_SHARDING=1
export NETKET_MPI=0

makes the code run.
Though if I benchmark versus a node with 4 Nvidia gpu (V100) the AMD is 50 times slower...

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working Under Investigation
Projects
None yet
Development

No branches or pull requests

3 participants