8000 Setting `requires_grad = True` for optimizer parameters · Issue #5106 · allenai/allennlp · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Dec 16, 2022. It is now read-only.
This repository was archived by the owner on Dec 16, 2022. It is now read-only.
Setting requires_grad = True for optimizer parameters #5106
Open
@nelson-liu

Description

@nelson-liu

Is your feature request related to a problem? Please describe.

I'd like the ability to set requires_grad=True in the optimizer parameter groups. For instance:

    ...
    "text_field_embedder": {
      "token_embedders": {
        "tokens": {
          "type": "pretrained_transformer",
          "model_name": transformer_model,
          "max_length": 512,
          "train_parameters": false,
        }
      }
    },
    ....
    "optimizer": {
       ...
      "parameter_groups": [
        # This turns on grad for the attention query bias vectors and the intermediate MLP bias vectors.
        # Since we set train_parameters to false in the token_embedder, these are the only weights that will be updated
        # in the token_embedder.
        [["^_text_field_embedder.token_embedder_tokens.transformer_model.*attention.self.query.bias$"], {"requires_grad": true}],
        [["^_text_field_embedder.token_embedder_tokens.transformer_model.*intermediate.dense.bias$"], {"requires_grad": true}]
      ]
    },

In this config, I set the token embedder train_parameters to false, so it's not trainable. However, i want to train some of the parameters (defined by the regex). The intended outcome is that the token embedder parameters are non-trainable (since train_parameters = False), but a subset of them are trainable (defined by the regex).

The current behavior is that these groups are just ignored. This is because, the non-trainable parameters aren't even passed to the optimizer, so the regexes don't match anything / they accordingly can't have their requires_grad value changed.

(i realize that i can do this by setting train_parameters = True, and then writing a regex to select out all of the parameters that don't match the regexes above and then setting {requires_grad: False} on those. however, that regex is borderline unmaintainable / certainly not very readable.)

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions

      0