keep name in ui batch process

main
Tran Xen 2 years ago
parent b3ea4c8b93
commit 76dbd57ad5

@ -3,7 +3,7 @@ import os
from dataclasses import dataclass from dataclasses import dataclass
from pprint import pformat from pprint import pformat
import traceback import traceback
from typing import Any, Dict, Generator, List, Set, Tuple, Optional from typing import Any, Dict, Generator, List, Set, Tuple, Optional, Union
import tempfile import tempfile
from tqdm import tqdm from tqdm import tqdm
import sys import sys
@ -110,7 +110,7 @@ def compare_faces(img1: PILImage, img2: PILImage) -> float:
def batch_process( def batch_process(
src_images: List[PILImage], src_images: List[Union[PILImage, str]], # image or filename
save_path: Optional[str], save_path: Optional[str],
units: List[FaceSwapUnitSettings], units: List[FaceSwapUnitSettings],
postprocess_options: PostProcessingOptions, postprocess_options: PostProcessingOptions,
@ -119,7 +119,7 @@ def batch_process(
Process a batch of images, apply face swapping according to the given settings, and optionally save the resulting images to a specified path. Process a batch of images, apply face swapping according to the given settings, and optionally save the resulting images to a specified path.
Args: Args:
src_images (List[PILImage]): List of source PIL Images to process. src_images (List[Union[PILImage, str]]): List of source PIL Images to process or list of images file names
save_path (Optional[str]): Destination path where the processed images will be saved. If None, no images are saved. save_path (Optional[str]): Destination path where the processed images will be saved. If None, no images are saved.
units (List[FaceSwapUnitSettings]): List of FaceSwapUnitSettings to apply to the images. units (List[FaceSwapUnitSettings]): List of FaceSwapUnitSettings to apply to the images.
postprocess_options (PostProcessingOptions): Post-processing settings to be applied to the images. postprocess_options (PostProcessingOptions): Post-processing settings to be applied to the images.
@ -138,6 +138,18 @@ def batch_process(
if src_images is not None and len(units) > 0: if src_images is not None and len(units) > 0:
result_images = [] result_images = []
for src_image in src_images: for src_image in src_images:
if isinstance(src_image, str):
if save_path:
path = os.path.join(
save_path, "swapped_" + os.path.basename(src_image)
)
src_image = Image.open(src_image)
elif save_path:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
assert isinstance(src_image, Image.Image)
current_images = [] current_images = []
swapped_images = process_images_units( swapped_images = process_images_units(
get_current_model(), images=[(src_image, None)], units=units get_current_model(), images=[(src_image, None)], units=units
@ -153,9 +165,6 @@ def batch_process(
if save_path: if save_path:
for img in current_images: for img in current_images:
path = tempfile.NamedTemporaryFile(
delete=False, suffix=".png", dir=save_path
).name
img.save(path) img.save(path)
result_images += current_images result_images += current_images

@ -216,12 +216,10 @@ def batch_process(
] ]
postprocess_options = classes[-1] postprocess_options = classes[-1]
images = [ images_paths = [file.name for file in files]
Image.open(file.name) for file in files
] # potentially greedy but Image.open is supposed to be lazy
return swapper.batch_process( return swapper.batch_process(
images, images_paths,
save_path=save_path, save_path=save_path,
units=units, units=units,
postprocess_options=postprocess_options, postprocess_options=postprocess_options,

Loading…
Cancel
Save