commit 1d40f677760d3fbaaaac6292841f2f5426bf83bd
parent 37119ad0d24f4fb6052a7b1bc8ae2081fed621a2
Author: Siddhant Agarwal <68201519+siddhantdev@users.noreply.github.com>
Date: Mon, 18 Aug 2025 09:10:28 +0530
feat(ssh): SSH configuration parser #35027
Diffstat:
2 files changed, 315 insertions(+), 0 deletions(-)
diff --git a/runtime/lua/vim/net/_ssh.lua b/runtime/lua/vim/net/_ssh.lua
@@ -0,0 +1,237 @@
+-- Converted into Lua from https://github.com/cyjake/ssh-config
+-- TODO (siddhantdev): deal with include directives
+
+local M = {}
+
+local whitespace_pattern = '%s'
+local line_break_pattern = '[\r\n]'
+
+---@param param string
+local function is_multi_value_directive(param)
+ local multi_value_directives = {
+ 'globalknownhostsfile',
+ 'host',
+ 'ipqos',
+ 'sendenv',
+ 'userknownhostsfile',
+ 'proxycommand',
+ 'match',
+ 'canonicaldomains',
+ }
+
+ return vim.list_contains(multi_value_directives, param:lower())
+end
+
+---@param text string The ssh configuration which needs to be parsed
+---@return string[] The parsed host names in the configuration
+function M.parse_ssh_config(text)
+ local i = 1
+ local line = 1
+
+ local function consume()
+ if i <= #text then
+ local char = text:sub(i, i)
+ i = i + 1
+ return char
+ end
+ return nil
+ end
+
+ local chr = consume()
+
+ local function parse_spaces()
+ local spaces = ''
+ while chr and chr:match(whitespace_pattern) do
+ spaces = spaces .. chr
+ chr = consume()
+ end
+ return spaces
+ end
+
+ local function parse_linebreaks()
+ local breaks = ''
+ while chr and chr:match(line_break_pattern) do
+ line = line + 1
+ breaks = breaks .. chr
+ chr = consume()
+ end
+ return breaks
+ end
+
+ local function parse_parameter_name()
+ local param = ''
+ while chr and not chr:match('[ \t=]') do
+ param = param .. chr
+ chr = consume()
+ end
+ return param
+ end
+
+ local function parse_separator()
+ local sep = parse_spaces()
+ if chr == '=' then
+ sep = sep .. chr
+ chr = consume()
+ end
+ return sep .. parse_spaces()
+ end
+
+ local function parse_value()
+ local val = {}
+ local quoted, escaped = false, false
+
+ while chr and not chr:match(line_break_pattern) do
+ if escaped then
+ table.insert(val, chr == '"' and chr or '\\' .. chr)
+ escaped = false
+ elseif chr == '"' and (val == {} or quoted) then
+ quoted = not quoted
+ elseif chr == '\\' then
+ escaped = true
+ elseif chr == '#' and not quoted then
+ break
+ else
+ table.insert(val, chr)
+ end
+ chr = consume()
+ end
+
+ if quoted or escaped then
+ error('Unexpected line break at line ' .. line)
+ end
+
+ return vim.trim(table.concat(val))
+ end
+
+ local function parse_comment()
+ while chr and not chr:match(line_break_pattern) do
+ chr = consume()
+ end
+ end
+
+ ---@return string[]
+ local function parse_multiple_values()
+ local results = {}
+ local val = {}
+ local quoted = false
+ local escaped = false
+
+ while chr and not chr:match(line_break_pattern) do
+ if escaped then
+ table.insert(val, chr == '"' and chr or '\\' .. chr)
+ escaped = false
+ elseif chr == '"' then
+ quoted = not quoted
+ elseif chr == '\\' then
+ escaped = true
+ elseif quoted then
+ table.insert(val, chr)
+ elseif chr:match('[ \t=]') then
+ if val ~= {} then
+ table.insert(results, vim.trim(table.concat(val)))
+ val = {}
+ end
+ elseif chr == '#' and #results > 0 then
+ break
+ else
+ table.insert(val, chr)
+ end
+ chr = consume()
+ end
+
+ if quoted or escaped then
+ error('Unexpected line break at line ' .. line)
+ end
+
+ if val ~= {} then
+ table.insert(results, vim.trim(table.concat(val)))
+ end
+
+ return results
+ end
+
+ local function parse_directive()
+ local param = parse_parameter_name()
+ local multiple = is_multi_value_directive(param)
+ local _ = parse_separator()
+ local value = multiple and parse_multiple_values() or parse_value()
+
+ local result = {
+ param = param,
+ value = value,
+ }
+
+ return result
+ end
+
+ local function parse_line()
+ local _ = parse_spaces()
+ if chr == '#' then
+ parse_comment()
+ return nil
+ end
+ local node = parse_directive()
+ local _ = parse_linebreaks()
+
+ return node
+ end
+
+ local hostnames = {}
+
+ ---@param value string
+ local function is_valid(value)
+ return not (value:find('[?*!]') or vim.list_contains(hostnames, value))
+ end
+
+ while chr do
+ local node = parse_line()
+ if node then
+ -- This is done just to assign the type
+ node.value = node.value ---@type string[]
+ if node.param:lower() == 'match' and node.value then
+ local current = nil
+ for ind, val in ipairs(node.value) do
+ if val:lower() == 'host' and ind + 1 <= #node.value and is_valid(node.value[ind + 1]) then
+ current = node.value[ind + 1]
+ end
+ end
+ if current then
+ table.insert(hostnames, current)
+ end
+ elseif node.param:lower() == 'host' and node.value then
+ for _, value in ipairs(node.value) do
+ if is_valid(value) then
+ table.insert(hostnames, value)
+ end
+ end
+ end
+ end
+ end
+
+ return hostnames
+end
+
+---@param filename string
+---@return string[] The hostnames configured in the file located at filename
+function M.parse_config(filename)
+ local file = io.open(filename, 'r')
+ if not file then
+ error('Cannot read ssh configuration file')
+ end
+ local config_string = file:read('*a')
+ file:close()
+
+ return M.parse_ssh_config(config_string)
+end
+
+---@return string[] The hostnames configured in the ssh configuration file
+--- located at "~/.ssh/config".
+--- Note: This does not currently process `Include` directives in the
+--- configuration file.
+function M.get_hosts()
+ local config_path = vim.fs.normalize('~/.ssh/config') ---@type string
+
+ return M.parse_config(config_path)
+end
+
+return M
diff --git a/test/functional/lua/ssh_spec.lua b/test/functional/lua/ssh_spec.lua
@@ -0,0 +1,78 @@
+local t = require('test.testutil')
+local parser = require('vim.net._ssh')
+local eq = t.eq
+
+describe('SSH parser', function()
+ it('parses SSH configuration strings', function()
+ local config = [[
+ Host *
+ ConnectTimeout 10
+ ServerAliveInterval 60
+ ServerAliveCountMax 3
+ # Use a specific key for any host not otherwise specified
+ # IdentityFile ~/.ssh/id_rsa
+
+ Host=dev
+ HostName=dev.example.com
+ User=devuser
+ Port=2222
+ IdentityFile=~/.ssh/id_rsa_dev
+
+ Host prod test
+ HostName 198.51.100.10
+ User admin
+ Port 22
+ IdentityFile ~/.ssh/id_rsa_prod
+ ForwardAgent yes
+
+ Host test
+ IdentitiesOnly yes
+
+ Host "quoted string"
+ User quote
+ Port 22
+
+ Match host foo host gh
+ HostName github.com
+ User git
+ IdentityFile ~/.ssh/id_rsa_github
+ IdentitiesOnly yes
+ ]]
+
+ eq({
+ 'dev',
+ 'prod',
+ 'test',
+ 'quoted string',
+ 'gh',
+ }, parser.parse_ssh_config(config))
+ end)
+
+ it('fails when a quote is not closed', function()
+ local config = [[
+ Host prod dev "test prod my
+ HostName 198.51.100.10
+ User admin
+ Port 22
+ IdentityFile ~/.ssh/id_rsa_prod
+ ForwardAgent yes
+ ]]
+
+ local ok, _ = pcall(parser.parse_ssh_config, config)
+ eq(false, ok)
+ end)
+
+ it('fails when the line ends with a single backslash', function()
+ local config = [[
+ Host prod test
+ HostName 198.51.100.10
+ User admin\
+ Port 22
+ IdentityFile ~/.ssh/id_rsa_prod
+ ForwardAgent yes
+ ]]
+
+ local ok, _ = pcall(parser.parse_ssh_config, config)
+ eq(false, ok)
+ end)
+end)