AutoTest
本篇为AutoTest快速上手和使用注意事项,源码逻辑与debug见:深度学习框架如何优雅的做算子对齐任务
简介
AutoTest框架相当于是一个high level的PyTorch,它的接口和PyTorch一样,但对于给定的输入会分别用OneFlow和PyTorch运行一遍,记录运行过程中得到的每个tensor以及对应梯度tensor的值,再对这些OneFlow和PyTorch分别产生的tensor检查一遍数值形状是否完全相同,以完成自动测试工作。
自动测试框架主要完成:
- 自动随机出各种合法参数组合成的Op
- 基于数值和类型完全相同的输入Tensor(PyTorch和OneFlow各有一份)分别运行PyTorch和OneFlow的代码
- 检查两份数值结果形状是否完全相同
使用
random
需要定义 random 模块的原因在于需要做到 Pytorch 和 OneFlow 输入统一的前提下随机生成输入。
k = random(1, 6)
x = random_pytorch_tensor(ndim=2, dim1=k)
y = random_pytorch_tensor(ndim=2, dim0=k)
print(x.value())
x = random(1, 4).to(int)
x.eval() # int
- 指定
k
将会在[1, 4)
这个区间进行取值 - 基于
random_pytorch_tensor
方法构造了两个2维随机tensorx
和y
,它们的维度分别是[m, k]
和[k, n]
,这些维度的值都是随机生成的。
def random_pytorch_tensor(
ndim=None, # 维度
dim0=1, # 一维大小
dim1=None, # 二维大小
dim2=None, # ……
dim3=None,
dim4=None,
low=0,
high=1,
dtype=float,
requires_grad=True,
):
- 可以粗略理解
random_pytorch_tensor
返回的是 PyTorch 与 Oneflow 两个版本的tensor,random_tensor
返回的是 Pytorch 版本的 tensor,二者的参数是一样的
@ autotest
def autotest(
n=20, # 测试次数
auto_backward=True, # 是否测试梯度回传
rtol=0.0001, # 搭配 check_allclose 使用,对应 np.allclose
atol=1e-05, # 搭配 check_allclose 使用,对应 np.allclose
check_graph=True,
check_allclose=True,
):
verbose = os.getenv("ONEFLOW_TEST_VERBOSE") is not None
def deco(f):
@functools.wraps(f)
def new_f(test_case):
nonlocal n
loop_limit = n * 20
loop = 0
while n > 0:
clear_note_fake_program()
if loop > loop_limit:
raise ValueError("autotest stuck in an endless loop!")
dual_modules_to_test.clear()
dual_objects_to_test.clear()
try:
global testing
testing = True
global testing_graph
if check_graph:
testing_graph = True
res = f(test_case) # 在给定输入的情况下去分别运行PyTorch和OneFlow的程序获得所有中间的输出tensor,包括tensor的梯度,并将它们记录到dual_modules_to_test这个列表
testing = False
testing_graph = False
except (PyTorchDoesNotSupportError, BothDoNotSupportError) as e:
if verbose:
print(f"{f.__name__}")
print(e)
loop += 1
continue
if res is not None:
if not isinstance(res, collections.abc.Sequence):
res = [res]
func_outputs = res
for x in res:
if auto_backward:
if isinstance(x.pytorch, torch_original.Tensor):
call_tensor_id.append(id(x.pytorch))
x.sum().backward()
dual_objects_to_test.append(x)
for x in dual_modules_to_test:
for key in x.pytorch.state_dict().keys():
if key not in x.oneflow.state_dict().keys():
warnings.warn(f"oneflow module don't have `{key}`")
continue
vis_parameters[key] = x.pytorch.state_dict()[key]
dual_objects_to_test.append(
GetDualObject(
"unused",
getattr(x.pytorch, key),
getattr(x.oneflow, key),
)
)
call_tensor_id.append(id(getattr(x.pytorch, key)))
dual_objects_to_test.append(
GetDualObject(
"unused",
getattr(x.pytorch, key).grad,
getattr(x.oneflow, key).grad,
)
)
call_tensor_id.append(id(getattr(x.pytorch, key).grad))
for x in dual_objects_to_test:
if (
isinstance(x.pytorch, torch_original.Tensor)
and id(x.pytorch) not in call_tensor_id
):
vis_tensor.append(x.pytorch)
# check eager
for x in dual_objects_to_test:
if check_allclose:
test_case.assertTrue(check_equality(x, rtol=rtol, atol=atol), x)
if verbose:
print(f"{f.__name__} test eager passed.")
n -= 1
loop += 1
return new_f
return deco
样例
import unittest # 导入自动测试框架unittest实现测试
import oneflow as flow # 支持flow.unittest.TestCase
import oneflow.unittest # 支持flow.unittest.TestCase
from oneflow.test_utils.automated_test_util import * # 导入autotest
class TestConvTranspose(flow.unittest.TestCase):
@autotest(n=1)
def test_ConvTranspose1d_(test_case):
channels = random(1, 6)
m = torch.nn.ConvTranspose1d(
in_channels=channels,
out_channels=random(1, 20),
kernel_size=random(1, 4),
stride=random() | nothing(),
padding=random(1, 3).to(int) | nothing(),
dilation=random(1, 5) | nothing(),
groups=random(1, 5) | nothing(),
padding_mode=constant("zeros") | nothing(),
)
m.train(random())
device = random_device()
m.to(device)
x = random_tensor(ndim=3, dim1=channels).to(device)
y = m(x)
return y
if __name__ == "__main__":
unittest.main()
注意事项
- test 方法体中的 torch 由 autotest 处理,所以调用 torch.nn.functional.conv1d 时不可以使用 import torch.nn.functional as F 进行简写
- torch.randn 不宜使用,应当使用 torch.ones 或者 random 方法替代。因为自动测试框架里会用pytorch和oneflow分别产生一次随机tensor,那很可能输入是不相等的,这样得到的结果数值检查就会失败