背景

某业务团队提出脚本性能测试的需求:想知道脚本 A 中的耗时具体在哪里。
PS:刚拿到该需求的时候我哭笑不得,从需求来看就是普通的跑一次脚本获取具体的耗时在哪里😂(后续性能调优也还是研发,测试仅做了个执行脚本的操作)。。。但迫于刚把测试团队的性能测试推广出去(热乎的新活),总不能直接把自己的路堵死了吧。
图1


解决方案

简单扫了一遍脚本的调用, 整个脚本还是比较简单(也就涉及到 6、7 个方法), 这样直接手动添加耗时统计效率最快。
【方案】:人为在每个方法中添加耗时的统计方法,然后重新编译并执行。 这样就得到了具体脚本耗时最大头是在方法 XX 上, 并通过pprof 工具对结果进行了一些分析,最终将结果反馈给了研发。

事情进阶

本以为这个事情就结束了,后来业务团队又提出了更多的脚本性能需求(复杂的涉及到好几十个方法调用 )。。。
图2

一个个手工去添加重复性又太高,并且效率比较低, 原有的手工去添加耗时的方法就行不通了😓。由于项目没有接入 APM 这类自动获取调用链路以及耗时的工具(临时去接这种工具可行性不高),那怎么去快速的在这么多方法中注入耗时统计的代码呢 ( ̄ε  ̄) ?

此时想到了原来接触的静态扫描中的一一AST能够满足自动注入代码的需求。


优化版方案

通过对脚本分析, 脚本调用的方法都是在当前项目内, 所以新的方案只需要满足对指定目录/文件的方法进行耗时统计代码的注入即可。二话不说,撸起袖子就是干!此处不详细介绍 AST 的基础知识,有兴趣可以看网上一些资料(个人感觉都就只简单讲了开始的一点入门,没有什么具体的深入了(不同语言生成的语法树差异不大,比如:go、js、java 等):

PS:下面是 Demo 代码, 有兴趣的可以自己本地 GoLand 编译器上尝试一下(⚠️ 注意:Demo 版本代码比较粗糙,不要在意一些编程规范或实现 ( ̄ε  ̄))。

demo.go 文件(被注入的对象)

package main

import (
    "fmt"
)

func greet() {
  // test 一般代码首行存在注释
    total := 0
    for i := 1; i <= 1000; i++ {
        total += i
    }
    fmt.Printf("total = %v\n", total)
}

func test() {}

核心实现的代码 main.go

package main

import (
    "bufio"
    _ "bufio"
    "bytes"
    "fmt"
    "go/ast"
    "go/format"
    "go/parser"
    "go/token"
    "log"
    "os"
    _ "os"
    "path/filepath"
    "strconv"
)

type Visitor struct {
}

func addStartTimeCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        startCodePos := bodyStmt.Pos()
        xValue := &ast.Ident{
            Name: "time",
            NamePos: startCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Now",
            NamePos: startCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Lparen: startCodePos,
            Rparen: startCodePos,
        }
        assignStmt := &ast.AssignStmt{
            Lhs:    []ast.Expr {ast.NewIdent("startT")},
            TokPos: startCodePos,
            Tok: token.DEFINE,
            Rhs: []ast.Expr { callExpr },
        }
        listNew := []ast.Stmt {assignStmt}
        bodyStmt.List = append(listNew, listStmt...)
    }
}

func parseFuncEndPos(blockStmt *ast.BlockStmt) token.Pos {
    bodyList := blockStmt.List
    lastElement := bodyList[len(bodyList)-1]
    funcEndPos := lastElement.End()
    return funcEndPos
}
func addEndTimeCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        endCodePos := parseFuncEndPos(bodyStmt)
        xValue := &ast.Ident{
            Name: "time",
            NamePos: endCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Since",
            NamePos: endCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Args: []ast.Expr {ast.NewIdent("startT")},
            Lparen: endCodePos,
            Rparen: endCodePos,
        }
        assignStmt := &ast.AssignStmt{
            Lhs:    []ast.Expr {ast.NewIdent("tc")},
            TokPos: endCodePos,
            Tok: token.DEFINE,
            Rhs: []ast.Expr { callExpr},
        }
        bodyStmt.List = append(listStmt, assignStmt)
    }
}

func addEndPrintCode(funcDecl *ast.FuncDecl) {
    bodyStmt := funcDecl.Body
    if len(bodyStmt.List) == 0 {
        return
    } else{
        listStmt := bodyStmt.List
        endCodePos := parseFuncEndPos(bodyStmt)
        xValue := &ast.Ident{
            Name: "fmt",
            NamePos: endCodePos,
        }
        selVaule := &ast.Ident{
            Name: "Printf",
            NamePos: endCodePos,
        }
        selectorExpr := &ast.SelectorExpr{
            X:   xValue,
            Sel: selVaule,
        }
        basicList := &ast.BasicLit{
            ValuePos: endCodePos,
            Kind:     token.STRING,
            Value:    "\"time cost = %v\\n\"",
        }
        nameValue := &ast.Ident{
            Name: "tc",
            NamePos: endCodePos,
        }
        callExpr := &ast.CallExpr{
            Fun:      selectorExpr,
            Args: []ast.Expr {basicList, nameValue},
            Lparen: endCodePos,
            Rparen: endCodePos,
        }
        exprStmt := &ast.ExprStmt{callExpr}
        bodyStmt.List = append(listStmt, exprStmt)
    }
}

func (v *Visitor) Visit(node ast.Node) ast.Visitor {
    switch node.(type) {
    case *ast.GenDecl:
        genDecl := node.(*ast.GenDecl)
        // 查找有没有import time包
        if genDecl.Tok == token.IMPORT {
            v.addImport(genDecl, "time")
            v.addImport(genDecl, "fmt")
            // 不需要再遍历子树
            return nil
        }
    case *ast.FuncDecl:
        funcDecl := node.(*ast.FuncDecl)
        addStartTimeCode(funcDecl)
        addEndTimeCode(funcDecl)
        addEndPrintCode(funcDecl)
        fmt.Println(funcDecl)
    }
    return v
}

// addImport 引入time包
func (v *Visitor) addImport(genDecl *ast.GenDecl, importName string) {
    // 是否已经import
    hasImported := false
    for _, v := range genDecl.Specs {
        imptSpec := v.(*ast.ImportSpec)
        // 如果已经包含"context"
        if imptSpec.Path.Value == strconv.Quote(importName) {
            hasImported = true
        }
    }
    // 如果没有import time,则import
    if !hasImported {
        genDecl.Specs = append(genDecl.Specs, &ast.ImportSpec{
            Path: &ast.BasicLit{
                Kind:  token.STRING,
                Value: strconv.Quote(importName),
            },
        })
    }
}

func main() {
    fset := token.NewFileSet()
    path, _ := filepath.Abs("./demo.go")
    f, err := parser.ParseFile(fset, path, nil, parser.ParseComments)
    if err != nil {
        log.Println(err)
        return
    }
    // 遍历节点
    v := &Visitor{}
    ast.Walk(v, f)
    var output []byte
    buffer := bytes.NewBuffer(output)
    err = format.Node(buffer, fset, f)
    if err != nil {
        log.Fatal(err)
    }
    // 输出Go代码 到对应文件
    fmt.Println(buffer.String())
    pathNew, _ := filepath.Abs("./demo_new.go")
    file, err := os.OpenFile(pathNew, os.O_WRONLY|os.O_CREATE, 0666)
    defer file.Close()
    write := bufio.NewWriter(file)
    write.WriteString(buffer.String())
    write.Flush()
}

实现的效果满足需求,具体效果如下(仅对 greet 方法插入了耗时统计的代码):

package main

import (
    "fmt"
    "time"
)

// main方法
func greet() {
    startT := time.Now()
    // test
    total := 0
    for i := 1; i <= 1000; i++ {
        total += i
    }
    fmt.Printf("total = %v\n", total)
    tc := time.Since(startT)
    fmt.Printf("time cost = %v\n", tc)
}

func test() {}

当 Demo 完成后第一感觉就是 NiuBility, 一般这种情况不出意外的话马上就要出意外了😂。对真实的项目脚本进行实操时直接各种问题。真是:一顿操作猛如虎,实操却是二百五。。。
图3


问题汇总

通过 2 天坚持不懈的打补丁和兼容, 完整版总算完成了。。。(此处就先不贴具体代码,后续实用一段时间稳定后再贴出来

图4


思考

最后感慨下:真的羡慕那些过目不忘的人(自己这昨天晚上吃了什么第二天早上可能就忘记了), 好久没有看 AST 了,导致此次使用上也生疏很多😂。。。

精通 AST 这真的是一个比较强大的技能, 当前能想到的场景包括调用链路分析、代码注入、变异测试中的代码篡改等等,直接对类似 sonar 扫描中的原理看有点上手太难。

!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!
!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!
!!!顺便咨询下有没有比较好的入门 AST 教程&实践推荐, 可以提供相关文档方便参考!!!


↙↙↙阅读原文可查看相关链接,并与作者交流