内容:
- S3
借助面向对象的编码风格,并加以合理的抽象,我们可以简单地模仿对象的重要特性,于是,问题和模型之间的转换就变得清晰自然。
S3对象
S3对象系统是一个简单且宽松的面向对象系统。每个基本对象的类型都有一个S3类名称。比如integer
,numeric
, character
, logical
, list
和data.frame
都属于S3类。
举例,下面vec1
类型是double
,意味其内部类型或者说存储模式是双精度浮点型数字。但它的类是numeric
。
vec1 = c(1, 2, 3)typeof(vec1)#> [1] "double"class(vec1)#> [1] "numeric"
下面data1
类型是list
,意味data1
的内部类型或者存储模式是列表,但它的S3类是data.frame
。
data1 = data.frame(x = 1:3, y = rnorm(3))typeof(data1)#> [1] "list"class(data1)#> [1] "data.frame"
理解对象的内部类型与S3类区别是一个重点。
一个类可以用多种方法定义它的行为,尤其是它与其他类的关系。在S3系统中,我们可以创建泛型函数(generic function),对于不同的类,由泛型函数决定调用哪个方法,这就是S3方法分派(method dispatch)的工作机理。
对象的类不同,其方法分派不同,因此,区别对象的类十分重要。
R中有许多基于某个通用目的定义的S3泛型函数,我们先看看head()
与tail()
。head()
展示一个数据对象的前n条记录,tail()
展示后n条。这跟x[1:n]
是不同的,因为对不同的类的对象,记录的定义是不同的。对原子向量(数值、字符向量等),前n条记录指前n个元素。但对于数据框,前n条记录指前n行而不是前n列。
查看下head
的函数内部信息:
head#> function (x, ...) #> UseMethod("head")#> <bytecode: 0x0000000018fcb138>#> <environment: namespace:utils>
我们发现函数中并没有实际的操作细节。它调用UseMethod("head")
来让泛型函数head()
执行方法分派,也就是说,对于不同的类,它可能有不同的执行方式(过程)。
num_vec = c(1, 2, 3, 4, 5)data_frame = data.frame(x = 1:5, y = rnorm(5))
调用函数:
head(num_vec, 3)#> [1] 1 2 3head(data_frame, 3)#> x y#> 1 1 0.537#> 2 2 1.072#> 3 3 0.181
我们可以使用methods()
查看head()
函数可以实现的所有方法:
methods("head")#> [1] head.data.frame* head.default* head.ftable* head.function* #> [5] head.matrix head.table* #> see '?methods' for accessing help and source code
可以看到head
不仅仅适用于向量和数据框。
注意,方法都是以method.class
形式表示,如果我们输入一个data.frame
,head()
会调用head.data.frame
方法。当没有方法可以匹配对象的类时,函数会自动转向method.default
方法。这就是方法分派的一个实际过程。
内置类和方法
S3泛型函数和方法在统一各个模型的使用方式上是最有用的。比如我们可以创建一个线性模型,以不同角度查看模型信息:
lm1 = lm(mpg ~ cyl + vs, data = mtcars)
线性模型本质上是由模型拟合产生的数据字段构成的列表,所以lm1
的类型是list
,但是它的类是lm
,因此泛型函数根据lm
选择方法:
typeof(lm1)#> [1] "list"class(lm1)#> [1] "lm"
甚至没有明确调用S3泛型函数时,S3方法分派也会自动进行。如果我们输入lm1
:
lm1#> #> Call:#> lm(formula = mpg ~ cyl + vs, data = mtcars)#> #> Coefficients:#> (Intercept) cyl vs #> 39.625 -3.091 -0.939
实际上,print()
函数被默默地调用了:
print(lm1)#> #> Call:#> lm(formula = mpg ~ cyl + vs, data = mtcars)#> #> Coefficients:#> (Intercept) cyl vs #> 39.625 -3.091 -0.939
为什么打印出来的不像列表呢?因为print()
是一个泛型函数,它为lm
选择了一个方法来打印线性模型最重要的信息。我们可以调用getS3method("print", "lm")
获取实际使用的方法与想象的进行验证:
identical(getS3method("print", "lm"), stats:::print.lm)#> [1] TRUE
print()
展示模型的一个简要版本,summary()
展示更详细的信息。summary()
也是一个泛型函数,它为模型的所有类提供了许多方法:
summary(lm1)#> #> Call:#> lm(formula = mpg ~ cyl + vs, data = mtcars)#> #> Residuals:#> Min 1Q Median 3Q Max #> -4.923 -1.953 -0.081 1.319 7.577 #> #> Coefficients:#> Estimate Std. Error t value Pr(>|t|) #> (Intercept) 39.625 4.225 9.38 2.8e-10 ***#> cyl -3.091 0.558 -5.54 5.7e-06 ***#> vs -0.939 1.978 -0.47 0.64 #> ---#> Signif. codes: 0 '***' 0.001 '**' 0.01 '*' 0.05 '.' 0.1 ' ' 1#> #> Residual standard error: 3.25 on 29 degrees of freedom#> Multiple R-squared: 0.728, Adjusted R-squared: 0.71 #> F-statistic: 38.9 on 2 and 29 DF, p-value: 6.23e-09
实际上,summary()
的输出结果也是一个对象,包含的数据都可以被访问。在这个例子里,这个对象是一个列表,是summary.lm
类,它有可供print()
选择的自己的方法:
lm1summary = summary(lm1)typeof(lm1summary)#> [1] "list"class(lm1summary)#> [1] "summary.lm"
查看列表成分:
names(lm1summary)#> [1] "call" "terms" "residuals" "coefficients" #> [5] "aliased" "sigma" "df" "r.squared" #> [9] "adj.r.squared" "fstatistic" "cov.unscaled"
还有一些其他有用的且与模型相关的泛型函数,例如plot()
,predict()
。不同的内置模型和第三方扩展包提供的模型都能实现这些泛型函数。
举例,我们可以对线性模型调用plot()
函数:
oldpar = par(mfrow = c(2, 2))plot(lm1)
par(oldpar)
为避免依次生成这4个图,我们用par()
将绘图区域划分为2x2的子区域。
利用predict()
我们可以使用模型对新数据进行预测,泛型函数predict()
自动选择正确的方法用新数据进行预测:
predict(lm1, data.frame(cyl = c(6, 8), vs = c(1, 1)))#> 1 2 #> 20.1 14.0
这个函数既可以用在样本内,又可以用在样本外。如果我们为模型提供新数据,它就进行样本外预测。
下面我们创建一幅真实值和拟合值的散点图,看一看线性模型的预测效果:
plot(mtcars$mpg, fitted(lm1))
这里的fitted()
也是泛型函数,等价于lm1$fitted.values
,拟合值等于用原始数据得到的预测值,即用原始数据构建的模型预测原始数据,predict(lm1, mtcars)
。
真实值与拟合值的差称为残差,可以通过另一个泛型函数residuals()
获得。
plot(density(residuals(lm1)), main = "Density of lm1 residuals")
这些泛型函数不仅适用于lm
、glm
和其他内置模型,也适用于其他扩展包提供的模型。
例如我们使用rpart
包,使用前面的数据和公式拟合一个回归树模型。
if(!require("rpart")) install.packages("rpart")#> 载入需要的程辑包:rpartlibrary(rpart)
tree_model = rpart(mpg ~ cyl + vs, data = mtcars)
我们之所以能够使用相同的方法,是因为这个包的作者希望函数调用的方式与调用R内置函数保持一致。
typeof(tree_model)#> [1] "list"class(tree_model)#> [1] "rpart"
打印模型:
print(tree_model)#> n= 32 #> #> node), split, n, deviance, yval#> * denotes terminal node#> #> 1) root 32 1130.0 20.1 #> 2) cyl>=5 21 198.0 16.6 #> 4) cyl>=7 14 85.2 15.1 *#> 5) cyl< 7 7 12.7 19.7 *#> 3) cyl< 5 11 203.0 26.7 *
更详细信息:
summary(tree_model)#> Call:#> rpart(formula = mpg ~ cyl + vs, data = mtcars)#> n= 32 #> #> CP nsplit rel error xerror xstd#> 1 0.6431 0 1.000 1.089 0.2579#> 2 0.0893 1 0.357 0.432 0.0811#> 3 0.0100 2 0.268 0.427 0.0818#> #> Variable importance#> cyl vs #> 65 35 #> #> Node number 1: 32 observations, complexity param=0.643#> mean=20.1, MSE=35.2 #> left son=2 (21 obs) right son=3 (11 obs)#> Primary splits:#> cyl < 5 to the right, improve=0.643, (0 missing)#> vs < 0.5 to the left, improve=0.441, (0 missing)#> Surrogate splits:#> vs < 0.5 to the left, agree=0.844, adj=0.545, (0 split)#> #> Node number 2: 21 observations, complexity param=0.0893#> mean=16.6, MSE=9.45 #> left son=4 (14 obs) right son=5 (7 obs)#> Primary splits:#> cyl < 7 to the right, improve=0.507, (0 missing)#> Surrogate splits:#> vs < 0.5 to the left, agree=0.857, adj=0.571, (0 split)#> #> Node number 3: 11 observations#> mean=26.7, MSE=18.5 #> #> Node number 4: 14 observations#> mean=15.1, MSE=6.09 #> #> Node number 5: 7 observations#> mean=19.7, MSE=1.81
下面对结果进行可视化,得到树图:
oldpar = par(xpd = NA)plot(tree_model)text(tree_model, use.n = TRUE)
par(oldpar)
为现有类定义泛型函数
在定义泛型函数时,我们创建一个函数去调用UseMethod()出发方法分派。然后对泛型函数想要作用的类创建带有method.class形式的方法函数,同时还要创建带有method.default形式的默认方法来应对所有其他情况。
下面我们创建一个新的泛型函数generic_head()
,它有两个参数:输入对象x和需要提取的记录条数n。泛型函数仅仅调用UseMethod("generic_head")
来让R根据输入对象x
的类执行方法分派。
generic_head = function(x, n) UseMethod("generic_head")
对原子向量提取前n
个元素,因此分别定义generic_head.numeric
、generic_head.character
等,另外最好定义一个默认方法捕获不能匹配的其他所有情况:
generic_head.default = function(x, n){ x[1:n]}
现在generic_head
只有一种方法,等于没有使用泛型函数:
generic_head(num_vec, 3)#> [1] 1 2 3
现在我们还没有定义针对data.frame
类的方法,所以当我们输入数据框时,函数会自动转向generic_head.default
,又因为提取的数量超出列数,所以下面的运行会报错:
generic_head(data_frame, 3)#> Error in `[.data.frame`(x, 1:n): 选择了未定义的列
下面为data.frame
定义方法:
generic_head.data.frame = function(x, n) { x[1:n, ]}
现在函数就可以正常运行了:
generic_head(data_frame, 3)#> x y#> 1 1 0.537#> 2 2 1.072#> 3 3 0.181
因为没有对参数进行检查,所以S3类执行的方法并不稳健。
定义新类并创建对象
现在我们来尝试构建新类,class(x)
获取x
的类,而class(x) = some_class
将x
的类设为some_class
。
使用列表作为底层数据结构
列表可能是创建新类时使用最广泛的数据结构,类描述了对象的类型和对象交互作用的方法,其中对象用于存储多种多样、长度不一的数据。
下面我们定义一个叫product
的函数,创建一个由name
、price
和inventory
构成的列表,该列表的类是product
。我们还将自己定义它的print
方法。
productor = function(name, price, inventory){ obj = list(name = name, price = price, inventory = inventory) class(obj) = "product" obj}
上面我们创建了一个列表,然后将它的类替换为product
。我们还可以使用structure()
:
product = function(name, price, inventory){ structure(list(name = name, price = price, inventory = inventory), class = "product")}
现在我们调用product()
函数生成product
类的实例:
laptop = product("Laptop", 499, 300)
查看它的结构和S3类方法分派:
typeof(laptop)#> [1] "list"class(laptop)#> [1] "product"
此时我们还没有为该类定义任何方法,如果print
将按默认列表输出:
print(laptop)#> $name#> [1] "Laptop"#> #> $price#> [1] 499#> #> $inventory#> [1] 300#> #> attr(,"class")#> [1] "product"
下面我们自定义一个print
方法,使得输出更紧凑:
print.product = function(x, ...){ cat("<product>\n") cat("name:", x$name, "\n") cat("price:", x$price, "\n") cat("inventory:", x$inventory, "\n") invisible(x)}
print
方法返回输入对象本身以备后用,这是一项约定。
现在我们再来看看输出:
laptop#> <product>#> name: Laptop #> price: 499 #> inventory: 300
我们可以像操作列表一样访问laptop
的成分:
laptop$name#> [1] "Laptop"laptop$price#> [1] 499laptop$inventory#> [1] 300
如果我们创建另一个对象,并将两者放入一个列表然后打印,print.product
仍然会被调用:
cellphone = product("Phone", 249, 12000)products = list(laptop, cellphone)products#> [[1]]#> <product>#> name: Laptop #> price: 499 #> inventory: 300 #> #> [[2]]#> <product>#> name: Phone #> price: 249 #> inventory: 12000
当products
以列表形式被打印时,会对每个元素调用print()
泛型函数,再由泛型函数执行方法分派。
大多数其他编程语言都对类有正式的定义,而S3没有,所以创建一个S3对象比较简单,但我们需要对输入参数进行充分的检查,以确保创建的对象与所属类内部一致。
除了定义新类,我们还可以定义新的泛型函数。下面创建一个叫value
的泛型函数,它通过测量产品的库存值来为product
调用实施方法:
value = function(x, ...) UseMethod("value")value.default = function(x, ...){ stop("Value is undefined")}value.product = function(x, ...){ x$price * x$inventory}
针对其他类,value
调用default
方法并终止运行。
value(laptop)#> [1] 149700value(cellphone)#> [1] 2988000value(data_frame)#> Error in value.default(data_frame): Value is undefined
使用原子向量作为底层数据结构
上面我们已经演示了创建S3类和泛型函数的过程,有时候我们需要使用原子向量创建新类,下面展示百分比形式向量创建过程。
首先定义一个percent
函数,它检查输入是否是数值向量并将输入对象类型改为percent
,percent
类继承numeric
类:
percent = function(x){ stopifnot(is.numeric(x)) class(x) = c("percent", "numeric") x}
这里的继承指方法分派首先在percent类中方法找,找不到就去numeric类方法中找。寻找的顺序由类名称的顺序决定。
pct = percent(c(0.1, 0.05, 0.25))pct#> [1] 0.10 0.05 0.25#> attr(,"class")#> [1] "percent" "numeric"
现在定义方法,让percent
类以百分比形式存在:
as.character.percent = function(x, ...){ paste0(as.numeric(x) * 100, "%")}
现在我们可以得到字符型了:
as.character(pct)#> [1] "10%" "5%" "25%"
也可以直接调用as.character()
为percent
提供一个format
方法:
format.percent = function(x, ...){ as.character(x, ...)}
format
现在有相同的效果:
format(pct)#> [1] "10%" "5%" "25%"
类似地,我们调用format.percent()
为percent
提供print
方法:
print.percent = function(x, ...){ print(format.percent(x), quote = FALSE)}
这里指定quote=FALSE
使得打印的格式化字符串更像数字而非字符串。
pct#> [1] 10% 5% 25%
注意,使用算术运算符操作后会自动保持输出向量类不变:
pct + 0.2#> [1] 30% 25% 45%pct * 0.5#> [1] 5% 2.5% 12.5%
可惜使用其他函数可能不会保持输入对象的类,比如sum()
、mean()
等:
sum(pct)#> [1] 0.4mean(pct)#> [1] 0.133max(pct)#> [1] 0.25min(pct)#> [1] 0.05
为了确保百分比形式保存,我们对percent
类实施一些操作:
sum.percent = function(...){ percent(NextMethod("sum"))}mean.percent = function(x, ...){ percent(NextMethod("mean"))}max.percent = function(...){ percent(NextMethod("max"))}min.percent = function(...){ percent(NextMethod("max"))}
NextMethod("sum")
对numeric
类调用sum()
函数,然后再调用percent()
函数将输出的数值向量包装为百分比形式:
sum(pct)#> [1] 40%mean(pct)#> [1] 13.3333333333333%max(pct)#> [1] 25%min(pct)#> [1] 5%
但如果我们组合一个百分比向量和其他数值型的值,percent
类又会消失掉,我们进行相同的改进:
c.percent = function(x, ...){ percent(NextMethod("c"))}
c(pct, 0.12)#> [1] 10% 5% 25% 12%
dan….我们取子集又会有问题
pct[1:3]#> [1] 0.10 0.05 0.25pct[[2]]#> [1] 0.05
同样地,我们对[
和[[
函数进行改造:
`[.percent` = function(x, i) { percent(NextMethod('['))}`[[.percent` = function(x, i){ percent(NextMethod("[["))}
此时显示就正常了:
pct[1:3]#> [1] 10% 5% 25%pct[[2]]#> [1] 5%
实现这些方法后,我们可以在数据框中使用:
data.frame(id = 1:3, pct)#> id pct#> 1 1 10%#> 2 2 5%#> 3 3 25%
S3继承
假设我们想要对一些交通工具,例如汽车、公共汽车和飞机进行建模。这些交通工具有一些共性,它们都有名称、速度、位置,而且都可以移动。为了形象化描述它们,我们定义一个基本类,称为vehichle
,用于存储公共部分,另外定义car
、bus
和airplane
这3个子类,它们继承vehichle
,但具有自定义的行为。
首先,定义一个函数来创建vehicle
对象,它本质上是一个环境。我们选择环境而不是列表,因为需要用到环境的引用语义,也就是说,我们传递一个对象,然后原地修改它,而不会创建这个对象的副本。因此无论什么位置将对象传递给函数,对象总是指向同一个交通工具。
Vehicle = function(class, name, speed) { obj = new.env(parent = emptyenv()) obj$name = name obj$speed = speed obj$position = c(0, 0, 0) class(obj) = c(class, "vehicle") obj}
这里的class(obj) = c(class, "vehicle")
似乎有点语义不明。但前者是基础函数,后者是输入参数,R能够判断好。
下面函数创建继承vehicle
的car
、bus
和airplane
的特定函数:
Car = function(...){ Vehicle(class = "car", ...)}Bus = function(...){ Vehicle(class = "bus", ...)}Airplane = function(...){ Vehicle(class = "airplane", ...)}
现在我们可以为每一个子类创建实例:
car = Car("Model-A", 80)bus = Bus("Medium-Bus", 40)airplane = Airplane("Big-Plane", 800)
下面为vehicle
提供通用的print
方法:
print.vehicle = function(x, ...){ cat(sprintf("<vehicle: %s>\n", class(x)[1])) cat("name:", x$name, "\n") cat("speed:", x$speed, "km/h\n") cat("position:", paste(x$position, collapse = ", "), "\n")}
因为我们定义的3个子类都有了继承,所以print
方法通用:
car#> <vehicle: car>#> name: Model-A #> speed: 80 km/h#> position: 0, 0, 0bus#> <vehicle: bus>#> name: Medium-Bus #> speed: 40 km/h#> position: 0, 0, 0airplane#> <vehicle: airplane>#> name: Big-Plane #> speed: 800 km/h#> position: 0, 0, 0
因为交通工具可以移动,我们创建一个泛型函数move
来表征这样的状态:
move = function(vehicle, x, y, z) { UseMethod("move")}move.vehicle = function(vehicle, movement) { if (length(movement) != 3){ stop("All three dimensions must be specified to move a vehicle") } vehicle$position = vehicle$position + movement vehicle}
这里我们将汽车和公共汽车的移动限定在二维平面上。
move.bus = move.car = function(vehicle, movement) { if (length(movement) != 2){ stop("This vehicle only supports 2d movement") } movement = c(movement, 0) NextMethod("move")}
这里我们将movement
的第3个纬度强制转换为0,然后调用NextMethod("move")
来调用move.vehicle()
。
飞机既可以在2维也可以在3维:
move.airplane = function(vehicle, movement) { if (length(movement) == 2){ movement = c(movement, 0) } NextMethod("move")}
下载3种方法都实现了,进行测试。
move(car, c(1, 2, 3))#> Error in move.car(car, c(1, 2, 3)): This vehicle only supports 2d movement
只能输入二维,所以提示报错了。
move(car, c(1, 2))#> <vehicle: car>#> name: Model-A #> speed: 80 km/h#> position: 1, 2, 0
move(airplane, c(1, 2))#> <vehicle: airplane>#> name: Big-Plane #> speed: 800 km/h#> position: 1, 2, 0
飞机,3维:
move(airplane, c(20,100,50))#> <vehicle: airplane>#> name: Big-Plane #> speed: 800 km/h#> position: 21, 102, 50
注意,airplane的位置是累积的。因为前面说过,它本质是一个环境,因此修改move.vehicle()
中的position
不会创建一个副本再修改,而是本地修改!
学习自《R语言编程指南》
内容太多,下次学习接下来的内容。
文章作者 王诗翔
上次更新 2018-08-15
许可协议 CC BY-NC-ND 4.0