diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 000000000..763d77f04 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,41 @@ +name: Tests + +on: + push: + branches: [master, main] + pull_request: + branches: [master, main] + +jobs: + test: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python-version: ["3.9", "3.10", "3.11"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + cache: pip + + - name: Install CPU-only PyTorch + run: | + pip install torch torchvision --index-url https://download.pytorch.org/whl/cpu + + - name: Install package and test deps + run: | + pip install pytest + pip install -e . --no-build-isolation + + - name: Run tests (CPU-safe, no external data required) + run: | + pytest tests/ -x -v \ + --ignore=tests/test_archs \ + --ignore=tests/test_models \ + --ignore=tests/test_data \ + --ignore=tests/test_losses diff --git a/.github/workflows/publish-pip.yml b/.github/workflows/publish-pip.yml index 06047f748..bdc012507 100644 --- a/.github/workflows/publish-pip.yml +++ b/.github/workflows/publish-pip.yml @@ -1,30 +1,44 @@ name: PyPI Publish -on: push +on: + push: + tags: + - 'v*' jobs: - build-n-publish: + build: + name: Build distribution runs-on: ubuntu-latest - if: startsWith(github.event.ref, 'refs/tags') - steps: - - uses: actions/checkout@v2 - - name: Set up Python 3.8 - uses: actions/setup-python@v1 + - uses: actions/checkout@v4 + - name: Set up Python + uses: actions/setup-python@v5 + with: + python-version: "3.11" + - name: Install build tools + run: pip install build + - name: Build sdist and wheel + run: python -m build + - name: Store distribution packages + uses: actions/upload-artifact@v4 with: - python-version: 3.8 - - name: Upgrade pip - run: pip install pip --upgrade - - name: Install PyTorch (cpu) - run: pip install torch==1.7.0+cpu torchvision==0.8.1+cpu -f https://download.pytorch.org/whl/torch_stable.html - - name: Install dependencies - run: pip install -r requirements.txt - - name: Build and install - run: rm -rf .eggs && pip install -e . - - name: Build for distribution - # remove bdist_wheel for pip installation with compiling cuda extensions - run: python setup.py sdist - - name: Publish distribution to PyPI - uses: pypa/gh-action-pypi-publish@master + name: python-package-distributions + path: dist/ + + publish-to-pypi: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + environment: + name: pypi + url: https://pypi.org/project/basicsr/ + permissions: + id-token: write # OIDC trusted publishing + steps: + - name: Download distributions + uses: actions/download-artifact@v4 with: - password: ${{ secrets.PYPI_API_TOKEN }} + name: python-package-distributions + path: dist/ + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index 239344af6..e7fa91b61 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -1,30 +1,33 @@ -name: PyLint +name: Lint on: [push, pull_request] jobs: - build: - + lint: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ["3.11"] steps: - - uses: actions/checkout@v2 - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v2 - with: - python-version: ${{ matrix.python-version }} + - uses: actions/checkout@v4 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v5 + with: + python-version: ${{ matrix.python-version }} + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff codespell - - name: Install dependencies - run: | - python -m pip install --upgrade pip - pip install codespell flake8 isort yapf + - name: Lint with ruff + run: | + ruff check basicsr/ options/ scripts/ tests/ inference/ + ruff format --check basicsr/ options/ scripts/ tests/ inference/ - - name: Lint - run: | - codespell - flake8 . - isort --check-only --diff basicsr/ options/ scripts/ tests/ inference/ setup.py - yapf -r -d basicsr/ options/ scripts/ tests/ inference/ setup.py + - name: Spell check + run: | + codespell \ + --skip=".git,./docs/build,*.cfg,*.toml,*.md" \ + --ignore-words-list="uper,ramdom,propgation,simuator" diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index d22231098..2648468e6 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -1,41 +1,31 @@ -name: release +name: Release + on: push: tags: - - '*' + - 'v*' jobs: - build: - permissions: write-all - name: Create Release + release: + name: Create GitHub Release runs-on: ubuntu-latest + permissions: + contents: write steps: - - name: Checkout code - uses: actions/checkout@v2 + - uses: actions/checkout@v4 - name: Create Release - id: create_release - uses: actions/create-release@v1 - env: - GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} + uses: softprops/action-gh-release@v2 with: - tag_name: ${{ github.ref }} - release_name: BasicSR ${{ github.ref }} Release Note + draft: true + prerelease: false + generate_release_notes: true body: | - 🚀 See you again 😸 - 🚀Have a nice day 😸 and happy everyday 😃 - 🚀 Long time no see ☄️ - - ✨ **Highlights** - ✅ [Features] Support ... + ## BasicSR ${{ github.ref_name }} - 🐛 **Bug Fixes** + ### Highlights + - See commits for full change list. - 🌴 **Improvements** - - 📢📢📢 - -
-
-
- (start from 2022-11-06)
diff --git a/VERSION b/VERSION
index 9df886c42..428b770e3 100644
--- a/VERSION
+++ b/VERSION
@@ -1 +1 @@
-1.4.2
+1.4.3
diff --git a/basicsr/archs/arch_util.py b/basicsr/archs/arch_util.py
index 493e56611..202c729a7 100644
--- a/basicsr/archs/arch_util.py
+++ b/basicsr/archs/arch_util.py
@@ -1,10 +1,11 @@
import collections.abc
import math
-import torch
-import torchvision
import warnings
from distutils.version import LooseVersion
from itertools import repeat
+
+import torch
+import torchvision
from torch import nn as nn
from torch.nn import functional as F
from torch.nn import init as init
@@ -178,13 +179,14 @@ def resize_flow(flow, size_type, sizes, interp_mode='bilinear', align_corners=Fa
input_flow[:, 0, :, :] *= ratio_w
input_flow[:, 1, :, :] *= ratio_h
resized_flow = F.interpolate(
- input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners)
+ input=input_flow, size=(output_h, output_w), mode=interp_mode, align_corners=align_corners
+ )
return resized_flow
# TODO: may write a cpp file
def pixel_unshuffle(x, scale):
- """ Pixel unshuffle.
+ """Pixel unshuffle.
Args:
x (Tensor): Input feature with shape (b, c, hh, hw).
@@ -224,11 +226,22 @@ def forward(self, x, feat):
logger.warning(f'Offset abs mean is {offset_absmean}, larger than 50.')
if LooseVersion(torchvision.__version__) >= LooseVersion('0.9.0'):
- return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
- self.dilation, mask)
+ return torchvision.ops.deform_conv2d(
+ x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, mask
+ )
else:
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding,
- self.dilation, self.groups, self.deformable_groups)
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
def _no_grad_trunc_normal_(tensor, mean, std, a, b):
@@ -237,13 +250,14 @@ def _no_grad_trunc_normal_(tensor, mean, std, a, b):
# Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf
def norm_cdf(x):
# Computes standard normal cumulative distribution function
- return (1. + math.erf(x / math.sqrt(2.))) / 2.
+ return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0
if (mean < a - 2 * std) or (mean > b + 2 * std):
warnings.warn(
'mean is more than 2 std from [a, b] in nn.init.trunc_normal_. '
'The distribution of values may be incorrect.',
- stacklevel=2)
+ stacklevel=2,
+ )
with torch.no_grad():
# Values are generated by using a truncated uniform distribution and
@@ -261,7 +275,7 @@ def norm_cdf(x):
tensor.erfinv_()
# Transform to proper mean, std
- tensor.mul_(std * math.sqrt(2.))
+ tensor.mul_(std * math.sqrt(2.0))
tensor.add_(mean)
# Clamp to ensure it's in the proper range
@@ -269,7 +283,7 @@ def norm_cdf(x):
return tensor
-def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.):
+def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
r"""Fills the input Tensor with values drawn from a truncated
normal distribution.
diff --git a/basicsr/archs/basicvsr_arch.py b/basicsr/archs/basicvsr_arch.py
index ed7b824ea..a67a49199 100644
--- a/basicsr/archs/basicvsr_arch.py
+++ b/basicsr/archs/basicvsr_arch.py
@@ -3,6 +3,7 @@
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import ResidualBlockNoBN, flow_warp, make_layer
from .edvr_arch import PCDAlignment, TSAFusion
from .spynet_arch import SpyNet
@@ -110,8 +111,10 @@ class ConvResidualBlocks(nn.Module):
def __init__(self, num_in_ch=3, num_out_ch=64, num_block=15):
super().__init__()
self.main = nn.Sequential(
- nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True), nn.LeakyReLU(negative_slope=0.1, inplace=True),
- make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch))
+ nn.Conv2d(num_in_ch, num_out_ch, 3, 1, 1, bias=True),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ make_layer(ResidualBlockNoBN, num_block, num_feat=num_out_ch),
+ )
def forward(self, fea):
return self.main(fea)
@@ -130,13 +133,9 @@ class IconVSR(nn.Module):
edvr_path (str): Path to the pretrained EDVR model. Default: None.
"""
- def __init__(self,
- num_feat=64,
- num_block=15,
- keyframe_stride=5,
- temporal_padding=2,
- spynet_path=None,
- edvr_path=None):
+ def __init__(
+ self, num_feat=64, num_block=15, keyframe_stride=5, temporal_padding=2, spynet_path=None, edvr_path=None
+ ):
super().__init__()
self.num_feat = num_feat
@@ -210,7 +209,7 @@ def get_keyframe_feature(self, x, keyframe_idx):
num_frames = 2 * self.temporal_padding + 1
feats_keyframe = {}
for i in keyframe_idx:
- feats_keyframe[i] = self.edvr(x[:, i:i + num_frames].contiguous())
+ feats_keyframe[i] = self.edvr(x[:, i : i + num_frames].contiguous())
return feats_keyframe
def forward(self, x):
@@ -265,7 +264,7 @@ def forward(self, x):
out += base
out_l[i] = out
- return torch.stack(out_l, dim=1)[..., :4 * h_input, :4 * w_input]
+ return torch.stack(out_l, dim=1)[..., : 4 * h_input, : 4 * w_input]
class EDVRFeatureExtractor(nn.Module):
@@ -321,13 +320,16 @@ def forward(self, x):
# PCD alignment
ref_feat_l = [ # reference feature list
- feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
- feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone(),
]
aligned_feat = []
for i in range(n):
nbr_feat_l = [ # neighboring feature list
- feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ feat_l1[:, i, :, :, :].clone(),
+ feat_l2[:, i, :, :, :].clone(),
+ feat_l3[:, i, :, :, :].clone(),
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
diff --git a/basicsr/archs/basicvsrpp_arch.py b/basicsr/archs/basicvsrpp_arch.py
index 2a9952e4b..340c738da 100644
--- a/basicsr/archs/basicvsrpp_arch.py
+++ b/basicsr/archs/basicvsrpp_arch.py
@@ -1,8 +1,9 @@
+import warnings
+
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision
-import warnings
from basicsr.archs.arch_util import flow_warp
from basicsr.archs.basicvsr_arch import ConvResidualBlocks
@@ -40,13 +41,15 @@ class BasicVSRPlusPlus(nn.Module):
Default: 100.
"""
- def __init__(self,
- mid_channels=64,
- num_blocks=7,
- max_residue_magnitude=10,
- is_low_res_input=True,
- spynet_path=None,
- cpu_cache_length=100):
+ def __init__(
+ self,
+ mid_channels=64,
+ num_blocks=7,
+ max_residue_magnitude=10,
+ is_low_res_input=True,
+ spynet_path=None,
+ cpu_cache_length=100,
+ ):
super().__init__()
self.mid_channels = mid_channels
@@ -61,9 +64,12 @@ def __init__(self,
self.feat_extract = ConvResidualBlocks(3, mid_channels, 5)
else:
self.feat_extract = nn.Sequential(
- nn.Conv2d(3, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
- nn.Conv2d(mid_channels, mid_channels, 3, 2, 1), nn.LeakyReLU(negative_slope=0.1, inplace=True),
- ConvResidualBlocks(mid_channels, mid_channels, 5))
+ nn.Conv2d(3, mid_channels, 3, 2, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ nn.Conv2d(mid_channels, mid_channels, 3, 2, 1),
+ nn.LeakyReLU(negative_slope=0.1, inplace=True),
+ ConvResidualBlocks(mid_channels, mid_channels, 5),
+ )
# propagation branches
self.deform_align = nn.ModuleDict()
@@ -77,7 +83,8 @@ def __init__(self,
3,
padding=1,
deformable_groups=16,
- max_residue_magnitude=max_residue_magnitude)
+ max_residue_magnitude=max_residue_magnitude,
+ )
self.backbone[module] = ConvResidualBlocks((2 + i) * mid_channels, mid_channels, num_blocks)
# upsampling module
@@ -102,9 +109,11 @@ def __init__(self,
self.is_with_alignment = True
else:
self.is_with_alignment = False
- warnings.warn('Deformable alignment module is not added. '
- 'Probably your CUDA is not configured correctly. DCN can only '
- 'be used with CUDA enabled. Alignment is skipped now.')
+ warnings.warn(
+ 'Deformable alignment module is not added. '
+ 'Probably your CUDA is not configured correctly. DCN can only '
+ 'be used with CUDA enabled. Alignment is skipped now.'
+ )
def check_if_mirror_extended(self, lqs):
"""Check whether the input is a mirror-extended sequence.
@@ -296,8 +305,9 @@ def forward(self, lqs):
if self.is_low_res_input:
lqs_downsample = lqs.clone()
else:
- lqs_downsample = F.interpolate(
- lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(n, t, c, h // 4, w // 4)
+ lqs_downsample = F.interpolate(lqs.view(-1, c, h, w), scale_factor=0.25, mode='bicubic').view(
+ n, t, c, h // 4, w // 4
+ )
# check whether the input is an extended sequence
self.check_if_mirror_extended(lqs)
@@ -318,8 +328,8 @@ def forward(self, lqs):
# compute optical flow using the low-res inputs
assert lqs_downsample.size(3) >= 64 and lqs_downsample.size(4) >= 64, (
- 'The height and width of low-res inputs must be at least 64, '
- f'but got {h} and {w}.')
+ f'The height and width of low-res inputs must be at least 64, but got {h} and {w}.'
+ )
flows_forward, flows_backward = self.compute_flow(lqs_downsample)
# feature propgation
@@ -404,8 +414,9 @@ def forward(self, x, extra_feat, flow_1, flow_2):
# mask
mask = torch.sigmoid(mask)
- return torchvision.ops.deform_conv2d(x, offset, self.weight, self.bias, self.stride, self.padding,
- self.dilation, mask)
+ return torchvision.ops.deform_conv2d(
+ x, offset, self.weight, self.bias, self.stride, self.padding, self.dilation, mask
+ )
# if __name__ == '__main__':
diff --git a/basicsr/archs/dfdnet_arch.py b/basicsr/archs/dfdnet_arch.py
index 4751434c2..b2040b321 100644
--- a/basicsr/archs/dfdnet_arch.py
+++ b/basicsr/archs/dfdnet_arch.py
@@ -5,6 +5,7 @@
from torch.nn.utils.spectral_norm import spectral_norm
from basicsr.utils.registry import ARCH_REGISTRY
+
from .dfdnet_util import AttentionBlock, Blur, MSDilationBlock, UpResBlock, adaptive_instance_normalization
from .vgg_arch import VGGFeatureExtractor
@@ -35,11 +36,16 @@ def __init__(self, in_channel, out_channel, kernel_size=3, padding=1):
# for SFT scale and shift
self.scale_block = nn.Sequential(
- spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
- spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)))
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)),
+ nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
+ )
self.shift_block = nn.Sequential(
- spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
- spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)), nn.Sigmoid())
+ spectral_norm(nn.Conv2d(in_channel, out_channel, 3, 1, 1)),
+ nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(out_channel, out_channel, 3, 1, 1)),
+ nn.Sigmoid(),
+ )
# The official codes use sigmoid for shift block, do not know why
def forward(self, x, updated_feat):
@@ -78,11 +84,8 @@ def __init__(self, num_feat, dict_path):
# vgg face extractor
self.vgg_extractor = VGGFeatureExtractor(
- layer_name_list=self.vgg_layers,
- vgg_type='vgg19',
- use_input_norm=True,
- range_norm=True,
- requires_grad=False)
+ layer_name_list=self.vgg_layers, vgg_type='vgg19', use_input_norm=True, range_norm=True, requires_grad=False
+ )
# attention block for fusing dictionary features and input features
self.attn_blocks = nn.ModuleDict()
@@ -99,13 +102,18 @@ def __init__(self, num_feat, dict_path):
self.upsample2 = SFTUpBlock(num_feat * 4, num_feat * 2)
self.upsample3 = SFTUpBlock(num_feat * 2, num_feat)
self.upsample4 = nn.Sequential(
- spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)), nn.LeakyReLU(0.2, True), UpResBlock(num_feat),
- UpResBlock(num_feat), nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1), nn.Tanh())
+ spectral_norm(nn.Conv2d(num_feat, num_feat, 3, 1, 1)),
+ nn.LeakyReLU(0.2, True),
+ UpResBlock(num_feat),
+ UpResBlock(num_feat),
+ nn.Conv2d(num_feat, 3, kernel_size=3, stride=1, padding=1),
+ nn.Tanh(),
+ )
def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_size):
"""swap the features from the dictionary."""
# get the original vgg features
- part_feat = vgg_feat[:, :, location[1]:location[3], location[0]:location[2]].clone()
+ part_feat = vgg_feat[:, :, location[1] : location[3], location[0] : location[2]].clone()
# resize original vgg features
part_resize_feat = F.interpolate(part_feat, dict_feat.size()[2:4], mode='bilinear', align_corners=False)
# use adaptive instance normalization to adjust color and illuminations
@@ -115,12 +123,12 @@ def swap_feat(self, vgg_feat, updated_feat, dict_feat, location, part_name, f_si
similarity_score = F.softmax(similarity_score.view(-1), dim=0)
# select the most similar features in the dict (after norm)
select_idx = torch.argmax(similarity_score)
- swap_feat = F.interpolate(dict_feat[select_idx:select_idx + 1], part_feat.size()[2:4])
+ swap_feat = F.interpolate(dict_feat[select_idx : select_idx + 1], part_feat.size()[2:4])
# attention
attn = self.attn_blocks[f'{part_name}_' + str(f_size)](swap_feat - part_feat)
attn_feat = attn * swap_feat
# update features
- updated_feat[:, :, location[1]:location[3], location[0]:location[2]] = attn_feat + part_feat
+ updated_feat[:, :, location[1] : location[3], location[0] : location[2]] = attn_feat + part_feat
return updated_feat
def put_dict_to_device(self, x):
@@ -152,8 +160,9 @@ def forward(self, x, part_locations):
# swap features from dictionary
for part_idx, part_name in enumerate(self.parts):
location = (part_locations[part_idx][batch] // (512 / f_size)).int()
- updated_feat = self.swap_feat(vgg_feat, updated_feat, dict_features[part_name], location, part_name,
- f_size)
+ updated_feat = self.swap_feat(
+ vgg_feat, updated_feat, dict_features[part_name], location, part_name, f_size
+ )
updated_vgg_features.append(updated_feat)
diff --git a/basicsr/archs/dfdnet_util.py b/basicsr/archs/dfdnet_util.py
index b4dc0ff73..1edbc8b8f 100644
--- a/basicsr/archs/dfdnet_util.py
+++ b/basicsr/archs/dfdnet_util.py
@@ -6,7 +6,6 @@
class BlurFunctionBackward(Function):
-
@staticmethod
def forward(ctx, grad_output, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
@@ -21,7 +20,6 @@ def backward(ctx, gradgrad_output):
class BlurFunction(Function):
-
@staticmethod
def forward(ctx, x, kernel, kernel_flip):
ctx.save_for_backward(kernel, kernel_flip)
@@ -39,7 +37,6 @@ def backward(ctx, grad_output):
class Blur(nn.Module):
-
def __init__(self, channel):
super().__init__()
kernel = torch.tensor([[1, 2, 1], [2, 4, 2], [1, 2, 1]], dtype=torch.float32)
@@ -90,8 +87,10 @@ def adaptive_instance_normalization(content_feat, style_feat):
def AttentionBlock(in_channel):
return nn.Sequential(
- spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)), nn.LeakyReLU(0.2, True),
- spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)))
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
+ nn.LeakyReLU(0.2, True),
+ spectral_norm(nn.Conv2d(in_channel, in_channel, 3, 1, 1)),
+ )
def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, bias=True):
@@ -106,7 +105,9 @@ def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, b
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
- bias=bias)),
+ bias=bias,
+ )
+ ),
nn.LeakyReLU(0.2),
spectral_norm(
nn.Conv2d(
@@ -116,7 +117,9 @@ def conv_block(in_channels, out_channels, kernel_size=3, stride=1, dilation=1, b
stride=stride,
dilation=dilation,
padding=((kernel_size - 1) // 2) * dilation,
- bias=bias)),
+ bias=bias,
+ )
+ ),
)
@@ -136,7 +139,9 @@ def __init__(self, in_channels, kernel_size=3, dilation=(1, 1, 1, 1), bias=True)
kernel_size=kernel_size,
stride=1,
padding=(kernel_size - 1) // 2,
- bias=bias))
+ bias=bias,
+ )
+ )
def forward(self, x):
out = []
@@ -148,7 +153,6 @@ def forward(self, x):
class UpResBlock(nn.Module):
-
def __init__(self, in_channel):
super(UpResBlock, self).__init__()
self.body = nn.Sequential(
diff --git a/basicsr/archs/discriminator_arch.py b/basicsr/archs/discriminator_arch.py
index 33f9a8f1b..f4217e59d 100644
--- a/basicsr/archs/discriminator_arch.py
+++ b/basicsr/archs/discriminator_arch.py
@@ -20,7 +20,8 @@ def __init__(self, num_in_ch, num_feat, input_size=128):
super(VGGStyleDiscriminator, self).__init__()
self.input_size = input_size
assert self.input_size == 128 or self.input_size == 256, (
- f'input size must be 128 or 256, but received {input_size}')
+ f'input size must be 128 or 256, but received {input_size}'
+ )
self.conv0_0 = nn.Conv2d(num_in_ch, num_feat, 3, 1, 1, bias=True)
self.conv0_1 = nn.Conv2d(num_feat, num_feat, 4, 2, 1, bias=False)
@@ -59,7 +60,7 @@ def __init__(self, num_in_ch, num_feat, input_size=128):
self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)
def forward(self, x):
- assert x.size(2) == self.input_size, (f'Input size must be identical to input_size, but received {x.size()}.')
+ assert x.size(2) == self.input_size, f'Input size must be identical to input_size, but received {x.size()}.'
feat = self.lrelu(self.conv0_0(x))
feat = self.lrelu(self.bn0_1(self.conv0_1(feat))) # output spatial size: /2
diff --git a/basicsr/archs/duf_arch.py b/basicsr/archs/duf_arch.py
index e2b3ab7df..31abf1ebf 100644
--- a/basicsr/archs/duf_arch.py
+++ b/basicsr/archs/duf_arch.py
@@ -28,32 +28,47 @@ def __init__(self, num_feat=64, num_grow_ch=32, adapt_official_weights=False):
momentum = 0.1
self.temporal_reduce1 = nn.Sequential(
- nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
nn.Conv3d(num_feat, num_feat, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True),
- nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
- nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+ nn.BatchNorm3d(num_feat, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True),
+ )
self.temporal_reduce2 = nn.Sequential(
- nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + num_grow_ch,
- num_feat + num_grow_ch, (1, 1, 1),
+ num_feat + num_grow_ch,
+ (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
- bias=True), nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
- nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+ bias=True,
+ ),
+ nn.BatchNorm3d(num_feat + num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
+ nn.Conv3d(num_feat + num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True),
+ )
self.temporal_reduce3 = nn.Sequential(
- nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + 2 * num_grow_ch,
- num_feat + 2 * num_grow_ch, (1, 1, 1),
+ num_feat + 2 * num_grow_ch,
+ (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
- bias=True), nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
+ bias=True,
+ ),
+ nn.BatchNorm3d(num_feat + 2 * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
- num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True))
+ num_feat + 2 * num_grow_ch, num_grow_ch, (3, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True
+ ),
+ )
def forward(self, x):
"""
@@ -76,7 +91,7 @@ def forward(self, x):
class DenseBlocks(nn.Module):
- """ A concatenation of N dense blocks.
+ """A concatenation of N dense blocks.
Args:
num_feat (int): Number of channels in the blocks. Default: 64.
@@ -102,20 +117,28 @@ def __init__(self, num_block, num_feat=64, num_grow_ch=16, adapt_official_weight
for i in range(0, num_block):
self.dense_blocks.append(
nn.Sequential(
- nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum), nn.ReLU(inplace=True),
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
+ nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
- num_feat + i * num_grow_ch, (1, 1, 1),
+ num_feat + i * num_grow_ch,
+ (1, 1, 1),
stride=(1, 1, 1),
padding=(0, 0, 0),
- bias=True), nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
+ bias=True,
+ ),
+ nn.BatchNorm3d(num_feat + i * num_grow_ch, eps=eps, momentum=momentum),
nn.ReLU(inplace=True),
nn.Conv3d(
num_feat + i * num_grow_ch,
- num_grow_ch, (3, 3, 3),
+ num_grow_ch,
+ (3, 3, 3),
stride=(1, 1, 1),
padding=(1, 1, 1),
- bias=True)))
+ bias=True,
+ ),
+ )
+ )
def forward(self, x):
"""
@@ -170,9 +193,11 @@ def forward(self, x, filters):
n, filter_prod, upsampling_square, h, w = filters.size()
kh, kw = self.filter_size
expanded_input = F.conv2d(
- x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3) # (n, 3*filter_prod, h, w)
- expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(0, 3, 4, 1,
- 2) # (n, h, w, 3, filter_prod)
+ x, self.expansion_filter.to(x), padding=(kh // 2, kw // 2), groups=3
+ ) # (n, 3*filter_prod, h, w)
+ expanded_input = expanded_input.view(n, 3, filter_prod, h, w).permute(
+ 0, 3, 4, 1, 2
+ ) # (n, h, w, 3, filter_prod)
filters = filters.permute(0, 3, 4, 1, 2) # (n, h, w, filter_prod, upsampling_square]
out = torch.matmul(expanded_input, filters) # (n, h, w, 3, upsampling_square)
return out.permute(0, 3, 4, 1, 2).view(n, 3 * upsampling_square, h, w)
@@ -227,10 +252,11 @@ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
raise ValueError(f'Only supported (16, 28, 52) layers, but got {num_layer}.')
self.dense_block1 = DenseBlocks(
- num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch,
- adapt_official_weights=adapt_official_weights) # T = 7
+ num_block=num_block, num_feat=64, num_grow_ch=num_grow_ch, adapt_official_weights=adapt_official_weights
+ ) # T = 7
self.dense_block2 = DenseBlocksTemporalReduce(
- 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights) # T = 1
+ 64 + num_grow_ch * num_block, num_grow_ch, adapt_official_weights=adapt_official_weights
+ ) # T = 1
channels = 64 + num_grow_ch * num_block + num_grow_ch * 3
self.bn3d2 = nn.BatchNorm3d(channels, eps=eps, momentum=momentum)
self.conv3d2 = nn.Conv3d(channels, 256, (1, 3, 3), stride=(1, 1, 1), padding=(0, 1, 1), bias=True)
@@ -240,7 +266,8 @@ def __init__(self, scale=4, num_layer=52, adapt_official_weights=False):
self.conv3d_f1 = nn.Conv3d(256, 512, (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
self.conv3d_f2 = nn.Conv3d(
- 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True)
+ 512, 1 * 5 * 5 * (scale**2), (1, 1, 1), stride=(1, 1, 1), padding=(0, 0, 0), bias=True
+ )
def forward(self, x):
"""
diff --git a/basicsr/archs/ecbsr_arch.py b/basicsr/archs/ecbsr_arch.py
index fe20e7725..1c0001dd3 100644
--- a/basicsr/archs/ecbsr_arch.py
+++ b/basicsr/archs/ecbsr_arch.py
@@ -44,7 +44,7 @@ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(scale)
bias = torch.randn(self.out_channels) * 1e-3
- bias = torch.reshape(bias, (self.out_channels, ))
+ bias = torch.reshape(bias, (self.out_channels,))
self.bias = nn.Parameter(bias)
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
@@ -66,7 +66,7 @@ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
- bias = torch.reshape(bias, (self.out_channels, ))
+ bias = torch.reshape(bias, (self.out_channels,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
@@ -88,7 +88,7 @@ def __init__(self, seq_type, in_channels, out_channels, depth_multiplier=1):
scale = torch.randn(size=(self.out_channels, 1, 1, 1)) * 1e-3
self.scale = nn.Parameter(torch.FloatTensor(scale))
bias = torch.randn(self.out_channels) * 1e-3
- bias = torch.reshape(bias, (self.out_channels, ))
+ bias = torch.reshape(bias, (self.out_channels,))
self.bias = nn.Parameter(torch.FloatTensor(bias))
# init mask
self.mask = torch.zeros((self.out_channels, 1, 3, 3), dtype=torch.float32)
@@ -138,7 +138,12 @@ def rep_params(self):
rep_weight = F.conv2d(input=self.k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.mid_planes, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
- rep_bias = F.conv2d(input=rep_bias, weight=self.k1).view(-1, ) + self.b1
+ rep_bias = (
+ F.conv2d(input=rep_bias, weight=self.k1).view(
+ -1,
+ )
+ + self.b1
+ )
else:
tmp = self.scale * self.mask
k1 = torch.zeros((self.out_channels, self.out_channels, 3, 3), device=device)
@@ -149,7 +154,12 @@ def rep_params(self):
rep_weight = F.conv2d(input=k1, weight=self.k0.permute(1, 0, 2, 3))
# re-param conv bias
rep_bias = torch.ones(1, self.out_channels, 3, 3, device=device) * self.b0.view(1, -1, 1, 1)
- rep_bias = F.conv2d(input=rep_bias, weight=k1).view(-1, ) + b1
+ rep_bias = (
+ F.conv2d(input=rep_bias, weight=k1).view(
+ -1,
+ )
+ + b1
+ )
return rep_weight, rep_bias
@@ -217,8 +227,10 @@ def rep_params(self):
weight2, bias2 = self.conv1x1_sbx.rep_params()
weight3, bias3 = self.conv1x1_sby.rep_params()
weight4, bias4 = self.conv1x1_lpl.rep_params()
- rep_weight, rep_bias = (weight0 + weight1 + weight2 + weight3 + weight4), (
- bias0 + bias1 + bias2 + bias3 + bias4)
+ rep_weight, rep_bias = (
+ (weight0 + weight1 + weight2 + weight3 + weight4),
+ (bias0 + bias1 + bias2 + bias3 + bias4),
+ )
if self.with_idt:
device = rep_weight.get_device()
diff --git a/basicsr/archs/edsr_arch.py b/basicsr/archs/edsr_arch.py
index b80566f11..29b7fb357 100644
--- a/basicsr/archs/edsr_arch.py
+++ b/basicsr/archs/edsr_arch.py
@@ -27,15 +27,17 @@ class EDSR(nn.Module):
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
- def __init__(self,
- num_in_ch,
- num_out_ch,
- num_feat=64,
- num_block=16,
- upscale=4,
- res_scale=1,
- img_range=255.,
- rgb_mean=(0.4488, 0.4371, 0.4040)):
+ def __init__(
+ self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_block=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.0,
+ rgb_mean=(0.4488, 0.4371, 0.4040),
+ ):
super(EDSR, self).__init__()
self.img_range = img_range
diff --git a/basicsr/archs/edvr_arch.py b/basicsr/archs/edvr_arch.py
index b0c4f47de..1ddf2247a 100644
--- a/basicsr/archs/edvr_arch.py
+++ b/basicsr/archs/edvr_arch.py
@@ -3,6 +3,7 @@
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import DCNv2Pack, ResidualBlockNoBN, make_layer
@@ -268,18 +269,20 @@ class EDVR(nn.Module):
with_tsa (bool): Whether has TSA module. Default: True.
"""
- def __init__(self,
- num_in_ch=3,
- num_out_ch=3,
- num_feat=64,
- num_frame=5,
- deformable_groups=8,
- num_extract_block=5,
- num_reconstruct_block=10,
- center_frame_idx=None,
- hr_in=False,
- with_predeblur=False,
- with_tsa=True):
+ def __init__(
+ self,
+ num_in_ch=3,
+ num_out_ch=3,
+ num_feat=64,
+ num_frame=5,
+ deformable_groups=8,
+ num_extract_block=5,
+ num_reconstruct_block=10,
+ center_frame_idx=None,
+ hr_in=False,
+ with_predeblur=False,
+ with_tsa=True,
+ ):
super(EDVR, self).__init__()
if center_frame_idx is None:
self.center_frame_idx = num_frame // 2
@@ -325,9 +328,9 @@ def __init__(self,
def forward(self, x):
b, t, c, h, w = x.size()
if self.hr_in:
- assert h % 16 == 0 and w % 16 == 0, ('The height and width must be multiple of 16.')
+ assert h % 16 == 0 and w % 16 == 0, 'The height and width must be multiple of 16.'
else:
- assert h % 4 == 0 and w % 4 == 0, ('The height and width must be multiple of 4.')
+ assert h % 4 == 0 and w % 4 == 0, 'The height and width must be multiple of 4.'
x_center = x[:, self.center_frame_idx, :, :, :].contiguous()
@@ -354,13 +357,16 @@ def forward(self, x):
# PCD alignment
ref_feat_l = [ # reference feature list
- feat_l1[:, self.center_frame_idx, :, :, :].clone(), feat_l2[:, self.center_frame_idx, :, :, :].clone(),
- feat_l3[:, self.center_frame_idx, :, :, :].clone()
+ feat_l1[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l2[:, self.center_frame_idx, :, :, :].clone(),
+ feat_l3[:, self.center_frame_idx, :, :, :].clone(),
]
aligned_feat = []
for i in range(t):
nbr_feat_l = [ # neighboring feature list
- feat_l1[:, i, :, :, :].clone(), feat_l2[:, i, :, :, :].clone(), feat_l3[:, i, :, :, :].clone()
+ feat_l1[:, i, :, :, :].clone(),
+ feat_l2[:, i, :, :, :].clone(),
+ feat_l3[:, i, :, :, :].clone(),
]
aligned_feat.append(self.pcd_align(nbr_feat_l, ref_feat_l))
aligned_feat = torch.stack(aligned_feat, dim=1) # (b, t, c, h, w)
diff --git a/basicsr/archs/hifacegan_arch.py b/basicsr/archs/hifacegan_arch.py
index 098e3ed43..bfb9101b0 100644
--- a/basicsr/archs/hifacegan_arch.py
+++ b/basicsr/archs/hifacegan_arch.py
@@ -4,21 +4,24 @@
import torch.nn.functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .hifacegan_util import BaseNetwork, LIPEncoder, SPADEResnetBlock, get_nonspade_norm_layer
class SPADEGenerator(BaseNetwork):
"""Generator with SPADEResBlock"""
- def __init__(self,
- num_in_ch=3,
- num_feat=64,
- use_vae=False,
- z_dim=256,
- crop_size=512,
- norm_g='spectralspadesyncbatch3x3',
- is_train=True,
- init_train_phase=3): # progressive training disabled
+ def __init__(
+ self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3,
+ ): # progressive training disabled
super().__init__()
self.nf = num_feat
self.input_nc = num_in_ch
@@ -42,19 +45,23 @@ def __init__(self,
self.g_middle_0 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
self.g_middle_1 = SPADEResnetBlock(16 * self.nf, 16 * self.nf, norm_g)
- self.ups = nn.ModuleList([
- SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
- SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
- SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
- SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g)
- ])
-
- self.to_rgbs = nn.ModuleList([
- nn.Conv2d(8 * self.nf, 3, 3, padding=1),
- nn.Conv2d(4 * self.nf, 3, 3, padding=1),
- nn.Conv2d(2 * self.nf, 3, 3, padding=1),
- nn.Conv2d(1 * self.nf, 3, 3, padding=1)
- ])
+ self.ups = nn.ModuleList(
+ [
+ SPADEResnetBlock(16 * self.nf, 8 * self.nf, norm_g),
+ SPADEResnetBlock(8 * self.nf, 4 * self.nf, norm_g),
+ SPADEResnetBlock(4 * self.nf, 2 * self.nf, norm_g),
+ SPADEResnetBlock(2 * self.nf, 1 * self.nf, norm_g),
+ ]
+ )
+
+ self.to_rgbs = nn.ModuleList(
+ [
+ nn.Conv2d(8 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(4 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(2 * self.nf, 3, 3, padding=1),
+ nn.Conv2d(1 * self.nf, 3, 3, padding=1),
+ ]
+ )
self.up = nn.Upsample(scale_factor=2)
@@ -148,15 +155,17 @@ class HiFaceGAN(SPADEGenerator):
Current encoder design: LIPEncoder
"""
- def __init__(self,
- num_in_ch=3,
- num_feat=64,
- use_vae=False,
- z_dim=256,
- crop_size=512,
- norm_g='spectralspadesyncbatch3x3',
- is_train=True,
- init_train_phase=3):
+ def __init__(
+ self,
+ num_in_ch=3,
+ num_feat=64,
+ use_vae=False,
+ z_dim=256,
+ crop_size=512,
+ norm_g='spectralspadesyncbatch3x3',
+ is_train=True,
+ init_train_phase=3,
+ ):
super().__init__(num_in_ch, num_feat, use_vae, z_dim, crop_size, norm_g, is_train, init_train_phase)
self.lip_encoder = LIPEncoder(num_in_ch, num_feat, self.sw, self.sh, self.scale_ratio)
@@ -185,15 +194,17 @@ class HiFaceGANDiscriminator(BaseNetwork):
Default: True.
"""
- def __init__(self,
- num_in_ch=3,
- num_out_ch=3,
- conditional_d=True,
- num_d=2,
- n_layers_d=4,
- num_feat=64,
- norm_d='spectralinstance',
- keep_features=True):
+ def __init__(
+ self,
+ num_in_ch=3,
+ num_out_ch=3,
+ conditional_d=True,
+ num_d=2,
+ n_layers_d=4,
+ num_feat=64,
+ norm_d='spectralinstance',
+ keep_features=True,
+ ):
super().__init__()
self.num_d = num_d
@@ -237,10 +248,12 @@ def __init__(self, input_nc, n_layers_d, num_feat, norm_d, keep_features):
nf_prev = nf
nf = min(nf * 2, 512)
stride = 1 if n == n_layers_d - 1 else 2
- sequence += [[
- norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
- nn.LeakyReLU(0.2, False)
- ]]
+ sequence += [
+ [
+ norm_layer(nn.Conv2d(nf_prev, nf, kernel_size=kw, stride=stride, padding=padw)),
+ nn.LeakyReLU(0.2, False),
+ ]
+ ]
sequence += [[nn.Conv2d(nf, 1, kernel_size=kw, stride=1, padding=padw)]]
diff --git a/basicsr/archs/hifacegan_util.py b/basicsr/archs/hifacegan_util.py
index 35cbef3f5..a578d1b53 100644
--- a/basicsr/archs/hifacegan_util.py
+++ b/basicsr/archs/hifacegan_util.py
@@ -1,8 +1,10 @@
import re
+
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn import init
+
# Warning: spectral norm could be buggy
# under eval mode and multi-GPU inference
# A workaround is sticking to single-GPU inference and train mode
@@ -10,7 +12,6 @@
class SPADE(nn.Module):
-
def __init__(self, config_text, norm_nc, label_nc):
super().__init__()
@@ -67,7 +68,7 @@ class SPADEResnetBlock(nn.Module):
def __init__(self, fin, fout, norm_g='spectralspadesyncbatch3x3', semantic_nc=3):
super().__init__()
# Attributes
- self.learned_shortcut = (fin != fout)
+ self.learned_shortcut = fin != fout
fmiddle = min(fin, fout)
# create conv layers
@@ -111,7 +112,7 @@ def act(self, x):
class BaseNetwork(nn.Module):
- """ A basis for hifacegan archs with custom initialization """
+ """A basis for hifacegan archs with custom initialization"""
def init_weights(self, init_type='normal', gain=0.02):
@@ -164,12 +165,13 @@ def forward(self, x):
class SimplifiedLIP(nn.Module):
-
def __init__(self, channels):
super(SimplifiedLIP, self).__init__()
self.logit = nn.Sequential(
- nn.Conv2d(channels, channels, 3, padding=1, bias=False), nn.InstanceNorm2d(channels, affine=True),
- SoftGate())
+ nn.Conv2d(channels, channels, 3, padding=1, bias=False),
+ nn.InstanceNorm2d(channels, affine=True),
+ SoftGate(),
+ )
def init_layer(self):
self.logit[0].weight.data.fill_(0.0)
@@ -226,7 +228,7 @@ def add_norm_layer(layer):
nonlocal norm_type
if norm_type.startswith('spectral'):
layer = spectral_norm(layer)
- subnorm_type = norm_type[len('spectral'):]
+ subnorm_type = norm_type[len('spectral') :]
if subnorm_type == 'none' or len(subnorm_type) == 0:
return layer
diff --git a/basicsr/archs/inception.py b/basicsr/archs/inception.py
index de1abef67..07c7f222d 100644
--- a/basicsr/archs/inception.py
+++ b/basicsr/archs/inception.py
@@ -2,6 +2,7 @@
# For FID metric
import os
+
import torch
import torch.nn as nn
import torch.nn.functional as F
@@ -10,7 +11,9 @@
# Inception weights ported to Pytorch from
# http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz
-FID_WEIGHTS_URL = 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+FID_WEIGHTS_URL = (
+ 'https://github.com/mseitzer/pytorch-fid/releases/download/fid_weights/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
+)
LOCAL_FID_WEIGHTS = 'experiments/pretrained_models/pt_inception-2015-12-05-6726825d.pth' # noqa: E501
@@ -26,15 +29,17 @@ class InceptionV3(nn.Module):
64: 0, # First max pooling features
192: 1, # Second max pooling features
768: 2, # Pre-aux classifier features
- 2048: 3 # Final average pooling features
+ 2048: 3, # Final average pooling features
}
- def __init__(self,
- output_blocks=(DEFAULT_BLOCK_INDEX),
- resize_input=True,
- normalize_input=True,
- requires_grad=False,
- use_fid_inception=True):
+ def __init__(
+ self,
+ output_blocks=(DEFAULT_BLOCK_INDEX),
+ resize_input=True,
+ normalize_input=True,
+ requires_grad=False,
+ use_fid_inception=True,
+ ):
"""Build pretrained InceptionV3.
Args:
@@ -71,7 +76,7 @@ def __init__(self,
self.output_blocks = sorted(output_blocks)
self.last_needed_block = max(output_blocks)
- assert self.last_needed_block <= 3, ('Last possible output block index is 3')
+ assert self.last_needed_block <= 3, 'Last possible output block index is 3'
self.blocks = nn.ModuleList()
@@ -86,8 +91,10 @@ def __init__(self,
# Block 0: input to maxpool1
block0 = [
- inception.Conv2d_1a_3x3, inception.Conv2d_2a_3x3, inception.Conv2d_2b_3x3,
- nn.MaxPool2d(kernel_size=3, stride=2)
+ inception.Conv2d_1a_3x3,
+ inception.Conv2d_2a_3x3,
+ inception.Conv2d_2b_3x3,
+ nn.MaxPool2d(kernel_size=3, stride=2),
]
self.blocks.append(nn.Sequential(*block0))
@@ -113,8 +120,10 @@ def __init__(self,
# Block 3: aux classifier to final avgpool
if self.last_needed_block >= 3:
block3 = [
- inception.Mixed_7a, inception.Mixed_7b, inception.Mixed_7c,
- nn.AdaptiveAvgPool2d(output_size=(1, 1))
+ inception.Mixed_7a,
+ inception.Mixed_7b,
+ inception.Mixed_7c,
+ nn.AdaptiveAvgPool2d(output_size=(1, 1)),
]
self.blocks.append(nn.Sequential(*block3))
diff --git a/basicsr/archs/rcan_arch.py b/basicsr/archs/rcan_arch.py
index 48872e680..7d1daf2f9 100644
--- a/basicsr/archs/rcan_arch.py
+++ b/basicsr/archs/rcan_arch.py
@@ -2,6 +2,7 @@
from torch import nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import Upsample, make_layer
@@ -16,8 +17,12 @@ class ChannelAttention(nn.Module):
def __init__(self, num_feat, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
- nn.AdaptiveAvgPool2d(1), nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
- nn.ReLU(inplace=True), nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0), nn.Sigmoid())
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(num_feat, num_feat // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(num_feat // squeeze_factor, num_feat, 1, padding=0),
+ nn.Sigmoid(),
+ )
def forward(self, x):
y = self.attention(x)
@@ -38,8 +43,11 @@ def __init__(self, num_feat, squeeze_factor=16, res_scale=1):
self.res_scale = res_scale
self.rcab = nn.Sequential(
- nn.Conv2d(num_feat, num_feat, 3, 1, 1), nn.ReLU(True), nn.Conv2d(num_feat, num_feat, 3, 1, 1),
- ChannelAttention(num_feat, squeeze_factor))
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1),
+ nn.ReLU(True),
+ nn.Conv2d(num_feat, num_feat, 3, 1, 1),
+ ChannelAttention(num_feat, squeeze_factor),
+ )
def forward(self, x):
res = self.rcab(x) * self.res_scale
@@ -60,7 +68,8 @@ def __init__(self, num_feat, num_block, squeeze_factor=16, res_scale=1):
super(ResidualGroup, self).__init__()
self.residual_group = make_layer(
- RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale)
+ RCAB, num_block, num_feat=num_feat, squeeze_factor=squeeze_factor, res_scale=res_scale
+ )
self.conv = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
def forward(self, x):
@@ -93,17 +102,19 @@ class RCAN(nn.Module):
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
- def __init__(self,
- num_in_ch,
- num_out_ch,
- num_feat=64,
- num_group=10,
- num_block=16,
- squeeze_factor=16,
- upscale=4,
- res_scale=1,
- img_range=255.,
- rgb_mean=(0.4488, 0.4371, 0.4040)):
+ def __init__(
+ self,
+ num_in_ch,
+ num_out_ch,
+ num_feat=64,
+ num_group=10,
+ num_block=16,
+ squeeze_factor=16,
+ upscale=4,
+ res_scale=1,
+ img_range=255.0,
+ rgb_mean=(0.4488, 0.4371, 0.4040),
+ ):
super(RCAN, self).__init__()
self.img_range = img_range
@@ -116,7 +127,8 @@ def __init__(self,
num_feat=num_feat,
num_block=num_block,
squeeze_factor=squeeze_factor,
- res_scale=res_scale)
+ res_scale=res_scale,
+ )
self.conv_after_body = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
diff --git a/basicsr/archs/ridnet_arch.py b/basicsr/archs/ridnet_arch.py
index 85bb9ae03..c831107e6 100644
--- a/basicsr/archs/ridnet_arch.py
+++ b/basicsr/archs/ridnet_arch.py
@@ -2,11 +2,12 @@
import torch.nn as nn
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import ResidualBlockNoBN, make_layer
class MeanShift(nn.Conv2d):
- """ Data normalization with mean and std.
+ """Data normalization with mean and std.
Args:
rgb_range (int): Maximum value of RGB.
@@ -53,7 +54,7 @@ def forward(self, x):
class MergeRun(nn.Module):
- """ Merge-and-run unit.
+ """Merge-and-run unit.
This unit contains two branches with different dilated convolutions,
followed by a convolution to process the concatenated features.
@@ -66,14 +67,21 @@ def __init__(self, in_channels, out_channels, kernel_size=3, stride=1, padding=1
super(MergeRun, self).__init__()
self.dilation1 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2), nn.ReLU(inplace=True))
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 2, 2),
+ nn.ReLU(inplace=True),
+ )
self.dilation2 = nn.Sequential(
- nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3), nn.ReLU(inplace=True),
- nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4), nn.ReLU(inplace=True))
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, 3, 3),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(out_channels, out_channels, kernel_size, stride, 4, 4),
+ nn.ReLU(inplace=True),
+ )
self.aggregation = nn.Sequential(
- nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True))
+ nn.Conv2d(out_channels * 2, out_channels, kernel_size, stride, padding), nn.ReLU(inplace=True)
+ )
def forward(self, x):
dilation1 = self.dilation1(x)
@@ -95,8 +103,12 @@ class ChannelAttention(nn.Module):
def __init__(self, mid_channels, squeeze_factor=16):
super(ChannelAttention, self).__init__()
self.attention = nn.Sequential(
- nn.AdaptiveAvgPool2d(1), nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
- nn.ReLU(inplace=True), nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0), nn.Sigmoid())
+ nn.AdaptiveAvgPool2d(1),
+ nn.Conv2d(mid_channels, mid_channels // squeeze_factor, 1, padding=0),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(mid_channels // squeeze_factor, mid_channels, 1, padding=0),
+ nn.Sigmoid(),
+ )
def forward(self, x):
y = self.attention(x)
@@ -151,14 +163,16 @@ class RIDNet(nn.Module):
Default: (0.4488, 0.4371, 0.4040), calculated from DIV2K dataset.
"""
- def __init__(self,
- in_channels,
- mid_channels,
- out_channels,
- num_block=4,
- img_range=255.,
- rgb_mean=(0.4488, 0.4371, 0.4040),
- rgb_std=(1.0, 1.0, 1.0)):
+ def __init__(
+ self,
+ in_channels,
+ mid_channels,
+ out_channels,
+ num_block=4,
+ img_range=255.0,
+ rgb_mean=(0.4488, 0.4371, 0.4040),
+ rgb_std=(1.0, 1.0, 1.0),
+ ):
super(RIDNet, self).__init__()
self.sub_mean = MeanShift(img_range, rgb_mean, rgb_std)
@@ -166,7 +180,8 @@ def __init__(self,
self.head = nn.Conv2d(in_channels, mid_channels, 3, 1, 1)
self.body = make_layer(
- EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels)
+ EAM, num_block, in_channels=mid_channels, mid_channels=mid_channels, out_channels=mid_channels
+ )
self.tail = nn.Conv2d(mid_channels, out_channels, 3, 1, 1)
self.relu = nn.ReLU(inplace=True)
diff --git a/basicsr/archs/rrdbnet_arch.py b/basicsr/archs/rrdbnet_arch.py
index 63d07080c..be64f3c67 100644
--- a/basicsr/archs/rrdbnet_arch.py
+++ b/basicsr/archs/rrdbnet_arch.py
@@ -3,6 +3,7 @@
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import default_init_weights, make_layer, pixel_unshuffle
diff --git a/basicsr/archs/spynet_arch.py b/basicsr/archs/spynet_arch.py
index 4c7af133d..8874c876e 100644
--- a/basicsr/archs/spynet_arch.py
+++ b/basicsr/archs/spynet_arch.py
@@ -1,25 +1,31 @@
import math
+
import torch
from torch import nn as nn
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import flow_warp
class BasicModule(nn.Module):
- """Basic Module for SpyNet.
- """
+ """Basic Module for SpyNet."""
def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
- nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
- nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
- nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
- nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3), nn.ReLU(inplace=False),
- nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+ nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3),
+ nn.ReLU(inplace=False),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3),
+ )
def forward(self, tensor_input):
return self.basic_module(tensor_input)
@@ -57,9 +63,8 @@ def process(self, ref, supp):
supp.insert(0, F.avg_pool2d(input=supp[0], kernel_size=2, stride=2, count_include_pad=False))
flow = ref[0].new_zeros(
- [ref[0].size(0), 2,
- int(math.floor(ref[0].size(2) / 2.0)),
- int(math.floor(ref[0].size(3) / 2.0))])
+ [ref[0].size(0), 2, int(math.floor(ref[0].size(2) / 2.0)), int(math.floor(ref[0].size(3) / 2.0))]
+ )
for level in range(len(ref)):
upsampled_flow = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
@@ -69,12 +74,24 @@ def process(self, ref, supp):
if upsampled_flow.size(3) != ref[level].size(3):
upsampled_flow = F.pad(input=upsampled_flow, pad=[0, 1, 0, 0], mode='replicate')
- flow = self.basic_module[level](torch.cat([
- ref[level],
- flow_warp(
- supp[level], upsampled_flow.permute(0, 2, 3, 1), interp_mode='bilinear', padding_mode='border'),
- upsampled_flow
- ], 1)) + upsampled_flow
+ flow = (
+ self.basic_module[level](
+ torch.cat(
+ [
+ ref[level],
+ flow_warp(
+ supp[level],
+ upsampled_flow.permute(0, 2, 3, 1),
+ interp_mode='bilinear',
+ padding_mode='border',
+ ),
+ upsampled_flow,
+ ],
+ 1,
+ )
+ )
+ + upsampled_flow
+ )
return flow
diff --git a/basicsr/archs/srresnet_arch.py b/basicsr/archs/srresnet_arch.py
index 7f571557c..fcaa4dc13 100644
--- a/basicsr/archs/srresnet_arch.py
+++ b/basicsr/archs/srresnet_arch.py
@@ -2,6 +2,7 @@
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import ResidualBlockNoBN, default_init_weights, make_layer
diff --git a/basicsr/archs/stylegan2_arch.py b/basicsr/archs/stylegan2_arch.py
index 9ab37f5a3..603e1abbc 100644
--- a/basicsr/archs/stylegan2_arch.py
+++ b/basicsr/archs/stylegan2_arch.py
@@ -1,5 +1,6 @@
import math
import random
+
import torch
from torch import nn
from torch.nn import functional as F
@@ -10,7 +11,6 @@
class NormStyleCode(nn.Module):
-
def forward(self, x):
"""Normalize the style codes.
@@ -66,7 +66,7 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(factor={self.factor})')
+ return f'{self.__class__.__name__}(factor={self.factor})'
class UpFirDnDownsample(nn.Module):
@@ -91,7 +91,7 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(factor={self.factor})')
+ return f'{self.__class__.__name__}(factor={self.factor})'
class UpFirDnSmooth(nn.Module):
@@ -127,8 +127,10 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
- f', downsample_factor={self.downsample_factor})')
+ return (
+ f'{self.__class__.__name__}(upsample_factor={self.upsample_factor}'
+ f', downsample_factor={self.downsample_factor})'
+ )
class EqualLinear(nn.Module):
@@ -152,8 +154,9 @@ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
- raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
- "Supported ones are: ['fused_lrelu', None].")
+ raise ValueError(
+ f"Wrong activation value in EqualLinear: {activation}Supported ones are: ['fused_lrelu', None]."
+ )
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
@@ -175,8 +178,10 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, bias={self.bias is not None})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, bias={self.bias is not None})'
+ )
class ModulatedConv2d(nn.Module):
@@ -199,15 +204,17 @@ class ModulatedConv2d(nn.Module):
Default: 1e-8.
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=(1, 3, 3, 1),
- eps=1e-8):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1),
+ eps=1e-8,
+ ):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -218,20 +225,24 @@ def __init__(self,
if self.sample_mode == 'upsample':
self.smooth = UpFirDnSmooth(
- resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size)
+ resample_kernel, upsample_factor=2, downsample_factor=1, kernel_size=kernel_size
+ )
elif self.sample_mode == 'downsample':
self.smooth = UpFirDnSmooth(
- resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
+ resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size
+ )
elif self.sample_mode is None:
pass
else:
- raise ValueError(f'Wrong sample mode {self.sample_mode}, '
- "supported ones are ['upsample', 'downsample', None].")
+ raise ValueError(
+ f"Wrong sample mode {self.sample_mode}, supported ones are ['upsample', 'downsample', None]."
+ )
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
- num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None
+ )
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
@@ -279,10 +290,12 @@ def forward(self, x, style):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, '
- f'kernel_size={self.kernel_size}, '
- f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size}, '
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})'
+ )
class StyleConv(nn.Module):
@@ -300,14 +313,16 @@ class StyleConv(nn.Module):
magnitude. Default: (1, 3, 3, 1).
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- resample_kernel=(1, 3, 3, 1)):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ resample_kernel=(1, 3, 3, 1),
+ ):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
@@ -316,7 +331,8 @@ def __init__(self,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
- resample_kernel=resample_kernel)
+ resample_kernel=resample_kernel,
+ )
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
@@ -351,7 +367,8 @@ def __init__(self, in_channels, num_style_feat, upsample=True, resample_kernel=(
else:
self.upsample = None
self.modulated_conv = ModulatedConv2d(
- in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None)
+ in_channels, 3, kernel_size=1, num_style_feat=num_style_feat, demodulate=False, sample_mode=None
+ )
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
@@ -408,14 +425,16 @@ class StyleGAN2Generator(nn.Module):
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
- def __init__(self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- resample_kernel=(1, 3, 3, 1),
- lr_mlp=0.01,
- narrow=1):
+ def __init__(
+ self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ resample_kernel=(1, 3, 3, 1),
+ lr_mlp=0.01,
+ narrow=1,
+ ):
super(StyleGAN2Generator, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
@@ -423,8 +442,9 @@ def __init__(self,
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
- num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
- activation='fused_lrelu'))
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
@@ -436,7 +456,7 @@ def __init__(self,
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
- '1024': int(16 * channel_multiplier * narrow)
+ '1024': int(16 * channel_multiplier * narrow),
}
self.channels = channels
@@ -448,7 +468,8 @@ def __init__(self,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
- resample_kernel=resample_kernel)
+ resample_kernel=resample_kernel,
+ )
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, resample_kernel=resample_kernel)
self.log_size = int(math.log(out_size, 2))
@@ -462,7 +483,7 @@ def __init__(self,
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
- resolution = 2**((layer_idx + 5) // 2)
+ resolution = 2 ** ((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
@@ -477,7 +498,8 @@ def __init__(self,
demodulate=True,
sample_mode='upsample',
resample_kernel=resample_kernel,
- ))
+ )
+ )
self.style_convs.append(
StyleConv(
out_channels,
@@ -486,7 +508,9 @@ def __init__(self,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
- resample_kernel=resample_kernel))
+ resample_kernel=resample_kernel,
+ )
+ )
self.to_rgbs.append(ToRGB(out_channels, num_style_feat, upsample=True, resample_kernel=resample_kernel))
in_channels = out_channels
@@ -509,15 +533,17 @@ def mean_latent(self, num_latent):
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
- def forward(self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False):
+ def forward(
+ self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False,
+ ):
"""Forward function for StyleGAN2Generator.
Args:
@@ -571,8 +597,9 @@ def forward(self,
skip = self.to_rgb1(out, latent[:, 1])
i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
- noise[2::2], self.to_rgbs):
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.style_convs[::2], self.style_convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
@@ -644,11 +671,13 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, '
- f'kernel_size={self.kernel_size},'
- f' stride={self.stride}, padding={self.padding}, '
- f'bias={self.bias is not None})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size},'
+ f' stride={self.stride}, padding={self.padding}, '
+ f'bias={self.bias is not None})'
+ )
class ConvLayer(nn.Sequential):
@@ -668,19 +697,22 @@ class ConvLayer(nn.Sequential):
activate (bool): Whether use activateion. Default: True.
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- downsample=False,
- resample_kernel=(1, 3, 3, 1),
- bias=True,
- activate=True):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ downsample=False,
+ resample_kernel=(1, 3, 3, 1),
+ bias=True,
+ activate=True,
+ ):
layers = []
# downsample
if downsample:
layers.append(
- UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size))
+ UpFirDnSmooth(resample_kernel, upsample_factor=1, downsample_factor=2, kernel_size=kernel_size)
+ )
stride = 2
self.padding = 0
else:
@@ -689,8 +721,9 @@ def __init__(self,
# conv
layers.append(
EqualConv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
- and not activate))
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias and not activate
+ )
+ )
# activation
if activate:
if bias:
@@ -718,9 +751,11 @@ def __init__(self, in_channels, out_channels, resample_kernel=(1, 3, 3, 1)):
self.conv1 = ConvLayer(in_channels, in_channels, 3, bias=True, activate=True)
self.conv2 = ConvLayer(
- in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True)
+ in_channels, out_channels, 3, downsample=True, resample_kernel=resample_kernel, bias=True, activate=True
+ )
self.skip = ConvLayer(
- in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False)
+ in_channels, out_channels, 1, downsample=True, resample_kernel=resample_kernel, bias=False, activate=False
+ )
def forward(self, x):
out = self.conv1(x)
@@ -757,7 +792,7 @@ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1),
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
- '1024': int(16 * channel_multiplier * narrow)
+ '1024': int(16 * channel_multiplier * narrow),
}
log_size = int(math.log(out_size, 2))
@@ -766,7 +801,7 @@ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1),
in_channels = channels[f'{out_size}']
for i in range(log_size, 2, -1):
- out_channels = channels[f'{2**(i - 1)}']
+ out_channels = channels[f'{2 ** (i - 1)}']
conv_body.append(ResBlock(in_channels, out_channels, resample_kernel))
in_channels = out_channels
self.conv_body = nn.Sequential(*conv_body)
@@ -774,7 +809,8 @@ def __init__(self, out_size, channel_multiplier=2, resample_kernel=(1, 3, 3, 1),
self.final_conv = ConvLayer(in_channels + 1, channels['4'], 3, bias=True, activate=True)
self.final_linear = nn.Sequential(
EqualLinear(
- channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'),
+ channels['4'] * 4 * 4, channels['4'], bias=True, bias_init_val=0, lr_mul=1, activation='fused_lrelu'
+ ),
EqualLinear(channels['4'], 1, bias=True, bias_init_val=0, lr_mul=1, activation=None),
)
self.stddev_group = stddev_group
diff --git a/basicsr/archs/stylegan2_bilinear_arch.py b/basicsr/archs/stylegan2_bilinear_arch.py
index 239517041..e7c65375e 100644
--- a/basicsr/archs/stylegan2_bilinear_arch.py
+++ b/basicsr/archs/stylegan2_bilinear_arch.py
@@ -1,5 +1,6 @@
import math
import random
+
import torch
from torch import nn
from torch.nn import functional as F
@@ -9,7 +10,6 @@
class NormStyleCode(nn.Module):
-
def forward(self, x):
"""Normalize the style codes.
@@ -43,8 +43,9 @@ def __init__(self, in_channels, out_channels, bias=True, bias_init_val=0, lr_mul
self.lr_mul = lr_mul
self.activation = activation
if self.activation not in ['fused_lrelu', None]:
- raise ValueError(f'Wrong activation value in EqualLinear: {activation}'
- "Supported ones are: ['fused_lrelu', None].")
+ raise ValueError(
+ f"Wrong activation value in EqualLinear: {activation}Supported ones are: ['fused_lrelu', None]."
+ )
self.scale = (1 / math.sqrt(in_channels)) * lr_mul
self.weight = nn.Parameter(torch.randn(out_channels, in_channels).div_(lr_mul))
@@ -66,8 +67,10 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, bias={self.bias is not None})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, bias={self.bias is not None})'
+ )
class ModulatedConv2d(nn.Module):
@@ -88,15 +91,17 @@ class ModulatedConv2d(nn.Module):
Default: 1e-8.
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- eps=1e-8,
- interpolation_mode='bilinear'):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ eps=1e-8,
+ interpolation_mode='bilinear',
+ ):
super(ModulatedConv2d, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -113,7 +118,8 @@ def __init__(self,
self.scale = 1 / math.sqrt(in_channels * kernel_size**2)
# modulation inside each modulated conv
self.modulation = EqualLinear(
- num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None)
+ num_style_feat, in_channels, bias=True, bias_init_val=1, lr_mul=1, activation=None
+ )
self.weight = nn.Parameter(torch.randn(1, out_channels, in_channels, kernel_size, kernel_size))
self.padding = kernel_size // 2
@@ -154,10 +160,12 @@ def forward(self, x, style):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, '
- f'kernel_size={self.kernel_size}, '
- f'demodulate={self.demodulate}, sample_mode={self.sample_mode})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size}, '
+ f'demodulate={self.demodulate}, sample_mode={self.sample_mode})'
+ )
class StyleConv(nn.Module):
@@ -173,14 +181,16 @@ class StyleConv(nn.Module):
Default: None.
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- num_style_feat,
- demodulate=True,
- sample_mode=None,
- interpolation_mode='bilinear'):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ num_style_feat,
+ demodulate=True,
+ sample_mode=None,
+ interpolation_mode='bilinear',
+ ):
super(StyleConv, self).__init__()
self.modulated_conv = ModulatedConv2d(
in_channels,
@@ -189,7 +199,8 @@ def __init__(self,
num_style_feat,
demodulate=demodulate,
sample_mode=sample_mode,
- interpolation_mode=interpolation_mode)
+ interpolation_mode=interpolation_mode,
+ )
self.weight = nn.Parameter(torch.zeros(1)) # for noise injection
self.activate = FusedLeakyReLU(out_channels)
@@ -230,7 +241,8 @@ def __init__(self, in_channels, num_style_feat, upsample=True, interpolation_mod
num_style_feat=num_style_feat,
demodulate=False,
sample_mode=None,
- interpolation_mode=interpolation_mode)
+ interpolation_mode=interpolation_mode,
+ )
self.bias = nn.Parameter(torch.zeros(1, 3, 1, 1))
def forward(self, x, style, skip=None):
@@ -249,7 +261,8 @@ def forward(self, x, style, skip=None):
if skip is not None:
if self.upsample:
skip = F.interpolate(
- skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners)
+ skip, scale_factor=2, mode=self.interpolation_mode, align_corners=self.align_corners
+ )
out = out + skip
return out
@@ -285,14 +298,16 @@ class StyleGAN2GeneratorBilinear(nn.Module):
narrow (float): Narrow ratio for channels. Default: 1.0.
"""
- def __init__(self,
- out_size,
- num_style_feat=512,
- num_mlp=8,
- channel_multiplier=2,
- lr_mlp=0.01,
- narrow=1,
- interpolation_mode='bilinear'):
+ def __init__(
+ self,
+ out_size,
+ num_style_feat=512,
+ num_mlp=8,
+ channel_multiplier=2,
+ lr_mlp=0.01,
+ narrow=1,
+ interpolation_mode='bilinear',
+ ):
super(StyleGAN2GeneratorBilinear, self).__init__()
# Style MLP layers
self.num_style_feat = num_style_feat
@@ -300,8 +315,9 @@ def __init__(self,
for i in range(num_mlp):
style_mlp_layers.append(
EqualLinear(
- num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp,
- activation='fused_lrelu'))
+ num_style_feat, num_style_feat, bias=True, bias_init_val=0, lr_mul=lr_mlp, activation='fused_lrelu'
+ )
+ )
self.style_mlp = nn.Sequential(*style_mlp_layers)
channels = {
@@ -313,7 +329,7 @@ def __init__(self,
'128': int(128 * channel_multiplier * narrow),
'256': int(64 * channel_multiplier * narrow),
'512': int(32 * channel_multiplier * narrow),
- '1024': int(16 * channel_multiplier * narrow)
+ '1024': int(16 * channel_multiplier * narrow),
}
self.channels = channels
@@ -325,7 +341,8 @@ def __init__(self,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
- interpolation_mode=interpolation_mode)
+ interpolation_mode=interpolation_mode,
+ )
self.to_rgb1 = ToRGB(channels['4'], num_style_feat, upsample=False, interpolation_mode=interpolation_mode)
self.log_size = int(math.log(out_size, 2))
@@ -339,7 +356,7 @@ def __init__(self,
in_channels = channels['4']
# noise
for layer_idx in range(self.num_layers):
- resolution = 2**((layer_idx + 5) // 2)
+ resolution = 2 ** ((layer_idx + 5) // 2)
shape = [1, 1, resolution, resolution]
self.noises.register_buffer(f'noise{layer_idx}', torch.randn(*shape))
# style convs and to_rgbs
@@ -353,7 +370,9 @@ def __init__(self,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode='upsample',
- interpolation_mode=interpolation_mode))
+ interpolation_mode=interpolation_mode,
+ )
+ )
self.style_convs.append(
StyleConv(
out_channels,
@@ -362,9 +381,12 @@ def __init__(self,
num_style_feat=num_style_feat,
demodulate=True,
sample_mode=None,
- interpolation_mode=interpolation_mode))
+ interpolation_mode=interpolation_mode,
+ )
+ )
self.to_rgbs.append(
- ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode))
+ ToRGB(out_channels, num_style_feat, upsample=True, interpolation_mode=interpolation_mode)
+ )
in_channels = out_channels
def make_noise(self):
@@ -386,15 +408,17 @@ def mean_latent(self, num_latent):
latent = self.style_mlp(latent_in).mean(0, keepdim=True)
return latent
- def forward(self,
- styles,
- input_is_latent=False,
- noise=None,
- randomize_noise=True,
- truncation=1,
- truncation_latent=None,
- inject_index=None,
- return_latents=False):
+ def forward(
+ self,
+ styles,
+ input_is_latent=False,
+ noise=None,
+ randomize_noise=True,
+ truncation=1,
+ truncation_latent=None,
+ inject_index=None,
+ return_latents=False,
+ ):
"""Forward function for StyleGAN2Generator.
Args:
@@ -448,8 +472,9 @@ def forward(self,
skip = self.to_rgb1(out, latent[:, 1])
i = 1
- for conv1, conv2, noise1, noise2, to_rgb in zip(self.style_convs[::2], self.style_convs[1::2], noise[1::2],
- noise[2::2], self.to_rgbs):
+ for conv1, conv2, noise1, noise2, to_rgb in zip(
+ self.style_convs[::2], self.style_convs[1::2], noise[1::2], noise[2::2], self.to_rgbs
+ ):
out = conv1(out, latent[:, i], noise=noise1)
out = conv2(out, latent[:, i + 1], noise=noise2)
skip = to_rgb(out, latent[:, i + 2], skip)
@@ -521,11 +546,13 @@ def forward(self, x):
return out
def __repr__(self):
- return (f'{self.__class__.__name__}(in_channels={self.in_channels}, '
- f'out_channels={self.out_channels}, '
- f'kernel_size={self.kernel_size},'
- f' stride={self.stride}, padding={self.padding}, '
- f'bias={self.bias is not None})')
+ return (
+ f'{self.__class__.__name__}(in_channels={self.in_channels}, '
+ f'out_channels={self.out_channels}, '
+ f'kernel_size={self.kernel_size},'
+ f' stride={self.stride}, padding={self.padding}, '
+ f'bias={self.bias is not None})'
+ )
class ConvLayer(nn.Sequential):
@@ -541,14 +568,16 @@ class ConvLayer(nn.Sequential):
activate (bool): Whether use activateion. Default: True.
"""
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- downsample=False,
- bias=True,
- activate=True,
- interpolation_mode='bilinear'):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ downsample=False,
+ bias=True,
+ activate=True,
+ interpolation_mode='bilinear',
+ ):
layers = []
self.interpolation_mode = interpolation_mode
# downsample
@@ -559,14 +588,16 @@ def __init__(self,
self.align_corners = False
layers.append(
- torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners))
+ torch.nn.Upsample(scale_factor=0.5, mode=interpolation_mode, align_corners=self.align_corners)
+ )
stride = 1
self.padding = kernel_size // 2
# conv
layers.append(
EqualConv2d(
- in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias
- and not activate))
+ in_channels, out_channels, kernel_size, stride=stride, padding=self.padding, bias=bias and not activate
+ )
+ )
# activation
if activate:
if bias:
@@ -596,7 +627,8 @@ def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
downsample=True,
interpolation_mode=interpolation_mode,
bias=True,
- activate=True)
+ activate=True,
+ )
self.skip = ConvLayer(
in_channels,
out_channels,
@@ -604,7 +636,8 @@ def __init__(self, in_channels, out_channels, interpolation_mode='bilinear'):
downsample=True,
interpolation_mode=interpolation_mode,
bias=False,
- activate=False)
+ activate=False,
+ )
def forward(self, x):
out = self.conv1(x)
diff --git a/basicsr/archs/swinir_arch.py b/basicsr/archs/swinir_arch.py
index 3917fa2c7..dfbae9fa6 100644
--- a/basicsr/archs/swinir_arch.py
+++ b/basicsr/archs/swinir_arch.py
@@ -3,23 +3,25 @@
# Originally Written by Ze Liu, Modified by Jingyun Liang.
import math
+
import torch
import torch.nn as nn
import torch.utils.checkpoint as checkpoint
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import to_2tuple, trunc_normal_
-def drop_path(x, drop_prob: float = 0., training: bool = False):
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
From: https://github.com/rwightman/pytorch-image-models/blob/master/timm/models/layers/drop.py
"""
- if drop_prob == 0. or not training:
+ if drop_prob == 0.0 or not training:
return x
keep_prob = 1 - drop_prob
- shape = (x.shape[0], ) + (1, ) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
random_tensor.floor_() # binarize
output = x.div(keep_prob) * random_tensor
@@ -41,8 +43,7 @@ def forward(self, x):
class Mlp(nn.Module):
-
- def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
+ def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.0):
super().__init__()
out_features = out_features or in_features
hidden_features = hidden_features or in_features
@@ -93,7 +94,7 @@ def window_reverse(windows, window_size, h, w):
class WindowAttention(nn.Module):
- r""" Window based multi-head self attention (W-MSA) module with relative position bias.
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
@@ -106,7 +107,7 @@ class WindowAttention(nn.Module):
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
- def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):
+ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0.0, proj_drop=0.0):
super().__init__()
self.dim = dim
@@ -117,7 +118,8 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at
# define a parameter table of relative position bias
self.relative_position_bias_table = nn.Parameter(
- torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)) # 2*Wh-1 * 2*Ww-1, nH
+ torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads)
+ ) # 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h = torch.arange(self.window_size[0])
@@ -138,7 +140,7 @@ def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, at
self.proj_drop = nn.Dropout(proj_drop)
- trunc_normal_(self.relative_position_bias_table, std=.02)
+ trunc_normal_(self.relative_position_bias_table, std=0.02)
self.softmax = nn.Softmax(dim=-1)
def forward(self, x, mask=None):
@@ -152,10 +154,11 @@ def forward(self, x, mask=None):
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
q = q * self.scale
- attn = (q @ k.transpose(-2, -1))
+ attn = q @ k.transpose(-2, -1)
relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(
- self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1) # Wh*Ww,Wh*Ww,nH
+ self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1
+ ) # Wh*Ww,Wh*Ww,nH
relative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous() # nH, Wh*Ww, Wh*Ww
attn = attn + relative_position_bias.unsqueeze(0)
@@ -192,7 +195,7 @@ def flops(self, n):
class SwinTransformerBlock(nn.Module):
- r""" Swin Transformer Block.
+ r"""Swin Transformer Block.
Args:
dim (int): Number of input channels.
@@ -210,20 +213,22 @@ class SwinTransformerBlock(nn.Module):
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
- def __init__(self,
- dim,
- input_resolution,
- num_heads,
- window_size=7,
- shift_size=0,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- act_layer=nn.GELU,
- norm_layer=nn.LayerNorm):
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ ):
super().__init__()
self.dim = dim
self.input_resolution = input_resolution
@@ -245,9 +250,10 @@ def __init__(self,
qkv_bias=qkv_bias,
qk_scale=qk_scale,
attn_drop=attn_drop,
- proj_drop=drop)
+ proj_drop=drop,
+ )
- self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
self.norm2 = norm_layer(dim)
mlp_hidden_dim = int(dim * mlp_ratio)
self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
@@ -263,10 +269,16 @@ def calculate_mask(self, x_size):
# calculate attention mask for SW-MSA
h, w = x_size
img_mask = torch.zeros((1, h, w, 1)) # 1 h w 1
- h_slices = (slice(0, -self.window_size), slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
- w_slices = (slice(0, -self.window_size), slice(-self.window_size,
- -self.shift_size), slice(-self.shift_size, None))
+ h_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
+ w_slices = (
+ slice(0, -self.window_size),
+ slice(-self.window_size, -self.shift_size),
+ slice(-self.shift_size, None),
+ )
cnt = 0
for h in h_slices:
for w in w_slices:
@@ -323,8 +335,10 @@ def forward(self, x, x_size):
return x
def extra_repr(self) -> str:
- return (f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
- f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}')
+ return (
+ f'dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, '
+ f'window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}'
+ )
def flops(self):
flops = 0
@@ -342,7 +356,7 @@ def flops(self):
class PatchMerging(nn.Module):
- r""" Patch Merging Layer.
+ r"""Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
@@ -391,7 +405,7 @@ def flops(self):
class BasicLayer(nn.Module):
- """ A basic Swin Transformer layer for one stage.
+ """A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
@@ -410,21 +424,23 @@ class BasicLayer(nn.Module):
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
- def __init__(self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False):
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ ):
super().__init__()
self.dim = dim
@@ -433,21 +449,25 @@ def __init__(self,
self.use_checkpoint = use_checkpoint
# build blocks
- self.blocks = nn.ModuleList([
- SwinTransformerBlock(
- dim=dim,
- input_resolution=input_resolution,
- num_heads=num_heads,
- window_size=window_size,
- shift_size=0 if (i % 2 == 0) else window_size // 2,
- mlp_ratio=mlp_ratio,
- qkv_bias=qkv_bias,
- qk_scale=qk_scale,
- drop=drop,
- attn_drop=attn_drop,
- drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
- norm_layer=norm_layer) for i in range(depth)
- ])
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=window_size,
+ shift_size=0 if (i % 2 == 0) else window_size // 2,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ qk_scale=qk_scale,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=drop_path[i] if isinstance(drop_path, list) else drop_path,
+ norm_layer=norm_layer,
+ )
+ for i in range(depth)
+ ]
+ )
# patch merging layer
if downsample is not None:
@@ -500,24 +520,26 @@ class RSTB(nn.Module):
resi_connection: The convolutional block before residual connection.
"""
- def __init__(self,
- dim,
- input_resolution,
- depth,
- num_heads,
- window_size,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop=0.,
- attn_drop=0.,
- drop_path=0.,
- norm_layer=nn.LayerNorm,
- downsample=None,
- use_checkpoint=False,
- img_size=224,
- patch_size=4,
- resi_connection='1conv'):
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ img_size=224,
+ patch_size=4,
+ resi_connection='1conv',
+ ):
super(RSTB, self).__init__()
self.dim = dim
@@ -537,22 +559,28 @@ def __init__(self,
drop_path=drop_path,
norm_layer=norm_layer,
downsample=downsample,
- use_checkpoint=use_checkpoint)
+ use_checkpoint=use_checkpoint,
+ )
if resi_connection == '1conv':
self.conv = nn.Conv2d(dim, dim, 3, 1, 1)
elif resi_connection == '3conv':
# to save parameters and memory
self.conv = nn.Sequential(
- nn.Conv2d(dim, dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(dim // 4, dim, 3, 1, 1))
+ nn.Conv2d(dim, dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(dim // 4, dim, 3, 1, 1),
+ )
self.patch_embed = PatchEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None
+ )
self.patch_unembed = PatchUnEmbed(
- img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None)
+ img_size=img_size, patch_size=patch_size, in_chans=0, embed_dim=dim, norm_layer=None
+ )
def forward(self, x, x_size):
return self.patch_embed(self.conv(self.patch_unembed(self.residual_group(x, x_size), x_size))) + x
@@ -569,7 +597,7 @@ def flops(self):
class PatchEmbed(nn.Module):
- r""" Image to Patch Embedding
+ r"""Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
@@ -612,7 +640,7 @@ def flops(self):
class PatchUnEmbed(nn.Module):
- r""" Image to Patch Unembedding
+ r"""Image to Patch Unembedding
Args:
img_size (int): Image size. Default: 224.
@@ -692,7 +720,7 @@ def flops(self):
@ARCH_REGISTRY.register()
class SwinIR(nn.Module):
- r""" SwinIR
+ r"""SwinIR
A PyTorch impl of : `SwinIR: Image Restoration Using Swin Transformer`, based on Swin Transformer.
Args:
@@ -719,29 +747,31 @@ class SwinIR(nn.Module):
resi_connection: The convolutional block before residual connection. '1conv'/'3conv'
"""
- def __init__(self,
- img_size=64,
- patch_size=1,
- in_chans=3,
- embed_dim=96,
- depths=(6, 6, 6, 6),
- num_heads=(6, 6, 6, 6),
- window_size=7,
- mlp_ratio=4.,
- qkv_bias=True,
- qk_scale=None,
- drop_rate=0.,
- attn_drop_rate=0.,
- drop_path_rate=0.1,
- norm_layer=nn.LayerNorm,
- ape=False,
- patch_norm=True,
- use_checkpoint=False,
- upscale=2,
- img_range=1.,
- upsampler='',
- resi_connection='1conv',
- **kwargs):
+ def __init__(
+ self,
+ img_size=64,
+ patch_size=1,
+ in_chans=3,
+ embed_dim=96,
+ depths=(6, 6, 6, 6),
+ num_heads=(6, 6, 6, 6),
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ qk_scale=None,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ upscale=2,
+ img_range=1.0,
+ upsampler='',
+ resi_connection='1conv',
+ **kwargs,
+ ):
super(SwinIR, self).__init__()
num_in_ch = in_chans
num_out_ch = in_chans
@@ -772,7 +802,8 @@ def __init__(self,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
num_patches = self.patch_embed.num_patches
patches_resolution = self.patch_embed.patches_resolution
self.patches_resolution = patches_resolution
@@ -783,12 +814,13 @@ def __init__(self,
patch_size=patch_size,
in_chans=embed_dim,
embed_dim=embed_dim,
- norm_layer=norm_layer if self.patch_norm else None)
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
# absolute position embedding
if self.ape:
self.absolute_pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dim))
- trunc_normal_(self.absolute_pos_embed, std=.02)
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
self.pos_drop = nn.Dropout(p=drop_rate)
@@ -809,13 +841,14 @@ def __init__(self,
qk_scale=qk_scale,
drop=drop_rate,
attn_drop=attn_drop_rate,
- drop_path=dpr[sum(depths[:i_layer]):sum(depths[:i_layer + 1])], # no impact on SR results
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], # no impact on SR results
norm_layer=norm_layer,
downsample=None,
use_checkpoint=use_checkpoint,
img_size=img_size,
patch_size=patch_size,
- resi_connection=resi_connection)
+ resi_connection=resi_connection,
+ )
self.layers.append(layer)
self.norm = norm_layer(self.num_features)
@@ -825,26 +858,32 @@ def __init__(self,
elif resi_connection == '3conv':
# to save parameters and memory
self.conv_after_body = nn.Sequential(
- nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0), nn.LeakyReLU(negative_slope=0.2, inplace=True),
- nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1))
+ nn.Conv2d(embed_dim, embed_dim // 4, 3, 1, 1),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim // 4, 1, 1, 0),
+ nn.LeakyReLU(negative_slope=0.2, inplace=True),
+ nn.Conv2d(embed_dim // 4, embed_dim, 3, 1, 1),
+ )
# ------------------------- 3, high quality image reconstruction ------------------------- #
if self.upsampler == 'pixelshuffle':
# for classical SR
self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
+ )
self.upsample = Upsample(upscale, num_feat)
self.conv_last = nn.Conv2d(num_feat, num_out_ch, 3, 1, 1)
elif self.upsampler == 'pixelshuffledirect':
# for lightweight SR (to save parameters)
- self.upsample = UpsampleOneStep(upscale, embed_dim, num_out_ch,
- (patches_resolution[0], patches_resolution[1]))
+ self.upsample = UpsampleOneStep(
+ upscale, embed_dim, num_out_ch, (patches_resolution[0], patches_resolution[1])
+ )
elif self.upsampler == 'nearest+conv':
# for real-world SR (less artifacts)
assert self.upscale == 4, 'only support x4 now.'
self.conv_before_upsample = nn.Sequential(
- nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True))
+ nn.Conv2d(embed_dim, num_feat, 3, 1, 1), nn.LeakyReLU(inplace=True)
+ )
self.conv_up1 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_up2 = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
self.conv_hr = nn.Conv2d(num_feat, num_feat, 3, 1, 1)
@@ -858,7 +897,7 @@ def __init__(self,
def _init_weights(self, m):
if isinstance(m, nn.Linear):
- trunc_normal_(m.weight, std=.02)
+ trunc_normal_(m.weight, std=0.02)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
@@ -942,12 +981,13 @@ def flops(self):
upscale=2,
img_size=(height, width),
window_size=window_size,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6],
embed_dim=60,
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
- upsampler='pixelshuffledirect')
+ upsampler='pixelshuffledirect',
+ )
print(model)
print(height, width, model.flops() / 1e9)
diff --git a/basicsr/archs/tof_arch.py b/basicsr/archs/tof_arch.py
index a90a64d89..b0a5ee5a0 100644
--- a/basicsr/archs/tof_arch.py
+++ b/basicsr/archs/tof_arch.py
@@ -3,6 +3,7 @@
from torch.nn import functional as F
from basicsr.utils.registry import ARCH_REGISTRY
+
from .arch_util import flow_warp
@@ -17,14 +18,19 @@ def __init__(self):
super(BasicModule, self).__init__()
self.basic_module = nn.Sequential(
nn.Conv2d(in_channels=8, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
- nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=64, kernel_size=7, stride=1, padding=3, bias=False),
- nn.BatchNorm2d(64), nn.ReLU(inplace=True),
+ nn.BatchNorm2d(64),
+ nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=32, kernel_size=7, stride=1, padding=3, bias=False),
- nn.BatchNorm2d(32), nn.ReLU(inplace=True),
+ nn.BatchNorm2d(32),
+ nn.ReLU(inplace=True),
nn.Conv2d(in_channels=32, out_channels=16, kernel_size=7, stride=1, padding=3, bias=False),
- nn.BatchNorm2d(16), nn.ReLU(inplace=True),
- nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3))
+ nn.BatchNorm2d(16),
+ nn.ReLU(inplace=True),
+ nn.Conv2d(in_channels=16, out_channels=2, kernel_size=7, stride=1, padding=3),
+ )
def forward(self, tensor_input):
"""
@@ -86,7 +92,8 @@ def forward(self, ref, supp):
for i in range(4):
flow_up = F.interpolate(input=flow, scale_factor=2, mode='bilinear', align_corners=True) * 2.0
flow = flow_up + self.basic_module[i](
- torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1))
+ torch.cat([ref[i], flow_warp(supp[i], flow_up.permute(0, 2, 3, 1)), flow_up], 1)
+ )
return flow
diff --git a/basicsr/archs/vgg_arch.py b/basicsr/archs/vgg_arch.py
index 05200334e..328a8ebe4 100644
--- a/basicsr/archs/vgg_arch.py
+++ b/basicsr/archs/vgg_arch.py
@@ -1,6 +1,7 @@
import os
-import torch
from collections import OrderedDict
+
+import torch
from torch import nn as nn
from torchvision.models import vgg as vgg
@@ -9,27 +10,127 @@
VGG_PRETRAIN_PATH = 'experiments/pretrained_models/vgg19-dcbb9e9d.pth'
NAMES = {
'vgg11': [
- 'conv1_1', 'relu1_1', 'pool1', 'conv2_1', 'relu2_1', 'pool2', 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2',
- 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2',
- 'pool5'
+ 'conv1_1',
+ 'relu1_1',
+ 'pool1',
+ 'conv2_1',
+ 'relu2_1',
+ 'pool2',
+ 'conv3_1',
+ 'relu3_1',
+ 'conv3_2',
+ 'relu3_2',
+ 'pool3',
+ 'conv4_1',
+ 'relu4_1',
+ 'conv4_2',
+ 'relu4_2',
+ 'pool4',
+ 'conv5_1',
+ 'relu5_1',
+ 'conv5_2',
+ 'relu5_2',
+ 'pool5',
],
'vgg13': [
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2', 'relu4_2', 'pool4',
- 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'pool5'
+ 'conv1_1',
+ 'relu1_1',
+ 'conv1_2',
+ 'relu1_2',
+ 'pool1',
+ 'conv2_1',
+ 'relu2_1',
+ 'conv2_2',
+ 'relu2_2',
+ 'pool2',
+ 'conv3_1',
+ 'relu3_1',
+ 'conv3_2',
+ 'relu3_2',
+ 'pool3',
+ 'conv4_1',
+ 'relu4_1',
+ 'conv4_2',
+ 'relu4_2',
+ 'pool4',
+ 'conv5_1',
+ 'relu5_1',
+ 'conv5_2',
+ 'relu5_2',
+ 'pool5',
],
'vgg16': [
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'pool3', 'conv4_1', 'relu4_1', 'conv4_2',
- 'relu4_2', 'conv4_3', 'relu4_3', 'pool4', 'conv5_1', 'relu5_1', 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3',
- 'pool5'
+ 'conv1_1',
+ 'relu1_1',
+ 'conv1_2',
+ 'relu1_2',
+ 'pool1',
+ 'conv2_1',
+ 'relu2_1',
+ 'conv2_2',
+ 'relu2_2',
+ 'pool2',
+ 'conv3_1',
+ 'relu3_1',
+ 'conv3_2',
+ 'relu3_2',
+ 'conv3_3',
+ 'relu3_3',
+ 'pool3',
+ 'conv4_1',
+ 'relu4_1',
+ 'conv4_2',
+ 'relu4_2',
+ 'conv4_3',
+ 'relu4_3',
+ 'pool4',
+ 'conv5_1',
+ 'relu5_1',
+ 'conv5_2',
+ 'relu5_2',
+ 'conv5_3',
+ 'relu5_3',
+ 'pool5',
],
'vgg19': [
- 'conv1_1', 'relu1_1', 'conv1_2', 'relu1_2', 'pool1', 'conv2_1', 'relu2_1', 'conv2_2', 'relu2_2', 'pool2',
- 'conv3_1', 'relu3_1', 'conv3_2', 'relu3_2', 'conv3_3', 'relu3_3', 'conv3_4', 'relu3_4', 'pool3', 'conv4_1',
- 'relu4_1', 'conv4_2', 'relu4_2', 'conv4_3', 'relu4_3', 'conv4_4', 'relu4_4', 'pool4', 'conv5_1', 'relu5_1',
- 'conv5_2', 'relu5_2', 'conv5_3', 'relu5_3', 'conv5_4', 'relu5_4', 'pool5'
- ]
+ 'conv1_1',
+ 'relu1_1',
+ 'conv1_2',
+ 'relu1_2',
+ 'pool1',
+ 'conv2_1',
+ 'relu2_1',
+ 'conv2_2',
+ 'relu2_2',
+ 'pool2',
+ 'conv3_1',
+ 'relu3_1',
+ 'conv3_2',
+ 'relu3_2',
+ 'conv3_3',
+ 'relu3_3',
+ 'conv3_4',
+ 'relu3_4',
+ 'pool3',
+ 'conv4_1',
+ 'relu4_1',
+ 'conv4_2',
+ 'relu4_2',
+ 'conv4_3',
+ 'relu4_3',
+ 'conv4_4',
+ 'relu4_4',
+ 'pool4',
+ 'conv5_1',
+ 'relu5_1',
+ 'conv5_2',
+ 'relu5_2',
+ 'conv5_3',
+ 'relu5_3',
+ 'conv5_4',
+ 'relu5_4',
+ 'pool5',
+ ],
}
@@ -75,14 +176,16 @@ class VGGFeatureExtractor(nn.Module):
pooling_stride (int): The stride of max pooling operation. Default: 2.
"""
- def __init__(self,
- layer_name_list,
- vgg_type='vgg19',
- use_input_norm=True,
- range_norm=False,
- requires_grad=False,
- remove_pooling=False,
- pooling_stride=2):
+ def __init__(
+ self,
+ layer_name_list,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ requires_grad=False,
+ remove_pooling=False,
+ pooling_stride=2,
+ ):
super(VGGFeatureExtractor, self).__init__()
self.layer_name_list = layer_name_list
@@ -107,7 +210,7 @@ def __init__(self,
else:
vgg_net = getattr(vgg, vgg_type)(pretrained=True)
- features = vgg_net.features[:max_idx + 1]
+ features = vgg_net.features[: max_idx + 1]
modified_net = OrderedDict()
for k, v in zip(self.names, features):
diff --git a/basicsr/data/__init__.py b/basicsr/data/__init__.py
index 510df1677..b171277f2 100644
--- a/basicsr/data/__init__.py
+++ b/basicsr/data/__init__.py
@@ -1,12 +1,13 @@
import importlib
-import numpy as np
import random
-import torch
-import torch.utils.data
from copy import deepcopy
from functools import partial
from os import path as osp
+import numpy as np
+import torch
+import torch.utils.data
+
from basicsr.data.prefetch_dataloader import PrefetchDataLoader
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.dist_util import get_dist_info
@@ -69,11 +70,13 @@ def build_dataloader(dataset, dataset_opt, num_gpu=1, dist=False, sampler=None,
shuffle=False,
num_workers=num_workers,
sampler=sampler,
- drop_last=True)
+ drop_last=True,
+ )
if sampler is None:
dataloader_args['shuffle'] = True
- dataloader_args['worker_init_fn'] = partial(
- worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ dataloader_args['worker_init_fn'] = (
+ partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) if seed is not None else None
+ )
elif phase in ['val', 'test']: # validation
dataloader_args = dict(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
else:
diff --git a/basicsr/data/data_sampler.py b/basicsr/data/data_sampler.py
index 575452d9f..faee122bb 100644
--- a/basicsr/data/data_sampler.py
+++ b/basicsr/data/data_sampler.py
@@ -1,4 +1,5 @@
import math
+
import torch
from torch.utils.data.sampler import Sampler
@@ -36,7 +37,7 @@ def __iter__(self):
indices = [v % dataset_size for v in indices]
# subsample
- indices = indices[self.rank:self.total_size:self.num_replicas]
+ indices = indices[self.rank : self.total_size : self.num_replicas]
assert len(indices) == self.num_samples
return iter(indices)
diff --git a/basicsr/data/data_util.py b/basicsr/data/data_util.py
index bf4c494b7..5a3bf7cfd 100644
--- a/basicsr/data/data_util.py
+++ b/basicsr/data/data_util.py
@@ -1,7 +1,8 @@
+from os import path as osp
+
import cv2
import numpy as np
import torch
-from os import path as osp
from torch.nn import functional as F
from basicsr.data.transforms import mod_crop
@@ -26,7 +27,7 @@ def read_img_seq(path, require_mod_crop=False, scale=1, return_imgname=False):
img_paths = path
else:
img_paths = sorted(list(scandir(path, full_path=True)))
- imgs = [cv2.imread(v).astype(np.float32) / 255. for v in img_paths]
+ imgs = [cv2.imread(v).astype(np.float32) / 255.0 for v in img_paths]
if require_mod_crop:
imgs = [mod_crop(img, scale) for img in imgs]
@@ -129,16 +130,17 @@ def paired_paths_from_lmdb(folders, keys):
Returns:
list[str]: Returned path list.
"""
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
- f'But got {len(folders)}')
+ assert len(folders) == 2, f'The len of folders should be 2 with [input_folder, gt_folder]. But got {len(folders)}'
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
if not (input_folder.endswith('.lmdb') and gt_folder.endswith('.lmdb')):
- raise ValueError(f'{input_key} folder and {gt_key} folder should both in lmdb '
- f'formats. But received {input_key}: {input_folder}; '
- f'{gt_key}: {gt_folder}')
+ raise ValueError(
+ f'{input_key} folder and {gt_key} folder should both in lmdb '
+ f'formats. But received {input_key}: {input_folder}; '
+ f'{gt_key}: {gt_folder}'
+ )
# ensure that the two meta_info files are the same
with open(osp.join(input_folder, 'meta_info.txt')) as fin:
input_lmdb_keys = [line.split('.')[0] for line in fin]
@@ -178,8 +180,7 @@ def paired_paths_from_meta_info_file(folders, keys, meta_info_file, filename_tmp
Returns:
list[str]: Returned path list.
"""
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
- f'But got {len(folders)}')
+ assert len(folders) == 2, f'The len of folders should be 2 with [input_folder, gt_folder]. But got {len(folders)}'
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
@@ -212,16 +213,16 @@ def paired_paths_from_folder(folders, keys, filename_tmpl):
Returns:
list[str]: Returned path list.
"""
- assert len(folders) == 2, ('The len of folders should be 2 with [input_folder, gt_folder]. '
- f'But got {len(folders)}')
+ assert len(folders) == 2, f'The len of folders should be 2 with [input_folder, gt_folder]. But got {len(folders)}'
assert len(keys) == 2, f'The len of keys should be 2 with [input_key, gt_key]. But got {len(keys)}'
input_folder, gt_folder = folders
input_key, gt_key = keys
input_paths = list(scandir(input_folder))
gt_paths = list(scandir(gt_folder))
- assert len(input_paths) == len(gt_paths), (f'{input_key} and {gt_key} datasets have different number of images: '
- f'{len(input_paths)}, {len(gt_paths)}.')
+ assert len(input_paths) == len(gt_paths), (
+ f'{input_key} and {gt_key} datasets have different number of images: {len(input_paths)}, {len(gt_paths)}.'
+ )
paths = []
for gt_path in gt_paths:
basename, ext = osp.splitext(osp.basename(gt_path))
@@ -275,6 +276,7 @@ def generate_gaussian_kernel(kernel_size=13, sigma=1.6):
np.array: The Gaussian kernel.
"""
from scipy.ndimage import filters as filters
+
kernel = np.zeros((kernel_size, kernel_size))
# set element at the middle to one, a dirac delta
kernel[kernel_size // 2, kernel_size // 2] = 1
diff --git a/basicsr/data/degradations.py b/basicsr/data/degradations.py
index 14319605d..4c20f5a51 100644
--- a/basicsr/data/degradations.py
+++ b/basicsr/data/degradations.py
@@ -1,7 +1,8 @@
-import cv2
import math
-import numpy as np
import random
+
+import cv2
+import numpy as np
import torch
from scipy import special
from scipy.stats import multivariate_normal
@@ -40,10 +41,11 @@ def mesh_grid(kernel_size):
xx (ndarray): with the shape (kernel_size, kernel_size)
yy (ndarray): with the shape (kernel_size, kernel_size)
"""
- ax = np.arange(-kernel_size // 2 + 1., kernel_size // 2 + 1.)
+ ax = np.arange(-kernel_size // 2 + 1.0, kernel_size // 2 + 1.0)
xx, yy = np.meshgrid(ax, ax)
- xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size,
- 1))).reshape(kernel_size, kernel_size, 2)
+ xy = np.hstack((xx.reshape((kernel_size * kernel_size, 1)), yy.reshape(kernel_size * kernel_size, 1))).reshape(
+ kernel_size, kernel_size, 2
+ )
return xy, xx, yy
@@ -173,12 +175,9 @@ def bivariate_plateau(kernel_size, sig_x, sig_y, theta, beta, grid=None, isotrop
return kernel
-def random_bivariate_Gaussian(kernel_size,
- sigma_x_range,
- sigma_y_range,
- rotation_range,
- noise_range=None,
- isotropic=True):
+def random_bivariate_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=None, isotropic=True
+):
"""Randomly generate bivariate isotropic or anisotropic Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
@@ -217,13 +216,9 @@ def random_bivariate_Gaussian(kernel_size,
return kernel
-def random_bivariate_generalized_Gaussian(kernel_size,
- sigma_x_range,
- sigma_y_range,
- rotation_range,
- beta_range,
- noise_range=None,
- isotropic=True):
+def random_bivariate_generalized_Gaussian(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, beta_range, noise_range=None, isotropic=True
+):
"""Randomly generate bivariate generalized Gaussian kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
@@ -269,13 +264,9 @@ def random_bivariate_generalized_Gaussian(kernel_size,
return kernel
-def random_bivariate_plateau(kernel_size,
- sigma_x_range,
- sigma_y_range,
- rotation_range,
- beta_range,
- noise_range=None,
- isotropic=True):
+def random_bivariate_plateau(
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, beta_range, noise_range=None, isotropic=True
+):
"""Randomly generate bivariate plateau kernels.
In the isotropic mode, only `sigma_x_range` is used. `sigma_y_range` and `rotation_range` is ignored.
@@ -321,15 +312,17 @@ def random_bivariate_plateau(kernel_size,
return kernel
-def random_mixed_kernels(kernel_list,
- kernel_prob,
- kernel_size=21,
- sigma_x_range=(0.6, 5),
- sigma_y_range=(0.6, 5),
- rotation_range=(-math.pi, math.pi),
- betag_range=(0.5, 8),
- betap_range=(0.5, 8),
- noise_range=None):
+def random_mixed_kernels(
+ kernel_list,
+ kernel_prob,
+ kernel_size=21,
+ sigma_x_range=(0.6, 5),
+ sigma_y_range=(0.6, 5),
+ rotation_range=(-math.pi, math.pi),
+ betag_range=(0.5, 8),
+ betap_range=(0.5, 8),
+ noise_range=None,
+):
"""Randomly generate mixed kernels.
Args:
@@ -352,10 +345,12 @@ def random_mixed_kernels(kernel_list,
kernel_type = random.choices(kernel_list, kernel_prob)[0]
if kernel_type == 'iso':
kernel = random_bivariate_Gaussian(
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True)
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=True
+ )
elif kernel_type == 'aniso':
kernel = random_bivariate_Gaussian(
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False)
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, noise_range=noise_range, isotropic=False
+ )
elif kernel_type == 'generalized_iso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
@@ -364,7 +359,8 @@ def random_mixed_kernels(kernel_list,
rotation_range,
betag_range,
noise_range=noise_range,
- isotropic=True)
+ isotropic=True,
+ )
elif kernel_type == 'generalized_aniso':
kernel = random_bivariate_generalized_Gaussian(
kernel_size,
@@ -373,13 +369,16 @@ def random_mixed_kernels(kernel_list,
rotation_range,
betag_range,
noise_range=noise_range,
- isotropic=False)
+ isotropic=False,
+ )
elif kernel_type == 'plateau_iso':
kernel = random_bivariate_plateau(
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True)
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=True
+ )
elif kernel_type == 'plateau_aniso':
kernel = random_bivariate_plateau(
- kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False)
+ kernel_size, sigma_x_range, sigma_y_range, rotation_range, betap_range, noise_range=None, isotropic=False
+ )
return kernel
@@ -398,9 +397,13 @@ def circular_lowpass_kernel(cutoff, kernel_size, pad_to=0):
"""
assert kernel_size % 2 == 1, 'Kernel size must be an odd number.'
kernel = np.fromfunction(
- lambda x, y: cutoff * special.j1(cutoff * np.sqrt(
- (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)) / (2 * np.pi * np.sqrt(
- (x - (kernel_size - 1) / 2)**2 + (y - (kernel_size - 1) / 2)**2)), [kernel_size, kernel_size])
+ lambda x, y: (
+ cutoff
+ * special.j1(cutoff * np.sqrt((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2))
+ / (2 * np.pi * np.sqrt((x - (kernel_size - 1) / 2) ** 2 + (y - (kernel_size - 1) / 2) ** 2))
+ ),
+ [kernel_size, kernel_size],
+ )
kernel[(kernel_size - 1) // 2, (kernel_size - 1) // 2] = cutoff**2 / (4 * np.pi)
kernel = kernel / np.sum(kernel)
if pad_to > kernel_size:
@@ -428,10 +431,10 @@ def generate_gaussian_noise(img, sigma=10, gray_noise=False):
float32.
"""
if gray_noise:
- noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.
+ noise = np.float32(np.random.randn(*(img.shape[0:2]))) * sigma / 255.0
noise = np.expand_dims(noise, axis=2).repeat(3, axis=2)
else:
- noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.
+ noise = np.float32(np.random.randn(*(img.shape))) * sigma / 255.0
return noise
@@ -449,11 +452,11 @@ def add_gaussian_noise(img, sigma=10, clip=True, rounds=False, gray_noise=False)
noise = generate_gaussian_noise(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -478,11 +481,11 @@ def generate_gaussian_noise_pt(img, sigma=10, gray_noise=0):
cal_gray_noise = torch.sum(gray_noise) > 0
if cal_gray_noise:
- noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.
+ noise_gray = torch.randn(*img.size()[2:4], dtype=img.dtype, device=img.device) * sigma / 255.0
noise_gray = noise_gray.view(b, 1, h, w)
# always calculate color noise
- noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.
+ noise = torch.randn(*img.size(), dtype=img.dtype, device=img.device) * sigma / 255.0
if cal_gray_noise:
noise = noise * (1 - gray_noise) + noise_gray * gray_noise
@@ -503,11 +506,11 @@ def add_gaussian_noise_pt(img, sigma=10, gray_noise=0, clip=True, rounds=False):
noise = generate_gaussian_noise_pt(img, sigma, gray_noise)
out = img + noise
if clip and rounds:
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -525,17 +528,18 @@ def random_add_gaussian_noise(img, sigma_range=(0, 1.0), gray_prob=0, clip=True,
noise = random_generate_gaussian_noise(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
def random_generate_gaussian_noise_pt(img, sigma_range=(0, 10), gray_prob=0):
- sigma = torch.rand(
- img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+ sigma = (
+ torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (sigma_range[1] - sigma_range[0]) + sigma_range[0]
+ )
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_gaussian_noise_pt(img, sigma, gray_noise)
@@ -545,11 +549,11 @@ def random_add_gaussian_noise_pt(img, sigma_range=(0, 1.0), gray_prob=0, clip=Tr
noise = random_generate_gaussian_noise_pt(img, sigma_range, gray_prob)
out = img + noise
if clip and rounds:
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -573,9 +577,9 @@ def generate_poisson_noise(img, scale=1.0, gray_noise=False):
if gray_noise:
img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
# round and clip image for counting vals correctly
- img = np.clip((img * 255.0).round(), 0, 255) / 255.
+ img = np.clip((img * 255.0).round(), 0, 255) / 255.0
vals = len(np.unique(img))
- vals = 2**np.ceil(np.log2(vals))
+ vals = 2 ** np.ceil(np.log2(vals))
out = np.float32(np.random.poisson(img * vals) / float(vals))
noise = out - img
if gray_noise:
@@ -598,11 +602,11 @@ def add_poisson_noise(img, scale=1.0, clip=True, rounds=False, gray_noise=False)
noise = generate_poisson_noise(img, scale, gray_noise)
out = img + noise
if clip and rounds:
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -629,10 +633,10 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
if cal_gray_noise:
img_gray = rgb_to_grayscale(img, num_output_channels=1)
# round and clip image for counting vals correctly
- img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.
+ img_gray = torch.clamp((img_gray * 255.0).round(), 0, 255) / 255.0
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img_gray[i, :, :, :])) for i in range(b)]
- vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
vals = img_gray.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img_gray * vals) / vals
noise_gray = out - img_gray
@@ -640,10 +644,10 @@ def generate_poisson_noise_pt(img, scale=1.0, gray_noise=0):
# always calculate color noise
# round and clip image for counting vals correctly
- img = torch.clamp((img * 255.0).round(), 0, 255) / 255.
+ img = torch.clamp((img * 255.0).round(), 0, 255) / 255.0
# use for-loop to get the unique values for each sample
vals_list = [len(torch.unique(img[i, :, :, :])) for i in range(b)]
- vals_list = [2**np.ceil(np.log2(vals)) for vals in vals_list]
+ vals_list = [2 ** np.ceil(np.log2(vals)) for vals in vals_list]
vals = img.new_tensor(vals_list).view(b, 1, 1, 1)
out = torch.poisson(img * vals) / vals
noise = out - img
@@ -671,11 +675,11 @@ def add_poisson_noise_pt(img, scale=1.0, clip=True, rounds=False, gray_noise=0):
noise = generate_poisson_noise_pt(img, scale, gray_noise)
out = img + noise
if clip and rounds:
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -695,17 +699,18 @@ def random_add_poisson_noise(img, scale_range=(0, 1.0), gray_prob=0, clip=True,
noise = random_generate_poisson_noise(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
- out = np.clip((out * 255.0).round(), 0, 255) / 255.
+ out = np.clip((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = np.clip(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
def random_generate_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0):
- scale = torch.rand(
- img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+ scale = (
+ torch.rand(img.size(0), dtype=img.dtype, device=img.device) * (scale_range[1] - scale_range[0]) + scale_range[0]
+ )
gray_noise = torch.rand(img.size(0), dtype=img.dtype, device=img.device)
gray_noise = (gray_noise < gray_prob).float()
return generate_poisson_noise_pt(img, scale, gray_noise)
@@ -715,11 +720,11 @@ def random_add_poisson_noise_pt(img, scale_range=(0, 1.0), gray_prob=0, clip=Tru
noise = random_generate_poisson_noise_pt(img, scale_range, gray_prob)
out = img + noise
if clip and rounds:
- out = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ out = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
elif clip:
out = torch.clamp(out, 0, 1)
elif rounds:
- out = (out * 255.0).round() / 255.
+ out = (out * 255.0).round() / 255.0
return out
@@ -742,8 +747,8 @@ def add_jpg_compression(img, quality=90):
"""
img = np.clip(img, 0, 1)
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), quality]
- _, encimg = cv2.imencode('.jpg', img * 255., encode_param)
- img = np.float32(cv2.imdecode(encimg, 1)) / 255.
+ _, encimg = cv2.imencode('.jpg', img * 255.0, encode_param)
+ img = np.float32(cv2.imdecode(encimg, 1)) / 255.0
return img
diff --git a/basicsr/data/ffhq_dataset.py b/basicsr/data/ffhq_dataset.py
index 23992eb87..7b7925902 100644
--- a/basicsr/data/ffhq_dataset.py
+++ b/basicsr/data/ffhq_dataset.py
@@ -1,6 +1,7 @@
import random
import time
from os import path as osp
+
from torch.utils import data as data
from torchvision.transforms.functional import normalize
diff --git a/basicsr/data/paired_image_dataset.py b/basicsr/data/paired_image_dataset.py
index 9f5c8c6ad..ea2cc384f 100644
--- a/basicsr/data/paired_image_dataset.py
+++ b/basicsr/data/paired_image_dataset.py
@@ -55,8 +55,9 @@ def __init__(self, opt):
self.io_backend_opt['client_keys'] = ['lq', 'gt']
self.paths = paired_paths_from_lmdb([self.lq_folder, self.gt_folder], ['lq', 'gt'])
elif 'meta_info_file' in self.opt and self.opt['meta_info_file'] is not None:
- self.paths = paired_paths_from_meta_info_file([self.lq_folder, self.gt_folder], ['lq', 'gt'],
- self.opt['meta_info_file'], self.filename_tmpl)
+ self.paths = paired_paths_from_meta_info_file(
+ [self.lq_folder, self.gt_folder], ['lq', 'gt'], self.opt['meta_info_file'], self.filename_tmpl
+ )
else:
self.paths = paired_paths_from_folder([self.lq_folder, self.gt_folder], ['lq', 'gt'], self.filename_tmpl)
@@ -91,7 +92,7 @@ def __getitem__(self, index):
# crop the unmatched GT images during validation or testing, especially for SR benchmark datasets
# TODO: It is better to update the datasets, rather than force to crop
if self.opt['phase'] != 'train':
- img_gt = img_gt[0:img_lq.shape[0] * scale, 0:img_lq.shape[1] * scale, :]
+ img_gt = img_gt[0 : img_lq.shape[0] * scale, 0 : img_lq.shape[1] * scale, :]
# BGR to RGB, HWC to CHW, numpy to tensor
img_gt, img_lq = img2tensor([img_gt, img_lq], bgr2rgb=True, float32=True)
diff --git a/basicsr/data/prefetch_dataloader.py b/basicsr/data/prefetch_dataloader.py
index 332abd32f..e53226ee2 100644
--- a/basicsr/data/prefetch_dataloader.py
+++ b/basicsr/data/prefetch_dataloader.py
@@ -1,5 +1,6 @@
import queue as Queue
import threading
+
import torch
from torch.utils.data import DataLoader
@@ -58,7 +59,7 @@ def __iter__(self):
return PrefetchGenerator(super().__iter__(), self.num_prefetch_queue)
-class CPUPrefetcher():
+class CPUPrefetcher:
"""CPU prefetcher.
Args:
@@ -79,7 +80,7 @@ def reset(self):
self.loader = iter(self.ori_loader)
-class CUDAPrefetcher():
+class CUDAPrefetcher:
"""CUDA prefetcher.
Reference: https://github.com/NVIDIA/apex/issues/304#
diff --git a/basicsr/data/realesrgan_dataset.py b/basicsr/data/realesrgan_dataset.py
index 1616e9b91..c289c7b38 100644
--- a/basicsr/data/realesrgan_dataset.py
+++ b/basicsr/data/realesrgan_dataset.py
@@ -1,10 +1,11 @@
-import cv2
import math
-import numpy as np
import os
import os.path as osp
import random
import time
+
+import cv2
+import numpy as np
import torch
from torch.utils import data as data
@@ -124,7 +125,7 @@ def __getitem__(self, index):
# randomly choose top and left coordinates
top = random.randint(0, h - crop_pad_size)
left = random.randint(0, w - crop_pad_size)
- img_gt = img_gt[top:top + crop_pad_size, left:left + crop_pad_size, ...]
+ img_gt = img_gt[top : top + crop_pad_size, left : left + crop_pad_size, ...]
# ------------------------ Generate kernels (used in the first degradation) ------------------------ #
kernel_size = random.choice(self.kernel_range)
@@ -141,10 +142,12 @@ def __getitem__(self, index):
self.kernel_prob,
kernel_size,
self.blur_sigma,
- self.blur_sigma, [-math.pi, math.pi],
+ self.blur_sigma,
+ [-math.pi, math.pi],
self.betag_range,
self.betap_range,
- noise_range=None)
+ noise_range=None,
+ )
# pad kernel
pad_size = (21 - kernel_size) // 2
kernel = np.pad(kernel, ((pad_size, pad_size), (pad_size, pad_size)))
@@ -163,10 +166,12 @@ def __getitem__(self, index):
self.kernel_prob2,
kernel_size,
self.blur_sigma2,
- self.blur_sigma2, [-math.pi, math.pi],
+ self.blur_sigma2,
+ [-math.pi, math.pi],
self.betag_range2,
self.betap_range2,
- noise_range=None)
+ noise_range=None,
+ )
# pad kernel
pad_size = (21 - kernel_size) // 2
diff --git a/basicsr/data/realesrgan_paired_dataset.py b/basicsr/data/realesrgan_paired_dataset.py
index 604b026d5..ee2d62548 100644
--- a/basicsr/data/realesrgan_paired_dataset.py
+++ b/basicsr/data/realesrgan_paired_dataset.py
@@ -1,4 +1,5 @@
import os
+
from torch.utils import data as data
from torchvision.transforms.functional import normalize
diff --git a/basicsr/data/reds_dataset.py b/basicsr/data/reds_dataset.py
index fabef1d7e..284e8ae13 100644
--- a/basicsr/data/reds_dataset.py
+++ b/basicsr/data/reds_dataset.py
@@ -1,7 +1,8 @@
-import numpy as np
import random
-import torch
from pathlib import Path
+
+import numpy as np
+import torch
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
@@ -51,7 +52,7 @@ def __init__(self, opt):
self.opt = opt
self.gt_root, self.lq_root = Path(opt['dataroot_gt']), Path(opt['dataroot_lq'])
self.flow_root = Path(opt['dataroot_flow']) if opt['dataroot_flow'] is not None else None
- assert opt['num_frame'] % 2 == 1, (f'num_frame should be odd number, but got {opt["num_frame"]}')
+ assert opt['num_frame'] % 2 == 1, f'num_frame should be odd number, but got {opt["num_frame"]}'
self.num_frame = opt['num_frame']
self.num_half_frames = opt['num_frame'] // 2
@@ -67,8 +68,9 @@ def __init__(self, opt):
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
- raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
- f"Supported ones are ['official', 'REDS4'].")
+ raise ValueError(
+ f"Wrong validation partition {opt['val_partition']}.Supported ones are ['official', 'REDS4']."
+ )
self.keys = [v for v in self.keys if v.split('/')[0] not in val_partition]
# file client (io backend)
@@ -89,8 +91,7 @@ def __init__(self, opt):
self.random_reverse = opt['random_reverse']
interval_str = ','.join(str(x) for x in opt['interval_list'])
logger = get_root_logger()
- logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
- f'random reverse is {self.random_reverse}.')
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
@@ -111,7 +112,7 @@ def __getitem__(self, index):
# each clip has 100 frames starting from 0 to 99
while (start_frame_idx < 0) or (end_frame_idx > 99):
center_frame_idx = random.randint(0, 99)
- start_frame_idx = (center_frame_idx - self.num_half_frames * interval)
+ start_frame_idx = center_frame_idx - self.num_half_frames * interval
end_frame_idx = center_frame_idx + self.num_half_frames * interval
frame_name = f'{center_frame_idx:08d}'
neighbor_list = list(range(start_frame_idx, end_frame_idx + 1, interval))
@@ -119,7 +120,7 @@ def __getitem__(self, index):
if self.random_reverse and random.random() < 0.5:
neighbor_list.reverse()
- assert len(neighbor_list) == self.num_frame, (f'Wrong length of neighbor list: {len(neighbor_list)}')
+ assert len(neighbor_list) == self.num_frame, f'Wrong length of neighbor list: {len(neighbor_list)}'
# get the GT frame (as the center frame)
if self.is_lmdb:
@@ -148,7 +149,7 @@ def __getitem__(self, index):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_p{i}'
else:
- flow_path = (self.flow_root / clip_name / f'{frame_name}_p{i}.png')
+ flow_path = self.flow_root / clip_name / f'{frame_name}_p{i}.png'
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
@@ -159,7 +160,7 @@ def __getitem__(self, index):
if self.is_lmdb:
flow_path = f'{clip_name}/{frame_name}_n{i}'
else:
- flow_path = (self.flow_root / clip_name / f'{frame_name}_n{i}.png')
+ flow_path = self.flow_root / clip_name / f'{frame_name}_n{i}.png'
img_bytes = self.file_client.get(flow_path, 'flow')
cat_flow = imfrombytes(img_bytes, flag='grayscale', float32=False) # uint8, [0, 255]
dx, dy = np.split(cat_flow, 2, axis=0)
@@ -173,7 +174,7 @@ def __getitem__(self, index):
# randomly crop
img_gt, img_lqs = paired_random_crop(img_gt, img_lqs, gt_size, scale, img_gt_path)
if self.flow_root is not None:
- img_lqs, img_flows = img_lqs[:self.num_frame], img_lqs[self.num_frame:]
+ img_lqs, img_flows = img_lqs[: self.num_frame], img_lqs[self.num_frame :]
# augmentation - flip, rotate
img_lqs.append(img_gt)
@@ -259,8 +260,9 @@ def __init__(self, opt):
elif opt['val_partition'] == 'official':
val_partition = [f'{v:03d}' for v in range(240, 270)]
else:
- raise ValueError(f'Wrong validation partition {opt["val_partition"]}.'
- f"Supported ones are ['official', 'REDS4'].")
+ raise ValueError(
+ f"Wrong validation partition {opt['val_partition']}.Supported ones are ['official', 'REDS4']."
+ )
if opt['test_mode']:
self.keys = [v for v in self.keys if v.split('/')[0] in val_partition]
else:
@@ -284,8 +286,7 @@ def __init__(self, opt):
self.random_reverse = opt.get('random_reverse', False)
interval_str = ','.join(str(x) for x in self.interval_list)
logger = get_root_logger()
- logger.info(f'Temporal augmentation interval list: [{interval_str}]; '
- f'random reverse is {self.random_reverse}.')
+ logger.info(f'Temporal augmentation interval list: [{interval_str}]; random reverse is {self.random_reverse}.')
def __getitem__(self, index):
if self.file_client is None:
@@ -340,8 +341,8 @@ def __getitem__(self, index):
img_results = augment(img_lqs, self.opt['use_hflip'], self.opt['use_rot'])
img_results = img2tensor(img_results)
- img_gts = torch.stack(img_results[len(img_lqs) // 2:], dim=0)
- img_lqs = torch.stack(img_results[:len(img_lqs) // 2], dim=0)
+ img_gts = torch.stack(img_results[len(img_lqs) // 2 :], dim=0)
+ img_lqs = torch.stack(img_results[: len(img_lqs) // 2], dim=0)
# img_lqs: (t, c, h, w)
# img_gts: (t, c, h, w)
diff --git a/basicsr/data/single_image_dataset.py b/basicsr/data/single_image_dataset.py
index acbc7d921..ce9542317 100644
--- a/basicsr/data/single_image_dataset.py
+++ b/basicsr/data/single_image_dataset.py
@@ -1,4 +1,5 @@
from os import path as osp
+
from torch.utils import data as data
from torchvision.transforms.functional import normalize
diff --git a/basicsr/data/transforms.py b/basicsr/data/transforms.py
index d9bbb5fb7..dce3842c6 100644
--- a/basicsr/data/transforms.py
+++ b/basicsr/data/transforms.py
@@ -1,5 +1,6 @@
-import cv2
import random
+
+import cv2
import torch
@@ -17,7 +18,7 @@ def mod_crop(img, scale):
if img.ndim in (2, 3):
h, w = img.shape[0], img.shape[1]
h_remainder, w_remainder = h % scale, w % scale
- img = img[:h - h_remainder, :w - w_remainder, ...]
+ img = img[: h - h_remainder, : w - w_remainder, ...]
else:
raise ValueError(f'Wrong img ndim: {img.ndim}.')
return img
@@ -61,12 +62,15 @@ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
lq_patch_size = gt_patch_size // scale
if h_gt != h_lq * scale or w_gt != w_lq * scale:
- raise ValueError(f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ',
- f'multiplication of LQ ({h_lq}, {w_lq}).')
+ raise ValueError(
+ f'Scale mismatches. GT ({h_gt}, {w_gt}) is not {scale}x ', f'multiplication of LQ ({h_lq}, {w_lq}).'
+ )
if h_lq < lq_patch_size or w_lq < lq_patch_size:
- raise ValueError(f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
- f'({lq_patch_size}, {lq_patch_size}). '
- f'Please remove {gt_path}.')
+ raise ValueError(
+ f'LQ ({h_lq}, {w_lq}) is smaller than patch size '
+ f'({lq_patch_size}, {lq_patch_size}). '
+ f'Please remove {gt_path}.'
+ )
# randomly choose top and left coordinates for lq patch
top = random.randint(0, h_lq - lq_patch_size)
@@ -74,16 +78,16 @@ def paired_random_crop(img_gts, img_lqs, gt_patch_size, scale, gt_path=None):
# crop lq patch
if input_type == 'Tensor':
- img_lqs = [v[:, :, top:top + lq_patch_size, left:left + lq_patch_size] for v in img_lqs]
+ img_lqs = [v[:, :, top : top + lq_patch_size, left : left + lq_patch_size] for v in img_lqs]
else:
- img_lqs = [v[top:top + lq_patch_size, left:left + lq_patch_size, ...] for v in img_lqs]
+ img_lqs = [v[top : top + lq_patch_size, left : left + lq_patch_size, ...] for v in img_lqs]
# crop corresponding gt patch
top_gt, left_gt = int(top * scale), int(left * scale)
if input_type == 'Tensor':
- img_gts = [v[:, :, top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size] for v in img_gts]
+ img_gts = [v[:, :, top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size] for v in img_gts]
else:
- img_gts = [v[top_gt:top_gt + gt_patch_size, left_gt:left_gt + gt_patch_size, ...] for v in img_gts]
+ img_gts = [v[top_gt : top_gt + gt_patch_size, left_gt : left_gt + gt_patch_size, ...] for v in img_gts]
if len(img_gts) == 1:
img_gts = img_gts[0]
if len(img_lqs) == 1:
diff --git a/basicsr/data/video_test_dataset.py b/basicsr/data/video_test_dataset.py
index 929f7d974..abaeff3da 100644
--- a/basicsr/data/video_test_dataset.py
+++ b/basicsr/data/video_test_dataset.py
@@ -1,6 +1,7 @@
import glob
-import torch
from os import path as osp
+
+import torch
from torch.utils import data as data
from basicsr.data.data_util import duf_downsample, generate_frame_indices, read_img_seq
@@ -74,8 +75,9 @@ def __init__(self, opt):
img_paths_gt = sorted(list(scandir(subfolder_gt, full_path=True)))
max_idx = len(img_paths_lq)
- assert max_idx == len(img_paths_gt), (f'Different number of images in lq ({max_idx})'
- f' and gt folders ({len(img_paths_gt)})')
+ assert max_idx == len(img_paths_gt), (
+ f'Different number of images in lq ({max_idx}) and gt folders ({len(img_paths_gt)})'
+ )
self.data_info['lq_path'].extend(img_paths_lq)
self.data_info['gt_path'].extend(img_paths_gt)
@@ -123,7 +125,7 @@ def __getitem__(self, index):
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
- 'lq_path': lq_path # center frame
+ 'lq_path': lq_path, # center frame
}
def __len__(self):
@@ -191,7 +193,7 @@ def __getitem__(self, index):
'folder': self.data_info['folder'][index], # folder name
'idx': self.data_info['idx'][index], # e.g., 0/843
'border': self.data_info['border'][index], # 0 for non-border
- 'lq_path': lq_path[self.opt['num_frame'] // 2] # center frame
+ 'lq_path': lq_path[self.opt['num_frame'] // 2], # center frame
}
def __len__(self):
@@ -200,7 +202,7 @@ def __len__(self):
@DATASET_REGISTRY.register()
class VideoTestDUFDataset(VideoTestDataset):
- """ Video test dataset for DUF dataset.
+ """Video test dataset for DUF dataset.
Args:
opt (dict): Config for train dataset. Most of keys are the same as VideoTestDataset.
@@ -244,7 +246,7 @@ def __getitem__(self, index):
'folder': folder, # folder name
'idx': self.data_info['idx'][index], # e.g., 0/99
'border': border, # 1 for border, 0 for non-border
- 'lq_path': lq_path # center frame
+ 'lq_path': lq_path, # center frame
}
diff --git a/basicsr/data/vimeo90k_dataset.py b/basicsr/data/vimeo90k_dataset.py
index e5e33e108..045bced41 100644
--- a/basicsr/data/vimeo90k_dataset.py
+++ b/basicsr/data/vimeo90k_dataset.py
@@ -1,6 +1,7 @@
import random
-import torch
from pathlib import Path
+
+import torch
from torch.utils import data as data
from basicsr.data.transforms import augment, paired_random_crop
@@ -135,7 +136,6 @@ def __len__(self):
@DATASET_REGISTRY.register()
class Vimeo90KRecurrentDataset(Vimeo90KDataset):
-
def __init__(self, opt):
super(Vimeo90KRecurrentDataset, self).__init__(opt)
diff --git a/basicsr/losses/__init__.py b/basicsr/losses/__init__.py
index 70a172aee..e9efb74d8 100644
--- a/basicsr/losses/__init__.py
+++ b/basicsr/losses/__init__.py
@@ -4,6 +4,7 @@
from basicsr.utils import get_root_logger, scandir
from basicsr.utils.registry import LOSS_REGISTRY
+
from .gan_loss import g_path_regularize, gradient_penalty_loss, r1_penalty
__all__ = ['build_loss', 'gradient_penalty_loss', 'r1_penalty', 'g_path_regularize']
diff --git a/basicsr/losses/basic_loss.py b/basicsr/losses/basic_loss.py
index c95b2b77c..47b0c62ac 100644
--- a/basicsr/losses/basic_loss.py
+++ b/basicsr/losses/basic_loss.py
@@ -4,6 +4,7 @@
from basicsr.archs.vgg_arch import VGGFeatureExtractor
from basicsr.utils.registry import LOSS_REGISTRY
+
from .loss_util import weighted_loss
_reduction_modes = ['none', 'mean', 'sum']
@@ -21,7 +22,7 @@ def mse_loss(pred, target):
@weighted_loss
def charbonnier_loss(pred, target, eps=1e-12):
- return torch.sqrt((pred - target)**2 + eps)
+ return torch.sqrt((pred - target) ** 2 + eps)
@LOSS_REGISTRY.register()
@@ -167,14 +168,16 @@ class PerceptualLoss(nn.Module):
criterion (str): Criterion used for perceptual loss. Default: 'l1'.
"""
- def __init__(self,
- layer_weights,
- vgg_type='vgg19',
- use_input_norm=True,
- range_norm=False,
- perceptual_weight=1.0,
- style_weight=0.,
- criterion='l1'):
+ def __init__(
+ self,
+ layer_weights,
+ vgg_type='vgg19',
+ use_input_norm=True,
+ range_norm=False,
+ perceptual_weight=1.0,
+ style_weight=0.0,
+ criterion='l1',
+ ):
super(PerceptualLoss, self).__init__()
self.perceptual_weight = perceptual_weight
self.style_weight = style_weight
@@ -183,7 +186,8 @@ def __init__(self,
layer_name_list=list(layer_weights.keys()),
vgg_type=vgg_type,
use_input_norm=use_input_norm,
- range_norm=range_norm)
+ range_norm=range_norm,
+ )
self.criterion_type = criterion
if self.criterion_type == 'l1':
@@ -226,11 +230,15 @@ def forward(self, x, gt):
style_loss = 0
for k in x_features.keys():
if self.criterion_type == 'fro':
- style_loss += torch.norm(
- self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro') * self.layer_weights[k]
+ style_loss += (
+ torch.norm(self._gram_mat(x_features[k]) - self._gram_mat(gt_features[k]), p='fro')
+ * self.layer_weights[k]
+ )
else:
- style_loss += self.criterion(self._gram_mat(x_features[k]), self._gram_mat(
- gt_features[k])) * self.layer_weights[k]
+ style_loss += (
+ self.criterion(self._gram_mat(x_features[k]), self._gram_mat(gt_features[k]))
+ * self.layer_weights[k]
+ )
style_loss *= self.style_weight
else:
style_loss = None
diff --git a/basicsr/losses/gan_loss.py b/basicsr/losses/gan_loss.py
index 870baa222..e6d5a16ba 100644
--- a/basicsr/losses/gan_loss.py
+++ b/basicsr/losses/gan_loss.py
@@ -1,4 +1,5 @@
import math
+
import torch
from torch import autograd as autograd
from torch import nn as nn
@@ -83,7 +84,7 @@ def get_target_label(self, input, target_is_real):
if self.gan_type in ['wgan', 'wgan_softplus']:
return target_is_real
- target_val = (self.real_label_val if target_is_real else self.fake_label_val)
+ target_val = self.real_label_val if target_is_real else self.fake_label_val
return input.new_ones(input.size()) * target_val
def forward(self, input, target_is_real, is_disc=False):
@@ -142,15 +143,15 @@ def forward(self, input, target_is_real, is_disc=False):
def r1_penalty(real_pred, real_img):
"""R1 regularization for discriminator. The core idea is to
- penalize the gradient on real data alone: when the
- generator distribution produces the true data distribution
- and the discriminator is equal to 0 on the data manifold, the
- gradient penalty ensures that the discriminator cannot create
- a non-zero gradient orthogonal to the data manifold without
- suffering a loss in the GAN game.
-
- Reference: Eq. 9 in Which training methods for GANs do actually converge.
- """
+ penalize the gradient on real data alone: when the
+ generator distribution produces the true data distribution
+ and the discriminator is equal to 0 on the data manifold, the
+ gradient penalty ensures that the discriminator cannot create
+ a non-zero gradient orthogonal to the data manifold without
+ suffering a loss in the GAN game.
+
+ Reference: Eq. 9 in Which training methods for GANs do actually converge.
+ """
grad_real = autograd.grad(outputs=real_pred.sum(), inputs=real_img, create_graph=True)[0]
grad_penalty = grad_real.pow(2).view(grad_real.shape[0], -1).sum(1).mean()
return grad_penalty
@@ -185,7 +186,7 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
alpha = real_data.new_tensor(torch.rand(batch_size, 1, 1, 1))
# interpolate between real_data and fake_data
- interpolates = alpha * real_data + (1. - alpha) * fake_data
+ interpolates = alpha * real_data + (1.0 - alpha) * fake_data
interpolates = autograd.Variable(interpolates, requires_grad=True)
disc_interpolates = discriminator(interpolates)
@@ -195,12 +196,13 @@ def gradient_penalty_loss(discriminator, real_data, fake_data, weight=None):
grad_outputs=torch.ones_like(disc_interpolates),
create_graph=True,
retain_graph=True,
- only_inputs=True)[0]
+ only_inputs=True,
+ )[0]
if weight is not None:
gradients = gradients * weight
- gradients_penalty = ((gradients.norm(2, dim=1) - 1)**2).mean()
+ gradients_penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
if weight is not None:
gradients_penalty /= torch.mean(weight)
diff --git a/basicsr/losses/loss_util.py b/basicsr/losses/loss_util.py
index fd293ff9e..27d757b92 100644
--- a/basicsr/losses/loss_util.py
+++ b/basicsr/losses/loss_util.py
@@ -1,4 +1,5 @@
import functools
+
import torch
from torch.nn import functional as F
@@ -136,7 +137,7 @@ def get_refined_artifact_map(img_gt, img_output, img_ema, ksize):
residual_ema = torch.sum(torch.abs(img_gt - img_ema), 1, keepdim=True)
residual_sr = torch.sum(torch.abs(img_gt - img_output), 1, keepdim=True)
- patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True)**(1 / 5)
+ patch_level_weight = torch.var(residual_sr.clone(), dim=(-1, -2, -3), keepdim=True) ** (1 / 5)
pixel_level_weight = get_local_weights(residual_sr.clone(), ksize)
overall_weight = patch_level_weight * pixel_level_weight
diff --git a/basicsr/metrics/__init__.py b/basicsr/metrics/__init__.py
index 4fb044a93..5e6bc794c 100644
--- a/basicsr/metrics/__init__.py
+++ b/basicsr/metrics/__init__.py
@@ -1,6 +1,7 @@
from copy import deepcopy
from basicsr.utils.registry import METRIC_REGISTRY
+
from .niqe import calculate_niqe
from .psnr_ssim import calculate_psnr, calculate_ssim
diff --git a/basicsr/metrics/fid.py b/basicsr/metrics/fid.py
index 1b0ba6df1..184348c99 100644
--- a/basicsr/metrics/fid.py
+++ b/basicsr/metrics/fid.py
@@ -64,7 +64,7 @@ def calculate_fid(mu1, sigma1, mu2, sigma2, eps=1e-6):
float: The Frechet Distance.
"""
assert mu1.shape == mu2.shape, 'Two mean vectors have different lengths'
- assert sigma1.shape == sigma2.shape, ('Two covariances have different dimensions')
+ assert sigma1.shape == sigma2.shape, 'Two covariances have different dimensions'
cov_sqrt, _ = linalg.sqrtm(sigma1 @ sigma2, disp=False)
diff --git a/basicsr/metrics/metric_util.py b/basicsr/metrics/metric_util.py
index 2a27c70a0..474dd68e0 100644
--- a/basicsr/metrics/metric_util.py
+++ b/basicsr/metrics/metric_util.py
@@ -38,8 +38,8 @@ def to_y_channel(img):
Returns:
(ndarray): Images with range [0, 255] (float type) without round.
"""
- img = img.astype(np.float32) / 255.
+ img = img.astype(np.float32) / 255.0
if img.ndim == 3 and img.shape[2] == 3:
img = bgr2ycbcr(img, y_only=True)
img = img[..., None]
- return img * 255.
+ return img * 255.0
diff --git a/basicsr/metrics/niqe.py b/basicsr/metrics/niqe.py
index e3c1467f6..c0ab74d83 100644
--- a/basicsr/metrics/niqe.py
+++ b/basicsr/metrics/niqe.py
@@ -1,7 +1,8 @@
-import cv2
import math
-import numpy as np
import os
+
+import cv2
+import numpy as np
from scipy.ndimage import convolve
from scipy.special import gamma
@@ -25,12 +26,12 @@ def estimate_aggd_param(block):
gam_reciprocal = np.reciprocal(gam)
r_gam = np.square(gamma(gam_reciprocal * 2)) / (gamma(gam_reciprocal) * gamma(gam_reciprocal * 3))
- left_std = np.sqrt(np.mean(block[block < 0]**2))
- right_std = np.sqrt(np.mean(block[block > 0]**2))
+ left_std = np.sqrt(np.mean(block[block < 0] ** 2))
+ right_std = np.sqrt(np.mean(block[block > 0] ** 2))
gammahat = left_std / right_std
- rhat = (np.mean(np.abs(block)))**2 / np.mean(block**2)
- rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1)**2)
- array_position = np.argmin((r_gam - rhatnorm)**2)
+ rhat = (np.mean(np.abs(block))) ** 2 / np.mean(block**2)
+ rhatnorm = (rhat * (gammahat**3 + 1) * (gammahat + 1)) / ((gammahat**2 + 1) ** 2)
+ array_position = np.argmin((r_gam - rhatnorm) ** 2)
alpha = gam[array_position]
beta_l = left_std * np.sqrt(gamma(1 / alpha) / gamma(3 / alpha))
@@ -95,12 +96,12 @@ def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, b
block_size_w (int): Width of the blocks in to which image is divided.
Default: 96 (the official recommended value).
"""
- assert img.ndim == 2, ('Input image must be a gray or Y (of YCbCr) image with shape (h, w).')
+ assert img.ndim == 2, 'Input image must be a gray or Y (of YCbCr) image with shape (h, w).'
# crop image
h, w = img.shape
num_block_h = math.floor(h / block_size_h)
num_block_w = math.floor(w / block_size_w)
- img = img[0:num_block_h * block_size_h, 0:num_block_w * block_size_w]
+ img = img[0 : num_block_h * block_size_h, 0 : num_block_w * block_size_w]
distparam = [] # dist param is actually the multiscale features
for scale in (1, 2): # perform on two scales (1, 2)
@@ -113,15 +114,17 @@ def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, b
for idx_w in range(num_block_w):
for idx_h in range(num_block_h):
# process ecah block
- block = img_nomalized[idx_h * block_size_h // scale:(idx_h + 1) * block_size_h // scale,
- idx_w * block_size_w // scale:(idx_w + 1) * block_size_w // scale]
+ block = img_nomalized[
+ idx_h * block_size_h // scale : (idx_h + 1) * block_size_h // scale,
+ idx_w * block_size_w // scale : (idx_w + 1) * block_size_w // scale,
+ ]
feat.append(compute_feature(block))
distparam.append(np.array(feat))
if scale == 1:
- img = imresize(img / 255., scale=0.5, antialiasing=True)
- img = img * 255.
+ img = imresize(img / 255.0, scale=0.5, antialiasing=True)
+ img = img * 255.0
distparam = np.concatenate(distparam, axis=1)
@@ -134,7 +137,8 @@ def niqe(img, mu_pris_param, cov_pris_param, gaussian_window, block_size_h=96, b
# compute niqe quality, Eq. 10 in the paper
invcov_param = np.linalg.pinv((cov_pris_param + cov_distparam) / 2)
quality = np.matmul(
- np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam)))
+ np.matmul((mu_pris_param - mu_distparam), invcov_param), np.transpose((mu_pris_param - mu_distparam))
+ )
quality = np.sqrt(quality)
quality = float(np.squeeze(quality))
@@ -185,7 +189,7 @@ def calculate_niqe(img, crop_border, input_order='HWC', convert_to='y', **kwargs
if convert_to == 'y':
img = to_y_channel(img)
elif convert_to == 'gray':
- img = cv2.cvtColor(img / 255., cv2.COLOR_BGR2GRAY) * 255.
+ img = cv2.cvtColor(img / 255.0, cv2.COLOR_BGR2GRAY) * 255.0
img = np.squeeze(img)
if crop_border != 0:
diff --git a/basicsr/metrics/psnr_ssim.py b/basicsr/metrics/psnr_ssim.py
index ab03113f8..d76420036 100644
--- a/basicsr/metrics/psnr_ssim.py
+++ b/basicsr/metrics/psnr_ssim.py
@@ -25,7 +25,7 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
float: PSNR result.
"""
- assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ assert img.shape == img2.shape, f'Image shapes are different: {img.shape}, {img2.shape}.'
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
@@ -42,10 +42,10 @@ def calculate_psnr(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
img = img.astype(np.float64)
img2 = img2.astype(np.float64)
- mse = np.mean((img - img2)**2)
+ mse = np.mean((img - img2) ** 2)
if mse == 0:
return float('inf')
- return 10. * np.log10(255. * 255. / mse)
+ return 10.0 * np.log10(255.0 * 255.0 / mse)
@METRIC_REGISTRY.register()
@@ -64,7 +64,7 @@ def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
float: PSNR result.
"""
- assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ assert img.shape == img2.shape, f'Image shapes are different: {img.shape}, {img2.shape}.'
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
@@ -77,8 +77,8 @@ def calculate_psnr_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
- mse = torch.mean((img - img2)**2, dim=[1, 2, 3])
- return 10. * torch.log10(1. / (mse + 1e-8))
+ mse = torch.mean((img - img2) ** 2, dim=[1, 2, 3])
+ return 10.0 * torch.log10(1.0 / (mse + 1e-8))
@METRIC_REGISTRY.register()
@@ -105,7 +105,7 @@ def calculate_ssim(img, img2, crop_border, input_order='HWC', test_y_channel=Fal
float: SSIM result.
"""
- assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ assert img.shape == img2.shape, f'Image shapes are different: {img.shape}, {img2.shape}.'
if input_order not in ['HWC', 'CHW']:
raise ValueError(f'Wrong input_order {input_order}. Supported input_orders are "HWC" and "CHW"')
img = reorder_image(img, input_order=input_order)
@@ -150,7 +150,7 @@ def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
float: SSIM result.
"""
- assert img.shape == img2.shape, (f'Image shapes are different: {img.shape}, {img2.shape}.')
+ assert img.shape == img2.shape, f'Image shapes are different: {img.shape}, {img2.shape}.'
if crop_border != 0:
img = img[:, :, crop_border:-crop_border, crop_border:-crop_border]
@@ -163,7 +163,7 @@ def calculate_ssim_pt(img, img2, crop_border, test_y_channel=False, **kwargs):
img = img.to(torch.float64)
img2 = img2.to(torch.float64)
- ssim = _ssim_pth(img * 255., img2 * 255.)
+ ssim = _ssim_pth(img * 255.0, img2 * 255.0)
return ssim
@@ -180,8 +180,8 @@ def _ssim(img, img2):
float: SSIM result.
"""
- c1 = (0.01 * 255)**2
- c2 = (0.03 * 255)**2
+ c1 = (0.01 * 255) ** 2
+ c2 = (0.03 * 255) ** 2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
@@ -210,8 +210,8 @@ def _ssim_pth(img, img2):
Returns:
float: SSIM result.
"""
- c1 = (0.01 * 255)**2
- c2 = (0.03 * 255)**2
+ c1 = (0.01 * 255) ** 2
+ c2 = (0.03 * 255) ** 2
kernel = cv2.getGaussianKernel(11, 1.5)
window = np.outer(kernel, kernel.transpose())
diff --git a/basicsr/metrics/test_metrics/test_psnr_ssim.py b/basicsr/metrics/test_metrics/test_psnr_ssim.py
index 18b05a73a..48f8fa572 100644
--- a/basicsr/metrics/test_metrics/test_psnr_ssim.py
+++ b/basicsr/metrics/test_metrics/test_psnr_ssim.py
@@ -16,8 +16,8 @@ def test(img_path, img_path2, crop_border, test_y_channel=False):
print(f'\tNumpy\tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
# --------------------- PyTorch (CPU) ---------------------
- img = img2tensor(img / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
- img2 = img2tensor(img2 / 255., bgr2rgb=True, float32=True).unsqueeze_(0)
+ img = img2tensor(img / 255.0, bgr2rgb=True, float32=True).unsqueeze_(0)
+ img2 = img2tensor(img2 / 255.0, bgr2rgb=True, float32=True).unsqueeze_(0)
psnr_pth = calculate_psnr_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
ssim_pth = calculate_ssim_pt(img, img2, crop_border=crop_border, test_y_channel=test_y_channel)
@@ -34,14 +34,18 @@ def test(img_path, img_path2, crop_border, test_y_channel=False):
torch.repeat_interleave(img, 2, dim=0),
torch.repeat_interleave(img2, 2, dim=0),
crop_border=crop_border,
- test_y_channel=test_y_channel)
+ test_y_channel=test_y_channel,
+ )
ssim_pth = calculate_ssim_pt(
torch.repeat_interleave(img, 2, dim=0),
torch.repeat_interleave(img2, 2, dim=0),
crop_border=crop_border,
- test_y_channel=test_y_channel)
- print(f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
- f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}')
+ test_y_channel=test_y_channel,
+ )
+ print(
+ f'\tTensor (GPU batch) \tPSNR: {psnr_pth[0]:.6f}, {psnr_pth[1]:.6f} dB,'
+ f'\tSSIM: {ssim_pth[0]:.6f}, {ssim_pth[1]:.6f}'
+ )
if __name__ == '__main__':
diff --git a/basicsr/models/base_model.py b/basicsr/models/base_model.py
index fbf8229f5..e4bc3055e 100644
--- a/basicsr/models/base_model.py
+++ b/basicsr/models/base_model.py
@@ -1,8 +1,9 @@
import os
import time
-import torch
from collections import OrderedDict
from copy import deepcopy
+
+import torch
from torch.nn.parallel import DataParallel, DistributedDataParallel
from basicsr.models import lr_scheduler as lr_scheduler
@@ -10,7 +11,7 @@
from basicsr.utils.dist_util import master_only
-class BaseModel():
+class BaseModel:
"""Base model."""
def __init__(self, opt):
@@ -95,7 +96,8 @@ def model_to_device(self, net):
if self.opt['dist']:
find_unused_parameters = self.opt.get('find_unused_parameters', False)
net = DistributedDataParallel(
- net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters)
+ net, device_ids=[torch.cuda.current_device()], find_unused_parameters=find_unused_parameters
+ )
elif self.opt['num_gpu'] > 1:
net = DataParallel(net)
return net
@@ -171,8 +173,7 @@ def _set_lr(self, lr_groups_l):
param_group['lr'] = lr
def _get_init_lr(self):
- """Get the initial lr, which is set by the scheduler.
- """
+ """Get the initial lr, which is set by the scheduler."""
init_lr_groups_l = []
for optimizer in self.optimizers:
init_lr_groups_l.append([v['initial_lr'] for v in optimizer.param_groups])
@@ -282,8 +283,9 @@ def _print_different_keys_loading(self, crt_net, load_net, strict=True):
common_keys = crt_net_keys & load_net_keys
for k in common_keys:
if crt_net[k].size() != load_net[k].size():
- logger.warning(f'Size different, ignore [{k}]: crt_net: '
- f'{crt_net[k].shape}; load_net: {load_net[k].shape}')
+ logger.warning(
+ f'Size different, ignore [{k}]: crt_net: {crt_net[k].shape}; load_net: {load_net[k].shape}'
+ )
load_net[k + '.ignore'] = load_net.pop(k)
def load_network(self, net, load_path, strict=True, param_key='params'):
diff --git a/basicsr/models/edvr_model.py b/basicsr/models/edvr_model.py
index 9bdbf7b94..be6a384a2 100644
--- a/basicsr/models/edvr_model.py
+++ b/basicsr/models/edvr_model.py
@@ -1,5 +1,6 @@
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
+
from .video_base_model import VideoBaseModel
@@ -33,12 +34,9 @@ def setup_optimizers(self):
optim_params = [
{ # add normal params first
'params': normal_params,
- 'lr': train_opt['optim_g']['lr']
- },
- {
- 'params': dcn_params,
- 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul
+ 'lr': train_opt['optim_g']['lr'],
},
+ {'params': dcn_params, 'lr': train_opt['optim_g']['lr'] * dcn_lr_mul},
]
optim_type = train_opt['optim_g'].pop('type')
diff --git a/basicsr/models/esrgan_model.py b/basicsr/models/esrgan_model.py
index 3d746d0e2..c682fd29b 100644
--- a/basicsr/models/esrgan_model.py
+++ b/basicsr/models/esrgan_model.py
@@ -1,7 +1,9 @@
-import torch
from collections import OrderedDict
+import torch
+
from basicsr.utils.registry import MODEL_REGISTRY
+
from .srgan_model import SRGANModel
@@ -19,7 +21,7 @@ def optimize_parameters(self, current_iter):
l_g_total = 0
loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
diff --git a/basicsr/models/hifacegan_model.py b/basicsr/models/hifacegan_model.py
index 435a2b179..909036e3e 100644
--- a/basicsr/models/hifacegan_model.py
+++ b/basicsr/models/hifacegan_model.py
@@ -1,6 +1,7 @@
-import torch
from collections import OrderedDict
from os import path as osp
+
+import torch
from tqdm import tqdm
from basicsr.archs import build_network
@@ -8,6 +9,7 @@
from basicsr.metrics import calculate_metric
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
+
from .sr_model import SRModel
@@ -101,15 +103,15 @@ def _divide_pred(pred):
The prediction contains the intermediate outputs of multiscale GAN,
so it's usually a list
"""
- if type(pred) == list:
+ if isinstance(pred, list):
fake = []
real = []
for p in pred:
- fake.append([tensor[:tensor.size(0) // 2] for tensor in p])
- real.append([tensor[tensor.size(0) // 2:] for tensor in p])
+ fake.append([tensor[: tensor.size(0) // 2] for tensor in p])
+ real.append([tensor[tensor.size(0) // 2 :] for tensor in p])
else:
- fake = pred[:pred.size(0) // 2]
- real = pred[pred.size(0) // 2:]
+ fake = pred[: pred.size(0) // 2]
+ real = pred[pred.size(0) // 2 :]
return fake, real
@@ -124,7 +126,7 @@ def optimize_parameters(self, current_iter):
l_g_total = 0
loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
@@ -209,8 +211,10 @@ def validation(self, dataloader, current_iter, tb_logger, save_img=False):
if self.opt['dist']:
self.dist_validation(dataloader, current_iter, tb_logger, save_img)
else:
- print('In HiFaceGANModel: The new metrics package is under development.' +
- 'Using super method now (Only PSNR & SSIM are supported)')
+ print(
+ 'In HiFaceGANModel: The new metrics package is under development.'
+ + 'Using super method now (Only PSNR & SSIM are supported)'
+ )
super().nondist_validation(dataloader, current_iter, tb_logger, save_img)
def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
@@ -253,15 +257,20 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
if save_img:
if self.opt['is_train']:
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
- f'{img_name}_{current_iter}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'], img_name, f'{img_name}_{current_iter}.png'
+ )
else:
if self.opt['val']['suffix']:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png',
+ )
else:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["name"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'], dataset_name, f'{img_name}_{self.opt["name"]}.png'
+ )
imwrite(tensor2img(visuals['result']), save_img_path)
diff --git a/basicsr/models/lr_scheduler.py b/basicsr/models/lr_scheduler.py
index 11e1c6c7a..1753c6b36 100644
--- a/basicsr/models/lr_scheduler.py
+++ b/basicsr/models/lr_scheduler.py
@@ -1,10 +1,11 @@
import math
from collections import Counter
+
from torch.optim.lr_scheduler import _LRScheduler
class MultiStepRestartLR(_LRScheduler):
- """ MultiStep with restarts learning rate scheme.
+ """MultiStep with restarts learning rate scheme.
Args:
optimizer (torch.nn.optimizer): Torch optimizer.
@@ -16,7 +17,7 @@ class MultiStepRestartLR(_LRScheduler):
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
- def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0, ), restart_weights=(1, ), last_epoch=-1):
+ def __init__(self, optimizer, milestones, gamma=0.1, restarts=(0,), restart_weights=(1,), last_epoch=-1):
self.milestones = Counter(milestones)
self.gamma = gamma
self.restarts = restarts
@@ -30,7 +31,7 @@ def get_lr(self):
return [group['initial_lr'] * weight for group in self.optimizer.param_groups]
if self.last_epoch not in self.milestones:
return [group['lr'] for group in self.optimizer.param_groups]
- return [group['lr'] * self.gamma**self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
+ return [group['lr'] * self.gamma ** self.milestones[self.last_epoch] for group in self.optimizer.param_groups]
def get_position_from_periods(iteration, cumulative_period):
@@ -55,7 +56,7 @@ def get_position_from_periods(iteration, cumulative_period):
class CosineAnnealingRestartLR(_LRScheduler):
- """ Cosine annealing with restarts learning rate scheme.
+ """Cosine annealing with restarts learning rate scheme.
An example of config:
periods = [10, 10, 10, 10]
@@ -74,13 +75,14 @@ class CosineAnnealingRestartLR(_LRScheduler):
last_epoch (int): Used in _LRScheduler. Default: -1.
"""
- def __init__(self, optimizer, periods, restart_weights=(1, ), eta_min=0, last_epoch=-1):
+ def __init__(self, optimizer, periods, restart_weights=(1,), eta_min=0, last_epoch=-1):
self.periods = periods
self.restart_weights = restart_weights
self.eta_min = eta_min
- assert (len(self.periods) == len(
- self.restart_weights)), 'periods and restart_weights should have the same length.'
- self.cumulative_period = [sum(self.periods[0:i + 1]) for i in range(0, len(self.periods))]
+ assert len(self.periods) == len(self.restart_weights), (
+ 'periods and restart_weights should have the same length.'
+ )
+ self.cumulative_period = [sum(self.periods[0 : i + 1]) for i in range(0, len(self.periods))]
super(CosineAnnealingRestartLR, self).__init__(optimizer, last_epoch)
def get_lr(self):
@@ -90,7 +92,10 @@ def get_lr(self):
current_period = self.periods[idx]
return [
- self.eta_min + current_weight * 0.5 * (base_lr - self.eta_min) *
- (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
+ self.eta_min
+ + current_weight
+ * 0.5
+ * (base_lr - self.eta_min)
+ * (1 + math.cos(math.pi * ((self.last_epoch - nearest_restart) / current_period)))
for base_lr in self.base_lrs
]
diff --git a/basicsr/models/realesrgan_model.py b/basicsr/models/realesrgan_model.py
index c74b28fb1..7e62f900b 100644
--- a/basicsr/models/realesrgan_model.py
+++ b/basicsr/models/realesrgan_model.py
@@ -1,7 +1,8 @@
-import numpy as np
import random
-import torch
from collections import OrderedDict
+
+import numpy as np
+import torch
from torch.nn import functional as F
from basicsr.data.degradations import random_add_gaussian_noise_pt, random_add_poisson_noise_pt
@@ -61,14 +62,13 @@ def _dequeue_and_enqueue(self):
self.gt = gt_dequeue
else:
# only do enqueue
- self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
- self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_lr[self.queue_ptr : self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr : self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
- """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
- """
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images."""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
@@ -97,14 +97,12 @@ def feed_data(self, data):
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
- out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob
+ )
else:
out = random_add_poisson_noise_pt(
- out,
- scale_range=self.opt['poisson_scale_range'],
- gray_prob=gray_noise_prob,
- clip=True,
- rounds=False)
+ out, scale_range=self.opt['poisson_scale_range'], gray_prob=gray_noise_prob, clip=True, rounds=False
+ )
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
@@ -124,19 +122,22 @@ def feed_data(self, data):
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
- out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode
+ )
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
- out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob
+ )
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
- rounds=False)
+ rounds=False,
+ )
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
@@ -165,12 +166,13 @@ def feed_data(self, data):
out = filter2D(out, self.sinc_kernel)
# clamp and round
- self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
# random crop
gt_size = self.opt['gt_size']
- (self.gt, self.gt_usm), self.lq = paired_random_crop([self.gt, self.gt_usm], self.lq, gt_size,
- self.opt['scale'])
+ (self.gt, self.gt_usm), self.lq = paired_random_crop(
+ [self.gt, self.gt_usm], self.lq, gt_size, self.opt['scale']
+ )
# training pair pool
self._dequeue_and_enqueue()
@@ -213,7 +215,7 @@ def optimize_parameters(self, current_iter):
l_g_total = 0
loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, l1_gt)
diff --git a/basicsr/models/realesrnet_model.py b/basicsr/models/realesrnet_model.py
index f5790918b..be1deceaf 100644
--- a/basicsr/models/realesrnet_model.py
+++ b/basicsr/models/realesrnet_model.py
@@ -1,5 +1,6 @@
-import numpy as np
import random
+
+import numpy as np
import torch
from torch.nn import functional as F
@@ -60,14 +61,13 @@ def _dequeue_and_enqueue(self):
self.gt = gt_dequeue
else:
# only do enqueue
- self.queue_lr[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.lq.clone()
- self.queue_gt[self.queue_ptr:self.queue_ptr + b, :, :, :] = self.gt.clone()
+ self.queue_lr[self.queue_ptr : self.queue_ptr + b, :, :, :] = self.lq.clone()
+ self.queue_gt[self.queue_ptr : self.queue_ptr + b, :, :, :] = self.gt.clone()
self.queue_ptr = self.queue_ptr + b
@torch.no_grad()
def feed_data(self, data):
- """Accept data from dataloader, and then add two-order degradations to obtain LQ images.
- """
+ """Accept data from dataloader, and then add two-order degradations to obtain LQ images."""
if self.is_train and self.opt.get('high_order_degradation', True):
# training data synthesis
self.gt = data['gt'].to(self.device)
@@ -98,14 +98,12 @@ def feed_data(self, data):
gray_noise_prob = self.opt['gray_noise_prob']
if np.random.uniform() < self.opt['gaussian_noise_prob']:
out = random_add_gaussian_noise_pt(
- out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ out, sigma_range=self.opt['noise_range'], clip=True, rounds=False, gray_prob=gray_noise_prob
+ )
else:
out = random_add_poisson_noise_pt(
- out,
- scale_range=self.opt['poisson_scale_range'],
- gray_prob=gray_noise_prob,
- clip=True,
- rounds=False)
+ out, scale_range=self.opt['poisson_scale_range'], gray_prob=gray_noise_prob, clip=True, rounds=False
+ )
# JPEG compression
jpeg_p = out.new_zeros(out.size(0)).uniform_(*self.opt['jpeg_range'])
out = torch.clamp(out, 0, 1) # clamp to [0, 1], otherwise JPEGer will result in unpleasant artifacts
@@ -125,19 +123,22 @@ def feed_data(self, data):
scale = 1
mode = random.choice(['area', 'bilinear', 'bicubic'])
out = F.interpolate(
- out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode)
+ out, size=(int(ori_h / self.opt['scale'] * scale), int(ori_w / self.opt['scale'] * scale)), mode=mode
+ )
# add noise
gray_noise_prob = self.opt['gray_noise_prob2']
if np.random.uniform() < self.opt['gaussian_noise_prob2']:
out = random_add_gaussian_noise_pt(
- out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob)
+ out, sigma_range=self.opt['noise_range2'], clip=True, rounds=False, gray_prob=gray_noise_prob
+ )
else:
out = random_add_poisson_noise_pt(
out,
scale_range=self.opt['poisson_scale_range2'],
gray_prob=gray_noise_prob,
clip=True,
- rounds=False)
+ rounds=False,
+ )
# JPEG compression + the final sinc filter
# We also need to resize images to desired sizes. We group [resize back + sinc filter] together
@@ -166,7 +167,7 @@ def feed_data(self, data):
out = filter2D(out, self.sinc_kernel)
# clamp and round
- self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.
+ self.lq = torch.clamp((out * 255.0).round(), 0, 255) / 255.0
# random crop
gt_size = self.opt['gt_size']
diff --git a/basicsr/models/sr_model.py b/basicsr/models/sr_model.py
index 787f1fd2e..394eb5686 100644
--- a/basicsr/models/sr_model.py
+++ b/basicsr/models/sr_model.py
@@ -1,6 +1,7 @@
-import torch
from collections import OrderedDict
from os import path as osp
+
+import torch
from tqdm import tqdm
from basicsr.archs import build_network
@@ -8,6 +9,7 @@
from basicsr.metrics import calculate_metric
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
+
from .base_model import BaseModel
@@ -219,15 +221,20 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
if save_img:
if self.opt['is_train']:
- save_img_path = osp.join(self.opt['path']['visualization'], img_name,
- f'{img_name}_{current_iter}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'], img_name, f'{img_name}_{current_iter}.png'
+ )
else:
if self.opt['val']['suffix']:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png',
+ )
else:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name,
- f'{img_name}_{self.opt["name"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'], dataset_name, f'{img_name}_{self.opt["name"]}.png'
+ )
imwrite(sr_img, save_img_path)
if with_metrics:
@@ -242,7 +249,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
if with_metrics:
for metric in self.metric_results.keys():
- self.metric_results[metric] /= (idx + 1)
+ self.metric_results[metric] /= idx + 1
# update the best metric result
self._update_best_metric_result(dataset_name, metric, self.metric_results[metric], current_iter)
@@ -253,8 +260,10 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
for metric, value in self.metric_results.items():
log_str += f'\t # {metric}: {value:.4f}'
if hasattr(self, 'best_metric_results'):
- log_str += (f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
- f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += (
+ f'\tBest: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter'
+ )
log_str += '\n'
logger = get_root_logger()
diff --git a/basicsr/models/srgan_model.py b/basicsr/models/srgan_model.py
index 45387ca79..dee207708 100644
--- a/basicsr/models/srgan_model.py
+++ b/basicsr/models/srgan_model.py
@@ -1,10 +1,12 @@
-import torch
from collections import OrderedDict
+import torch
+
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
+
from .sr_model import SRModel
@@ -92,7 +94,7 @@ def optimize_parameters(self, current_iter):
l_g_total = 0
loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
diff --git a/basicsr/models/stylegan2_model.py b/basicsr/models/stylegan2_model.py
index d7da70812..b2551252e 100644
--- a/basicsr/models/stylegan2_model.py
+++ b/basicsr/models/stylegan2_model.py
@@ -1,16 +1,18 @@
-import cv2
import math
-import numpy as np
import random
-import torch
from collections import OrderedDict
from os import path as osp
+import cv2
+import numpy as np
+import torch
+
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.losses.gan_loss import g_path_regularize, r1_penalty
from basicsr.utils import imwrite, tensor2img
from basicsr.utils.registry import MODEL_REGISTRY
+
from .base_model import BaseModel
@@ -105,25 +107,21 @@ def setup_optimizers(self):
optim_params_g = [
{ # add normal params first
'params': normal_params,
- 'lr': train_opt['optim_g']['lr']
- },
- {
- 'params': style_mlp_params,
- 'lr': train_opt['optim_g']['lr'] * 0.01
+ 'lr': train_opt['optim_g']['lr'],
},
- {
- 'params': modulation_conv_params,
- 'lr': train_opt['optim_g']['lr'] / 3
- }
+ {'params': style_mlp_params, 'lr': train_opt['optim_g']['lr'] * 0.01},
+ {'params': modulation_conv_params, 'lr': train_opt['optim_g']['lr'] / 3},
]
else:
normal_params = []
for name, param in self.net_g.named_parameters():
normal_params.append(param)
- optim_params_g = [{ # add normal params first
- 'params': normal_params,
- 'lr': train_opt['optim_g']['lr']
- }]
+ optim_params_g = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_g']['lr'],
+ }
+ ]
optim_type = train_opt['optim_g'].pop('type')
lr = train_opt['optim_g']['lr'] * net_g_reg_ratio
@@ -144,21 +142,20 @@ def setup_optimizers(self):
optim_params_d = [
{ # add normal params first
'params': normal_params,
- 'lr': train_opt['optim_d']['lr']
+ 'lr': train_opt['optim_d']['lr'],
},
- {
- 'params': linear_params,
- 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))
- }
+ {'params': linear_params, 'lr': train_opt['optim_d']['lr'] * (1 / math.sqrt(512))},
]
else:
normal_params = []
for name, param in self.net_d.named_parameters():
normal_params.append(param)
- optim_params_d = [{ # add normal params first
- 'params': normal_params,
- 'lr': train_opt['optim_d']['lr']
- }]
+ optim_params_d = [
+ { # add normal params first
+ 'params': normal_params,
+ 'lr': train_opt['optim_d']['lr'],
+ }
+ ]
optim_type = train_opt['optim_d'].pop('type')
lr = train_opt['optim_d']['lr'] * net_d_reg_ratio
@@ -209,7 +206,7 @@ def optimize_parameters(self, current_iter):
self.real_img.requires_grad = True
real_pred = self.net_d(self.real_img)
l_d_r1 = r1_penalty(real_pred, self.real_img)
- l_d_r1 = (self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0])
+ l_d_r1 = self.r1_reg_weight / 2 * l_d_r1 * self.net_d_reg_every + 0 * real_pred[0]
# TODO: why do we need to add 0 * real_pred, otherwise, a runtime
# error will arise: RuntimeError: Expected to have finished
# reduction in the prior iteration before starting a new one.
@@ -240,7 +237,7 @@ def optimize_parameters(self, current_iter):
fake_img, latents = self.net_g(noise, return_latents=True)
l_g_path, path_lengths, self.mean_path_length = g_path_regularize(fake_img, latents, self.mean_path_length)
- l_g_path = (self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0])
+ l_g_path = self.path_reg_weight * self.net_g_reg_every * l_g_path + 0 * fake_img[0, 0, 0, 0]
# TODO: why do we need to add 0 * fake_img[0, 0, 0, 0]
l_g_path.backward()
loss_dict['l_g_path'] = l_g_path.detach().mean()
@@ -251,7 +248,7 @@ def optimize_parameters(self, current_iter):
self.log_dict = self.reduce_loss_dict(loss_dict)
# EMA
- self.model_ema(decay=0.5**(32 / (10 * 1000)))
+ self.model_ema(decay=0.5 ** (32 / (10 * 1000)))
def test(self):
with torch.no_grad():
@@ -272,7 +269,7 @@ def nondist_validation(self, dataloader, current_iter, tb_logger, save_img):
save_img_path = osp.join(self.opt['path']['visualization'], 'test', f'test_{self.opt["name"]}.png')
imwrite(result, save_img_path)
# add sample images to tb_logger
- result = (result / 255.).astype(np.float32)
+ result = (result / 255.0).astype(np.float32)
result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
if tb_logger is not None:
tb_logger.add_image('samples', result, global_step=current_iter, dataformats='HWC')
diff --git a/basicsr/models/swinir_model.py b/basicsr/models/swinir_model.py
index 5ac182f23..688ad9822 100644
--- a/basicsr/models/swinir_model.py
+++ b/basicsr/models/swinir_model.py
@@ -2,12 +2,12 @@
from torch.nn import functional as F
from basicsr.utils.registry import MODEL_REGISTRY
+
from .sr_model import SRModel
@MODEL_REGISTRY.register()
class SwinIRModel(SRModel):
-
def test(self):
# pad to multiplication of window_size
window_size = self.opt['network_g']['window_size']
@@ -30,4 +30,4 @@ def test(self):
self.net_g.train()
_, _, h, w = self.output.size()
- self.output = self.output[:, :, 0:h - mod_pad_h * scale, 0:w - mod_pad_w * scale]
+ self.output = self.output[:, :, 0 : h - mod_pad_h * scale, 0 : w - mod_pad_w * scale]
diff --git a/basicsr/models/video_base_model.py b/basicsr/models/video_base_model.py
index 9f7993a15..72fb5e31e 100644
--- a/basicsr/models/video_base_model.py
+++ b/basicsr/models/video_base_model.py
@@ -1,6 +1,7 @@
-import torch
from collections import Counter
from os import path as osp
+
+import torch
from torch import distributed as dist
from tqdm import tqdm
@@ -8,6 +9,7 @@
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
+
from .sr_model import SRModel
@@ -30,7 +32,8 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
- num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda'
+ )
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
@@ -77,11 +80,19 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
img_name = osp.splitext(osp.basename(lq_path))[0]
if self.opt['val']['suffix']:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
- f'{img_name}_{self.opt["val"]["suffix"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ folder,
+ f'{img_name}_{self.opt["val"]["suffix"]}.png',
+ )
else:
- save_img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
- f'{img_name}_{self.opt["name"]}.png')
+ save_img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ folder,
+ f'{img_name}_{self.opt["name"]}.png',
+ )
imwrite(result_img, save_img_path)
if with_metrics:
@@ -123,8 +134,7 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
# 'folder2': tensor (len(metrics))
# }
metric_results_avg = {
- folder: torch.mean(tensor, dim=0).cpu()
- for (folder, tensor) in self.metric_results.items()
+ folder: torch.mean(tensor, dim=0).cpu() for (folder, tensor) in self.metric_results.items()
}
# total_avg_results is a dict: {
# 'metric1': float,
@@ -147,8 +157,10 @@ def _log_validation_metric_values(self, current_iter, dataset_name, tb_logger):
for folder, tensor in metric_results_avg.items():
log_str += f'\t # {folder}: {tensor[metric_idx].item():.4f}'
if hasattr(self, 'best_metric_results'):
- log_str += (f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
- f'{self.best_metric_results[dataset_name][metric]["iter"]} iter')
+ log_str += (
+ f'\n\t Best: {self.best_metric_results[dataset_name][metric]["val"]:.4f} @ '
+ f'{self.best_metric_results[dataset_name][metric]["iter"]} iter'
+ )
log_str += '\n'
logger = get_root_logger()
diff --git a/basicsr/models/video_gan_model.py b/basicsr/models/video_gan_model.py
index a2adcdeee..e12aef207 100644
--- a/basicsr/models/video_gan_model.py
+++ b/basicsr/models/video_gan_model.py
@@ -1,4 +1,5 @@
from basicsr.utils.registry import MODEL_REGISTRY
+
from .srgan_model import SRGANModel
from .video_base_model import VideoBaseModel
diff --git a/basicsr/models/video_recurrent_gan_model.py b/basicsr/models/video_recurrent_gan_model.py
index 74cf81145..5e1f6ef3d 100644
--- a/basicsr/models/video_recurrent_gan_model.py
+++ b/basicsr/models/video_recurrent_gan_model.py
@@ -1,16 +1,17 @@
-import torch
from collections import OrderedDict
+import torch
+
from basicsr.archs import build_network
from basicsr.losses import build_loss
from basicsr.utils import get_root_logger
from basicsr.utils.registry import MODEL_REGISTRY
+
from .video_recurrent_model import VideoRecurrentModel
@MODEL_REGISTRY.register()
class VideoRecurrentGANModel(VideoRecurrentModel):
-
def init_training_settings(self):
train_opt = self.opt['train']
@@ -79,12 +80,9 @@ def setup_optimizers(self):
optim_params = [
{ # add flow params first
'params': flow_params,
- 'lr': train_opt['lr_flow']
- },
- {
- 'params': normal_params,
- 'lr': train_opt['optim_g']['lr']
+ 'lr': train_opt['lr_flow'],
},
+ {'params': normal_params, 'lr': train_opt['optim_g']['lr']},
]
else:
optim_params = self.net_g.parameters()
@@ -121,7 +119,7 @@ def optimize_parameters(self, current_iter):
l_g_total = 0
loss_dict = OrderedDict()
- if (current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters):
+ if current_iter % self.net_d_iters == 0 and current_iter > self.net_d_init_iters:
# pixel loss
if self.cri_pix:
l_g_pix = self.cri_pix(self.output, self.gt)
diff --git a/basicsr/models/video_recurrent_model.py b/basicsr/models/video_recurrent_model.py
index 796ee57d5..d6634132d 100644
--- a/basicsr/models/video_recurrent_model.py
+++ b/basicsr/models/video_recurrent_model.py
@@ -1,6 +1,7 @@
-import torch
from collections import Counter
from os import path as osp
+
+import torch
from torch import distributed as dist
from tqdm import tqdm
@@ -8,12 +9,12 @@
from basicsr.utils import get_root_logger, imwrite, tensor2img
from basicsr.utils.dist_util import get_dist_info
from basicsr.utils.registry import MODEL_REGISTRY
+
from .video_base_model import VideoBaseModel
@MODEL_REGISTRY.register()
class VideoRecurrentModel(VideoBaseModel):
-
def __init__(self, opt):
super(VideoRecurrentModel, self).__init__(opt)
if self.is_train:
@@ -37,12 +38,9 @@ def setup_optimizers(self):
optim_params = [
{ # add normal params first
'params': normal_params,
- 'lr': train_opt['optim_g']['lr']
- },
- {
- 'params': flow_params,
- 'lr': train_opt['optim_g']['lr'] * flow_lr_mul
+ 'lr': train_opt['optim_g']['lr'],
},
+ {'params': flow_params, 'lr': train_opt['optim_g']['lr'] * flow_lr_mul},
]
optim_type = train_opt['optim_g'].pop('type')
@@ -78,7 +76,8 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
num_frame_each_folder = Counter(dataset.data_info['folder'])
for folder, num_frame in num_frame_each_folder.items():
self.metric_results[folder] = torch.zeros(
- num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda')
+ num_frame, len(self.opt['val']['metrics']), dtype=torch.float32, device='cuda'
+ )
# initialize the best metric results
self._initialize_best_metric_results(dataset_name)
# zero self.metric_results
@@ -140,11 +139,19 @@ def dist_validation(self, dataloader, current_iter, tb_logger, save_img):
clip_ = val_data['lq_path'].split('/')[-3]
seq_ = val_data['lq_path'].split('/')[-2]
name_ = f'{clip_}_{seq_}'
- img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
- f"{name_}_{self.opt['name']}.png")
+ img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ folder,
+ f'{name_}_{self.opt["name"]}.png',
+ )
else: # others
- img_path = osp.join(self.opt['path']['visualization'], dataset_name, folder,
- f"{idx:08d}_{self.opt['name']}.png")
+ img_path = osp.join(
+ self.opt['path']['visualization'],
+ dataset_name,
+ folder,
+ f'{idx:08d}_{self.opt["name"]}.png',
+ )
# image name only for REDS dataset
imwrite(result_img, img_path)
diff --git a/basicsr/ops/dcn/__init__.py b/basicsr/ops/dcn/__init__.py
index 32e3592f8..dee92b694 100644
--- a/basicsr/ops/dcn/__init__.py
+++ b/basicsr/ops/dcn/__init__.py
@@ -1,7 +1,17 @@
-from .deform_conv import (DeformConv, DeformConvPack, ModulatedDeformConv, ModulatedDeformConvPack, deform_conv,
- modulated_deform_conv)
+from .deform_conv import (
+ DeformConv,
+ DeformConvPack,
+ ModulatedDeformConv,
+ ModulatedDeformConvPack,
+ deform_conv,
+ modulated_deform_conv,
+)
__all__ = [
- 'DeformConv', 'DeformConvPack', 'ModulatedDeformConv', 'ModulatedDeformConvPack', 'deform_conv',
- 'modulated_deform_conv'
+ 'DeformConv',
+ 'DeformConvPack',
+ 'ModulatedDeformConv',
+ 'ModulatedDeformConvPack',
+ 'deform_conv',
+ 'modulated_deform_conv',
]
diff --git a/basicsr/ops/dcn/deform_conv.py b/basicsr/ops/dcn/deform_conv.py
index 6268ca825..1cf71abe3 100644
--- a/basicsr/ops/dcn/deform_conv.py
+++ b/basicsr/ops/dcn/deform_conv.py
@@ -1,5 +1,6 @@
import math
import os
+
import torch
from torch import nn as nn
from torch.autograd import Function
@@ -10,6 +11,7 @@
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
+
module_path = os.path.dirname(__file__)
deform_conv_ext = load(
'deform_conv',
@@ -31,18 +33,10 @@
class DeformConvFunction(Function):
-
@staticmethod
- def forward(ctx,
- input,
- offset,
- weight,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- deformable_groups=1,
- im2col_step=64):
+ def forward(
+ ctx, input, offset, weight, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1, im2col_step=64
+ ):
if input is not None and input.dim() != 4:
raise ValueError(f'Expected 4D tensor as input, got {input.dim()}D tensor instead.')
ctx.stride = _pair(stride)
@@ -63,11 +57,25 @@ def forward(ctx,
else:
cur_im2col_step = min(ctx.im2col_step, input.shape[0])
assert (input.shape[0] % cur_im2col_step) == 0, 'im2col step must divide batchsize'
- deform_conv_ext.deform_conv_forward(input, weight,
- offset, output, ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
- weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
- ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
- ctx.deformable_groups, cur_im2col_step)
+ deform_conv_ext.deform_conv_forward(
+ input,
+ weight,
+ offset,
+ output,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ cur_im2col_step,
+ )
return output
@staticmethod
@@ -86,20 +94,49 @@ def backward(ctx, grad_output):
if ctx.needs_input_grad[0] or ctx.needs_input_grad[1]:
grad_input = torch.zeros_like(input)
grad_offset = torch.zeros_like(offset)
- deform_conv_ext.deform_conv_backward_input(input, offset, grad_output, grad_input,
- grad_offset, weight, ctx.bufs_[0], weight.size(3),
- weight.size(2), ctx.stride[1], ctx.stride[0], ctx.padding[1],
- ctx.padding[0], ctx.dilation[1], ctx.dilation[0], ctx.groups,
- ctx.deformable_groups, cur_im2col_step)
+ deform_conv_ext.deform_conv_backward_input(
+ input,
+ offset,
+ grad_output,
+ grad_input,
+ grad_offset,
+ weight,
+ ctx.bufs_[0],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ cur_im2col_step,
+ )
if ctx.needs_input_grad[2]:
grad_weight = torch.zeros_like(weight)
- deform_conv_ext.deform_conv_backward_parameters(input, offset, grad_output, grad_weight,
- ctx.bufs_[0], ctx.bufs_[1], weight.size(3),
- weight.size(2), ctx.stride[1], ctx.stride[0],
- ctx.padding[1], ctx.padding[0], ctx.dilation[1],
- ctx.dilation[0], ctx.groups, ctx.deformable_groups, 1,
- cur_im2col_step)
+ deform_conv_ext.deform_conv_backward_parameters(
+ input,
+ offset,
+ grad_output,
+ grad_weight,
+ ctx.bufs_[0],
+ ctx.bufs_[1],
+ weight.size(3),
+ weight.size(2),
+ ctx.stride[1],
+ ctx.stride[0],
+ ctx.padding[1],
+ ctx.padding[0],
+ ctx.dilation[1],
+ ctx.dilation[0],
+ ctx.groups,
+ ctx.deformable_groups,
+ 1,
+ cur_im2col_step,
+ )
return (grad_input, grad_offset, grad_weight, None, None, None, None, None)
@@ -112,26 +149,17 @@ def _output_size(input, weight, padding, dilation, stride):
pad = padding[d]
kernel = dilation[d] * (weight.size(d + 2) - 1) + 1
stride_ = stride[d]
- output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1, )
+ output_size += ((in_size + (2 * pad) - kernel) // stride_ + 1,)
if not all(map(lambda s: s > 0, output_size)):
raise ValueError(f'convolution input is too small (output would be {"x".join(map(str, output_size))})')
return output_size
class ModulatedDeformConvFunction(Function):
-
@staticmethod
- def forward(ctx,
- input,
- offset,
- mask,
- weight,
- bias=None,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- deformable_groups=1):
+ def forward(
+ ctx, input, offset, mask, weight, bias=None, stride=1, padding=0, dilation=1, groups=1, deformable_groups=1
+ ):
ctx.stride = stride
ctx.padding = padding
ctx.dilation = dilation
@@ -146,10 +174,27 @@ def forward(ctx,
ctx.save_for_backward(input, offset, mask, weight, bias)
output = input.new_empty(ModulatedDeformConvFunction._infer_shape(ctx, input, weight))
ctx._bufs = [input.new_empty(0), input.new_empty(0)]
- deform_conv_ext.modulated_deform_conv_forward(input, weight, bias, ctx._bufs[0], offset, mask, output,
- ctx._bufs[1], weight.shape[2], weight.shape[3], ctx.stride,
- ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
- ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ deform_conv_ext.modulated_deform_conv_forward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ output,
+ ctx._bufs[1],
+ weight.shape[2],
+ weight.shape[3],
+ ctx.stride,
+ ctx.stride,
+ ctx.padding,
+ ctx.padding,
+ ctx.dilation,
+ ctx.dilation,
+ ctx.groups,
+ ctx.deformable_groups,
+ ctx.with_bias,
+ )
return output
@staticmethod
@@ -163,11 +208,32 @@ def backward(ctx, grad_output):
grad_mask = torch.zeros_like(mask)
grad_weight = torch.zeros_like(weight)
grad_bias = torch.zeros_like(bias)
- deform_conv_ext.modulated_deform_conv_backward(input, weight, bias, ctx._bufs[0], offset, mask, ctx._bufs[1],
- grad_input, grad_weight, grad_bias, grad_offset, grad_mask,
- grad_output, weight.shape[2], weight.shape[3], ctx.stride,
- ctx.stride, ctx.padding, ctx.padding, ctx.dilation, ctx.dilation,
- ctx.groups, ctx.deformable_groups, ctx.with_bias)
+ deform_conv_ext.modulated_deform_conv_backward(
+ input,
+ weight,
+ bias,
+ ctx._bufs[0],
+ offset,
+ mask,
+ ctx._bufs[1],
+ grad_input,
+ grad_weight,
+ grad_bias,
+ grad_offset,
+ grad_mask,
+ grad_output,
+ weight.shape[2],
+ weight.shape[3],
+ ctx.stride,
+ ctx.stride,
+ ctx.padding,
+ ctx.padding,
+ ctx.dilation,
+ ctx.dilation,
+ ctx.groups,
+ ctx.deformable_groups,
+ ctx.with_bias,
+ )
if not ctx.with_bias:
grad_bias = None
@@ -189,17 +255,18 @@ def _infer_shape(ctx, input, weight):
class DeformConv(nn.Module):
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- deformable_groups=1,
- bias=False):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=False,
+ ):
super(DeformConv, self).__init__()
assert not bias
@@ -226,22 +293,23 @@ def reset_parameters(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
- stdv = 1. / math.sqrt(n)
+ stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
def forward(self, x, offset):
# To fix an assert error in deform_conv_cuda.cpp:128
# input image is smaller than kernel
- input_pad = (x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1])
+ input_pad = x.size(2) < self.kernel_size[0] or x.size(3) < self.kernel_size[1]
if input_pad:
pad_h = max(self.kernel_size[0] - x.size(2), 0)
pad_w = max(self.kernel_size[1] - x.size(3), 0)
x = F.pad(x, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
offset = F.pad(offset, (0, pad_w, 0, pad_h), 'constant', 0).contiguous()
- out = deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
- self.deformable_groups)
+ out = deform_conv(
+ x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups
+ )
if input_pad:
- out = out[:, :, :out.size(2) - pad_h, :out.size(3) - pad_w].contiguous()
+ out = out[:, :, : out.size(2) - pad_h, : out.size(3) - pad_w].contiguous()
return out
@@ -273,7 +341,8 @@ def __init__(self, *args, **kwargs):
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
- bias=True)
+ bias=True,
+ )
self.init_offset()
def init_offset(self):
@@ -282,22 +351,24 @@ def init_offset(self):
def forward(self, x):
offset = self.conv_offset(x)
- return deform_conv(x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups,
- self.deformable_groups)
+ return deform_conv(
+ x, offset, self.weight, self.stride, self.padding, self.dilation, self.groups, self.deformable_groups
+ )
class ModulatedDeformConv(nn.Module):
-
- def __init__(self,
- in_channels,
- out_channels,
- kernel_size,
- stride=1,
- padding=0,
- dilation=1,
- groups=1,
- deformable_groups=1,
- bias=True):
+ def __init__(
+ self,
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ dilation=1,
+ groups=1,
+ deformable_groups=1,
+ bias=True,
+ ):
super(ModulatedDeformConv, self).__init__()
self.in_channels = in_channels
self.out_channels = out_channels
@@ -323,14 +394,24 @@ def init_weights(self):
n = self.in_channels
for k in self.kernel_size:
n *= k
- stdv = 1. / math.sqrt(n)
+ stdv = 1.0 / math.sqrt(n)
self.weight.data.uniform_(-stdv, stdv)
if self.bias is not None:
self.bias.data.zero_()
def forward(self, x, offset, mask):
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
- self.groups, self.deformable_groups)
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
class ModulatedDeformConvPack(ModulatedDeformConv):
@@ -361,7 +442,8 @@ def __init__(self, *args, **kwargs):
stride=_pair(self.stride),
padding=_pair(self.padding),
dilation=_pair(self.dilation),
- bias=True)
+ bias=True,
+ )
self.init_weights()
def init_weights(self):
@@ -375,5 +457,15 @@ def forward(self, x):
o1, o2, mask = torch.chunk(out, 3, dim=1)
offset = torch.cat((o1, o2), dim=1)
mask = torch.sigmoid(mask)
- return modulated_deform_conv(x, offset, mask, self.weight, self.bias, self.stride, self.padding, self.dilation,
- self.groups, self.deformable_groups)
+ return modulated_deform_conv(
+ x,
+ offset,
+ mask,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ self.deformable_groups,
+ )
diff --git a/basicsr/ops/fused_act/fused_act.py b/basicsr/ops/fused_act/fused_act.py
index 88edc4454..77b62982b 100644
--- a/basicsr/ops/fused_act/fused_act.py
+++ b/basicsr/ops/fused_act/fused_act.py
@@ -1,6 +1,7 @@
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/fused_act.py # noqa:E501
import os
+
import torch
from torch import nn
from torch.autograd import Function
@@ -8,6 +9,7 @@
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
+
module_path = os.path.dirname(__file__)
fused_act_ext = load(
'fused',
@@ -28,7 +30,6 @@
class FusedLeakyReLUFunctionBackward(Function):
-
@staticmethod
def forward(ctx, grad_output, out, negative_slope, scale):
ctx.save_for_backward(out)
@@ -50,15 +51,15 @@ def forward(ctx, grad_output, out, negative_slope, scale):
@staticmethod
def backward(ctx, gradgrad_input, gradgrad_bias):
- out, = ctx.saved_tensors
- gradgrad_out = fused_act_ext.fused_bias_act(gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope,
- ctx.scale)
+ (out,) = ctx.saved_tensors
+ gradgrad_out = fused_act_ext.fused_bias_act(
+ gradgrad_input, gradgrad_bias, out, 3, 1, ctx.negative_slope, ctx.scale
+ )
return gradgrad_out, None, None, None
class FusedLeakyReLUFunction(Function):
-
@staticmethod
def forward(ctx, input, bias, negative_slope, scale):
empty = input.new_empty(0)
@@ -71,7 +72,7 @@ def forward(ctx, input, bias, negative_slope, scale):
@staticmethod
def backward(ctx, grad_output):
- out, = ctx.saved_tensors
+ (out,) = ctx.saved_tensors
grad_input, grad_bias = FusedLeakyReLUFunctionBackward.apply(grad_output, out, ctx.negative_slope, ctx.scale)
@@ -79,7 +80,6 @@ def backward(ctx, grad_output):
class FusedLeakyReLU(nn.Module):
-
def __init__(self, channel, negative_slope=0.2, scale=2**0.5):
super().__init__()
diff --git a/basicsr/ops/upfirdn2d/upfirdn2d.py b/basicsr/ops/upfirdn2d/upfirdn2d.py
index d6122d59a..bf955add7 100644
--- a/basicsr/ops/upfirdn2d/upfirdn2d.py
+++ b/basicsr/ops/upfirdn2d/upfirdn2d.py
@@ -1,6 +1,7 @@
# modify from https://github.com/rosinality/stylegan2-pytorch/blob/master/op/upfirdn2d.py # noqa:E501
import os
+
import torch
from torch.autograd import Function
from torch.nn import functional as F
@@ -8,6 +9,7 @@
BASICSR_JIT = os.getenv('BASICSR_JIT')
if BASICSR_JIT == 'True':
from torch.utils.cpp_extension import load
+
module_path = os.path.dirname(__file__)
upfirdn2d_ext = load(
'upfirdn2d',
@@ -28,7 +30,6 @@
class UpFirDn2dBackward(Function):
-
@staticmethod
def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size, out_size):
@@ -71,7 +72,7 @@ def forward(ctx, grad_output, kernel, grad_kernel, up, down, pad, g_pad, in_size
@staticmethod
def backward(ctx, gradgrad_input):
- kernel, = ctx.saved_tensors
+ (kernel,) = ctx.saved_tensors
gradgrad_input = gradgrad_input.reshape(-1, ctx.in_size[2], ctx.in_size[3], 1)
@@ -95,7 +96,6 @@ def backward(ctx, gradgrad_input):
class UpFirDn2d(Function):
-
@staticmethod
def forward(ctx, input, kernel, up, down, pad):
up_x, up_y = up
@@ -171,7 +171,12 @@ def upfirdn2d_native(input, kernel, up_x, up_y, down_x, down_y, pad_x0, pad_x1,
out = out.view(-1, in_h * up_y, in_w * up_x, minor)
out = F.pad(out, [0, 0, max(pad_x0, 0), max(pad_x1, 0), max(pad_y0, 0), max(pad_y1, 0)])
- out = out[:, max(-pad_y0, 0):out.shape[1] - max(-pad_y1, 0), max(-pad_x0, 0):out.shape[2] - max(-pad_x1, 0), :, ]
+ out = out[
+ :,
+ max(-pad_y0, 0) : out.shape[1] - max(-pad_y1, 0),
+ max(-pad_x0, 0) : out.shape[2] - max(-pad_x1, 0),
+ :,
+ ]
out = out.permute(0, 3, 1, 2)
out = out.reshape([-1, 1, in_h * up_y + pad_y0 + pad_y1, in_w * up_x + pad_x0 + pad_x1])
diff --git a/basicsr/test.py b/basicsr/test.py
index 53cb3b7aa..24e2c0ce2 100644
--- a/basicsr/test.py
+++ b/basicsr/test.py
@@ -1,7 +1,8 @@
import logging
-import torch
from os import path as osp
+import torch
+
from basicsr.data import build_dataloader, build_dataset
from basicsr.models import build_model
from basicsr.utils import get_env_info, get_root_logger, get_time_str, make_exp_dirs
@@ -17,7 +18,7 @@ def test_pipeline(root_path):
# mkdir and initialize loggers
make_exp_dirs(opt)
- log_file = osp.join(opt['path']['log'], f"test_{opt['name']}_{get_time_str()}.log")
+ log_file = osp.join(opt['path']['log'], f'test_{opt["name"]}_{get_time_str()}.log')
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
@@ -27,8 +28,9 @@ def test_pipeline(root_path):
for _, dataset_opt in sorted(opt['datasets'].items()):
test_set = build_dataset(dataset_opt)
test_loader = build_dataloader(
- test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
- logger.info(f"Number of test images in {dataset_opt['name']}: {len(test_set)}")
+ test_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']
+ )
+ logger.info(f'Number of test images in {dataset_opt["name"]}: {len(test_set)}')
test_loaders.append(test_loader)
# create model
diff --git a/basicsr/train.py b/basicsr/train.py
index e02d98fe0..114516bb2 100644
--- a/basicsr/train.py
+++ b/basicsr/train.py
@@ -2,23 +2,38 @@
import logging
import math
import time
-import torch
from os import path as osp
+import torch
+
from basicsr.data import build_dataloader, build_dataset
from basicsr.data.data_sampler import EnlargedSampler
from basicsr.data.prefetch_dataloader import CPUPrefetcher, CUDAPrefetcher
from basicsr.models import build_model
-from basicsr.utils import (AvgTimer, MessageLogger, check_resume, get_env_info, get_root_logger, get_time_str,
- init_tb_logger, init_wandb_logger, make_exp_dirs, mkdir_and_rename, scandir)
+from basicsr.utils import (
+ AvgTimer,
+ MessageLogger,
+ check_resume,
+ get_env_info,
+ get_root_logger,
+ get_time_str,
+ init_tb_logger,
+ init_wandb_logger,
+ make_exp_dirs,
+ mkdir_and_rename,
+ scandir,
+)
from basicsr.utils.options import copy_opt_file, dict2str, parse_options
def init_tb_loggers(opt):
# initialize wandb logger before tensorboard logger to allow proper sync
- if (opt['logger'].get('wandb') is not None) and (opt['logger']['wandb'].get('project')
- is not None) and ('debug' not in opt['name']):
- assert opt['logger'].get('use_tb_logger') is True, ('should turn on tensorboard when using wandb')
+ if (
+ (opt['logger'].get('wandb') is not None)
+ and (opt['logger']['wandb'].get('project') is not None)
+ and ('debug' not in opt['name'])
+ ):
+ assert opt['logger'].get('use_tb_logger') is True, 'should turn on tensorboard when using wandb'
init_wandb_logger(opt)
tb_logger = None
if opt['logger'].get('use_tb_logger') and 'debug' not in opt['name']:
@@ -40,23 +55,28 @@ def create_train_val_dataloader(opt, logger):
num_gpu=opt['num_gpu'],
dist=opt['dist'],
sampler=train_sampler,
- seed=opt['manual_seed'])
+ seed=opt['manual_seed'],
+ )
num_iter_per_epoch = math.ceil(
- len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size']))
+ len(train_set) * dataset_enlarge_ratio / (dataset_opt['batch_size_per_gpu'] * opt['world_size'])
+ )
total_iters = int(opt['train']['total_iter'])
total_epochs = math.ceil(total_iters / (num_iter_per_epoch))
- logger.info('Training statistics:'
- f'\n\tNumber of train images: {len(train_set)}'
- f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
- f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
- f'\n\tWorld size (gpu number): {opt["world_size"]}'
- f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
- f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.')
+ logger.info(
+ 'Training statistics:'
+ f'\n\tNumber of train images: {len(train_set)}'
+ f'\n\tDataset enlarge ratio: {dataset_enlarge_ratio}'
+ f'\n\tBatch size per gpu: {dataset_opt["batch_size_per_gpu"]}'
+ f'\n\tWorld size (gpu number): {opt["world_size"]}'
+ f'\n\tRequire iter number per epoch: {num_iter_per_epoch}'
+ f'\n\tTotal epochs: {total_epochs}; iters: {total_iters}.'
+ )
elif phase.split('_')[0] == 'val':
val_set = build_dataset(dataset_opt)
val_loader = build_dataloader(
- val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed'])
+ val_set, dataset_opt, num_gpu=opt['num_gpu'], dist=opt['dist'], sampler=None, seed=opt['manual_seed']
+ )
logger.info(f'Number of val images/folders in {dataset_opt["name"]}: {len(val_set)}')
val_loaders.append(val_loader)
else:
@@ -109,7 +129,7 @@ def train_pipeline(root_path):
# WARNING: should not use get_root_logger in the above codes, including the called functions
# Otherwise the logger will not be properly initialized
- log_file = osp.join(opt['path']['log'], f"train_{opt['name']}_{get_time_str()}.log")
+ log_file = osp.join(opt['path']['log'], f'train_{opt["name"]}_{get_time_str()}.log')
logger = get_root_logger(logger_name='basicsr', log_level=logging.INFO, log_file=log_file)
logger.info(get_env_info())
logger.info(dict2str(opt))
@@ -124,7 +144,7 @@ def train_pipeline(root_path):
model = build_model(opt)
if resume_state: # resume training
model.resume_training(resume_state) # handle optimizers and schedulers
- logger.info(f"Resuming training from epoch: {resume_state['epoch']}, iter: {resume_state['iter']}.")
+ logger.info(f'Resuming training from epoch: {resume_state["epoch"]}, iter: {resume_state["iter"]}.')
start_epoch = resume_state['epoch']
current_iter = resume_state['iter']
else:
diff --git a/basicsr/utils/__init__.py b/basicsr/utils/__init__.py
index 9569c5078..077ed14c2 100644
--- a/basicsr/utils/__init__.py
+++ b/basicsr/utils/__init__.py
@@ -43,5 +43,5 @@
'USMSharp',
'usm_sharp',
# options
- 'yaml_load'
+ 'yaml_load',
]
diff --git a/basicsr/utils/color_util.py b/basicsr/utils/color_util.py
index 4740d5c98..445c85023 100644
--- a/basicsr/utils/color_util.py
+++ b/basicsr/utils/color_util.py
@@ -29,8 +29,11 @@ def rgb2ycbcr(img, y_only=False):
if y_only:
out_img = np.dot(img, [65.481, 128.553, 24.966]) + 16.0
else:
- out_img = np.matmul(
- img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [16, 128, 128]
+ out_img = np.matmul(img, [[65.481, -37.797, 112.0], [128.553, -74.203, -93.786], [24.966, 112.0, -18.214]]) + [
+ 16,
+ 128,
+ 128,
+ ]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
@@ -62,8 +65,11 @@ def bgr2ycbcr(img, y_only=False):
if y_only:
out_img = np.dot(img, [24.966, 128.553, 65.481]) + 16.0
else:
- out_img = np.matmul(
- img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [16, 128, 128]
+ out_img = np.matmul(img, [[24.966, 112.0, -18.214], [128.553, -74.203, -93.786], [65.481, -37.797, 112.0]]) + [
+ 16,
+ 128,
+ 128,
+ ]
out_img = _convert_output_type_range(out_img, img_type)
return out_img
@@ -91,8 +97,9 @@ def ycbcr2rgb(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
- out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071],
- [0.00625893, -0.00318811, 0]]) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
+ out_img = np.matmul(
+ img, [[0.00456621, 0.00456621, 0.00456621], [0, -0.00153632, 0.00791071], [0.00625893, -0.00318811, 0]]
+ ) * 255.0 + [-222.921, 135.576, -276.836] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
@@ -120,8 +127,9 @@ def ycbcr2bgr(img):
"""
img_type = img.dtype
img = _convert_input_type_range(img) * 255
- out_img = np.matmul(img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0],
- [0, -0.00318811, 0.00625893]]) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
+ out_img = np.matmul(
+ img, [[0.00456621, 0.00456621, 0.00456621], [0.00791071, -0.00153632, 0], [0, -0.00318811, 0.00625893]]
+ ) * 255.0 + [-276.836, 135.576, -222.921] # noqa: E126
out_img = _convert_output_type_range(out_img, img_type)
return out_img
@@ -147,7 +155,7 @@ def _convert_input_type_range(img):
if img_type == np.float32:
pass
elif img_type == np.uint8:
- img /= 255.
+ img /= 255.0
else:
raise TypeError(f'The img type should be np.float32 or np.uint8, but got {img_type}')
return img
@@ -179,7 +187,7 @@ def _convert_output_type_range(img, dst_type):
if dst_type == np.uint8:
img = img.round()
else:
- img /= 255.
+ img /= 255.0
return img.astype(dst_type)
@@ -204,5 +212,5 @@ def rgb2ycbcr_pt(img, y_only=False):
bias = torch.tensor([16, 128, 128]).view(1, 3, 1, 1).to(img)
out_img = torch.matmul(img.permute(0, 2, 3, 1), weight).permute(0, 3, 1, 2) + bias
- out_img = out_img / 255.
+ out_img = out_img / 255.0
return out_img
diff --git a/basicsr/utils/diffjpeg.py b/basicsr/utils/diffjpeg.py
index 65f96b44f..00fd9cdb9 100644
--- a/basicsr/utils/diffjpeg.py
+++ b/basicsr/utils/diffjpeg.py
@@ -4,7 +4,9 @@
For images not divisible by 8
https://dsp.stackexchange.com/questions/35339/jpeg-dct-padding/35343#35343
"""
+
import itertools
+
import numpy as np
import torch
import torch.nn as nn
@@ -12,10 +14,18 @@
# ------------------------ utils ------------------------#
y_table = np.array(
- [[16, 11, 10, 16, 24, 40, 51, 61], [12, 12, 14, 19, 26, 58, 60, 55], [14, 13, 16, 24, 40, 57, 69, 56],
- [14, 17, 22, 29, 51, 87, 80, 62], [18, 22, 37, 56, 68, 109, 103, 77], [24, 35, 55, 64, 81, 104, 113, 92],
- [49, 64, 78, 87, 103, 121, 120, 101], [72, 92, 95, 98, 112, 100, 103, 99]],
- dtype=np.float32).T
+ [
+ [16, 11, 10, 16, 24, 40, 51, 61],
+ [12, 12, 14, 19, 26, 58, 60, 55],
+ [14, 13, 16, 24, 40, 57, 69, 56],
+ [14, 17, 22, 29, 51, 87, 80, 62],
+ [18, 22, 37, 56, 68, 109, 103, 77],
+ [24, 35, 55, 64, 81, 104, 113, 92],
+ [49, 64, 78, 87, 103, 121, 120, 101],
+ [72, 92, 95, 98, 112, 100, 103, 99],
+ ],
+ dtype=np.float32,
+).T
y_table = nn.Parameter(torch.from_numpy(y_table))
c_table = np.empty((8, 8), dtype=np.float32)
c_table.fill(99)
@@ -24,13 +34,12 @@
def diff_round(x):
- """ Differentiable rounding function
- """
- return torch.round(x) + (x - torch.round(x))**3
+ """Differentiable rounding function"""
+ return torch.round(x) + (x - torch.round(x)) ** 3
def quality_to_factor(quality):
- """ Calculate factor corresponding to quality
+ """Calculate factor corresponding to quality
Args:
quality(float): Quality for jpeg compression.
@@ -39,22 +48,22 @@ def quality_to_factor(quality):
float: Compression factor.
"""
if quality < 50:
- quality = 5000. / quality
+ quality = 5000.0 / quality
else:
- quality = 200. - quality * 2
- return quality / 100.
+ quality = 200.0 - quality * 2
+ return quality / 100.0
# ------------------------ compression ------------------------#
class RGB2YCbCrJpeg(nn.Module):
- """ Converts RGB image to YCbCr
- """
+ """Converts RGB image to YCbCr"""
def __init__(self):
super(RGB2YCbCrJpeg, self).__init__()
- matrix = np.array([[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]],
- dtype=np.float32).T
- self.shift = nn.Parameter(torch.tensor([0., 128., 128.]))
+ matrix = np.array(
+ [[0.299, 0.587, 0.114], [-0.168736, -0.331264, 0.5], [0.5, -0.418688, -0.081312]], dtype=np.float32
+ ).T
+ self.shift = nn.Parameter(torch.tensor([0.0, 128.0, 128.0]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
@@ -71,8 +80,7 @@ def forward(self, image):
class ChromaSubsampling(nn.Module):
- """ Chroma subsampling on CbCr channels
- """
+ """Chroma subsampling on CbCr channels"""
def __init__(self):
super(ChromaSubsampling, self).__init__()
@@ -96,8 +104,7 @@ def forward(self, image):
class BlockSplitting(nn.Module):
- """ Splitting image into patches
- """
+ """Splitting image into patches"""
def __init__(self):
super(BlockSplitting, self).__init__()
@@ -119,15 +126,14 @@ def forward(self, image):
class DCT8x8(nn.Module):
- """ Discrete Cosine Transformation
- """
+ """Discrete Cosine Transformation"""
def __init__(self):
super(DCT8x8, self).__init__()
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
tensor[x, y, u, v] = np.cos((2 * x + 1) * u * np.pi / 16) * np.cos((2 * y + 1) * v * np.pi / 16)
- alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
self.tensor = nn.Parameter(torch.from_numpy(tensor).float())
self.scale = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha) * 0.25).float())
@@ -146,7 +152,7 @@ def forward(self, image):
class YQuantize(nn.Module):
- """ JPEG Quantization for Y channel
+ """JPEG Quantization for Y channel
Args:
rounding(function): rounding function to use
@@ -176,7 +182,7 @@ def forward(self, image, factor=1):
class CQuantize(nn.Module):
- """ JPEG Quantization for CbCr channels
+ """JPEG Quantization for CbCr channels
Args:
rounding(function): rounding function to use
@@ -245,8 +251,7 @@ def forward(self, image, factor=1):
class YDequantize(nn.Module):
- """Dequantize Y channel
- """
+ """Dequantize Y channel"""
def __init__(self):
super(YDequantize, self).__init__()
@@ -270,8 +275,7 @@ def forward(self, image, factor=1):
class CDequantize(nn.Module):
- """Dequantize CbCr channel
- """
+ """Dequantize CbCr channel"""
def __init__(self):
super(CDequantize, self).__init__()
@@ -295,12 +299,11 @@ def forward(self, image, factor=1):
class iDCT8x8(nn.Module):
- """Inverse discrete Cosine Transformation
- """
+ """Inverse discrete Cosine Transformation"""
def __init__(self):
super(iDCT8x8, self).__init__()
- alpha = np.array([1. / np.sqrt(2)] + [1] * 7)
+ alpha = np.array([1.0 / np.sqrt(2)] + [1] * 7)
self.alpha = nn.Parameter(torch.from_numpy(np.outer(alpha, alpha)).float())
tensor = np.zeros((8, 8, 8, 8), dtype=np.float32)
for x, y, u, v in itertools.product(range(8), repeat=4):
@@ -322,8 +325,7 @@ def forward(self, image):
class BlockMerging(nn.Module):
- """Merge patches into image
- """
+ """Merge patches into image"""
def __init__(self):
super(BlockMerging, self).__init__()
@@ -346,8 +348,7 @@ def forward(self, patches, height, width):
class ChromaUpsampling(nn.Module):
- """Upsample chroma layers
- """
+ """Upsample chroma layers"""
def __init__(self):
super(ChromaUpsampling, self).__init__()
@@ -376,14 +377,13 @@ def repeat(x, k=2):
class YCbCr2RGBJpeg(nn.Module):
- """Converts YCbCr image to RGB JPEG
- """
+ """Converts YCbCr image to RGB JPEG"""
def __init__(self):
super(YCbCr2RGBJpeg, self).__init__()
- matrix = np.array([[1., 0., 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
- self.shift = nn.Parameter(torch.tensor([0, -128., -128.]))
+ matrix = np.array([[1.0, 0.0, 1.402], [1, -0.344136, -0.714136], [1, 1.772, 0]], dtype=np.float32).T
+ self.shift = nn.Parameter(torch.tensor([0, -128.0, -128.0]))
self.matrix = nn.Parameter(torch.from_numpy(matrix))
def forward(self, image):
@@ -496,11 +496,11 @@ def forward(self, x, quality):
from basicsr.utils import img2tensor, tensor2img
- img_gt = cv2.imread('test.png') / 255.
+ img_gt = cv2.imread('test.png') / 255.0
# -------------- cv2 -------------- #
encode_param = [int(cv2.IMWRITE_JPEG_QUALITY), 20]
- _, encimg = cv2.imencode('.jpg', img_gt * 255., encode_param)
+ _, encimg = cv2.imencode('.jpg', img_gt * 255.0, encode_param)
img_lq = np.float32(cv2.imdecode(encimg, 1))
cv2.imwrite('cv2_JPEG_20.png', img_lq)
diff --git a/basicsr/utils/dist_util.py b/basicsr/utils/dist_util.py
index 0fab887b2..4392e17d3 100644
--- a/basicsr/utils/dist_util.py
+++ b/basicsr/utils/dist_util.py
@@ -2,6 +2,7 @@
import functools
import os
import subprocess
+
import torch
import torch.distributed as dist
import torch.multiprocessing as mp
diff --git a/basicsr/utils/download_util.py b/basicsr/utils/download_util.py
index f73abd0e1..47832767b 100644
--- a/basicsr/utils/download_util.py
+++ b/basicsr/utils/download_util.py
@@ -1,9 +1,10 @@
import math
import os
+from urllib.parse import urlparse
+
import requests
from torch.hub import download_url_to_file, get_dir
from tqdm import tqdm
-from urllib.parse import urlparse
from .misc import sizeof_fmt
diff --git a/basicsr/utils/file_client.py b/basicsr/utils/file_client.py
index 89d83ab9e..4b9207dcf 100644
--- a/basicsr/utils/file_client.py
+++ b/basicsr/utils/file_client.py
@@ -32,6 +32,7 @@ class MemcachedBackend(BaseStorageBackend):
def __init__(self, server_list_cfg, client_cfg, sys_path=None):
if sys_path is not None:
import sys
+
sys.path.append(sys_path)
try:
import mc
@@ -47,6 +48,7 @@ def __init__(self, server_list_cfg, client_cfg, sys_path=None):
def get(self, filepath):
filepath = str(filepath)
import mc
+
self._client.Get(filepath, self._mc_buffer)
value_buf = mc.ConvertBuffer(self._mc_buffer)
return value_buf
@@ -104,8 +106,10 @@ def __init__(self, db_paths, client_keys='default', readonly=True, lock=False, r
self.db_paths = [str(v) for v in db_paths]
elif isinstance(db_paths, str):
self.db_paths = [str(db_paths)]
- assert len(client_keys) == len(self.db_paths), ('client_keys and db_paths should have the same length, '
- f'but received {len(client_keys)} and {len(self.db_paths)}.')
+ assert len(client_keys) == len(self.db_paths), (
+ 'client_keys and db_paths should have the same length, '
+ f'but received {len(client_keys)} and {len(self.db_paths)}.'
+ )
self._client = {}
for client, path in zip(client_keys, self.db_paths):
@@ -119,7 +123,7 @@ def get(self, filepath, client_key):
client_key (str): Used for distinguishing different lmdb envs.
"""
filepath = str(filepath)
- assert client_key in self._client, (f'client_key {client_key} is not in lmdb clients.')
+ assert client_key in self._client, f'client_key {client_key} is not in lmdb clients.'
client = self._client[client_key]
with client.begin(write=False) as txn:
value_buf = txn.get(filepath.encode('ascii'))
@@ -150,8 +154,9 @@ class FileClient(object):
def __init__(self, backend='disk', **kwargs):
if backend not in self._backends:
- raise ValueError(f'Backend {backend} is not supported. Currently supported ones'
- f' are {list(self._backends.keys())}')
+ raise ValueError(
+ f'Backend {backend} is not supported. Currently supported ones are {list(self._backends.keys())}'
+ )
self.backend = backend
self.client = self._backends[backend](**kwargs)
diff --git a/basicsr/utils/flow_util.py b/basicsr/utils/flow_util.py
index 3d7180b4e..7b9f654c4 100644
--- a/basicsr/utils/flow_util.py
+++ b/basicsr/utils/flow_util.py
@@ -1,7 +1,8 @@
# Modified from https://github.com/open-mmlab/mmcv/blob/master/mmcv/video/optflow.py # noqa: E501
+import os
+
import cv2
import numpy as np
-import os
def flowread(flow_path, quantize=False, concat_axis=0, *args, **kwargs):
diff --git a/basicsr/utils/img_process_util.py b/basicsr/utils/img_process_util.py
index 52e02f099..cdd79f6da 100644
--- a/basicsr/utils/img_process_util.py
+++ b/basicsr/utils/img_process_util.py
@@ -61,7 +61,6 @@ def usm_sharp(img, weight=0.5, radius=50, threshold=10):
class USMSharp(torch.nn.Module):
-
def __init__(self, radius=50, sigma=0):
super(USMSharp, self).__init__()
if radius % 2 == 0:
diff --git a/basicsr/utils/img_util.py b/basicsr/utils/img_util.py
index 3a5f1da09..3293a017e 100644
--- a/basicsr/utils/img_util.py
+++ b/basicsr/utils/img_util.py
@@ -1,7 +1,8 @@
-import cv2
import math
-import numpy as np
import os
+
+import cv2
+import numpy as np
import torch
from torchvision.utils import make_grid
@@ -128,7 +129,7 @@ def imfrombytes(content, flag='color', float32=False):
imread_flags = {'color': cv2.IMREAD_COLOR, 'grayscale': cv2.IMREAD_GRAYSCALE, 'unchanged': cv2.IMREAD_UNCHANGED}
img = cv2.imdecode(img_np, imread_flags[flag])
if float32:
- img = img.astype(np.float32) / 255.
+ img = img.astype(np.float32) / 255.0
return img
diff --git a/basicsr/utils/lmdb_util.py b/basicsr/utils/lmdb_util.py
index a2b45ce01..4e529bc46 100644
--- a/basicsr/utils/lmdb_util.py
+++ b/basicsr/utils/lmdb_util.py
@@ -1,20 +1,23 @@
-import cv2
-import lmdb
import sys
from multiprocessing import Pool
from os import path as osp
+
+import cv2
+import lmdb
from tqdm import tqdm
-def make_lmdb_from_imgs(data_path,
- lmdb_path,
- img_path_list,
- keys,
- batch=5000,
- compress_level=1,
- multiprocessing_read=False,
- n_thread=40,
- map_size=None):
+def make_lmdb_from_imgs(
+ data_path,
+ lmdb_path,
+ img_path_list,
+ keys,
+ batch=5000,
+ compress_level=1,
+ multiprocessing_read=False,
+ n_thread=40,
+ map_size=None,
+):
"""Make lmdb from images.
Contents of lmdb. The file structure is:
@@ -61,8 +64,9 @@ def make_lmdb_from_imgs(data_path,
estimated size from images. Default: None
"""
- assert len(img_path_list) == len(keys), ('img_path_list and keys should have the same length, '
- f'but got {len(img_path_list)} and {len(keys)}')
+ assert len(img_path_list) == len(keys), (
+ f'img_path_list and keys should have the same length, but got {len(img_path_list)} and {len(keys)}'
+ )
print(f'Create lmdb for {data_path}, save to {lmdb_path}...')
print(f'Totoal images: {len(img_path_list)}')
if not lmdb_path.endswith('.lmdb'):
@@ -156,7 +160,7 @@ def read_img_worker(path, key, compress_level):
return (key, img_byte, (h, w, c))
-class LmdbMaker():
+class LmdbMaker:
"""LMDB Maker.
Args:
diff --git a/basicsr/utils/logger.py b/basicsr/utils/logger.py
index 73553dc66..988e65cb8 100644
--- a/basicsr/utils/logger.py
+++ b/basicsr/utils/logger.py
@@ -7,8 +7,7 @@
initialized_logger = {}
-class AvgTimer():
-
+class AvgTimer:
def __init__(self, window=200):
self.window = window # average window
self.current_time = 0
@@ -42,7 +41,7 @@ def get_avg_time(self):
return self.avg_time
-class MessageLogger():
+class MessageLogger:
"""Message logger for printing.
Args:
@@ -86,7 +85,7 @@ def __call__(self, log_vars):
current_iter = log_vars.pop('iter')
lrs = log_vars.pop('lrs')
- message = (f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:(')
+ message = f'[{self.exp_name[:5]}..][epoch:{epoch:3d}, iter:{current_iter:8,d}, lr:('
for v in lrs:
message += f'{v:.3e},'
message += ')] '
@@ -118,6 +117,7 @@ def __call__(self, log_vars):
@master_only
def init_tb_logger(log_dir):
from torch.utils.tensorboard import SummaryWriter
+
tb_logger = SummaryWriter(log_dir=log_dir)
return tb_logger
@@ -126,6 +126,7 @@ def init_tb_logger(log_dir):
def init_wandb_logger(opt):
"""We now only use wandb to sync tensorboard log."""
import wandb
+
logger = get_root_logger()
project = opt['logger']['wandb']['project']
@@ -194,6 +195,7 @@ def get_env_info():
import torchvision
from basicsr.version import __version__
+
msg = r"""
____ _ _____ ____
/ __ ) ____ _ _____ (_)_____/ ___/ / __ \
@@ -206,8 +208,10 @@ def get_env_info():
/ /_/ // /_/ // /_/ // /_/ / / /___/ /_/ // /__ / /< /_/
\____/ \____/ \____/ \____/ /_____/\____/ \___//_/|_| (_)
"""
- msg += ('\nVersion Information: '
- f'\n\tBasicSR: {__version__}'
- f'\n\tPyTorch: {torch.__version__}'
- f'\n\tTorchVision: {torchvision.__version__}')
+ msg += (
+ '\nVersion Information: '
+ f'\n\tBasicSR: {__version__}'
+ f'\n\tPyTorch: {torch.__version__}'
+ f'\n\tTorchVision: {torchvision.__version__}'
+ )
return msg
diff --git a/basicsr/utils/matlab_functions.py b/basicsr/utils/matlab_functions.py
index a201f79aa..d5b600a44 100644
--- a/basicsr/utils/matlab_functions.py
+++ b/basicsr/utils/matlab_functions.py
@@ -1,4 +1,5 @@
import math
+
import numpy as np
import torch
@@ -8,9 +9,9 @@ def cubic(x):
absx = torch.abs(x)
absx2 = absx**2
absx3 = absx**3
- return (1.5 * absx3 - 2.5 * absx2 + 1) * (
- (absx <= 1).type_as(absx)) + (-0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2) * (((absx > 1) *
- (absx <= 2)).type_as(absx))
+ return (1.5 * absx3 - 2.5 * absx2 + 1) * ((absx <= 1).type_as(absx)) + (
+ -0.5 * absx3 + 2.5 * absx2 - 4 * absx + 2
+ ) * (((absx > 1) * (absx <= 2)).type_as(absx))
def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width, antialiasing):
@@ -49,7 +50,8 @@ def calculate_weights_indices(in_length, out_length, scale, kernel, kernel_width
# The indices of the input pixels involved in computing the k-th output
# pixel are in row k of the indices matrix.
indices = left.view(out_length, 1).expand(out_length, p) + torch.linspace(0, p - 1, p).view(1, p).expand(
- out_length, p)
+ out_length, p
+ )
# The weights used to compute the k-th output pixel are in row k of the
# weights matrix.
@@ -120,10 +122,12 @@ def imresize(img, scale, antialiasing=True):
kernel = 'cubic'
# get weights and indices
- weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(in_h, out_h, scale, kernel, kernel_width,
- antialiasing)
- weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(in_w, out_w, scale, kernel, kernel_width,
- antialiasing)
+ weights_h, indices_h, sym_len_hs, sym_len_he = calculate_weights_indices(
+ in_h, out_h, scale, kernel, kernel_width, antialiasing
+ )
+ weights_w, indices_w, sym_len_ws, sym_len_we = calculate_weights_indices(
+ in_w, out_w, scale, kernel, kernel_width, antialiasing
+ )
# process H dimension
# symmetric copying
img_aug = torch.FloatTensor(in_c, in_h + sym_len_hs + sym_len_he, in_w)
@@ -144,7 +148,7 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_h):
idx = int(indices_h[i][0])
for j in range(in_c):
- out_1[j, i, :] = img_aug[j, idx:idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
+ out_1[j, i, :] = img_aug[j, idx : idx + kernel_width, :].transpose(0, 1).mv(weights_h[i])
# process W dimension
# symmetric copying
@@ -166,7 +170,7 @@ def imresize(img, scale, antialiasing=True):
for i in range(out_w):
idx = int(indices_w[i][0])
for j in range(in_c):
- out_2[j, :, i] = out_1_aug[j, :, idx:idx + kernel_width].mv(weights_w[i])
+ out_2[j, :, i] = out_1_aug[j, :, idx : idx + kernel_width].mv(weights_w[i])
if squeeze_flag:
out_2 = out_2.squeeze(0)
diff --git a/basicsr/utils/misc.py b/basicsr/utils/misc.py
index c8d4a1403..3e775b870 100644
--- a/basicsr/utils/misc.py
+++ b/basicsr/utils/misc.py
@@ -1,10 +1,11 @@
-import numpy as np
import os
import random
import time
-import torch
from os import path as osp
+import numpy as np
+import torch
+
from .dist_util import master_only
@@ -111,10 +112,11 @@ def check_resume(opt, resume_iter):
for network in networks:
name = f'pretrain_{network}'
basename = network.replace('network_', '')
- if opt['path'].get('ignore_resume_networks') is None or (network
- not in opt['path']['ignore_resume_networks']):
+ if opt['path'].get('ignore_resume_networks') is None or (
+ network not in opt['path']['ignore_resume_networks']
+ ):
opt['path'][name] = osp.join(opt['path']['models'], f'net_{basename}_{resume_iter}.pth')
- print(f"Set {name} to {opt['path'][name]}")
+ print(f'Set {name} to {opt["path"][name]}')
# change param_key to params in resume
param_keys = [key for key in opt['path'].keys() if key.startswith('param_key')]
diff --git a/basicsr/utils/options.py b/basicsr/utils/options.py
index 5c7155ecc..0d1369796 100644
--- a/basicsr/utils/options.py
+++ b/basicsr/utils/options.py
@@ -1,11 +1,12 @@
import argparse
import os
import random
-import torch
-import yaml
from collections import OrderedDict
from os import path as osp
+import torch
+import yaml
+
from basicsr.utils import set_random_seed
from basicsr.utils.dist_util import get_dist_info, init_dist, master_only
@@ -104,7 +105,8 @@ def parse_options(root_path, is_train=True):
parser.add_argument('--debug', action='store_true')
parser.add_argument('--local_rank', type=int, default=0)
parser.add_argument(
- '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999')
+ '--force_yml', nargs='+', default=None, help='Force to update yml files. Examples: train:ema_decay=0.999'
+ )
args = parser.parse_args()
# parse yml to dict
@@ -207,6 +209,7 @@ def copy_opt_file(opt_file, experiments_root):
import sys
import time
from shutil import copyfile
+
cmd = ' '.join(sys.argv)
filename = osp.join(experiments_root, osp.basename(opt_file))
copyfile(opt_file, filename)
diff --git a/basicsr/utils/plot_util.py b/basicsr/utils/plot_util.py
index 1e6da5bc2..c60e3bd40 100644
--- a/basicsr/utils/plot_util.py
+++ b/basicsr/utils/plot_util.py
@@ -66,7 +66,7 @@ def read_data_from_txt_1v(path, pattern):
def smooth_data(values, smooth_weight):
- """ Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
+ """Smooth data using 1st-order IIR low-pass filter (what tensorflow does).
Reference: https://github.com/tensorflow/tensorboard/blob/f801ebf1f9fbfe2baee1ddd65714d0bccc640fb1/tensorboard/plugins/scalar/vz_line_chart/vz-line-chart.ts#L704 # noqa: E501
diff --git a/basicsr/utils/registry.py b/basicsr/utils/registry.py
index 5e72ef7ff..008f1ee6b 100644
--- a/basicsr/utils/registry.py
+++ b/basicsr/utils/registry.py
@@ -1,7 +1,7 @@
# Modified from: https://github.com/facebookresearch/fvcore/blob/master/fvcore/common/registry.py # noqa: E501
-class Registry():
+class Registry:
"""
The registry that provides name -> object mapping, to support third-party
users' custom modules.
@@ -39,8 +39,7 @@ def _do_register(self, name, obj, suffix=None):
if isinstance(suffix, str):
name = name + '_' + suffix
- assert (name not in self._obj_map), (f"An object named '{name}' was already registered "
- f"in '{self._name}' registry!")
+ assert name not in self._obj_map, f"An object named '{name}' was already registered in '{self._name}' registry!"
self._obj_map[name] = obj
def register(self, obj=None, suffix=None):
diff --git a/docs/requirements.txt b/docs/requirements.txt
index 1be4caf6d..ea076deb6 100644
--- a/docs/requirements.txt
+++ b/docs/requirements.txt
@@ -1,21 +1,18 @@
-# add all requirements to auto generate the docs
-addict
-future
+# Requirements for building the docs
lmdb
-numpy
+numpy>=1.23
opencv-python
-Pillow
+Pillow>=9.4
pyyaml
-recommonmark
requests
-scikit-image
-scipy
-sphinx
-sphinx_intl
-sphinx_markdown_tables
-sphinx_rtd_theme
-tb-nightly
-torch>=1.7
-torchvision
+scikit-image>=0.19
+scipy>=1.10
+tensorboard>=2.12
+torch>=2.0
+torchvision>=0.15
tqdm
-yapf
+# Sphinx docs
+myst-parser
+sphinx>=7.0
+sphinx-rtd-theme
+sphinx-intl
diff --git a/inference/inference_basicvsr.py b/inference/inference_basicvsr.py
index 7b5e4b945..435ed11ed 100644
--- a/inference/inference_basicvsr.py
+++ b/inference/inference_basicvsr.py
@@ -1,8 +1,9 @@
import argparse
-import cv2
import glob
import os
import shutil
+
+import cv2
import torch
from basicsr.archs.basicvsr_arch import BasicVSR
@@ -25,7 +26,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSR_REDS4.pth')
parser.add_argument(
- '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder')
+ '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder'
+ )
parser.add_argument('--save_path', type=str, default='results/BasicVSR', help='save image path')
parser.add_argument('--interval', type=int, default=15, help='interval size')
args = parser.parse_args()
@@ -60,7 +62,7 @@ def main():
else:
for idx in range(0, num_imgs, args.interval):
interval = min(args.interval, num_imgs - idx)
- imgs, imgnames = read_img_seq(imgs_list[idx:idx + interval], return_imgname=True)
+ imgs, imgnames = read_img_seq(imgs_list[idx : idx + interval], return_imgname=True)
imgs = imgs.unsqueeze(0).to(device)
inference(imgs, imgnames, model, args.save_path)
diff --git a/inference/inference_basicvsrpp.py b/inference/inference_basicvsrpp.py
index b44aaa482..0f9707cca 100644
--- a/inference/inference_basicvsrpp.py
+++ b/inference/inference_basicvsrpp.py
@@ -1,8 +1,9 @@
import argparse
-import cv2
import glob
import os
import shutil
+
+import cv2
import torch
from basicsr.archs.basicvsrpp_arch import BasicVSRPlusPlus
@@ -25,7 +26,8 @@ def main():
parser = argparse.ArgumentParser()
parser.add_argument('--model_path', type=str, default='experiments/pretrained_models/BasicVSRPP_REDS4.pth')
parser.add_argument(
- '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder')
+ '--input_path', type=str, default='datasets/REDS4/sharp_bicubic/000', help='input test image folder'
+ )
parser.add_argument('--save_path', type=str, default='results/BasicVSRPP/000', help='save image path')
parser.add_argument('--interval', type=int, default=100, help='interval size')
args = parser.parse_args()
@@ -60,7 +62,7 @@ def main():
else:
for idx in range(0, num_imgs, args.interval):
interval = min(args.interval, num_imgs - idx)
- imgs, imgnames = read_img_seq(imgs_list[idx:idx + interval], return_imgname=True)
+ imgs, imgnames = read_img_seq(imgs_list[idx : idx + interval], return_imgname=True)
imgs = imgs.unsqueeze(0).to(device)
inference(imgs, imgnames, model, args.save_path)
diff --git a/inference/inference_dfdnet.py b/inference/inference_dfdnet.py
index 64a7a6456..c66111e76 100644
--- a/inference/inference_dfdnet.py
+++ b/inference/inference_dfdnet.py
@@ -1,7 +1,8 @@
import argparse
import glob
-import numpy as np
import os
+
+import numpy as np
import torch
import torchvision.transforms as transforms
from skimage import io
@@ -27,7 +28,8 @@ def get_part_location(landmarks):
# left eye
mean_left_eye = np.mean(landmarks[map_left_eye], 0) # (x, y)
half_len_left_eye = np.max(
- (np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16)) # A number
+ (np.max(np.max(landmarks[map_left_eye], 0) - np.min(landmarks[map_left_eye], 0)) / 2, 16)
+ ) # A number
loc_left_eye = np.hstack((mean_left_eye - half_len_left_eye + 1, mean_left_eye + half_len_left_eye)).astype(int)
loc_left_eye = torch.from_numpy(loc_left_eye).unsqueeze(0)
# (1, 4), the four numbers forms two coordinates in the diagonal
@@ -35,20 +37,20 @@ def get_part_location(landmarks):
# right eye
mean_right_eye = np.mean(landmarks[map_right_eye], 0)
half_len_right_eye = np.max(
- (np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16))
- loc_right_eye = np.hstack(
- (mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(int)
+ (np.max(np.max(landmarks[map_right_eye], 0) - np.min(landmarks[map_right_eye], 0)) / 2, 16)
+ )
+ loc_right_eye = np.hstack((mean_right_eye - half_len_right_eye + 1, mean_right_eye + half_len_right_eye)).astype(
+ int
+ )
loc_right_eye = torch.from_numpy(loc_right_eye).unsqueeze(0)
# nose
mean_nose = np.mean(landmarks[map_nose], 0)
- half_len_nose = np.max(
- (np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16)) # noqa: E126
+ half_len_nose = np.max((np.max(np.max(landmarks[map_nose], 0) - np.min(landmarks[map_nose], 0)) / 2, 16)) # noqa: E126
loc_nose = np.hstack((mean_nose - half_len_nose + 1, mean_nose + half_len_nose)).astype(int)
loc_nose = torch.from_numpy(loc_nose).unsqueeze(0)
# mouth
mean_mouth = np.mean(landmarks[map_mouth], 0)
- half_len_mouth = np.max(
- (np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16)) # noqa: E126
+ half_len_mouth = np.max((np.max(np.max(landmarks[map_mouth], 0) - np.min(landmarks[map_mouth], 0)) / 2, 16)) # noqa: E126
loc_mouth = np.hstack((mean_mouth - half_len_mouth + 1, mean_mouth + half_len_mouth)).astype(int)
loc_mouth = torch.from_numpy(loc_mouth).unsqueeze(0)
@@ -67,13 +69,15 @@ def get_part_location(landmarks):
parser.add_argument(
'--model_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth')
+ # noqa: E251
+ default='experiments/pretrained_models/DFDNet/DFDNet_official-d1fa5650.pth',
+ )
parser.add_argument(
'--dict_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth')
+ # noqa: E251
+ default='experiments/pretrained_models/DFDNet/DFDNet_dict_512-f79685f0.pth',
+ )
parser.add_argument('--test_path', type=str, default='datasets/TestWhole')
parser.add_argument('--upsample_num_times', type=int, default=1)
parser.add_argument('--save_inverse_affine', action='store_true')
@@ -89,20 +93,20 @@ def get_part_location(landmarks):
parser.add_argument(
'--detection_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat' # noqa: E501
+ # noqa: E251
+ default='experiments/pretrained_models/dlib/mmod_human_face_detector-4cb19393.dat', # noqa: E501
)
parser.add_argument(
'--landmark5_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat' # noqa: E501
+ # noqa: E251
+ default='experiments/pretrained_models/dlib/shape_predictor_5_face_landmarks-c4b1e980.dat', # noqa: E501
)
parser.add_argument(
'--landmark68_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat' # noqa: E501
+ # noqa: E251
+ default='experiments/pretrained_models/dlib/shape_predictor_68_face_landmarks-fbdc2cb8.dat', # noqa: E501
)
args = parser.parse_args()
@@ -137,7 +141,8 @@ def get_part_location(landmarks):
face_helper.init_dlib(args.detection_path, args.landmark5_path, args.landmark68_path)
# detect faces
num_det_faces = face_helper.detect_faces(
- img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest)
+ img_path, upsample_num_times=args.upsample_num_times, only_keep_largest=args.only_keep_largest
+ )
# get 5 face landmarks for each face
num_landmarks = face_helper.get_face_landmarks_5()
print(f'\tDetect {num_det_faces} faces, {num_landmarks} landmarks.')
diff --git a/inference/inference_esrgan.py b/inference/inference_esrgan.py
index e425b137f..3ffa979b2 100644
--- a/inference/inference_esrgan.py
+++ b/inference/inference_esrgan.py
@@ -1,8 +1,9 @@
import argparse
-import cv2
import glob
-import numpy as np
import os
+
+import cv2
+import numpy as np
import torch
from basicsr.archs.rrdbnet_arch import RRDBNet
@@ -13,8 +14,8 @@ def main():
parser.add_argument(
'--model_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth' # noqa: E501
+ # noqa: E251
+ default='experiments/pretrained_models/ESRGAN/ESRGAN_SRx4_DF2KOST_official-ff704c30.pth', # noqa: E501
)
parser.add_argument('--input', type=str, default='datasets/Set14/LRbicx4', help='input test image folder')
parser.add_argument('--output', type=str, default='results/ESRGAN', help='output folder')
@@ -32,7 +33,7 @@ def main():
imgname = os.path.splitext(os.path.basename(path))[0]
print('Testing', idx, imgname)
# read image
- img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
+ img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
# inference
diff --git a/inference/inference_ridnet.py b/inference/inference_ridnet.py
index 9825ba898..dfe7bf840 100644
--- a/inference/inference_ridnet.py
+++ b/inference/inference_ridnet.py
@@ -1,8 +1,9 @@
import argparse
-import cv2
import glob
-import numpy as np
import os
+
+import cv2
+import numpy as np
import torch
from tqdm import tqdm
@@ -17,8 +18,9 @@
parser.add_argument(
'--model_path',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/RIDNet/RIDNet.pth')
+ # noqa: E251
+ default='experiments/pretrained_models/RIDNet/RIDNet.pth',
+ )
args = parser.parse_args()
if args.test_path.endswith('/'): # solve when path ends with /
args.test_path = args.test_path[:-1]
diff --git a/inference/inference_stylegan2.py b/inference/inference_stylegan2.py
index 52545acec..2e11834f2 100644
--- a/inference/inference_stylegan2.py
+++ b/inference/inference_stylegan2.py
@@ -1,6 +1,7 @@
import argparse
import math
import os
+
import torch
from torchvision import utils
@@ -15,10 +16,9 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
for i in range(args.pics):
sample_z = torch.randn(args.sample, args.latent, device=device)
- sample, _ = g_ema([sample_z],
- truncation=args.truncation,
- randomize_noise=randomize_noise,
- truncation_latent=mean_latent)
+ sample, _ = g_ema(
+ [sample_z], truncation=args.truncation, randomize_noise=randomize_noise, truncation_latent=mean_latent
+ )
utils.save_image(
sample,
@@ -42,8 +42,8 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
parser.add_argument(
'--ckpt',
type=str,
- default= # noqa: E251
- 'experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-3ab41b38.pth' # noqa: E501
+ # noqa: E251
+ default='experiments/pretrained_models/StyleGAN/stylegan2_ffhq_config_f_1024_official-3ab41b38.pth', # noqa: E501
)
parser.add_argument('--channel_multiplier', type=int, default=2)
parser.add_argument('--randomize_noise', type=bool, default=True)
@@ -55,8 +55,9 @@ def generate(args, g_ema, device, mean_latent, randomize_noise):
os.makedirs('samples', exist_ok=True)
set_random_seed(2020)
- g_ema = StyleGAN2Generator(
- args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(device)
+ g_ema = StyleGAN2Generator(args.size, args.latent, args.n_mlp, channel_multiplier=args.channel_multiplier).to(
+ device
+ )
checkpoint = torch.load(args.ckpt)['params_ema']
g_ema.load_state_dict(checkpoint)
diff --git a/inference/inference_swinir.py b/inference/inference_swinir.py
index 28e9bdeca..cb9039585 100644
--- a/inference/inference_swinir.py
+++ b/inference/inference_swinir.py
@@ -1,9 +1,10 @@
# Modified from https://github.com/JingyunLiang/SwinIR
import argparse
-import cv2
import glob
-import numpy as np
import os
+
+import cv2
+import numpy as np
import torch
from torch.nn import functional as F
@@ -18,7 +19,8 @@ def main():
'--task',
type=str,
default='classical_sr',
- help='classical_sr, lightweight_sr, real_sr, gray_dn, color_dn, jpeg_car')
+ help='classical_sr, lightweight_sr, real_sr, gray_dn, color_dn, jpeg_car',
+ )
# dn: denoising; car: compression artifact removal
# TODO: it now only supports sr, need to adapt to dn and jpeg_car
parser.add_argument('--patch_size', type=int, default=64, help='training patch size')
@@ -29,7 +31,8 @@ def main():
parser.add_argument(
'--model_path',
type=str,
- default='experiments/pretrained_models/SwinIR/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth')
+ default='experiments/pretrained_models/SwinIR/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth',
+ )
args = parser.parse_args()
os.makedirs(args.output, exist_ok=True)
@@ -49,7 +52,7 @@ def main():
imgname = os.path.splitext(os.path.basename(path))[0]
print('Testing', idx, imgname)
# read image
- img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
+ img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
img = img.unsqueeze(0).to(device)
@@ -66,7 +69,7 @@ def main():
output = model(img)
_, _, h, w = output.size()
- output = output[:, :, 0:h - mod_pad_h * args.scale, 0:w - mod_pad_w * args.scale]
+ output = output[:, :, 0 : h - mod_pad_h * args.scale, 0 : w - mod_pad_w * args.scale]
# save image
output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
@@ -84,13 +87,14 @@ def define_model(args):
in_chans=3,
img_size=args.patch_size,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffle',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
# 002 lightweight image sr
# use 'pixelshuffledirect' to save parameters
@@ -100,13 +104,14 @@ def define_model(args):
in_chans=3,
img_size=64,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6],
embed_dim=60,
num_heads=[6, 6, 6, 6],
mlp_ratio=2,
upsampler='pixelshuffledirect',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
# 003 real-world image sr
elif args.task == 'real_sr':
@@ -117,13 +122,14 @@ def define_model(args):
in_chans=3,
img_size=64,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='nearest+conv',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
else:
# larger model size; use '3conv' to save parameters and memory; use ema for GAN training
model = SwinIR(
@@ -131,13 +137,14 @@ def define_model(args):
in_chans=3,
img_size=64,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6, 6, 6, 6, 6, 6],
embed_dim=248,
num_heads=[8, 8, 8, 8, 8, 8, 8, 8, 8],
mlp_ratio=2,
upsampler='nearest+conv',
- resi_connection='3conv')
+ resi_connection='3conv',
+ )
# 004 grayscale image denoising
elif args.task == 'gray_dn':
@@ -146,13 +153,14 @@ def define_model(args):
in_chans=1,
img_size=128,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
# 005 color image denoising
elif args.task == 'color_dn':
@@ -161,13 +169,14 @@ def define_model(args):
in_chans=3,
img_size=128,
window_size=8,
- img_range=1.,
+ img_range=1.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
# 006 JPEG compression artifact reduction
# use window_size=7 because JPEG encoding uses 8x8; use img_range=255 because it's slightly better than 1
@@ -177,13 +186,14 @@ def define_model(args):
in_chans=1,
img_size=126,
window_size=7,
- img_range=255.,
+ img_range=255.0,
depths=[6, 6, 6, 6, 6, 6],
embed_dim=180,
num_heads=[6, 6, 6, 6, 6, 6],
mlp_ratio=2,
upsampler='',
- resi_connection='1conv')
+ resi_connection='1conv',
+ )
loadnet = torch.load(args.model_path)
if 'params_ema' in loadnet:
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 000000000..946371866
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,88 @@
+[build-system]
+requires = ["setuptools>=68", "wheel", "cython", "numpy"]
+build-backend = "setuptools.build_meta"
+
+[project]
+name = "basicsr"
+dynamic = ["version"]
+description = "Open Source Image and Video Super-Resolution Toolbox"
+readme = { file = "README.md", content-type = "text/markdown" }
+license = { text = "Apache License 2.0" }
+authors = [
+ { name = "Xintao Wang", email = "xintao.wang@outlook.com" },
+]
+keywords = ["computer vision", "restoration", "super resolution"]
+requires-python = ">=3.9"
+classifiers = [
+ "Development Status :: 4 - Beta",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ "Programming Language :: Python :: 3",
+ "Programming Language :: Python :: 3.9",
+ "Programming Language :: Python :: 3.10",
+ "Programming Language :: Python :: 3.11",
+ "Programming Language :: Python :: 3.12",
+]
+dependencies = [
+ "numpy>=1.23",
+ "lmdb",
+ "opencv-python",
+ "Pillow>=9.4",
+ "pyyaml",
+ "requests",
+ "scikit-image>=0.19",
+ "scipy>=1.10",
+ "tensorboard>=2.12",
+ "torch>=2.0",
+ "torchvision>=0.15",
+ "tqdm",
+]
+
+[project.optional-dependencies]
+dev = [
+ "pytest>=7.0",
+ "ruff>=0.4",
+ "pre-commit>=3.0",
+]
+
+[project.urls]
+Homepage = "https://github.com/XPixelGroup/BasicSR"
+Repository = "https://github.com/XPixelGroup/BasicSR"
+
+[tool.setuptools]
+include-package-data = true
+
+[tool.setuptools.packages.find]
+exclude = ["options*", "datasets*", "experiments*", "results*", "tb_logger*", "wandb*"]
+
+[tool.setuptools.dynamic]
+version = { attr = "basicsr.version.__version__" }
+
+[tool.ruff]
+line-length = 120
+target-version = "py39"
+
+[tool.ruff.lint]
+select = ["E", "F", "W", "I"]
+ignore = ["E501"]
+
+[tool.ruff.format]
+quote-style = "single"
+
+[tool.isort]
+line_length = 120
+multi_line_output = 0
+known_standard_library = ["pkg_resources", "setuptools"]
+known_first_party = ["basicsr"]
+known_third_party = ["PIL", "cv2", "lmdb", "numpy", "pytest", "requests", "scipy", "skimage", "torch", "torchvision", "tqdm", "yaml"]
+no_lines_before = ["STDLIB", "LOCALFOLDER"]
+default_section = "THIRDPARTY"
+
+[tool.pytest.ini_options]
+testpaths = ["tests"]
+
+[tool.codespell]
+skip = ".git,./docs/build,*.cfg,*.toml"
+count = true
+quiet-level = 3
+ignore-words-list = "gool"
diff --git a/requirements.txt b/requirements.txt
index 82437c2b6..b2c297218 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1,15 +1,12 @@
-addict
-future
lmdb
-numpy>=1.17
+numpy>=1.23
opencv-python
-Pillow
+Pillow>=9.4
pyyaml
requests
-scikit-image
-scipy
-tb-nightly
-torch>=1.7
-torchvision
+scikit-image>=0.19
+scipy>=1.10
+tensorboard>=2.12
+torch>=2.0
+torchvision>=0.15
tqdm
-yapf
diff --git a/scripts/data_preparation/create_lmdb.py b/scripts/data_preparation/create_lmdb.py
index d25ddc46a..33801451e 100644
--- a/scripts/data_preparation/create_lmdb.py
+++ b/scripts/data_preparation/create_lmdb.py
@@ -161,7 +161,8 @@ def prepare_keys_vimeo90k(folder_path, train_list_path, mode):
parser.add_argument(
'--dataset',
type=str,
- help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' You may need to modify the corresponding configurations in codes."))
+ help=("Options: 'DIV2K', 'REDS', 'Vimeo90K' You may need to modify the corresponding configurations in codes."),
+ )
args = parser.parse_args()
dataset = args.dataset.lower()
if dataset == 'div2k':
diff --git a/scripts/data_preparation/download_datasets.py b/scripts/data_preparation/download_datasets.py
index c97e2f477..94071ee37 100644
--- a/scripts/data_preparation/download_datasets.py
+++ b/scripts/data_preparation/download_datasets.py
@@ -30,6 +30,7 @@ def download_dataset(dataset, file_ids):
extracted_path = save_path.replace('.zip', '')
print(f'Extract {save_path} to {extracted_path}')
import zipfile
+
with zipfile.ZipFile(save_path, 'r') as zip_ref:
zip_ref.extractall(extracted_path)
@@ -38,6 +39,7 @@ def download_dataset(dataset, file_ids):
if osp.isdir(subfolder):
print(f'Move {subfolder} to {extracted_path}')
import shutil
+
for path in glob.glob(osp.join(subfolder, '*')):
shutil.move(path, extracted_path)
shutil.rmtree(subfolder)
@@ -47,10 +49,8 @@ def download_dataset(dataset, file_ids):
parser = argparse.ArgumentParser()
parser.add_argument(
- 'dataset',
- type=str,
- help=("Options: 'Set5', 'Set14'. "
- "Set to 'all' if you want to download all the dataset."))
+ 'dataset', type=str, help=("Options: 'Set5', 'Set14'. Set to 'all' if you want to download all the dataset.")
+ )
args = parser.parse_args()
file_ids = {
@@ -60,7 +60,7 @@ def download_dataset(dataset, file_ids):
},
'Set14': {
'Set14.zip': '1vsw07sV8wGrRQ8UARe2fO5jjgy9QJy_E',
- }
+ },
}
if args.dataset == 'all':
diff --git a/scripts/data_preparation/extract_images_from_tfrecords.py b/scripts/data_preparation/extract_images_from_tfrecords.py
index 12cc3edba..be1c07681 100644
--- a/scripts/data_preparation/extract_images_from_tfrecords.py
+++ b/scripts/data_preparation/extract_images_from_tfrecords.py
@@ -1,9 +1,10 @@
import argparse
-import cv2
import glob
-import numpy as np
import os
+import cv2
+import numpy as np
+
from basicsr.utils.lmdb_util import LmdbMaker
@@ -159,7 +160,8 @@ def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, save_type='
"""
parser = argparse.ArgumentParser()
parser.add_argument(
- '--dataset', type=str, default='ffhq', help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'.")
+ '--dataset', type=str, default='ffhq', help="Dataset name. Options: 'ffhq' | 'celeba'. Default: 'ffhq'."
+ )
parser.add_argument(
'--tf_file',
type=str,
@@ -169,13 +171,16 @@ def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, save_type='
'Put quotes around the wildcard argument to prevent the shell '
'from expanding it.'
"Example: 'datasets/celeba/celeba_tfrecords/validation/validation-r08-s-*-of-*.tfrecords'" # noqa:E501
- ))
+ ),
+ )
parser.add_argument('--log_resolution', type=int, default=10, help='Log scale of resolution.')
parser.add_argument('--save_root', type=str, default='datasets/ffhq/', help='Save root path.')
parser.add_argument(
- '--save_type', type=str, default='img', help="Save type. Options: 'img' | 'lmdb'. Default: 'img'.")
+ '--save_type', type=str, default='img', help="Save type. Options: 'img' | 'lmdb'. Default: 'img'."
+ )
parser.add_argument(
- '--compress_level', type=int, default=1, help='Compress level when encoding images. Default: 1.')
+ '--compress_level', type=int, default=1, help='Compress level when encoding images. Default: 1.'
+ )
args = parser.parse_args()
try:
@@ -189,11 +194,13 @@ def make_ffhq_lmdb_from_imgs(folder_path, log_resolution, save_root, save_type='
args.log_resolution,
args.save_root,
save_type=args.save_type,
- compress_level=args.compress_level)
+ compress_level=args.compress_level,
+ )
else:
convert_celeba_tfrecords(
args.tf_file,
args.log_resolution,
args.save_root,
save_type=args.save_type,
- compress_level=args.compress_level)
+ compress_level=args.compress_level,
+ )
diff --git a/scripts/data_preparation/extract_subimages.py b/scripts/data_preparation/extract_subimages.py
index 55e9f6ba6..0dc8a2f98 100644
--- a/scripts/data_preparation/extract_subimages.py
+++ b/scripts/data_preparation/extract_subimages.py
@@ -1,9 +1,10 @@
-import cv2
-import numpy as np
import os
import sys
from multiprocessing import Pool
from os import path as osp
+
+import cv2
+import numpy as np
from tqdm import tqdm
from basicsr.utils import scandir
@@ -143,11 +144,13 @@ def worker(path, opt):
for x in h_space:
for y in w_space:
index += 1
- cropped_img = img[x:x + crop_size, y:y + crop_size, ...]
+ cropped_img = img[x : x + crop_size, y : y + crop_size, ...]
cropped_img = np.ascontiguousarray(cropped_img)
cv2.imwrite(
- osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'), cropped_img,
- [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']])
+ osp.join(opt['save_folder'], f'{img_name}_s{index:03d}{extension}'),
+ cropped_img,
+ [cv2.IMWRITE_PNG_COMPRESSION, opt['compression_level']],
+ )
process_info = f'Processing {img_name} ...'
return process_info
diff --git a/scripts/data_preparation/generate_meta_info.py b/scripts/data_preparation/generate_meta_info.py
index 7bb1aed35..95254ed2f 100644
--- a/scripts/data_preparation/generate_meta_info.py
+++ b/scripts/data_preparation/generate_meta_info.py
@@ -1,12 +1,12 @@
from os import path as osp
+
from PIL import Image
from basicsr.utils import scandir
def generate_meta_info_div2k():
- """Generate meta info for DIV2K dataset.
- """
+ """Generate meta info for DIV2K dataset."""
gt_folder = 'datasets/DIV2K/DIV2K_train_HR_sub/'
meta_info_txt = 'basicsr/data/meta_info/meta_info_DIV2K800sub_GT.txt'
diff --git a/scripts/data_preparation/prepare_hifacegan_dataset.py b/scripts/data_preparation/prepare_hifacegan_dataset.py
index cab2afad5..78c12df7d 100644
--- a/scripts/data_preparation/prepare_hifacegan_dataset.py
+++ b/scripts/data_preparation/prepare_hifacegan_dataset.py
@@ -1,5 +1,6 @@
-import cv2
import os
+
+import cv2
from tqdm import tqdm
@@ -16,8 +17,8 @@ def augment_image(self, x):
irange, jrange = (h + 15) // 16, (w + 15) // 16
for i in range(irange):
for j in range(jrange):
- mean = x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16].mean(axis=(0, 1))
- x[i * 16:(i + 1) * 16, j * 16:(j + 1) * 16] = mean
+ mean = x[i * 16 : (i + 1) * 16, j * 16 : (j + 1) * 16].mean(axis=(0, 1))
+ x[i * 16 : (i + 1) * 16, j * 16 : (j + 1) * 16] = mean
return x.astype('uint8')
@@ -32,41 +33,40 @@ class DegradationSimulator:
Custom degradation is possible by passing an inherited class from ia.augmentors
"""
- def __init__(self, ):
+ def __init__(
+ self,
+ ):
import imgaug.augmenters as ia
+
self.default_deg_templates = {
- 'sr4x':
- ia.Sequential([
- # It's almost like a 4x bicubic downsampling
- ia.Resize((0.25000, 0.25001), cv2.INTER_AREA),
- ia.Resize({
- 'height': 512,
- 'width': 512
- }, cv2.INTER_CUBIC),
- ]),
- 'sr4x8x':
- ia.Sequential([
- ia.Resize((0.125, 0.25), cv2.INTER_AREA),
- ia.Resize({
- 'height': 512,
- 'width': 512
- }, cv2.INTER_CUBIC),
- ]),
- 'denoise':
- ia.OneOf([
- ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True),
- ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True),
- ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True),
- ]),
- 'deblur':
- ia.OneOf([
- ia.MotionBlur(k=(10, 20)),
- ia.GaussianBlur((3.0, 8.0)),
- ]),
- 'jpeg':
- ia.JpegCompression(compression=(50, 85)),
- '16x':
- Mosaic16x(),
+ 'sr4x': ia.Sequential(
+ [
+ # It's almost like a 4x bicubic downsampling
+ ia.Resize((0.25000, 0.25001), cv2.INTER_AREA),
+ ia.Resize({'height': 512, 'width': 512}, cv2.INTER_CUBIC),
+ ]
+ ),
+ 'sr4x8x': ia.Sequential(
+ [
+ ia.Resize((0.125, 0.25), cv2.INTER_AREA),
+ ia.Resize({'height': 512, 'width': 512}, cv2.INTER_CUBIC),
+ ]
+ ),
+ 'denoise': ia.OneOf(
+ [
+ ia.AdditiveGaussianNoise(scale=(20, 40), per_channel=True),
+ ia.AdditiveLaplaceNoise(scale=(20, 40), per_channel=True),
+ ia.AdditivePoissonNoise(lam=(15, 30), per_channel=True),
+ ]
+ ),
+ 'deblur': ia.OneOf(
+ [
+ ia.MotionBlur(k=(10, 20)),
+ ia.GaussianBlur((3.0, 8.0)),
+ ]
+ ),
+ 'jpeg': ia.JpegCompression(compression=(50, 85)),
+ '16x': Mosaic16x(),
}
rand_deg_list = [
@@ -79,6 +79,7 @@ def __init__(self, ):
def create_training_dataset(self, deg, gt_folder, lq_folder=None):
from imgaug.augmenters.meta import Augmenter # baseclass
+
"""
Create a degradation simulator and apply it to GT images on the fly
Save the degraded result in the lq_folder (if None, name it as GT_deg)
@@ -91,7 +92,8 @@ def create_training_dataset(self, deg, gt_folder, lq_folder=None):
if isinstance(deg, str):
assert deg in self.default_deg_templates, (
- f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}')
+ f'Degration type {deg} not recognized: {"|".join(list(self.default_deg_templates.keys()))}'
+ )
deg = self.default_deg_templates[deg]
else:
assert isinstance(deg, Augmenter), f'Deg must be either str|Augmenter, got {deg}'
diff --git a/scripts/download_pretrained_models.py b/scripts/download_pretrained_models.py
index 7ed25e51a..0ad2775f4 100644
--- a/scripts/download_pretrained_models.py
+++ b/scripts/download_pretrained_models.py
@@ -31,15 +31,18 @@ def download_pretrained_models(method, file_ids):
parser.add_argument(
'method',
type=str,
- help=("Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
- "Set to 'all' to download all the models."))
+ help=(
+ "Options: 'ESRGAN', 'EDVR', 'StyleGAN', 'EDSR', 'DUF', 'DFDNet', 'dlib', 'TOF', 'flownet', 'BasicVSR'. "
+ "Set to 'all' to download all the models."
+ ),
+ )
args = parser.parse_args()
file_ids = {
'ESRGAN': {
'ESRGAN_SRx4_DF2KOST_official-ff704c30.pth': # file name
'1b3_bWZTjNO3iL2js1yWkJfjZykcQgvzT', # file id
- 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM'
+ 'ESRGAN_PSNR_SRx4_DF2K_official-150ff491.pth': '1swaV5iBMFfg-DL6ZyiARztbhutDCWXMM',
},
'EDVR': {
'EDVR_L_x4_SR_REDS_official-9f5f5039.pth': '127KXEjlCwfoPC1aXyDkluNwr9elwyHNb',
@@ -48,7 +51,7 @@ def download_pretrained_models(method, file_ids):
'EDVR_M_x4_SR_REDS_official-32075921.pth': '1dd6aFj-5w2v08VJTq5mS9OFsD-wALYD6',
'EDVR_L_x4_SRblur_REDS_official-983d7b8e.pth': '1GZz_87ybR8eAAY3X2HWwI3L6ny7-5Yvl',
'EDVR_L_deblur_REDS_official-ca46bd8c.pth': '1_ma2tgHscZtkIY2tEJkVdU-UP8bnqBRE',
- 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW'
+ 'EDVR_L_deblurcomp_REDS_official-0e988e5c.pth': '1fEoSeLFnHSBbIs95Au2W197p8e4ws4DW',
},
'StyleGAN': {
'stylegan2_ffhq_config_f_1024_official-3ab41b38.pth': '1qtdsT1FrvKQsFiW3OqOcIb-VS55TVy1g',
@@ -61,7 +64,7 @@ def download_pretrained_models(method, file_ids):
'stylegan2_car_config_f_512_official-e8fcab4f.pth': '14jS-nWNTguDSd1kTIX-tBHp2WdvK7hva',
'stylegan2_car_config_f_512_discriminator_official-5008e3d1.pth': '1UxkAzZ0zvw4KzBVOUpShCivsdXBS8Zi2',
'stylegan2_horse_config_f_256_official-26d57fee.pth': '12QsZ-mrO8_4gC0UrO15Jb3ykcQ88HxFx',
- 'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4'
+ 'stylegan2_horse_config_f_256_discriminator_official-be6c4c33.pth': '1me4ybSib72xA9ZxmzKsHDtP-eNCKw_X4',
},
'EDSR': {
'EDSR_Mx2_f64b16_DIV2K_official-3ba7b086.pth': '1mREMGVDymId3NzIc2u90sl_X4-pb4ZcV',
@@ -69,30 +72,26 @@ def download_pretrained_models(method, file_ids):
'EDSR_Mx4_f64b16_DIV2K_official-0c287733.pth': '1bCK6cFYU01uJudLgUUe-jgx-tZ3ikOWn',
'EDSR_Lx2_f256b32_DIV2K_official-be38e77d.pth': '15257lZCRZ0V6F9LzTyZFYbbPrqNjKyMU',
'EDSR_Lx3_f256b32_DIV2K_official-3660f70d.pth': '18q_D434sLG_rAZeHGonAX8dkqjoyZ2su',
- 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl'
+ 'EDSR_Lx4_f256b32_DIV2K_official-76ee1c8f.pth': '1GCi30YYCzgMCcgheGWGusP9aWKOAy5vl',
},
'DUF': {
'DUF_x2_16L_official-39537cb9.pth': '1e91cEZOlUUk35keK9EnuK0F54QegnUKo',
'DUF_x3_16L_official-34ce53ec.pth': '1XN6aQj20esM7i0hxTbfiZr_SL8i4PZ76',
'DUF_x4_16L_official-bf8f0cfa.pth': '1V_h9U1CZgLSHTv1ky2M3lvuH-hK5hw_J',
'DUF_x4_28L_official-cbada450.pth': '1M8w0AMBJW65MYYD-_8_be0cSH_SHhDQ4',
- 'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T'
- },
- 'TOF': {
- 'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'
+ 'DUF_x4_52L_official-483d2c78.pth': '1GcmEWNr7mjTygi-QCOVgQWOo5OCNbh_T',
},
+ 'TOF': {'tof_x4_vimeo90k_official-32c9e01f.pth': '1TgQiXXsvkTBFrQ1D0eKPgL10tQGu0gKb'},
'DFDNet': {
'DFDNet_dict_512-f79685f0.pth': '1iH00oMsoN_1OJaEQw3zP7_wqiAYMnY79',
- 'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe'
+ 'DFDNet_official-d1fa5650.pth': '1u6Sgcp8gVoy4uVTrOJKD3y9RuqH2JBAe',
},
'dlib': {
'mmod_human_face_detector-4cb19393.dat': '1FUM-hcoxNzFCOpCWbAUStBBMiU4uIGIL',
'shape_predictor_5_face_landmarks-c4b1e980.dat': '1PNPSmFjmbuuUDd5Mg5LDxyk7tu7TQv2F',
- 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni'
- },
- 'flownet': {
- 'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'
+ 'shape_predictor_68_face_landmarks-fbdc2cb8.dat': '1IneH-O-gNkG0SQpNCplwxtOAtRCkG2ni',
},
+ 'flownet': {'spynet_sintel_final-3d2a1287.pth': '1VZz1cikwTRVX7zXoD247DB7n5Tj_LQpF'},
'BasicVSR': {
'BasicVSR_REDS4-543c8261.pth': '1wLWdz18lWf9Z7lomHPkdySZ-_GV2920p',
'BasicVSR_Vimeo90K_BDx4-e9bf46eb.pth': '1baaf4RSpzs_zcDAF_s2CyArrGvLgmXxW',
@@ -101,8 +100,8 @@ def download_pretrained_models(method, file_ids):
'EDVR_Vimeo90K_pretrained_for_IconVSR-ee48ee92.pth': '16vR262NDVyVv5Q49xp2Sb-Llu05f63tt',
'IconVSR_REDS-aaa5367f.pth': '1b8ir754uIAFUSJ8YW_cmPzqer19AR7Hz',
'IconVSR_Vimeo90K_BDx4-cfcb7e00.pth': '13lp55s-YTd-fApx8tTy24bbHsNIGXdAH',
- 'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g'
- }
+ 'IconVSR_Vimeo90K_BIx4-35fec07c.pth': '1lWUB36ERjFbAspr-8UsopJ6xwOuWjh2g',
+ },
}
if args.method == 'all':
diff --git a/scripts/metrics/calculate_fid_folder.py b/scripts/metrics/calculate_fid_folder.py
index 71b02e1fe..26c004eb0 100644
--- a/scripts/metrics/calculate_fid_folder.py
+++ b/scripts/metrics/calculate_fid_folder.py
@@ -1,5 +1,6 @@
import argparse
import math
+
import numpy as np
import torch
from torch.utils.data import DataLoader
@@ -40,7 +41,8 @@ def calculate_fid_folder():
shuffle=False,
num_workers=args.num_workers,
sampler=None,
- drop_last=False)
+ drop_last=False,
+ )
args.num_sample = min(args.num_sample, len(dataset))
total_batch = math.ceil(args.num_sample / args.batch_size)
@@ -54,7 +56,7 @@ def data_generator(data_loader, total_batch):
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
- features = features[:args.num_sample]
+ features = features[: args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
diff --git a/scripts/metrics/calculate_fid_stats_from_datasets.py b/scripts/metrics/calculate_fid_stats_from_datasets.py
index 56e352920..9b210b76d 100644
--- a/scripts/metrics/calculate_fid_stats_from_datasets.py
+++ b/scripts/metrics/calculate_fid_stats_from_datasets.py
@@ -1,5 +1,6 @@
import argparse
import math
+
import numpy as np
import torch
from torch.utils.data import DataLoader
@@ -34,7 +35,8 @@ def calculate_stats_from_dataset():
# create dataloader
data_loader = DataLoader(
- dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False)
+ dataset=dataset, batch_size=args.batch_size, shuffle=False, num_workers=4, sampler=None, drop_last=False
+ )
total_batch = math.ceil(args.num_sample / args.batch_size)
def data_generator(data_loader, total_batch):
@@ -47,14 +49,15 @@ def data_generator(data_loader, total_batch):
features = extract_inception_features(data_generator(data_loader, total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
- features = features[:args.num_sample]
+ features = features[: args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
mean = np.mean(features, 0)
cov = np.cov(features, rowvar=False)
save_path = f'inception_{opt["name"]}_{args.size}.pth'
torch.save(
- dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False)
+ dict(name=opt['name'], size=args.size, mean=mean, cov=cov), save_path, _use_new_zipfile_serialization=False
+ )
if __name__ == '__main__':
diff --git a/scripts/metrics/calculate_lpips.py b/scripts/metrics/calculate_lpips.py
index 4170fb40e..5b0d46163 100644
--- a/scripts/metrics/calculate_lpips.py
+++ b/scripts/metrics/calculate_lpips.py
@@ -1,7 +1,8 @@
-import cv2
import glob
-import numpy as np
import os.path as osp
+
+import cv2
+import numpy as np
from torchvision.transforms.functional import normalize
from basicsr.utils import img2tensor
@@ -28,9 +29,11 @@ def main():
std = [0.5, 0.5, 0.5]
for i, img_path in enumerate(img_list):
basename, ext = osp.splitext(osp.basename(img_path))
- img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
- img_restored = cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(
- np.float32) / 255.
+ img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
+ img_restored = (
+ cv2.imread(osp.join(folder_restored, basename + suffix + ext), cv2.IMREAD_UNCHANGED).astype(np.float32)
+ / 255.0
+ )
img_gt, img_restored = img2tensor([img_gt, img_restored], bgr2rgb=True, float32=True)
# norm to [-1, 1]
@@ -40,7 +43,7 @@ def main():
# calculate lpips
lpips_val = loss_fn_vgg(img_restored.unsqueeze(0).cuda(), img_gt.unsqueeze(0).cuda())
- print(f'{i+1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
+ print(f'{i + 1:3d}: {basename:25}. \tLPIPS: {lpips_val:.6f}.')
lpips_all.append(lpips_val)
print(f'Average: LPIPS: {sum(lpips_all) / len(lpips_all):.6f}')
diff --git a/scripts/metrics/calculate_niqe.py b/scripts/metrics/calculate_niqe.py
index 1149b0faa..9c31a3539 100644
--- a/scripts/metrics/calculate_niqe.py
+++ b/scripts/metrics/calculate_niqe.py
@@ -1,8 +1,9 @@
import argparse
-import cv2
import os
import warnings
+import cv2
+
from basicsr.metrics import calculate_niqe
from basicsr.utils import scandir
@@ -19,7 +20,7 @@ def main(args):
with warnings.catch_warnings():
warnings.simplefilter('ignore', category=RuntimeWarning)
niqe_score = calculate_niqe(img, args.crop_border, input_order='HWC', convert_to='y')
- print(f'{i+1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
+ print(f'{i + 1:3d}: {basename:25}. \tNIQE: {niqe_score:.6f}')
niqe_all.append(niqe_score)
print(args.input)
diff --git a/scripts/metrics/calculate_psnr_ssim.py b/scripts/metrics/calculate_psnr_ssim.py
index 16de34622..aa63fcddd 100644
--- a/scripts/metrics/calculate_psnr_ssim.py
+++ b/scripts/metrics/calculate_psnr_ssim.py
@@ -1,15 +1,15 @@
import argparse
+from os import path as osp
+
import cv2
import numpy as np
-from os import path as osp
from basicsr.metrics import calculate_psnr, calculate_ssim
from basicsr.utils import bgr2ycbcr, scandir
def main(args):
- """Calculate PSNR and SSIM for images.
- """
+ """Calculate PSNR and SSIM for images."""
psnr_all = []
ssim_all = []
img_list_gt = sorted(list(scandir(args.gt, recursive=True, full_path=True)))
@@ -22,12 +22,12 @@ def main(args):
for i, img_path in enumerate(img_list_gt):
basename, ext = osp.splitext(osp.basename(img_path))
- img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+ img_gt = cv2.imread(img_path, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
if args.suffix == '':
img_path_restored = img_list_restored[i]
else:
img_path_restored = osp.join(args.restored, basename + args.suffix + ext)
- img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.
+ img_restored = cv2.imread(img_path_restored, cv2.IMREAD_UNCHANGED).astype(np.float32) / 255.0
if args.correct_mean_var:
mean_l = []
@@ -54,7 +54,7 @@ def main(args):
# calculate PSNR and SSIM
psnr = calculate_psnr(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
ssim = calculate_ssim(img_gt * 255, img_restored * 255, crop_border=args.crop_border, input_order='HWC')
- print(f'{i+1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
+ print(f'{i + 1:3d}: {basename:25}. \tPSNR: {psnr:.6f} dB, \tSSIM: {ssim:.6f}')
psnr_all.append(psnr)
ssim_all.append(ssim)
print(args.gt)
@@ -71,7 +71,8 @@ def main(args):
parser.add_argument(
'--test_y_channel',
action='store_true',
- help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.')
+ help='If True, test Y channel (In MatLab YCbCr format). If False, test RGB channels.',
+ )
parser.add_argument('--correct_mean_var', action='store_true', help='Correct the mean and var of restored images.')
args = parser.parse_args()
main(args)
diff --git a/scripts/metrics/calculate_stylegan2_fid.py b/scripts/metrics/calculate_stylegan2_fid.py
index c5564b8a3..5ddde14ee 100644
--- a/scripts/metrics/calculate_stylegan2_fid.py
+++ b/scripts/metrics/calculate_stylegan2_fid.py
@@ -1,5 +1,6 @@
import argparse
import math
+
import numpy as np
import torch
from torch import nn
@@ -28,7 +29,8 @@ def calculate_stylegan2_fid():
num_style_feat=512,
num_mlp=8,
channel_multiplier=args.channel_multiplier,
- resample_kernel=(1, 3, 3, 1))
+ resample_kernel=(1, 3, 3, 1),
+ )
generator.load_state_dict(torch.load(args.ckpt)['params_ema'])
generator = nn.DataParallel(generator).eval().to(device)
@@ -53,7 +55,7 @@ def sample_generator(total_batch):
features = extract_inception_features(sample_generator(total_batch), inception, total_batch, device)
features = features.numpy()
total_len = features.shape[0]
- features = features[:args.num_sample]
+ features = features[: args.num_sample]
print(f'Extracted {total_len} features, use the first {features.shape[0]} features to calculate stats.')
sample_mean = np.mean(features, 0)
sample_cov = np.cov(features, rowvar=False)
diff --git a/scripts/model_conversion/convert_dfdnet.py b/scripts/model_conversion/convert_dfdnet.py
index 5371d142b..eee37241a 100644
--- a/scripts/model_conversion/convert_dfdnet.py
+++ b/scripts/model_conversion/convert_dfdnet.py
@@ -34,7 +34,7 @@ def convert_net(ori_net, crt_net):
elif 'multi_scale_dilation' in crt_k:
if 'conv_blocks' in crt_k:
_, _, c, d, e = crt_k.split('.')
- ori_k = f'MSDilate.conv{int(c)+1}.{d}.{e}'
+ ori_k = f'MSDilate.conv{int(c) + 1}.{d}.{e}'
else:
ori_k = crt_k.replace('multi_scale_dilation.conv_fusion', 'MSDilate.convi')
@@ -53,9 +53,7 @@ def convert_net(ori_net, crt_net):
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
- raise ValueError('Wrong tensor size: \n'
- f'crt_net: {crt_net[crt_k].size()}\n'
- f'ori_net: {ori_net[ori_k].size()}')
+ raise ValueError(f'Wrong tensor size: \ncrt_net: {crt_net[crt_k].size()}\nori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
@@ -71,4 +69,5 @@ def convert_net(ori_net, crt_net):
torch.save(
dict(params=crt_net_params),
'experiments/pretrained_models/DFDNet/DFDNet_official.pth',
- _use_new_zipfile_serialization=False)
+ _use_new_zipfile_serialization=False,
+ )
diff --git a/scripts/model_conversion/convert_models.py b/scripts/model_conversion/convert_models.py
index 46bb085f9..fcd392788 100644
--- a/scripts/model_conversion/convert_models.py
+++ b/scripts/model_conversion/convert_models.py
@@ -36,7 +36,7 @@ def convert_edvr():
ori_k = crt_k.replace('predeblur.resblock_l', 'pre_deblur.RB_L')
elif 'predeblur.resblock_l1' in crt_k:
a, b, c, d, e = crt_k.split('.')
- ori_k = f'pre_deblur.RB_L1_{int(c)+1}.{d}.{e}'
+ ori_k = f'pre_deblur.RB_L1_{int(c) + 1}.{d}.{e}'
elif 'conv_l2' in crt_k:
ori_k = crt_k.replace('conv_l2_', 'fea_L2_conv')
@@ -63,8 +63,14 @@ def convert_edvr():
ori_k = f'pcd_align.L{level}_fea_conv.{d}'
elif 'pcd_align.cas_dcnpack' in crt_k:
ori_k = crt_k.replace('conv_offset', 'conv_offset_mask')
- elif ('conv_first' in crt_k or 'feature_extraction' in crt_k or 'pcd_align.cas_offset' in crt_k
- or 'upconv' in crt_k or 'conv_last' in crt_k or 'conv_1x1' in crt_k):
+ elif (
+ 'conv_first' in crt_k
+ or 'feature_extraction' in crt_k
+ or 'pcd_align.cas_offset' in crt_k
+ or 'upconv' in crt_k
+ or 'conv_last' in crt_k
+ or 'conv_1x1' in crt_k
+ ):
ori_k = crt_k
elif 'temporal_attn1' in crt_k:
@@ -157,7 +163,7 @@ def convert_rcan_model():
elif 'attention' in crt_k:
_, ai, _, bi, _, ci, d, di, e = crt_k.split('.')
- ori_k = f'body.{ai}.body.{bi}.body.{ci}.conv_du.{int(di)-1}.{e}'
+ ori_k = f'body.{ai}.body.{bi}.body.{ci}.conv_du.{int(di) - 1}.{e}'
elif 'rcab' in crt_k:
a, ai, b, bi, c, ci, d = crt_k.split('.')
ori_k = f'body.{ai}.body.{bi}.body.{ci}.{d}'
@@ -173,6 +179,7 @@ def convert_rcan_model():
def convert_esrgan_model():
from basicsr.archs.rrdbnet_arch import RRDBNet
+
rrdb = RRDBNet(3, 3, num_feat=64, num_block=23, num_grow_ch=32)
crt_net = rrdb.state_dict()
# for k, v in crt_net.items():
@@ -201,6 +208,7 @@ def convert_esrgan_model():
def convert_duf_model():
from basicsr.archs.duf_arch import DUF
+
scale = 2
duf = DUF(scale=scale, num_layer=16, adapt_official_weights=True)
crt_net = duf.state_dict()
@@ -211,7 +219,7 @@ def convert_duf_model():
# print('******')
# for k, v in ori_net.items():
# print(k)
- '''
+ """
for crt_k, crt_v in crt_net.items():
if 'conv3d1' in crt_k:
ori_k = crt_k.replace('conv3d1', 'conv3d_1')
@@ -269,7 +277,7 @@ def convert_duf_model():
print(crt_k)
crt_net[crt_k] = ori_net[ori_k]
- '''
+ """
# for 16 layers
for crt_k, _ in crt_net.items():
if 'conv3d1' in crt_k:
@@ -343,17 +351,17 @@ def convert_duf_model():
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
- crt_net['conv3d_r2.weight'][:scale**2, ...] = x1
- crt_net['conv3d_r2.weight'][scale**2:2 * (scale**2), ...] = x2
- crt_net['conv3d_r2.weight'][2 * (scale**2):, ...] = x3
+ crt_net['conv3d_r2.weight'][: scale**2, ...] = x1
+ crt_net['conv3d_r2.weight'][scale**2 : 2 * (scale**2), ...] = x2
+ crt_net['conv3d_r2.weight'][2 * (scale**2) :, ...] = x3
x = crt_net['conv3d_r2.bias'].clone()
x1 = x[::3, ...]
x2 = x[1::3, ...]
x3 = x[2::3, ...]
- crt_net['conv3d_r2.bias'][:scale**2, ...] = x1
- crt_net['conv3d_r2.bias'][scale**2:2 * (scale**2), ...] = x2
- crt_net['conv3d_r2.bias'][2 * (scale**2):, ...] = x3
+ crt_net['conv3d_r2.bias'][: scale**2, ...] = x1
+ crt_net['conv3d_r2.bias'][scale**2 : 2 * (scale**2), ...] = x2
+ crt_net['conv3d_r2.bias'][2 * (scale**2) :, ...] = x3
torch.save(crt_net, 'experiments/pretrained_models/DUF_x2_16L_official.pth')
diff --git a/scripts/model_conversion/convert_ridnet.py b/scripts/model_conversion/convert_ridnet.py
index d0b4c428f..a78f8df30 100644
--- a/scripts/model_conversion/convert_ridnet.py
+++ b/scripts/model_conversion/convert_ridnet.py
@@ -1,11 +1,13 @@
-import torch
from collections import OrderedDict
+import torch
+
from basicsr.archs.ridnet_arch import RIDNet
if __name__ == '__main__':
ori_net_checkpoint = torch.load(
- 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage)
+ 'experiments/pretrained_models/RIDNet/RIDNet_official_original.pt', map_location=lambda storage, loc: storage
+ )
rid_net = RIDNet(3, 64, 3)
new_ridnet_dict = OrderedDict()
diff --git a/scripts/model_conversion/convert_stylegan.py b/scripts/model_conversion/convert_stylegan.py
index 01999e143..b60a80e47 100644
--- a/scripts/model_conversion/convert_stylegan.py
+++ b/scripts/model_conversion/convert_stylegan.py
@@ -37,9 +37,7 @@ def convert_net_g(ori_net, crt_net):
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
- raise ValueError('Wrong tensor size: \n'
- f'crt_net: {crt_net[crt_k].size()}\n'
- f'ori_net: {ori_net[ori_k].size()}')
+ raise ValueError(f'Wrong tensor size: \ncrt_net: {crt_net[crt_k].size()}\nori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
@@ -57,9 +55,7 @@ def convert_net_d(ori_net, crt_net):
# replace
if crt_net[crt_k].size() != ori_net[ori_k].size():
- raise ValueError('Wrong tensor size: \n'
- f'crt_net: {crt_net[crt_k].size()}\n'
- f'ori_net: {ori_net[ori_k].size()}')
+ raise ValueError(f'Wrong tensor size: \ncrt_net: {crt_net[crt_k].size()}\nori_net: {ori_net[ori_k].size()}')
else:
crt_net[crt_k] = ori_net[ori_k]
return crt_net
diff --git a/scripts/plot/model_complexity_cmp_bsrn.py b/scripts/plot/model_complexity_cmp_bsrn.py
index b39cf85dc..33c8e380d 100644
--- a/scripts/plot/model_complexity_cmp_bsrn.py
+++ b/scripts/plot/model_complexity_cmp_bsrn.py
@@ -4,7 +4,7 @@ def main():
fig, ax = plt.subplots(figsize=(15, 10))
radius = 9.5
notation_size = 27
- '''0 - 10'''
+ """0 - 10"""
# BSRN-S, FSRCNN
x = [156, 13]
y = [32.16, 30.71]
@@ -12,7 +12,7 @@ def main():
ax.scatter(x, y, s=area, alpha=0.8, marker='.', c='#4D96FF', edgecolors='white', linewidths=2.0)
plt.annotate('FSRCNN', (13 + 10, 30.71 + 0.1), fontsize=notation_size)
plt.annotate('BSRN-S(Ours)', (156 - 70, 32.16 + 0.15), fontsize=notation_size)
- '''10 - 25'''
+ """10 - 25"""
# BSRN, RFDN
x = [357, 550]
y = [32.30, 32.24]
@@ -20,7 +20,7 @@ def main():
ax.scatter(x, y, s=area, alpha=1.0, marker='.', c='#FFD93D', edgecolors='white', linewidths=2.0)
plt.annotate('BSRN(Ours)', (357 - 70, 32.35 + 0.10), fontsize=notation_size)
plt.annotate('RFDN', (550 - 70, 32.24 + 0.15), fontsize=notation_size)
- '''25 - 50'''
+ """25 - 50"""
# IDN, IMDN, PAN
x = [553, 715, 272]
y = [31.82, 32.21, 32.13]
@@ -29,7 +29,7 @@ def main():
plt.annotate('IDN', (553 - 60, 31.82 + 0.15), fontsize=notation_size)
plt.annotate('IMDN', (715 + 10, 32.21 + 0.15), fontsize=notation_size)
plt.annotate('PAN', (272 - 70, 32.13 - 0.25), fontsize=notation_size)
- '''50 - 100'''
+ """50 - 100"""
# SRCNN, CARN, LAPAR-A
x = [57, 1592, 659]
y = [30.48, 32.13, 32.15]
@@ -37,7 +37,7 @@ def main():
ax.scatter(x, y, s=area, alpha=0.8, marker='.', c='#EAE7C6', edgecolors='white', linewidths=2.0)
plt.annotate('SRCNN', (57 + 30, 30.48 + 0.1), fontsize=notation_size)
plt.annotate('LAPAR-A', (659 - 75, 32.15 + 0.20), fontsize=notation_size)
- '''1M+'''
+ """1M+"""
# LapSRCN, VDSR, DRRN, MemNet
x = [502, 666, 298, 678]
y = [31.54, 31.35, 31.68, 31.74]
@@ -47,7 +47,7 @@ def main():
plt.annotate('VDSR', (666 - 70, 31.35 - 0.35), fontsize=notation_size)
plt.annotate('DRRN', (298 - 65, 31.68 - 0.35), fontsize=notation_size)
plt.annotate('MemNet', (678 + 15, 31.74 + 0.18), fontsize=notation_size)
- '''Ours marker'''
+ """Ours marker"""
x = [156]
y = [32.16]
ax.scatter(x, y, alpha=1.0, marker='*', c='r', s=300)
@@ -62,8 +62,10 @@ def main():
plt.title('PSNR vs. Parameters vs. Multi-Adds', fontsize=35)
h = [
- plt.plot([], [], color=c, marker='.', ms=i, alpha=a, ls='')[0] for i, c, a in zip(
- [40, 60, 80, 95, 110], ['#4D96FF', '#FFD93D', '#95CD41', '#EAE7C6', '#264653'], [0.8, 1.0, 0.6, 0.8, 0.3])
+ plt.plot([], [], color=c, marker='.', ms=i, alpha=a, ls='')[0]
+ for i, c, a in zip(
+ [40, 60, 80, 95, 110], ['#4D96FF', '#FFD93D', '#95CD41', '#EAE7C6', '#264653'], [0.8, 1.0, 0.6, 0.8, 0.3]
+ )
]
ax.legend(
labelspacing=0.1,
@@ -78,7 +80,8 @@ def main():
loc='lower right',
ncol=5,
shadow=True,
- handleheight=6)
+ handleheight=6,
+ )
for size in ax.get_xticklabels(): # Set fontsize for x-axis
size.set_fontsize('30')
diff --git a/scripts/publish_models.py b/scripts/publish_models.py
index 81495aa55..c2fb3c4fa 100644
--- a/scripts/publish_models.py
+++ b/scripts/publish_models.py
@@ -1,19 +1,21 @@
import glob
import subprocess
-import torch
from os import path as osp
+
+import torch
from torch.serialization import _is_zipfile, _open_file_like
def update_sha(paths):
print('# Update sha ...')
for idx, path in enumerate(paths):
- print(f'{idx+1:03d}: Processing {path}')
+ print(f'{idx + 1:03d}: Processing {path}')
net = torch.load(path, map_location=torch.device('cpu'))
basename = osp.basename(path)
if 'params' not in net and 'params_ema' not in net:
- user_response = input(f'WARN: Model {basename} does not have "params"/"params_ema" key. '
- 'Do you still want to continue? Y/N\n')
+ user_response = input(
+ f'WARN: Model {basename} does not have "params"/"params_ema" key. Do you still want to continue? Y/N\n'
+ )
if user_response.lower() == 'y':
pass
elif user_response.lower() == 'n':
@@ -45,7 +47,7 @@ def convert_to_backward_compatible_models(paths):
"""
print('# Convert to backward compatible pth files ...')
for idx, path in enumerate(paths):
- print(f'{idx+1:03d}: Processing {path}')
+ print(f'{idx + 1:03d}: Processing {path}')
flag_need_conversion = False
with _open_file_like(path, 'rb') as opened_file:
if _is_zipfile(opened_file):
diff --git a/setup.cfg b/setup.cfg
index 3af905a32..a93a20f8d 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -5,30 +5,3 @@ ignore =
# line break after binary operator (W504)
W504,
max-line-length=120
-
-[yapf]
-based_on_style = pep8
-column_limit = 120
-blank_line_before_nested_class_or_def = true
-split_before_expression_after_opening_paren = true
-
-[isort]
-line_length = 120
-multi_line_output = 0
-known_standard_library = pkg_resources,setuptools
-known_first_party = basicsr
-known_third_party = PIL,cv2,lmdb,numpy,pytest,requests,scipy,skimage,torch,torchvision,tqdm,yaml
-no_lines_before = STDLIB,LOCALFOLDER
-default_section = THIRDPARTY
-
-[codespell]
-skip = .git,./docs/build,*.cfg
-count =
-quiet-level = 3
-ignore-words-list = gool
-
-[aliases]
-test=pytest
-
-[tool:pytest]
-addopts=tests/
diff --git a/tests/test_archs/test_basicvsr_arch.py b/tests/test_archs/test_basicvsr_arch.py
index df100777f..456c8d506 100644
--- a/tests/test_archs/test_basicvsr_arch.py
+++ b/tests/test_archs/test_basicvsr_arch.py
@@ -28,14 +28,16 @@ def test_iconvsr():
# model init and forward
net = IconVSR(
- num_feat=8, num_block=1, keyframe_stride=2, temporal_padding=2, spynet_path=None, edvr_path=None).cuda()
+ num_feat=8, num_block=1, keyframe_stride=2, temporal_padding=2, spynet_path=None, edvr_path=None
+ ).cuda()
img = torch.rand((1, 6, 3, 64, 64), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 6, 3, 256, 256)
# --------------------------- temporal padding 3 ------------------------- #
net = IconVSR(
- num_feat=8, num_block=1, keyframe_stride=2, temporal_padding=3, spynet_path=None, edvr_path=None).cuda()
+ num_feat=8, num_block=1, keyframe_stride=2, temporal_padding=3, spynet_path=None, edvr_path=None
+ ).cuda()
img = torch.rand((1, 8, 3, 64, 64), dtype=torch.float32).cuda()
output = net(img)
assert output.shape == (1, 8, 3, 256, 256)
diff --git a/tests/test_models/test_sr_model.py b/tests/test_models/test_sr_model.py
index 2e95cda52..8348d7795 100644
--- a/tests/test_models/test_sr_model.py
+++ b/tests/test_models/test_sr_model.py
@@ -1,4 +1,5 @@
import tempfile
+
import torch
import yaml
@@ -132,7 +133,8 @@ def test_srmodel():
dataroot_lq='tests/data/lq',
io_backend=dict(type='disk'),
scale=4,
- phase='val')
+ phase='val',
+ )
dataset = PairedImageDataset(dataset_opt)
dataloader = torch.utils.data.DataLoader(dataset=dataset, batch_size=1, shuffle=False, num_workers=0)
assert model.is_train is True