From 102ee5a997a29e6f799d2ea0caaa11a397ae5ee6 Mon Sep 17 00:00:00 2001 From: Chongxiao Cao Date: Thu, 21 Jul 2022 14:12:57 -0700 Subject: [PATCH] Spark/Lightning: add reader_worker_count and reader_pool_type --- horovod/spark/lightning/datamodule.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/horovod/spark/lightning/datamodule.py b/horovod/spark/lightning/datamodule.py index ef9d2ffe39..1e9049dc58 100644 --- a/horovod/spark/lightning/datamodule.py +++ b/horovod/spark/lightning/datamodule.py @@ -62,6 +62,8 @@ def setup(self, stage=None): reader_factory = make_batch_reader self.train_reader = reader_factory(self.train_dir, num_epochs=self.num_reader_epochs, + reader_pool_type=self.reader_pool_type, + workers_count=self.reader_worker_count, cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=self.schema_fields, @@ -72,6 +74,8 @@ def setup(self, stage=None): **reader_factory_kwargs) if self.has_val: self.val_reader = reader_factory(self.val_dir, num_epochs=self.num_reader_epochs, + reader_pool_type=self.reader_pool_type, + workers_count=self.reader_worker_count, cur_shard=self.cur_shard, shard_count=self.shard_count, hdfs_driver=PETASTORM_HDFS_DRIVER, schema_fields=self.schema_fields,