You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
83 lines
3.2 KiB
Python
83 lines
3.2 KiB
Python
"""
|
|
Code from codeformer https://github.com/sczhou/CodeFormer
|
|
|
|
S-Lab License 1.0
|
|
|
|
Copyright 2022 S-Lab
|
|
|
|
Redistribution and use for non-commercial purpose in source and
|
|
binary forms, with or without modification, are permitted provided
|
|
that the following conditions are met:
|
|
|
|
1. Redistributions of source code must retain the above copyright
|
|
notice, this list of conditions and the following disclaimer.
|
|
|
|
2. Redistributions in binary form must reproduce the above copyright
|
|
notice, this list of conditions and the following disclaimer in
|
|
the documentation and/or other materials provided with the
|
|
distribution.
|
|
|
|
3. Neither the name of the copyright holder nor the names of its
|
|
contributors may be used to endorse or promote products derived
|
|
from this software without specific prior written permission.
|
|
|
|
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
|
|
"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
|
|
LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
|
|
A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
|
|
HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
|
|
SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
|
|
LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
|
DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
|
|
THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
|
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
|
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
In the event that redistribution and/or use for commercial purpose in
|
|
source or binary forms, with or without modification is required,
|
|
please contact the contributor(s) of the work.
|
|
"""
|
|
|
|
|
|
import torch
|
|
import os
|
|
import torch
|
|
from torch.hub import download_url_to_file, get_dir
|
|
from .parsenet import ParseNet
|
|
from urllib.parse import urlparse
|
|
from scripts.faceswaplab_globals import FACE_PARSER_DIR
|
|
|
|
ROOT_DIR = FACE_PARSER_DIR
|
|
|
|
|
|
def load_file_from_url(url, model_dir=None, progress=True, file_name=None):
|
|
"""Ref:https://github.com/1adrianb/face-alignment/blob/master/face_alignment/utils.py"""
|
|
if model_dir is None:
|
|
hub_dir = get_dir()
|
|
model_dir = os.path.join(hub_dir, "checkpoints")
|
|
|
|
os.makedirs(os.path.join(ROOT_DIR, model_dir), exist_ok=True)
|
|
|
|
parts = urlparse(url)
|
|
filename = os.path.basename(parts.path)
|
|
if file_name is not None:
|
|
filename = file_name
|
|
cached_file = os.path.abspath(os.path.join(ROOT_DIR, model_dir, filename))
|
|
if not os.path.exists(cached_file):
|
|
print(f'Downloading: "{url}" to {cached_file}\n')
|
|
download_url_to_file(url, cached_file, hash_prefix=None, progress=progress)
|
|
return cached_file
|
|
|
|
|
|
def init_parsing_model(device="cuda"):
|
|
model = ParseNet(in_size=512, out_size=512, parsing_ch=19)
|
|
model_url = "https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/parsing_parsenet.pth"
|
|
model_path = load_file_from_url(
|
|
url=model_url, model_dir="weights/facelib", progress=True, file_name=None
|
|
)
|
|
load_net = torch.load(model_path, map_location=lambda storage, loc: storage)
|
|
model.load_state_dict(load_net, strict=True)
|
|
model.eval()
|
|
model = model.to(device)
|
|
return model
|