| from copy import deepcopy | |
| from typing import Dict, Any | |
| def layout_to_in_out_slice(layout, in_len, out_len=None): | |
| t_axis = layout.find("T") | |
| num_axes = len(layout) | |
| in_slice = [slice(None, None), ] * num_axes | |
| out_slice = deepcopy(in_slice) | |
| in_slice[t_axis] = slice(None, in_len) | |
| if out_len is None: | |
| out_slice[t_axis] = slice(in_len, None) | |
| else: | |
| out_slice[t_axis] = slice(in_len, in_len + out_len) | |
| return in_slice, out_slice | |
| def step_layout_to_in_out_slice( | |
| layout, | |
| in_len, in_step:int=1, | |
| out_len=None, out_step:int=1, | |
| in_out_diff:int=1 | |
| ): | |
| t_axis = layout.find("T") | |
| num_axes = len(layout) | |
| in_slice = [slice(None, None), ] * num_axes | |
| out_slice = deepcopy(in_slice) | |
| in_slice[t_axis] = slice(None, in_len*in_step, in_step) | |
| out_start = in_len * in_step + in_out_diff - in_step | |
| if out_len is None: | |
| out_slice[t_axis] = slice(out_start, None, out_step) | |
| else: | |
| out_slice[t_axis] = slice(out_start, out_start + out_len*out_step,out_step) | |
| return in_slice, out_slice | |
| def parse_layout_shape(layout: str) -> Dict[str, Any]: | |
| r""" | |
| Parameters | |
| ---------- | |
| layout: str | |
| e.g., "NTHWC", "NHWC". | |
| Returns | |
| ------- | |
| ret: Dict | |
| """ | |
| batch_axis = layout.find("N") | |
| t_axis = layout.find("T") | |
| h_axis = layout.find("H") | |
| w_axis = layout.find("W") | |
| c_axis = layout.find("C") | |
| return { | |
| "batch_axis": batch_axis, | |
| "t_axis": t_axis, | |
| "h_axis": h_axis, | |
| "w_axis": w_axis, | |
| "c_axis": c_axis, | |
| } | |