add tests

main
Tran Xen 2 years ago
parent ddc403db1f
commit 15e9366eb6

@ -1,4 +1,4 @@
#!/bin/bash
autoflake --in-place --remove-unused-variables -r --remove-all-unused-imports .
mypy --install-types
pre-commit run --all-files
pre-commit run --all-files

@ -12,13 +12,13 @@ address = "http://127.0.0.1:7860"
# First face unit :
unit1 = FaceSwapUnit(
source_img=pil_to_base64("../../references/man.png"), # The face you want to use
source_img=pil_to_base64("../references/man.png"), # The face you want to use
faces_index=(0,), # Replace first face
)
# Second face unit :
unit2 = FaceSwapUnit(
source_img=pil_to_base64("../../references/woman.png"), # The face you want to use
source_img=pil_to_base64("../references/woman.png"), # The face you want to use
same_gender=True,
faces_index=(0,), # Replace first woman since same gender is on
)
@ -48,5 +48,6 @@ result = requests.post(
)
response = FaceSwapResponse.parse_obj(result.json())
for img, info in zip(response.pil_images, response.infos):
img.show(title=info)
print(response.json())
for img in response.pil_images:
img.show()

Before

Width:  |  Height:  |  Size: 99 KiB

After

Width:  |  Height:  |  Size: 99 KiB

@ -13,9 +13,6 @@ from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSetting
from scripts.faceswaplab_utils.imgutils import (
base64_to_pil,
)
from scripts.faceswaplab_utils.models_utils import get_current_model
from modules.shared import opts
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions,
)
@ -135,22 +132,18 @@ def faceswaplab_api(_: gr.Blocks, app: FastAPI) -> None:
units: List[FaceSwapUnitSettings] = []
src_image: Optional[Image.Image] = base64_to_pil(request.image)
response = FaceSwapResponse(images=[], infos=[])
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
if src_image is not None:
if request.postprocessing:
pp_options = get_postprocessing_options(request.postprocessing)
units = get_faceswap_units_settings(request.units)
swapped_images = swapper.process_images_units(
get_current_model(),
images=[(src_image, None)],
units=units,
upscaled_swapper=opts.data.get("faceswaplab_upscaled_swapper", False),
swapped_images = swapper.batch_process(
[src_image], None, units=units, postprocess_options=pp_options
)
for img, info in swapped_images:
if pp_options:
img = enhance_image(img, pp_options)
for img in swapped_images:
response.images.append(encode_to_base64(img))
response.infos.append(info)
response.infos = [] # Not used atm
return response

@ -8,7 +8,7 @@ REFERENCE_PATH = os.path.join(
scripts.basedir(), "extensions", "sd-webui-faceswaplab", "references"
)
VERSION_FLAG: str = "v1.1.0"
VERSION_FLAG: str = "v1.1.1"
EXTENSION_PATH = os.path.join("extensions", "sd-webui-faceswaplab")
# The NSFW score threshold. If any part of the image has a score greater than this threshold, the image will be considered NSFW.

@ -2,6 +2,7 @@ import copy
import os
from dataclasses import dataclass
from typing import Any, Dict, List, Set, Tuple, Optional
import tempfile
import cv2
import insightface
@ -21,6 +22,12 @@ from scripts import faceswaplab_globals
from modules.shared import opts
from functools import lru_cache
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from scripts.faceswaplab_postprocessing.postprocessing_options import (
PostProcessingOptions,
)
from scripts.faceswaplab_utils.models_utils import get_current_model
providers = ["CPUExecutionProvider"]
@ -78,6 +85,53 @@ def compare_faces(img1: Image.Image, img2: Image.Image) -> float:
return -1
def batch_process(
src_images: List[Image.Image],
save_path: Optional[str],
units: List[FaceSwapUnitSettings],
postprocess_options: PostProcessingOptions,
) -> Optional[List[Image.Image]]:
try:
if save_path:
os.makedirs(save_path, exist_ok=True)
units = [u for u in units if u.enable]
if src_images is not None and len(units) > 0:
result_images = []
for src_image in src_images:
current_images = []
swapped_images = process_images_units(
get_current_model(),
images=[(src_image, None)],
units=units,
upscaled_swapper=opts.data.get(
"faceswaplab_upscaled_swapper", False
),
)
if len(swapped_images) > 0:
current_images += [img for img, _ in swapped_images]
logger.info("%s images generated", len(current_images))
for i, img in enumerate(current_images):
current_images[i] = enhance_image(img, postprocess_options)
if save_path:
for img in current_images:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
img.save(path)
result_images += current_images
return result_images
except Exception as e:
logger.error("Batch Process error : %s", e)
import traceback
traceback.print_exc()
return None
class FaceModelException(Exception):
"""Exception raised when an error is encountered in the face model."""

@ -26,7 +26,6 @@ from scripts.faceswaplab_postprocessing.postprocessing import enhance_image
from dataclasses import fields
from typing import Any, Dict, List, Optional
from scripts.faceswaplab_ui.faceswaplab_unit_settings import FaceSwapUnitSettings
from scripts.faceswaplab_utils.models_utils import get_current_model
import re
@ -291,9 +290,6 @@ def batch_process(
files: List[gr.File], save_path: str, *components: List[gr.components.Component]
) -> Optional[List[Image.Image]]:
try:
if save_path is not None:
os.makedirs(save_path, exist_ok=True)
units_count = opts.data.get("faceswaplab_units_count", 3)
units: List[FaceSwapUnitSettings] = []
@ -312,36 +308,15 @@ def batch_process(
*components[shift : shift + len(fields(PostProcessingOptions))] # type: ignore
)
logger.debug("%s", pformat(postprocess_options))
units = [u for u in units if u.enable]
if files is not None and len(units) > 0:
images = []
for file in files:
current_images = []
src_image = Image.open(file.name)
swapped_images = swapper.process_images_units(
get_current_model(),
images=[(src_image, None)],
units=units,
upscaled_swapper=opts.data.get(
"faceswaplab_upscaled_swapper", False
),
)
if len(swapped_images) > 0:
current_images += [img for img, _ in swapped_images]
logger.info("%s images generated", len(current_images))
for i, img in enumerate(current_images):
current_images[i] = enhance_image(img, postprocess_options)
for img in current_images:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
img.save(path)
images += current_images
return images
images = [
Image.open(file.name) for file in files
] # potentially greedy but Image.open is supposed to be lazy
return swapper.batch_process(
images,
save_path=save_path,
units=units,
postprocess_options=postprocess_options,
)
except Exception as e:
logger.error("Batch Process error : %s", e)
import traceback

@ -0,0 +1,3 @@
#!/bin/bash
./check.sh
pytest -p no:warnings

@ -0,0 +1,83 @@
from typing import List
import pytest
import requests
import sys
sys.path.append(".")
from client_api.client_utils import (
FaceSwapUnit,
FaceSwapResponse,
PostProcessingOptions,
FaceSwapRequest,
base64_to_pil,
pil_to_base64,
InpaintingWhen,
)
from PIL import Image
base_url = "http://127.0.0.1:7860"
@pytest.fixture
def face_swap_request() -> FaceSwapRequest:
# First face unit
unit1 = FaceSwapUnit(
source_img=pil_to_base64("references/man.png"), # The face you want to use
faces_index=(0,), # Replace first face
)
# Second face unit
unit2 = FaceSwapUnit(
source_img=pil_to_base64("references/woman.png"), # The face you want to use
same_gender=True,
faces_index=(0,), # Replace first woman since same gender is on
)
# Post-processing config
pp = PostProcessingOptions(
face_restorer_name="CodeFormer",
codeformer_weight=0.5,
restorer_visibility=1,
upscaler_name="Lanczos",
scale=4,
inpainting_steps=30,
inpainting_denoising_strengh=0.1,
inpainting_when=InpaintingWhen.BEFORE_RESTORE_FACE,
)
# Prepare the request
request = FaceSwapRequest(
image=pil_to_base64("tests/test_image.png"),
units=[unit1, unit2],
postprocessing=pp,
)
return request
def test_version() -> None:
response = requests.get(f"{base_url}/faceswaplab/version")
assert response.status_code == 200
assert "version" in response.json()
def test_faceswap(face_swap_request: FaceSwapRequest) -> None:
response = requests.post(
f"{base_url}/faceswaplab/swap_face",
data=face_swap_request.json(),
headers={"Content-Type": "application/json; charset=utf-8"},
)
assert response.status_code == 200
data = response.json()
assert "images" in data
assert "infos" in data
res = FaceSwapResponse.parse_obj(response.json())
images: List[Image.Image] = res.pil_images
assert len(images) == 1
image = images[0]
orig_image = base64_to_pil(face_swap_request.image)
assert image.width == orig_image.width * face_swap_request.postprocessing.scale
assert image.height == orig_image.height * face_swap_request.postprocessing.scale

Binary file not shown.

After

Width:  |  Height:  |  Size: 99 KiB

Loading…
Cancel
Save