Spaces:
Running
on
Zero
Running
on
Zero
File size: 23,656 Bytes
83897ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 |
# Copyright (c) Facebook, Inc. and its affiliates.
import numpy as np
import fvcore.nn.weight_init as weight_init
import torch
import torch.nn.functional as F
from torch import nn
from detectron2.layers import (
CNNBlockBase,
Conv2d,
DeformConv,
ModulatedDeformConv,
ShapeSpec,
get_norm,
)
from .backbone import Backbone
from .build import BACKBONE_REGISTRY
__all__ = [
"ResNetBlockBase",
"BasicBlock",
"BottleneckBlock",
"DeformBottleneckBlock",
"BasicStem",
"ResNet",
"make_stage",
"build_resnet_backbone",
]
class BasicBlock(CNNBlockBase):
"""
The basic residual block for ResNet-18 and ResNet-34 defined in :paper:`ResNet`,
with two 3x3 conv layers and a projection shortcut if needed.
"""
def __init__(self, in_channels, out_channels, *, stride=1, norm="BN"):
"""
Args:
in_channels (int): Number of input channels.
out_channels (int): Number of output channels.
stride (int): Stride for the first conv.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=3,
stride=stride,
padding=1,
bias=False,
norm=get_norm(norm, out_channels),
)
self.conv2 = Conv2d(
out_channels,
out_channels,
kernel_size=3,
stride=1,
padding=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class BottleneckBlock(CNNBlockBase):
"""
The standard bottleneck residual block used by ResNet-50, 101 and 152
defined in :paper:`ResNet`. It contains 3 conv layers with kernels
1x1, 3x3, 1x1, and a projection shortcut if needed.
"""
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
):
"""
Args:
bottleneck_channels (int): number of output channels for the 3x3
"bottleneck" conv layers.
num_groups (int): number of groups for the 3x3 conv layer.
norm (str or callable): normalization for all conv layers.
See :func:`layers.get_norm` for supported format.
stride_in_1x1 (bool): when stride>1, whether to put stride in the
first 1x1 convolution or the bottleneck 3x3 convolution.
dilation (int): the dilation rate of the 3x3 conv layer.
"""
super().__init__(in_channels, out_channels, stride)
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
# The original MSRA ResNet models have stride in the first 1x1 conv
# The subsequent fb.torch.resnet and Caffe2 ResNe[X]t implementations have
# stride in the 3x3 conv
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
self.conv2 = Conv2d(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
# Zero-initialize the last normalization in each residual branch,
# so that at the beginning, the residual branch starts with zeros,
# and each residual block behaves like an identity.
# See Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "For BN layers, the learnable scaling coefficient γ is initialized
# to be 1, except for each residual block's last BN
# where γ is initialized to be 0."
# nn.init.constant_(self.conv3.norm.weight, 0)
# TODO this somehow hurts performance when training GN models from scratch.
# Add it as an option when we need to use this code to train a backbone.
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
out = self.conv2(out)
out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class DeformBottleneckBlock(CNNBlockBase):
"""
Similar to :class:`BottleneckBlock`, but with :paper:`deformable conv <deformconv>`
in the 3x3 convolution.
"""
def __init__(
self,
in_channels,
out_channels,
*,
bottleneck_channels,
stride=1,
num_groups=1,
norm="BN",
stride_in_1x1=False,
dilation=1,
deform_modulated=False,
deform_num_groups=1,
):
super().__init__(in_channels, out_channels, stride)
self.deform_modulated = deform_modulated
if in_channels != out_channels:
self.shortcut = Conv2d(
in_channels,
out_channels,
kernel_size=1,
stride=stride,
bias=False,
norm=get_norm(norm, out_channels),
)
else:
self.shortcut = None
stride_1x1, stride_3x3 = (stride, 1) if stride_in_1x1 else (1, stride)
self.conv1 = Conv2d(
in_channels,
bottleneck_channels,
kernel_size=1,
stride=stride_1x1,
bias=False,
norm=get_norm(norm, bottleneck_channels),
)
if deform_modulated:
deform_conv_op = ModulatedDeformConv
# offset channels are 2 or 3 (if with modulated) * kernel_size * kernel_size
offset_channels = 27
else:
deform_conv_op = DeformConv
offset_channels = 18
self.conv2_offset = Conv2d(
bottleneck_channels,
offset_channels * deform_num_groups,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
dilation=dilation,
)
self.conv2 = deform_conv_op(
bottleneck_channels,
bottleneck_channels,
kernel_size=3,
stride=stride_3x3,
padding=1 * dilation,
bias=False,
groups=num_groups,
dilation=dilation,
deformable_groups=deform_num_groups,
norm=get_norm(norm, bottleneck_channels),
)
self.conv3 = Conv2d(
bottleneck_channels,
out_channels,
kernel_size=1,
bias=False,
norm=get_norm(norm, out_channels),
)
for layer in [self.conv1, self.conv2, self.conv3, self.shortcut]:
if layer is not None: # shortcut can be None
weight_init.c2_msra_fill(layer)
nn.init.constant_(self.conv2_offset.weight, 0)
nn.init.constant_(self.conv2_offset.bias, 0)
def forward(self, x):
out = self.conv1(x)
out = F.relu_(out)
if self.deform_modulated:
offset_mask = self.conv2_offset(out)
offset_x, offset_y, mask = torch.chunk(offset_mask, 3, dim=1)
offset = torch.cat((offset_x, offset_y), dim=1)
mask = mask.sigmoid()
out = self.conv2(out, offset, mask)
else:
offset = self.conv2_offset(out)
out = self.conv2(out, offset)
out = F.relu_(out)
out = self.conv3(out)
if self.shortcut is not None:
shortcut = self.shortcut(x)
else:
shortcut = x
out += shortcut
out = F.relu_(out)
return out
class BasicStem(CNNBlockBase):
"""
The standard ResNet stem (layers before the first residual block),
with a conv, relu and max_pool.
"""
def __init__(self, in_channels=3, out_channels=64, norm="BN"):
"""
Args:
norm (str or callable): norm after the first conv layer.
See :func:`layers.get_norm` for supported format.
"""
super().__init__(in_channels, out_channels, 4)
self.in_channels = in_channels
self.conv1 = Conv2d(
in_channels,
out_channels,
kernel_size=7,
stride=2,
padding=3,
bias=False,
norm=get_norm(norm, out_channels),
)
weight_init.c2_msra_fill(self.conv1)
def forward(self, x):
x = self.conv1(x)
x = F.relu_(x)
x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
return x
class ResNet(Backbone):
"""
Implement :paper:`ResNet`.
"""
def __init__(self, stem, stages, num_classes=None, out_features=None, freeze_at=0):
"""
Args:
stem (nn.Module): a stem module
stages (list[list[CNNBlockBase]]): several (typically 4) stages,
each contains multiple :class:`CNNBlockBase`.
num_classes (None or int): if None, will not perform classification.
Otherwise, will create a linear layer.
out_features (list[str]): name of the layers whose outputs should
be returned in forward. Can be anything in "stem", "linear", or "res2" ...
If None, will return the output of the last layer.
freeze_at (int): The number of stages at the beginning to freeze.
see :meth:`freeze` for detailed explanation.
"""
super().__init__()
self.stem = stem
self.num_classes = num_classes
current_stride = self.stem.stride
self._out_feature_strides = {"stem": current_stride}
self._out_feature_channels = {"stem": self.stem.out_channels}
self.stage_names, self.stages = [], []
if out_features is not None:
# Avoid keeping unused layers in this module. They consume extra memory
# and may cause allreduce to fail
num_stages = max(
[{"res2": 1, "res3": 2, "res4": 3, "res5": 4}.get(f, 0) for f in out_features]
)
stages = stages[:num_stages]
for i, blocks in enumerate(stages):
assert len(blocks) > 0, len(blocks)
for block in blocks:
assert isinstance(block, CNNBlockBase), block
name = "res" + str(i + 2)
stage = nn.Sequential(*blocks)
self.add_module(name, stage)
self.stage_names.append(name)
self.stages.append(stage)
self._out_feature_strides[name] = current_stride = int(
current_stride * np.prod([k.stride for k in blocks])
)
self._out_feature_channels[name] = curr_channels = blocks[-1].out_channels
self.stage_names = tuple(self.stage_names) # Make it static for scripting
if num_classes is not None:
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
self.linear = nn.Linear(curr_channels, num_classes)
# Sec 5.1 in "Accurate, Large Minibatch SGD: Training ImageNet in 1 Hour":
# "The 1000-way fully-connected layer is initialized by
# drawing weights from a zero-mean Gaussian with standard deviation of 0.01."
nn.init.normal_(self.linear.weight, std=0.01)
name = "linear"
if out_features is None:
out_features = [name]
self._out_features = out_features
assert len(self._out_features)
children = [x[0] for x in self.named_children()]
for out_feature in self._out_features:
assert out_feature in children, "Available children: {}".format(", ".join(children))
self.freeze(freeze_at)
def forward(self, x):
"""
Args:
x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
Returns:
dict[str->Tensor]: names and the corresponding features
"""
assert x.dim() == 4, f"ResNet takes an input of shape (N, C, H, W). Got {x.shape} instead!"
outputs = {}
x = self.stem(x)
if "stem" in self._out_features:
outputs["stem"] = x
for name, stage in zip(self.stage_names, self.stages):
x = stage(x)
if name in self._out_features:
outputs[name] = x
if self.num_classes is not None:
x = self.avgpool(x)
x = torch.flatten(x, 1)
x = self.linear(x)
if "linear" in self._out_features:
outputs["linear"] = x
return outputs
def output_shape(self):
return {
name: ShapeSpec(
channels=self._out_feature_channels[name], stride=self._out_feature_strides[name]
)
for name in self._out_features
}
def freeze(self, freeze_at=0):
"""
Freeze the first several stages of the ResNet. Commonly used in
fine-tuning.
Layers that produce the same feature map spatial size are defined as one
"stage" by :paper:`FPN`.
Args:
freeze_at (int): number of stages to freeze.
`1` means freezing the stem. `2` means freezing the stem and
one residual stage, etc.
Returns:
nn.Module: this ResNet itself
"""
if freeze_at >= 1:
self.stem.freeze()
for idx, stage in enumerate(self.stages, start=2):
if freeze_at >= idx:
for block in stage.children():
block.freeze()
return self
@staticmethod
def make_stage(block_class, num_blocks, *, in_channels, out_channels, **kwargs):
"""
Create a list of blocks of the same type that forms one ResNet stage.
Args:
block_class (type): a subclass of CNNBlockBase that's used to create all blocks in this
stage. A module of this type must not change spatial resolution of inputs unless its
stride != 1.
num_blocks (int): number of blocks in this stage
in_channels (int): input channels of the entire stage.
out_channels (int): output channels of **every block** in the stage.
kwargs: other arguments passed to the constructor of
`block_class`. If the argument name is "xx_per_block", the
argument is a list of values to be passed to each block in the
stage. Otherwise, the same argument is passed to every block
in the stage.
Returns:
list[CNNBlockBase]: a list of block module.
Examples:
::
stage = ResNet.make_stage(
BottleneckBlock, 3, in_channels=16, out_channels=64,
bottleneck_channels=16, num_groups=1,
stride_per_block=[2, 1, 1],
dilations_per_block=[1, 1, 2]
)
Usually, layers that produce the same feature map spatial size are defined as one
"stage" (in :paper:`FPN`). Under such definition, ``stride_per_block[1:]`` should
all be 1.
"""
blocks = []
for i in range(num_blocks):
curr_kwargs = {}
for k, v in kwargs.items():
if k.endswith("_per_block"):
assert len(v) == num_blocks, (
f"Argument '{k}' of make_stage should have the "
f"same length as num_blocks={num_blocks}."
)
newk = k[: -len("_per_block")]
assert newk not in kwargs, f"Cannot call make_stage with both {k} and {newk}!"
curr_kwargs[newk] = v[i]
else:
curr_kwargs[k] = v
blocks.append(
block_class(in_channels=in_channels, out_channels=out_channels, **curr_kwargs)
)
in_channels = out_channels
return blocks
@staticmethod
def make_default_stages(depth, block_class=None, **kwargs):
"""
Created list of ResNet stages from pre-defined depth (one of 18, 34, 50, 101, 152).
If it doesn't create the ResNet variant you need, please use :meth:`make_stage`
instead for fine-grained customization.
Args:
depth (int): depth of ResNet
block_class (type): the CNN block class. Has to accept
`bottleneck_channels` argument for depth > 50.
By default it is BasicBlock or BottleneckBlock, based on the
depth.
kwargs:
other arguments to pass to `make_stage`. Should not contain
stride and channels, as they are predefined for each depth.
Returns:
list[list[CNNBlockBase]]: modules in all stages; see arguments of
:class:`ResNet.__init__`.
"""
num_blocks_per_stage = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[depth]
if block_class is None:
block_class = BasicBlock if depth < 50 else BottleneckBlock
if depth < 50:
in_channels = [64, 64, 128, 256]
out_channels = [64, 128, 256, 512]
else:
in_channels = [64, 256, 512, 1024]
out_channels = [256, 512, 1024, 2048]
ret = []
for n, s, i, o in zip(num_blocks_per_stage, [1, 2, 2, 2], in_channels, out_channels):
if depth >= 50:
kwargs["bottleneck_channels"] = o // 4
ret.append(
ResNet.make_stage(
block_class=block_class,
num_blocks=n,
stride_per_block=[s] + [1] * (n - 1),
in_channels=i,
out_channels=o,
**kwargs,
)
)
return ret
ResNetBlockBase = CNNBlockBase
"""
Alias for backward compatibiltiy.
"""
def make_stage(*args, **kwargs):
"""
Deprecated alias for backward compatibiltiy.
"""
return ResNet.make_stage(*args, **kwargs)
@BACKBONE_REGISTRY.register()
def build_resnet_backbone(cfg, input_shape):
"""
Create a ResNet instance from config.
Returns:
ResNet: a :class:`ResNet` instance.
"""
# need registration of new blocks/stems?
norm = cfg.MODEL.RESNETS.NORM
stem = BasicStem(
in_channels=input_shape.channels,
out_channels=cfg.MODEL.RESNETS.STEM_OUT_CHANNELS,
norm=norm,
)
# fmt: off
freeze_at = cfg.MODEL.BACKBONE.FREEZE_AT
out_features = cfg.MODEL.RESNETS.OUT_FEATURES
depth = cfg.MODEL.RESNETS.DEPTH
num_groups = cfg.MODEL.RESNETS.NUM_GROUPS
width_per_group = cfg.MODEL.RESNETS.WIDTH_PER_GROUP
bottleneck_channels = num_groups * width_per_group
in_channels = cfg.MODEL.RESNETS.STEM_OUT_CHANNELS
out_channels = cfg.MODEL.RESNETS.RES2_OUT_CHANNELS
stride_in_1x1 = cfg.MODEL.RESNETS.STRIDE_IN_1X1
res5_dilation = cfg.MODEL.RESNETS.RES5_DILATION
deform_on_per_stage = cfg.MODEL.RESNETS.DEFORM_ON_PER_STAGE
deform_modulated = cfg.MODEL.RESNETS.DEFORM_MODULATED
deform_num_groups = cfg.MODEL.RESNETS.DEFORM_NUM_GROUPS
# fmt: on
assert res5_dilation in {1, 2}, "res5_dilation cannot be {}.".format(res5_dilation)
num_blocks_per_stage = {
18: [2, 2, 2, 2],
34: [3, 4, 6, 3],
50: [3, 4, 6, 3],
101: [3, 4, 23, 3],
152: [3, 8, 36, 3],
}[depth]
if depth in [18, 34]:
assert out_channels == 64, "Must set MODEL.RESNETS.RES2_OUT_CHANNELS = 64 for R18/R34"
assert not any(
deform_on_per_stage
), "MODEL.RESNETS.DEFORM_ON_PER_STAGE unsupported for R18/R34"
assert res5_dilation == 1, "Must set MODEL.RESNETS.RES5_DILATION = 1 for R18/R34"
assert num_groups == 1, "Must set MODEL.RESNETS.NUM_GROUPS = 1 for R18/R34"
stages = []
for idx, stage_idx in enumerate(range(2, 6)):
# res5_dilation is used this way as a convention in R-FCN & Deformable Conv paper
dilation = res5_dilation if stage_idx == 5 else 1
first_stride = 1 if idx == 0 or (stage_idx == 5 and dilation == 2) else 2
stage_kargs = {
"num_blocks": num_blocks_per_stage[idx],
"stride_per_block": [first_stride] + [1] * (num_blocks_per_stage[idx] - 1),
"in_channels": in_channels,
"out_channels": out_channels,
"norm": norm,
}
# Use BasicBlock for R18 and R34.
if depth in [18, 34]:
stage_kargs["block_class"] = BasicBlock
else:
stage_kargs["bottleneck_channels"] = bottleneck_channels
stage_kargs["stride_in_1x1"] = stride_in_1x1
stage_kargs["dilation"] = dilation
stage_kargs["num_groups"] = num_groups
if deform_on_per_stage[idx]:
stage_kargs["block_class"] = DeformBottleneckBlock
stage_kargs["deform_modulated"] = deform_modulated
stage_kargs["deform_num_groups"] = deform_num_groups
else:
stage_kargs["block_class"] = BottleneckBlock
blocks = ResNet.make_stage(**stage_kargs)
in_channels = out_channels
out_channels *= 2
bottleneck_channels *= 2
stages.append(blocks)
return ResNet(stem, stages, out_features=out_features, freeze_at=freeze_at)
|