8000 Error when using torch.fx on bert · Issue #67970 · pytorch/pytorch · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
Error when using torch.fx on bert #67970
Open
@lethean1

Description

@lethean1

I am trying to use symbolic_trace on bert model and getting this error:

Traceback (most recent call last):
File "/workspace/fuse/hugging_face_competiton/train.py", line 196, in <module>
    model = fusion.fuse(model)
  File "/workspace/fuse/hugging_face_competiton/fusion.py", line 72, in fuse
    fx_model = fx.symbolic_trace(model)
  File "/root/miniconda3/lib/python3.9/site-packages/torch/fx/symbolic_trace.py", line 858, in symbolic_trace
    graph = tracer.trace(root, concrete_args)
  File "/root/miniconda3/lib/python3.9/site-packages/torch/fx/symbolic_trace.py", line 570, in trace
    self.create_node('output', 'output', (self.create_arg(fn(*args)),), {},
  File "/root/miniconda3/lib/python3.9/site-packages/transformers/models/bert/modeling_bert.py", line 941, in forward
    raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
ValueError: You cannot specify both input_ids and inputs_embeds at the same time

Model

    configuration = BertConfig(hidden_size=args.hidden_size, num_hidden_layers=args.num_hidden_layers,
                               num_attention_heads=args.num_attention_heads, intermediate_size=4 * args.hidden_size)
    model = BertModel(configuration).cuda()
    model.eval()
    fused_model = fusion.fuse(model)

Is it because of fx use Proxies as the input and cause this conflict?
And any help on how to run this correctly is appreciated. Thanks

Metadata

Metadata

Assignees

No one assigned

    Labels

    module: fxtriagedThis issue has been looked at a team member, and triaged and prioritized into an appropriate module

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0