diff --git a/bert_pytorch/dataset/dataset.py b/bert_pytorch/dataset/dataset.py index 7d787f3..cf80a52 100644 --- a/bert_pytorch/dataset/dataset.py +++ b/bert_pytorch/dataset/dataset.py @@ -115,10 +115,10 @@ def get_random_line(self): if self.on_memory: return self.lines[random.randrange(len(self.lines))][1] - line = self.file.__next__() + line = self.random_file.__next__() if line is None: - self.file.close() - self.file = open(self.corpus_path, "r", encoding=self.encoding) + self.random_file.close() + self.random_file = open(self.corpus_path, "r", encoding=self.encoding) for _ in range(random.randint(self.corpus_lines if self.corpus_lines < 1000 else 1000)): self.random_file.__next__() line = self.random_file.__next__()