Shortcuts

Source code for torch.jit._freeze

"""Freezing

This is not intended to be imported directly; please use the exposed
functionalities in `torch.jit`.
"""

from typing import Optional, List

import torch
from torch.jit._script import RecursiveScriptModule, ScriptModule


[docs]def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize_numerics: bool = True): r""" Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned module's submodules, parameters, and attributes as constants in the TorchScript IR Graph. By default, `forward` will be preserved, as well as attributes & methods specified in `preserved_attrs`. Additionally, any attribute that is modified within a preserved method will be preserved. Freezing currently only accepts ScriptModules that are in eval mode. Args: mod (:class:`ScriptModule`): a module to be frozen preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method. Attributes modified in preserved methods will also be preserved. optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`. Returns: Frozen :class:`ScriptModule`. Example (Freezing a simple module with a Parameter): .. testcode:: import torch class MyModule(torch.nn.Module): def __init__(self, N, M): super(MyModule, self).__init__() self.weight = torch.nn.Parameter(torch.rand(N, M)) self.linear = torch.nn.Linear(N, M) def forward(self, input): output = self.weight.mm(input) output = self.linear(output) return output scripted_module = torch.jit.script(MyModule(2, 3).eval()) frozen_module = torch.jit.freeze(scripted_module) # parameters have been removed and inlined into the Graph as constants assert len(list(frozen_module.named_parameters())) == 0 # See the compiled graph as Python code print(frozen_module.code) Example (Freezing a module with preserved attributes) .. testcode:: import torch class MyModule2(torch.nn.Module): def __init__(self): super(MyModule2, self).__init__() self.modified_tensor = torch.tensor(10.) self.version = 1 def forward(self, input): self.modified_tensor += 1 return input + self.modified_tensor scripted_module = torch.jit.script(MyModule2().eval()) frozen_module = torch.jit.freeze(scripted_module, preserved_attrs=["version"]) # we've manually preserved `version`, so it still exists on the frozen module and can be modified assert frozen_module.version == 1 frozen_module.version = 2 # `modified_tensor` is detected as being mutated in the forward, so freezing preserves # it to retain model semantics assert frozen_module(torch.tensor(1)) == torch.tensor(12) # now that we've run it once, the next result will be incremented by one assert frozen_module(torch.tensor(1)) == torch.tensor(13) Note: If you're not sure why an attribute is not being inlined as a constant, you can run `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the attribute is being modified. """ if not isinstance(mod, ScriptModule): raise RuntimeError( "Freezing expects a ScriptModule as input. " "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'." ) if mod.training: raise RuntimeError( "Freezing is currently only implemented for modules in eval mode. " "Please call .eval() on your module before freezing." ) preserved_attrs = preserved_attrs if preserved_attrs is not None else [] out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs)) RecursiveScriptModule._finalize_scriptmodule(out) optimize_frozen_module(out, optimize_numerics) return out
def optimize_frozen_module(mod, optimize_numerics: bool = True): r""" Runs a series of optimizations looking for patterns that occur in frozen graphs. The current set of optimizations is: - Dropout Removal - Conv -> Batchnorm folding - Conv -> Add/Sub folding - Conv -> Mul/Div folding Args: mod (:class:`ScriptModule`): a frozen module to be optimized optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly preserve numerics. These optimizations preserve default rtol and atol of `torch.testing.assert_allclose` when applied on a single transformation, however in a module where many transformations are applied the rtol or atol may no longer fall within the default `assert_allclose` tolerance. Conv -> Batchnorm folding, Conv-Add/Sub, and Conv -> Mul/Div folding all may alter numerics. Returns: None Note: In rare occassions, this can result in slower execution. Example (Freezing a module with Conv->Batchnorm) .. code-block:: python import torch in_channels, out_channels = 3, 32 conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True) bn = torch.nn.BatchNorm2d(out_channels, eps=.001) mod = torch.nn.Sequential(conv, bn) # set optimize to False here, by default freezing runs optimize_frozen_module frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False) # inspect frozen mod assert "batch_norm" in str(frozen_mod.graph) torch.jit.optimize_frozen_module(frozen_mod) assert "batch_norm" not in str(frozen_mod.graph) """ # xxx: keep in sync with frozen_graph_optimization.cpp # intentionally duplicated to make to make it easier to create custom optimization sequence torch._C._jit_pass_remove_dropout(mod._c) if optimize_numerics: # run a couple times to capture Conv -> Mul -> Add etc for _ in range(2): torch._C._jit_pass_fold_frozen_conv_bn(mod.graph) torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph) torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)

Docs

Access comprehensive developer documentation for PyTorch

View Docs

Tutorials

Get in-depth tutorials for beginners and advanced developers

View Tutorials

Resources

Find development resources and get your questions answered

View Resources