File size: 1,590 Bytes
7667a87 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 |
from copy import deepcopy
from typing import Dict, Any
def layout_to_in_out_slice(layout, in_len, out_len=None):
t_axis = layout.find("T")
num_axes = len(layout)
in_slice = [slice(None, None), ] * num_axes
out_slice = deepcopy(in_slice)
in_slice[t_axis] = slice(None, in_len)
if out_len is None:
out_slice[t_axis] = slice(in_len, None)
else:
out_slice[t_axis] = slice(in_len, in_len + out_len)
return in_slice, out_slice
def step_layout_to_in_out_slice(
layout,
in_len, in_step:int=1,
out_len=None, out_step:int=1,
in_out_diff:int=1
):
t_axis = layout.find("T")
num_axes = len(layout)
in_slice = [slice(None, None), ] * num_axes
out_slice = deepcopy(in_slice)
in_slice[t_axis] = slice(None, in_len*in_step, in_step)
out_start = in_len * in_step + in_out_diff - in_step
if out_len is None:
out_slice[t_axis] = slice(out_start, None, out_step)
else:
out_slice[t_axis] = slice(out_start, out_start + out_len*out_step,out_step)
return in_slice, out_slice
def parse_layout_shape(layout: str) -> Dict[str, Any]:
r"""
Parameters
----------
layout: str
e.g., "NTHWC", "NHWC".
Returns
-------
ret: Dict
"""
batch_axis = layout.find("N")
t_axis = layout.find("T")
h_axis = layout.find("H")
w_axis = layout.find("W")
c_axis = layout.find("C")
return {
"batch_axis": batch_axis,
"t_axis": t_axis,
"h_axis": h_axis,
"w_axis": w_axis,
"c_axis": c_axis,
}
|