Closed
Description
These two lines need a bit attention because if batch > number of samples in testing here:
if batch == -1 or batch > dataset['train_input'].shape[0]:
batch_size = dataset['train_input'].shape[0]
batch_size_test = dataset['test_input'].shape[0]
else:
batch_size = batch
batch_size_test = batch
then this is invalid:
test_id = np.random.choice(dataset['test_input'].shape[0], batch_size_test, replace=False)
On the other hand, what is the purpose of batching testing dataset if it's used for testing the model? Surely, batching should only be done on training dataset, i.e. this line is not required?
test_loss = loss_fn_eval(self.forward(dataset['test_input'][test_id].to(device)), dataset['test_label'][test_id].to(device))
Metadata
Metadata
Assignees
Labels
No labels