8000 Running OCR with a tensor produces different (and wrong) results w.r.t using a numpy array · Issue #1914 · mindee/doctr · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Running OCR with a tensor produces different (and wrong) results w.r.t using a numpy array #1914
Open
@git-artes

Description

@git-artes

Bug description

First off, thanks for the great software. Just to provide a little bit of context, we're trying to build an end-to-end system consisting of a (sort of) denoiser followed by doctr, and only fine-tune the denoiser so that it produces more readable results as measured by the loss of the results produced by doctr.

Thus, we would like to use tensors as inputs to the doctr modules. However, when using tensors the results are very different (and wrong) to when compared to using numpy arrays. The code below shows a simple example, which has been tested on google colab.

Code snippet to reproduce the bug

!pip3 install -U pip

## to avoid the RuntimeError: Given input size: (128x1x16). Calculated output size: (128x0x8). Output size is too small bug on colab
# see https://github.com/mindee/doctr/discussions/1884
!pip3 uninstall -y tensorFlow
# now yes, install doctr
!pip3 install "python-doctr[torch,viz]"

import cv2
import numpy as np
import torch
from doctr.models import ocr_predictor


# Function to generate a text image
def generate_text_image(text="HELLO", size=(128, 128)):
    img = np.ones(size, dtype=np.uint8) * 255  # White background
    font = cv2.FONT_HERSHEY_SIMPLEX
    font_scale = 1
    thickness = 2
    text_size = cv2.getTextSize(text, font, font_scale, thickness)[0]
    text_x = (size[1] - text_size[0]) // 2
    text_y = (size[0] + text_size[1]) // 2
    cv2.putText(img, text, (text_x, text_y), font, font_scale, (0,), thickness)
    return img

# Generate image
original_image = generate_text_image(text="HELLO how are you", size=(600, 600))

# Convert to tensor 
input_tensor = torch.tensor(original_image).unsqueeze(0).unsqueeze(0).repeat(1,3,1,1)  # Shape: [1, 3, H, W]
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load OCR model
ocr_model = ocr_predictor(pretrained=True).to(device)

#apply ocr_model to tensor (works wrong, and detects only several "-")
print(ocr_model(input_tensor))
#apply ocr_model to numpy array (works great and detects all the words)
threech_original_image_batch = np.expand_dims(np.stack((original_image,)*3, axis=-1),axis=0) # Shape: [1, H, W, 3]
print(ocr_model(threech_original_image_batch))

Error traceback

When run on the tensor we get the following (wrong) output:

Document(
  (pages): [Page(
    dimensions=torch.Size([600, 600])
    (blocks): [Block(
      (lines): [
        Line(
          (words): [Word(value='-', confidence=1.0)]
        ),
        Line(
          (words): [
            Word(value='-', confidence=1.0),
            Word(value='-', confidence=1.0),
            Word(value='-', confidence=1.0),
          ]
        ),
      ]
      (artefacts): []
    )]
  )]
)

and when run on the numpy array we get a correct output:

Document(
  (pages): [Page(
    dimensions=(600, 600)
    (blocks): [Block(
      (lines): [
        Line(
          (words): [Word(value='HELLO', confidence=1.0)]
        ),
        Line(
          (words): [
            Word(value='how', confidence=0.58),
            Word(value='are', confidence=0.97),
            Word(value='you', confidence=0.88),
          ]
        ),
      ]
      (artefacts): []
    )]
  )]
)

Environment

Collecting environment information...

DocTR version: v0.11.0
TensorFlow version: N/A
PyTorch version: 2.6.0+cu124 (torchvision 0.21.0+cu124)
OpenCV version: 4.11.0
OS: Ubuntu 22.04.4 LTS
Python version: 3.11.11
Is CUDA available (TensorFlow): N/A
Is CUDA available (PyTorch): No
CUDA runtime version: 12.5.82
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_adv.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_cnn.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_precompiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_engines_runtime_compiled.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_graph.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_heuristic.so.9.2.1
/usr/lib/x86_64-linux-gnu/libcudnn_ops.so.9.2.1

Deep Learning backend

is_tf_available: False
is_torch_available: True

Metadata

Metadata

Assignees

No one assigned

    Labels

    type: bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0