Spaces:
Runtime error
Runtime error
| import time | |
| import atexit | |
| import inspect | |
| import torch | |
| from typing import Optional, Union, List | |
| from pathlib import Path | |
| from concurrent.futures import ThreadPoolExecutor | |
| def fold_path(fn:str): | |
| ''' Fold a path like `from/to/file.py` to relative `f/t/file.py`. ''' | |
| return '/'.join([p[:1] for p in fn.split('/')[:-1]]) + '/' + fn.split('/')[-1] | |
| def summary_frame_info(frame:inspect.FrameInfo): | |
| ''' Convert a FrameInfo object to a summary string. ''' | |
| return f'{frame.function} @ {fold_path(frame.filename)}:{frame.lineno}' | |
| class TimeMonitorDisabled: | |
| def foo(self, *args, **kwargs): | |
| return | |
| def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): | |
| self.tick = self.foo | |
| self.report = self.foo | |
| self.clear = self.foo | |
| self.dump_statistics = self.foo | |
| def __call__(self, *args, **kwargs): | |
| return self | |
| def __enter__(self): | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| return | |
| class TimeMonitor: | |
| ''' | |
| It is supposed to be used like this: | |
| time_monitor = TimeMonitor() | |
| with time_monitor('test_block', 'Block that does something.') as tm: | |
| do_something() | |
| time_monitor.report() | |
| ''' | |
| def __init__(self, log_folder:Optional[Union[str, Path]]=None, record_birth_block:bool=False): | |
| if log_folder is not None: | |
| self.log_folder = Path(log_folder) if isinstance(log_folder, str) else log_folder | |
| self.log_folder.mkdir(parents=True, exist_ok=True) | |
| log_fn = self.log_folder / 'readable.log' | |
| self.log_fh = open(log_fn, 'w') # Log file handler. | |
| self.log_fh.write('=== New Exp ===\n') | |
| else: | |
| self.log_folder = None | |
| self.clear() | |
| self.current_block_uid_stack : List = [] # Unique block id stack for recording. | |
| self.current_block_aid_stack : List = [] # Block id stack for accumulated cost analysis. | |
| # Specially add a global start and end block. | |
| self.record_birth_block = record_birth_block and log_folder is not None | |
| if self.record_birth_block: | |
| self.__call__('monitor_birth', 'Since the monitor is constructed.') | |
| self.__enter__() | |
| # Register the exit hook to dump the data safely. | |
| atexit.register(self._die_hook) | |
| def __call__(self, block_name:str, block_desc:Optional[str]=None): | |
| ''' Set up the name of the context for a block. ''' | |
| # 1. Format the block name. | |
| block_name = block_name.replace('/', '-').replace(' ', '-') | |
| block_name_recursive = '/'.join([s.split('/')[-1] for s in self.current_block_aid_stack] + [block_name]) # Tree structure block name. | |
| # 2. Get a unique name for the block record. | |
| block_postfixed = 0 | |
| while f'{block_name_recursive}_{block_postfixed}' in self.block_info: | |
| block_postfixed += 1 | |
| # 3. Get the caller frame information. | |
| caller_frame = inspect.stack()[1] | |
| block_position = summary_frame_info(caller_frame) | |
| # 4. Initialize the block information. | |
| self.current_block_uid_stack.append(f'{block_name_recursive}_{block_postfixed}') | |
| self.current_block_aid_stack.append(block_name) | |
| self.block_info[self.current_block_uid_stack[-1]] = { | |
| 'records' : [], | |
| 'position' : block_position, | |
| 'desc' : block_desc, | |
| } | |
| return self | |
| def __enter__(self): | |
| caller_frame = inspect.stack()[1] | |
| record = self._tick_record(caller_frame, 'Start of the block.') | |
| self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
| return self | |
| def __exit__(self, exc_type, exc_val, exc_tb): | |
| caller_frame = inspect.stack()[1] | |
| record = self._tick_record(caller_frame, 'End of the block.') | |
| self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
| # Finish one block. | |
| curr_block_uid = self.current_block_uid_stack.pop() | |
| curr_block_aid = '/'.join(self.current_block_aid_stack) | |
| self.current_block_aid_stack.pop() | |
| self.finished_blocks.append(curr_block_uid) | |
| elapsed = self.block_info[curr_block_uid]['records'][-1]['timestamp'] \ | |
| - self.block_info[curr_block_uid]['records'][0]['timestamp'] | |
| self.block_cost[curr_block_aid] = self.block_cost.get(curr_block_aid, 0) + elapsed | |
| if hasattr(self, 'dump_thread'): | |
| self.dump_thread.result() | |
| with ThreadPoolExecutor() as executor: | |
| self.dump_thread = executor.submit(self.dump_statistics) | |
| def tick(self, desc:str=''): | |
| ''' | |
| Record a intermediate timestamp. These records are only for in-block analysis, | |
| and will be ignored when analyzing in global view. | |
| ''' | |
| caller_frame = inspect.stack()[1] | |
| record = self._tick_record(caller_frame, desc) | |
| self.block_info[self.current_block_uid_stack[-1]]['records'].append(record) | |
| return | |
| def report(self, level:Union[str, List[str]]='global'): | |
| import rich | |
| caller_frame = inspect.stack()[1] | |
| caller_info = summary_frame_info(caller_frame) | |
| if isinstance(level, str): | |
| level = [level] | |
| for lv in level: # To make sure we can output in order. | |
| if lv == 'block': | |
| rich.print(f'[bold underline][EA-B][/bold underline] {caller_info} -> blocks level records:') | |
| for block_name in self.finished_blocks: | |
| msg = '\t' + self._generate_block_msg(block_name).replace('\n\t', '\n\t\t') | |
| rich.print(msg) | |
| elif lv == 'global': | |
| rich.print(f'[bold underline][EA-G][/bold underline] {caller_info} -> global efficiency analysis:') | |
| for block_name, cost in self.block_cost.items(): | |
| rich.print(f'\t{block_name}: {cost:.2f} sec') | |
| def clear(self): | |
| self.finished_blocks = [] | |
| self.block_info = {} | |
| self.block_cost = {} | |
| def dump_statistics(self): | |
| ''' Dump the logging raw data for post analysis. ''' | |
| if self.log_folder is None: | |
| return | |
| dump_fn = self.log_folder / 'statistics.pkl' | |
| with open(dump_fn, 'wb') as f: | |
| import pickle | |
| pickle.dump({ | |
| 'finished_blocks' : self.finished_blocks, | |
| 'block_info' : self.block_info, | |
| 'block_cost' : self.block_cost, | |
| 'curr_aid_stack' : self.current_block_aid_stack, # nonempty when when errors happen inside a block | |
| }, f) | |
| # TODO: Draw a graph to visualize the time consumption. | |
| def _tick_record(self, caller_frame, desc:Optional[str]=''): | |
| # 1. Generate the record. | |
| torch.cuda.synchronize() | |
| timestamp = time.time() | |
| readable_time = time.strftime('%Y-%m-%d %H:%M:%S', time.localtime(timestamp)) | |
| position = summary_frame_info(caller_frame) | |
| record = { | |
| 'time' : readable_time, | |
| 'timestamp' : timestamp, | |
| 'position' : position, | |
| 'desc' : desc, | |
| } | |
| # 2. Log the record. | |
| if self.log_folder is not None: | |
| block_uid = self.current_block_uid_stack[-1] | |
| log_msg = f'[{readable_time}] ποΈ {block_uid} π {desc} π {position}' | |
| self.log_fh.write(log_msg + '\n') | |
| return record | |
| def _generate_block_msg(self, block_name): | |
| block_info = self.block_info[block_name] | |
| block_position = block_info['position'] | |
| block_desc = block_info['desc'] | |
| records = block_info['records'] | |
| msg = f'ποΈ {block_name} π {block_desc} π {block_position}' | |
| for rid, record in enumerate(records): | |
| readable_time = record['time'] | |
| tick_desc = record['desc'] | |
| tick_position = record['position'] | |
| if rid > 0: | |
| prev_record = records[rid-1] | |
| tick_elapsed = record['timestamp'] - prev_record['timestamp'] | |
| tick_elapsed = f'{tick_elapsed:.2f} s' | |
| else: | |
| tick_elapsed = 'N/A' | |
| msg += f'\n\t[{readable_time}] β³ {tick_elapsed} π {tick_desc} π {tick_position}' | |
| return msg | |
| def _die_hook(self): | |
| if self.record_birth_block: | |
| self.__exit__(None, None, None) | |
| self.dump_statistics() | |
| if self.log_folder is not None: | |
| self.log_fh.close() |