8000 [train] Can not start training on more than one node · Issue #54065 · ray-project/ray · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
[train] Can not start training on more than one node #54065
Open
@Rex-dby

Description

@Rex-dby

Hi, I just trying to use the training example code for test, it works well on a single node, but fails with two nodes.
Is there any reason? Thanks!
logs

(swift) (base) bydeng@iv-ydgjd36ghsk36d113bdn:/data2/workspace/bydeng/ray_serve$ python train.py 
2025-06-25 09:15:13,372 INFO worker.py:1723 -- Connecting to existing Ray cluster at address: 172.31.0.7:5001...
2025-06-25 09:15:13,384 INFO worker.py:1908 -- Connected to Ray cluster. View the dashboard at 127.0.0.1:7779 
2025-06-25 09:15:13,503 WARNING tune_controller.py:2132 -- The maximum number of pending trials has been automatically set to the number of available cluster CPUs, which is high (396 CPUs/pending trials). If you're running an experiment with a large number of trials, this could lead to scheduling overhead. In this case, consider setting the `TUNE_MAX_PENDING_TRIALS_PG` environment variable to the desired maximum number of concurrent pending trials.
2025-06-25 09:15:13,505 WARNING tune_controller.py:2132 -- The maximum number of pending trials has been automatically set to the number of available cluster CPUs, which is high (396 CPUs/pending trials). If you're running an experiment with a large number of trials, this could lead to scheduling overhead. In this case, consider setting the `TUNE_MAX_PENDING_TRIALS_PG` environment variable to the desired maximum number of concurrent pending trials.

View detailed results here: /data2/users/bydeng/ray_results/TorchTrainer_2025-06-25_09-15-13
To visualize your results with TensorBoard, run: `tensorboard --logdir /tmp/ray/session_2025-06-24_16-12-09_684429_738797/artifacts/2025-06-25_09-15-13/TorchTrainer_2025-06-25_09-15-13/driver_artifacts`

Training started with configuration:
╭─────────────────────────────────────────────────╮
│ Training config                                 │
├─────────────────────────────────────────────────┤
│ train_loop_config/batch_size_per_worker      10 │
│ train_loop_config/epochs                     10 │
│ train_loop_config/lr                      0.001 │
╰─────────────────────────────────────────────────╯
(RayTrainWorker pid=1674498) Setting up process group for: env:// [rank=0, world_size=4]
(TorchTrainer pid=1674340) Started distributed worker processes: 
(TorchTrainer pid=1674340) - (node_id=2c133bf1692d7d9387e22f6ed2fbf02370f66c98872406f83d42ebdd, ip=172.31.0.7, pid=1674498) world_rank=0, local_rank=0, node_rank=0
(TorchTrainer pid=1674340) - (node_id=2c133bf1692d7d9387e22f6ed2fbf02370f66c98872406f83d42ebdd, ip=172.31.0.7, pid=1674496) world_rank=1, local_rank=1, node_rank=0
(TorchTrainer pid=1674340) - (node_id=2c133bf1692d7d9387e22f6ed2fbf02370f66c98872406f83d42ebdd, ip=172.31.0.7, pid=1674497) world_rank=2, local_rank=2, node_rank=0
(TorchTrainer pid=1674340) - (node_id=37414c22c369ec5ddd61cbe81e7f869521068094b127d10af1bfeb69, ip=172.31.16.4, pid=3871886) world_rank=3, local_rank=0, node_rank=1
  0%|          | 0.00/26.4M [00:00<?, ?B/s]) 
  0%|          | 32.
B4B4
8k/26.4M [00:00<02:12, 199kB/s]
  0%|          | 65.5k/26.4M [00:00<02:13, 198kB/s]
  0%|          | 131k/26.4M [00:00<01:31, 288kB/s] 
100%|██████████| 26.4M/26.4M [00:02<00:00, 10.6MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 188kB/s]
  0%|          | 0.00/4.42M [00:00<?, ?B/s] [repeated 3x across cluster] (Ray deduplicates logs by default. Set RAY_DEDUP_LOGS=0 to disable log deduplication, or see https://docs.ray.io/en/master/ray-observability/user-guides/configure-logging.html#log-deduplication for more options.)
  5%|▌         | 229k/4.42M [00:00<00:12, 343kB/s] [repeated 40x across cluster]
(RayTrainWorker pid=3871886, ip=172.31.16.4) Moving model to device: cuda:0
(RayTrainWorker pid=3871886, ip=172.31.16.4) Wrapping provided model in DistributedDataParallel.
 92%|█████████▏| 24.4M/26.4M [00:06<00:00, 4.42MB/s]
 96%|█████████▌| 25.3M/26.4M [00:06<00:00, 4.45MB/s]
100%|██████████| 26.4M/26.4M [00:07<00:00, 3.74MB/s]
100%|██████████| 29.5k/29.5k [00:00<00:00, 185kB/s]
  0%|          | 0.00/4.42M [00:00<?, ?B/s] [repeated 3x across cluster]
  1%|▏         | 65.5k/4.42M [00:00<00:20, 211kB/s] [repeated 15x across cluster]
100%|██████████| 4.42M/4.42M [00:01<00:00, 2.53MB/s]
(RayTrainWorker pid=1674498) Moving model to device: cuda:0
(RayTrainWorker pid=1674498) Wrapping provided model in DistributedDataParallel.
(RayTrainWorker pid=3871886, ip=172.31.16.4) *** SIGSEGV received at time=1750814135 on cpu 91 ***
(RayTrainWorker pid=3871886, ip=172.31.16.4) PC: @     0x7ef253354ff3  (unknown)  ncclTopoGetLocal()
(RayTrainWorker pid=3871886, ip=172.31.16.4)     @     0x7f235940c420  (unknown)  (unknown)
(RayTrainWorker pid=3871886, ip=172.31.16.4) [2025-06-25 09:15:35,532 E 3871886 3872364] logging.cc:496: *** SIGSEGV received at time=1750814135 on cpu 91 ***
(RayTrainWorker pid=3871886, ip=172.31.16.4) [2025-06-25 09:15:35,532 E 3871886 3872364] logging.cc:496: PC: @     0x7ef253354ff3  (unknown)  ncclTopoGetLocal()
(RayTrainWorker pid=3871886, ip=172.31.16.4) [2025-06-25 09:15:35,532 E 3871886 3872364] logging.cc:496:     @     0x7f235940c420  (unknown)  (unknown)
(RayTrainWorker pid=3871886, ip=172.31.16.4) Fatal Python error: Segmentation fault
(RayTrainWorker pid=3871886, ip=172.31.16.4) 
(RayTrainWorker pid=3871886, ip=172.31.16.4) 
(RayTrainWorker pid=3871886, ip=172.31.16.4) Extension modules: msgpack._cmsgpack, google._upb._message, psutil._psutil_linux, psutil._psutil_posix, setproctitle, yaml._yaml, charset_normalizer.md, requests.packages.charset_normalizer.md, requests.packages.chardet.md, ray._raylet, numpy._core._multiarray_umath, numpy.linalg._umath_linalg, pyarrow.lib, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, pyarrow._parquet, pyarrow._json, torch._C, torch._C._dynamo.autograd_compiler, torch._C._dynamo.eval_frame, torch._C._dynamo.guards, torch._C._dynamo.utils, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, PIL._imaging, PIL._imagingft (total: 83)
(raylet) A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff275554ab0c907cb06bd2b3b10f000000 Worker ID: c698eb8395e48159df9d0358a72f76065dac39bce1f0dc8b8a0a8008 Node ID: 37414c22c369ec5ddd61cbe81e7f869521068094b127d10af1bfeb69 Worker IP address: 172.31.16.4 Worker port: 10013 Worker PID: 3871886 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.
(TorchTrainer pid=1674340) Worker 3 has failed.
2025-06-25 09:15:35,897 ERROR tune_controller.py:1331 -- Trial task failed for trial TorchTrainer_d887c_00000
Traceback (most recent call last):
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/air/execution/_internal/event_manager.py", line 110, in resolve_future
    result = ray.get(future)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/_private/auto_init_hook.py", line 22, in auto_init_wrapper
    return fn(*args, **kwargs)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/_private/client_mode_hook.py", line 104, in wrapper
    return func(*args, **kwargs)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/_private/worker.py", line 2849, in get
    values, debugger_breakpoint = worker.get_objects(object_refs, timeout=timeout)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/_private/worker.py", line 937, in get_objects
    raise value.as_instanceof_cause()
ray.exceptions.RayTaskError(ActorDiedError): ray::_Inner.train() (pid=1674340, ip=172.31.0.7, actor_id=5d7c5e8e421ac5fc06b351790f000000, repr=TorchTrainer)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 330, in train
    raise skipped from exception_cause(skipped)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/air/_internal/util.py", line 107, in run
    self._ret = self._target(*self._args, **self._kwargs)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 45, in <lambda>
    training_func=lambda: self._trainable_func(self.config),
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/base_trainer.py", line 883, in _trainable_func
    super()._trainable_func(self._merged_config)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 261, in _trainable_func
    output = fn()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/base_trainer.py", line 123, in _train_coordinator_fn
    trainer.training_loop()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/data_parallel_trainer.py", line 470, in training_loop
    self._run_training(training_iterator)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/data_parallel_trainer.py", line 369, in _run_training
    for training_results in training_iterator:
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 129, in __next__
    next_results = self._run_with_error_handling(self._fetch_next_result)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 94, in _run_with_error_handling
    return func()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 174, in _fetch_next_result
    results = self._backend_executor.get_next_results()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 616, in get_next_results
    results = self.get_with_failure_handling(futures)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 729, in get_with_failure_handling
    self._increment_failures()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 791, in _increment_failures
    raise failure
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: RayTrainWorker
        actor_id: 275554ab0c907cb06bd2b3b10f000000
        pid: 3871886
        namespace: 08f59208-9f05-45d5-97e7-914caee8db65
        ip: 172.31.16.4
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

Training errored after 0 iterations at 2025-06-25 09:15:35. Total running time: 22s
Error file: /tmp/ray/session_2025-06-24_16-12-09_684429_738797/artifacts/2025-06-25_09-15-13/TorchTrainer_2025-06-25_09-15-13/driver_artifacts/TorchTrainer_d887c_00000_0_2025-06-25_09-15-13/error.txt
2025-06-25 09:15:35,902 INFO tune.py:1009 -- Wrote the latest version of all result files and experiment state to '/data2/users/bydeng/ray_results/TorchTrainer_2025-06-25_09-15-13' in 0.0018s.

2025-06-25 09:15:35,902 ERROR tune.py:1037 -- Trials did not complete: [TorchTrainer_d887c_00000]
ray.exceptions.RayTaskError(ActorDiedError): ray::_Inner.train() (pid=1674340, ip=172.31.0.7, actor_id=5d7c5e8e421ac5fc06b351790f000000, repr=TorchTrainer)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/trainable.py", line 330, in train
    raise skipped from exception_cause(skipped)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/air/_internal/util.py", line 107, in run
    self._ret = self._target(*self._args, **self._kwargs)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 45, in <lambda>
    training_func=lambda: self._trainable_func(self.config),
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/base_trainer.py", line 883, in _trainable_func
    super()._trainable_func(self._merged_config)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/tune/trainable/function_trainable.py", line 261, in _trainable_func
    output = fn()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/base_trainer.py", line 123, in _train_coordinator_fn
    trainer.training_loop()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/data_parallel_trainer.py", line 470, in training_loop
    self._run_training(training_iterator)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/data_parallel_trainer.py", line 369, in _run_training
    for training_results in training_iterator:
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 129, in __next__
    next_results = self._run_with_error_handling(self._fetch_next_result)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 94, in _run_with_error_handling
    return func()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/trainer.py", line 174, in _fetch_next_result
    results = self._backend_executor.get_next_results()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 616, in get_next_results
    results = self.get_with_failure_handling(futures)
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 729, in get_with_failure_handling
    self._increment_failures()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/backend_executor.py", line 791, in _increment_failures
    raise failure
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/_internal/utils.py", line 57, in check_for_failure
    ray.get(object_ref)
ray.exceptions.ActorDiedError: The actor died unexpectedly before finishing this task.
        class_name: RayTrainWorker
        actor_id: 275554ab0c907cb06bd2b3b10f000000
        pid: 3871886
        namespace: 08f59208-9f05-45d5-97e7-914caee8db65
        ip: 172.31.16.4
The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/data2/workspace/bydeng/ray_serve/train.py", line 154, in <module>
    train_fashion_mnist(num_workers=4, use_gpu=True)
  File "/data2/workspace/bydeng/ray_serve/train.py", line 148, in train_fashion_mnist
    result = trainer.fit()
  File "/data2/users/bydeng/.conda/envs/swift/lib/python3.10/site-packages/ray/train/base_trainer.py", line 722, in fit
    raise TrainingFailedError(
ray.train.base_trainer.TrainingFailedError: The Ray Train run failed. Please inspect the previous error messages for a cause. After fixing the issue (assuming that the error is not caused by your own application logic, but rather an error such as OOM), you can restart the run from scratch or continue this run.
To continue this run, you can use: `trainer = TorchTrainer.restore("/data2/users/bydeng/ray_results/TorchTrainer_2025-06-25_09-15-13")`.
To start a new run that will retry on training failures, set `train.RunConfig(failure_config=train.FailureConfig(max_failures))` in the Trainer's `run_config` with `max_failures > 0`, or `max_failures = -1` for unlimited retries.
100%|██████████| 5.15k/5.15k [00:00<00:00, 66.2MB/s]
 82%|████████▏ | 3.60M/4.42M [00:01<00:00, 3.64MB/s] [repeated 8x across cluster]

training code as following:

import os
from typing import Dict

import torch
from filelock import FileLock
from torch import nn
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from torchvision.transforms import Normalize, ToTensor
from tqdm import tqdm

import ray.train
from ray.train import ScalingConfig
from ray.train.torch import TorchTrainer


def get_dataloaders(batch_size):
    # Transform to normalize the input images
    transform = transforms.Compose([ToTensor(), Normalize((0.28604,), (0.32025,))])

    with FileLock(os.path.expanduser("data.lock")):
        # Download training data from open datasets
        training_data = datasets.FashionMNIST(
            root="data",
            train=True,
            download=True,
            transform=transform,
        )

        # Download test data from open datasets
        test_data = datasets.FashionMNIST(
            root="data",
            train=False,
            download=True,
            transform=transform,
        )

    # Create data loaders
    train_dataloader = DataLoader(training_data, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_data, batch_size=batch_size)

    return train_dataloader, test_dataloader


# Model Definition
class NeuralNetwork(nn.Module):
    def __init__(self):
        super(NeuralNetwork, self).__init__()
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28 * 28, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(512, 10),
            nn.ReLU(),
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits


def train_func_per_worker(config: Dict):
    lr = config["lr"]
    epochs = config["epochs"]
    batch_size = config["batch_size_per_worker"]

    # Get dataloaders inside the worker training function
    train_dataloader, test_dataloader = get_dataloaders(batch_size=batch_size)

    # [1] Prepare Dataloader for distributed training
    # Shard the datasets among workers and move batches to the correct device
    # =======================================================================
    train_dataloader = ray.train.torch.prepare_data_loader(train_dataloader)
    test_dataloader = ray.train.torch.prepare_data_loader(test_dataloader)

    model = NeuralNetwork()

    # [2] Prepare and wrap your model with DistributedDataParallel
    # Move the model to the correct GPU/CPU device
    # ============================================================
    model = ray.train.torch.prepare_model(model)

    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)

    # Model training loop
    for epoch in range(epochs):
        if ray.train.get_context().get_world_size() > 1:
            # Required for the distributed sampler to shuffle properly across epochs.
            train_dataloader.sampler.set_epoch(epoch)

        model.train()
        for X, y in tqdm(train_dataloader, desc=f"Train Epoch {epoch}"):
            pred = model(X)
            loss = loss_fn(pred, y)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        model.eval()
        test_loss, num_correct, num_total = 0, 0, 0
        with torch.no_grad():
            for X, y in tqdm(test_dataloader, desc=f"Test Epoch {epoch}"):
                pred = model(X)
                loss = loss_fn(pred, y)

                test_loss += loss.item()
                num_total += y.shape[0]
                num_correct += (pred.argmax(1) == y).sum().item()

        test_loss /= len(test_dataloader)
        accuracy = num_correct / num_total

        # [3] Report metrics to Ray Train
        # ===============================
        ray.train.report(metrics={"loss": test_loss, "accuracy": accuracy})


def train_fashion_mnist(num_workers=2, use_gpu=False):
    global_batch_size = 40

    train_config = {
        "lr": 1e-3,
        "epochs": 10,
        "batch_size_per_worker": global_batch_size // num_workers,
    }

    # Configure computation resources
    scaling_config = ScalingConfig(num_workers=num_workers, use_gpu=use_gpu)

    # Initialize a Ray TorchTrainer
    trainer = TorchTrainer(
        train_loop_per_worker=train_func_per_worker,
        train_loop_config=train_config,
        scaling_config=scaling_config,
        # max_failures=3,
    )

    # [4] Start distributed training
    # Run `train_func_per_worker` on all workers
    # =============================================
    result = trainer.fit()
    print(f"Training result: {result}")


if __name__ == "__main__":
    ray.init(address="auto", runtime_env={'env_vars': { 'GRPC_ENABLE_FORK_SUPPORT': 'True', 'GRPC_POLL_STRATEGY': 'epoll1', 'RAY_start_python_importer_thread': '0', }})
    train_fashion_mnist(num_workers=4, use_gpu=True)

python version is Python 3.10.16
ray version is 2.47.1

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0