fileio
作为MMCV的一个文件操作的核心模块,提供了一套统一的API根据不同的后端实现不同格式文件的序列化(dump)和反序列化(load)。PS: 从v1.3.16之后,MMCV才开始支持不同后端文件的序列化和反序列化,具体细节在#1330可以看到。
MMCV对文件的序列化和反序列化提供了统一的接口,因此调用起来十分简便:
'''Load from disk or dump to disk
'''
import mmcv# load data from different files
data_json = mmcv.load('test.json')
data_yaml = mmcv.load('test.yaml')
data_pkl = mmcv.load('out_json.pkl')# load data from a file-like object
with open('test.json', 'r') as file:data_json = mmcv.load(file, 'json')# dump data
mmcv.dump(data_json, 'out_json.pkl')print('json:', data_json)
print('yaml:', data_yaml)
print('pkl:', data_pkl)
'''
Ouput:
json: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
yaml: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
pkl: {'a': [1, 2, 3], 'b': None, 'c': 'hello'}
'''
fileio
模块中有两个核心组件:一个是负责不同格式文件读写的FileHandler和不同后端文件获取的FileClient。
只需要自定义Handler继承BaseFileHandler,并且实现load_from_fileobj
、dump_to_fileobj
、dump_to_str
这三个方法即可。下面是自定义.log文件加载为list的读取示例:
'''Convert data joined by comma in .log file to list obj mutually.
'''
# here, we need to writer a file handler inherited from BaseFileHandler and register it@mmcv.register_handler(['txt', 'log'])
class LogFileHandler(mmcv.BaseFileHandler):str_like = Truedef __init__(self) -> None:super().__init__()def load_from_fileobj(self, file, **kwargs):res_list = []for line in file.readlines():res_list.append(line.split(','))return res_listdef dump_to_fileobj(self, obj, file, **kwargs):for line_data in obj:file.write(','.join(line_data))def dump_to_str(self, obj, **kwargs):return str(obj) + ':str'# Running script
data_txt = mmcv.load('test.log')
print(data_txt)
mmcv.dump(data_txt, 'out_txt.log')
data_str = mmcv.dump(data_txt, file_format='log')
print(data_str)
'''
Output:
In terminal:
[['a', ' b', ' c\n'], ['e', ' f']]
[['a', ' b', ' c\n'], ['e', ' f']]:strIn test.log:
a, b, c
e, fIn out_txt.log:
a, b, c
e, f
'''
接下来说明几个需要注意的点:
mmcv.load/dump
方法中通过FileClient
去解析uri从而选择对应的backend访问指定的文件,并将文件转为file object这样的字节流。 FileClient
组件主要是提供统一的不同后台文件访问API,其实做的事情就是将不同后端的文件转换为字节流。同样,自定义后台只需要继承BaseStorageBackend
类,实现get和get_text方法即可。
class BaseStorageBackend(metaclass=ABCMeta):@abstractmethoddef get(self, filepath):pass@abstractmethoddef get_text(self, filepath):pass
下面是HardDiskBackend
类的示例:
class HardDiskBackend(BaseStorageBackend):"""Raw hard disks storage backend."""_allow_symlink = Truedef get(self, filepath: Union[str, Path]) -> bytes:"""Read data from a given ``filepath`` with 'rb' mode.Args:filepath (str or Path): Path to read data.Returns:bytes: Expected bytes object."""with open(filepath, 'rb') as f:value_buf = f.read()return value_bufdef get_text(self,filepath: Union[str, Path],encoding: str = 'utf-8') -> str:"""Read data from a given ``filepath`` with 'r' mode.Args:filepath (str or Path): Path to read data.encoding (str): The encoding format used to open the ``filepath``.Default: 'utf-8'.Returns:str: Expected text reading from ``filepath``."""with open(filepath, 'r', encoding=encoding) as f:value_buf = f.read()return value_buf
最后,我们这里尝试编写一段自动选择后端读取和保存Pytorch CheckPoint的程序:
# implement an interface which automatically select the corresponding backend
import torch
from mmcv.fileio.file_client import FileClientdef load_checkpoint(path):file_client = FileClient.infer_client(uri=path)with io.BytesIO(file_client.get(path)) as buffer:ckpt = torch.load(buffer)return ckptdef save_ckpt(ckpt, path):file_client = FileClient.infer_client(uri=path)with io.BytesIO() as buffer:torch.save(ckpt, buffer)file_client.put(buffer.getvalue(), path)# Running script
ckpt_path = 'https://download.pytorch.org/models/resnet18-f37072fd.pth'
ckpt = load_checkpoint(ckpt_path) # HttpbHTTPBackend
# print(type(ckpt))
print(ckpt)save_ckpt(ckpt, 'resnet18-f37072fd.pth')
'''
Output
OrderedDict([('conv1.weight', Parameter containing:
tensor([[[[-1.0419e-02, -6.1356e-03, -1.8098e-03, ..., 5.6615e-02,1.7083e-02, -1.2694e-02], ...
'''