| | import argparse |
| | import cv2 |
| | import numpy as np |
| | import os |
| | import onnxruntime as ort |
| | from axengine import InferenceSession |
| | import numpy as np |
| | import cv2 |
| | import argparse |
| | import os.path as osp |
| | from loguru import logger |
| | from numpy import ndarray |
| | import pickle as pkl |
| | import torch |
| | import torch.nn.functional as F |
| | from cropper import Cropper |
| | import imageio |
| | import subprocess |
| | from utils.timer import Timer |
| | from typing import Union |
| | from scipy.spatial import ConvexHull |
| |
|
| |
|
| | appearance_feature_extractor, motion_extractor, warping_module, spade_generator, stitching_retargeting_module = None, None, None, None, None |
| |
|
| |
|
| | def parse_args() -> argparse.Namespace: |
| | parser = argparse.ArgumentParser( |
| | prog="LivePortrait", |
| | description="LivePortrait: A Real-time 3D Live Portrait Animation System" |
| | ) |
| | parser.add_argument( |
| | "--source", |
| | type=str, |
| | required=True, |
| | help="Path to source image.", |
| | ) |
| | parser.add_argument( |
| | "--driving", |
| | type=str, |
| | required=True, |
| | help="Path to driving image.", |
| | ) |
| | parser.add_argument( |
| | "--models", |
| | type=str, |
| | required=True, |
| | help="Path to onnx models.", |
| | ) |
| | parser.add_argument( |
| | "--output-dir", |
| | type=str, |
| | default="./output", |
| | help="Path to infer results.", |
| | ) |
| | |
| | return parser.parse_args() |
| |
|
| |
|
| | def images2video(images, wfp, **kwargs): |
| | fps = kwargs.get('fps', 30) |
| | video_format = kwargs.get('format', 'mp4') |
| | codec = kwargs.get('codec', 'libx264') |
| | quality = kwargs.get('quality') |
| | pixelformat = kwargs.get('pixelformat', 'yuv420p') |
| | image_mode = kwargs.get('image_mode', 'rgb') |
| | macro_block_size = kwargs.get('macro_block_size', 2) |
| | ffmpeg_params = ['-crf', str(kwargs.get('crf', 18))] |
| |
|
| | writer = imageio.get_writer( |
| | wfp, fps=fps, format=video_format, |
| | codec=codec, quality=quality, ffmpeg_params=ffmpeg_params, pixelformat=pixelformat, macro_block_size=macro_block_size |
| | ) |
| |
|
| | n = len(images) |
| | for i in range(n): |
| | if image_mode.lower() == 'bgr': |
| | writer.append_data(images[i][..., ::-1]) |
| | else: |
| | writer.append_data(images[i]) |
| |
|
| | writer.close() |
| |
|
| |
|
| | def has_audio_stream(video_path: str) -> bool: |
| | """ |
| | Check if the video file contains an audio stream. |
| | |
| | :param video_path: Path to the video file |
| | :return: True if the video contains an audio stream, False otherwise |
| | """ |
| | if osp.isdir(video_path): |
| | return False |
| |
|
| | cmd = [ |
| | 'ffprobe', |
| | '-v', 'error', |
| | '-select_streams', 'a', |
| | '-show_entries', 'stream=codec_type', |
| | '-of', 'default=noprint_wrappers=1:nokey=1', |
| | f'"{video_path}"' |
| | ] |
| |
|
| | try: |
| | |
| | result = exec_cmd(' '.join(cmd)) |
| | if result.returncode != 0: |
| | logger.info(f"Error occurred while probing video: {result.stderr}") |
| | return False |
| |
|
| | |
| | return bool(result.stdout.strip()) |
| | except Exception as e: |
| | logger.info( |
| | f"Error occurred while probing video: {video_path}, " |
| | "you may need to install ffprobe! (https://ffmpeg.org/download.html) " |
| | "Now set audio to false!", |
| | style="bold red" |
| | ) |
| | return False |
| |
|
| |
|
| | def tensor_to_numpy(data: Union[np.ndarray, torch.Tensor]) -> np.ndarray: |
| | """transform torch.Tensor into numpy.ndarray""" |
| | if isinstance(data, torch.Tensor): |
| | return data.data.cpu().numpy() |
| | return data |
| |
|
| |
|
| | def calc_motion_multiplier( |
| | kp_source: Union[np.ndarray, torch.Tensor], |
| | kp_driving_initial: Union[np.ndarray, torch.Tensor] |
| | ) -> float: |
| | """calculate motion_multiplier based on the source image and the first driving frame""" |
| | kp_source_np = tensor_to_numpy(kp_source) |
| | kp_driving_initial_np = tensor_to_numpy(kp_driving_initial) |
| |
|
| | source_area = ConvexHull(kp_source_np.squeeze(0)).volume |
| | driving_area = ConvexHull(kp_driving_initial_np.squeeze(0)).volume |
| | motion_multiplier = np.sqrt(source_area) / np.sqrt(driving_area) |
| | |
| |
|
| | return motion_multiplier |
| |
|
| |
|
| | def load_video(video_info, n_frames=-1): |
| | reader = imageio.get_reader(video_info, "ffmpeg") |
| |
|
| | ret = [] |
| | for idx, frame_rgb in enumerate(reader): |
| | if n_frames > 0 and idx >= n_frames: |
| | break |
| | ret.append(frame_rgb) |
| |
|
| | reader.close() |
| | return ret |
| |
|
| |
|
| | def fast_check_ffmpeg(): |
| | try: |
| | subprocess.run(["ffmpeg", "-version"], capture_output=True, check=True) |
| | return True |
| | except: |
| | return False |
| |
|
| |
|
| | def is_video(file_path): |
| | if file_path.lower().endswith((".mp4", ".mov", ".avi", ".webm")) or osp.isdir(file_path): |
| | return True |
| | return False |
| |
|
| |
|
| | def is_image(file_path): |
| | image_extensions = ('.jpg', '.jpeg', '.png', '.gif', '.bmp', '.tiff', '.webp') |
| | return file_path.lower().endswith(image_extensions) |
| |
|
| |
|
| | def get_fps(filepath, default_fps=25): |
| | try: |
| | fps = cv2.VideoCapture(filepath).get(cv2.CAP_PROP_FPS) |
| |
|
| | if fps in (0, None): |
| | fps = default_fps |
| | except Exception as e: |
| | logger.info(e) |
| | fps = default_fps |
| |
|
| | return fps |
| |
|
| |
|
| | def calculate_distance_ratio(lmk: np.ndarray, idx1: int, idx2: int, idx3: int, idx4: int, eps: float = 1e-6) -> np.ndarray: |
| | return (np.linalg.norm(lmk[:, idx1] - lmk[:, idx2], axis=1, keepdims=True) / |
| | (np.linalg.norm(lmk[:, idx3] - lmk[:, idx4], axis=1, keepdims=True) + eps)) |
| |
|
| |
|
| | def calc_eye_close_ratio(lmk: np.ndarray, target_eye_ratio: np.ndarray = None) -> np.ndarray: |
| | lefteye_close_ratio = calculate_distance_ratio(lmk, 6, 18, 0, 12) |
| | righteye_close_ratio = calculate_distance_ratio(lmk, 30, 42, 24, 36) |
| | if target_eye_ratio is not None: |
| | return np.concatenate([lefteye_close_ratio, righteye_close_ratio, target_eye_ratio], axis=1) |
| | else: |
| | return np.concatenate([lefteye_close_ratio, righteye_close_ratio], axis=1) |
| |
|
| |
|
| | def calc_lip_close_ratio(lmk: np.ndarray) -> np.ndarray: |
| | return calculate_distance_ratio(lmk, 90, 102, 48, 66) |
| |
|
| |
|
| | def concat_frames(driving_image_lst, source_image_lst, I_p_lst): |
| | |
| | out_lst = [] |
| | h, w, _ = I_p_lst[0].shape |
| | source_image_resized_lst = [cv2.resize(img, (w, h)) for img in source_image_lst] |
| |
|
| | for idx, _ in enumerate(I_p_lst): |
| | I_p = I_p_lst[idx] |
| | source_image_resized = source_image_resized_lst[idx] if len(source_image_lst) > 1 else source_image_resized_lst[0] |
| |
|
| | if driving_image_lst is None: |
| | out = np.hstack((source_image_resized, I_p)) |
| | else: |
| | driving_image = driving_image_lst[idx] |
| | driving_image_resized = cv2.resize(driving_image, (w, h)) |
| | out = np.hstack((driving_image_resized, source_image_resized, I_p)) |
| |
|
| | out_lst.append(out) |
| | return out_lst |
| |
|
| |
|
| | def concat_feat(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
| | """ |
| | kp_source: (bs, k, 3) |
| | kp_driving: (bs, k, 3) |
| | Return: (bs, 2k*3) |
| | """ |
| | bs_src = kp_source.shape[0] |
| | bs_dri = kp_driving.shape[0] |
| | assert bs_src == bs_dri, 'batch size must be equal' |
| |
|
| | feat = torch.cat([kp_source.view(bs_src, -1), kp_driving.view(bs_dri, -1)], dim=1) |
| | return feat |
| |
|
| |
|
| | DTYPE = np.float32 |
| | CV2_INTERP = cv2.INTER_LINEAR |
| |
|
| |
|
| | def _transform_img(img, M, dsize, flags=CV2_INTERP, borderMode=None): |
| | """ conduct similarity or affine transformation to the image, do not do border operation! |
| | img: |
| | M: 2x3 matrix or 3x3 matrix |
| | dsize: target shape (width, height) |
| | """ |
| | if isinstance(dsize, tuple) or isinstance(dsize, list): |
| | _dsize = tuple(dsize) |
| | else: |
| | _dsize = (dsize, dsize) |
| |
|
| | if borderMode is not None: |
| | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags, borderMode=borderMode, borderValue=(0, 0, 0)) |
| | else: |
| | return cv2.warpAffine(img, M[:2, :], dsize=_dsize, flags=flags) |
| |
|
| |
|
| | def prepare_paste_back(mask_crop, crop_M_c2o, dsize): |
| | """prepare mask for later image paste back |
| | """ |
| | mask_ori = _transform_img(mask_crop, crop_M_c2o, dsize) |
| | mask_ori = mask_ori.astype(np.float32) / 255. |
| | return mask_ori |
| |
|
| |
|
| | def paste_back(img_crop, M_c2o, img_ori, mask_ori): |
| | """paste back the image |
| | """ |
| | dsize = (img_ori.shape[1], img_ori.shape[0]) |
| | result = _transform_img(img_crop, M_c2o, dsize=dsize) |
| | result = np.clip(mask_ori * result + (1 - mask_ori) * img_ori, 0, 255).astype(np.uint8) |
| | return result |
| |
|
| |
|
| | def prefix(filename): |
| | """a.jpg -> a""" |
| | pos = filename.rfind(".") |
| | if pos == -1: |
| | return filename |
| | return filename[:pos] |
| |
|
| |
|
| | def basename(filename): |
| | """a/b/c.jpg -> c""" |
| | return prefix(osp.basename(filename)) |
| |
|
| |
|
| | def mkdir(d, log=False): |
| | |
| | if not osp.exists(d): |
| | os.makedirs(d, exist_ok=True) |
| | if log: |
| | logger.info(f"Make dir: {d}") |
| | return d |
| |
|
| |
|
| | def dct2device(dct: dict, device): |
| | for key in dct: |
| | if isinstance(dct[key], torch.Tensor): |
| | dct[key] = dct[key].to(device) |
| | else: |
| | dct[key] = torch.tensor(dct[key]).to(device) |
| | return dct |
| |
|
| |
|
| | PI = np.pi |
| |
|
| | def headpose_pred_to_degree(pred): |
| | """ |
| | pred: (bs, 66) or (bs, 1) or others |
| | """ |
| | if pred.ndim > 1 and pred.shape[1] == 66: |
| | |
| | device = pred.device |
| | idx_tensor = [idx for idx in range(0, 66)] |
| | idx_tensor = torch.FloatTensor(idx_tensor).to(device) |
| | pred = F.softmax(pred, dim=1) |
| | degree = torch.sum(pred*idx_tensor, axis=1) * 3 - 97.5 |
| |
|
| | return degree |
| |
|
| | return pred |
| |
|
| |
|
| | def get_rotation_matrix(pitch_, yaw_, roll_): |
| | """ the input is in degree |
| | """ |
| | |
| | pitch = pitch_ / 180 * PI |
| | yaw = yaw_ / 180 * PI |
| | roll = roll_ / 180 * PI |
| |
|
| | device = pitch.device |
| |
|
| | if pitch.ndim == 1: |
| | pitch = pitch.unsqueeze(1) |
| | if yaw.ndim == 1: |
| | yaw = yaw.unsqueeze(1) |
| | if roll.ndim == 1: |
| | roll = roll.unsqueeze(1) |
| |
|
| | |
| | bs = pitch.shape[0] |
| | ones = torch.ones([bs, 1]).to(device) |
| | zeros = torch.zeros([bs, 1]).to(device) |
| | x, y, z = pitch, yaw, roll |
| |
|
| | rot_x = torch.cat([ |
| | ones, zeros, zeros, |
| | zeros, torch.cos(x), -torch.sin(x), |
| | zeros, torch.sin(x), torch.cos(x) |
| | ], dim=1).reshape([bs, 3, 3]) |
| |
|
| | rot_y = torch.cat([ |
| | torch.cos(y), zeros, torch.sin(y), |
| | zeros, ones, zeros, |
| | -torch.sin(y), zeros, torch.cos(y) |
| | ], dim=1).reshape([bs, 3, 3]) |
| |
|
| | rot_z = torch.cat([ |
| | torch.cos(z), -torch.sin(z), zeros, |
| | torch.sin(z), torch.cos(z), zeros, |
| | zeros, zeros, ones |
| | ], dim=1).reshape([bs, 3, 3]) |
| |
|
| | rot = rot_z @ rot_y @ rot_x |
| | return rot.permute(0, 2, 1) |
| |
|
| |
|
| | def make_abs_path(fn): |
| | return osp.join(osp.dirname(osp.realpath(__file__)), fn) |
| |
|
| |
|
| | def load_image_rgb(image_path: str): |
| | if not osp.exists(image_path): |
| | raise FileNotFoundError(f"Image not found: {image_path}") |
| | img = cv2.imread(image_path, cv2.IMREAD_COLOR) |
| | return cv2.cvtColor(img, cv2.COLOR_BGR2RGB) |
| |
|
| |
|
| | def resize_to_limit(img: np.ndarray, max_dim=1920, division=2): |
| | """ |
| | ajust the size of the image so that the maximum dimension does not exceed max_dim, and the width and the height of the image are multiples of n. |
| | :param img: the image to be processed. |
| | :param max_dim: the maximum dimension constraint. |
| | :param n: the number that needs to be multiples of. |
| | :return: the adjusted image. |
| | """ |
| | h, w = img.shape[:2] |
| |
|
| | |
| | if max_dim > 0 and max(h, w) > max_dim: |
| | if h > w: |
| | new_h = max_dim |
| | new_w = int(w * (max_dim / h)) |
| | else: |
| | new_w = max_dim |
| | new_h = int(h * (max_dim / w)) |
| | img = cv2.resize(img, (new_w, new_h)) |
| |
|
| | |
| | division = max(division, 1) |
| | new_h = img.shape[0] - (img.shape[0] % division) |
| | new_w = img.shape[1] - (img.shape[1] % division) |
| |
|
| | if new_h == 0 or new_w == 0: |
| | |
| | return img |
| |
|
| | if new_h != img.shape[0] or new_w != img.shape[1]: |
| | img = img[:new_h, :new_w] |
| |
|
| | return img |
| |
|
| |
|
| | def preprocess(input_data): |
| | img_rgb = load_image_rgb(input_data) |
| | img_rgb = resize_to_limit(img_rgb) |
| | return [img_rgb] |
| |
|
| |
|
| | def postprocess(output_data): |
| | |
| | |
| | return output_data |
| |
|
| |
|
| | def infer(model, input_data): |
| | input_name = model.get_inputs()[0].name |
| | output_name = model.get_outputs()[0].name |
| | input_data = preprocess(input_data) |
| | result = model.run([output_name], {input_name: input_data}) |
| | return postprocess(result) |
| |
|
| |
|
| | def partial_fields(target_class, kwargs): |
| | return target_class(**{k: v for k, v in kwargs.items() if hasattr(target_class, k)}) |
| |
|
| |
|
| | def calc_ratio(lmk_lst): |
| | input_eye_ratio_lst = [] |
| | input_lip_ratio_lst = [] |
| | for lmk in lmk_lst: |
| | |
| | input_eye_ratio_lst.append(calc_eye_close_ratio(lmk[None])) |
| | |
| | input_lip_ratio_lst.append(calc_lip_close_ratio(lmk[None])) |
| | return input_eye_ratio_lst, input_lip_ratio_lst |
| |
|
| |
|
| | def prepare_videos(imgs) -> torch.Tensor: |
| | """ construct the input as standard |
| | imgs: NxBxHxWx3, uint8 |
| | """ |
| | device = "cpu" |
| | if isinstance(imgs, list): |
| | _imgs = np.array(imgs)[..., np.newaxis] |
| | elif isinstance(imgs, np.ndarray): |
| | _imgs = imgs |
| | else: |
| | raise ValueError(f'imgs type error: {type(imgs)}') |
| |
|
| | y = _imgs.astype(np.float32) / 255. |
| | y = np.clip(y, 0, 1) |
| | y = torch.from_numpy(y).permute(0, 4, 3, 1, 2) |
| | y = y.to(device) |
| |
|
| | return y |
| |
|
| |
|
| | def get_kp_info(x: torch.Tensor) -> dict: |
| | """ get the implicit keypoint information |
| | x: Bx3xHxW, normalized to 0~1 |
| | flag_refine_info: whether to trandform the pose to degrees and the dimention of the reshape |
| | return: A dict contains keys: 'pitch', 'yaw', 'roll', 't', 'exp', 'scale', 'kp' |
| | """ |
| | outs = motion_extractor.run(None, input_feed={"input": x.numpy()}) |
| | |
| | |
| | kp_info = {} |
| | kp_info['pitch'] = torch.from_numpy(outs[0]) |
| | kp_info['yaw'] = torch.from_numpy(outs[1]) |
| | kp_info['roll'] = torch.from_numpy(outs[2]) |
| | kp_info['t'] = torch.from_numpy(outs[3]) |
| | kp_info['exp'] = torch.from_numpy(outs[4]) |
| | kp_info['scale'] = torch.from_numpy(outs[5]) |
| | kp_info['kp'] = torch.from_numpy(outs[6]) |
| |
|
| | flag_refine_info: bool = True |
| | if flag_refine_info: |
| | bs = kp_info['kp'].shape[0] |
| | kp_info['pitch'] = headpose_pred_to_degree(kp_info['pitch'])[:, None] |
| | kp_info['yaw'] = headpose_pred_to_degree(kp_info['yaw'])[:, None] |
| | kp_info['roll'] = headpose_pred_to_degree(kp_info['roll'])[:, None] |
| | kp_info['kp'] = kp_info['kp'].reshape(bs, -1, 3) |
| | kp_info['exp'] = kp_info['exp'].reshape(bs, -1, 3) |
| |
|
| | return kp_info |
| |
|
| |
|
| | def transform_keypoint(kp_info: dict): |
| | """ |
| | transform the implicit keypoints with the pose, shift, and expression deformation |
| | kp: BxNx3 |
| | """ |
| | kp = kp_info['kp'] |
| | pitch, yaw, roll = kp_info['pitch'], kp_info['yaw'], kp_info['roll'] |
| |
|
| | t, exp = kp_info['t'], kp_info['exp'] |
| | scale = kp_info['scale'] |
| | pitch = headpose_pred_to_degree(pitch) |
| | yaw = headpose_pred_to_degree(yaw) |
| | roll = headpose_pred_to_degree(roll) |
| |
|
| | bs = kp.shape[0] |
| | if kp.ndim == 2: |
| | num_kp = kp.shape[1] // 3 |
| | else: |
| | num_kp = kp.shape[1] |
| |
|
| | rot_mat = get_rotation_matrix(pitch, yaw, roll) |
| |
|
| | |
| | kp_transformed = kp.view(bs, num_kp, 3) @ rot_mat + exp.view(bs, num_kp, 3) |
| | kp_transformed *= scale[..., None] |
| | kp_transformed[:, :, 0:2] += t[:, None, 0:2] |
| |
|
| | return kp_transformed |
| |
|
| |
|
| | def make_motion_template(I_lst, c_eyes_lst, c_lip_lst, **kwargs): |
| | n_frames = I_lst.shape[0] |
| | template_dct = { |
| | 'n_frames': n_frames, |
| | 'output_fps': kwargs.get('output_fps', 25), |
| | 'motion': [], |
| | 'c_eyes_lst': [], |
| | 'c_lip_lst': [], |
| | } |
| |
|
| | for i in range(n_frames): |
| | |
| | I_i = I_lst[i] |
| | x_i_info = get_kp_info(I_i) |
| | x_s = transform_keypoint(x_i_info) |
| | R_i = get_rotation_matrix(x_i_info['pitch'], x_i_info['yaw'], x_i_info['roll']) |
| |
|
| | item_dct = { |
| | 'scale': x_i_info['scale'].cpu().numpy().astype(np.float32), |
| | 'R': R_i.cpu().numpy().astype(np.float32), |
| | 'exp': x_i_info['exp'].cpu().numpy().astype(np.float32), |
| | 't': x_i_info['t'].cpu().numpy().astype(np.float32), |
| | 'kp': x_i_info['kp'].cpu().numpy().astype(np.float32), |
| | 'x_s': x_s.cpu().numpy().astype(np.float32), |
| | } |
| |
|
| | template_dct['motion'].append(item_dct) |
| |
|
| | c_eyes = c_eyes_lst[i].astype(np.float32) |
| | template_dct['c_eyes_lst'].append(c_eyes) |
| |
|
| | c_lip = c_lip_lst[i].astype(np.float32) |
| | template_dct['c_lip_lst'].append(c_lip) |
| |
|
| | return template_dct |
| |
|
| |
|
| | def prepare_source(img: np.ndarray) -> torch.Tensor: |
| | """ construct the input as standard |
| | img: HxWx3, uint8, 256x256 |
| | """ |
| | device = "cpu" |
| | h, w = img.shape[:2] |
| | x = img.copy() |
| |
|
| | if x.ndim == 3: |
| | x = x[np.newaxis].astype(np.float32) / 255. |
| | elif x.ndim == 4: |
| | x = x.astype(np.float32) / 255. |
| | else: |
| | raise ValueError(f'img ndim should be 3 or 4: {x.ndim}') |
| | x = np.clip(x, 0, 1) |
| | x = torch.from_numpy(x).permute(0, 3, 1, 2) |
| | x = x.to(device) |
| | return x |
| |
|
| |
|
| | def extract_feature_3d(x: torch.Tensor) -> torch.Tensor: |
| | """ get the appearance feature of the image by F |
| | x: Bx3xHxW, normalized to 0~1 |
| | """ |
| | outs = appearance_feature_extractor.run(None, input_feed={"input": x.numpy()})[0] |
| | |
| | |
| | return torch.from_numpy(outs) |
| |
|
| |
|
| | def stitch(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
| | """ |
| | kp_source: BxNx3 |
| | kp_driving: BxNx3 |
| | Return: Bx(3*num_kp+2) |
| | """ |
| | feat_stiching = concat_feat(kp_source, kp_driving) |
| | delta = stitching_retargeting_module.run(None, input_feed={"input": feat_stiching.numpy()})[0] |
| | |
| | return torch.from_numpy(delta) |
| |
|
| |
|
| | def stitching(kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
| | """ conduct the stitching |
| | kp_source: Bxnum_kpx3 |
| | kp_driving: Bxnum_kpx3 |
| | """ |
| |
|
| | bs, num_kp = kp_source.shape[:2] |
| |
|
| | kp_driving_new = kp_driving.clone() |
| | delta = stitch(kp_source, kp_driving_new) |
| |
|
| | delta_exp = delta[..., :3*num_kp].reshape(bs, num_kp, 3) |
| | delta_tx_ty = delta[..., 3*num_kp:3*num_kp+2].reshape(bs, 1, 2) |
| |
|
| | kp_driving_new += delta_exp |
| | kp_driving_new[..., :2] += delta_tx_ty |
| |
|
| | return kp_driving_new |
| |
|
| |
|
| | def warp_decode(feature_3d: torch.Tensor, kp_source: torch.Tensor, kp_driving: torch.Tensor) -> torch.Tensor: |
| | """ get the image after the warping of the implicit keypoints |
| | feature_3d: Bx32x16x64x64, feature volume |
| | kp_source: BxNx3 |
| | kp_driving: BxNx3 |
| | """ |
| | warp_timer = Timer() |
| | warp_timer.tic() |
| | outs = warping_module.run([], {"feature_3d": feature_3d.numpy(), "kp_driving": kp_driving.numpy(), "kp_source": kp_source.numpy()})[2] |
| | warp_timer.toc() |
| | logger.debug(f'warp time: {warp_timer.diff:.3f}s') |
| | |
| | outs = spade_generator.run(None, input_feed={"input": outs})[0] |
| | |
| | ret_dct = {} |
| | ret_dct['out'] = torch.from_numpy(outs) |
| | return ret_dct |
| |
|
| |
|
| | def parse_output(out: torch.Tensor) -> np.ndarray: |
| | """ construct the output as standard |
| | return: 1xHxWx3, uint8 |
| | """ |
| | out = np.transpose(out.data.cpu().numpy(), [0, 2, 3, 1]) |
| | out = np.clip(out, 0, 1) |
| | out = np.clip(out * 255, 0, 255).astype(np.uint8) |
| |
|
| | return out |
| |
|
| |
|
| | def load_model(model_type, model_path=None): |
| | if model_type == 'appearance_feature_extractor': |
| | model = InferenceSession(f"{model_path}/feature_extractor.axmodel") |
| | elif model_type == 'motion_extractor': |
| | model = InferenceSession(f'{model_path}/motion_extractor.axmodel') |
| | elif model_type == 'warping_module': |
| | model = ort.InferenceSession(f'{model_path}/warp.onnx', providers=["CPUExecutionProvider"]) |
| | |
| | elif model_type == 'spade_generator': |
| | model = InferenceSession(f'{model_path}/spade_generator.axmodel') |
| | elif model_type == 'stitching_retargeting_module': |
| | model = InferenceSession(f'{model_path}/stitching_retargeting.axmodel') |
| | return model |
| |
|
| |
|
| | def main(): |
| | args = parse_args() |
| |
|
| | global appearance_feature_extractor |
| | appearance_feature_extractor = load_model("appearance_feature_extractor", args.models) |
| |
|
| | global motion_extractor |
| | motion_extractor = load_model("motion_extractor", args.models) |
| |
|
| | global warping_module |
| | warping_module = load_model("warping_module", args.models) |
| |
|
| | global spade_generator |
| | spade_generator = load_model("spade_generator", args.models) |
| |
|
| | global stitching_retargeting_module |
| | stitching_retargeting_module = load_model("stitching_retargeting_module", args.models) |
| |
|
| | source = args.source |
| | driving = args.driving |
| |
|
| | ffmpeg_dir = os.path.join(os.getcwd(), "ffmpeg") |
| | if osp.exists(ffmpeg_dir): |
| | os.environ["PATH"] += (os.pathsep + ffmpeg_dir) |
| |
|
| | if not fast_check_ffmpeg(): |
| | raise ImportError( |
| | "FFmpeg is not installed. Please install FFmpeg (including ffmpeg and ffprobe) before running this script. https://ffmpeg.org/download.html" |
| | ) |
| |
|
| | source_rgb_lst = preprocess(source) |
| | if is_video(args.driving): |
| | flag_is_driving_video = True |
| | |
| | output_fps = int(get_fps(args.driving)) |
| | driving_rgb_lst = load_video(args.driving) |
| | elif is_image(args.driving): |
| | flag_is_driving_video = False |
| | output_fps = 25 |
| | driving_rgb_lst = [load_image_rgb(driving)] |
| | else: |
| | raise Exception(f"{args.driving} is not a supported type!") |
| |
|
| | |
| | cropper: Cropper = Cropper() |
| | logger.info("Start making driving motion template...") |
| | driving_n_frames = len(driving_rgb_lst) |
| | n_frames = driving_n_frames |
| | driving_lmk_crop_lst = cropper.calc_lmks_from_cropped_video(driving_rgb_lst) |
| | driving_rgb_crop_256x256_lst = [cv2.resize(_, (256, 256)) for _ in driving_rgb_lst] |
| | |
| |
|
| | c_d_eyes_lst, c_d_lip_lst = calc_ratio(driving_lmk_crop_lst) |
| | |
| | I_d_lst = prepare_videos(driving_rgb_crop_256x256_lst) |
| | driving_template_dct = make_motion_template(I_d_lst, c_d_eyes_lst, c_d_lip_lst, output_fps=output_fps) |
| | |
| | |
| | |
| |
|
| | if not flag_is_driving_video: |
| | c_d_eyes_lst = c_d_eyes_lst * n_frames |
| | c_d_lip_lst = c_d_lip_lst * n_frames |
| |
|
| | I_p_pstbk_lst = [] |
| | logger.info("Prepared pasteback mask done.") |
| |
|
| | I_p_lst = [] |
| | R_d_0, x_d_0_info = None, None |
| | flag_normalize_lip = False |
| | flag_source_video_eye_retargeting = False |
| | lip_delta_before_animation, eye_delta_before_animation = None, None |
| |
|
| | |
| | |
| | flag_do_crop = True |
| | if flag_do_crop: |
| | crop_info = cropper.crop_source_image(source_rgb_lst[0]) |
| | if crop_info is None: |
| | raise Exception("No face detected in the source image!") |
| | source_lmk = crop_info['lmk_crop'] |
| | img_crop_256x256 = crop_info['img_crop_256x256'] |
| | else: |
| | source_lmk = cropper.calc_lmk_from_cropped_image(source_rgb_lst[0]) |
| | img_crop_256x256 = cv2.resize(source_rgb_lst[0], (256, 256)) |
| |
|
| | I_s = prepare_source(img_crop_256x256) |
| | x_s_info = get_kp_info(I_s) |
| | x_c_s = x_s_info['kp'] |
| | R_s = get_rotation_matrix(x_s_info['pitch'], x_s_info['yaw'], x_s_info['roll']) |
| | f_s = extract_feature_3d(I_s) |
| | x_s = transform_keypoint(x_s_info) |
| |
|
| | |
| | mask_crop: ndarray = cv2.imread(make_abs_path('./utils/resources/mask_template.png'), cv2.IMREAD_COLOR) |
| | mask_ori_float = prepare_paste_back(mask_crop, crop_info['M_c2o'], dsize=(source_rgb_lst[0].shape[1], source_rgb_lst[0].shape[0])) |
| |
|
| | with open(make_abs_path('./utils/resources/lip_array.pkl'), 'rb') as f: |
| | lip_array = pkl.load(f) |
| | device = "cpu" |
| | flag_is_source_video = False |
| | |
| | if flag_is_driving_video: |
| | logger.info(f"The animated video consists of {n_frames} frames.") |
| | else: |
| | logger.info(f"The output of image-driven portrait animation is an image.") |
| | for i in range(n_frames): |
| | x_d_i_info = driving_template_dct['motion'][i] |
| | x_d_i_info = dct2device(x_d_i_info, device) |
| | R_d_i = x_d_i_info['R'] if 'R' in x_d_i_info.keys() else x_d_i_info['R_d'] |
| |
|
| | if i == 0: |
| | R_d_0 = R_d_i |
| | x_d_0_info = x_d_i_info.copy() |
| |
|
| | delta_new = x_s_info['exp'].clone() |
| | R_new = x_d_r_lst_smooth[i] if flag_is_source_video else (R_d_i @ R_d_0.permute(0, 2, 1)) @ R_s |
| | if flag_is_driving_video: |
| | delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - x_d_0_info['exp']) |
| | else: |
| | delta_new = x_s_info['exp'] + (x_d_i_info['exp'] - torch.from_numpy(lip_array).to(dtype=torch.float32, device=device)) |
| | |
| | scale_new = x_s_info['scale'] if flag_is_source_video else x_s_info['scale'] * (x_d_i_info['scale'] / x_d_0_info['scale']) |
| | t_new = x_s_info['t'] if flag_is_source_video else x_s_info['t'] + (x_d_i_info['t'] - x_d_0_info['t']) |
| | t_new[..., 2].fill_(0) |
| | x_d_i_new = scale_new * (x_c_s @ R_new + delta_new) + t_new |
| |
|
| | if i == 0 and flag_is_driving_video: |
| | x_d_0_new = x_d_i_new |
| | motion_multiplier = calc_motion_multiplier(x_s, x_d_0_new) |
| | |
| | x_d_diff = (x_d_i_new - x_d_0_new) * motion_multiplier |
| | x_d_i_new = x_d_diff + x_s |
| |
|
| | |
| | |
| | x_d_i_new = stitching(x_s, x_d_i_new) |
| | x_d_i_new = x_s + (x_d_i_new - x_s) * 1.0 |
| | out = warp_decode(f_s, x_s, x_d_i_new) |
| | I_p_i = parse_output(out['out'])[0] |
| | I_p_lst.append(I_p_i) |
| | I_p_pstbk = paste_back(I_p_i, crop_info['M_c2o'], source_rgb_lst[0], mask_ori_float) |
| | I_p_pstbk_lst.append(I_p_pstbk) |
| |
|
| | mkdir(args.output_dir) |
| | wfp_concat = None |
| | |
| | |
| | frames_concatenated = concat_frames(driving_rgb_crop_256x256_lst, [img_crop_256x256], I_p_lst) |
| |
|
| | if flag_is_driving_video or (flag_is_source_video and not flag_is_driving_video): |
| | flag_source_has_audio = flag_is_source_video and has_audio_stream(args.source) |
| | flag_driving_has_audio = has_audio_stream(args.driving) |
| |
|
| | wfp_concat = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat.mp4') |
| |
|
| | |
| | output_fps = source_fps if flag_is_source_video else output_fps |
| | images2video(frames_concatenated, wfp=wfp_concat, fps=output_fps) |
| |
|
| | if flag_source_has_audio or flag_driving_has_audio: |
| | |
| | wfp_concat_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_concat_with_audio.mp4') |
| | audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source |
| | logger.info(f"Audio is selected from {audio_from_which_video}, concat mode") |
| | add_audio_to_video(wfp_concat, audio_from_which_video, wfp_concat_with_audio) |
| | os.replace(wfp_concat_with_audio, wfp_concat) |
| | logger.info(f"Replace {wfp_concat_with_audio} with {wfp_concat}") |
| |
|
| | |
| | wfp = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}.mp4') |
| | if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: |
| | images2video(I_p_pstbk_lst, wfp=wfp, fps=output_fps) |
| | else: |
| | images2video(I_p_lst, wfp=wfp, fps=output_fps) |
| |
|
| | |
| | if flag_source_has_audio or flag_driving_has_audio: |
| | wfp_with_audio = osp.join(args.output_dir, f'{basename(args.source)}--{basename(args.driving)}_with_audio.mp4') |
| | audio_from_which_video = args.driving if ((flag_driving_has_audio and args.audio_priority == 'driving') or (not flag_source_has_audio)) else args.source |
| | logger.info(f"Audio is selected from {audio_from_which_video}") |
| | add_audio_to_video(wfp, audio_from_which_video, wfp_with_audio) |
| | os.replace(wfp_with_audio, wfp) |
| | logger.info(f"Replace {wfp_with_audio} with {wfp}") |
| |
|
| | |
| | |
| | |
| | logger.info(f'Animated video: {wfp}') |
| | logger.info(f'Animated video with concat: {wfp_concat}') |
| | else: |
| | wfp_concat = osp.join(args.output_dir, f'{basename(source)}--{basename(driving)}_concat.jpg') |
| | cv2.imwrite(wfp_concat, frames_concatenated[0][..., ::-1]) |
| | wfp = osp.join(args.output_dir, f'{basename(source)}--{basename(driving)}.jpg') |
| | if I_p_pstbk_lst is not None and len(I_p_pstbk_lst) > 0: |
| | cv2.imwrite(wfp, I_p_pstbk_lst[0][..., ::-1]) |
| | else: |
| | cv2.imwrite(wfp, frames_concatenated[0][..., ::-1]) |
| | |
| | logger.info(f'Animated image: {wfp}') |
| | logger.info(f'Animated image with concat: {wfp_concat}') |
| |
|
| |
|
| | if __name__ == "__main__": |
| | """ |
| | Usage: |
| | python3 infer.py --source ../assets/examples/source/s0.jpg --driving ../assets/examples/driving/d8.jpg --models ./axmdoels --output-dir ./axmodel_infer |
| | """ |
| | timer = Timer() |
| | timer.tic() |
| | main() |
| | elapse = timer.toc() |
| | logger.debug(f'LivePortrait axmodel infer time: {elapse:.3f}s') |
| |
|